from pathlib import Path from typing import List, Tuple, Dict import sys import os from numpy import ndarray from pydantic import BaseModel sys.path.append(os.path.dirname(os.path.abspath(__file__))) from keypoint_helper import run_keypoints_post_processing from ultralytics import YOLO from team_cluster import TeamClassifier from utils import ( BoundingBox, Constants, ) from inference import predict_batch import time import torch import gc from pitch import process_batch_input, get_cls_net import yaml os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" class BoundingBox(BaseModel): x1: int y1: int x2: int y2: int cls_id: int conf: float class TVFrameResult(BaseModel): frame_id: int boxes: List[BoundingBox] keypoints: List[Tuple[int, int]] class Miner: SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT CORNER_INDICES = Constants.CORNER_INDICES KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting def __init__(self, path_hf_repo: Path) -> None: try: device = "cuda" if torch.cuda.is_available() else "cpu" model_path = path_hf_repo / "football_object_detection.onnx" self.bbox_model = YOLO(model_path) print("BBox Model Loaded") ball_path = path_hf_repo / "ball_detection.pt" self.ball_model = YOLO(ball_path) self.ball_model.to(device) for _ in range(3): dummy_input = torch.zeros(16, 3, 1024, 1024, device=device) self.ball_model(dummy_input) print("Ball Model Loaded") team_model_path = path_hf_repo / "osnet_model.pth.tar-100" self.team_classifier = TeamClassifier( device=device, batch_size=32, model_name=str(team_model_path) ) print("Team Classifier Loaded") # Team classification state self.team_classifier_fitted = False self.player_crops_for_fit = [] model_kp_path = path_hf_repo / 'keypoint' config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) loaded_state_kp = torch.load(model_kp_path, map_location=device) model = get_cls_net(cfg_kp) model.load_state_dict(loaded_state_kp) model.to(device) model.eval() self.keypoints_model = model self.kp_threshold = 0.1 self.pitch_batch_size = 4 self.health = "healthy" print("✅ Keypoints Model Loaded") except Exception as e: self.health = "❌ Miner initialization failed: " + str(e) print(self.health) def __repr__(self) -> str: if self.health == 'healthy': return ( f"health: {self.health}\n" f"BBox Model: {type(self.bbox_model).__name__}\n" f"Keypoints Model: {type(self.keypoints_model).__name__}" ) else: return self.health def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: results = predict_batch( self.bbox_model, self.ball_model, self.team_classifier, self.keypoints_model, batch_images, offset, n_keypoints, self.pitch_batch_size, self.kp_threshold ) return results