| import os, cv2,json |
| import asyncio |
| from PIL import Image |
| from google import genai |
| from google.genai import types |
| from textwrap import dedent |
|
|
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) |
|
|
| class GeminiSportsClassifier: |
| def __init__(self): |
| self.categories = [ |
| "a photo of a marathon race on city roads", "a photo of long distance marathon runners with bib numbers", |
| "a photo of a large group running marathon event", "a photo of Olympic marathon race", |
| "a photo of sprint running race on track", "a photo of Olympic track and field athletics", |
| "a photo of relay race baton passing", "a photo of hurdles race on track", |
| "a photo of high jump Olympic event", |
| "a photo of pole vault Olympic event", "a photo of long jump Olympic event", |
| "a photo of javelin throw Olympic event", |
| "a photo of discus throw Olympic event", "a photo of shot put Olympic event", |
| "a photo of Olympic swimming competition", |
| "a photo of Olympic cycling road race", "a photo of Olympic track cycling", |
| "a photo of Olympic gymnastics performance", |
| "a photo of Olympic boxing match", "a photo of Olympic wrestling match", |
| "a photo of Olympic weightlifting competition", |
| "a photo of cricket match on stadium", "a photo of cricket players batting and bowling", |
| "a photo of football match in stadium", |
| "a photo of soccer players playing match", "a photo of football goal scoring moment", |
| "a photo of basketball game in indoor court", |
| "a photo of volleyball match on court", "a photo of tennis match on court", |
| "a photo of badminton match indoor stadium" |
| ] |
| def _prepare_image(self, img): |
| img.thumbnail((512, 512)) |
| return img |
|
|
| def _extract_frames(self, video_path, num_frames=2): |
| cap = cv2.VideoCapture(video_path) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| frames = [] |
| for i in range(num_frames): |
| pos = int((i + 1) * total_frames / (num_frames + 1)) |
| cap.set(cv2.CAP_PROP_POS_FRAMES, pos) |
| ret, frame = cap.read() |
| if ret: |
| frame = cv2.resize(frame, (480, 270)) |
| img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| frames.append(img) |
| cap.release() |
| return frames |
|
|
| async def classify(self, file_path): |
| ext = os.path.splitext(file_path)[1].lower() |
| try: |
| media = [] |
| if ext in ['.jpg', '.jpeg', '.png', ".webp"]: |
| media.append(self._prepare_image(Image.open(file_path))) |
| elif ext in ['.mp4', '.avi', '.mov']: |
| media = self._extract_frames(file_path) |
| else: |
| return {"file": file_path, "answer": "not_sports"} |
|
|
| prompt = dedent(f""" |
| Analyze the provided media (image or video frames). |
| 1. If it matches one of these categories, return the category: {self.categories} |
| 2. If it is not related to these sports, return "not_sports". |
| Output strictly in JSON: {{"answer": "category_name"}} |
| """) |
|
|
| response = await client.aio.models.generate_content( |
| model="gemini-2.0-flash", |
| contents=[prompt, *media], |
| config=types.GenerateContentConfig( |
| response_mime_type='application/json' |
| ) |
| ) |
|
|
| data = json.loads(response.text) |
| if isinstance(data, dict): |
| return {"file": file_path, **data} |
| else: |
| return {"file": file_path, "answer": str(data)} |
|
|
| except Exception as e: |
| return {"file": file_path, "answer": "error", "details": str(e)} |
|
|
| async def classify_batch(self, file_paths, max_concurrent=5): |
| semaphore = asyncio.Semaphore(max_concurrent) |
|
|
| async def classify_with_limit(path): |
| async with semaphore: |
| return await self.classify(path) |
|
|
| tasks = [classify_with_limit(path) for path in file_paths] |
| return await asyncio.gather(*tasks) |
|
|
|
|
|
|