Spaces:
Build error
Build error
| import cv2 | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from collections import Counter | |
| from PIL import Image | |
| import os | |
| class SceneClassifier: | |
| def __init__(self, model_path: str = "2nzi/Image_Surf_NotSurf"): | |
| # print(f"[DEBUG] Initializing SceneClassifier with model: {model_path}") | |
| try: | |
| # Initialiser le processeur et le modèle | |
| self.processor = AutoImageProcessor.from_pretrained( | |
| "google/vit-base-patch16-224", | |
| use_fast=True | |
| ) | |
| self.model = AutoModelForImageClassification.from_pretrained( | |
| model_path, | |
| trust_remote_code=True | |
| ) | |
| self.id_to_label = self.model.config.id2label | |
| # print("[DEBUG] Model loaded successfully") | |
| except Exception as e: | |
| # print(f"[ERROR] Failed to load model: {str(e)}") | |
| raise | |
| def _time_to_seconds(self, time_str: str) -> float: | |
| h, m, s = time_str.split(':') | |
| return int(h) * 3600 + int(m) * 60 + float(s) | |
| def _extract_frames(self, video_path: str, start_time: str, end_time: str, num_frames: int = 5) -> list: | |
| cap = cv2.VideoCapture(video_path) | |
| start_sec = self._time_to_seconds(start_time) | |
| end_sec = self._time_to_seconds(end_time) | |
| scene_duration = end_sec - start_sec | |
| frame_interval = scene_duration / (num_frames + 1) | |
| frames = [] | |
| for i in range(num_frames): | |
| timestamp = start_sec + frame_interval * (i + 1) | |
| cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) | |
| success, frame = cap.read() | |
| if success: | |
| image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| frames.append(image_pil) | |
| else: | |
| print(f"[WARNING] Failed to extract frame at {timestamp} seconds") | |
| cap.release() | |
| return frames | |
| def _classify_frame(self, frame: Image) -> dict: | |
| inputs = self.processor(images=frame, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| confidence, predicted_class = torch.max(probs, dim=-1) | |
| return { | |
| "label": self.id_to_label[predicted_class.item()], | |
| "confidence": float(confidence.item()) | |
| } | |
| def classify_scene(self, video_path: str, scene: dict) -> dict: | |
| print(f"[DEBUG] Classifying scene: {scene['start']} -> {scene['end']}") | |
| frames = self._extract_frames(video_path, scene["start"], scene["end"]) | |
| if not frames: | |
| print("[WARNING] No frames extracted for classification") | |
| return {"recognized_sport": "Unknown", "confidence": 0.0} | |
| classifications = [self._classify_frame(frame) for frame in frames] | |
| labels = [c["label"] for c in classifications] | |
| label_counts = Counter(labels) | |
| predominant_label, count = label_counts.most_common(1)[0] | |
| confidence_avg = sum( | |
| c["confidence"] for c in classifications | |
| if c["label"] == predominant_label | |
| ) / count | |
| result = { | |
| "recognized_sport": predominant_label, | |
| "confidence": confidence_avg | |
| } | |
| print(f"[DEBUG] Classification result: {result}") | |
| return result |