| import os |
| import uuid |
| import asyncio |
| from typing import List |
| from fastapi import FastAPI, UploadFile, File, HTTPException |
| from dotenv import load_dotenv |
|
|
| from binary_classifier.classify import GeminiSportsClassifier |
| from embeddings_ import process_inputs |
| from s3_file_uploader import upload_to_s3 |
|
|
| load_dotenv() |
|
|
| app = FastAPI(title="Sports Classification API") |
| classifier = GeminiSportsClassifier() |
|
|
| UPLOAD_DIR = "temp_uploads" |
| os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
| @app.post("/classify") |
| async def classify_media(files: List[UploadFile] = File(...)): |
| if not files: |
| raise HTTPException(status_code=400, detail="No files uploaded") |
| file_paths = [] |
| path_to_original = {} |
| try: |
| for file in files: |
| ext = os.path.splitext(file.filename)[1] |
| temp_name = f"{uuid.uuid4()}{ext}" |
| temp_path = os.path.join(UPLOAD_DIR, temp_name) |
| with open(temp_path, "wb") as f: |
| content = await file.read() |
| f.write(content) |
| file_paths.append(temp_path) |
| path_to_original[temp_path] = file.filename |
| first_level_results = await classifier.classify_batch(file_paths) |
| allowed = {"not_sports", "error"} |
| filtered_items = [] |
|
|
| for item in first_level_results: |
| if item.get("answer") not in allowed: |
| path = item["file"] |
| file_name = path_to_original.get(path) |
| ext = os.path.splitext(path)[1].lower() |
|
|
| media_type = "video" if ext in [".mp4", ".avi", ".mov"] else "image" |
|
|
| filtered_items.append({ |
| "file_path": path, |
| "file_name": file_name, |
| "media_type": media_type, |
| "answer": item.get("answer") |
| }) |
|
|
| if not filtered_items: |
| return [] |
|
|
| embed_paths = [item["file_path"] for item in filtered_items] |
| embedding_result = process_inputs(embed_paths) |
| image_embeddings = embedding_result.get("images", []) |
| video_embeddings = embedding_result.get("videos", []) |
| img_idx = 0 |
| vid_idx = 0 |
|
|
| for item in filtered_items: |
| if item["media_type"] == "image": |
| item["embedding"] = image_embeddings[img_idx].tolist() if len(image_embeddings) > 0 else None |
| img_idx += 1 |
| else: |
| item["embedding"] = video_embeddings[vid_idx].tolist() if len(video_embeddings) > 0 else None |
| vid_idx += 1 |
|
|
| async def upload_single(item): |
| url = await upload_to_s3(item["file_path"], item["file_name"]) |
| item["s3_url"] = url |
| return item |
|
|
| uploaded_results = await asyncio.gather(*[upload_single(item) for item in filtered_items]) |
|
|
| return uploaded_results |
|
|
| finally: |
| for path in file_paths: |
| if os.path.exists(path): |
| os.remove(path) |