File size: 4,109 Bytes
210e540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from pathlib import Path
from typing import List, Tuple, Dict
import sys
import os

from numpy import ndarray
from pydantic import BaseModel
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from keypoint_helper import run_keypoints_post_processing

from ultralytics import YOLO
from team_cluster import TeamClassifier
from utils import (
    BoundingBox, 
    Constants,
)
from inference import predict_batch
import time
import torch
import gc
from pitch import process_batch_input, get_cls_net
import yaml
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


class BoundingBox(BaseModel):
    x1: int
    y1: int
    x2: int
    y2: int
    cls_id: int
    conf: float


class TVFrameResult(BaseModel):
    frame_id: int
    boxes: List[BoundingBox]
    keypoints: List[Tuple[int, int]]


class Miner:
    SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
    SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
    SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
    CORNER_INDICES = Constants.CORNER_INDICES
    KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
    CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
    GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
    MIN_SAMPLES_FOR_FIT = 16  # Minimum player crops needed before fitting TeamClassifier
    MAX_SAMPLES_FOR_FIT = 600  # Maximum samples to avoid overfitting

    def __init__(self, path_hf_repo: Path) -> None:
        try:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model_path = path_hf_repo / "football_object_detection.onnx"
            self.bbox_model = YOLO(model_path)
            
            print("BBox Model Loaded")

            ball_path = path_hf_repo / "ball_detection.pt"
            self.ball_model = YOLO(ball_path)
            self.ball_model.to(device)

            for _ in range(3):
                dummy_input = torch.zeros(16, 3, 1024, 1024, device=device)
                self.ball_model(dummy_input)
            print("Ball Model Loaded")

            team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
            self.team_classifier = TeamClassifier(
                device=device,
                batch_size=32,
                model_name=str(team_model_path)
            )
            print("Team Classifier Loaded")
            
            # Team classification state
            self.team_classifier_fitted = False
            self.player_crops_for_fit = [] 

            model_kp_path = path_hf_repo / 'keypoint'
            config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
            cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
            
            loaded_state_kp = torch.load(model_kp_path, map_location=device)
            model = get_cls_net(cfg_kp)
            model.load_state_dict(loaded_state_kp)
            model.to(device)
            model.eval()

            self.keypoints_model = model
            self.kp_threshold = 0.1
            self.pitch_batch_size = 4
            self.health = "healthy"
            print("✅ Keypoints Model Loaded")
        except Exception as e:
            self.health = "❌ Miner initialization failed: " + str(e)
            print(self.health)

    def __repr__(self) -> str:
        if self.health == 'healthy':
            return (
                f"health: {self.health}\n"
                f"BBox Model: {type(self.bbox_model).__name__}\n"
                f"Keypoints Model: {type(self.keypoints_model).__name__}"
            )
        else:
            return self.health

    def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:    
        results = predict_batch(
            self.bbox_model,
            self.ball_model,
            self.team_classifier,
            self.keypoints_model,
            batch_images,
            offset,
            n_keypoints,
            self.pitch_batch_size,
            self.kp_threshold
        )
        return results