from pathlib import Path from typing import List, Tuple, Dict import sys import os from numpy import ndarray from pydantic import BaseModel sys.path.append(os.path.dirname(os.path.abspath(__file__))) from ultralytics import YOLO from team_cluster import TeamClassifier from utils import ( BoundingBox, Constants, ) from inference import predict_batch import torch from pitch import get_cls_net import yaml os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" class BoundingBox(BaseModel): x1: int y1: int x2: int y2: int cls_id: int conf: float class TVFrameResult(BaseModel): frame_id: int boxes: List[BoundingBox] keypoints: List[Tuple[int, int]] class Miner: SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT CORNER_INDICES = Constants.CORNER_INDICES KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier MAX_SAMPLES_FOR_FIT = 1000 # Maximum samples to avoid overfitting def __init__(self, path_hf_repo: Path) -> None: try: device = "cuda" if torch.cuda.is_available() else "cpu" model_path = path_hf_repo / "football_object_detection.onnx" self.bbox_model = YOLO(model_path) print("BBox Model Loaded") team_model_path = path_hf_repo / "osnet_model.pth.tar-100" self.team_classifier = TeamClassifier( device=device, batch_size=32, model_name=str(team_model_path) ) print("Team Classifier Loaded") # Team classification state self.team_classifier_fitted = False self.player_crops_for_fit = [] model_kp_path = path_hf_repo / 'keypoint' config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) loaded_state_kp = torch.load(model_kp_path, map_location=device) model = get_cls_net(cfg_kp) model.load_state_dict(loaded_state_kp) model.to(device) model.eval() self.keypoints_model = model self.kp_threshold = 0.1 self.pitch_batch_size = 4 self.health = "healthy" self.path_hf_repo = path_hf_repo print("✅ Keypoints Model Loaded") except Exception as e: self.health = "❌ Miner initialization failed: " + str(e) print(self.health) def __repr__(self) -> str: if self.health == 'healthy': return ( f"health: {self.health}\n" f"BBox Model: {type(self.bbox_model).__name__}\n" f"Keypoints Model: {type(self.keypoints_model).__name__}" ) else: return self.health def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: results = predict_batch( self.bbox_model, self.team_classifier, self.keypoints_model, batch_images, offset, n_keypoints, self.pitch_batch_size, self.kp_threshold, self.path_hf_repo ) return results