File size: 3,708 Bytes
b8add4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from pathlib import Path
from typing import List, Tuple, Dict
import sys
import os

from numpy import ndarray
from pydantic import BaseModel
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from ultralytics import YOLO
from team_cluster import TeamClassifier
from utils import (
    BoundingBox, 
    Constants,
)
from inference import predict_batch
import torch
from pitch import get_cls_net
import yaml
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


class BoundingBox(BaseModel):
    x1: int
    y1: int
    x2: int
    y2: int
    cls_id: int
    conf: float


class TVFrameResult(BaseModel):
    frame_id: int
    boxes: List[BoundingBox]
    keypoints: List[Tuple[int, int]]


class Miner:
    SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
    SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
    SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
    CORNER_INDICES = Constants.CORNER_INDICES
    KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
    CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
    GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
    MIN_SAMPLES_FOR_FIT = 16  # Minimum player crops needed before fitting TeamClassifier
    MAX_SAMPLES_FOR_FIT = 1000  # Maximum samples to avoid overfitting

    def __init__(self, path_hf_repo: Path) -> None:
        try:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model_path = path_hf_repo / "football_object_detection.onnx"
            self.bbox_model = YOLO(model_path)
            
            print("BBox Model Loaded")

            team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
            self.team_classifier = TeamClassifier(
                device=device,
                batch_size=32,
                model_name=str(team_model_path)
            )
            print("Team Classifier Loaded")
            
            # Team classification state
            self.team_classifier_fitted = False
            self.player_crops_for_fit = [] 

            model_kp_path = path_hf_repo / 'keypoint'
            config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
            cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
            
            loaded_state_kp = torch.load(model_kp_path, map_location=device)
            model = get_cls_net(cfg_kp)
            model.load_state_dict(loaded_state_kp)
            model.to(device)
            model.eval()

            self.keypoints_model = model
            self.kp_threshold = 0.1
            self.pitch_batch_size = 4
            self.health = "healthy"
            self.path_hf_repo = path_hf_repo
            print("✅ Keypoints Model Loaded")
        except Exception as e:
            self.health = "❌ Miner initialization failed: " + str(e)
            print(self.health)

    def __repr__(self) -> str:
        if self.health == 'healthy':
            return (
                f"health: {self.health}\n"
                f"BBox Model: {type(self.bbox_model).__name__}\n"
                f"Keypoints Model: {type(self.keypoints_model).__name__}"
            )
        else:
            return self.health

    def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:    
        results = predict_batch(
            self.bbox_model,
            self.team_classifier,
            self.keypoints_model,
            batch_images,
            offset,
            n_keypoints,
            self.pitch_batch_size,
            self.kp_threshold,
            self.path_hf_repo
        )
        return results