tarto2 commited on
Commit
779104d
·
verified ·
1 Parent(s): 8fa61d8

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. keypoint_helper.py +116 -0
  2. miner.py +158 -364
keypoint_helper.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from typing import List, Tuple, Sequence, Any
5
+
6
+ FOOTBALL_KEYPOINTS: list[tuple[int, int]] = [
7
+ (0, 0), # 1
8
+ (0, 0), # 2
9
+ (0, 0), # 3
10
+ (0, 0), # 4
11
+ (0, 0), # 5
12
+ (0, 0), # 6
13
+
14
+ (0, 0), # 7
15
+ (0, 0), # 8
16
+ (0, 0), # 9
17
+
18
+ (0, 0), # 10
19
+ (0, 0), # 11
20
+ (0, 0), # 12
21
+ (0, 0), # 13
22
+
23
+ (0, 0), # 14
24
+ (527, 283), # 15
25
+ (527, 403), # 16
26
+ (0, 0), # 17
27
+
28
+ (0, 0), # 18
29
+ (0, 0), # 19
30
+ (0, 0), # 20
31
+ (0, 0), # 21
32
+
33
+ (0, 0), # 22
34
+
35
+ (0, 0), # 23
36
+ (0, 0), # 24
37
+
38
+ (0, 0), # 25
39
+ (0, 0), # 26
40
+ (0, 0), # 27
41
+ (0, 0), # 28
42
+ (0, 0), # 29
43
+ (0, 0), # 30
44
+
45
+ (405, 340), # 31
46
+ (645, 340), # 32
47
+ ]
48
+
49
+ def convert_keypoints_to_val_format(keypoints):
50
+ return [tuple(int(x) for x in pair) for pair in keypoints]
51
+
52
+ def predict_failed_indices(results_frames: Sequence[Any]) -> list[int]:
53
+
54
+ max_frames = len(results_frames)
55
+ if max_frames == 0:
56
+ return []
57
+
58
+ failed_indices: list[int] = []
59
+ for frame_index, frame_result in enumerate(results_frames):
60
+ frame_keypoints = getattr(frame_result, "keypoints", []) or []
61
+ non_zero_count = sum(1 for (x, y) in frame_keypoints if int(x) != 0 and int(y) != 0)
62
+ if non_zero_count <= 4:
63
+ failed_indices.append(frame_index)
64
+ return failed_indices
65
+
66
+ def _generate_sparse_template_keypoints(frame_width: int, frame_height: int) -> list[tuple[int, int]]:
67
+ template_max_x, template_max_y = (1045, 675)
68
+ sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1)
69
+ sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1)
70
+ scaled: list[tuple[int, int]] = []
71
+ for i in range(32):
72
+ tx, ty = FOOTBALL_KEYPOINTS[i]
73
+ x_scaled = int(round(tx * sx))
74
+ y_scaled = int(round(ty * sy))
75
+ scaled.append((x_scaled, y_scaled))
76
+ return scaled
77
+
78
+ def fix_keypoints(
79
+ results_frames: Sequence[Any],
80
+ failed_indices: Sequence[int],
81
+ frame_width: int,
82
+ frame_height: int,
83
+ ) -> list[Any]:
84
+ max_frames = len(results_frames)
85
+ if max_frames == 0:
86
+ return list(results_frames)
87
+
88
+ failed_set = set(int(i) for i in failed_indices)
89
+ all_indices = list(range(max_frames))
90
+ successful_indices = [i for i in all_indices if i not in failed_set]
91
+
92
+ if len(successful_indices) == 0:
93
+ sparse_template = _generate_sparse_template_keypoints(frame_width, frame_height)
94
+ for frame_result in results_frames:
95
+ setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(sparse_template)))
96
+ return list(results_frames)
97
+
98
+ seed_index = successful_indices[0]
99
+ seed_kps_raw = getattr(results_frames[seed_index], "keypoints", []) or []
100
+ last_success_kps = convert_keypoints_to_val_format(seed_kps_raw)
101
+
102
+ for frame_index in range(max_frames):
103
+ frame_result = results_frames[frame_index]
104
+ if frame_index in failed_set:
105
+ setattr(frame_result, "keypoints", list(last_success_kps))
106
+ else:
107
+ current_kps_raw = getattr(frame_result, "keypoints", []) or []
108
+ current_kps = convert_keypoints_to_val_format(current_kps_raw)
109
+ setattr(frame_result, "keypoints", list(current_kps))
110
+ last_success_kps = current_kps
111
+
112
+ return list(results_frames)
113
+
114
+ def run_keypoints_post_processing(results_frames: Sequence[Any], frame_width: int, frame_height: int) -> list[Any]:
115
+ failed_indices = predict_failed_indices(results_frames)
116
+ return fix_keypoints(results_frames, failed_indices, frame_width, frame_height)
miner.py CHANGED
@@ -1,26 +1,23 @@
1
  from pathlib import Path
2
- from typing import List, Tuple, Dict, Optional
3
  import sys
4
  import os
5
 
6
  from numpy import ndarray
7
  from pydantic import BaseModel
8
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 
9
 
10
  from ultralytics import YOLO
11
  from team_cluster import TeamClassifier
12
  from utils import (
13
  BoundingBox,
14
  Constants,
15
- classify_teams_batch,
16
  )
17
 
18
  import time
19
  import torch
20
  import gc
21
- import cv2
22
- import numpy as np
23
- from collections import defaultdict
24
  from pitch import process_batch_input, get_cls_net
25
  import yaml
26
 
@@ -49,7 +46,7 @@ class Miner:
49
  CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
50
  GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
51
  MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
52
- MAX_SAMPLES_FOR_FIT = 1000 # Maximum samples to avoid overfitting
53
 
54
  def __init__(self, path_hf_repo: Path) -> None:
55
  try:
@@ -57,7 +54,7 @@ class Miner:
57
  model_path = path_hf_repo / "football_object_detection.onnx"
58
  self.bbox_model = YOLO(model_path)
59
 
60
- print(f"BBox Model Loaded: class name {self.bbox_model.names}")
61
 
62
  team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
63
  self.team_classifier = TeamClassifier(
@@ -71,8 +68,6 @@ class Miner:
71
  self.team_classifier_fitted = False
72
  self.player_crops_for_fit = []
73
 
74
- # self.keypoints_model = YOLO(path_hf_repo / "keypoint.pt")
75
-
76
  model_kp_path = path_hf_repo / 'keypoint'
77
  config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
78
  cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
@@ -84,8 +79,6 @@ class Miner:
84
  model.eval()
85
 
86
  self.keypoints_model = model
87
- print("Keypoints Model (keypoint.pt) Loaded")
88
-
89
  self.kp_threshold = 0.1
90
  self.pitch_batch_size = 4
91
  self.health = "healthy"
@@ -138,109 +131,6 @@ class Miner:
138
 
139
  return intersection_area / union_area
140
 
141
- def _extract_jersey_region(self, crop: ndarray) -> ndarray:
142
- """
143
- Extract jersey region (upper body) from player crop.
144
- For close-ups, focuses on upper 60%, for distant shots uses full crop.
145
- """
146
- if crop is None or crop.size == 0:
147
- return crop
148
-
149
- h, w = crop.shape[:2]
150
- if h < 10 or w < 10:
151
- return crop
152
-
153
- # For close-up shots, extract upper body (jersey region)
154
- is_closeup = h > 100 or (h * w) > 12000
155
- if is_closeup:
156
- # Upper 60% of the crop (jersey area, avoiding shorts)
157
- jersey_top = 0
158
- jersey_bottom = int(h * 0.60)
159
- jersey_left = max(0, int(w * 0.05))
160
- jersey_right = min(w, int(w * 0.95))
161
- return crop[jersey_top:jersey_bottom, jersey_left:jersey_right]
162
- return crop
163
-
164
- def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]:
165
- """
166
- Extract color signature from jersey region using HSV and LAB color spaces.
167
- Returns a feature vector with dominant colors and color statistics.
168
- """
169
- if crop is None or crop.size == 0:
170
- return None
171
-
172
- jersey_region = self._extract_jersey_region(crop)
173
- if jersey_region.size == 0:
174
- return None
175
-
176
- try:
177
- # Convert to HSV and LAB color spaces
178
- hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV)
179
- lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB)
180
-
181
- # Reshape for processing
182
- hsv_flat = hsv.reshape(-1, 3).astype(np.float32)
183
- lab_flat = lab.reshape(-1, 3).astype(np.float32)
184
-
185
- # Compute statistics for HSV
186
- hsv_mean = np.mean(hsv_flat, axis=0) / 255.0
187
- hsv_std = np.std(hsv_flat, axis=0) / 255.0
188
-
189
- # Compute statistics for LAB
190
- lab_mean = np.mean(lab_flat, axis=0) / 255.0
191
- lab_std = np.std(lab_flat, axis=0) / 255.0
192
-
193
- # Dominant color (most frequent hue)
194
- hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180))
195
- dominant_hue = np.argmax(hue_hist) * 5 # Convert to hue value
196
-
197
- # Combine features
198
- color_features = np.concatenate([
199
- hsv_mean,
200
- hsv_std,
201
- lab_mean[:2], # L and A channels (B is less informative)
202
- lab_std[:2],
203
- [dominant_hue / 180.0] # Normalized dominant hue
204
- ])
205
-
206
- return color_features
207
- except Exception as e:
208
- print(f"Error extracting color signature: {e}")
209
- return None
210
-
211
- def _get_spatial_position(self, bbox: Tuple[float, float, float, float],
212
- frame_width: int, frame_height: int) -> Tuple[float, float]:
213
- """
214
- Get normalized spatial position of player on the pitch.
215
- Returns (x_normalized, y_normalized) where 0,0 is top-left.
216
- """
217
- x1, y1, x2, y2 = bbox
218
- center_x = (x1 + x2) / 2.0
219
- center_y = (y1 + y2) / 2.0
220
-
221
- # Normalize to [0, 1]
222
- x_norm = center_x / frame_width if frame_width > 0 else 0.5
223
- y_norm = center_y / frame_height if frame_height > 0 else 0.5
224
-
225
- return (x_norm, y_norm)
226
-
227
- def _find_best_match(self, target_box: Tuple[float, float, float, float],
228
- predicted_frame_data: Dict[int, Tuple[Tuple, str]],
229
- iou_threshold: float) -> Tuple[Optional[str], float]:
230
- """
231
- Find best matching box in predicted frame data using IoU.
232
- """
233
- best_iou = 0.0
234
- best_team_id = None
235
-
236
- for idx, (bbox, team_cls_id) in predicted_frame_data.items():
237
- iou = self._calculate_iou(target_box, bbox)
238
- if iou > best_iou and iou >= iou_threshold:
239
- best_iou = iou
240
- best_team_id = team_cls_id
241
-
242
- return (best_team_id, best_iou)
243
-
244
  def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
245
  batch_size = 16
246
  detection_results = []
@@ -253,203 +143,175 @@ class Miner:
253
  return detection_results
254
 
255
  def _team_classify(self, detection_results, decoded_images, offset):
256
- """
257
- Hybrid team classification combining:
258
- 1. Appearance features (OSNet)
259
- 2. Color signatures (HSV/LAB)
260
- 3. Spatial priors (left/right side of pitch)
261
- 4. Temporal tracking (same player = same team)
262
- """
263
  start = time.time()
264
-
265
- # Phase 1: Collect samples and fit appearance-based classifier
266
- fit_sample_size = min(self.MAX_SAMPLES_FOR_FIT, len(detection_results) * 10)
267
  player_crops_for_fit = []
268
-
269
  for frame_id in range(len(detection_results)):
270
  detection_box = detection_results[frame_id].boxes.data
271
  if len(detection_box) < 4:
272
  continue
273
-
274
  if len(player_crops_for_fit) < fit_sample_size:
275
  frame_image = decoded_images[frame_id]
276
  for box in detection_box:
277
  x1, y1, x2, y2, conf, cls_id = box.tolist()
278
- if conf < 0.5 or cls_id != 2:
279
  continue
280
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
281
- if crop.size > 0:
282
- player_crops_for_fit.append(crop)
283
-
 
 
 
 
284
  if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
285
- print(f"Fitting TeamClassifier (OSNet) with {len(player_crops_for_fit)} player crops")
286
  self.team_classifier.fit(player_crops_for_fit)
287
  self.team_classifier_fitted = True
288
  break
289
-
290
- if not self.team_classifier_fitted and len(player_crops_for_fit) >= self.MIN_SAMPLES_FOR_FIT:
291
  print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
292
  self.team_classifier.fit(player_crops_for_fit)
293
  self.team_classifier_fitted = True
294
-
295
- print(f"Fitting time: {time.time() - start:.2f}s")
296
-
297
- # Phase 2: Hybrid classification for all frames
298
  start = time.time()
299
- bboxes: dict[int, list[BoundingBox]] = {}
300
-
301
- # Temporal tracking: {track_id: (team_id, confidence, last_frame)}
302
- player_tracks: Dict[Tuple, Tuple[int, float, int]] = {}
303
-
304
- # Spatial team assignment: track which team is on which side
305
- left_side_team = None
306
- right_side_team = None
307
-
 
 
 
308
  for frame_id in range(len(detection_results)):
 
 
 
 
 
 
 
309
  detection_box = detection_results[frame_id].boxes.data
310
  frame_image = decoded_images[frame_id]
311
- frame_h, frame_w = frame_image.shape[:2]
312
- boxes = []
313
-
314
- # Collect all players in this frame
315
- player_data = [] # (idx, crop, bbox, spatial_pos, color_sig)
316
-
317
  for idx, box in enumerate(detection_box):
318
  x1, y1, x2, y2, conf, cls_id = box.tolist()
319
- if cls_id != 2 or conf < 0.6:
320
- continue
321
-
322
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
323
- if crop.size == 0:
324
  continue
325
-
326
- bbox = (x1, y1, x2, y2)
327
- spatial_pos = self._get_spatial_position(bbox, frame_w, frame_h)
328
- color_sig = self._extract_color_signature(crop)
329
-
330
- player_data.append((idx, crop, bbox, spatial_pos, color_sig))
331
-
332
- if len(player_data) == 0:
333
- bboxes[offset + frame_id] = []
334
- continue
335
-
336
- # Step 1: Get appearance-based predictions (OSNet)
337
- appearance_predictions = {}
338
- if self.team_classifier and self.team_classifier_fitted:
339
- crops = [data[1] for data in player_data]
340
- appearance_team_ids = self.team_classifier.predict(crops)
341
- for (idx, _, _, _, _), team_id in zip(player_data, appearance_team_ids):
342
- appearance_predictions[idx] = team_id
343
-
344
- # Step 2: Extract color signatures and cluster
345
- color_signatures = []
346
- color_indices = []
347
- for idx, _, _, _, color_sig in player_data:
348
- if color_sig is not None:
349
- color_signatures.append(color_sig)
350
- color_indices.append(idx)
351
-
352
- color_predictions = {}
353
- if len(color_signatures) >= 4:
354
- try:
355
- from sklearn.cluster import KMeans
356
- color_kmeans = KMeans(n_clusters=2, random_state=42, n_init=10)
357
- color_clusters = color_kmeans.fit_predict(color_signatures)
358
- for idx, cluster_id in zip(color_indices, color_clusters):
359
- color_predictions[idx] = cluster_id
360
- except Exception as e:
361
- print(f"Color clustering failed: {e}")
362
-
363
- # Step 3: Apply spatial priors
364
- # Determine which team is on which side based on majority
365
- if left_side_team is None or right_side_team is None:
366
- left_side_players = [p for p in player_data if p[3][0] < 0.5] # x < 0.5
367
- right_side_players = [p for p in player_data if p[3][0] >= 0.5] # x >= 0.5
368
-
369
- if len(left_side_players) >= 2 and len(right_side_players) >= 2:
370
- # Use appearance predictions to determine sides
371
- left_teams = [appearance_predictions.get(p[0]) for p in left_side_players
372
- if p[0] in appearance_predictions]
373
- right_teams = [appearance_predictions.get(p[0]) for p in right_side_players
374
- if p[0] in appearance_predictions]
375
-
376
- if left_teams and right_teams:
377
- left_team_mode = max(set(left_teams), key=left_teams.count)
378
- right_team_mode = max(set(right_teams), key=right_teams.count)
379
-
380
- if left_team_mode != right_team_mode:
381
- left_side_team = left_team_mode
382
- right_side_team = right_team_mode
383
-
384
- # Step 4: Combine predictions with voting
385
- final_predictions = {}
386
- for idx, _, bbox, spatial_pos, _ in player_data:
387
- votes = []
388
- weights = []
389
-
390
- # Appearance vote (weight: 0.4)
391
- if idx in appearance_predictions:
392
- votes.append(appearance_predictions[idx])
393
- weights.append(0.4)
394
-
395
- # Color vote (weight: 0.3)
396
- if idx in color_predictions:
397
- votes.append(color_predictions[idx])
398
- weights.append(0.3)
399
-
400
- # Spatial vote (weight: 0.3)
401
- if left_side_team is not None and right_side_team is not None:
402
- x_pos, _ = spatial_pos
403
- if x_pos < 0.5:
404
- spatial_team = left_side_team
405
- else:
406
- spatial_team = right_side_team
407
- votes.append(spatial_team)
408
- weights.append(0.3)
409
-
410
- # Temporal vote (weight: 0.2) - match with previous frames
411
- if len(votes) > 0:
412
- # Simple temporal matching: find similar bbox in previous frames
413
- best_track_match = None
414
- best_track_iou = 0.0
415
- for track_key, (track_team, track_conf, track_frame) in player_tracks.items():
416
- if abs(track_frame - frame_id) <= 5: # Within 5 frames
417
- track_bbox = track_key
418
- iou = self._calculate_iou(bbox, track_bbox)
419
- if iou > best_track_iou and iou > 0.3:
420
- best_track_iou = iou
421
- best_track_match = track_team
422
-
423
- if best_track_match is not None:
424
- votes.append(best_track_match)
425
- weights.append(0.2)
426
-
427
- # Weighted voting
428
- if len(votes) > 0:
429
- team_0_score = sum(w for v, w in zip(votes, weights) if v == 0)
430
- team_1_score = sum(w for v, w in zip(votes, weights) if v == 1)
431
-
432
- if team_0_score > team_1_score:
433
- final_team = 0
434
- elif team_1_score > team_0_score:
435
- final_team = 1
436
- else:
437
- # Tie: use appearance prediction or first vote
438
- final_team = votes[0] if votes else 0
439
-
440
- final_predictions[idx] = final_team
441
-
442
- # Update tracking
443
- track_key = bbox
444
- player_tracks[track_key] = (final_team, max(team_0_score, team_1_score), frame_id)
445
-
446
- # Step 5: Generate output boxes
447
  for idx, box in enumerate(detection_box):
448
  x1, y1, x2, y2, conf, cls_id = box.tolist()
449
  if cls_id == 2 and conf < 0.6:
450
  continue
451
-
452
- # Check overlap with staff
453
  overlap_staff = False
454
  for idy, boxy in enumerate(detection_box):
455
  s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
@@ -460,13 +322,12 @@ class Miner:
460
  break
461
  if overlap_staff:
462
  continue
463
-
464
  mapped_cls_id = str(int(cls_id))
465
-
466
- # Override with team prediction
467
- if idx in final_predictions:
468
- mapped_cls_id = str(6 + int(final_predictions[idx]))
469
-
470
  if mapped_cls_id != '4':
471
  if int(mapped_cls_id) == 3 and conf < 0.5:
472
  continue
@@ -480,17 +341,14 @@ class Miner:
480
  conf=float(conf),
481
  )
482
  )
483
-
484
  # Handle footballs - keep only the best one
485
  footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
486
  if len(footballs) > 1:
487
  best_ball = max(footballs, key=lambda b: b.conf)
488
  boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
489
  boxes.append(best_ball)
490
-
491
- bboxes[offset + frame_id] = boxes
492
 
493
- print(f"Hybrid team classification time: {time.time() - start:.2f}s")
494
  return bboxes
495
 
496
 
@@ -499,19 +357,11 @@ class Miner:
499
  detection_results = self._detect_objects_batch(batch_images)
500
  end = time.time()
501
  print(f"Detection time: {end - start}")
502
-
503
- # Use hybrid team classification
504
  start = time.time()
505
  bboxes = self._team_classify(detection_results, batch_images, offset)
506
  end = time.time()
507
  print(f"Team classify time: {end - start}")
508
 
509
- # Phase 3: Keypoint Detection
510
- # keypoints: Dict[int, List[Tuple[int, int]]] = {}
511
-
512
- # keypoints = self._detect_keypoints_batch(batch_images, offset, n_keypoints)
513
-
514
-
515
  pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
516
  keypoints: Dict[int, List[Tuple[int, int]]] = {}
517
 
@@ -560,81 +410,25 @@ class Miner:
560
  end = time.time()
561
  print(f"Keypoint time: {end - start}")
562
 
 
563
  results: List[TVFrameResult] = []
564
  for frame_number in range(offset, offset + len(batch_images)):
565
  frame_boxes = bboxes.get(frame_number, [])
 
566
  result = TVFrameResult(
567
  frame_id=frame_number,
568
  boxes=frame_boxes,
569
- keypoints=keypoints.get(
570
- frame_number,
571
- [(0, 0) for _ in range(n_keypoints)],
572
- ),
573
  )
574
  results.append(result)
575
 
 
 
 
 
576
  gc.collect()
577
  if torch.cuda.is_available():
578
  torch.cuda.empty_cache()
579
  torch.cuda.synchronize()
580
 
581
- return results
582
-
583
- def _detect_keypoints_batch(self, batch_images: List[ndarray],
584
- offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]:
585
- """
586
- Phase 3: Keypoint detection for all frames in batch.
587
-
588
- Args:
589
- batch_images: List of images to process
590
- offset: Frame offset for numbering
591
- n_keypoints: Number of keypoints expected
592
-
593
- Returns:
594
- Dictionary mapping frame_id to list of keypoint coordinates
595
- """
596
- keypoints: Dict[int, List[Tuple[int, int]]] = {}
597
- keypoints_model_results = self.keypoints_model.predict(batch_images)
598
-
599
- if keypoints_model_results is None:
600
- return keypoints
601
-
602
- for frame_idx_in_batch, detection in enumerate(keypoints_model_results):
603
- if not hasattr(detection, "keypoints") or detection.keypoints is None:
604
- continue
605
-
606
- # Extract keypoints with confidence
607
- frame_keypoints_with_conf: List[Tuple[int, int, float]] = []
608
- for i, part_points in enumerate(detection.keypoints.data):
609
- for k_id, (x, y, _) in enumerate(part_points):
610
- confidence = float(detection.keypoints.conf[i][k_id])
611
- frame_keypoints_with_conf.append((int(x), int(y), confidence))
612
-
613
- # Pad or truncate to expected number of keypoints
614
- if len(frame_keypoints_with_conf) < n_keypoints:
615
- frame_keypoints_with_conf.extend(
616
- [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf))
617
- )
618
- else:
619
- frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints]
620
-
621
- # Filter keypoints based on confidence thresholds
622
- filtered_keypoints: List[Tuple[int, int]] = []
623
- for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf):
624
- if idx in self.CORNER_INDICES:
625
- # Corner keypoints have lower confidence threshold
626
- if confidence < 0.3:
627
- filtered_keypoints.append((0, 0))
628
- else:
629
- filtered_keypoints.append((int(x), int(y)))
630
- else:
631
- # Regular keypoints
632
- if confidence < 0.5:
633
- filtered_keypoints.append((0, 0))
634
- else:
635
- filtered_keypoints.append((int(x), int(y)))
636
-
637
- frame_id = offset + frame_idx_in_batch
638
- keypoints[frame_id] = filtered_keypoints
639
-
640
- return keypoints
 
1
  from pathlib import Path
2
+ from typing import List, Tuple, Dict
3
  import sys
4
  import os
5
 
6
  from numpy import ndarray
7
  from pydantic import BaseModel
8
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
+ from keypoint_helper import run_keypoints_post_processing
10
 
11
  from ultralytics import YOLO
12
  from team_cluster import TeamClassifier
13
  from utils import (
14
  BoundingBox,
15
  Constants,
 
16
  )
17
 
18
  import time
19
  import torch
20
  import gc
 
 
 
21
  from pitch import process_batch_input, get_cls_net
22
  import yaml
23
 
 
46
  CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
47
  GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
48
  MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
49
+ MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting
50
 
51
  def __init__(self, path_hf_repo: Path) -> None:
52
  try:
 
54
  model_path = path_hf_repo / "football_object_detection.onnx"
55
  self.bbox_model = YOLO(model_path)
56
 
57
+ print("BBox Model Loaded")
58
 
59
  team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
60
  self.team_classifier = TeamClassifier(
 
68
  self.team_classifier_fitted = False
69
  self.player_crops_for_fit = []
70
 
 
 
71
  model_kp_path = path_hf_repo / 'keypoint'
72
  config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
73
  cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
 
79
  model.eval()
80
 
81
  self.keypoints_model = model
 
 
82
  self.kp_threshold = 0.1
83
  self.pitch_batch_size = 4
84
  self.health = "healthy"
 
131
 
132
  return intersection_area / union_area
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
135
  batch_size = 16
136
  detection_results = []
 
143
  return detection_results
144
 
145
  def _team_classify(self, detection_results, decoded_images, offset):
146
+ self.team_classifier_fitted = False
 
 
 
 
 
 
147
  start = time.time()
148
+ # Collect player crops from first batch for fitting
149
+ fit_sample_size = 600
 
150
  player_crops_for_fit = []
151
+
152
  for frame_id in range(len(detection_results)):
153
  detection_box = detection_results[frame_id].boxes.data
154
  if len(detection_box) < 4:
155
  continue
156
+ # Collect player boxes for team classification fitting (first batch only)
157
  if len(player_crops_for_fit) < fit_sample_size:
158
  frame_image = decoded_images[frame_id]
159
  for box in detection_box:
160
  x1, y1, x2, y2, conf, cls_id = box.tolist()
161
+ if conf < 0.5:
162
  continue
163
+ mapped_cls_id = str(int(cls_id))
164
+ # Only collect player crops (cls_id = 2)
165
+ if mapped_cls_id == '2':
166
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
167
+ if crop.size > 0:
168
+ player_crops_for_fit.append(crop)
169
+
170
+ # Fit team classifier after collecting samples
171
  if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
172
+ print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
173
  self.team_classifier.fit(player_crops_for_fit)
174
  self.team_classifier_fitted = True
175
  break
176
+ if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
 
177
  print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
178
  self.team_classifier.fit(player_crops_for_fit)
179
  self.team_classifier_fitted = True
180
+ end = time.time()
181
+ print(f"Fitting Kmeans time: {end - start}")
182
+
183
+ # Second pass: predict teams with configurable frame skipping optimization
184
  start = time.time()
185
+
186
+ # Get configuration for frame skipping
187
+ prediction_interval = 1 # Default: predict every 2 frames
188
+ iou_threshold = 0.3
189
+
190
+ print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
191
+
192
+ # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
193
+ predicted_frame_data = {}
194
+
195
+ # Step 1: Predict for frames at prediction_interval only
196
+ frames_to_predict = []
197
  for frame_id in range(len(detection_results)):
198
+ if frame_id % prediction_interval == 0:
199
+ frames_to_predict.append(frame_id)
200
+
201
+ print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
202
+ f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
203
+
204
+ for frame_id in frames_to_predict:
205
  detection_box = detection_results[frame_id].boxes.data
206
  frame_image = decoded_images[frame_id]
207
+
208
+ # Collect player crops for this frame
209
+ frame_player_crops = []
210
+ frame_player_indices = []
211
+ frame_player_boxes = []
212
+
213
  for idx, box in enumerate(detection_box):
214
  x1, y1, x2, y2, conf, cls_id = box.tolist()
215
+ if cls_id == 2 and conf < 0.6:
 
 
 
 
216
  continue
217
+ mapped_cls_id = str(int(cls_id))
218
+
219
+ # Collect player crops for prediction
220
+ if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
221
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
222
+ if crop.size > 0:
223
+ frame_player_crops.append(crop)
224
+ frame_player_indices.append(idx)
225
+ frame_player_boxes.append((x1, y1, x2, y2))
226
+
227
+ # Predict teams for all players in this frame
228
+ if len(frame_player_crops) > 0:
229
+ team_ids = self.team_classifier.predict(frame_player_crops)
230
+ predicted_frame_data[frame_id] = {}
231
+ for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
232
+ # Map team_id (0,1) to cls_id (6,7)
233
+ team_cls_id = str(6 + int(team_id))
234
+ predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
235
+
236
+ # Step 2: Process all frames (interpolate skipped frames)
237
+ fallback_count = 0
238
+ interpolated_count = 0
239
+ bboxes: dict[int, list[BoundingBox]] = {}
240
+ for frame_id in range(len(detection_results)):
241
+ detection_box = detection_results[frame_id].boxes.data
242
+ frame_image = decoded_images[frame_id]
243
+ boxes = []
244
+
245
+ team_predictions = {}
246
+
247
+ if frame_id % prediction_interval == 0:
248
+ # Predicted frame: use pre-computed predictions
249
+ if frame_id in predicted_frame_data:
250
+ for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
251
+ team_predictions[idx] = team_cls_id
252
+ else:
253
+ # Skipped frame: interpolate from neighboring predicted frames
254
+ # Find nearest predicted frames
255
+ prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
256
+ next_predicted_frame = prev_predicted_frame + prediction_interval
257
+
258
+ # Collect current frame player boxes
259
+ for idx, box in enumerate(detection_box):
260
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
261
+ if cls_id == 2 and conf < 0.6:
262
+ continue
263
+ mapped_cls_id = str(int(cls_id))
264
+
265
+ if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
266
+ target_box = (x1, y1, x2, y2)
267
+
268
+ # Try to match with previous predicted frame
269
+ best_team_id = None
270
+ best_iou = 0.0
271
+
272
+ if prev_predicted_frame in predicted_frame_data:
273
+ team_id, iou = self._find_best_match(
274
+ target_box,
275
+ predicted_frame_data[prev_predicted_frame],
276
+ iou_threshold
277
+ )
278
+ if team_id is not None:
279
+ best_team_id = team_id
280
+ best_iou = iou
281
+
282
+ # Try to match with next predicted frame if available and no good match yet
283
+ if best_team_id is None and next_predicted_frame < len(detection_results):
284
+ if next_predicted_frame in predicted_frame_data:
285
+ team_id, iou = self._find_best_match(
286
+ target_box,
287
+ predicted_frame_data[next_predicted_frame],
288
+ iou_threshold
289
+ )
290
+ if team_id is not None and iou > best_iou:
291
+ best_team_id = team_id
292
+ best_iou = iou
293
+
294
+ # Track interpolation success
295
+ if best_team_id is not None:
296
+ interpolated_count += 1
297
+ else:
298
+ # Fallback: if no match found, predict individually
299
+ crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
300
+ if crop.size > 0:
301
+ team_id = self.team_classifier.predict([crop])[0]
302
+ best_team_id = str(6 + int(team_id))
303
+ fallback_count += 1
304
+
305
+ if best_team_id is not None:
306
+ team_predictions[idx] = best_team_id
307
+
308
+ # Parse boxes with team classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  for idx, box in enumerate(detection_box):
310
  x1, y1, x2, y2, conf, cls_id = box.tolist()
311
  if cls_id == 2 and conf < 0.6:
312
  continue
313
+
314
+ # Check overlap with staff box
315
  overlap_staff = False
316
  for idy, boxy in enumerate(detection_box):
317
  s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
 
322
  break
323
  if overlap_staff:
324
  continue
325
+
326
  mapped_cls_id = str(int(cls_id))
327
+
328
+ # Override cls_id for players with team prediction
329
+ if idx in team_predictions:
330
+ mapped_cls_id = team_predictions[idx]
 
331
  if mapped_cls_id != '4':
332
  if int(mapped_cls_id) == 3 and conf < 0.5:
333
  continue
 
341
  conf=float(conf),
342
  )
343
  )
 
344
  # Handle footballs - keep only the best one
345
  footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
346
  if len(footballs) > 1:
347
  best_ball = max(footballs, key=lambda b: b.conf)
348
  boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
349
  boxes.append(best_ball)
 
 
350
 
351
+ bboxes[offset + frame_id] = boxes
352
  return bboxes
353
 
354
 
 
357
  detection_results = self._detect_objects_batch(batch_images)
358
  end = time.time()
359
  print(f"Detection time: {end - start}")
 
 
360
  start = time.time()
361
  bboxes = self._team_classify(detection_results, batch_images, offset)
362
  end = time.time()
363
  print(f"Team classify time: {end - start}")
364
 
 
 
 
 
 
 
365
  pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
366
  keypoints: Dict[int, List[Tuple[int, int]]] = {}
367
 
 
410
  end = time.time()
411
  print(f"Keypoint time: {end - start}")
412
 
413
+
414
  results: List[TVFrameResult] = []
415
  for frame_number in range(offset, offset + len(batch_images)):
416
  frame_boxes = bboxes.get(frame_number, [])
417
+ frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)])
418
  result = TVFrameResult(
419
  frame_id=frame_number,
420
  boxes=frame_boxes,
421
+ keypoints=frame_keypoints,
 
 
 
422
  )
423
  results.append(result)
424
 
425
+ if len(batch_images) > 0:
426
+ h, w = batch_images[0].shape[:2]
427
+ results = run_keypoints_post_processing(results, w, h)
428
+
429
  gc.collect()
430
  if torch.cuda.is_available():
431
  torch.cuda.empty_cache()
432
  torch.cuda.synchronize()
433
 
434
+ return results