update
Browse files- .gitattributes +1 -1
- .gitignore +5 -0
- keypoint → 20251029-detection.pt +2 -2
- football_keypoints_detection.pt → 20251029-keypoint.pt +0 -0
- README.md +132 -0
- SV_kp.engine +3 -0
- config.yml +2 -3
- football_object_detection.onnx → detection.onnx +0 -0
- detection.pt +3 -0
- hrnetv2_w48.yaml +0 -35
- keypoint.pt +3 -0
- miner.py +439 -359
- object-detection.onnx +3 -0
- pitch.py +25 -15
- player.py +388 -0
.gitattributes
CHANGED
|
@@ -33,5 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
-
|
| 37 |
osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
SV_kp.engine filter=lfs diff=lfs merge=lfs -text
|
| 37 |
osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
venv
|
| 2 |
+
outputs
|
| 3 |
+
test_predict_batch.py
|
| 4 |
+
test.mp4
|
| 5 |
+
inspect_yolo_model.py
|
keypoint → 20251029-detection.pt
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bbacfcb38e38b1b8816788e9e6e845160533719a0b87b693d58b932380d0d28
|
| 3 |
+
size 152961687
|
football_keypoints_detection.pt → 20251029-keypoint.pt
RENAMED
|
File without changes
|
README.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
🚀 Example Chute for Turbovision 🪂
|
| 2 |
+
|
| 3 |
+
This repository demonstrates how to deploy a Chute via the Turbovision CLI, hosted on Hugging Face Hub. It serves as a minimal example showcasing the required structure and workflow for integrating machine learning models, preprocessing, and orchestration into a reproducible Chute environment.
|
| 4 |
+
|
| 5 |
+
## Repository Structure
|
| 6 |
+
|
| 7 |
+
The following two files must be present (in their current locations) for a successful deployment — their content can be modified as needed:
|
| 8 |
+
|
| 9 |
+
| File | Purpose |
|
| 10 |
+
|------|---------|
|
| 11 |
+
| `miner.py` | Defines the ML model type(s), orchestration, and all pre/postprocessing logic. |
|
| 12 |
+
| `config.yml` | Specifies machine configuration (e.g., GPU type, memory, environment variables). |
|
| 13 |
+
|
| 14 |
+
Other files — e.g., model weights, utility scripts, or dependencies — are optional and can be included as needed for your model.
|
| 15 |
+
|
| 16 |
+
> **Note**: Any required assets must be defined or contained within this repo, which is fully open-source, since all network-related operations (downloading challenge data, weights, etc.) are disabled inside the Chute.
|
| 17 |
+
|
| 18 |
+
## Overview
|
| 19 |
+
|
| 20 |
+
Below is a high-level diagram showing the interaction between Huggingface, Chutes and Turbovision:
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
┌─────────────┐ ┌──────────┐ ┌──────────────┐
|
| 24 |
+
│ HuggingFace │ ───> │ Chutes │ ───> │ Turbovision │
|
| 25 |
+
│ Hub │ │ .ai │ │ Validator │
|
| 26 |
+
└─────────────┘ └──────────┘ └──────────────┘
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Local Testing
|
| 30 |
+
|
| 31 |
+
After editing the `config.yml` and `miner.py` and saving it into your Huggingface Repo, you will want to test it works locally.
|
| 32 |
+
|
| 33 |
+
1. **Copy the template file** `scorevision/chute_template/turbovision_chute.py.j2` as a python file called `my_chute.py` and fill in the missing variables:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
HF_REPO_NAME = "{{ huggingface_repository_name }}"
|
| 37 |
+
HF_REPO_REVISION = "{{ huggingface_repository_revision }}"
|
| 38 |
+
CHUTES_USERNAME = "{{ chute_username }}"
|
| 39 |
+
CHUTE_NAME = "{{ chute_name }}"
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
2. **Run the following command to build the chute locally** (Caution: there are known issues with the docker location when running this on a mac):
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
chutes build my_chute:chute --local --public
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
3. **Run the name of the docker image just built** (i.e. `CHUTE_NAME`) and enter it:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
docker run -p 8000:8000 -e CHUTES_EXECUTION_CONTEXT=REMOTE -it <image-name> /bin/bash
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
4. **Run the file from within the container**:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
chutes run my_chute:chute --dev --debug
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
5. **In another terminal, test the local endpoints** to ensure there are no bugs:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
# Health check
|
| 64 |
+
curl -X POST http://localhost:8000/health -d '{}'
|
| 65 |
+
|
| 66 |
+
# Prediction test
|
| 67 |
+
curl -X POST http://localhost:8000/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}'
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Live Testing
|
| 71 |
+
|
| 72 |
+
If you have any chute with the same name (i.e. from a previous deployment), ensure you delete that first (or you will get an error when trying to build).
|
| 73 |
+
|
| 74 |
+
1. **List existing chutes**:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
chutes chutes list
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
Take note of the chute id that you wish to delete (if any):
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
chutes chutes delete <chute-id>
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
2. **You should also delete its associated image**:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
chutes images list
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Take note of the chute image id:
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
chutes images delete <chute-image-id>
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
3. **Use Turbovision's CLI to build, deploy and commit on-chain**:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
sv -vv push
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
> **Note**: You can skip the on-chain commit using `--no-commit`. You can also specify a past huggingface revision to point to using `--revision` and/or the local files you want to upload to your huggingface repo using `--model-path`.
|
| 105 |
+
|
| 106 |
+
4. **When completed, warm up the chute** (if its cold 🧊):
|
| 107 |
+
|
| 108 |
+
You can confirm its status using `chutes chutes list` or `chutes chutes get <chute-id>` if you already know its id.
|
| 109 |
+
|
| 110 |
+
> **Note**: Warming up can sometimes take a while but if the chute runs without errors (should be if you've tested locally first) and there are sufficient nodes (i.e. machines) available matching the `config.yml` you specified, the chute should become hot 🔥!
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
chutes warmup <chute-id>
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
5. **Test the chute's endpoints**:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
# Health check
|
| 120 |
+
curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/health -d '{}' -H "Authorization: Bearer $CHUTES_API_KEY"
|
| 121 |
+
|
| 122 |
+
# Prediction
|
| 123 |
+
curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}' -H "Authorization: Bearer $CHUTES_API_KEY"
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
6. **Test what your chute would get on a validator**:
|
| 127 |
+
|
| 128 |
+
This also applies any validation/integrity checks which may fail if you did not use the Turbovision CLI above to deploy the chute:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
sv -vv run-once
|
| 132 |
+
```
|
SV_kp.engine
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f99452eb79e064189e2758abd20a78845a5b639fc8b9c4bc650519c83e13e8db
|
| 3 |
+
size 368289641
|
config.yml
CHANGED
|
@@ -2,15 +2,14 @@ Image:
|
|
| 2 |
from_base: parachutes/python:3.12
|
| 3 |
run_command:
|
| 4 |
- pip install --upgrade setuptools wheel
|
| 5 |
-
- pip install
|
| 6 |
-
- pip install "ultralytics==8.3.222" "opencv-python-headless" "numpy" "pydantic"
|
| 7 |
- pip install scikit-learn
|
| 8 |
- pip install onnxruntime-gpu
|
| 9 |
set_workdir: /app
|
| 10 |
|
| 11 |
NodeSelector:
|
| 12 |
gpu_count: 1
|
| 13 |
-
min_vram_gb_per_gpu:
|
| 14 |
exclude:
|
| 15 |
- "5090"
|
| 16 |
- b200
|
|
|
|
| 2 |
from_base: parachutes/python:3.12
|
| 3 |
run_command:
|
| 4 |
- pip install --upgrade setuptools wheel
|
| 5 |
+
- pip install ultralytics==8.3.222 opencv-python-headless numpy pydantic
|
|
|
|
| 6 |
- pip install scikit-learn
|
| 7 |
- pip install onnxruntime-gpu
|
| 8 |
set_workdir: /app
|
| 9 |
|
| 10 |
NodeSelector:
|
| 11 |
gpu_count: 1
|
| 12 |
+
min_vram_gb_per_gpu: 16
|
| 13 |
exclude:
|
| 14 |
- "5090"
|
| 15 |
- b200
|
football_object_detection.onnx → detection.onnx
RENAMED
|
File without changes
|
detection.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ad3e89b658d2626c34174f6799d240ffd37cfe45752c0ce6ef73b05935042e0
|
| 3 |
+
size 52014742
|
hrnetv2_w48.yaml
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
MODEL:
|
| 2 |
-
IMAGE_SIZE: [960, 540]
|
| 3 |
-
NUM_JOINTS: 58
|
| 4 |
-
PRETRAIN: ''
|
| 5 |
-
EXTRA:
|
| 6 |
-
FINAL_CONV_KERNEL: 1
|
| 7 |
-
STAGE1:
|
| 8 |
-
NUM_MODULES: 1
|
| 9 |
-
NUM_BRANCHES: 1
|
| 10 |
-
BLOCK: BOTTLENECK
|
| 11 |
-
NUM_BLOCKS: [4]
|
| 12 |
-
NUM_CHANNELS: [64]
|
| 13 |
-
FUSE_METHOD: SUM
|
| 14 |
-
STAGE2:
|
| 15 |
-
NUM_MODULES: 1
|
| 16 |
-
NUM_BRANCHES: 2
|
| 17 |
-
BLOCK: BASIC
|
| 18 |
-
NUM_BLOCKS: [4, 4]
|
| 19 |
-
NUM_CHANNELS: [48, 96]
|
| 20 |
-
FUSE_METHOD: SUM
|
| 21 |
-
STAGE3:
|
| 22 |
-
NUM_MODULES: 4
|
| 23 |
-
NUM_BRANCHES: 3
|
| 24 |
-
BLOCK: BASIC
|
| 25 |
-
NUM_BLOCKS: [4, 4, 4]
|
| 26 |
-
NUM_CHANNELS: [48, 96, 192]
|
| 27 |
-
FUSE_METHOD: SUM
|
| 28 |
-
STAGE4:
|
| 29 |
-
NUM_MODULES: 3
|
| 30 |
-
NUM_BRANCHES: 4
|
| 31 |
-
BLOCK: BASIC
|
| 32 |
-
NUM_BLOCKS: [4, 4, 4, 4]
|
| 33 |
-
NUM_CHANNELS: [48, 96, 192, 384]
|
| 34 |
-
FUSE_METHOD: SUM
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
keypoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6dd10dba85895c92760cdb5a99c5cfca899c68f361a66c5448f38a187280ee1f
|
| 3 |
+
size 6849672
|
miner.py
CHANGED
|
@@ -1,34 +1,23 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
-
from typing import List, Tuple, Dict
|
| 3 |
-
import sys
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
from numpy import ndarray
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from ultralytics import YOLO
|
|
|
|
|
|
|
| 11 |
from team_cluster import TeamClassifier
|
| 12 |
from utils import (
|
| 13 |
BoundingBox,
|
| 14 |
Constants,
|
|
|
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
-
import time
|
| 18 |
-
import torch
|
| 19 |
-
import gc
|
| 20 |
-
from pitch import process_batch_input, get_cls_net
|
| 21 |
-
import yaml
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class BoundingBox(BaseModel):
|
| 25 |
-
x1: int
|
| 26 |
-
y1: int
|
| 27 |
-
x2: int
|
| 28 |
-
y2: int
|
| 29 |
-
cls_id: int
|
| 30 |
-
conf: float
|
| 31 |
-
|
| 32 |
|
| 33 |
class TVFrameResult(BaseModel):
|
| 34 |
frame_id: int
|
|
@@ -37,6 +26,10 @@ class TVFrameResult(BaseModel):
|
|
| 37 |
|
| 38 |
|
| 39 |
class Miner:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
|
| 41 |
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
|
| 42 |
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
|
|
@@ -45,385 +38,472 @@ class Miner:
|
|
| 45 |
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
|
| 46 |
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
|
| 47 |
MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
|
| 48 |
-
MAX_SAMPLES_FOR_FIT =
|
| 49 |
|
| 50 |
def __init__(self, path_hf_repo: Path) -> None:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
model = get_cls_net(cfg_kp)
|
| 76 |
-
model.load_state_dict(loaded_state_kp)
|
| 77 |
-
model.to(device)
|
| 78 |
-
model.eval()
|
| 79 |
|
| 80 |
-
self.keypoints_model = model
|
| 81 |
-
self.kp_threshold = 0.1
|
| 82 |
-
self.pitch_batch_size = 4
|
| 83 |
-
self.health = "healthy"
|
| 84 |
-
print("✅ Keypoints Model Loaded")
|
| 85 |
-
except Exception as e:
|
| 86 |
-
self.health = "❌ Miner initialization failed: " + str(e)
|
| 87 |
-
print(self.health)
|
| 88 |
|
| 89 |
-
def __repr__(self) -> str:
|
| 90 |
-
if self.health == 'healthy':
|
| 91 |
-
return (
|
| 92 |
-
f"health: {self.health}\n"
|
| 93 |
-
f"BBox Model: {type(self.bbox_model).__name__}\n"
|
| 94 |
-
f"Keypoints Model: {type(self.keypoints_model).__name__}"
|
| 95 |
-
)
|
| 96 |
-
else:
|
| 97 |
-
return self.health
|
| 98 |
|
| 99 |
-
def
|
| 100 |
-
box2: Tuple[float, float, float, float]) -> float:
|
| 101 |
"""
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
Returns:
|
| 107 |
-
|
| 108 |
"""
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
if union_area == 0:
|
| 129 |
-
return 0.0
|
| 130 |
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
def
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
for frame_number in range(0, n_frames, batch_size):
|
| 138 |
-
batch_images = decoded_images[frame_number: frame_number + batch_size]
|
| 139 |
-
detections = self.bbox_model(batch_images, verbose=False, save=False)
|
| 140 |
-
detection_results.extend(detections)
|
| 141 |
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
continue
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
self.team_classifier.fit(player_crops_for_fit)
|
| 178 |
-
self.team_classifier_fitted = True
|
| 179 |
-
end = time.time()
|
| 180 |
-
print(f"Fitting Kmeans time: {end - start}")
|
| 181 |
-
|
| 182 |
-
# Second pass: predict teams with configurable frame skipping optimization
|
| 183 |
-
start = time.time()
|
| 184 |
-
|
| 185 |
-
# Get configuration for frame skipping
|
| 186 |
-
prediction_interval = 1 # Default: predict every 2 frames
|
| 187 |
-
iou_threshold = 0.3
|
| 188 |
-
|
| 189 |
-
print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}")
|
| 190 |
-
|
| 191 |
-
# Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}}
|
| 192 |
-
predicted_frame_data = {}
|
| 193 |
-
|
| 194 |
-
# Step 1: Predict for frames at prediction_interval only
|
| 195 |
-
frames_to_predict = []
|
| 196 |
-
for frame_id in range(len(detection_results)):
|
| 197 |
-
if frame_id % prediction_interval == 0:
|
| 198 |
-
frames_to_predict.append(frame_id)
|
| 199 |
-
|
| 200 |
-
print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames "
|
| 201 |
-
f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)")
|
| 202 |
-
|
| 203 |
-
for frame_id in frames_to_predict:
|
| 204 |
-
detection_box = detection_results[frame_id].boxes.data
|
| 205 |
-
frame_image = decoded_images[frame_id]
|
| 206 |
-
|
| 207 |
-
# Collect player crops for this frame
|
| 208 |
-
frame_player_crops = []
|
| 209 |
-
frame_player_indices = []
|
| 210 |
-
frame_player_boxes = []
|
| 211 |
-
|
| 212 |
-
for idx, box in enumerate(detection_box):
|
| 213 |
-
x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| 214 |
-
if cls_id == 2 and conf < 0.6:
|
| 215 |
continue
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# Collect player crops for prediction
|
| 219 |
-
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
|
| 220 |
-
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
|
| 221 |
-
if crop.size > 0:
|
| 222 |
-
frame_player_crops.append(crop)
|
| 223 |
-
frame_player_indices.append(idx)
|
| 224 |
-
frame_player_boxes.append((x1, y1, x2, y2))
|
| 225 |
-
|
| 226 |
-
# Predict teams for all players in this frame
|
| 227 |
-
if len(frame_player_crops) > 0:
|
| 228 |
-
team_ids = self.team_classifier.predict(frame_player_crops)
|
| 229 |
-
predicted_frame_data[frame_id] = {}
|
| 230 |
-
for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids):
|
| 231 |
-
# Map team_id (0,1) to cls_id (6,7)
|
| 232 |
-
team_cls_id = str(6 + int(team_id))
|
| 233 |
-
predicted_frame_data[frame_id][idx] = (bbox, team_cls_id)
|
| 234 |
-
|
| 235 |
-
# Step 2: Process all frames (interpolate skipped frames)
|
| 236 |
-
fallback_count = 0
|
| 237 |
-
interpolated_count = 0
|
| 238 |
-
bboxes: dict[int, list[BoundingBox]] = {}
|
| 239 |
-
for frame_id in range(len(detection_results)):
|
| 240 |
-
detection_box = detection_results[frame_id].boxes.data
|
| 241 |
-
frame_image = decoded_images[frame_id]
|
| 242 |
-
boxes = []
|
| 243 |
-
|
| 244 |
-
team_predictions = {}
|
| 245 |
-
|
| 246 |
-
if frame_id % prediction_interval == 0:
|
| 247 |
-
# Predicted frame: use pre-computed predictions
|
| 248 |
-
if frame_id in predicted_frame_data:
|
| 249 |
-
for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items():
|
| 250 |
-
team_predictions[idx] = team_cls_id
|
| 251 |
-
else:
|
| 252 |
-
# Skipped frame: interpolate from neighboring predicted frames
|
| 253 |
-
# Find nearest predicted frames
|
| 254 |
-
prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval
|
| 255 |
-
next_predicted_frame = prev_predicted_frame + prediction_interval
|
| 256 |
-
|
| 257 |
-
# Collect current frame player boxes
|
| 258 |
-
for idx, box in enumerate(detection_box):
|
| 259 |
-
x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| 260 |
-
if cls_id == 2 and conf < 0.6:
|
| 261 |
-
continue
|
| 262 |
-
mapped_cls_id = str(int(cls_id))
|
| 263 |
-
|
| 264 |
-
if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2':
|
| 265 |
-
target_box = (x1, y1, x2, y2)
|
| 266 |
-
|
| 267 |
-
# Try to match with previous predicted frame
|
| 268 |
-
best_team_id = None
|
| 269 |
-
best_iou = 0.0
|
| 270 |
-
|
| 271 |
-
if prev_predicted_frame in predicted_frame_data:
|
| 272 |
-
team_id, iou = self._find_best_match(
|
| 273 |
-
target_box,
|
| 274 |
-
predicted_frame_data[prev_predicted_frame],
|
| 275 |
-
iou_threshold
|
| 276 |
-
)
|
| 277 |
-
if team_id is not None:
|
| 278 |
-
best_team_id = team_id
|
| 279 |
-
best_iou = iou
|
| 280 |
-
|
| 281 |
-
# Try to match with next predicted frame if available and no good match yet
|
| 282 |
-
if best_team_id is None and next_predicted_frame < len(detection_results):
|
| 283 |
-
if next_predicted_frame in predicted_frame_data:
|
| 284 |
-
team_id, iou = self._find_best_match(
|
| 285 |
-
target_box,
|
| 286 |
-
predicted_frame_data[next_predicted_frame],
|
| 287 |
-
iou_threshold
|
| 288 |
-
)
|
| 289 |
-
if team_id is not None and iou > best_iou:
|
| 290 |
-
best_team_id = team_id
|
| 291 |
-
best_iou = iou
|
| 292 |
-
|
| 293 |
-
# Track interpolation success
|
| 294 |
-
if best_team_id is not None:
|
| 295 |
-
interpolated_count += 1
|
| 296 |
-
else:
|
| 297 |
-
# Fallback: if no match found, predict individually
|
| 298 |
-
crop = frame_image[int(y1):int(y2), int(x1):int(x2)]
|
| 299 |
-
if crop.size > 0:
|
| 300 |
-
team_id = self.team_classifier.predict([crop])[0]
|
| 301 |
-
best_team_id = str(6 + int(team_id))
|
| 302 |
-
fallback_count += 1
|
| 303 |
-
|
| 304 |
-
if best_team_id is not None:
|
| 305 |
-
team_predictions[idx] = best_team_id
|
| 306 |
-
|
| 307 |
-
# Parse boxes with team classification
|
| 308 |
-
for idx, box in enumerate(detection_box):
|
| 309 |
-
x1, y1, x2, y2, conf, cls_id = box.tolist()
|
| 310 |
-
if cls_id == 2 and conf < 0.6:
|
| 311 |
continue
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
staff_iou = self._calculate_iou(box[:4], boxy[:4])
|
| 319 |
-
if staff_iou >= 0.8:
|
| 320 |
-
overlap_staff = True
|
| 321 |
break
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
boxes.append(
|
| 334 |
BoundingBox(
|
| 335 |
x1=int(x1),
|
| 336 |
y1=int(y1),
|
| 337 |
x2=int(x2),
|
| 338 |
y2=int(y2),
|
| 339 |
-
cls_id=int(
|
| 340 |
conf=float(conf),
|
| 341 |
)
|
| 342 |
)
|
|
|
|
| 343 |
# Handle footballs - keep only the best one
|
| 344 |
footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
|
| 345 |
if len(footballs) > 1:
|
| 346 |
best_ball = max(footballs, key=lambda b: b.conf)
|
| 347 |
boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
|
| 348 |
boxes.append(best_ball)
|
| 349 |
-
|
| 350 |
-
bboxes[offset + frame_id] = boxes
|
| 351 |
-
return bboxes
|
| 352 |
-
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
pitch_batch_size = min(self.pitch_batch_size, len(batch_images))
|
| 365 |
-
keypoints: Dict[int, List[Tuple[int, int]]] = {}
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
)
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
break
|
| 385 |
-
frame_keypoints: List[Tuple[int, int]] = []
|
| 386 |
-
try:
|
| 387 |
-
height, width = batch_images[frame_number_in_batch].shape[:2]
|
| 388 |
-
if kp_dict is not None and isinstance(kp_dict, dict):
|
| 389 |
-
for idx in range(32):
|
| 390 |
-
x, y = 0, 0
|
| 391 |
-
kp_idx = idx + 1
|
| 392 |
-
if kp_idx in kp_dict:
|
| 393 |
-
try:
|
| 394 |
-
kp_data = kp_dict[kp_idx]
|
| 395 |
-
if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
|
| 396 |
-
x = int(kp_data["x"] * width)
|
| 397 |
-
y = int(kp_data["y"] * height)
|
| 398 |
-
except (KeyError, TypeError, ValueError):
|
| 399 |
-
pass
|
| 400 |
-
frame_keypoints.append((x, y))
|
| 401 |
-
except (IndexError, ValueError, AttributeError):
|
| 402 |
-
frame_keypoints = [(0, 0)] * 32
|
| 403 |
-
if len(frame_keypoints) < n_keypoints:
|
| 404 |
-
frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
|
| 405 |
-
else:
|
| 406 |
-
frame_keypoints = frame_keypoints[:n_keypoints]
|
| 407 |
-
keypoints[offset + frame_number_in_batch] = frame_keypoints
|
| 408 |
-
break
|
| 409 |
-
end = time.time()
|
| 410 |
-
print(f"Keypoint time: {end - start}")
|
| 411 |
-
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
results: List[TVFrameResult] = []
|
| 414 |
for frame_number in range(offset, offset + len(batch_images)):
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
| 421 |
)
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
gc.collect()
|
| 425 |
-
if torch.cuda.is_available():
|
| 426 |
-
torch.cuda.empty_cache()
|
| 427 |
-
torch.cuda.synchronize()
|
| 428 |
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
+
from typing import List, Tuple, Dict, Optional
|
| 3 |
+
import sys, os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from torchvision.ops import batched_nms
|
| 9 |
+
import torch
|
| 10 |
from ultralytics import YOLO
|
| 11 |
+
from numpy import ndarray
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
from team_cluster import TeamClassifier
|
| 14 |
from utils import (
|
| 15 |
BoundingBox,
|
| 16 |
Constants,
|
| 17 |
+
suppress_small_contained_boxes,
|
| 18 |
+
classify_teams_batch,
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class TVFrameResult(BaseModel):
|
| 23 |
frame_id: int
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
class Miner:
|
| 29 |
+
"""
|
| 30 |
+
Football video analysis system for object detection and team classification.
|
| 31 |
+
"""
|
| 32 |
+
# Use constants from utils
|
| 33 |
SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
|
| 34 |
SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
|
| 35 |
SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
|
|
|
|
| 38 |
CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
|
| 39 |
GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
|
| 40 |
MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier
|
| 41 |
+
MAX_SAMPLES_FOR_FIT = 500 # Maximum samples to avoid overfitting
|
| 42 |
|
| 43 |
def __init__(self, path_hf_repo: Path) -> None:
|
| 44 |
+
providers = [
|
| 45 |
+
'CUDAExecutionProvider',
|
| 46 |
+
'CPUExecutionProvider'
|
| 47 |
+
]
|
| 48 |
+
model_path = path_hf_repo / "detection.onnx"
|
| 49 |
+
session = ort.InferenceSession(model_path, providers=providers)
|
| 50 |
+
|
| 51 |
+
input_name = session.get_inputs()[0].name
|
| 52 |
+
height = width = 640
|
| 53 |
+
dummy = np.zeros((1, 3, height, width), dtype=np.float32)
|
| 54 |
+
session.run(None, {input_name: dummy})
|
| 55 |
+
model = session
|
| 56 |
+
self.bbox_model = model
|
| 57 |
+
|
| 58 |
+
print("BBox Model Loaded")
|
| 59 |
+
self.keypoints_model = YOLO(path_hf_repo / "keypoint.pt")
|
| 60 |
+
print("Keypoints Model (keypoint.pt) Loaded")
|
| 61 |
+
# Initialize team classifier with OSNet model
|
| 62 |
+
team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
|
| 63 |
+
device = 'cuda'
|
| 64 |
+
self.team_classifier = TeamClassifier(
|
| 65 |
+
device=device,
|
| 66 |
+
batch_size=32,
|
| 67 |
+
model_name=str(team_model_path)
|
| 68 |
+
)
|
| 69 |
+
print("Team Classifier Loaded")
|
| 70 |
+
|
| 71 |
+
# Team classification state
|
| 72 |
+
self.team_classifier_fitted = False
|
| 73 |
+
self.player_crops_for_fit = [] # Collect samples across frames
|
| 74 |
|
| 75 |
+
def __repr__(self) -> str:
|
| 76 |
+
return (
|
| 77 |
+
f"BBox Model: {type(self.bbox_model).__name__}\n"
|
| 78 |
+
f"Keypoints Model: {type(self.keypoints_model).__name__}"
|
| 79 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
def _handle_multiple_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
|
|
|
|
| 84 |
"""
|
| 85 |
+
Handle goalkeeper detection issues:
|
| 86 |
+
1. Fix misplaced goalkeepers (standing in middle of field)
|
| 87 |
+
2. Limit to maximum 2 goalkeepers (one from each team)
|
| 88 |
+
|
| 89 |
Returns:
|
| 90 |
+
Filtered list of boxes with corrected goalkeepers
|
| 91 |
"""
|
| 92 |
+
# Step 1: Fix misplaced goalkeepers first
|
| 93 |
+
# Convert goalkeepers in middle of field to regular players
|
| 94 |
+
boxes = self._fix_misplaced_goalkeepers(boxes)
|
| 95 |
+
|
| 96 |
+
# Step 2: Handle multiple goalkeepers (after fixing misplaced ones)
|
| 97 |
+
gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1]
|
| 98 |
+
if len(gk_idxs) <= 2:
|
| 99 |
+
return boxes
|
| 100 |
+
|
| 101 |
+
# Sort goalkeepers by confidence (highest first)
|
| 102 |
+
gk_idxs_sorted = sorted(gk_idxs, key=lambda i: boxes[i].conf, reverse=True)
|
| 103 |
+
keep_gk_idxs = set(gk_idxs_sorted[:2]) # Keep top 2 goalkeepers
|
| 104 |
+
|
| 105 |
+
# Create new list keeping only top 2 goalkeepers
|
| 106 |
+
filtered_boxes = []
|
| 107 |
+
for i, box in enumerate(boxes):
|
| 108 |
+
if int(box.cls_id) == 1:
|
| 109 |
+
# Only keep the top 2 goalkeepers by confidence
|
| 110 |
+
if i in keep_gk_idxs:
|
| 111 |
+
filtered_boxes.append(box)
|
| 112 |
+
# Skip extra goalkeepers
|
| 113 |
+
else:
|
| 114 |
+
# Keep all non-goalkeeper boxes
|
| 115 |
+
filtered_boxes.append(box)
|
| 116 |
+
|
| 117 |
+
return filtered_boxes
|
| 118 |
|
| 119 |
+
def _fix_misplaced_goalkeepers(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
|
| 120 |
+
"""
|
| 121 |
+
"""
|
| 122 |
+
gk_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 1]
|
| 123 |
+
player_idxs = [i for i, bb in enumerate(boxes) if int(bb.cls_id) == 2]
|
| 124 |
+
|
| 125 |
+
if len(gk_idxs) == 0 or len(player_idxs) < 2:
|
| 126 |
+
return boxes
|
| 127 |
+
|
| 128 |
+
updated_boxes = boxes.copy()
|
| 129 |
+
|
| 130 |
+
for gk_idx in gk_idxs:
|
| 131 |
+
if boxes[gk_idx].conf < 0.3:
|
| 132 |
+
updated_boxes[gk_idx].cls_id = 2
|
| 133 |
+
|
| 134 |
+
return updated_boxes
|
| 135 |
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
def _pre_process_img(self, frames: List[np.ndarray], scale: float = 640.0) -> np.ndarray:
|
| 138 |
+
"""
|
| 139 |
+
Preprocess images for ONNX inference.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
frames: List of BGR frames
|
| 143 |
+
scale: Target scale for resizing
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Preprocessed numpy array ready for ONNX inference
|
| 147 |
+
"""
|
| 148 |
+
imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames])
|
| 149 |
+
imgs = imgs.transpose(0, 3, 1, 2) # BHWC to BCHW
|
| 150 |
+
imgs = imgs.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
| 151 |
+
return imgs
|
| 152 |
|
| 153 |
+
def _post_process_output(self, outputs: np.ndarray, x_scale: float, y_scale: float,
|
| 154 |
+
conf_thresh: float = 0.6, nms_thresh: float = 0.55) -> List[List[Tuple]]:
|
| 155 |
+
"""
|
| 156 |
+
Post-process ONNX model outputs to get detections.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
Args:
|
| 159 |
+
outputs: Raw ONNX model outputs
|
| 160 |
+
x_scale: X-axis scaling factor
|
| 161 |
+
y_scale: Y-axis scaling factor
|
| 162 |
+
conf_thresh: Confidence threshold
|
| 163 |
+
nms_thresh: NMS threshold
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
List of detections for each frame: [(box, conf, class_id), ...]
|
| 167 |
+
"""
|
| 168 |
+
B, C, N = outputs.shape
|
| 169 |
+
outputs = torch.from_numpy(outputs)
|
| 170 |
+
outputs = outputs.permute(0, 2, 1) # B,C,N -> B,N,C
|
| 171 |
+
|
| 172 |
+
boxes = outputs[..., :4]
|
| 173 |
+
class_scores = 1 / (1 + torch.exp(-outputs[..., 4:])) # Sigmoid activation
|
| 174 |
+
conf, class_id = class_scores.max(dim=2)
|
| 175 |
|
| 176 |
+
mask = conf > conf_thresh
|
| 177 |
+
|
| 178 |
+
# Special handling for balls - keep best one even with lower confidence
|
| 179 |
+
for i in range(class_id.shape[0]): # loop over batch
|
| 180 |
+
# Find detections that are balls
|
| 181 |
+
ball_mask = class_id[i] == 0
|
| 182 |
+
ball_idx = ball_mask.nonzero(as_tuple=True)[0]
|
| 183 |
+
if ball_idx.numel() > 0:
|
| 184 |
+
# Pick the one with the highest confidence
|
| 185 |
+
best_ball_idx = ball_idx[conf[i, ball_idx].argmax()]
|
| 186 |
+
if conf[i, best_ball_idx] >= 0.55: # apply confidence threshold
|
| 187 |
+
mask[i, best_ball_idx] = True
|
| 188 |
+
|
| 189 |
+
batch_idx, pred_idx = mask.nonzero(as_tuple=True)
|
| 190 |
|
| 191 |
+
if len(batch_idx) == 0:
|
| 192 |
+
return [[] for _ in range(B)]
|
| 193 |
+
|
| 194 |
+
boxes = boxes[batch_idx, pred_idx]
|
| 195 |
+
conf = conf[batch_idx, pred_idx]
|
| 196 |
+
class_id = class_id[batch_idx, pred_idx]
|
| 197 |
+
|
| 198 |
+
# Convert from center format to xyxy format
|
| 199 |
+
x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
| 200 |
+
x1 = (x - w / 2) * x_scale
|
| 201 |
+
y1 = (y - h / 2) * y_scale
|
| 202 |
+
x2 = (x + w / 2) * x_scale
|
| 203 |
+
y2 = (y + h / 2) * y_scale
|
| 204 |
+
boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
|
| 205 |
+
|
| 206 |
+
# Apply batched NMS
|
| 207 |
+
max_coord = 1e4
|
| 208 |
+
offset = batch_idx.to(boxes_xyxy) * max_coord
|
| 209 |
+
boxes_for_nms = boxes_xyxy + offset[:, None]
|
| 210 |
+
|
| 211 |
+
keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh)
|
| 212 |
+
|
| 213 |
+
boxes_final = boxes_xyxy[keep]
|
| 214 |
+
conf_final = conf[keep]
|
| 215 |
+
class_final = class_id[keep]
|
| 216 |
+
batch_final = batch_idx[keep]
|
| 217 |
+
|
| 218 |
+
# Group results by batch
|
| 219 |
+
results = [[] for _ in range(B)]
|
| 220 |
+
for b in range(B):
|
| 221 |
+
mask_b = batch_final == b
|
| 222 |
+
if mask_b.sum() == 0:
|
| 223 |
continue
|
| 224 |
+
results[b] = list(zip(boxes_final[mask_b].numpy(),
|
| 225 |
+
conf_final[mask_b].numpy(),
|
| 226 |
+
class_final[mask_b].numpy()))
|
| 227 |
+
return results
|
| 228 |
+
|
| 229 |
+
def _ioa(self, a: BoundingBox, b: BoundingBox) -> float:
|
| 230 |
+
inter = self._intersect_area(a, b)
|
| 231 |
+
aa = self._area(a)
|
| 232 |
+
if aa <= 0:
|
| 233 |
+
return 0.0
|
| 234 |
+
return inter / aa
|
| 235 |
+
|
| 236 |
+
def suppress_small_contained(self, boxes: List[BoundingBox]) -> List[BoundingBox]:
|
| 237 |
+
if len(boxes) <= 1:
|
| 238 |
+
return boxes
|
| 239 |
+
keep = [True] * len(boxes)
|
| 240 |
+
areas = [self._area(bb) for bb in boxes]
|
| 241 |
+
for i in range(len(boxes)):
|
| 242 |
+
if not keep[i]:
|
| 243 |
+
continue
|
| 244 |
+
for j in range(len(boxes)):
|
| 245 |
+
if i == j or not keep[j]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
continue
|
| 247 |
+
ai, aj = areas[i], areas[j]
|
| 248 |
+
if ai == 0 or aj == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
continue
|
| 250 |
+
if ai <= aj:
|
| 251 |
+
ratio = ai / aj
|
| 252 |
+
if ratio <= self.SMALL_RATIO_MAX:
|
| 253 |
+
ioa_i_in_j = self._ioa(boxes[i], boxes[j])
|
| 254 |
+
if ioa_i_in_j >= self.SMALL_CONTAINED_IOA:
|
| 255 |
+
keep[i] = False
|
|
|
|
|
|
|
|
|
|
| 256 |
break
|
| 257 |
+
else:
|
| 258 |
+
ratio = aj / ai
|
| 259 |
+
if ratio <= self.SMALL_RATIO_MAX:
|
| 260 |
+
ioa_j_in_i = self._ioa(boxes[j], boxes[i])
|
| 261 |
+
if ioa_j_in_i >= self.SMALL_CONTAINED_IOA:
|
| 262 |
+
keep[j] = False
|
| 263 |
+
return [bb for bb, k in zip(boxes, keep) if k]
|
| 264 |
+
|
| 265 |
+
def _detect_objects_batch(self, batch_images: List[ndarray], offset: int) -> Dict[int, List[BoundingBox]]:
|
| 266 |
+
"""
|
| 267 |
+
Phase 1: Object detection for all frames in batch.
|
| 268 |
+
Returns detected objects with players still having class_id=2 (before team classification).
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
batch_images: List of images to process
|
| 272 |
+
offset: Frame offset for numbering
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Dictionary mapping frame_id to list of detected boxes
|
| 276 |
+
"""
|
| 277 |
+
bboxes: Dict[int, List[BoundingBox]] = {}
|
| 278 |
|
| 279 |
+
if len(batch_images) == 0:
|
| 280 |
+
return bboxes
|
| 281 |
+
|
| 282 |
+
print(f"Processing batch of {len(batch_images)} images")
|
| 283 |
+
|
| 284 |
+
# Get original image dimensions for scaling
|
| 285 |
+
height, width = batch_images[0].shape[:2]
|
| 286 |
+
scale = 640.0
|
| 287 |
+
x_scale = width / scale
|
| 288 |
+
y_scale = height / scale
|
| 289 |
+
|
| 290 |
+
# Memory optimization: Process smaller batches if needed
|
| 291 |
+
max_batch_size = 32 # Reduce batch size further to prevent memory issues
|
| 292 |
+
if len(batch_images) > max_batch_size:
|
| 293 |
+
print(f"Large batch detected ({len(batch_images)} images), splitting into smaller batches of {max_batch_size}")
|
| 294 |
+
# Process in smaller chunks
|
| 295 |
+
all_bboxes = {}
|
| 296 |
+
for chunk_start in range(0, len(batch_images), max_batch_size):
|
| 297 |
+
chunk_end = min(chunk_start + max_batch_size, len(batch_images))
|
| 298 |
+
chunk_images = batch_images[chunk_start:chunk_end]
|
| 299 |
+
chunk_offset = offset + chunk_start
|
| 300 |
+
print(f"Processing chunk {chunk_start//max_batch_size + 1}: images {chunk_start}-{chunk_end-1}")
|
| 301 |
+
chunk_bboxes = self._detect_objects_batch(chunk_images, chunk_offset)
|
| 302 |
+
all_bboxes.update(chunk_bboxes)
|
| 303 |
+
return all_bboxes
|
| 304 |
+
|
| 305 |
+
# Preprocess images for ONNX inference
|
| 306 |
+
imgs = self._pre_process_img(batch_images, scale)
|
| 307 |
+
actual_batch_size = len(batch_images)
|
| 308 |
+
|
| 309 |
+
# Handle batch size mismatch - pad if needed
|
| 310 |
+
model_batch_size = self.bbox_model.get_inputs()[0].shape[0]
|
| 311 |
+
print(f"Model input shape: {self.bbox_model.get_inputs()[0].shape}, batch_size: {model_batch_size}")
|
| 312 |
+
|
| 313 |
+
if model_batch_size is not None:
|
| 314 |
+
try:
|
| 315 |
+
# Handle dynamic batch size (None, -1, 'None')
|
| 316 |
+
if str(model_batch_size) in ['None', '-1'] or model_batch_size == -1:
|
| 317 |
+
model_batch_size = None
|
| 318 |
+
else:
|
| 319 |
+
model_batch_size = int(model_batch_size)
|
| 320 |
+
except (ValueError, TypeError):
|
| 321 |
+
model_batch_size = None
|
| 322 |
+
|
| 323 |
+
print(f"Processed model_batch_size: {model_batch_size}, actual_batch_size: {actual_batch_size}")
|
| 324 |
+
|
| 325 |
+
if model_batch_size and actual_batch_size < model_batch_size:
|
| 326 |
+
padding_size = model_batch_size - actual_batch_size
|
| 327 |
+
dummy_img = np.zeros((1, 3, int(scale), int(scale)), dtype=np.float32)
|
| 328 |
+
padding = np.repeat(dummy_img, padding_size, axis=0)
|
| 329 |
+
imgs = np.vstack([imgs, padding])
|
| 330 |
+
|
| 331 |
+
# ONNX inference with error handling
|
| 332 |
+
try:
|
| 333 |
+
input_name = self.bbox_model.get_inputs()[0].name
|
| 334 |
+
import time
|
| 335 |
+
start_time = time.time()
|
| 336 |
+
outputs = self.bbox_model.run(None, {input_name: imgs})[0]
|
| 337 |
+
inference_time = time.time() - start_time
|
| 338 |
+
print(f"Inference time: {inference_time:.3f}s for {actual_batch_size} images")
|
| 339 |
+
|
| 340 |
+
# Remove padded results if we added padding
|
| 341 |
+
if model_batch_size and isinstance(model_batch_size, int) and actual_batch_size < model_batch_size:
|
| 342 |
+
outputs = outputs[:actual_batch_size]
|
| 343 |
+
|
| 344 |
+
# Post-process outputs to get detections
|
| 345 |
+
raw_results = self._post_process_output(np.array(outputs), x_scale, y_scale)
|
| 346 |
+
|
| 347 |
+
except Exception as e:
|
| 348 |
+
print(f"Error during ONNX inference: {e}")
|
| 349 |
+
return bboxes
|
| 350 |
+
|
| 351 |
+
if not raw_results:
|
| 352 |
+
return bboxes
|
| 353 |
+
|
| 354 |
+
# Convert raw results to BoundingBox objects and apply processing
|
| 355 |
+
for frame_idx_in_batch, frame_detections in enumerate(raw_results):
|
| 356 |
+
if not frame_detections:
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
# Convert to BoundingBox objects
|
| 360 |
+
boxes: List[BoundingBox] = []
|
| 361 |
+
for box, conf, cls_id in frame_detections:
|
| 362 |
+
x1, y1, x2, y2 = box
|
| 363 |
+
if int(cls_id) < 4:
|
| 364 |
boxes.append(
|
| 365 |
BoundingBox(
|
| 366 |
x1=int(x1),
|
| 367 |
y1=int(y1),
|
| 368 |
x2=int(x2),
|
| 369 |
y2=int(y2),
|
| 370 |
+
cls_id=int(cls_id),
|
| 371 |
conf=float(conf),
|
| 372 |
)
|
| 373 |
)
|
| 374 |
+
|
| 375 |
# Handle footballs - keep only the best one
|
| 376 |
footballs = [bb for bb in boxes if int(bb.cls_id) == 0]
|
| 377 |
if len(footballs) > 1:
|
| 378 |
best_ball = max(footballs, key=lambda b: b.conf)
|
| 379 |
boxes = [bb for bb in boxes if int(bb.cls_id) != 0]
|
| 380 |
boxes.append(best_ball)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
# Remove overlapping small boxes
|
| 383 |
+
boxes = suppress_small_contained_boxes(boxes, self.SMALL_CONTAINED_IOA, self.SMALL_RATIO_MAX)
|
| 384 |
+
|
| 385 |
+
# Handle goalkeeper detection issues:
|
| 386 |
+
# 1. Fix misplaced goalkeepers (convert to players if standing in middle)
|
| 387 |
+
# 2. Allow up to 2 goalkeepers maximum (one from each team)
|
| 388 |
+
# Goalkeepers remain class_id = 1 (no team assignment)
|
| 389 |
+
boxes = self._handle_multiple_goalkeepers(boxes)
|
| 390 |
+
|
| 391 |
+
# Store results (players still have class_id=2, will be classified in phase 2)
|
| 392 |
+
frame_id = offset + frame_idx_in_batch
|
| 393 |
+
bboxes[frame_id] = boxes
|
| 394 |
+
|
| 395 |
+
return bboxes
|
| 396 |
|
|
|
|
|
|
|
| 397 |
|
| 398 |
+
def predict_batch(
|
| 399 |
+
self,
|
| 400 |
+
batch_images: List[ndarray],
|
| 401 |
+
offset: int,
|
| 402 |
+
n_keypoints: int,
|
| 403 |
+
task_type: Optional[str] = None,
|
| 404 |
+
) -> List[TVFrameResult]:
|
| 405 |
+
process_objects = task_type is None or task_type == "object"
|
| 406 |
+
process_keypoints = task_type is None or task_type == "keypoint"
|
| 407 |
+
|
| 408 |
+
# Phase 1: Object Detection for all frames
|
| 409 |
+
bboxes: Dict[int, List[BoundingBox]] = {}
|
| 410 |
+
if process_objects:
|
| 411 |
+
bboxes = self._detect_objects_batch(batch_images, offset)
|
| 412 |
+
|
| 413 |
+
import time
|
| 414 |
+
time_start = time.time()
|
| 415 |
+
# Phase 2: Team Classification for all detected players
|
| 416 |
+
if process_objects and bboxes:
|
| 417 |
+
bboxes, self.team_classifier_fitted, self.player_crops_for_fit = classify_teams_batch(
|
| 418 |
+
self.team_classifier,
|
| 419 |
+
self.team_classifier_fitted,
|
| 420 |
+
self.player_crops_for_fit,
|
| 421 |
+
batch_images,
|
| 422 |
+
bboxes,
|
| 423 |
+
offset,
|
| 424 |
+
self.MIN_SAMPLES_FOR_FIT,
|
| 425 |
+
self.MAX_SAMPLES_FOR_FIT,
|
| 426 |
+
self.SINGLE_PLAYER_HUE_PIVOT
|
| 427 |
)
|
| 428 |
+
self.team_classifier_fitted = False
|
| 429 |
+
self.player_crops_for_fit = []
|
| 430 |
+
print(f"Time Team Classification: {time.time() - time_start} s")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
# Phase 3: Keypoint Detection
|
| 433 |
+
keypoints: Dict[int, List[Tuple[int, int]]] = {}
|
| 434 |
+
if process_keypoints:
|
| 435 |
+
keypoints = self._detect_keypoints_batch(batch_images, offset, n_keypoints)
|
| 436 |
+
|
| 437 |
+
# Phase 4: Combine results
|
| 438 |
results: List[TVFrameResult] = []
|
| 439 |
for frame_number in range(offset, offset + len(batch_images)):
|
| 440 |
+
results.append(
|
| 441 |
+
TVFrameResult(
|
| 442 |
+
frame_id=frame_number,
|
| 443 |
+
boxes=bboxes.get(frame_number, []),
|
| 444 |
+
keypoints=keypoints.get(
|
| 445 |
+
frame_number,
|
| 446 |
+
[(0, 0) for _ in range(n_keypoints)],
|
| 447 |
+
),
|
| 448 |
+
)
|
| 449 |
)
|
| 450 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
+
def _detect_keypoints_batch(self, batch_images: List[ndarray],
|
| 453 |
+
offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]:
|
| 454 |
+
"""
|
| 455 |
+
Phase 3: Keypoint detection for all frames in batch.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
batch_images: List of images to process
|
| 459 |
+
offset: Frame offset for numbering
|
| 460 |
+
n_keypoints: Number of keypoints expected
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
Dictionary mapping frame_id to list of keypoint coordinates
|
| 464 |
+
"""
|
| 465 |
+
keypoints: Dict[int, List[Tuple[int, int]]] = {}
|
| 466 |
+
keypoints_model_results = self.keypoints_model.predict(batch_images)
|
| 467 |
+
|
| 468 |
+
if keypoints_model_results is None:
|
| 469 |
+
return keypoints
|
| 470 |
+
|
| 471 |
+
for frame_idx_in_batch, detection in enumerate(keypoints_model_results):
|
| 472 |
+
if not hasattr(detection, "keypoints") or detection.keypoints is None:
|
| 473 |
+
continue
|
| 474 |
+
|
| 475 |
+
# Extract keypoints with confidence
|
| 476 |
+
frame_keypoints_with_conf: List[Tuple[int, int, float]] = []
|
| 477 |
+
for i, part_points in enumerate(detection.keypoints.data):
|
| 478 |
+
for k_id, (x, y, _) in enumerate(part_points):
|
| 479 |
+
confidence = float(detection.keypoints.conf[i][k_id])
|
| 480 |
+
frame_keypoints_with_conf.append((int(x), int(y), confidence))
|
| 481 |
+
|
| 482 |
+
# Pad or truncate to expected number of keypoints
|
| 483 |
+
if len(frame_keypoints_with_conf) < n_keypoints:
|
| 484 |
+
frame_keypoints_with_conf.extend(
|
| 485 |
+
[(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf))
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints]
|
| 489 |
+
|
| 490 |
+
# Filter keypoints based on confidence thresholds
|
| 491 |
+
filtered_keypoints: List[Tuple[int, int]] = []
|
| 492 |
+
for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf):
|
| 493 |
+
if idx in self.CORNER_INDICES:
|
| 494 |
+
# Corner keypoints have lower confidence threshold
|
| 495 |
+
if confidence < 0.3:
|
| 496 |
+
filtered_keypoints.append((0, 0))
|
| 497 |
+
else:
|
| 498 |
+
filtered_keypoints.append((int(x), int(y)))
|
| 499 |
+
else:
|
| 500 |
+
# Regular keypoints
|
| 501 |
+
if confidence < 0.5:
|
| 502 |
+
filtered_keypoints.append((0, 0))
|
| 503 |
+
else:
|
| 504 |
+
filtered_keypoints.append((int(x), int(y)))
|
| 505 |
+
|
| 506 |
+
frame_id = offset + frame_idx_in_batch
|
| 507 |
+
keypoints[frame_id] = filtered_keypoints
|
| 508 |
+
|
| 509 |
+
return keypoints
|
object-detection.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05112479be8cb59494e9ae23a57af43becd5aa1f448b0e5ed33fcb6b4c2bbbc3
|
| 3 |
+
size 273322667
|
pitch.py
CHANGED
|
@@ -520,7 +520,7 @@ def run_inference(model, input_tensor: torch.Tensor, device):
|
|
| 520 |
output = model.module().forward(input_tensor)
|
| 521 |
return output
|
| 522 |
|
| 523 |
-
def preprocess_batch_fast(frames):
|
| 524 |
"""Ultra-fast batch preprocessing using optimized tensor operations"""
|
| 525 |
target_size = (540, 960) # H, W format for model input
|
| 526 |
batch = []
|
|
@@ -530,7 +530,7 @@ def preprocess_batch_fast(frames):
|
|
| 530 |
img = img.astype(np.float32) / 255.0
|
| 531 |
img = np.transpose(img, (2, 0, 1)) # HWC -> CHW
|
| 532 |
batch.append(img)
|
| 533 |
-
batch = torch.
|
| 534 |
|
| 535 |
return batch
|
| 536 |
|
|
@@ -610,24 +610,16 @@ def inference_batch(frames, model, kp_threshold, device, batch_size=8):
|
|
| 610 |
results = []
|
| 611 |
num_frames = len(frames)
|
| 612 |
|
| 613 |
-
# Get the device from the model itself
|
| 614 |
-
model_device = next(model.parameters()).device
|
| 615 |
-
|
| 616 |
# Process all frames in optimally-sized batches
|
| 617 |
for i in range(0, num_frames, batch_size):
|
| 618 |
current_batch_size = min(batch_size, num_frames - i)
|
| 619 |
batch_frames = frames[i:i + current_batch_size]
|
| 620 |
|
| 621 |
-
# Fast preprocessing
|
| 622 |
-
batch = preprocess_batch_fast(batch_frames)
|
| 623 |
-
b, c, h, w = batch.size()
|
| 624 |
-
|
| 625 |
-
# Move batch to model device
|
| 626 |
-
batch = batch.to(model_device)
|
| 627 |
-
|
| 628 |
-
with torch.no_grad():
|
| 629 |
-
heatmaps = model(batch)
|
| 630 |
|
|
|
|
|
|
|
| 631 |
# Ultra-fast keypoint extraction
|
| 632 |
kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1)
|
| 633 |
|
|
@@ -660,10 +652,28 @@ def get_mapped_keypoints(kp_points):
|
|
| 660 |
# mapped_points[key] = value
|
| 661 |
return mapped_points
|
| 662 |
|
| 663 |
-
def process_batch_input(frames, model, kp_threshold, device, batch_size=
|
| 664 |
"""Process multiple input images in batch"""
|
| 665 |
# Batch inference
|
| 666 |
kp_results = inference_batch(frames, model, kp_threshold, device, batch_size)
|
| 667 |
kp_results = [get_mapped_keypoints(kp) for kp in kp_results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
return kp_results
|
|
|
|
| 520 |
output = model.module().forward(input_tensor)
|
| 521 |
return output
|
| 522 |
|
| 523 |
+
def preprocess_batch_fast(frames, device):
|
| 524 |
"""Ultra-fast batch preprocessing using optimized tensor operations"""
|
| 525 |
target_size = (540, 960) # H, W format for model input
|
| 526 |
batch = []
|
|
|
|
| 530 |
img = img.astype(np.float32) / 255.0
|
| 531 |
img = np.transpose(img, (2, 0, 1)) # HWC -> CHW
|
| 532 |
batch.append(img)
|
| 533 |
+
batch = torch.tensor(np.stack(batch), dtype=torch.float32)
|
| 534 |
|
| 535 |
return batch
|
| 536 |
|
|
|
|
| 610 |
results = []
|
| 611 |
num_frames = len(frames)
|
| 612 |
|
|
|
|
|
|
|
|
|
|
| 613 |
# Process all frames in optimally-sized batches
|
| 614 |
for i in range(0, num_frames, batch_size):
|
| 615 |
current_batch_size = min(batch_size, num_frames - i)
|
| 616 |
batch_frames = frames[i:i + current_batch_size]
|
| 617 |
|
| 618 |
+
# Fast preprocessing
|
| 619 |
+
batch = preprocess_batch_fast(batch_frames, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
+
heatmaps = run_inference(model, batch, device)
|
| 622 |
+
|
| 623 |
# Ultra-fast keypoint extraction
|
| 624 |
kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1)
|
| 625 |
|
|
|
|
| 652 |
# mapped_points[key] = value
|
| 653 |
return mapped_points
|
| 654 |
|
| 655 |
+
def process_batch_input(frames, model, kp_threshold, device, batch_size=8):
|
| 656 |
"""Process multiple input images in batch"""
|
| 657 |
# Batch inference
|
| 658 |
kp_results = inference_batch(frames, model, kp_threshold, device, batch_size)
|
| 659 |
kp_results = [get_mapped_keypoints(kp) for kp in kp_results]
|
| 660 |
+
# Draw results and save
|
| 661 |
+
# for i, (frame, kp_points, input_path) in enumerate(zip(frames, kp_results, valid_paths)):
|
| 662 |
+
# height, width = frame.shape[:2]
|
| 663 |
+
|
| 664 |
+
# # Apply mapping to get standard keypoint IDs
|
| 665 |
+
# mapped_kp_points = get_mapped_keypoints(kp_points)
|
| 666 |
+
|
| 667 |
+
# for key, value in mapped_kp_points.items():
|
| 668 |
+
# x = int(value['x'] * width)
|
| 669 |
+
# y = int(value['y'] * height)
|
| 670 |
+
# cv2.circle(frame, (x, y), 5, (0, 255, 0), -1) # Green circles
|
| 671 |
+
# cv2.putText(frame, str(key), (x+10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
| 672 |
+
|
| 673 |
+
# # Save result
|
| 674 |
+
# output_path = input_path.replace('.png', '_result.png').replace('.jpg', '_result.jpg')
|
| 675 |
+
# cv2.imwrite(output_path, frame)
|
| 676 |
+
|
| 677 |
+
# print(f"Batch processing complete. Processed {len(frames)} images.")
|
| 678 |
|
| 679 |
return kp_results
|
player.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.cluster import KMeans
|
| 4 |
+
import warnings
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision.ops import batched_nms
|
| 9 |
+
from numpy import ndarray
|
| 10 |
+
# Suppress ALL runtime and sklearn warnings
|
| 11 |
+
warnings.filterwarnings('ignore', category=RuntimeWarning)
|
| 12 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 13 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 14 |
+
|
| 15 |
+
# Suppress sklearn warnings specifically
|
| 16 |
+
import logging
|
| 17 |
+
logging.getLogger('sklearn').setLevel(logging.ERROR)
|
| 18 |
+
|
| 19 |
+
def get_grass_color(img):
|
| 20 |
+
# Convert image to HSV color space
|
| 21 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 22 |
+
|
| 23 |
+
# Define range of green color in HSV
|
| 24 |
+
lower_green = np.array([30, 40, 40])
|
| 25 |
+
upper_green = np.array([80, 255, 255])
|
| 26 |
+
|
| 27 |
+
# Threshold the HSV image to get only green colors
|
| 28 |
+
mask = cv2.inRange(hsv, lower_green, upper_green)
|
| 29 |
+
|
| 30 |
+
# Calculate the mean value of the pixels that are not masked
|
| 31 |
+
masked_img = cv2.bitwise_and(img, img, mask=mask)
|
| 32 |
+
grass_color = cv2.mean(img, mask=mask)
|
| 33 |
+
return grass_color[:3]
|
| 34 |
+
|
| 35 |
+
def get_players_boxes(frame, result):
|
| 36 |
+
players_imgs = []
|
| 37 |
+
players_boxes = []
|
| 38 |
+
for (box, score, cls) in result:
|
| 39 |
+
label = int(cls)
|
| 40 |
+
if label == 0:
|
| 41 |
+
x1, y1, x2, y2 = box.astype(int)
|
| 42 |
+
player_img = frame[y1: y2, x1: x2]
|
| 43 |
+
players_imgs.append(player_img)
|
| 44 |
+
players_boxes.append([box, score, cls])
|
| 45 |
+
return players_imgs, players_boxes
|
| 46 |
+
|
| 47 |
+
def get_kits_colors(players, grass_hsv=None, frame=None):
|
| 48 |
+
kits_colors = []
|
| 49 |
+
if grass_hsv is None:
|
| 50 |
+
grass_color = get_grass_color(frame)
|
| 51 |
+
grass_hsv = cv2.cvtColor(np.uint8([[list(grass_color)]]), cv2.COLOR_BGR2HSV)
|
| 52 |
+
|
| 53 |
+
for player_img in players:
|
| 54 |
+
# Skip empty or invalid images
|
| 55 |
+
if player_img is None or player_img.size == 0 or len(player_img.shape) != 3:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
# Convert image to HSV color space
|
| 59 |
+
hsv = cv2.cvtColor(player_img, cv2.COLOR_BGR2HSV)
|
| 60 |
+
|
| 61 |
+
# Define range of green color in HSV
|
| 62 |
+
lower_green = np.array([grass_hsv[0, 0, 0] - 10, 40, 40])
|
| 63 |
+
upper_green = np.array([grass_hsv[0, 0, 0] + 10, 255, 255])
|
| 64 |
+
|
| 65 |
+
# Threshold the HSV image to get only green colors
|
| 66 |
+
mask = cv2.inRange(hsv, lower_green, upper_green)
|
| 67 |
+
|
| 68 |
+
# Bitwise-AND mask and original image
|
| 69 |
+
mask = cv2.bitwise_not(mask)
|
| 70 |
+
upper_mask = np.zeros(player_img.shape[:2], np.uint8)
|
| 71 |
+
upper_mask[0:player_img.shape[0]//2, 0:player_img.shape[1]] = 255
|
| 72 |
+
mask = cv2.bitwise_and(mask, upper_mask)
|
| 73 |
+
|
| 74 |
+
kit_color = np.array(cv2.mean(player_img, mask=mask)[:3])
|
| 75 |
+
|
| 76 |
+
kits_colors.append(kit_color)
|
| 77 |
+
return kits_colors
|
| 78 |
+
|
| 79 |
+
def get_kits_classifier(kits_colors):
|
| 80 |
+
if len(kits_colors) == 0:
|
| 81 |
+
return None
|
| 82 |
+
if len(kits_colors) == 1:
|
| 83 |
+
# Only one kit color, create a dummy classifier
|
| 84 |
+
return None
|
| 85 |
+
kits_kmeans = KMeans(n_clusters=2)
|
| 86 |
+
kits_kmeans.fit(kits_colors)
|
| 87 |
+
return kits_kmeans
|
| 88 |
+
|
| 89 |
+
def classify_kits(kits_classifer, kits_colors):
|
| 90 |
+
if kits_classifer is None or len(kits_colors) == 0:
|
| 91 |
+
return np.array([0]) # Default to team 0
|
| 92 |
+
team = kits_classifer.predict(kits_colors)
|
| 93 |
+
return team
|
| 94 |
+
|
| 95 |
+
def get_left_team_label(players_boxes, kits_colors, kits_clf):
|
| 96 |
+
left_team_label = 0
|
| 97 |
+
team_0 = []
|
| 98 |
+
team_1 = []
|
| 99 |
+
|
| 100 |
+
for i in range(len(players_boxes)):
|
| 101 |
+
x1, y1, x2, y2 = players_boxes[i][0].astype(int)
|
| 102 |
+
team = classify_kits(kits_clf, [kits_colors[i]]).item()
|
| 103 |
+
if team == 0:
|
| 104 |
+
team_0.append(np.array([x1]))
|
| 105 |
+
else:
|
| 106 |
+
team_1.append(np.array([x1]))
|
| 107 |
+
|
| 108 |
+
team_0 = np.array(team_0)
|
| 109 |
+
team_1 = np.array(team_1)
|
| 110 |
+
|
| 111 |
+
# Safely calculate averages with fallback for empty arrays
|
| 112 |
+
avg_team_0 = np.average(team_0) if len(team_0) > 0 else 0
|
| 113 |
+
avg_team_1 = np.average(team_1) if len(team_1) > 0 else 0
|
| 114 |
+
|
| 115 |
+
if avg_team_0 - avg_team_1 > 0:
|
| 116 |
+
left_team_label = 1
|
| 117 |
+
|
| 118 |
+
return left_team_label
|
| 119 |
+
|
| 120 |
+
def check_box_boundaries(boxes, img_height, img_width):
|
| 121 |
+
"""
|
| 122 |
+
Check if bounding boxes are within image boundaries and clip them if necessary.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
boxes: numpy array of shape (N, 4) with [x1, y1, x2, y2] format
|
| 126 |
+
img_height: height of the image
|
| 127 |
+
img_width: width of the image
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
valid_boxes: numpy array of valid boxes within boundaries
|
| 131 |
+
valid_indices: indices of valid boxes
|
| 132 |
+
"""
|
| 133 |
+
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
| 134 |
+
|
| 135 |
+
# Check if boxes are within boundaries
|
| 136 |
+
valid_mask = (x1 >= 0) & (y1 >= 0) & (x2 < img_width) & (y2 < img_height) & (x1 < x2) & (y1 < y2)
|
| 137 |
+
|
| 138 |
+
if not np.any(valid_mask):
|
| 139 |
+
return np.array([]), np.array([])
|
| 140 |
+
|
| 141 |
+
valid_boxes = boxes[valid_mask]
|
| 142 |
+
valid_indices = np.where(valid_mask)[0]
|
| 143 |
+
|
| 144 |
+
# Clip boxes to image boundaries
|
| 145 |
+
valid_boxes[:, 0] = np.clip(valid_boxes[:, 0], 0, img_width - 1) # x1
|
| 146 |
+
valid_boxes[:, 1] = np.clip(valid_boxes[:, 1], 0, img_height - 1) # y1
|
| 147 |
+
valid_boxes[:, 2] = np.clip(valid_boxes[:, 2], 0, img_width - 1) # x2
|
| 148 |
+
valid_boxes[:, 3] = np.clip(valid_boxes[:, 3], 0, img_height - 1) # y2
|
| 149 |
+
|
| 150 |
+
return valid_boxes, valid_indices
|
| 151 |
+
|
| 152 |
+
def process_team_identification_batch(frames, results, kits_clf, left_team_label, grass_hsv):
|
| 153 |
+
"""
|
| 154 |
+
Process team identification and label formatting for batch results.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
frames: list of frames
|
| 158 |
+
results: list of detection results for each frame
|
| 159 |
+
kits_clf: trained kit classifier
|
| 160 |
+
left_team_label: label for left team
|
| 161 |
+
grass_hsv: grass color in HSV format
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
processed_results: list of processed results with team identification
|
| 165 |
+
"""
|
| 166 |
+
processed_results = []
|
| 167 |
+
|
| 168 |
+
for frame_idx, frame in enumerate(frames):
|
| 169 |
+
frame_results = []
|
| 170 |
+
frame_detections = results[frame_idx]
|
| 171 |
+
|
| 172 |
+
if not frame_detections:
|
| 173 |
+
processed_results.append([])
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
# Extract player boxes and images
|
| 177 |
+
players_imgs = []
|
| 178 |
+
players_boxes = []
|
| 179 |
+
player_indices = []
|
| 180 |
+
|
| 181 |
+
for idx, (box, score, cls) in enumerate(frame_detections):
|
| 182 |
+
label = int(cls)
|
| 183 |
+
if label == 0: # Player detection
|
| 184 |
+
x1, y1, x2, y2 = box.astype(int)
|
| 185 |
+
|
| 186 |
+
# Check boundaries
|
| 187 |
+
if (x1 >= 0 and y1 >= 0 and x2 < frame.shape[1] and y2 < frame.shape[0] and x1 < x2 and y1 < y2):
|
| 188 |
+
player_img = frame[y1:y2, x1:x2]
|
| 189 |
+
if player_img.size > 0: # Ensure valid image
|
| 190 |
+
players_imgs.append(player_img)
|
| 191 |
+
players_boxes.append([box, score, cls])
|
| 192 |
+
player_indices.append(idx)
|
| 193 |
+
|
| 194 |
+
# Initialize player team mapping
|
| 195 |
+
player_team_map = {}
|
| 196 |
+
|
| 197 |
+
# Process team identification if we have players
|
| 198 |
+
if players_imgs and kits_clf is not None:
|
| 199 |
+
kits_colors = get_kits_colors(players_imgs, grass_hsv)
|
| 200 |
+
teams = classify_kits(kits_clf, kits_colors)
|
| 201 |
+
|
| 202 |
+
# Create mapping from player index to team
|
| 203 |
+
for i, team in enumerate(teams):
|
| 204 |
+
player_team_map[player_indices[i]] = team.item()
|
| 205 |
+
|
| 206 |
+
id = 0
|
| 207 |
+
# Process all detections with team identification
|
| 208 |
+
for idx, (box, score, cls) in enumerate(frame_detections):
|
| 209 |
+
label = int(cls)
|
| 210 |
+
x1, y1, x2, y2 = box.astype(int)
|
| 211 |
+
|
| 212 |
+
# Check boundaries
|
| 213 |
+
valid_boxes, valid_indices = check_box_boundaries(
|
| 214 |
+
np.array([[x1, y1, x2, y2]]), frame.shape[0], frame.shape[1]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if len(valid_boxes) == 0:
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
x1, y1, x2, y2 = valid_boxes[0].astype(int)
|
| 221 |
+
|
| 222 |
+
# Apply team identification logic
|
| 223 |
+
if label == 0: # Player
|
| 224 |
+
if players_imgs and kits_clf is not None and idx in player_team_map:
|
| 225 |
+
team = player_team_map[idx]
|
| 226 |
+
if team == left_team_label:
|
| 227 |
+
final_label = 6 # Player-L (Left team)
|
| 228 |
+
else:
|
| 229 |
+
final_label = 7 # Player-R (Right team)
|
| 230 |
+
else:
|
| 231 |
+
final_label = 6 # Default player label
|
| 232 |
+
|
| 233 |
+
elif label == 1: # Goalkeeper
|
| 234 |
+
final_label = 1 # GK
|
| 235 |
+
|
| 236 |
+
elif label == 2: # Ball
|
| 237 |
+
final_label = 0 # Ball
|
| 238 |
+
|
| 239 |
+
elif label == 3 or label == 4: # Referee or other
|
| 240 |
+
final_label = 3 # Referee
|
| 241 |
+
|
| 242 |
+
else:
|
| 243 |
+
final_label = int(label) # Keep original label, ensure it's int
|
| 244 |
+
|
| 245 |
+
frame_results.append({
|
| 246 |
+
"id": int(id),
|
| 247 |
+
"bbox": [int(x1), int(y1), int(x2), int(y2)],
|
| 248 |
+
"class_id": int(final_label),
|
| 249 |
+
"conf": float(score)
|
| 250 |
+
})
|
| 251 |
+
id = id + 1
|
| 252 |
+
|
| 253 |
+
processed_results.append(frame_results)
|
| 254 |
+
|
| 255 |
+
return processed_results
|
| 256 |
+
|
| 257 |
+
def convert_numpy_types(obj):
|
| 258 |
+
"""Convert numpy types to native Python types for JSON serialization."""
|
| 259 |
+
if isinstance(obj, np.integer):
|
| 260 |
+
return int(obj)
|
| 261 |
+
elif isinstance(obj, np.floating):
|
| 262 |
+
return float(obj)
|
| 263 |
+
elif isinstance(obj, np.ndarray):
|
| 264 |
+
return obj.tolist()
|
| 265 |
+
elif isinstance(obj, dict):
|
| 266 |
+
return {key: convert_numpy_types(value) for key, value in obj.items()}
|
| 267 |
+
elif isinstance(obj, list):
|
| 268 |
+
return [convert_numpy_types(item) for item in obj]
|
| 269 |
+
else:
|
| 270 |
+
return obj
|
| 271 |
+
|
| 272 |
+
def pre_process_img(frames, scale):
|
| 273 |
+
imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames])
|
| 274 |
+
imgs = imgs.transpose(0, 3, 1, 2)
|
| 275 |
+
imgs = imgs.astype(np.float32) / 255.0 # Normalize
|
| 276 |
+
return imgs
|
| 277 |
+
|
| 278 |
+
def post_process_output(outputs, x_scale, y_scale, conf_thresh=0.6, nms_thresh=0.75):
|
| 279 |
+
B, C, N = outputs.shape
|
| 280 |
+
outputs = torch.from_numpy(outputs)
|
| 281 |
+
outputs = outputs.permute(0, 2, 1)
|
| 282 |
+
boxes = outputs[..., :4]
|
| 283 |
+
class_scores = 1 / (1 + torch.exp(-outputs[..., 4:]))
|
| 284 |
+
conf, class_id = class_scores.max(dim=2)
|
| 285 |
+
|
| 286 |
+
mask = conf > conf_thresh
|
| 287 |
+
|
| 288 |
+
for i in range(class_id.shape[0]): # loop over batch
|
| 289 |
+
# Find detections that are balls
|
| 290 |
+
ball_idx = np.where(class_id[i] == 2)[0]
|
| 291 |
+
if ball_idx.size > 0:
|
| 292 |
+
# Pick the one with the highest confidence
|
| 293 |
+
top = ball_idx[np.argmax(conf[i, ball_idx])]
|
| 294 |
+
if conf[i, top] > 0.55: # apply confidence threshold
|
| 295 |
+
mask[i, top] = True
|
| 296 |
+
|
| 297 |
+
# ball_mask = (class_id == 2) & (conf > 0.51)
|
| 298 |
+
# mask = mask | ball_mask
|
| 299 |
+
|
| 300 |
+
batch_idx, pred_idx = mask.nonzero(as_tuple=True)
|
| 301 |
+
|
| 302 |
+
if len(batch_idx) == 0:
|
| 303 |
+
return [[] for _ in range(B)]
|
| 304 |
+
|
| 305 |
+
boxes = boxes[batch_idx, pred_idx]
|
| 306 |
+
conf = conf[batch_idx, pred_idx]
|
| 307 |
+
class_id = class_id[batch_idx, pred_idx]
|
| 308 |
+
|
| 309 |
+
x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
| 310 |
+
x1 = (x - w / 2) * x_scale
|
| 311 |
+
y1 = (y - h / 2) * y_scale
|
| 312 |
+
x2 = (x + w / 2) * x_scale
|
| 313 |
+
y2 = (y + h / 2) * y_scale
|
| 314 |
+
boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
|
| 315 |
+
|
| 316 |
+
max_coord = 1e4
|
| 317 |
+
offset = batch_idx.to(boxes_xyxy) * max_coord
|
| 318 |
+
boxes_for_nms = boxes_xyxy + offset[:, None]
|
| 319 |
+
|
| 320 |
+
keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh)
|
| 321 |
+
|
| 322 |
+
boxes_final = boxes_xyxy[keep]
|
| 323 |
+
conf_final = conf[keep]
|
| 324 |
+
class_final = class_id[keep]
|
| 325 |
+
batch_final = batch_idx[keep]
|
| 326 |
+
|
| 327 |
+
results = [[] for _ in range(B)]
|
| 328 |
+
for b in range(B):
|
| 329 |
+
mask_b = batch_final == b
|
| 330 |
+
if mask_b.sum() == 0:
|
| 331 |
+
continue
|
| 332 |
+
results[b] = list(zip(boxes_final[mask_b].numpy(),
|
| 333 |
+
conf_final[mask_b].numpy(),
|
| 334 |
+
class_final[mask_b].numpy()))
|
| 335 |
+
return results
|
| 336 |
+
|
| 337 |
+
def player_detection_result(frames: list[ndarray], batch_size, model, kits_clf=None, left_team_label=None, grass_hsv=None):
|
| 338 |
+
start_time = time.time()
|
| 339 |
+
# input_layer = model.input(0)
|
| 340 |
+
# output_layer = model.output(0)
|
| 341 |
+
height, width = frames[0].shape[:2]
|
| 342 |
+
scale = 640.0
|
| 343 |
+
x_scale = width / scale
|
| 344 |
+
y_scale = height / scale
|
| 345 |
+
|
| 346 |
+
# infer_queue = AsyncInferQueue(model, len(frames))
|
| 347 |
+
|
| 348 |
+
infer_time = time.time()
|
| 349 |
+
kits_clf = kits_clf
|
| 350 |
+
left_team_label = left_team_label
|
| 351 |
+
grass_hsv = grass_hsv
|
| 352 |
+
results = []
|
| 353 |
+
for i in range(0, len(frames), batch_size):
|
| 354 |
+
if i + batch_size > len(frames):
|
| 355 |
+
batch_size = len(frames) - i
|
| 356 |
+
batch_frames = frames[i:i + batch_size]
|
| 357 |
+
imgs = pre_process_img(batch_frames, scale)
|
| 358 |
+
|
| 359 |
+
input_name = model.get_inputs()[0].name
|
| 360 |
+
outputs = model.run(None, {input_name: imgs})[0]
|
| 361 |
+
raw_results = post_process_output(np.array(outputs), x_scale, y_scale)
|
| 362 |
+
|
| 363 |
+
if kits_clf is None or left_team_label is None or grass_hsv is None:
|
| 364 |
+
# Use first frame to initialize team classification
|
| 365 |
+
first_frame = batch_frames[0]
|
| 366 |
+
first_frame_results = raw_results[0] if raw_results else []
|
| 367 |
+
|
| 368 |
+
if first_frame_results:
|
| 369 |
+
players_imgs, players_boxes = get_players_boxes(first_frame, first_frame_results)
|
| 370 |
+
if players_imgs:
|
| 371 |
+
grass_color = get_grass_color(first_frame)
|
| 372 |
+
grass_hsv = cv2.cvtColor(np.uint8([[list(grass_color)]]), cv2.COLOR_BGR2HSV)
|
| 373 |
+
kits_colors = get_kits_colors(players_imgs, grass_hsv)
|
| 374 |
+
if kits_colors: # Only proceed if we have valid kit colors
|
| 375 |
+
kits_clf = get_kits_classifier(kits_colors)
|
| 376 |
+
if kits_clf is not None:
|
| 377 |
+
left_team_label = int(get_left_team_label(players_boxes, kits_colors, kits_clf))
|
| 378 |
+
|
| 379 |
+
# Process team identification and boundary checking
|
| 380 |
+
processed_results = process_team_identification_batch(
|
| 381 |
+
batch_frames, raw_results, kits_clf, left_team_label, grass_hsv
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
processed_results = convert_numpy_types(processed_results)
|
| 385 |
+
results.extend(processed_results)
|
| 386 |
+
|
| 387 |
+
# Return the same format as before for compatibility
|
| 388 |
+
return results, kits_clf, left_team_label, grass_hsv
|