cliChutes commited on
Commit
b8add4e
·
verified ·
1 Parent(s): 429a825

scorevision: push artifact

Browse files
Files changed (1) hide show
  1. miner.py +111 -0
miner.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Tuple, Dict
3
+ import sys
4
+ import os
5
+
6
+ from numpy import ndarray
7
+ from pydantic import BaseModel
8
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
+
10
+ from ultralytics import YOLO
11
+ from team_cluster import TeamClassifier
12
+ from utils import (
13
+ BoundingBox,
14
+ Constants,
15
+ )
16
+ from inference import predict_batch
17
+ import torch
18
+ from pitch import get_cls_net
19
+ import yaml
20
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
21
+
22
+
23
+ class BoundingBox(BaseModel):
24
+ x1: int
25
+ y1: int
26
+ x2: int
27
+ y2: int
28
+ cls_id: int
29
+ conf: float
30
+
31
+
32
+ class TVFrameResult(BaseModel):
33
+ frame_id: int
34
+ boxes: List[BoundingBox]
35
+ keypoints: List[Tuple[int, int]]
36
+
37
+
38
+ class Miner:
39
+ SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
40
+ SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
41
+ SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
42
+ CORNER_INDICES = Constants.CORNER_INDICES
43
+ KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
44
+ CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
45
+ GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
46
+ MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
47
+ MAX_SAMPLES_FOR_FIT = 1000 # Maximum samples to avoid overfitting
48
+
49
+ def __init__(self, path_hf_repo: Path) -> None:
50
+ try:
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ model_path = path_hf_repo / "football_object_detection.onnx"
53
+ self.bbox_model = YOLO(model_path)
54
+
55
+ print("BBox Model Loaded")
56
+
57
+ team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
58
+ self.team_classifier = TeamClassifier(
59
+ device=device,
60
+ batch_size=32,
61
+ model_name=str(team_model_path)
62
+ )
63
+ print("Team Classifier Loaded")
64
+
65
+ # Team classification state
66
+ self.team_classifier_fitted = False
67
+ self.player_crops_for_fit = []
68
+
69
+ model_kp_path = path_hf_repo / 'keypoint'
70
+ config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
71
+ cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
72
+
73
+ loaded_state_kp = torch.load(model_kp_path, map_location=device)
74
+ model = get_cls_net(cfg_kp)
75
+ model.load_state_dict(loaded_state_kp)
76
+ model.to(device)
77
+ model.eval()
78
+
79
+ self.keypoints_model = model
80
+ self.kp_threshold = 0.1
81
+ self.pitch_batch_size = 4
82
+ self.health = "healthy"
83
+ self.path_hf_repo = path_hf_repo
84
+ print("✅ Keypoints Model Loaded")
85
+ except Exception as e:
86
+ self.health = "❌ Miner initialization failed: " + str(e)
87
+ print(self.health)
88
+
89
+ def __repr__(self) -> str:
90
+ if self.health == 'healthy':
91
+ return (
92
+ f"health: {self.health}\n"
93
+ f"BBox Model: {type(self.bbox_model).__name__}\n"
94
+ f"Keypoints Model: {type(self.keypoints_model).__name__}"
95
+ )
96
+ else:
97
+ return self.health
98
+
99
+ def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
100
+ results = predict_batch(
101
+ self.bbox_model,
102
+ self.team_classifier,
103
+ self.keypoints_model,
104
+ batch_images,
105
+ offset,
106
+ n_keypoints,
107
+ self.pitch_batch_size,
108
+ self.kp_threshold,
109
+ self.path_hf_repo
110
+ )
111
+ return results