import os, cv2,json import asyncio from PIL import Image from google import genai from google.genai import types from textwrap import dedent from dotenv import load_dotenv load_dotenv() client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) class GeminiSportsClassifier: def __init__(self): self.categories = [ "a photo of a marathon race on city roads", "a photo of long distance marathon runners with bib numbers", "a photo of a large group running marathon event", "a photo of Olympic marathon race", "a photo of sprint running race on track", "a photo of Olympic track and field athletics", "a photo of relay race baton passing", "a photo of hurdles race on track", "a photo of high jump Olympic event", "a photo of pole vault Olympic event", "a photo of long jump Olympic event", "a photo of javelin throw Olympic event", "a photo of discus throw Olympic event", "a photo of shot put Olympic event", "a photo of Olympic swimming competition", "a photo of Olympic cycling road race", "a photo of Olympic track cycling", "a photo of Olympic gymnastics performance", "a photo of Olympic boxing match", "a photo of Olympic wrestling match", "a photo of Olympic weightlifting competition", "a photo of cricket match on stadium", "a photo of cricket players batting and bowling", "a photo of football match in stadium", "a photo of soccer players playing match", "a photo of football goal scoring moment", "a photo of basketball game in indoor court", "a photo of volleyball match on court", "a photo of tennis match on court", "a photo of badminton match indoor stadium" ] def _prepare_image(self, img): img.thumbnail((512, 512)) return img def _extract_frames(self, video_path, num_frames=2): cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frames = [] for i in range(num_frames): pos = int((i + 1) * total_frames / (num_frames + 1)) cap.set(cv2.CAP_PROP_POS_FRAMES, pos) ret, frame = cap.read() if ret: frame = cv2.resize(frame, (480, 270)) img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames.append(img) cap.release() return frames async def classify(self, file_path): ext = os.path.splitext(file_path)[1].lower() try: media = [] if ext in ['.jpg', '.jpeg', '.png', ".webp"]: media.append(self._prepare_image(Image.open(file_path))) elif ext in ['.mp4', '.avi', '.mov']: media = self._extract_frames(file_path) else: return {"file": file_path, "answer": "not_sports"} prompt = dedent(f""" Analyze the provided media (image or video frames). 1. If it matches one of these categories, return the category: {self.categories} 2. If it is not related to these sports, return "not_sports". Output strictly in JSON: {{"answer": "category_name"}} """) response = await client.aio.models.generate_content( model="gemini-2.0-flash", contents=[prompt, *media], config=types.GenerateContentConfig( response_mime_type='application/json' ) ) data = json.loads(response.text) if isinstance(data, dict): return {"file": file_path, **data} else: return {"file": file_path, "answer": str(data)} except Exception as e: return {"file": file_path, "answer": "error", "details": str(e)} async def classify_batch(self, file_paths, max_concurrent=5): semaphore = asyncio.Semaphore(max_concurrent) async def classify_with_limit(path): async with semaphore: return await self.classify(path) tasks = [classify_with_limit(path) for path in file_paths] return await asyncio.gather(*tasks)