| | from pathlib import Path
|
| | from typing import List, Tuple, Dict
|
| | import sys
|
| | import os
|
| |
|
| | from numpy import ndarray
|
| | from pydantic import BaseModel
|
| | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| | from keypoint_helper import run_keypoints_post_processing
|
| |
|
| | from ultralytics import YOLO
|
| | from team_cluster import TeamClassifier
|
| | from utils import (
|
| | BoundingBox,
|
| | Constants,
|
| | )
|
| | from inference import predict_batch
|
| | import time
|
| | import torch
|
| | import gc
|
| | from pitch import process_batch_input, get_cls_net
|
| | import yaml
|
| | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| |
|
| |
|
| | class BoundingBox(BaseModel):
|
| | x1: int
|
| | y1: int
|
| | x2: int
|
| | y2: int
|
| | cls_id: int
|
| | conf: float
|
| |
|
| |
|
| | class TVFrameResult(BaseModel):
|
| | frame_id: int
|
| | boxes: List[BoundingBox]
|
| | keypoints: List[Tuple[int, int]]
|
| |
|
| |
|
| | class Miner:
|
| | SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
|
| | SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
|
| | SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
|
| | CORNER_INDICES = Constants.CORNER_INDICES
|
| | KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
|
| | CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
|
| | GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
|
| | MIN_SAMPLES_FOR_FIT = 16
|
| | MAX_SAMPLES_FOR_FIT = 600
|
| |
|
| | def __init__(self, path_hf_repo: Path) -> None:
|
| | try:
|
| | device = "cuda" if torch.cuda.is_available() else "cpu"
|
| | model_path = path_hf_repo / "football_object_detection.onnx"
|
| | self.bbox_model = YOLO(model_path)
|
| |
|
| | print("BBox Model Loaded")
|
| |
|
| | ball_path = path_hf_repo / "ball_detection.pt"
|
| | self.ball_model = YOLO(ball_path)
|
| | self.ball_model.to(device)
|
| |
|
| | for _ in range(3):
|
| | dummy_input = torch.zeros(16, 3, 1024, 1024, device=device)
|
| | self.ball_model(dummy_input)
|
| | print("Ball Model Loaded")
|
| |
|
| | team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
|
| | self.team_classifier = TeamClassifier(
|
| | device=device,
|
| | batch_size=32,
|
| | model_name=str(team_model_path)
|
| | )
|
| | print("Team Classifier Loaded")
|
| |
|
| |
|
| | self.team_classifier_fitted = False
|
| | self.player_crops_for_fit = []
|
| |
|
| | model_kp_path = path_hf_repo / 'keypoint'
|
| | config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
|
| | cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
|
| |
|
| | loaded_state_kp = torch.load(model_kp_path, map_location=device)
|
| | model = get_cls_net(cfg_kp)
|
| | model.load_state_dict(loaded_state_kp)
|
| | model.to(device)
|
| | model.eval()
|
| |
|
| | self.keypoints_model = model
|
| | self.kp_threshold = 0.1
|
| | self.pitch_batch_size = 4
|
| | self.health = "healthy"
|
| | print("✅ Keypoints Model Loaded")
|
| | except Exception as e:
|
| | self.health = "❌ Miner initialization failed: " + str(e)
|
| | print(self.health)
|
| |
|
| | def __repr__(self) -> str:
|
| | if self.health == 'healthy':
|
| | return (
|
| | f"health: {self.health}\n"
|
| | f"BBox Model: {type(self.bbox_model).__name__}\n"
|
| | f"Keypoints Model: {type(self.keypoints_model).__name__}"
|
| | )
|
| | else:
|
| | return self.health
|
| |
|
| | def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
|
| | results = predict_batch(
|
| | self.bbox_model,
|
| | self.ball_model,
|
| | self.team_classifier,
|
| | self.keypoints_model,
|
| | batch_images,
|
| | offset,
|
| | n_keypoints,
|
| | self.pitch_batch_size,
|
| | self.kp_threshold
|
| | )
|
| | return results |