tarto2 commited on
Commit
e9d3644
·
1 Parent(s): 996e082

introduce osnet

Browse files
Files changed (6) hide show
  1. config.yml +6 -3
  2. miner.py +316 -267
  3. osnet_ain.pyc +0 -0
  4. pitch.py +1 -19
  5. team_cluster.pyc +0 -0
  6. utils.pyc +0 -0
config.yml CHANGED
@@ -2,13 +2,15 @@ Image:
2
  from_base: parachutes/python:3.12
3
  run_command:
4
  - pip install --upgrade setuptools wheel
 
5
  - pip install "ultralytics==8.3.222" "opencv-python-headless" "numpy" "pydantic"
6
- - pip install "tensorflow" "torch==2.7.1" "torchvision==0.22.1" "torch-tensorrt==2.7"
 
7
  set_workdir: /app
8
 
9
  NodeSelector:
10
  gpu_count: 1
11
- min_vram_gb_per_gpu: 16
12
  exclude:
13
  - "5090"
14
  - b200
@@ -19,4 +21,5 @@ Chute:
19
  timeout_seconds: 900
20
  concurrency: 4
21
  max_instances: 5
22
- scaling_threshold: 0.5
 
 
2
  from_base: parachutes/python:3.12
3
  run_command:
4
  - pip install --upgrade setuptools wheel
5
+ - pip install "torch==2.7.1" "torchvision==0.22.1"
6
  - pip install "ultralytics==8.3.222" "opencv-python-headless" "numpy" "pydantic"
7
+ - pip install scikit-learn
8
+ - pip install onnxruntime-gpu
9
  set_workdir: /app
10
 
11
  NodeSelector:
12
  gpu_count: 1
13
+ min_vram_gb_per_gpu: 24
14
  exclude:
15
  - "5090"
16
  - b200
 
21
  timeout_seconds: 900
22
  concurrency: 4
23
  max_instances: 5
24
+ scaling_threshold: 0.5
25
+ shutdown_after_seconds: 3600
miner.py CHANGED
@@ -4,41 +4,22 @@ import sys
4
  import os
5
 
6
  from numpy import ndarray
7
- import numpy as np
8
  from pydantic import BaseModel
9
- import cv2
10
-
11
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12
 
13
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
14
- os.environ["OMP_NUM_THREADS"] = "16"
15
- os.environ["TF_NUM_INTRAOP_THREADS"] = "16"
16
- os.environ["TF_NUM_INTEROP_THREADS"] = "2"
17
- os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
18
- os.environ["ORT_LOGGING_LEVEL"] = "3"
19
- os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
20
-
21
- import logging
22
- import tensorflow as tf
23
- from tensorflow.keras import mixed_precision
24
- import torch._dynamo
25
  import torch
26
- # import torch_tensorrt
27
  import gc
28
- from ultralytics import YOLO
29
  from pitch import process_batch_input, get_cls_net
30
  import yaml
31
 
32
- logging.getLogger("tensorflow").setLevel(logging.ERROR)
33
- tf.config.threading.set_intra_op_parallelism_threads(16)
34
- tf.config.threading.set_inter_op_parallelism_threads(2)
35
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
36
- tf.get_logger().setLevel("ERROR")
37
- tf.autograph.set_verbosity(0)
38
- mixed_precision.set_global_policy("mixed_float16")
39
- tf.config.optimizer.set_jit(True)
40
- torch._dynamo.config.suppress_errors = True
41
-
42
 
43
  class BoundingBox(BaseModel):
44
  x1: int
@@ -56,260 +37,340 @@ class TVFrameResult(BaseModel):
56
 
57
 
58
  class Miner:
59
- QUASI_TOTAL_IOA: float = 0.90
60
- SMALL_CONTAINED_IOA: float = 0.85
61
- SMALL_RATIO_MAX: float = 0.50
62
- SINGLE_PLAYER_HUE_PIVOT: float = 90.0
63
-
64
- CLS_MAP = {
65
- 0: 0,
66
- 1: 1,
67
- 2: 6,
68
- 3: 7,
69
- 4: 3,
70
- }
71
 
72
  def __init__(self, path_hf_repo: Path) -> None:
73
- print(path_hf_repo / "best.pt")
74
- self.bbox_model = YOLO(path_hf_repo / "best.pt")
75
- print(" BBox Model (best.pt) Loaded")
76
- device = "cuda" if torch.cuda.is_available() else "cpu"
77
- # model_kp_path = path_hf_repo / "SV_kp.engine"
78
- # model_kp = torch_tensorrt.load(model_kp_path)
79
-
80
- model_kp_path = path_hf_repo / 'keypoint'
81
- config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
82
- cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
83
-
84
- loaded_state_kp = torch.load(model_kp_path, map_location=device)
85
- model = get_cls_net(cfg_kp)
86
- model.load_state_dict(loaded_state_kp)
87
- model.to(device)
88
- model.eval()
89
-
90
- # @torch.inference_mode()
91
- # def run_inference(model, input_tensor: torch.Tensor):
92
- # input_tensor = input_tensor.to(device).to(memory_format=torch.channels_last)
93
- # output = model.module().forward(input_tensor)
94
- # return output
95
-
96
- # run_inference(model_kp, torch.randn(8, 3, 540, 960, device=device, dtype=torch.float32))
97
- self.keypoints_model = model
98
- self.kp_threshold = 0.1
99
- self.pitch_batch_size = 8
100
- print("✅ Keypoints Model Loaded")
 
 
 
 
 
 
 
 
 
101
 
102
  def __repr__(self) -> str:
103
- return (
104
- f"BBox Model: {type(self.bbox_model).__name__}\n"
105
- f"Keypoints Model: {type(self.keypoints_model).__name__}"
106
- )
107
-
108
- @staticmethod
109
- def _clip_box_to_image(x1: int, y1: int, x2: int, y2: int, w: int, h: int) -> Tuple[int, int, int, int]:
110
- x1 = max(0, min(int(x1), w - 1))
111
- y1 = max(0, min(int(y1), h - 1))
112
- x2 = max(0, min(int(x2), w - 1))
113
- y2 = max(0, min(int(y2), h - 1))
114
- if x2 <= x1:
115
- x2 = min(w - 1, x1 + 1)
116
- if y2 <= y1:
117
- y2 = min(h - 1, y1 + 1)
118
- return x1, y1, x2, y2
119
-
120
- @staticmethod
121
- def _area(bb: BoundingBox) -> int:
122
- return max(0, bb.x2 - bb.x1) * max(0, bb.y2 - bb.y1)
123
-
124
- @staticmethod
125
- def _intersect_area(a: BoundingBox, b: BoundingBox) -> int:
126
- ix1 = max(a.x1, b.x1)
127
- iy1 = max(a.y1, b.y1)
128
- ix2 = min(a.x2, b.x2)
129
- iy2 = min(a.y2, b.y2)
130
- if ix2 <= ix1 or iy2 <= iy1:
131
- return 0
132
- return (ix2 - ix1) * (iy2 - iy1)
133
-
134
- @staticmethod
135
- def _center(bb: BoundingBox) -> Tuple[float, float]:
136
- return (0.5 * (bb.x1 + bb.x2), 0.5 * (bb.y1 + bb.y2))
137
-
138
- @staticmethod
139
- def _mean_hs(img_bgr: np.ndarray) -> Tuple[float, float]:
140
- hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
141
- return float(np.mean(hsv[:, :, 0])), float(np.mean(hsv[:, :, 1]))
142
-
143
- def _hs_feature_from_roi(self, img_bgr: np.ndarray, box: BoundingBox) -> np.ndarray:
144
- H, W = img_bgr.shape[:2]
145
- x1, y1, x2, y2 = self._clip_box_to_image(box.x1, box.y1, box.x2, box.y2, W, H)
146
- roi = img_bgr[y1:y2, x1:x2]
147
- if roi.size == 0:
148
- return np.array([0.0, 0.0], dtype=np.float32)
149
- hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
150
- lower_green = np.array([35, 60, 60], dtype=np.uint8)
151
- upper_green = np.array([85, 255, 255], dtype=np.uint8)
152
- green_mask = cv2.inRange(hsv, lower_green, upper_green)
153
- non_green_mask = cv2.bitwise_not(green_mask)
154
- num_non_green = int(np.count_nonzero(non_green_mask))
155
- total = hsv.shape[0] * hsv.shape[1]
156
- if num_non_green > max(50, total // 20):
157
- h_vals = hsv[:, :, 0][non_green_mask > 0]
158
- s_vals = hsv[:, :, 1][non_green_mask > 0]
159
- h_mean = float(np.mean(h_vals)) if h_vals.size else 0.0
160
- s_mean = float(np.mean(s_vals)) if s_vals.size else 0.0
161
  else:
162
- h_mean, s_mean = self._mean_hs(roi)
163
- return np.array([h_mean, s_mean], dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- def _ioa(self, a: BoundingBox, b: BoundingBox) -> float:
166
- inter = self._intersect_area(a, b)
167
- aa = self._area(a)
168
- if aa <= 0:
 
 
169
  return 0.0
170
- return inter / aa
171
-
172
- def suppress_quasi_total_containment(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
173
- if len(boxes) <= 1:
174
- return boxes
175
- keep = [True] * len(boxes)
176
- for i in range(len(boxes)):
177
- if not keep[i]:
178
- continue
179
- for j in range(len(boxes)):
180
- if i == j or not keep[j]:
181
- continue
182
- ioa_i_in_j = self._ioa(boxes[i], boxes[j])
183
- if ioa_i_in_j >= self.QUASI_TOTAL_IOA:
184
- keep[i] = False
185
- break
186
- return [bb for bb, k in zip(boxes, keep) if k]
187
-
188
- def suppress_small_contained(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
189
- if len(boxes) <= 1:
190
- return boxes
191
- keep = [True] * len(boxes)
192
- areas = [self._area(bb) for bb in boxes]
193
- for i in range(len(boxes)):
194
- if not keep[i]:
195
  continue
196
- for j in range(len(boxes)):
197
- if i == j or not keep[j]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  continue
199
- ai, aj = areas[i], areas[j]
200
- if ai == 0 or aj == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  continue
202
- if ai <= aj:
203
- ratio = ai / aj
204
- if ratio <= self.SMALL_RATIO_MAX:
205
- ioa_i_in_j = self._ioa(boxes[i], boxes[j])
206
- if ioa_i_in_j >= self.SMALL_CONTAINED_IOA:
207
- keep[i] = False
 
 
 
208
  break
209
- else:
210
- ratio = aj / ai
211
- if ratio <= self.SMALL_RATIO_MAX:
212
- ioa_j_in_i = self._ioa(boxes[j], boxes[i])
213
- if ioa_j_in_i >= self.SMALL_CONTAINED_IOA:
214
- keep[j] = False
215
- return [bb for bb, k in zip(boxes, keep) if k]
216
-
217
- def _assign_players_two_clusters(self, features: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
218
- criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0)
219
- _, labels, centers = cv2.kmeans(
220
- np.float32(features),
221
- K=2,
222
- bestLabels=None,
223
- criteria=criteria,
224
- attempts=5,
225
- flags=cv2.KMEANS_PP_CENTERS,
226
- )
227
- return labels.reshape(-1), centers
228
-
229
- def _reclass_extra_goalkeepers(self, img_bgr: np.ndarray, boxes: List[BoundingBox], cluster_centers: np.ndarray | None) -> None:
230
- gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1]
231
- if len(gk_idxs) <= 1:
232
- return
233
- gk_idxs_sorted = sorted(gk_idxs, key=lambda i: boxes[i].conf, reverse=True)
234
- keep_gk_idx = gk_idxs_sorted[0]
235
- to_reclass = gk_idxs_sorted[1:]
236
- for gki in to_reclass:
237
- hs_gk = self._hs_feature_from_roi(img_bgr, boxes[gki])
238
- if cluster_centers is not None:
239
- d0 = float(np.linalg.norm(hs_gk - cluster_centers[0]))
240
- d1 = float(np.linalg.norm(hs_gk - cluster_centers[1]))
241
- assign_cls = 6 if d0 <= d1 else 7
242
- else:
243
- assign_cls = 6 if float(hs_gk[0]) < self.SINGLE_PLAYER_HUE_PIVOT else 7
244
- boxes[gki].cls_id = int(assign_cls)
245
-
246
- def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
247
- bboxes: Dict[int, List[BoundingBox]] = {}
248
- bbox_model_results = self.bbox_model.predict(batch_images)
249
- if bbox_model_results is not None:
250
- for frame_idx_in_batch, detection in enumerate(bbox_model_results):
251
- if not hasattr(detection, "boxes") or detection.boxes is None:
252
  continue
253
- boxes: List[BoundingBox] = []
254
- for box in detection.boxes.data:
255
- x1, y1, x2, y2, conf, cls_id = box.tolist()
256
- # if cls_id == 3:
257
- # cls_id = 2
258
- # elif cls_id == 2:
259
- # cls_id = 3
 
 
260
  boxes.append(
261
  BoundingBox(
262
  x1=int(x1),
263
  y1=int(y1),
264
  x2=int(x2),
265
  y2=int(y2),
266
- cls_id=int(self.CLS_MAP[int(cls_id)]),
267
  conf=float(conf),
268
  )
269
  )
270
- footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
271
- if len(footballs) > 1:
272
- best_ball = max(footballs, key=lambda b: b.conf)
273
- boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
274
- boxes.append(best_ball)
275
- # boxes = self.suppress_quasi_total_containment(boxes)
276
- # boxes = self.suppress_small_contained(boxes)
277
- # img_bgr = batch_images[frame_idx_in_batch]
278
- # player_indices: List[int] = []
279
- # player_feats: List[np.ndarray] = []
280
- # for i, bb in enumerate(boxes):
281
- # if int(bb.cls_id) == 2:
282
- # hs = self._hs_feature_from_roi(img_bgr, bb)
283
- # player_indices.append(i)
284
- # player_feats.append(hs)
285
- # cluster_centers = None
286
- # n_players = len(player_feats)
287
- # if n_players >= 2:
288
- # feats = np.vstack(player_feats)
289
- # labels, centers = self._assign_players_two_clusters(feats)
290
- # order = np.argsort(centers[:, 0])
291
- # centers = centers[order]
292
- # remap = {old_idx: new_idx for new_idx, old_idx in enumerate(order)}
293
- # labels = np.vectorize(remap.get)(labels)
294
- # cluster_centers = centers
295
- # for idx_in_list, lbl in zip(player_indices, labels):
296
- # boxes[idx_in_list].cls_id = 6 if int(lbl) == 0 else 7
297
- # elif n_players == 1:
298
- # hue, _ = player_feats[0]
299
- # boxes[player_indices[0]].cls_id = 6 if float(hue) < self.SINGLE_PLAYER_HUE_PIVOT else 7
300
- # self._reclass_extra_goalkeepers(img_bgr, boxes, cluster_centers)
301
- bboxes[offset + frame_idx_in_batch] = boxes
302
 
303
  pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
304
  keypoints: Dict[int, List[Tuple[int, int]]] = {}
 
 
305
  while True:
306
- # try:
307
  gc.collect()
308
  if torch.cuda.is_available():
309
- tf.keras.backend.clear_session()
310
  torch.cuda.empty_cache()
311
  torch.cuda.synchronize()
312
- device_str = "cuda" if torch.cuda.is_available() else "cpu"
313
  keypoints_result = process_batch_input(
314
  batch_images,
315
  self.keypoints_model,
@@ -344,21 +405,10 @@ class Miner:
344
  else:
345
  frame_keypoints = frame_keypoints[:n_keypoints]
346
  keypoints[offset + frame_number_in_batch] = frame_keypoints
347
- print("✅ Keypoints predicted")
348
  break
349
- # except RuntimeError as e:
350
- # print(self.pitch_batch_size)
351
- # print(e)
352
- # if "out of memory" in str(e):
353
- # if self.pitch_batch_size == 1:
354
- # break
355
- # self.pitch_batch_size = self.pitch_batch_size // 2 if self.pitch_batch_size > 1 else 1
356
- # pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
357
- # else:
358
- # break
359
- # except Exception as e:
360
- # print(f"❌ Error during keypoints prediction: {e}")
361
- # break
362
 
363
  results: List[TVFrameResult] = []
364
  for frame_number in range(offset, offset + len(batch_images)):
@@ -373,7 +423,6 @@ class Miner:
373
 
374
  gc.collect()
375
  if torch.cuda.is_available():
376
- tf.keras.backend.clear_session()
377
  torch.cuda.empty_cache()
378
  torch.cuda.synchronize()
379
 
 
4
  import os
5
 
6
  from numpy import ndarray
 
7
  from pydantic import BaseModel
 
 
8
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
 
10
+ from ultralytics import YOLO
11
+ from team_cluster import TeamClassifier
12
+ from utils import (
13
+ BoundingBox,
14
+ Constants,
15
+ )
16
+
17
+ import time
 
 
 
 
18
  import torch
 
19
  import gc
 
20
  from pitch import process_batch_input, get_cls_net
21
  import yaml
22
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class BoundingBox(BaseModel):
25
  x1: int
 
37
 
38
 
39
  class Miner:
40
+ SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
41
+ SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
42
+ SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
43
+ CORNER_INDICES = Constants.CORNER_INDICES
44
+ KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
45
+ CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
46
+ GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
47
+ MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
48
+ MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting
 
 
 
49
 
50
  def __init__(self, path_hf_repo: Path) -> None:
51
+ try:
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ model_path = path_hf_repo / "football_object_detection.onnx"
54
+ self.bbox_model = YOLO(model_path)
55
+
56
+ print("BBox Model Loaded")
57
+
58
+ team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
59
+ self.team_classifier = TeamClassifier(
60
+ device=device,
61
+ batch_size=32,
62
+ model_name=str(team_model_path)
63
+ )
64
+ print("Team Classifier Loaded")
65
+
66
+ # Team classification state
67
+ self.team_classifier_fitted = False
68
+ self.player_crops_for_fit = []
69
+
70
+ model_kp_path = path_hf_repo / 'keypoint'
71
+ config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
72
+ cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
73
+
74
+ loaded_state_kp = torch.load(model_kp_path, map_location=device)
75
+ model = get_cls_net(cfg_kp)
76
+ model.load_state_dict(loaded_state_kp)
77
+ model.to(device)
78
+ model.eval()
79
+
80
+ self.keypoints_model = model
81
+ self.kp_threshold = 0.1
82
+ self.pitch_batch_size = 4
83
+ self.health = "healthy"
84
+ print("✅ Keypoints Model Loaded")
85
+ except Exception as e:
86
+ self.health = "❌ Miner initialization failed: " + str(e)
87
+ print(self.health)
88
 
89
  def __repr__(self) -> str:
90
+ if self.health == 'healthy':
91
+ return (
92
+ f"health: {self.health}\n"
93
+ f"BBox Model: {type(self.bbox_model).__name__}\n"
94
+ f"Keypoints Model: {type(self.keypoints_model).__name__}"
95
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  else:
97
+ return self.health
98
+
99
+ def _calculate_iou(self, box1: Tuple[float, float, float, float],
100
+ box2: Tuple[float, float, float, float]) -> float:
101
+ """
102
+ Calculate Intersection over Union (IoU) between two bounding boxes.
103
+ Args:
104
+ box1: (x1, y1, x2, y2)
105
+ box2: (x1, y1, x2, y2)
106
+ Returns:
107
+ IoU score (0-1)
108
+ """
109
+ x1_1, y1_1, x2_1, y2_1 = box1
110
+ x1_2, y1_2, x2_2, y2_2 = box2
111
+
112
+ # Calculate intersection area
113
+ x_left = max(x1_1, x1_2)
114
+ y_top = max(y1_1, y1_2)
115
+ x_right = min(x2_1, x2_2)
116
+ y_bottom = min(y2_1, y2_2)
117
+
118
+ if x_right < x_left or y_bottom < y_top:
119
+ return 0.0
120
+
121
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
122
 
123
+ # Calculate union area
124
+ box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
125
+ box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
126
+ union_area = box1_area + box2_area - intersection_area
127
+
128
+ if union_area == 0:
129
  return 0.0
130
+
131
+ return intersection_area / union_area
132
+
133
+ def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
134
+ batch_size = 16
135
+ detection_results = []
136
+ n_frames = len(decoded_images)
137
+ for frame_number in range(0, n_frames, batch_size):
138
+ batch_images = decoded_images[frame_number: frame_number + batch_size]
139
+ detections = self.bbox_model(batch_images, verbose=False, save=False)
140
+ detection_results.extend(detections)
141
+
142
+ return detection_results
143
+
144
+ def _team_classify(self, detection_results, decoded_images, offset):
145
+ self.team_classifier_fitted = False
146
+ start = time.time()
147
+ # Collect player crops from first batch for fitting
148
+ fit_sample_size = 600
149
+ player_crops_for_fit = []
150
+
151
+ for frame_id in range(len(detection_results)):
152
+ detection_box = detection_results[frame_id].boxes.data
153
+ if len(detection_box) < 4:
 
154
  continue
155
+ # Collect player boxes for team classification fitting (first batch only)
156
+ if len(player_crops_for_fit) < fit_sample_size:
157
+ frame_image = decoded_images[frame_id]
158
+ for box in detection_box:
159
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
160
+ if conf < 0.5:
161
+ continue
162
+ mapped_cls_id = str(int(cls_id))
163
+ # Only collect player crops (cls_id = 2)
164
+ if mapped_cls_id == '2':
165
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
166
+ if crop.size > 0:
167
+ player_crops_for_fit.append(crop)
168
+
169
+ # Fit team classifier after collecting samples
170
+ if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
171
+ print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
172
+ self.team_classifier.fit(player_crops_for_fit)
173
+ self.team_classifier_fitted = True
174
+ break
175
+ if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
176
+ print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
177
+ self.team_classifier.fit(player_crops_for_fit)
178
+ self.team_classifier_fitted = True
179
+ end = time.time()
180
+ print(f"Fitting Kmeans time: {end - start}")
181
+
182
+ # Second pass: predict teams with configurable frame skipping optimization
183
+ start = time.time()
184
+
185
+ # Get configuration for frame skipping
186
+ prediction_interval = 1 # Default: predict every 2 frames
187
+ iou_threshold = 0.3
188
+
189
+ print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
190
+
191
+ # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
192
+ predicted_frame_data = {}
193
+
194
+ # Step 1: Predict for frames at prediction_interval only
195
+ frames_to_predict = []
196
+ for frame_id in range(len(detection_results)):
197
+ if frame_id % prediction_interval == 0:
198
+ frames_to_predict.append(frame_id)
199
+
200
+ print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
201
+ f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
202
+
203
+ for frame_id in frames_to_predict:
204
+ detection_box = detection_results[frame_id].boxes.data
205
+ frame_image = decoded_images[frame_id]
206
+
207
+ # Collect player crops for this frame
208
+ frame_player_crops = []
209
+ frame_player_indices = []
210
+ frame_player_boxes = []
211
+
212
+ for idx, box in enumerate(detection_box):
213
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
214
+ if cls_id == 2 and conf < 0.6:
215
  continue
216
+ mapped_cls_id = str(int(cls_id))
217
+
218
+ # Collect player crops for prediction
219
+ if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
220
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
221
+ if crop.size > 0:
222
+ frame_player_crops.append(crop)
223
+ frame_player_indices.append(idx)
224
+ frame_player_boxes.append((x1, y1, x2, y2))
225
+
226
+ # Predict teams for all players in this frame
227
+ if len(frame_player_crops) > 0:
228
+ team_ids = self.team_classifier.predict(frame_player_crops)
229
+ predicted_frame_data[frame_id] = {}
230
+ for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
231
+ # Map team_id (0,1) to cls_id (6,7)
232
+ team_cls_id = str(6 + int(team_id))
233
+ predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
234
+
235
+ # Step 2: Process all frames (interpolate skipped frames)
236
+ fallback_count = 0
237
+ interpolated_count = 0
238
+ bboxes: dict[int, list[BoundingBox]] = {}
239
+ for frame_id in range(len(detection_results)):
240
+ detection_box = detection_results[frame_id].boxes.data
241
+ frame_image = decoded_images[frame_id]
242
+ boxes = []
243
+
244
+ team_predictions = {}
245
+
246
+ if frame_id % prediction_interval == 0:
247
+ # Predicted frame: use pre-computed predictions
248
+ if frame_id in predicted_frame_data:
249
+ for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
250
+ team_predictions[idx] = team_cls_id
251
+ else:
252
+ # Skipped frame: interpolate from neighboring predicted frames
253
+ # Find nearest predicted frames
254
+ prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
255
+ next_predicted_frame = prev_predicted_frame + prediction_interval
256
+
257
+ # Collect current frame player boxes
258
+ for idx, box in enumerate(detection_box):
259
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
260
+ if cls_id == 2 and conf < 0.6:
261
+ continue
262
+ mapped_cls_id = str(int(cls_id))
263
+
264
+ if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
265
+ target_box = (x1, y1, x2, y2)
266
+
267
+ # Try to match with previous predicted frame
268
+ best_team_id = None
269
+ best_iou = 0.0
270
+
271
+ if prev_predicted_frame in predicted_frame_data:
272
+ team_id, iou = self._find_best_match(
273
+ target_box,
274
+ predicted_frame_data[prev_predicted_frame],
275
+ iou_threshold
276
+ )
277
+ if team_id is not None:
278
+ best_team_id = team_id
279
+ best_iou = iou
280
+
281
+ # Try to match with next predicted frame if available and no good match yet
282
+ if best_team_id is None and next_predicted_frame < len(detection_results):
283
+ if next_predicted_frame in predicted_frame_data:
284
+ team_id, iou = self._find_best_match(
285
+ target_box,
286
+ predicted_frame_data[next_predicted_frame],
287
+ iou_threshold
288
+ )
289
+ if team_id is not None and iou > best_iou:
290
+ best_team_id = team_id
291
+ best_iou = iou
292
+
293
+ # Track interpolation success
294
+ if best_team_id is not None:
295
+ interpolated_count += 1
296
+ else:
297
+ # Fallback: if no match found, predict individually
298
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
299
+ if crop.size > 0:
300
+ team_id = self.team_classifier.predict([crop])[0]
301
+ best_team_id = str(6 + int(team_id))
302
+ fallback_count += 1
303
+
304
+ if best_team_id is not None:
305
+ team_predictions[idx] = best_team_id
306
+
307
+ # Parse boxes with team classification
308
+ for idx, box in enumerate(detection_box):
309
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
310
+ if cls_id == 2 and conf < 0.6:
311
  continue
312
+
313
+ # Check overlap with staff box
314
+ overlap_staff = False
315
+ for idy, boxy in enumerate(detection_box):
316
+ s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
317
+ if cls_id == 2 and s_cls_id == 4:
318
+ staff_iou = self._calculate_iou(box[:4], boxy[:4])
319
+ if staff_iou >= 0.8:
320
+ overlap_staff = True
321
  break
322
+ if overlap_staff:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  continue
324
+
325
+ mapped_cls_id = str(int(cls_id))
326
+
327
+ # Override cls_id for players with team prediction
328
+ if idx in team_predictions:
329
+ mapped_cls_id = team_predictions[idx]
330
+ if mapped_cls_id != '4':
331
+ if int(mapped_cls_id) == 3 and conf < 0.5:
332
+ continue
333
  boxes.append(
334
  BoundingBox(
335
  x1=int(x1),
336
  y1=int(y1),
337
  x2=int(x2),
338
  y2=int(y2),
339
+ cls_id=int(mapped_cls_id),
340
  conf=float(conf),
341
  )
342
  )
343
+ # Handle footballs - keep only the best one
344
+ footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
345
+ if len(footballs) > 1:
346
+ best_ball = max(footballs, key=lambda b: b.conf)
347
+ boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
348
+ boxes.append(best_ball)
349
+
350
+ bboxes[offset + frame_id] = boxes
351
+ return bboxes
352
+
353
+
354
+ def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
355
+ start = time.time()
356
+ detection_results = self._detect_objects_batch(batch_images)
357
+ end = time.time()
358
+ print(f"Detection time: {end - start}")
359
+ start = time.time()
360
+ bboxes = self._team_classify(detection_results, batch_images, offset)
361
+ end = time.time()
362
+ print(f"Team classify time: {end - start}")
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
365
  keypoints: Dict[int, List[Tuple[int, int]]] = {}
366
+
367
+ start = time.time()
368
  while True:
 
369
  gc.collect()
370
  if torch.cuda.is_available():
 
371
  torch.cuda.empty_cache()
372
  torch.cuda.synchronize()
373
+ device_str = "cuda"
374
  keypoints_result = process_batch_input(
375
  batch_images,
376
  self.keypoints_model,
 
405
  else:
406
  frame_keypoints = frame_keypoints[:n_keypoints]
407
  keypoints[offset + frame_number_in_batch] = frame_keypoints
 
408
  break
409
+ end = time.time()
410
+ print(f"Keypoint time: {end - start}")
411
+
 
 
 
 
 
 
 
 
 
 
412
 
413
  results: List[TVFrameResult] = []
414
  for frame_number in range(offset, offset + len(batch_images)):
 
423
 
424
  gc.collect()
425
  if torch.cuda.is_available():
 
426
  torch.cuda.empty_cache()
427
  torch.cuda.synchronize()
428
 
osnet_ain.pyc ADDED
Binary file (24.2 kB). View file
 
pitch.py CHANGED
@@ -660,28 +660,10 @@ def get_mapped_keypoints(kp_points):
660
  # mapped_points[key] = value
661
  return mapped_points
662
 
663
- def process_batch_input(frames, model, kp_threshold, device, batch_size=8):
664
  """Process multiple input images in batch"""
665
  # Batch inference
666
  kp_results = inference_batch(frames, model, kp_threshold, device, batch_size)
667
  kp_results = [get_mapped_keypoints(kp) for kp in kp_results]
668
- # Draw results and save
669
- # for i, (frame, kp_points, input_path) in enumerate(zip(frames, kp_results, valid_paths)):
670
- # height, width = frame.shape[:2]
671
-
672
- # # Apply mapping to get standard keypoint IDs
673
- # mapped_kp_points = get_mapped_keypoints(kp_points)
674
-
675
- # for key, value in mapped_kp_points.items():
676
- # x = int(value['x'] * width)
677
- # y = int(value['y'] * height)
678
- # cv2.circle(frame, (x, y), 5, (0, 255, 0), -1) # Green circles
679
- # cv2.putText(frame, str(key), (x+10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
680
-
681
- # # Save result
682
- # output_path = input_path.replace('.png', '_result.png').replace('.jpg', '_result.jpg')
683
- # cv2.imwrite(output_path, frame)
684
-
685
- # print(f"Batch processing complete. Processed {len(frames)} images.")
686
 
687
  return kp_results
 
660
  # mapped_points[key] = value
661
  return mapped_points
662
 
663
+ def process_batch_input(frames, model, kp_threshold, device, batch_size=16):
664
  """Process multiple input images in batch"""
665
  # Batch inference
666
  kp_results = inference_batch(frames, model, kp_threshold, device, batch_size)
667
  kp_results = [get_mapped_keypoints(kp) for kp in kp_results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
669
  return kp_results
team_cluster.pyc ADDED
Binary file (7.62 kB). View file
 
utils.pyc ADDED
Binary file (20.6 kB). View file