tarto2 commited on
Commit
1ec792c
·
verified ·
1 Parent(s): 570d67e

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. keypoint_helper.py +116 -0
  2. miner.py +9 -5
keypoint_helper.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from typing import List, Tuple, Sequence, Any
5
+
6
+ FOOTBALL_KEYPOINTS: list[tuple[int, int]] = [
7
+ (0, 0), # 1
8
+ (0, 0), # 2
9
+ (0, 0), # 3
10
+ (0, 0), # 4
11
+ (0, 0), # 5
12
+ (0, 0), # 6
13
+
14
+ (0, 0), # 7
15
+ (0, 0), # 8
16
+ (0, 0), # 9
17
+
18
+ (0, 0), # 10
19
+ (0, 0), # 11
20
+ (0, 0), # 12
21
+ (0, 0), # 13
22
+
23
+ (0, 0), # 14
24
+ (527, 283), # 15
25
+ (527, 403), # 16
26
+ (0, 0), # 17
27
+
28
+ (0, 0), # 18
29
+ (0, 0), # 19
30
+ (0, 0), # 20
31
+ (0, 0), # 21
32
+
33
+ (0, 0), # 22
34
+
35
+ (0, 0), # 23
36
+ (0, 0), # 24
37
+
38
+ (0, 0), # 25
39
+ (0, 0), # 26
40
+ (0, 0), # 27
41
+ (0, 0), # 28
42
+ (0, 0), # 29
43
+ (0, 0), # 30
44
+
45
+ (405, 340), # 31
46
+ (645, 340), # 32
47
+ ]
48
+
49
+ def convert_keypoints_to_val_format(keypoints):
50
+ return [tuple(int(x) for x in pair) for pair in keypoints]
51
+
52
+ def predict_failed_indices(results_frames: Sequence[Any]) -> list[int]:
53
+
54
+ max_frames = len(results_frames)
55
+ if max_frames == 0:
56
+ return []
57
+
58
+ failed_indices: list[int] = []
59
+ for frame_index, frame_result in enumerate(results_frames):
60
+ frame_keypoints = getattr(frame_result, "keypoints", []) or []
61
+ non_zero_count = sum(1 for (x, y) in frame_keypoints if int(x) != 0 and int(y) != 0)
62
+ if non_zero_count <= 4:
63
+ failed_indices.append(frame_index)
64
+ return failed_indices
65
+
66
+ def _generate_sparse_template_keypoints(frame_width: int, frame_height: int) -> list[tuple[int, int]]:
67
+ template_max_x, template_max_y = (1045, 675)
68
+ sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1)
69
+ sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1)
70
+ scaled: list[tuple[int, int]] = []
71
+ for i in range(32):
72
+ tx, ty = FOOTBALL_KEYPOINTS[i]
73
+ x_scaled = int(round(tx * sx))
74
+ y_scaled = int(round(ty * sy))
75
+ scaled.append((x_scaled, y_scaled))
76
+ return scaled
77
+
78
+ def fix_keypoints(
79
+ results_frames: Sequence[Any],
80
+ failed_indices: Sequence[int],
81
+ frame_width: int,
82
+ frame_height: int,
83
+ ) -> list[Any]:
84
+ max_frames = len(results_frames)
85
+ if max_frames == 0:
86
+ return list(results_frames)
87
+
88
+ failed_set = set(int(i) for i in failed_indices)
89
+ all_indices = list(range(max_frames))
90
+ successful_indices = [i for i in all_indices if i not in failed_set]
91
+
92
+ if len(successful_indices) == 0:
93
+ sparse_template = _generate_sparse_template_keypoints(frame_width, frame_height)
94
+ for frame_result in results_frames:
95
+ setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(sparse_template)))
96
+ return list(results_frames)
97
+
98
+ seed_index = successful_indices[0]
99
+ seed_kps_raw = getattr(results_frames[seed_index], "keypoints", []) or []
100
+ last_success_kps = convert_keypoints_to_val_format(seed_kps_raw)
101
+
102
+ for frame_index in range(max_frames):
103
+ frame_result = results_frames[frame_index]
104
+ if frame_index in failed_set:
105
+ setattr(frame_result, "keypoints", list(last_success_kps))
106
+ else:
107
+ current_kps_raw = getattr(frame_result, "keypoints", []) or []
108
+ current_kps = convert_keypoints_to_val_format(current_kps_raw)
109
+ setattr(frame_result, "keypoints", list(current_kps))
110
+ last_success_kps = current_kps
111
+
112
+ return list(results_frames)
113
+
114
+ def run_keypoints_post_processing(results_frames: Sequence[Any], frame_width: int, frame_height: int) -> list[Any]:
115
+ failed_indices = predict_failed_indices(results_frames)
116
+ return fix_keypoints(results_frames, failed_indices, frame_width, frame_height)
miner.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  from numpy import ndarray
7
  from pydantic import BaseModel
8
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
-
10
 
11
  from ultralytics import YOLO
12
  from team_cluster import TeamClassifier
@@ -46,7 +46,7 @@ class Miner:
46
  CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
47
  GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
48
  MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
49
- MAX_SAMPLES_FOR_FIT = 1000 # Maximum samples to avoid overfitting
50
 
51
  def __init__(self, path_hf_repo: Path) -> None:
52
  try:
@@ -54,7 +54,7 @@ class Miner:
54
  model_path = path_hf_repo / "football_object_detection.onnx"
55
  self.bbox_model = YOLO(model_path)
56
 
57
- print("BBox Model Loaded")
58
 
59
  team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
60
  self.team_classifier = TeamClassifier(
@@ -62,7 +62,7 @@ class Miner:
62
  batch_size=32,
63
  model_name=str(team_model_path)
64
  )
65
- print("Team Classifier Loaded")
66
 
67
  # Team classification state
68
  self.team_classifier_fitted = False
@@ -146,7 +146,7 @@ class Miner:
146
  self.team_classifier_fitted = False
147
  start = time.time()
148
  # Collect player crops from first batch for fitting
149
- fit_sample_size = 1000
150
  player_crops_for_fit = []
151
 
152
  for frame_id in range(len(detection_results)):
@@ -422,6 +422,10 @@ class Miner:
422
  )
423
  results.append(result)
424
 
 
 
 
 
425
  gc.collect()
426
  if torch.cuda.is_available():
427
  torch.cuda.empty_cache()
 
6
  from numpy import ndarray
7
  from pydantic import BaseModel
8
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
+ from keypoint_helper import run_keypoints_post_processing
10
 
11
  from ultralytics import YOLO
12
  from team_cluster import TeamClassifier
 
46
  CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
47
  GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
48
  MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
49
+ MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting
50
 
51
  def __init__(self, path_hf_repo: Path) -> None:
52
  try:
 
54
  model_path = path_hf_repo / "football_object_detection.onnx"
55
  self.bbox_model = YOLO(model_path)
56
 
57
+ print("BBox Model Loaded")
58
 
59
  team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
60
  self.team_classifier = TeamClassifier(
 
62
  batch_size=32,
63
  model_name=str(team_model_path)
64
  )
65
+ print("Team Classifier Loaded")
66
 
67
  # Team classification state
68
  self.team_classifier_fitted = False
 
146
  self.team_classifier_fitted = False
147
  start = time.time()
148
  # Collect player crops from first batch for fitting
149
+ fit_sample_size = 600
150
  player_crops_for_fit = []
151
 
152
  for frame_id in range(len(detection_results)):
 
422
  )
423
  results.append(result)
424
 
425
+ if len(batch_images) > 0:
426
+ h, w = batch_images[0].shape[:2]
427
+ results = run_keypoints_post_processing(results, w, h)
428
+
429
  gc.collect()
430
  if torch.cuda.is_available():
431
  torch.cuda.empty_cache()