import os import uuid import asyncio from typing import List from fastapi import FastAPI, UploadFile, File, HTTPException from dotenv import load_dotenv from binary_classifier.classify import GeminiSportsClassifier from embeddings_ import process_inputs from s3_file_uploader import upload_to_s3 load_dotenv() app = FastAPI(title="Sports Classification API") classifier = GeminiSportsClassifier() UPLOAD_DIR = "temp_uploads" os.makedirs(UPLOAD_DIR, exist_ok=True) @app.post("/classify") async def classify_media(files: List[UploadFile] = File(...)): if not files: raise HTTPException(status_code=400, detail="No files uploaded") file_paths = [] path_to_original = {} try: for file in files: ext = os.path.splitext(file.filename)[1] temp_name = f"{uuid.uuid4()}{ext}" temp_path = os.path.join(UPLOAD_DIR, temp_name) with open(temp_path, "wb") as f: content = await file.read() f.write(content) file_paths.append(temp_path) path_to_original[temp_path] = file.filename first_level_results = await classifier.classify_batch(file_paths) allowed = {"not_sports", "error"} filtered_items = [] for item in first_level_results: if item.get("answer") not in allowed: path = item["file"] file_name = path_to_original.get(path) ext = os.path.splitext(path)[1].lower() media_type = "video" if ext in [".mp4", ".avi", ".mov"] else "image" filtered_items.append({ "file_path": path, "file_name": file_name, "media_type": media_type, "answer": item.get("answer") }) if not filtered_items: return [] embed_paths = [item["file_path"] for item in filtered_items] embedding_result = process_inputs(embed_paths) image_embeddings = embedding_result.get("images", []) video_embeddings = embedding_result.get("videos", []) img_idx = 0 vid_idx = 0 for item in filtered_items: if item["media_type"] == "image": item["embedding"] = image_embeddings[img_idx].tolist() if len(image_embeddings) > 0 else None img_idx += 1 else: item["embedding"] = video_embeddings[vid_idx].tolist() if len(video_embeddings) > 0 else None vid_idx += 1 async def upload_single(item): url = await upload_to_s3(item["file_path"], item["file_name"]) item["s3_url"] = url return item uploaded_results = await asyncio.gather(*[upload_single(item) for item in filtered_items]) return uploaded_results finally: for path in file_paths: if os.path.exists(path): os.remove(path)