tarto2 commited on
Commit
424a89a
·
verified ·
1 Parent(s): 43eec27

Update miner.py

Browse files
Files changed (1) hide show
  1. miner.py +433 -433
miner.py CHANGED
@@ -1,434 +1,434 @@
1
- from pathlib import Path
2
- from typing import List, Tuple, Dict
3
- import sys
4
- import os
5
-
6
- from numpy import ndarray
7
- from pydantic import BaseModel
8
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
- from keypoint_helper import run_keypoints_post_processing
10
-
11
- from ultralytics import YOLO
12
- from team_cluster import TeamClassifier
13
- from utils import (
14
- BoundingBox,
15
- Constants,
16
- )
17
-
18
- import time
19
- import torch
20
- import gc
21
- from pitch import process_batch_input, get_cls_net
22
- import yaml
23
-
24
-
25
- class BoundingBox(BaseModel):
26
- x1: int
27
- y1: int
28
- x2: int
29
- y2: int
30
- cls_id: int
31
- conf: float
32
-
33
-
34
- class TVFrameResult(BaseModel):
35
- frame_id: int
36
- boxes: List[BoundingBox]
37
- keypoints: List[Tuple[int, int]]
38
-
39
-
40
- class Miner:
41
- SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
42
- SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
43
- SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
44
- CORNER_INDICES = Constants.CORNER_INDICES
45
- KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
46
- CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
47
- GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
48
- MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
49
- MAX_SAMPLES_FOR_FIT = 1000 # Maximum samples to avoid overfitting
50
-
51
- def __init__(self, path_hf_repo: Path) -> None:
52
- try:
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- model_path = path_hf_repo / "football_object_detection.onnx"
55
- self.bbox_model = YOLO(model_path)
56
-
57
- print("BBox Model Loaded")
58
-
59
- team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
60
- self.team_classifier = TeamClassifier(
61
- device=device,
62
- batch_size=32,
63
- model_name=str(team_model_path)
64
- )
65
- print("Team Classifier Loaded")
66
-
67
- # Team classification state
68
- self.team_classifier_fitted = False
69
- self.player_crops_for_fit = []
70
-
71
- model_kp_path = path_hf_repo / 'keypoint'
72
- config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
73
- cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
74
-
75
- loaded_state_kp = torch.load(model_kp_path, map_location=device)
76
- model = get_cls_net(cfg_kp)
77
- model.load_state_dict(loaded_state_kp)
78
- model.to(device)
79
- model.eval()
80
-
81
- self.keypoints_model = model
82
- self.kp_threshold = 0.1
83
- self.pitch_batch_size = 4
84
- self.health = "healthy"
85
- print("✅ Keypoints Model Loaded")
86
- except Exception as e:
87
- self.health = "❌ Miner initialization failed: " + str(e)
88
- print(self.health)
89
-
90
- def __repr__(self) -> str:
91
- if self.health == 'healthy':
92
- return (
93
- f"health: {self.health}\n"
94
- f"BBox Model: {type(self.bbox_model).__name__}\n"
95
- f"Keypoints Model: {type(self.keypoints_model).__name__}"
96
- )
97
- else:
98
- return self.health
99
-
100
- def _calculate_iou(self, box1: Tuple[float, float, float, float],
101
- box2: Tuple[float, float, float, float]) -> float:
102
- """
103
- Calculate Intersection over Union (IoU) between two bounding boxes.
104
- Args:
105
- box1: (x1, y1, x2, y2)
106
- box2: (x1, y1, x2, y2)
107
- Returns:
108
- IoU score (0-1)
109
- """
110
- x1_1, y1_1, x2_1, y2_1 = box1
111
- x1_2, y1_2, x2_2, y2_2 = box2
112
-
113
- # Calculate intersection area
114
- x_left = max(x1_1, x1_2)
115
- y_top = max(y1_1, y1_2)
116
- x_right = min(x2_1, x2_2)
117
- y_bottom = min(y2_1, y2_2)
118
-
119
- if x_right < x_left or y_bottom < y_top:
120
- return 0.0
121
-
122
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
123
-
124
- # Calculate union area
125
- box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
126
- box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
127
- union_area = box1_area + box2_area - intersection_area
128
-
129
- if union_area == 0:
130
- return 0.0
131
-
132
- return intersection_area / union_area
133
-
134
- def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
135
- batch_size = 16
136
- detection_results = []
137
- n_frames = len(decoded_images)
138
- for frame_number in range(0, n_frames, batch_size):
139
- batch_images = decoded_images[frame_number: frame_number + batch_size]
140
- detections = self.bbox_model(batch_images, verbose=False, save=False)
141
- detection_results.extend(detections)
142
-
143
- return detection_results
144
-
145
- def _team_classify(self, detection_results, decoded_images, offset):
146
- self.team_classifier_fitted = False
147
- start = time.time()
148
- # Collect player crops from first batch for fitting
149
- fit_sample_size = 1000
150
- player_crops_for_fit = []
151
-
152
- for frame_id in range(len(detection_results)):
153
- detection_box = detection_results[frame_id].boxes.data
154
- if len(detection_box) < 4:
155
- continue
156
- # Collect player boxes for team classification fitting (first batch only)
157
- if len(player_crops_for_fit) < fit_sample_size:
158
- frame_image = decoded_images[frame_id]
159
- for box in detection_box:
160
- x1, y1, x2, y2, conf, cls_id = box.tolist()
161
- if conf < 0.5:
162
- continue
163
- mapped_cls_id = str(int(cls_id))
164
- # Only collect player crops (cls_id = 2)
165
- if mapped_cls_id == '2':
166
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
167
- if crop.size > 0:
168
- player_crops_for_fit.append(crop)
169
-
170
- # Fit team classifier after collecting samples
171
- if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
172
- print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
173
- self.team_classifier.fit(player_crops_for_fit)
174
- self.team_classifier_fitted = True
175
- break
176
- if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
177
- print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
178
- self.team_classifier.fit(player_crops_for_fit)
179
- self.team_classifier_fitted = True
180
- end = time.time()
181
- print(f"Fitting Kmeans time: {end - start}")
182
-
183
- # Second pass: predict teams with configurable frame skipping optimization
184
- start = time.time()
185
-
186
- # Get configuration for frame skipping
187
- prediction_interval = 1 # Default: predict every 2 frames
188
- iou_threshold = 0.3
189
-
190
- print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
191
-
192
- # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
193
- predicted_frame_data = {}
194
-
195
- # Step 1: Predict for frames at prediction_interval only
196
- frames_to_predict = []
197
- for frame_id in range(len(detection_results)):
198
- if frame_id % prediction_interval == 0:
199
- frames_to_predict.append(frame_id)
200
-
201
- print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
202
- f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
203
-
204
- for frame_id in frames_to_predict:
205
- detection_box = detection_results[frame_id].boxes.data
206
- frame_image = decoded_images[frame_id]
207
-
208
- # Collect player crops for this frame
209
- frame_player_crops = []
210
- frame_player_indices = []
211
- frame_player_boxes = []
212
-
213
- for idx, box in enumerate(detection_box):
214
- x1, y1, x2, y2, conf, cls_id = box.tolist()
215
- if cls_id == 2 and conf < 0.6:
216
- continue
217
- mapped_cls_id = str(int(cls_id))
218
-
219
- # Collect player crops for prediction
220
- if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
221
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
222
- if crop.size > 0:
223
- frame_player_crops.append(crop)
224
- frame_player_indices.append(idx)
225
- frame_player_boxes.append((x1, y1, x2, y2))
226
-
227
- # Predict teams for all players in this frame
228
- if len(frame_player_crops) > 0:
229
- team_ids = self.team_classifier.predict(frame_player_crops)
230
- predicted_frame_data[frame_id] = {}
231
- for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
232
- # Map team_id (0,1) to cls_id (6,7)
233
- team_cls_id = str(6 + int(team_id))
234
- predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
235
-
236
- # Step 2: Process all frames (interpolate skipped frames)
237
- fallback_count = 0
238
- interpolated_count = 0
239
- bboxes: dict[int, list[BoundingBox]] = {}
240
- for frame_id in range(len(detection_results)):
241
- detection_box = detection_results[frame_id].boxes.data
242
- frame_image = decoded_images[frame_id]
243
- boxes = []
244
-
245
- team_predictions = {}
246
-
247
- if frame_id % prediction_interval == 0:
248
- # Predicted frame: use pre-computed predictions
249
- if frame_id in predicted_frame_data:
250
- for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
251
- team_predictions[idx] = team_cls_id
252
- else:
253
- # Skipped frame: interpolate from neighboring predicted frames
254
- # Find nearest predicted frames
255
- prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
256
- next_predicted_frame = prev_predicted_frame + prediction_interval
257
-
258
- # Collect current frame player boxes
259
- for idx, box in enumerate(detection_box):
260
- x1, y1, x2, y2, conf, cls_id = box.tolist()
261
- if cls_id == 2 and conf < 0.6:
262
- continue
263
- mapped_cls_id = str(int(cls_id))
264
-
265
- if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
266
- target_box = (x1, y1, x2, y2)
267
-
268
- # Try to match with previous predicted frame
269
- best_team_id = None
270
- best_iou = 0.0
271
-
272
- if prev_predicted_frame in predicted_frame_data:
273
- team_id, iou = self._find_best_match(
274
- target_box,
275
- predicted_frame_data[prev_predicted_frame],
276
- iou_threshold
277
- )
278
- if team_id is not None:
279
- best_team_id = team_id
280
- best_iou = iou
281
-
282
- # Try to match with next predicted frame if available and no good match yet
283
- if best_team_id is None and next_predicted_frame < len(detection_results):
284
- if next_predicted_frame in predicted_frame_data:
285
- team_id, iou = self._find_best_match(
286
- target_box,
287
- predicted_frame_data[next_predicted_frame],
288
- iou_threshold
289
- )
290
- if team_id is not None and iou > best_iou:
291
- best_team_id = team_id
292
- best_iou = iou
293
-
294
- # Track interpolation success
295
- if best_team_id is not None:
296
- interpolated_count += 1
297
- else:
298
- # Fallback: if no match found, predict individually
299
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
300
- if crop.size > 0:
301
- team_id = self.team_classifier.predict([crop])[0]
302
- best_team_id = str(6 + int(team_id))
303
- fallback_count += 1
304
-
305
- if best_team_id is not None:
306
- team_predictions[idx] = best_team_id
307
-
308
- # Parse boxes with team classification
309
- for idx, box in enumerate(detection_box):
310
- x1, y1, x2, y2, conf, cls_id = box.tolist()
311
- if cls_id == 2 and conf < 0.6:
312
- continue
313
-
314
- # Check overlap with staff box
315
- overlap_staff = False
316
- for idy, boxy in enumerate(detection_box):
317
- s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
318
- if cls_id == 2 and s_cls_id == 4:
319
- staff_iou = self._calculate_iou(box[:4], boxy[:4])
320
- if staff_iou >= 0.8:
321
- overlap_staff = True
322
- break
323
- if overlap_staff:
324
- continue
325
-
326
- mapped_cls_id = str(int(cls_id))
327
-
328
- # Override cls_id for players with team prediction
329
- if idx in team_predictions:
330
- mapped_cls_id = team_predictions[idx]
331
- if mapped_cls_id != '4':
332
- if int(mapped_cls_id) == 3 and conf < 0.5:
333
- continue
334
- boxes.append(
335
- BoundingBox(
336
- x1=int(x1),
337
- y1=int(y1),
338
- x2=int(x2),
339
- y2=int(y2),
340
- cls_id=int(mapped_cls_id),
341
- conf=float(conf),
342
- )
343
- )
344
- # Handle footballs - keep only the best one
345
- footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
346
- if len(footballs) > 1:
347
- best_ball = max(footballs, key=lambda b: b.conf)
348
- boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
349
- boxes.append(best_ball)
350
-
351
- bboxes[offset + frame_id] = boxes
352
- return bboxes
353
-
354
-
355
- def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
356
- start = time.time()
357
- detection_results = self._detect_objects_batch(batch_images)
358
- end = time.time()
359
- print(f"Detection time: {end - start}")
360
- start = time.time()
361
- bboxes = self._team_classify(detection_results, batch_images, offset)
362
- end = time.time()
363
- print(f"Team classify time: {end - start}")
364
-
365
- pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
366
- keypoints: Dict[int, List[Tuple[int, int]]] = {}
367
-
368
- start = time.time()
369
- while True:
370
- gc.collect()
371
- if torch.cuda.is_available():
372
- torch.cuda.empty_cache()
373
- torch.cuda.synchronize()
374
- device_str = "cuda"
375
- keypoints_result = process_batch_input(
376
- batch_images,
377
- self.keypoints_model,
378
- self.kp_threshold,
379
- device_str,
380
- batch_size=pitch_batch_size,
381
- )
382
- if keypoints_result is not None and len(keypoints_result) > 0:
383
- for frame_number_in_batch, kp_dict in enumerate(keypoints_result):
384
- if frame_number_in_batch >= len(batch_images):
385
- break
386
- frame_keypoints: List[Tuple[int, int]] = []
387
- try:
388
- height, width = batch_images[frame_number_in_batch].shape[:2]
389
- if kp_dict is not None and isinstance(kp_dict, dict):
390
- for idx in range(32):
391
- x, y = 0, 0
392
- kp_idx = idx + 1
393
- if kp_idx in kp_dict:
394
- try:
395
- kp_data = kp_dict[kp_idx]
396
- if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
397
- x = int(kp_data["x"] * width)
398
- y = int(kp_data["y"] * height)
399
- except (KeyError, TypeError, ValueError):
400
- pass
401
- frame_keypoints.append((x, y))
402
- except (IndexError, ValueError, AttributeError):
403
- frame_keypoints = [(0, 0)] * 32
404
- if len(frame_keypoints) < n_keypoints:
405
- frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
406
- else:
407
- frame_keypoints = frame_keypoints[:n_keypoints]
408
- keypoints[offset + frame_number_in_batch] = frame_keypoints
409
- break
410
- end = time.time()
411
- print(f"Keypoint time: {end - start}")
412
-
413
-
414
- results: List[TVFrameResult] = []
415
- for frame_number in range(offset, offset + len(batch_images)):
416
- frame_boxes = bboxes.get(frame_number, [])
417
- frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)])
418
- result = TVFrameResult(
419
- frame_id=frame_number,
420
- boxes=frame_boxes,
421
- keypoints=frame_keypoints,
422
- )
423
- results.append(result)
424
-
425
- if len(batch_images) > 0:
426
- h, w = batch_images[0].shape[:2]
427
- results = run_keypoints_post_processing(results, w, h)
428
-
429
- gc.collect()
430
- if torch.cuda.is_available():
431
- torch.cuda.empty_cache()
432
- torch.cuda.synchronize()
433
-
434
  return results
 
1
+ from pathlib import Path
2
+ from typing import List, Tuple, Dict
3
+ import sys
4
+ import os
5
+
6
+ from numpy import ndarray
7
+ from pydantic import BaseModel
8
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
+ from keypoint_helper import run_keypoints_post_processing
10
+
11
+ from ultralytics import YOLO
12
+ from team_cluster import TeamClassifier
13
+ from utils import (
14
+ BoundingBox,
15
+ Constants,
16
+ )
17
+
18
+ import time
19
+ import torch
20
+ import gc
21
+ from pitch import process_batch_input, get_cls_net
22
+ import yaml
23
+
24
+
25
+ class BoundingBox(BaseModel):
26
+ x1: int
27
+ y1: int
28
+ x2: int
29
+ y2: int
30
+ cls_id: int
31
+ conf: float
32
+
33
+
34
+ class TVFrameResult(BaseModel):
35
+ frame_id: int
36
+ boxes: List[BoundingBox]
37
+ keypoints: List[Tuple[int, int]]
38
+
39
+
40
+ class Miner:
41
+ SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
42
+ SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
43
+ SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
44
+ CORNER_INDICES = Constants.CORNER_INDICES
45
+ KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
46
+ CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
47
+ GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
48
+ MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
49
+ MAX_SAMPLES_FOR_FIT = 700 # Maximum samples to avoid overfitting
50
+
51
+ def __init__(self, path_hf_repo: Path) -> None:
52
+ try:
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ model_path = path_hf_repo / "football_object_detection.onnx"
55
+ self.bbox_model = YOLO(model_path)
56
+
57
+ print("BBox Model Loaded")
58
+
59
+ team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
60
+ self.team_classifier = TeamClassifier(
61
+ device=device,
62
+ batch_size=32,
63
+ model_name=str(team_model_path)
64
+ )
65
+ print("Team Classifier Loaded")
66
+
67
+ # Team classification state
68
+ self.team_classifier_fitted = False
69
+ self.player_crops_for_fit = []
70
+
71
+ model_kp_path = path_hf_repo / 'keypoint'
72
+ config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
73
+ cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
74
+
75
+ loaded_state_kp = torch.load(model_kp_path, map_location=device)
76
+ model = get_cls_net(cfg_kp)
77
+ model.load_state_dict(loaded_state_kp)
78
+ model.to(device)
79
+ model.eval()
80
+
81
+ self.keypoints_model = model
82
+ self.kp_threshold = 0.1
83
+ self.pitch_batch_size = 4
84
+ self.health = "healthy"
85
+ print("✅ Keypoints Model Loaded")
86
+ except Exception as e:
87
+ self.health = "❌ Miner initialization failed: " + str(e)
88
+ print(self.health)
89
+
90
+ def __repr__(self) -> str:
91
+ if self.health == 'healthy':
92
+ return (
93
+ f"health: {self.health}\n"
94
+ f"BBox Model: {type(self.bbox_model).__name__}\n"
95
+ f"Keypoints Model: {type(self.keypoints_model).__name__}"
96
+ )
97
+ else:
98
+ return self.health
99
+
100
+ def _calculate_iou(self, box1: Tuple[float, float, float, float],
101
+ box2: Tuple[float, float, float, float]) -> float:
102
+ """
103
+ Calculate Intersection over Union (IoU) between two bounding boxes.
104
+ Args:
105
+ box1: (x1, y1, x2, y2)
106
+ box2: (x1, y1, x2, y2)
107
+ Returns:
108
+ IoU score (0-1)
109
+ """
110
+ x1_1, y1_1, x2_1, y2_1 = box1
111
+ x1_2, y1_2, x2_2, y2_2 = box2
112
+
113
+ # Calculate intersection area
114
+ x_left = max(x1_1, x1_2)
115
+ y_top = max(y1_1, y1_2)
116
+ x_right = min(x2_1, x2_2)
117
+ y_bottom = min(y2_1, y2_2)
118
+
119
+ if x_right < x_left or y_bottom < y_top:
120
+ return 0.0
121
+
122
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
123
+
124
+ # Calculate union area
125
+ box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
126
+ box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
127
+ union_area = box1_area + box2_area - intersection_area
128
+
129
+ if union_area == 0:
130
+ return 0.0
131
+
132
+ return intersection_area / union_area
133
+
134
+ def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
135
+ batch_size = 16
136
+ detection_results = []
137
+ n_frames = len(decoded_images)
138
+ for frame_number in range(0, n_frames, batch_size):
139
+ batch_images = decoded_images[frame_number: frame_number + batch_size]
140
+ detections = self.bbox_model(batch_images, verbose=False, save=False)
141
+ detection_results.extend(detections)
142
+
143
+ return detection_results
144
+
145
+ def _team_classify(self, detection_results, decoded_images, offset):
146
+ self.team_classifier_fitted = False
147
+ start = time.time()
148
+ # Collect player crops from first batch for fitting
149
+ fit_sample_size = 700
150
+ player_crops_for_fit = []
151
+
152
+ for frame_id in range(len(detection_results)):
153
+ detection_box = detection_results[frame_id].boxes.data
154
+ if len(detection_box) < 4:
155
+ continue
156
+ # Collect player boxes for team classification fitting (first batch only)
157
+ if len(player_crops_for_fit) < fit_sample_size:
158
+ frame_image = decoded_images[frame_id]
159
+ for box in detection_box:
160
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
161
+ if conf < 0.5:
162
+ continue
163
+ mapped_cls_id = str(int(cls_id))
164
+ # Only collect player crops (cls_id = 2)
165
+ if mapped_cls_id == '2':
166
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
167
+ if crop.size > 0:
168
+ player_crops_for_fit.append(crop)
169
+
170
+ # Fit team classifier after collecting samples
171
+ if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
172
+ print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
173
+ self.team_classifier.fit(player_crops_for_fit)
174
+ self.team_classifier_fitted = True
175
+ break
176
+ if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
177
+ print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
178
+ self.team_classifier.fit(player_crops_for_fit)
179
+ self.team_classifier_fitted = True
180
+ end = time.time()
181
+ print(f"Fitting Kmeans time: {end - start}")
182
+
183
+ # Second pass: predict teams with configurable frame skipping optimization
184
+ start = time.time()
185
+
186
+ # Get configuration for frame skipping
187
+ prediction_interval = 1 # Default: predict every 2 frames
188
+ iou_threshold = 0.3
189
+
190
+ print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
191
+
192
+ # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
193
+ predicted_frame_data = {}
194
+
195
+ # Step 1: Predict for frames at prediction_interval only
196
+ frames_to_predict = []
197
+ for frame_id in range(len(detection_results)):
198
+ if frame_id % prediction_interval == 0:
199
+ frames_to_predict.append(frame_id)
200
+
201
+ print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
202
+ f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
203
+
204
+ for frame_id in frames_to_predict:
205
+ detection_box = detection_results[frame_id].boxes.data
206
+ frame_image = decoded_images[frame_id]
207
+
208
+ # Collect player crops for this frame
209
+ frame_player_crops = []
210
+ frame_player_indices = []
211
+ frame_player_boxes = []
212
+
213
+ for idx, box in enumerate(detection_box):
214
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
215
+ if cls_id == 2 and conf < 0.6:
216
+ continue
217
+ mapped_cls_id = str(int(cls_id))
218
+
219
+ # Collect player crops for prediction
220
+ if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
221
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
222
+ if crop.size > 0:
223
+ frame_player_crops.append(crop)
224
+ frame_player_indices.append(idx)
225
+ frame_player_boxes.append((x1, y1, x2, y2))
226
+
227
+ # Predict teams for all players in this frame
228
+ if len(frame_player_crops) > 0:
229
+ team_ids = self.team_classifier.predict(frame_player_crops)
230
+ predicted_frame_data[frame_id] = {}
231
+ for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
232
+ # Map team_id (0,1) to cls_id (6,7)
233
+ team_cls_id = str(6 + int(team_id))
234
+ predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
235
+
236
+ # Step 2: Process all frames (interpolate skipped frames)
237
+ fallback_count = 0
238
+ interpolated_count = 0
239
+ bboxes: dict[int, list[BoundingBox]] = {}
240
+ for frame_id in range(len(detection_results)):
241
+ detection_box = detection_results[frame_id].boxes.data
242
+ frame_image = decoded_images[frame_id]
243
+ boxes = []
244
+
245
+ team_predictions = {}
246
+
247
+ if frame_id % prediction_interval == 0:
248
+ # Predicted frame: use pre-computed predictions
249
+ if frame_id in predicted_frame_data:
250
+ for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
251
+ team_predictions[idx] = team_cls_id
252
+ else:
253
+ # Skipped frame: interpolate from neighboring predicted frames
254
+ # Find nearest predicted frames
255
+ prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
256
+ next_predicted_frame = prev_predicted_frame + prediction_interval
257
+
258
+ # Collect current frame player boxes
259
+ for idx, box in enumerate(detection_box):
260
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
261
+ if cls_id == 2 and conf < 0.6:
262
+ continue
263
+ mapped_cls_id = str(int(cls_id))
264
+
265
+ if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
266
+ target_box = (x1, y1, x2, y2)
267
+
268
+ # Try to match with previous predicted frame
269
+ best_team_id = None
270
+ best_iou = 0.0
271
+
272
+ if prev_predicted_frame in predicted_frame_data:
273
+ team_id, iou = self._find_best_match(
274
+ target_box,
275
+ predicted_frame_data[prev_predicted_frame],
276
+ iou_threshold
277
+ )
278
+ if team_id is not None:
279
+ best_team_id = team_id
280
+ best_iou = iou
281
+
282
+ # Try to match with next predicted frame if available and no good match yet
283
+ if best_team_id is None and next_predicted_frame < len(detection_results):
284
+ if next_predicted_frame in predicted_frame_data:
285
+ team_id, iou = self._find_best_match(
286
+ target_box,
287
+ predicted_frame_data[next_predicted_frame],
288
+ iou_threshold
289
+ )
290
+ if team_id is not None and iou > best_iou:
291
+ best_team_id = team_id
292
+ best_iou = iou
293
+
294
+ # Track interpolation success
295
+ if best_team_id is not None:
296
+ interpolated_count += 1
297
+ else:
298
+ # Fallback: if no match found, predict individually
299
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
300
+ if crop.size > 0:
301
+ team_id = self.team_classifier.predict([crop])[0]
302
+ best_team_id = str(6 + int(team_id))
303
+ fallback_count += 1
304
+
305
+ if best_team_id is not None:
306
+ team_predictions[idx] = best_team_id
307
+
308
+ # Parse boxes with team classification
309
+ for idx, box in enumerate(detection_box):
310
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
311
+ if cls_id == 2 and conf < 0.6:
312
+ continue
313
+
314
+ # Check overlap with staff box
315
+ overlap_staff = False
316
+ for idy, boxy in enumerate(detection_box):
317
+ s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
318
+ if cls_id == 2 and s_cls_id == 4:
319
+ staff_iou = self._calculate_iou(box[:4], boxy[:4])
320
+ if staff_iou >= 0.8:
321
+ overlap_staff = True
322
+ break
323
+ if overlap_staff:
324
+ continue
325
+
326
+ mapped_cls_id = str(int(cls_id))
327
+
328
+ # Override cls_id for players with team prediction
329
+ if idx in team_predictions:
330
+ mapped_cls_id = team_predictions[idx]
331
+ if mapped_cls_id != '4':
332
+ if int(mapped_cls_id) == 3 and conf < 0.5:
333
+ continue
334
+ boxes.append(
335
+ BoundingBox(
336
+ x1=int(x1),
337
+ y1=int(y1),
338
+ x2=int(x2),
339
+ y2=int(y2),
340
+ cls_id=int(mapped_cls_id),
341
+ conf=float(conf),
342
+ )
343
+ )
344
+ # Handle footballs - keep only the best one
345
+ footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
346
+ if len(footballs) > 1:
347
+ best_ball = max(footballs, key=lambda b: b.conf)
348
+ boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
349
+ boxes.append(best_ball)
350
+
351
+ bboxes[offset + frame_id] = boxes
352
+ return bboxes
353
+
354
+
355
+ def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
356
+ start = time.time()
357
+ detection_results = self._detect_objects_batch(batch_images)
358
+ end = time.time()
359
+ print(f"Detection time: {end - start}")
360
+ start = time.time()
361
+ bboxes = self._team_classify(detection_results, batch_images, offset)
362
+ end = time.time()
363
+ print(f"Team classify time: {end - start}")
364
+
365
+ pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
366
+ keypoints: Dict[int, List[Tuple[int, int]]] = {}
367
+
368
+ start = time.time()
369
+ while True:
370
+ gc.collect()
371
+ if torch.cuda.is_available():
372
+ torch.cuda.empty_cache()
373
+ torch.cuda.synchronize()
374
+ device_str = "cuda"
375
+ keypoints_result = process_batch_input(
376
+ batch_images,
377
+ self.keypoints_model,
378
+ self.kp_threshold,
379
+ device_str,
380
+ batch_size=pitch_batch_size,
381
+ )
382
+ if keypoints_result is not None and len(keypoints_result) > 0:
383
+ for frame_number_in_batch, kp_dict in enumerate(keypoints_result):
384
+ if frame_number_in_batch >= len(batch_images):
385
+ break
386
+ frame_keypoints: List[Tuple[int, int]] = []
387
+ try:
388
+ height, width = batch_images[frame_number_in_batch].shape[:2]
389
+ if kp_dict is not None and isinstance(kp_dict, dict):
390
+ for idx in range(32):
391
+ x, y = 0, 0
392
+ kp_idx = idx + 1
393
+ if kp_idx in kp_dict:
394
+ try:
395
+ kp_data = kp_dict[kp_idx]
396
+ if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
397
+ x = int(kp_data["x"] * width)
398
+ y = int(kp_data["y"] * height)
399
+ except (KeyError, TypeError, ValueError):
400
+ pass
401
+ frame_keypoints.append((x, y))
402
+ except (IndexError, ValueError, AttributeError):
403
+ frame_keypoints = [(0, 0)] * 32
404
+ if len(frame_keypoints) < n_keypoints:
405
+ frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
406
+ else:
407
+ frame_keypoints = frame_keypoints[:n_keypoints]
408
+ keypoints[offset + frame_number_in_batch] = frame_keypoints
409
+ break
410
+ end = time.time()
411
+ print(f"Keypoint time: {end - start}")
412
+
413
+
414
+ results: List[TVFrameResult] = []
415
+ for frame_number in range(offset, offset + len(batch_images)):
416
+ frame_boxes = bboxes.get(frame_number, [])
417
+ frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)])
418
+ result = TVFrameResult(
419
+ frame_id=frame_number,
420
+ boxes=frame_boxes,
421
+ keypoints=frame_keypoints,
422
+ )
423
+ results.append(result)
424
+
425
+ if len(batch_images) > 0:
426
+ h, w = batch_images[0].shape[:2]
427
+ results = run_keypoints_post_processing(results, w, h)
428
+
429
+ gc.collect()
430
+ if torch.cuda.is_available():
431
+ torch.cuda.empty_cache()
432
+ torch.cuda.synchronize()
433
+
434
  return results