tarto2 commited on
Commit
acf7a04
·
1 Parent(s): 08db6c9
.gitattributes CHANGED
@@ -33,5 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- keypoint filter=lfs diff=lfs merge=lfs -text
37
  osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SV_kp.engine filter=lfs diff=lfs merge=lfs -text
37
  osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ venv
2
+ outputs
3
+ test_predict_batch.py
4
+ test.mp4
5
+ inspect_yolo_model.py
keypoint → 20251029-detection.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ea78fa76aaf94976a8eca428d6e3c59697a93430cba1a4603e20284b61f5113
3
- size 264964645
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bbacfcb38e38b1b8816788e9e6e845160533719a0b87b693d58b932380d0d28
3
+ size 152961687
football_keypoints_detection.pt → 20251029-keypoint.pt RENAMED
File without changes
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🚀 Example Chute for Turbovision 🪂
2
+
3
+ This repository demonstrates how to deploy a Chute via the Turbovision CLI, hosted on Hugging Face Hub. It serves as a minimal example showcasing the required structure and workflow for integrating machine learning models, preprocessing, and orchestration into a reproducible Chute environment.
4
+
5
+ ## Repository Structure
6
+
7
+ The following two files must be present (in their current locations) for a successful deployment — their content can be modified as needed:
8
+
9
+ | File | Purpose |
10
+ |------|---------|
11
+ | `miner.py` | Defines the ML model type(s), orchestration, and all pre/postprocessing logic. |
12
+ | `config.yml` | Specifies machine configuration (e.g., GPU type, memory, environment variables). |
13
+
14
+ Other files — e.g., model weights, utility scripts, or dependencies — are optional and can be included as needed for your model.
15
+
16
+ > **Note**: Any required assets must be defined or contained within this repo, which is fully open-source, since all network-related operations (downloading challenge data, weights, etc.) are disabled inside the Chute.
17
+
18
+ ## Overview
19
+
20
+ Below is a high-level diagram showing the interaction between Huggingface, Chutes and Turbovision:
21
+
22
+ ```
23
+ ┌─────────────┐ ┌──────────┐ ┌──────────────┐
24
+ │ HuggingFace │ ───> │ Chutes │ ───> │ Turbovision │
25
+ │ Hub │ │ .ai │ │ Validator │
26
+ └─────────────┘ └──────────┘ └──────────────┘
27
+ ```
28
+
29
+ ## Local Testing
30
+
31
+ After editing the `config.yml` and `miner.py` and saving it into your Huggingface Repo, you will want to test it works locally.
32
+
33
+ 1. **Copy the template file** `scorevision/chute_template/turbovision_chute.py.j2` as a python file called `my_chute.py` and fill in the missing variables:
34
+
35
+ ```python
36
+ HF_REPO_NAME = "{{ huggingface_repository_name }}"
37
+ HF_REPO_REVISION = "{{ huggingface_repository_revision }}"
38
+ CHUTES_USERNAME = "{{ chute_username }}"
39
+ CHUTE_NAME = "{{ chute_name }}"
40
+ ```
41
+
42
+ 2. **Run the following command to build the chute locally** (Caution: there are known issues with the docker location when running this on a mac):
43
+
44
+ ```bash
45
+ chutes build my_chute:chute --local --public
46
+ ```
47
+
48
+ 3. **Run the name of the docker image just built** (i.e. `CHUTE_NAME`) and enter it:
49
+
50
+ ```bash
51
+ docker run -p 8000:8000 -e CHUTES_EXECUTION_CONTEXT=REMOTE -it <image-name> /bin/bash
52
+ ```
53
+
54
+ 4. **Run the file from within the container**:
55
+
56
+ ```bash
57
+ chutes run my_chute:chute --dev --debug
58
+ ```
59
+
60
+ 5. **In another terminal, test the local endpoints** to ensure there are no bugs:
61
+
62
+ ```bash
63
+ # Health check
64
+ curl -X POST http://localhost:8000/health -d '{}'
65
+
66
+ # Prediction test
67
+ curl -X POST http://localhost:8000/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}'
68
+ ```
69
+
70
+ ## Live Testing
71
+
72
+ If you have any chute with the same name (i.e. from a previous deployment), ensure you delete that first (or you will get an error when trying to build).
73
+
74
+ 1. **List existing chutes**:
75
+
76
+ ```bash
77
+ chutes chutes list
78
+ ```
79
+
80
+ Take note of the chute id that you wish to delete (if any):
81
+
82
+ ```bash
83
+ chutes chutes delete <chute-id>
84
+ ```
85
+
86
+ 2. **You should also delete its associated image**:
87
+
88
+ ```bash
89
+ chutes images list
90
+ ```
91
+
92
+ Take note of the chute image id:
93
+
94
+ ```bash
95
+ chutes images delete <chute-image-id>
96
+ ```
97
+
98
+ 3. **Use Turbovision's CLI to build, deploy and commit on-chain**:
99
+
100
+ ```bash
101
+ sv -vv push
102
+ ```
103
+
104
+ > **Note**: You can skip the on-chain commit using `--no-commit`. You can also specify a past huggingface revision to point to using `--revision` and/or the local files you want to upload to your huggingface repo using `--model-path`.
105
+
106
+ 4. **When completed, warm up the chute** (if its cold 🧊):
107
+
108
+ You can confirm its status using `chutes chutes list` or `chutes chutes get <chute-id>` if you already know its id.
109
+
110
+ > **Note**: Warming up can sometimes take a while but if the chute runs without errors (should be if you've tested locally first) and there are sufficient nodes (i.e. machines) available matching the `config.yml` you specified, the chute should become hot 🔥!
111
+
112
+ ```bash
113
+ chutes warmup <chute-id>
114
+ ```
115
+
116
+ 5. **Test the chute's endpoints**:
117
+
118
+ ```bash
119
+ # Health check
120
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/health -d '{}' -H "Authorization: Bearer $CHUTES_API_KEY"
121
+
122
+ # Prediction
123
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}' -H "Authorization: Bearer $CHUTES_API_KEY"
124
+ ```
125
+
126
+ 6. **Test what your chute would get on a validator**:
127
+
128
+ This also applies any validation/integrity checks which may fail if you did not use the Turbovision CLI above to deploy the chute:
129
+
130
+ ```bash
131
+ sv -vv run-once
132
+ ```
SV_kp.engine ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f99452eb79e064189e2758abd20a78845a5b639fc8b9c4bc650519c83e13e8db
3
+ size 368289641
config.yml CHANGED
@@ -2,15 +2,14 @@ Image:
2
  from_base: parachutes/python:3.12
3
  run_command:
4
  - pip install --upgrade setuptools wheel
5
- - pip install "torch==2.7.1" "torchvision==0.22.1"
6
- - pip install "ultralytics==8.3.222" "opencv-python-headless" "numpy" "pydantic"
7
  - pip install scikit-learn
8
  - pip install onnxruntime-gpu
9
  set_workdir: /app
10
 
11
  NodeSelector:
12
  gpu_count: 1
13
- min_vram_gb_per_gpu: 24
14
  exclude:
15
  - "5090"
16
  - b200
 
2
  from_base: parachutes/python:3.12
3
  run_command:
4
  - pip install --upgrade setuptools wheel
5
+ - pip install ultralytics==8.3.222 opencv-python-headless numpy pydantic
 
6
  - pip install scikit-learn
7
  - pip install onnxruntime-gpu
8
  set_workdir: /app
9
 
10
  NodeSelector:
11
  gpu_count: 1
12
+ min_vram_gb_per_gpu: 16
13
  exclude:
14
  - "5090"
15
  - b200
football_object_detection.onnx → detection.onnx RENAMED
File without changes
detection.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ad3e89b658d2626c34174f6799d240ffd37cfe45752c0ce6ef73b05935042e0
3
+ size 52014742
hrnetv2_w48.yaml DELETED
@@ -1,35 +0,0 @@
1
- MODEL:
2
- IMAGE_SIZE: [960, 540]
3
- NUM_JOINTS: 58
4
- PRETRAIN: ''
5
- EXTRA:
6
- FINAL_CONV_KERNEL: 1
7
- STAGE1:
8
- NUM_MODULES: 1
9
- NUM_BRANCHES: 1
10
- BLOCK: BOTTLENECK
11
- NUM_BLOCKS: [4]
12
- NUM_CHANNELS: [64]
13
- FUSE_METHOD: SUM
14
- STAGE2:
15
- NUM_MODULES: 1
16
- NUM_BRANCHES: 2
17
- BLOCK: BASIC
18
- NUM_BLOCKS: [4, 4]
19
- NUM_CHANNELS: [48, 96]
20
- FUSE_METHOD: SUM
21
- STAGE3:
22
- NUM_MODULES: 4
23
- NUM_BRANCHES: 3
24
- BLOCK: BASIC
25
- NUM_BLOCKS: [4, 4, 4]
26
- NUM_CHANNELS: [48, 96, 192]
27
- FUSE_METHOD: SUM
28
- STAGE4:
29
- NUM_MODULES: 3
30
- NUM_BRANCHES: 4
31
- BLOCK: BASIC
32
- NUM_BLOCKS: [4, 4, 4, 4]
33
- NUM_CHANNELS: [48, 96, 192, 384]
34
- FUSE_METHOD: SUM
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
keypoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dd10dba85895c92760cdb5a99c5cfca899c68f361a66c5448f38a187280ee1f
3
+ size 6849672
miner.py CHANGED
@@ -1,34 +1,23 @@
1
  from pathlib import Path
2
- from typing import List, Tuple, Dict
3
- import sys
4
- import os
5
-
6
- from numpy import ndarray
7
- from pydantic import BaseModel
8
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
-
 
 
 
 
10
  from ultralytics import YOLO
 
 
11
  from team_cluster import TeamClassifier
12
  from utils import (
13
  BoundingBox,
14
  Constants,
 
 
15
  )
16
 
17
- import time
18
- import torch
19
- import gc
20
- from pitch import process_batch_input, get_cls_net
21
- import yaml
22
-
23
-
24
- class BoundingBox(BaseModel):
25
- x1: int
26
- y1: int
27
- x2: int
28
- y2: int
29
- cls_id: int
30
- conf: float
31
-
32
 
33
  class TVFrameResult(BaseModel):
34
  frame_id: int
@@ -37,6 +26,10 @@ class TVFrameResult(BaseModel):
37
 
38
 
39
  class Miner:
 
 
 
 
40
  SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
41
  SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
42
  SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
@@ -45,385 +38,472 @@ class Miner:
45
  CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
46
  GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
47
  MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
48
- MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting
49
 
50
  def __init__(self, path_hf_repo: Path) -> None:
51
- try:
52
- device = "cuda" if torch.cuda.is_available() else "cpu"
53
- model_path = path_hf_repo / "football_object_detection.onnx"
54
- self.bbox_model = YOLO(model_path)
55
-
56
- print("BBox Model Loaded")
57
-
58
- team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
59
- self.team_classifier = TeamClassifier(
60
- device=device,
61
- batch_size=32,
62
- model_name=str(team_model_path)
63
- )
64
- print("Team Classifier Loaded")
65
-
66
- # Team classification state
67
- self.team_classifier_fitted = False
68
- self.player_crops_for_fit = []
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- model_kp_path = path_hf_repo / 'keypoint'
71
- config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
72
- cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
73
-
74
- loaded_state_kp = torch.load(model_kp_path, map_location=device)
75
- model = get_cls_net(cfg_kp)
76
- model.load_state_dict(loaded_state_kp)
77
- model.to(device)
78
- model.eval()
79
 
80
- self.keypoints_model = model
81
- self.kp_threshold = 0.1
82
- self.pitch_batch_size = 4
83
- self.health = "healthy"
84
- print("✅ Keypoints Model Loaded")
85
- except Exception as e:
86
- self.health = "❌ Miner initialization failed: " + str(e)
87
- print(self.health)
88
 
89
- def __repr__(self) -> str:
90
- if self.health == 'healthy':
91
- return (
92
- f"health: {self.health}\n"
93
- f"BBox Model: {type(self.bbox_model).__name__}\n"
94
- f"Keypoints Model: {type(self.keypoints_model).__name__}"
95
- )
96
- else:
97
- return self.health
98
 
99
- def _calculate_iou(self, box1: Tuple[float, float, float, float],
100
- box2: Tuple[float, float, float, float]) -> float:
101
  """
102
- Calculate Intersection over Union (IoU) between two bounding boxes.
103
- Args:
104
- box1: (x1, y1, x2, y2)
105
- box2: (x1, y1, x2, y2)
106
  Returns:
107
- IoU score (0-1)
108
  """
109
- x1_1, y1_1, x2_1, y2_1 = box1
110
- x1_2, y1_2, x2_2, y2_2 = box2
111
-
112
- # Calculate intersection area
113
- x_left = max(x1_1, x1_2)
114
- y_top = max(y1_1, y1_2)
115
- x_right = min(x2_1, x2_2)
116
- y_bottom = min(y2_1, y2_2)
117
-
118
- if x_right < x_left or y_bottom < y_top:
119
- return 0.0
120
-
121
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # Calculate union area
124
- box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
125
- box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
126
- union_area = box1_area + box2_area - intersection_area
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- if union_area == 0:
129
- return 0.0
130
 
131
- return intersection_area / union_area
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]:
134
- batch_size = 16
135
- detection_results = []
136
- n_frames = len(decoded_images)
137
- for frame_number in range(0, n_frames, batch_size):
138
- batch_images = decoded_images[frame_number: frame_number + batch_size]
139
- detections = self.bbox_model(batch_images, verbose=False, save=False)
140
- detection_results.extend(detections)
141
 
142
- return detection_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- def _team_classify(self, detection_results, decoded_images, offset):
145
- self.team_classifier_fitted = False
146
- start = time.time()
147
- # Collect player crops from first batch for fitting
148
- fit_sample_size = 600
149
- player_crops_for_fit = []
 
 
 
 
 
 
 
 
150
 
151
- for frame_id in range(len(detection_results)):
152
- detection_box = detection_results[frame_id].boxes.data
153
- if len(detection_box) < 4:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  continue
155
- # Collect player boxes for team classification fitting (first batch only)
156
- if len(player_crops_for_fit) < fit_sample_size:
157
- frame_image = decoded_images[frame_id]
158
- for box in detection_box:
159
- x1, y1, x2, y2, conf, cls_id = box.tolist()
160
- if conf < 0.5:
161
- continue
162
- mapped_cls_id = str(int(cls_id))
163
- # Only collect player crops (cls_id = 2)
164
- if mapped_cls_id == '2':
165
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
166
- if crop.size > 0:
167
- player_crops_for_fit.append(crop)
168
-
169
- # Fit team classifier after collecting samples
170
- if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size:
171
- print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
172
- self.team_classifier.fit(player_crops_for_fit)
173
- self.team_classifier_fitted = True
174
- break
175
- if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16:
176
- print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops")
177
- self.team_classifier.fit(player_crops_for_fit)
178
- self.team_classifier_fitted = True
179
- end = time.time()
180
- print(f"Fitting Kmeans time: {end - start}")
181
-
182
- # Second pass: predict teams with configurable frame skipping optimization
183
- start = time.time()
184
-
185
- # Get configuration for frame skipping
186
- prediction_interval = 1 # Default: predict every 2 frames
187
- iou_threshold = 0.3
188
-
189
- print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
190
-
191
- # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
192
- predicted_frame_data = {}
193
-
194
- # Step 1: Predict for frames at prediction_interval only
195
- frames_to_predict = []
196
- for frame_id in range(len(detection_results)):
197
- if frame_id % prediction_interval == 0:
198
- frames_to_predict.append(frame_id)
199
-
200
- print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
201
- f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
202
-
203
- for frame_id in frames_to_predict:
204
- detection_box = detection_results[frame_id].boxes.data
205
- frame_image = decoded_images[frame_id]
206
-
207
- # Collect player crops for this frame
208
- frame_player_crops = []
209
- frame_player_indices = []
210
- frame_player_boxes = []
211
-
212
- for idx, box in enumerate(detection_box):
213
- x1, y1, x2, y2, conf, cls_id = box.tolist()
214
- if cls_id == 2 and conf < 0.6:
215
  continue
216
- mapped_cls_id = str(int(cls_id))
217
-
218
- # Collect player crops for prediction
219
- if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
220
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
221
- if crop.size > 0:
222
- frame_player_crops.append(crop)
223
- frame_player_indices.append(idx)
224
- frame_player_boxes.append((x1, y1, x2, y2))
225
-
226
- # Predict teams for all players in this frame
227
- if len(frame_player_crops) > 0:
228
- team_ids = self.team_classifier.predict(frame_player_crops)
229
- predicted_frame_data[frame_id] = {}
230
- for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
231
- # Map team_id (0,1) to cls_id (6,7)
232
- team_cls_id = str(6 + int(team_id))
233
- predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
234
-
235
- # Step 2: Process all frames (interpolate skipped frames)
236
- fallback_count = 0
237
- interpolated_count = 0
238
- bboxes: dict[int, list[BoundingBox]] = {}
239
- for frame_id in range(len(detection_results)):
240
- detection_box = detection_results[frame_id].boxes.data
241
- frame_image = decoded_images[frame_id]
242
- boxes = []
243
-
244
- team_predictions = {}
245
-
246
- if frame_id % prediction_interval == 0:
247
- # Predicted frame: use pre-computed predictions
248
- if frame_id in predicted_frame_data:
249
- for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
250
- team_predictions[idx] = team_cls_id
251
- else:
252
- # Skipped frame: interpolate from neighboring predicted frames
253
- # Find nearest predicted frames
254
- prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
255
- next_predicted_frame = prev_predicted_frame + prediction_interval
256
-
257
- # Collect current frame player boxes
258
- for idx, box in enumerate(detection_box):
259
- x1, y1, x2, y2, conf, cls_id = box.tolist()
260
- if cls_id == 2 and conf < 0.6:
261
- continue
262
- mapped_cls_id = str(int(cls_id))
263
-
264
- if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
265
- target_box = (x1, y1, x2, y2)
266
-
267
- # Try to match with previous predicted frame
268
- best_team_id = None
269
- best_iou = 0.0
270
-
271
- if prev_predicted_frame in predicted_frame_data:
272
- team_id, iou = self._find_best_match(
273
- target_box,
274
- predicted_frame_data[prev_predicted_frame],
275
- iou_threshold
276
- )
277
- if team_id is not None:
278
- best_team_id = team_id
279
- best_iou = iou
280
-
281
- # Try to match with next predicted frame if available and no good match yet
282
- if best_team_id is None and next_predicted_frame < len(detection_results):
283
- if next_predicted_frame in predicted_frame_data:
284
- team_id, iou = self._find_best_match(
285
- target_box,
286
- predicted_frame_data[next_predicted_frame],
287
- iou_threshold
288
- )
289
- if team_id is not None and iou > best_iou:
290
- best_team_id = team_id
291
- best_iou = iou
292
-
293
- # Track interpolation success
294
- if best_team_id is not None:
295
- interpolated_count += 1
296
- else:
297
- # Fallback: if no match found, predict individually
298
- crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
299
- if crop.size > 0:
300
- team_id = self.team_classifier.predict([crop])[0]
301
- best_team_id = str(6 + int(team_id))
302
- fallback_count += 1
303
-
304
- if best_team_id is not None:
305
- team_predictions[idx] = best_team_id
306
-
307
- # Parse boxes with team classification
308
- for idx, box in enumerate(detection_box):
309
- x1, y1, x2, y2, conf, cls_id = box.tolist()
310
- if cls_id == 2 and conf < 0.6:
311
  continue
312
-
313
- # Check overlap with staff box
314
- overlap_staff = False
315
- for idy, boxy in enumerate(detection_box):
316
- s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist()
317
- if cls_id == 2 and s_cls_id == 4:
318
- staff_iou = self._calculate_iou(box[:4], boxy[:4])
319
- if staff_iou >= 0.8:
320
- overlap_staff = True
321
  break
322
- if overlap_staff:
323
- continue
324
-
325
- mapped_cls_id = str(int(cls_id))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- # Override cls_id for players with team prediction
328
- if idx in team_predictions:
329
- mapped_cls_id = team_predictions[idx]
330
- if mapped_cls_id != '4':
331
- if int(mapped_cls_id) == 3 and conf < 0.5:
332
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  boxes.append(
334
  BoundingBox(
335
  x1=int(x1),
336
  y1=int(y1),
337
  x2=int(x2),
338
  y2=int(y2),
339
- cls_id=int(mapped_cls_id),
340
  conf=float(conf),
341
  )
342
  )
 
343
  # Handle footballs - keep only the best one
344
  footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
345
  if len(footballs) > 1:
346
  best_ball = max(footballs, key=lambda b: b.conf)
347
  boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
348
  boxes.append(best_ball)
349
-
350
- bboxes[offset + frame_id] = boxes
351
- return bboxes
352
-
353
 
354
- def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
355
- start = time.time()
356
- detection_results = self._detect_objects_batch(batch_images)
357
- end = time.time()
358
- print(f"Detection time: {end - start}")
359
- start = time.time()
360
- bboxes = self._team_classify(detection_results, batch_images, offset)
361
- end = time.time()
362
- print(f"Team classify time: {end - start}")
 
 
 
 
 
363
 
364
- pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
365
- keypoints: Dict[int, List[Tuple[int, int]]] = {}
366
 
367
- start = time.time()
368
- while True:
369
- gc.collect()
370
- if torch.cuda.is_available():
371
- torch.cuda.empty_cache()
372
- torch.cuda.synchronize()
373
- device_str = "cuda"
374
- keypoints_result = process_batch_input(
375
- batch_images,
376
- self.keypoints_model,
377
- self.kp_threshold,
378
- device_str,
379
- batch_size=pitch_batch_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  )
381
- if keypoints_result is not None and len(keypoints_result) > 0:
382
- for frame_number_in_batch, kp_dict in enumerate(keypoints_result):
383
- if frame_number_in_batch >= len(batch_images):
384
- break
385
- frame_keypoints: List[Tuple[int, int]] = []
386
- try:
387
- height, width = batch_images[frame_number_in_batch].shape[:2]
388
- if kp_dict is not None and isinstance(kp_dict, dict):
389
- for idx in range(32):
390
- x, y = 0, 0
391
- kp_idx = idx + 1
392
- if kp_idx in kp_dict:
393
- try:
394
- kp_data = kp_dict[kp_idx]
395
- if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
396
- x = int(kp_data["x"] * width)
397
- y = int(kp_data["y"] * height)
398
- except (KeyError, TypeError, ValueError):
399
- pass
400
- frame_keypoints.append((x, y))
401
- except (IndexError, ValueError, AttributeError):
402
- frame_keypoints = [(0, 0)] * 32
403
- if len(frame_keypoints) < n_keypoints:
404
- frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
405
- else:
406
- frame_keypoints = frame_keypoints[:n_keypoints]
407
- keypoints[offset + frame_number_in_batch] = frame_keypoints
408
- break
409
- end = time.time()
410
- print(f"Keypoint time: {end - start}")
411
-
412
 
 
 
 
 
 
 
413
  results: List[TVFrameResult] = []
414
  for frame_number in range(offset, offset + len(batch_images)):
415
- frame_boxes = bboxes.get(frame_number, [])
416
- frame_keypoints = keypoints.get(frame_number, [(0, 0) for _ in range(n_keypoints)])
417
- result = TVFrameResult(
418
- frame_id=frame_number,
419
- boxes=frame_boxes,
420
- keypoints=frame_keypoints,
 
 
 
421
  )
422
- results.append(result)
423
-
424
- gc.collect()
425
- if torch.cuda.is_available():
426
- torch.cuda.empty_cache()
427
- torch.cuda.synchronize()
428
 
429
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
+ from typing import List, Tuple, Dict, Optional
3
+ import sys, os
 
 
 
 
4
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
5
+ import onnxruntime as ort
6
+ import numpy as np
7
+ import cv2
8
+ from torchvision.ops import batched_nms
9
+ import torch
10
  from ultralytics import YOLO
11
+ from numpy import ndarray
12
+ from pydantic import BaseModel
13
  from team_cluster import TeamClassifier
14
  from utils import (
15
  BoundingBox,
16
  Constants,
17
+ suppress_small_contained_boxes,
18
+ classify_teams_batch,
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class TVFrameResult(BaseModel):
23
  frame_id: int
 
26
 
27
 
28
  class Miner:
29
+ """
30
+ Football video analysis system for object detection and team classification.
31
+ """
32
+ # Use constants from utils
33
  SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
34
  SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
35
  SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
 
38
  CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
39
  GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
40
  MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
41
+ MAX_SAMPLES_FOR_FIT = 500 # Maximum samples to avoid overfitting
42
 
43
  def __init__(self, path_hf_repo: Path) -> None:
44
+ providers = [
45
+ 'CUDAExecutionProvider',
46
+ 'CPUExecutionProvider'
47
+ ]
48
+ model_path = path_hf_repo / "detection.onnx"
49
+ session = ort.InferenceSession(model_path, providers=providers)
50
+
51
+ input_name = session.get_inputs()[0].name
52
+ height = width = 640
53
+ dummy = np.zeros((1, 3, height, width), dtype=np.float32)
54
+ session.run(None, {input_name: dummy})
55
+ model = session
56
+ self.bbox_model = model
57
+
58
+ print("BBox Model Loaded")
59
+ self.keypoints_model = YOLO(path_hf_repo / "keypoint.pt")
60
+ print("Keypoints Model (keypoint.pt) Loaded")
61
+ # Initialize team classifier with OSNet model
62
+ team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
63
+ device = 'cuda'
64
+ self.team_classifier = TeamClassifier(
65
+ device=device,
66
+ batch_size=32,
67
+ model_name=str(team_model_path)
68
+ )
69
+ print("Team Classifier Loaded")
70
+
71
+ # Team classification state
72
+ self.team_classifier_fitted = False
73
+ self.player_crops_for_fit = [] # Collect samples across frames
74
 
75
+ def __repr__(self) -> str:
76
+ return (
77
+ f"BBox Model: {type(self.bbox_model).__name__}\n"
78
+ f"Keypoints Model: {type(self.keypoints_model).__name__}"
79
+ )
 
 
 
 
80
 
 
 
 
 
 
 
 
 
81
 
 
 
 
 
 
 
 
 
 
82
 
83
+ def _handle_multiple_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
 
84
  """
85
+ Handle goalkeeper detection issues:
86
+ 1. Fix misplaced goalkeepers (standing in middle of field)
87
+ 2. Limit to maximum 2 goalkeepers (one from each team)
88
+
89
  Returns:
90
+ Filtered list of boxes with corrected goalkeepers
91
  """
92
+ # Step 1: Fix misplaced goalkeepers first
93
+ # Convert goalkeepers in middle of field to regular players
94
+ boxes = self._fix_misplaced_goalkeepers(boxes)
95
+
96
+ # Step 2: Handle multiple goalkeepers (after fixing misplaced ones)
97
+ gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1]
98
+ if len(gk_idxs) <= 2:
99
+ return boxes
100
+
101
+ # Sort goalkeepers by confidence (highest first)
102
+ gk_idxs_sorted = sorted(gk_idxs, key=lambda i: boxes[i].conf, reverse=True)
103
+ keep_gk_idxs = set(gk_idxs_sorted[:2]) # Keep top 2 goalkeepers
104
+
105
+ # Create new list keeping only top 2 goalkeepers
106
+ filtered_boxes = []
107
+ for i, box in enumerate(boxes):
108
+ if int(box.cls_id) == 1:
109
+ # Only keep the top 2 goalkeepers by confidence
110
+ if i in keep_gk_idxs:
111
+ filtered_boxes.append(box)
112
+ # Skip extra goalkeepers
113
+ else:
114
+ # Keep all non-goalkeeper boxes
115
+ filtered_boxes.append(box)
116
+
117
+ return filtered_boxes
118
 
119
+ def _fix_misplaced_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
120
+ """
121
+ """
122
+ gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1]
123
+ player_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 2]
124
+
125
+ if len(gk_idxs) == 0 or len(player_idxs) < 2:
126
+ return boxes
127
+
128
+ updated_boxes = boxes.copy()
129
+
130
+ for gk_idx in gk_idxs:
131
+ if boxes[gk_idx].conf < 0.3:
132
+ updated_boxes[gk_idx].cls_id = 2
133
+
134
+ return updated_boxes
135
 
 
 
136
 
137
+ def _pre_process_img(self, frames: List[np.ndarray], scale: float = 640.0) -> np.ndarray:
138
+ """
139
+ Preprocess images for ONNX inference.
140
+
141
+ Args:
142
+ frames: List of BGR frames
143
+ scale: Target scale for resizing
144
+
145
+ Returns:
146
+ Preprocessed numpy array ready for ONNX inference
147
+ """
148
+ imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames])
149
+ imgs = imgs.transpose(0, 3, 1, 2) # BHWC to BCHW
150
+ imgs = imgs.astype(np.float32) / 255.0 # Normalize to [0, 1]
151
+ return imgs
152
 
153
+ def _post_process_output(self, outputs: np.ndarray, x_scale: float, y_scale: float,
154
+ conf_thresh: float = 0.6, nms_thresh: float = 0.55) -> List[List[Tuple]]:
155
+ """
156
+ Post-process ONNX model outputs to get detections.
 
 
 
 
157
 
158
+ Args:
159
+ outputs: Raw ONNX model outputs
160
+ x_scale: X-axis scaling factor
161
+ y_scale: Y-axis scaling factor
162
+ conf_thresh: Confidence threshold
163
+ nms_thresh: NMS threshold
164
+
165
+ Returns:
166
+ List of detections for each frame: [(box, conf, class_id), ...]
167
+ """
168
+ B, C, N = outputs.shape
169
+ outputs = torch.from_numpy(outputs)
170
+ outputs = outputs.permute(0, 2, 1) # B,C,N -> B,N,C
171
+
172
+ boxes = outputs[..., :4]
173
+ class_scores = 1 / (1 + torch.exp(-outputs[..., 4:])) # Sigmoid activation
174
+ conf, class_id = class_scores.max(dim=2)
175
 
176
+ mask = conf > conf_thresh
177
+
178
+ # Special handling for balls - keep best one even with lower confidence
179
+ for i in range(class_id.shape[0]): # loop over batch
180
+ # Find detections that are balls
181
+ ball_mask = class_id[i] == 0
182
+ ball_idx = ball_mask.nonzero(as_tuple=True)[0]
183
+ if ball_idx.numel() > 0:
184
+ # Pick the one with the highest confidence
185
+ best_ball_idx = ball_idx[conf[i, ball_idx].argmax()]
186
+ if conf[i, best_ball_idx] >= 0.55: # apply confidence threshold
187
+ mask[i, best_ball_idx] = True
188
+
189
+ batch_idx, pred_idx = mask.nonzero(as_tuple=True)
190
 
191
+ if len(batch_idx) == 0:
192
+ return [[] for _ in range(B)]
193
+
194
+ boxes = boxes[batch_idx, pred_idx]
195
+ conf = conf[batch_idx, pred_idx]
196
+ class_id = class_id[batch_idx, pred_idx]
197
+
198
+ # Convert from center format to xyxy format
199
+ x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
200
+ x1 = (x - w / 2) * x_scale
201
+ y1 = (y - h / 2) * y_scale
202
+ x2 = (x + w / 2) * x_scale
203
+ y2 = (y + h / 2) * y_scale
204
+ boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
205
+
206
+ # Apply batched NMS
207
+ max_coord = 1e4
208
+ offset = batch_idx.to(boxes_xyxy) * max_coord
209
+ boxes_for_nms = boxes_xyxy + offset[:, None]
210
+
211
+ keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh)
212
+
213
+ boxes_final = boxes_xyxy[keep]
214
+ conf_final = conf[keep]
215
+ class_final = class_id[keep]
216
+ batch_final = batch_idx[keep]
217
+
218
+ # Group results by batch
219
+ results = [[] for _ in range(B)]
220
+ for b in range(B):
221
+ mask_b = batch_final == b
222
+ if mask_b.sum() == 0:
223
  continue
224
+ results[b] = list(zip(boxes_final[mask_b].numpy(),
225
+ conf_final[mask_b].numpy(),
226
+ class_final[mask_b].numpy()))
227
+ return results
228
+
229
+ def _ioa(self, a: BoundingBox, b: BoundingBox) -> float:
230
+ inter = self._intersect_area(a, b)
231
+ aa = self._area(a)
232
+ if aa <= 0:
233
+ return 0.0
234
+ return inter / aa
235
+
236
+ def suppress_small_contained(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
237
+ if len(boxes) <= 1:
238
+ return boxes
239
+ keep = [True] * len(boxes)
240
+ areas = [self._area(bb) for bb in boxes]
241
+ for i in range(len(boxes)):
242
+ if not keep[i]:
243
+ continue
244
+ for j in range(len(boxes)):
245
+ if i == j or not keep[j]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  continue
247
+ ai, aj = areas[i], areas[j]
248
+ if ai == 0 or aj == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  continue
250
+ if ai <= aj:
251
+ ratio = ai / aj
252
+ if ratio <= self.SMALL_RATIO_MAX:
253
+ ioa_i_in_j = self._ioa(boxes[i], boxes[j])
254
+ if ioa_i_in_j >= self.SMALL_CONTAINED_IOA:
255
+ keep[i] = False
 
 
 
256
  break
257
+ else:
258
+ ratio = aj / ai
259
+ if ratio <= self.SMALL_RATIO_MAX:
260
+ ioa_j_in_i = self._ioa(boxes[j], boxes[i])
261
+ if ioa_j_in_i >= self.SMALL_CONTAINED_IOA:
262
+ keep[j] = False
263
+ return [bb for bb, k in zip(boxes, keep) if k]
264
+
265
+ def _detect_objects_batch(self, batch_images: List[ndarray], offset: int) -> Dict[int, List[BoundingBox]]:
266
+ """
267
+ Phase 1: Object detection for all frames in batch.
268
+ Returns detected objects with players still having class_id=2 (before team classification).
269
+
270
+ Args:
271
+ batch_images: List of images to process
272
+ offset: Frame offset for numbering
273
+
274
+ Returns:
275
+ Dictionary mapping frame_id to list of detected boxes
276
+ """
277
+ bboxes: Dict[int, List[BoundingBox]] = {}
278
 
279
+ if len(batch_images) == 0:
280
+ return bboxes
281
+
282
+ print(f"Processing batch of {len(batch_images)} images")
283
+
284
+ # Get original image dimensions for scaling
285
+ height, width = batch_images[0].shape[:2]
286
+ scale = 640.0
287
+ x_scale = width / scale
288
+ y_scale = height / scale
289
+
290
+ # Memory optimization: Process smaller batches if needed
291
+ max_batch_size = 32 # Reduce batch size further to prevent memory issues
292
+ if len(batch_images) > max_batch_size:
293
+ print(f"Large batch detected ({len(batch_images)} images), splitting into smaller batches of {max_batch_size}")
294
+ # Process in smaller chunks
295
+ all_bboxes = {}
296
+ for chunk_start in range(0, len(batch_images), max_batch_size):
297
+ chunk_end = min(chunk_start + max_batch_size, len(batch_images))
298
+ chunk_images = batch_images[chunk_start:chunk_end]
299
+ chunk_offset = offset + chunk_start
300
+ print(f"Processing chunk {chunk_start//max_batch_size + 1}: images {chunk_start}-{chunk_end-1}")
301
+ chunk_bboxes = self._detect_objects_batch(chunk_images, chunk_offset)
302
+ all_bboxes.update(chunk_bboxes)
303
+ return all_bboxes
304
+
305
+ # Preprocess images for ONNX inference
306
+ imgs = self._pre_process_img(batch_images, scale)
307
+ actual_batch_size = len(batch_images)
308
+
309
+ # Handle batch size mismatch - pad if needed
310
+ model_batch_size = self.bbox_model.get_inputs()[0].shape[0]
311
+ print(f"Model input shape: {self.bbox_model.get_inputs()[0].shape}, batch_size: {model_batch_size}")
312
+
313
+ if model_batch_size is not None:
314
+ try:
315
+ # Handle dynamic batch size (None, -1, 'None')
316
+ if str(model_batch_size) in ['None', '-1'] or model_batch_size == -1:
317
+ model_batch_size = None
318
+ else:
319
+ model_batch_size = int(model_batch_size)
320
+ except (ValueError, TypeError):
321
+ model_batch_size = None
322
+
323
+ print(f"Processed model_batch_size: {model_batch_size}, actual_batch_size: {actual_batch_size}")
324
+
325
+ if model_batch_size and actual_batch_size < model_batch_size:
326
+ padding_size = model_batch_size - actual_batch_size
327
+ dummy_img = np.zeros((1, 3, int(scale), int(scale)), dtype=np.float32)
328
+ padding = np.repeat(dummy_img, padding_size, axis=0)
329
+ imgs = np.vstack([imgs, padding])
330
+
331
+ # ONNX inference with error handling
332
+ try:
333
+ input_name = self.bbox_model.get_inputs()[0].name
334
+ import time
335
+ start_time = time.time()
336
+ outputs = self.bbox_model.run(None, {input_name: imgs})[0]
337
+ inference_time = time.time() - start_time
338
+ print(f"Inference time: {inference_time:.3f}s for {actual_batch_size} images")
339
+
340
+ # Remove padded results if we added padding
341
+ if model_batch_size and isinstance(model_batch_size, int) and actual_batch_size < model_batch_size:
342
+ outputs = outputs[:actual_batch_size]
343
+
344
+ # Post-process outputs to get detections
345
+ raw_results = self._post_process_output(np.array(outputs), x_scale, y_scale)
346
+
347
+ except Exception as e:
348
+ print(f"Error during ONNX inference: {e}")
349
+ return bboxes
350
+
351
+ if not raw_results:
352
+ return bboxes
353
+
354
+ # Convert raw results to BoundingBox objects and apply processing
355
+ for frame_idx_in_batch, frame_detections in enumerate(raw_results):
356
+ if not frame_detections:
357
+ continue
358
+
359
+ # Convert to BoundingBox objects
360
+ boxes: List[BoundingBox] = []
361
+ for box, conf, cls_id in frame_detections:
362
+ x1, y1, x2, y2 = box
363
+ if int(cls_id) < 4:
364
  boxes.append(
365
  BoundingBox(
366
  x1=int(x1),
367
  y1=int(y1),
368
  x2=int(x2),
369
  y2=int(y2),
370
+ cls_id=int(cls_id),
371
  conf=float(conf),
372
  )
373
  )
374
+
375
  # Handle footballs - keep only the best one
376
  footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
377
  if len(footballs) > 1:
378
  best_ball = max(footballs, key=lambda b: b.conf)
379
  boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
380
  boxes.append(best_ball)
 
 
 
 
381
 
382
+ # Remove overlapping small boxes
383
+ boxes = suppress_small_contained_boxes(boxes, self.SMALL_CONTAINED_IOA, self.SMALL_RATIO_MAX)
384
+
385
+ # Handle goalkeeper detection issues:
386
+ # 1. Fix misplaced goalkeepers (convert to players if standing in middle)
387
+ # 2. Allow up to 2 goalkeepers maximum (one from each team)
388
+ # Goalkeepers remain class_id = 1 (no team assignment)
389
+ boxes = self._handle_multiple_goalkeepers(boxes)
390
+
391
+ # Store results (players still have class_id=2, will be classified in phase 2)
392
+ frame_id = offset + frame_idx_in_batch
393
+ bboxes[frame_id] = boxes
394
+
395
+ return bboxes
396
 
 
 
397
 
398
+ def predict_batch(
399
+ self,
400
+ batch_images: List[ndarray],
401
+ offset: int,
402
+ n_keypoints: int,
403
+ task_type: Optional[str] = None,
404
+ ) -> List[TVFrameResult]:
405
+ process_objects = task_type is None or task_type == "object"
406
+ process_keypoints = task_type is None or task_type == "keypoint"
407
+
408
+ # Phase 1: Object Detection for all frames
409
+ bboxes: Dict[int, List[BoundingBox]] = {}
410
+ if process_objects:
411
+ bboxes = self._detect_objects_batch(batch_images, offset)
412
+
413
+ import time
414
+ time_start = time.time()
415
+ # Phase 2: Team Classification for all detected players
416
+ if process_objects and bboxes:
417
+ bboxes, self.team_classifier_fitted, self.player_crops_for_fit = classify_teams_batch(
418
+ self.team_classifier,
419
+ self.team_classifier_fitted,
420
+ self.player_crops_for_fit,
421
+ batch_images,
422
+ bboxes,
423
+ offset,
424
+ self.MIN_SAMPLES_FOR_FIT,
425
+ self.MAX_SAMPLES_FOR_FIT,
426
+ self.SINGLE_PLAYER_HUE_PIVOT
427
  )
428
+ self.team_classifier_fitted = False
429
+ self.player_crops_for_fit = []
430
+ print(f"Time Team Classification: {time.time() - time_start} s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
+ # Phase 3: Keypoint Detection
433
+ keypoints: Dict[int, List[Tuple[int, int]]] = {}
434
+ if process_keypoints:
435
+ keypoints = self._detect_keypoints_batch(batch_images, offset, n_keypoints)
436
+
437
+ # Phase 4: Combine results
438
  results: List[TVFrameResult] = []
439
  for frame_number in range(offset, offset + len(batch_images)):
440
+ results.append(
441
+ TVFrameResult(
442
+ frame_id=frame_number,
443
+ boxes=bboxes.get(frame_number, []),
444
+ keypoints=keypoints.get(
445
+ frame_number,
446
+ [(0, 0) for _ in range(n_keypoints)],
447
+ ),
448
+ )
449
  )
450
+ return results
 
 
 
 
 
451
 
452
+ def _detect_keypoints_batch(self, batch_images: List[ndarray],
453
+ offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]:
454
+ """
455
+ Phase 3: Keypoint detection for all frames in batch.
456
+
457
+ Args:
458
+ batch_images: List of images to process
459
+ offset: Frame offset for numbering
460
+ n_keypoints: Number of keypoints expected
461
+
462
+ Returns:
463
+ Dictionary mapping frame_id to list of keypoint coordinates
464
+ """
465
+ keypoints: Dict[int, List[Tuple[int, int]]] = {}
466
+ keypoints_model_results = self.keypoints_model.predict(batch_images)
467
+
468
+ if keypoints_model_results is None:
469
+ return keypoints
470
+
471
+ for frame_idx_in_batch, detection in enumerate(keypoints_model_results):
472
+ if not hasattr(detection, "keypoints") or detection.keypoints is None:
473
+ continue
474
+
475
+ # Extract keypoints with confidence
476
+ frame_keypoints_with_conf: List[Tuple[int, int, float]] = []
477
+ for i, part_points in enumerate(detection.keypoints.data):
478
+ for k_id, (x, y, _) in enumerate(part_points):
479
+ confidence = float(detection.keypoints.conf[i][k_id])
480
+ frame_keypoints_with_conf.append((int(x), int(y), confidence))
481
+
482
+ # Pad or truncate to expected number of keypoints
483
+ if len(frame_keypoints_with_conf) < n_keypoints:
484
+ frame_keypoints_with_conf.extend(
485
+ [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf))
486
+ )
487
+ else:
488
+ frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints]
489
+
490
+ # Filter keypoints based on confidence thresholds
491
+ filtered_keypoints: List[Tuple[int, int]] = []
492
+ for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf):
493
+ if idx in self.CORNER_INDICES:
494
+ # Corner keypoints have lower confidence threshold
495
+ if confidence < 0.3:
496
+ filtered_keypoints.append((0, 0))
497
+ else:
498
+ filtered_keypoints.append((int(x), int(y)))
499
+ else:
500
+ # Regular keypoints
501
+ if confidence < 0.5:
502
+ filtered_keypoints.append((0, 0))
503
+ else:
504
+ filtered_keypoints.append((int(x), int(y)))
505
+
506
+ frame_id = offset + frame_idx_in_batch
507
+ keypoints[frame_id] = filtered_keypoints
508
+
509
+ return keypoints
object-detection.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05112479be8cb59494e9ae23a57af43becd5aa1f448b0e5ed33fcb6b4c2bbbc3
3
+ size 273322667
pitch.py CHANGED
@@ -520,7 +520,7 @@ def run_inference(model, input_tensor: torch.Tensor, device):
520
  output = model.module().forward(input_tensor)
521
  return output
522
 
523
- def preprocess_batch_fast(frames):
524
  """Ultra-fast batch preprocessing using optimized tensor operations"""
525
  target_size = (540, 960) # H, W format for model input
526
  batch = []
@@ -530,7 +530,7 @@ def preprocess_batch_fast(frames):
530
  img = img.astype(np.float32) / 255.0
531
  img = np.transpose(img, (2, 0, 1)) # HWC -> CHW
532
  batch.append(img)
533
- batch = torch.from_numpy(np.stack(batch)).float()
534
 
535
  return batch
536
 
@@ -610,24 +610,16 @@ def inference_batch(frames, model, kp_threshold, device, batch_size=8):
610
  results = []
611
  num_frames = len(frames)
612
 
613
- # Get the device from the model itself
614
- model_device = next(model.parameters()).device
615
-
616
  # Process all frames in optimally-sized batches
617
  for i in range(0, num_frames, batch_size):
618
  current_batch_size = min(batch_size, num_frames - i)
619
  batch_frames = frames[i:i + current_batch_size]
620
 
621
- # Fast preprocessing - create on CPU first
622
- batch = preprocess_batch_fast(batch_frames)
623
- b, c, h, w = batch.size()
624
-
625
- # Move batch to model device
626
- batch = batch.to(model_device)
627
-
628
- with torch.no_grad():
629
- heatmaps = model(batch)
630
 
 
 
631
  # Ultra-fast keypoint extraction
632
  kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1)
633
 
@@ -660,10 +652,28 @@ def get_mapped_keypoints(kp_points):
660
  # mapped_points[key] = value
661
  return mapped_points
662
 
663
- def process_batch_input(frames, model, kp_threshold, device, batch_size=16):
664
  """Process multiple input images in batch"""
665
  # Batch inference
666
  kp_results = inference_batch(frames, model, kp_threshold, device, batch_size)
667
  kp_results = [get_mapped_keypoints(kp) for kp in kp_results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
669
  return kp_results
 
520
  output = model.module().forward(input_tensor)
521
  return output
522
 
523
+ def preprocess_batch_fast(frames, device):
524
  """Ultra-fast batch preprocessing using optimized tensor operations"""
525
  target_size = (540, 960) # H, W format for model input
526
  batch = []
 
530
  img = img.astype(np.float32) / 255.0
531
  img = np.transpose(img, (2, 0, 1)) # HWC -> CHW
532
  batch.append(img)
533
+ batch = torch.tensor(np.stack(batch), dtype=torch.float32)
534
 
535
  return batch
536
 
 
610
  results = []
611
  num_frames = len(frames)
612
 
 
 
 
613
  # Process all frames in optimally-sized batches
614
  for i in range(0, num_frames, batch_size):
615
  current_batch_size = min(batch_size, num_frames - i)
616
  batch_frames = frames[i:i + current_batch_size]
617
 
618
+ # Fast preprocessing
619
+ batch = preprocess_batch_fast(batch_frames, device)
 
 
 
 
 
 
 
620
 
621
+ heatmaps = run_inference(model, batch, device)
622
+
623
  # Ultra-fast keypoint extraction
624
  kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1)
625
 
 
652
  # mapped_points[key] = value
653
  return mapped_points
654
 
655
+ def process_batch_input(frames, model, kp_threshold, device, batch_size=8):
656
  """Process multiple input images in batch"""
657
  # Batch inference
658
  kp_results = inference_batch(frames, model, kp_threshold, device, batch_size)
659
  kp_results = [get_mapped_keypoints(kp) for kp in kp_results]
660
+ # Draw results and save
661
+ # for i, (frame, kp_points, input_path) in enumerate(zip(frames, kp_results, valid_paths)):
662
+ # height, width = frame.shape[:2]
663
+
664
+ # # Apply mapping to get standard keypoint IDs
665
+ # mapped_kp_points = get_mapped_keypoints(kp_points)
666
+
667
+ # for key, value in mapped_kp_points.items():
668
+ # x = int(value['x'] * width)
669
+ # y = int(value['y'] * height)
670
+ # cv2.circle(frame, (x, y), 5, (0, 255, 0), -1) # Green circles
671
+ # cv2.putText(frame, str(key), (x+10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
672
+
673
+ # # Save result
674
+ # output_path = input_path.replace('.png', '_result.png').replace('.jpg', '_result.jpg')
675
+ # cv2.imwrite(output_path, frame)
676
+
677
+ # print(f"Batch processing complete. Processed {len(frames)} images.")
678
 
679
  return kp_results
player.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from sklearn.cluster import KMeans
4
+ import warnings
5
+ import time
6
+
7
+ import torch
8
+ from torchvision.ops import batched_nms
9
+ from numpy import ndarray
10
+ # Suppress ALL runtime and sklearn warnings
11
+ warnings.filterwarnings('ignore', category=RuntimeWarning)
12
+ warnings.filterwarnings('ignore', category=FutureWarning)
13
+ warnings.filterwarnings('ignore', category=UserWarning)
14
+
15
+ # Suppress sklearn warnings specifically
16
+ import logging
17
+ logging.getLogger('sklearn').setLevel(logging.ERROR)
18
+
19
+ def get_grass_color(img):
20
+ # Convert image to HSV color space
21
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
22
+
23
+ # Define range of green color in HSV
24
+ lower_green = np.array([30, 40, 40])
25
+ upper_green = np.array([80, 255, 255])
26
+
27
+ # Threshold the HSV image to get only green colors
28
+ mask = cv2.inRange(hsv, lower_green, upper_green)
29
+
30
+ # Calculate the mean value of the pixels that are not masked
31
+ masked_img = cv2.bitwise_and(img, img, mask=mask)
32
+ grass_color = cv2.mean(img, mask=mask)
33
+ return grass_color[:3]
34
+
35
+ def get_players_boxes(frame, result):
36
+ players_imgs = []
37
+ players_boxes = []
38
+ for (box, score, cls) in result:
39
+ label = int(cls)
40
+ if label == 0:
41
+ x1, y1, x2, y2 = box.astype(int)
42
+ player_img = frame[y1: y2, x1: x2]
43
+ players_imgs.append(player_img)
44
+ players_boxes.append([box, score, cls])
45
+ return players_imgs, players_boxes
46
+
47
+ def get_kits_colors(players, grass_hsv=None, frame=None):
48
+ kits_colors = []
49
+ if grass_hsv is None:
50
+ grass_color = get_grass_color(frame)
51
+ grass_hsv = cv2.cvtColor(np.uint8([[list(grass_color)]]), cv2.COLOR_BGR2HSV)
52
+
53
+ for player_img in players:
54
+ # Skip empty or invalid images
55
+ if player_img is None or player_img.size == 0 or len(player_img.shape) != 3:
56
+ continue
57
+
58
+ # Convert image to HSV color space
59
+ hsv = cv2.cvtColor(player_img, cv2.COLOR_BGR2HSV)
60
+
61
+ # Define range of green color in HSV
62
+ lower_green = np.array([grass_hsv[0, 0, 0] - 10, 40, 40])
63
+ upper_green = np.array([grass_hsv[0, 0, 0] + 10, 255, 255])
64
+
65
+ # Threshold the HSV image to get only green colors
66
+ mask = cv2.inRange(hsv, lower_green, upper_green)
67
+
68
+ # Bitwise-AND mask and original image
69
+ mask = cv2.bitwise_not(mask)
70
+ upper_mask = np.zeros(player_img.shape[:2], np.uint8)
71
+ upper_mask[0:player_img.shape[0]//2, 0:player_img.shape[1]] = 255
72
+ mask = cv2.bitwise_and(mask, upper_mask)
73
+
74
+ kit_color = np.array(cv2.mean(player_img, mask=mask)[:3])
75
+
76
+ kits_colors.append(kit_color)
77
+ return kits_colors
78
+
79
+ def get_kits_classifier(kits_colors):
80
+ if len(kits_colors) == 0:
81
+ return None
82
+ if len(kits_colors) == 1:
83
+ # Only one kit color, create a dummy classifier
84
+ return None
85
+ kits_kmeans = KMeans(n_clusters=2)
86
+ kits_kmeans.fit(kits_colors)
87
+ return kits_kmeans
88
+
89
+ def classify_kits(kits_classifer, kits_colors):
90
+ if kits_classifer is None or len(kits_colors) == 0:
91
+ return np.array([0]) # Default to team 0
92
+ team = kits_classifer.predict(kits_colors)
93
+ return team
94
+
95
+ def get_left_team_label(players_boxes, kits_colors, kits_clf):
96
+ left_team_label = 0
97
+ team_0 = []
98
+ team_1 = []
99
+
100
+ for i in range(len(players_boxes)):
101
+ x1, y1, x2, y2 = players_boxes[i][0].astype(int)
102
+ team = classify_kits(kits_clf, [kits_colors[i]]).item()
103
+ if team == 0:
104
+ team_0.append(np.array([x1]))
105
+ else:
106
+ team_1.append(np.array([x1]))
107
+
108
+ team_0 = np.array(team_0)
109
+ team_1 = np.array(team_1)
110
+
111
+ # Safely calculate averages with fallback for empty arrays
112
+ avg_team_0 = np.average(team_0) if len(team_0) > 0 else 0
113
+ avg_team_1 = np.average(team_1) if len(team_1) > 0 else 0
114
+
115
+ if avg_team_0 - avg_team_1 > 0:
116
+ left_team_label = 1
117
+
118
+ return left_team_label
119
+
120
+ def check_box_boundaries(boxes, img_height, img_width):
121
+ """
122
+ Check if bounding boxes are within image boundaries and clip them if necessary.
123
+
124
+ Args:
125
+ boxes: numpy array of shape (N, 4) with [x1, y1, x2, y2] format
126
+ img_height: height of the image
127
+ img_width: width of the image
128
+
129
+ Returns:
130
+ valid_boxes: numpy array of valid boxes within boundaries
131
+ valid_indices: indices of valid boxes
132
+ """
133
+ x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
134
+
135
+ # Check if boxes are within boundaries
136
+ valid_mask = (x1 >= 0) & (y1 >= 0) & (x2 < img_width) & (y2 < img_height) & (x1 < x2) & (y1 < y2)
137
+
138
+ if not np.any(valid_mask):
139
+ return np.array([]), np.array([])
140
+
141
+ valid_boxes = boxes[valid_mask]
142
+ valid_indices = np.where(valid_mask)[0]
143
+
144
+ # Clip boxes to image boundaries
145
+ valid_boxes[:, 0] = np.clip(valid_boxes[:, 0], 0, img_width - 1) # x1
146
+ valid_boxes[:, 1] = np.clip(valid_boxes[:, 1], 0, img_height - 1) # y1
147
+ valid_boxes[:, 2] = np.clip(valid_boxes[:, 2], 0, img_width - 1) # x2
148
+ valid_boxes[:, 3] = np.clip(valid_boxes[:, 3], 0, img_height - 1) # y2
149
+
150
+ return valid_boxes, valid_indices
151
+
152
+ def process_team_identification_batch(frames, results, kits_clf, left_team_label, grass_hsv):
153
+ """
154
+ Process team identification and label formatting for batch results.
155
+
156
+ Args:
157
+ frames: list of frames
158
+ results: list of detection results for each frame
159
+ kits_clf: trained kit classifier
160
+ left_team_label: label for left team
161
+ grass_hsv: grass color in HSV format
162
+
163
+ Returns:
164
+ processed_results: list of processed results with team identification
165
+ """
166
+ processed_results = []
167
+
168
+ for frame_idx, frame in enumerate(frames):
169
+ frame_results = []
170
+ frame_detections = results[frame_idx]
171
+
172
+ if not frame_detections:
173
+ processed_results.append([])
174
+ continue
175
+
176
+ # Extract player boxes and images
177
+ players_imgs = []
178
+ players_boxes = []
179
+ player_indices = []
180
+
181
+ for idx, (box, score, cls) in enumerate(frame_detections):
182
+ label = int(cls)
183
+ if label == 0: # Player detection
184
+ x1, y1, x2, y2 = box.astype(int)
185
+
186
+ # Check boundaries
187
+ if (x1 >= 0 and y1 >= 0 and x2 < frame.shape[1] and y2 < frame.shape[0] and x1 < x2 and y1 < y2):
188
+ player_img = frame[y1:y2, x1:x2]
189
+ if player_img.size > 0: # Ensure valid image
190
+ players_imgs.append(player_img)
191
+ players_boxes.append([box, score, cls])
192
+ player_indices.append(idx)
193
+
194
+ # Initialize player team mapping
195
+ player_team_map = {}
196
+
197
+ # Process team identification if we have players
198
+ if players_imgs and kits_clf is not None:
199
+ kits_colors = get_kits_colors(players_imgs, grass_hsv)
200
+ teams = classify_kits(kits_clf, kits_colors)
201
+
202
+ # Create mapping from player index to team
203
+ for i, team in enumerate(teams):
204
+ player_team_map[player_indices[i]] = team.item()
205
+
206
+ id = 0
207
+ # Process all detections with team identification
208
+ for idx, (box, score, cls) in enumerate(frame_detections):
209
+ label = int(cls)
210
+ x1, y1, x2, y2 = box.astype(int)
211
+
212
+ # Check boundaries
213
+ valid_boxes, valid_indices = check_box_boundaries(
214
+ np.array([[x1, y1, x2, y2]]), frame.shape[0], frame.shape[1]
215
+ )
216
+
217
+ if len(valid_boxes) == 0:
218
+ continue
219
+
220
+ x1, y1, x2, y2 = valid_boxes[0].astype(int)
221
+
222
+ # Apply team identification logic
223
+ if label == 0: # Player
224
+ if players_imgs and kits_clf is not None and idx in player_team_map:
225
+ team = player_team_map[idx]
226
+ if team == left_team_label:
227
+ final_label = 6 # Player-L (Left team)
228
+ else:
229
+ final_label = 7 # Player-R (Right team)
230
+ else:
231
+ final_label = 6 # Default player label
232
+
233
+ elif label == 1: # Goalkeeper
234
+ final_label = 1 # GK
235
+
236
+ elif label == 2: # Ball
237
+ final_label = 0 # Ball
238
+
239
+ elif label == 3 or label == 4: # Referee or other
240
+ final_label = 3 # Referee
241
+
242
+ else:
243
+ final_label = int(label) # Keep original label, ensure it's int
244
+
245
+ frame_results.append({
246
+ "id": int(id),
247
+ "bbox": [int(x1), int(y1), int(x2), int(y2)],
248
+ "class_id": int(final_label),
249
+ "conf": float(score)
250
+ })
251
+ id = id + 1
252
+
253
+ processed_results.append(frame_results)
254
+
255
+ return processed_results
256
+
257
+ def convert_numpy_types(obj):
258
+ """Convert numpy types to native Python types for JSON serialization."""
259
+ if isinstance(obj, np.integer):
260
+ return int(obj)
261
+ elif isinstance(obj, np.floating):
262
+ return float(obj)
263
+ elif isinstance(obj, np.ndarray):
264
+ return obj.tolist()
265
+ elif isinstance(obj, dict):
266
+ return {key: convert_numpy_types(value) for key, value in obj.items()}
267
+ elif isinstance(obj, list):
268
+ return [convert_numpy_types(item) for item in obj]
269
+ else:
270
+ return obj
271
+
272
+ def pre_process_img(frames, scale):
273
+ imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames])
274
+ imgs = imgs.transpose(0, 3, 1, 2)
275
+ imgs = imgs.astype(np.float32) / 255.0 # Normalize
276
+ return imgs
277
+
278
+ def post_process_output(outputs, x_scale, y_scale, conf_thresh=0.6, nms_thresh=0.75):
279
+ B, C, N = outputs.shape
280
+ outputs = torch.from_numpy(outputs)
281
+ outputs = outputs.permute(0, 2, 1)
282
+ boxes = outputs[..., :4]
283
+ class_scores = 1 / (1 + torch.exp(-outputs[..., 4:]))
284
+ conf, class_id = class_scores.max(dim=2)
285
+
286
+ mask = conf > conf_thresh
287
+
288
+ for i in range(class_id.shape[0]): # loop over batch
289
+ # Find detections that are balls
290
+ ball_idx = np.where(class_id[i] == 2)[0]
291
+ if ball_idx.size > 0:
292
+ # Pick the one with the highest confidence
293
+ top = ball_idx[np.argmax(conf[i, ball_idx])]
294
+ if conf[i, top] > 0.55: # apply confidence threshold
295
+ mask[i, top] = True
296
+
297
+ # ball_mask = (class_id == 2) & (conf > 0.51)
298
+ # mask = mask | ball_mask
299
+
300
+ batch_idx, pred_idx = mask.nonzero(as_tuple=True)
301
+
302
+ if len(batch_idx) == 0:
303
+ return [[] for _ in range(B)]
304
+
305
+ boxes = boxes[batch_idx, pred_idx]
306
+ conf = conf[batch_idx, pred_idx]
307
+ class_id = class_id[batch_idx, pred_idx]
308
+
309
+ x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
310
+ x1 = (x - w / 2) * x_scale
311
+ y1 = (y - h / 2) * y_scale
312
+ x2 = (x + w / 2) * x_scale
313
+ y2 = (y + h / 2) * y_scale
314
+ boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
315
+
316
+ max_coord = 1e4
317
+ offset = batch_idx.to(boxes_xyxy) * max_coord
318
+ boxes_for_nms = boxes_xyxy + offset[:, None]
319
+
320
+ keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh)
321
+
322
+ boxes_final = boxes_xyxy[keep]
323
+ conf_final = conf[keep]
324
+ class_final = class_id[keep]
325
+ batch_final = batch_idx[keep]
326
+
327
+ results = [[] for _ in range(B)]
328
+ for b in range(B):
329
+ mask_b = batch_final == b
330
+ if mask_b.sum() == 0:
331
+ continue
332
+ results[b] = list(zip(boxes_final[mask_b].numpy(),
333
+ conf_final[mask_b].numpy(),
334
+ class_final[mask_b].numpy()))
335
+ return results
336
+
337
+ def player_detection_result(frames: list[ndarray], batch_size, model, kits_clf=None, left_team_label=None, grass_hsv=None):
338
+ start_time = time.time()
339
+ # input_layer = model.input(0)
340
+ # output_layer = model.output(0)
341
+ height, width = frames[0].shape[:2]
342
+ scale = 640.0
343
+ x_scale = width / scale
344
+ y_scale = height / scale
345
+
346
+ # infer_queue = AsyncInferQueue(model, len(frames))
347
+
348
+ infer_time = time.time()
349
+ kits_clf = kits_clf
350
+ left_team_label = left_team_label
351
+ grass_hsv = grass_hsv
352
+ results = []
353
+ for i in range(0, len(frames), batch_size):
354
+ if i + batch_size > len(frames):
355
+ batch_size = len(frames) - i
356
+ batch_frames = frames[i:i + batch_size]
357
+ imgs = pre_process_img(batch_frames, scale)
358
+
359
+ input_name = model.get_inputs()[0].name
360
+ outputs = model.run(None, {input_name: imgs})[0]
361
+ raw_results = post_process_output(np.array(outputs), x_scale, y_scale)
362
+
363
+ if kits_clf is None or left_team_label is None or grass_hsv is None:
364
+ # Use first frame to initialize team classification
365
+ first_frame = batch_frames[0]
366
+ first_frame_results = raw_results[0] if raw_results else []
367
+
368
+ if first_frame_results:
369
+ players_imgs, players_boxes = get_players_boxes(first_frame, first_frame_results)
370
+ if players_imgs:
371
+ grass_color = get_grass_color(first_frame)
372
+ grass_hsv = cv2.cvtColor(np.uint8([[list(grass_color)]]), cv2.COLOR_BGR2HSV)
373
+ kits_colors = get_kits_colors(players_imgs, grass_hsv)
374
+ if kits_colors: # Only proceed if we have valid kit colors
375
+ kits_clf = get_kits_classifier(kits_colors)
376
+ if kits_clf is not None:
377
+ left_team_label = int(get_left_team_label(players_boxes, kits_colors, kits_clf))
378
+
379
+ # Process team identification and boundary checking
380
+ processed_results = process_team_identification_batch(
381
+ batch_frames, raw_results, kits_clf, left_team_label, grass_hsv
382
+ )
383
+
384
+ processed_results = convert_numpy_types(processed_results)
385
+ results.extend(processed_results)
386
+
387
+ # Return the same format as before for compatibility
388
+ return results, kits_clf, left_team_label, grass_hsv