vicky4s4s's picture
Upload 76 files
01e9350 verified
import os
import uuid
import asyncio
from typing import List
from fastapi import FastAPI, UploadFile, File, HTTPException
from dotenv import load_dotenv
from binary_classifier.classify import GeminiSportsClassifier
from embeddings_ import process_inputs
from s3_file_uploader import upload_to_s3
load_dotenv()
app = FastAPI(title="Sports Classification API")
classifier = GeminiSportsClassifier()
UPLOAD_DIR = "temp_uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
@app.post("/classify")
async def classify_media(files: List[UploadFile] = File(...)):
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
file_paths = []
path_to_original = {}
try:
for file in files:
ext = os.path.splitext(file.filename)[1]
temp_name = f"{uuid.uuid4()}{ext}"
temp_path = os.path.join(UPLOAD_DIR, temp_name)
with open(temp_path, "wb") as f:
content = await file.read()
f.write(content)
file_paths.append(temp_path)
path_to_original[temp_path] = file.filename
first_level_results = await classifier.classify_batch(file_paths)
allowed = {"not_sports", "error"}
filtered_items = []
for item in first_level_results:
if item.get("answer") not in allowed:
path = item["file"]
file_name = path_to_original.get(path)
ext = os.path.splitext(path)[1].lower()
media_type = "video" if ext in [".mp4", ".avi", ".mov"] else "image"
filtered_items.append({
"file_path": path,
"file_name": file_name,
"media_type": media_type,
"answer": item.get("answer")
})
if not filtered_items:
return []
embed_paths = [item["file_path"] for item in filtered_items]
embedding_result = process_inputs(embed_paths)
image_embeddings = embedding_result.get("images", [])
video_embeddings = embedding_result.get("videos", [])
img_idx = 0
vid_idx = 0
for item in filtered_items:
if item["media_type"] == "image":
item["embedding"] = image_embeddings[img_idx].tolist() if len(image_embeddings) > 0 else None
img_idx += 1
else:
item["embedding"] = video_embeddings[vid_idx].tolist() if len(video_embeddings) > 0 else None
vid_idx += 1
async def upload_single(item):
url = await upload_to_s3(item["file_path"], item["file_name"])
item["s3_url"] = url
return item
uploaded_results = await asyncio.gather(*[upload_single(item) for item in filtered_items])
return uploaded_results
finally:
for path in file_paths:
if os.path.exists(path):
os.remove(path)