gloriforge commited on
Commit
55e8e9c
·
0 Parent(s):

Duplicate from gloriforge/turbo_1_1

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Example Chute for Turbovision 🪂
2
+
3
+ This repository demonstrates how to deploy a **Chute** via the **Turbovision CLI**, hosted on **Hugging Face Hub**.
4
+ It serves as a minimal example showcasing the required structure and workflow for integrating machine learning models, preprocessing, and orchestration into a reproducible Chute environment.
5
+
6
+ ## Repository Structure
7
+ The following two files **must be present** (in their current locations) for a successful deployment — their content can be modified as needed:
8
+
9
+ | File | Purpose |
10
+ |------|----------|
11
+ | `miner.py` | Defines the ML model type(s), orchestration, and all pre/postprocessing logic. |
12
+ | `config.yml` | Specifies machine configuration (e.g., GPU type, memory, environment variables). |
13
+
14
+ Other files — e.g., model weights, utility scripts, or dependencies — are **optional** and can be included as needed for your model. Note: Any required assets must be defined or contained **within this repo**, which is fully open-source, since all network-related operations (downloading challenge data, weights, etc.) are disabled **inside the Chute**
15
+
16
+ ## Overview
17
+
18
+ Below is a high-level diagram showing the interaction between Huggingface, Chutes and Turbovision:
19
+
20
+ ![](../images/miner.png)
21
+
22
+ ## Local Testing
23
+ After editing the `config.yml` and `miner.py` and saving it into your Huggingface Repo, you will want to test it works locally.
24
+
25
+ 1. Copy the file `scorevision/chute_tmeplate/turbovision_chute.py.j2` as a python file called `my_chute.py` and fill in the missing variables:
26
+ ```python
27
+ HF_REPO_NAME = "{{ huggingface_repository_name }}"
28
+ HF_REPO_REVISION = "{{ huggingface_repository_revision }}"
29
+ CHUTES_USERNAME = "{{ chute_username }}"
30
+ CHUTE_NAME = "{{ chute_name }}"
31
+ ```
32
+
33
+ 2. Run the following command to build the chute locally (Caution: there are known issues with the docker location when running this on a mac)
34
+ ```bash
35
+ chutes build my_chute:chute --local --public
36
+ ```
37
+
38
+ 3. Run the name of the docker image just built (i.e. `CHUTE_NAME`) and enter it
39
+ ```bash
40
+ docker run -p 8000:8000 -e CHUTES_EXECUTION_CONTEXT=REMOTE -it <image-name> /bin/bash
41
+ ```
42
+
43
+ 4. Run the file from within the container
44
+ ```bash
45
+ chutes run my_chute:chute --dev --debug
46
+ ```
47
+
48
+ 5. In another terminal, test the local endpoints to ensure there are no bugs
49
+ ```bash
50
+ curl -X POST http://localhost:8000/health -d '{}'
51
+ curl -X POST http://localhost:8000/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}'
52
+ ```
53
+
54
+ ## Live Testing
55
+ 1. If you have any chute with the same name (ie from a previous deployment), ensure you delete that first (or you will get an error when trying to build).
56
+ ```bash
57
+ chutes chutes list
58
+ ```
59
+ Take note of the chute id that you wish to delete (if any)
60
+ ```bash
61
+ chutes chutes delete <chute-id>
62
+ ```
63
+
64
+ You should also delete its associated image
65
+ ```bash
66
+ chutes images list
67
+ ```
68
+ Take note of the chute image id
69
+ ```bash
70
+ chutes images delete <chute-image-id>
71
+ ```
72
+
73
+ 2. Use Turbovision's CLI to build, deploy and commit on-chain (Note: you can skip the on-chain commit using `--no-commit`. You can also specify a past huggingface revision to point to using `--revision` and/or the local files you want to upload to your huggingface repo using `--model-path`)
74
+ ```bash
75
+ sv -vv push
76
+ ```
77
+
78
+ 3. When completed, warm up the chute (if its cold 🧊). (You can confirm its status using `chutes chutes list` or `chutes chutes get <chute-id>` if you already know its id). Note: Warming up can sometimes take a while but if the chute runs without errors (should be if you've tested locally first) and there are sufficient nodes (i.e. machines) available matching the `config.yml` you specified, the chute should become hot 🔥!
79
+ ```bash
80
+ chutes warmup <chute-id>
81
+ ```
82
+
83
+ 4. Test the chute's endpoints
84
+ ```bash
85
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/health -d '{}' -H "Authorization: Bearer $CHUTES_API_KEY"
86
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}' -H "Authorization: Bearer $CHUTES_API_KEY"
87
+ ```
88
+
89
+ 5. Test what your chute would get on a validator (this also applies any validation/integrity checks which may fail if you did not use the Turbovision CLI above to deploy the chute)
90
+ ```bash
91
+ sv -vv run-once
92
+ ```
chute_config.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Image:
2
+ from_base: parachutes/python:3.12
3
+ run_command:
4
+ - pip install --upgrade setuptools wheel
5
+ - pip install huggingface_hub==0.19.4 opencv-python-headless
6
+ - pip install "ultralytics==8.3.222" "numpy" "pydantic"
7
+ - pip install --index-url https://download.pytorch.org/whl/cu128 "torch==2.7.1" "torchvision==0.22.1"
8
+ - pip install scikit-learn cryptography
9
+ - pip install onnxruntime-gpu numba
10
+ - pip install cython==3.2.2
11
+ set_workdir: /app
12
+
13
+ NodeSelector:
14
+ gpu_count: 1
15
+ min_vram_gb_per_gpu: 24
16
+ min_memory_gb: 32
17
+ min_cpu_count: 32
18
+
19
+ exclude:
20
+ - "5090"
21
+ - b200
22
+ - h200
23
+ - mi300x
24
+
25
+
26
+ Chute:
27
+ timeout_seconds: 900
28
+ concurrency: 4
29
+ max_instances: 5
30
+ scaling_threshold: 0.8
31
+ shutdown_after_seconds: 604800
football_pitch_template.png ADDED
hrnetv2_w48.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMAGE_SIZE: [960, 540]
3
+ NUM_JOINTS: 58
4
+ PRETRAIN: ''
5
+ EXTRA:
6
+ FINAL_CONV_KERNEL: 1
7
+ STAGE1:
8
+ NUM_MODULES: 1
9
+ NUM_BRANCHES: 1
10
+ BLOCK: BOTTLENECK
11
+ NUM_BLOCKS: [4]
12
+ NUM_CHANNELS: [64]
13
+ FUSE_METHOD: SUM
14
+ STAGE2:
15
+ NUM_MODULES: 1
16
+ NUM_BRANCHES: 2
17
+ BLOCK: BASIC
18
+ NUM_BLOCKS: [4, 4]
19
+ NUM_CHANNELS: [48, 96]
20
+ FUSE_METHOD: SUM
21
+ STAGE3:
22
+ NUM_MODULES: 4
23
+ NUM_BRANCHES: 3
24
+ BLOCK: BASIC
25
+ NUM_BLOCKS: [4, 4, 4]
26
+ NUM_CHANNELS: [48, 96, 192]
27
+ FUSE_METHOD: SUM
28
+ STAGE4:
29
+ NUM_MODULES: 3
30
+ NUM_BRANCHES: 4
31
+ BLOCK: BASIC
32
+ NUM_BLOCKS: [4, 4, 4, 4]
33
+ NUM_CHANNELS: [48, 96, 192, 384]
34
+ FUSE_METHOD: SUM
35
+
keypoint_detect.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ea78fa76aaf94976a8eca428d6e3c59697a93430cba1a4603e20284b61f5113
3
+ size 264964645
miner.py ADDED
@@ -0,0 +1,2613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from numpy import ndarray
7
+ from pydantic import BaseModel
8
+ from ultralytics import YOLO
9
+ import os
10
+
11
+ from typing import Iterable, Generator, List, TypeVar, Tuple, Sequence, Any, Dict, Optional
12
+ from collections import deque, OrderedDict, defaultdict
13
+ import threading
14
+ from itertools import combinations
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ import yaml
17
+ from cv2 import (
18
+ bitwise_and,
19
+ findHomography,
20
+ warpPerspective,
21
+ cvtColor,
22
+ COLOR_BGR2GRAY,
23
+ threshold,
24
+ THRESH_BINARY,
25
+ getStructuringElement,
26
+ MORPH_RECT,
27
+ MORPH_TOPHAT,
28
+ GaussianBlur,
29
+ morphologyEx,
30
+ Canny,
31
+ connectedComponents,
32
+ perspectiveTransform,
33
+ RETR_EXTERNAL,
34
+ CHAIN_APPROX_SIMPLE,
35
+ findContours,
36
+ boundingRect,
37
+ dilate,
38
+ imread,
39
+ countNonZero
40
+ )
41
+ import gc
42
+
43
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
44
+ class BoundingBox(BaseModel):
45
+ x1: int
46
+ y1: int
47
+ x2: int
48
+ y2: int
49
+ cls_id: int
50
+ conf: float
51
+ track_id: int | None = None
52
+
53
+
54
+ class TVFrameResult(BaseModel):
55
+ frame_id: int
56
+ boxes: list[BoundingBox]
57
+ keypoints: list[tuple[int, int]]
58
+
59
+ V = TypeVar("V")
60
+ kp_threshold = 0.3
61
+
62
+ def create_batches(sequence: Iterable[V], batch_size: int) -> Generator[List[V], None, None]:
63
+ batch_size = max(batch_size, 1)
64
+ current_batch = []
65
+ for element in sequence:
66
+ if len(current_batch) == batch_size:
67
+ yield current_batch
68
+ current_batch = []
69
+ current_batch.append(element)
70
+ if current_batch:
71
+ yield current_batch
72
+
73
+ from torch import nn
74
+ from torch.nn import functional as F
75
+ from sklearn.cluster import KMeans
76
+ from PIL import Image
77
+ from collections import defaultdict
78
+
79
+ _OSNET_MODEL = None
80
+ team_classifier_path = None
81
+
82
+ BALL_ID = 0
83
+ GK_ID = 1
84
+ PLAYER_ID = 2
85
+ REF_ID = 3
86
+ TEAM_1_ID = 6
87
+ TEAM_2_ID = 7
88
+
89
+ pretrained_urls = {
90
+ 'osnet_x1_0':
91
+ 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
92
+ }
93
+
94
+ class ConvLayer(nn.Module):
95
+ """Convolution layer (conv + bn + relu)."""
96
+
97
+ def __init__(
98
+ self,
99
+ in_channels,
100
+ out_channels,
101
+ kernel_size,
102
+ stride=1,
103
+ padding=0,
104
+ groups=1,
105
+ IN=False
106
+ ):
107
+ super(ConvLayer, self).__init__()
108
+ self.conv = nn.Conv2d(
109
+ in_channels,
110
+ out_channels,
111
+ kernel_size,
112
+ stride=stride,
113
+ padding=padding,
114
+ bias=False,
115
+ groups=groups
116
+ )
117
+ if IN:
118
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True)
119
+ else:
120
+ self.bn = nn.BatchNorm2d(out_channels)
121
+ self.relu = nn.ReLU(inplace=True)
122
+
123
+ def forward(self, x):
124
+ x = self.conv(x)
125
+ x = self.bn(x)
126
+ x = self.relu(x)
127
+ return x
128
+
129
+
130
+ class Conv1x1(nn.Module):
131
+ """1x1 convolution + bn + relu."""
132
+
133
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
134
+ super(Conv1x1, self).__init__()
135
+ self.conv = nn.Conv2d(
136
+ in_channels,
137
+ out_channels,
138
+ 1,
139
+ stride=stride,
140
+ padding=0,
141
+ bias=False,
142
+ groups=groups
143
+ )
144
+ self.bn = nn.BatchNorm2d(out_channels)
145
+ self.relu = nn.ReLU(inplace=True)
146
+
147
+ def forward(self, x):
148
+ x = self.conv(x)
149
+ x = self.bn(x)
150
+ x = self.relu(x)
151
+ return x
152
+
153
+
154
+ class Conv1x1Linear(nn.Module):
155
+ """1x1 convolution + bn (w/o non-linearity)."""
156
+
157
+ def __init__(self, in_channels, out_channels, stride=1):
158
+ super(Conv1x1Linear, self).__init__()
159
+ self.conv = nn.Conv2d(
160
+ in_channels, out_channels, 1, stride=stride, padding=0, bias=False
161
+ )
162
+ self.bn = nn.BatchNorm2d(out_channels)
163
+
164
+ def forward(self, x):
165
+ x = self.conv(x)
166
+ x = self.bn(x)
167
+ return x
168
+
169
+
170
+ class Conv3x3(nn.Module):
171
+ """3x3 convolution + bn + relu."""
172
+
173
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
174
+ super(Conv3x3, self).__init__()
175
+ self.conv = nn.Conv2d(
176
+ in_channels,
177
+ out_channels,
178
+ 3,
179
+ stride=stride,
180
+ padding=1,
181
+ bias=False,
182
+ groups=groups
183
+ )
184
+ self.bn = nn.BatchNorm2d(out_channels)
185
+ self.relu = nn.ReLU(inplace=True)
186
+
187
+ def forward(self, x):
188
+ x = self.conv(x)
189
+ x = self.bn(x)
190
+ x = self.relu(x)
191
+ return x
192
+
193
+
194
+ class LightConv3x3(nn.Module):
195
+ """Lightweight 3x3 convolution.
196
+
197
+ 1x1 (linear) + dw 3x3 (nonlinear).
198
+ """
199
+
200
+ def __init__(self, in_channels, out_channels):
201
+ super(LightConv3x3, self).__init__()
202
+ self.conv1 = nn.Conv2d(
203
+ in_channels, out_channels, 1, stride=1, padding=0, bias=False
204
+ )
205
+ self.conv2 = nn.Conv2d(
206
+ out_channels,
207
+ out_channels,
208
+ 3,
209
+ stride=1,
210
+ padding=1,
211
+ bias=False,
212
+ groups=out_channels
213
+ )
214
+ self.bn = nn.BatchNorm2d(out_channels)
215
+ self.relu = nn.ReLU(inplace=True)
216
+
217
+ def forward(self, x):
218
+ x = self.conv1(x)
219
+ x = self.conv2(x)
220
+ x = self.bn(x)
221
+ x = self.relu(x)
222
+ return x
223
+
224
+
225
+ class ChannelGate(nn.Module):
226
+
227
+ def __init__(
228
+ self,
229
+ in_channels,
230
+ num_gates=None,
231
+ return_gates=False,
232
+ gate_activation='sigmoid',
233
+ reduction=16,
234
+ layer_norm=False
235
+ ):
236
+ super(ChannelGate, self).__init__()
237
+ if num_gates is None:
238
+ num_gates = in_channels
239
+ self.return_gates = return_gates
240
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
241
+ self.fc1 = nn.Conv2d(
242
+ in_channels,
243
+ in_channels // reduction,
244
+ kernel_size=1,
245
+ bias=True,
246
+ padding=0
247
+ )
248
+ self.norm1 = None
249
+ if layer_norm:
250
+ self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
251
+ self.relu = nn.ReLU(inplace=True)
252
+ self.fc2 = nn.Conv2d(
253
+ in_channels // reduction,
254
+ num_gates,
255
+ kernel_size=1,
256
+ bias=True,
257
+ padding=0
258
+ )
259
+ if gate_activation == 'sigmoid':
260
+ self.gate_activation = nn.Sigmoid()
261
+ elif gate_activation == 'relu':
262
+ self.gate_activation = nn.ReLU(inplace=True)
263
+ elif gate_activation == 'linear':
264
+ self.gate_activation = None
265
+ else:
266
+ raise RuntimeError(
267
+ "Unknown gate activation: {}".format(gate_activation)
268
+ )
269
+
270
+ def forward(self, x):
271
+ input = x
272
+ x = self.global_avgpool(x)
273
+ x = self.fc1(x)
274
+ if self.norm1 is not None:
275
+ x = self.norm1(x)
276
+ x = self.relu(x)
277
+ x = self.fc2(x)
278
+ if self.gate_activation is not None:
279
+ x = self.gate_activation(x)
280
+ if self.return_gates:
281
+ return x
282
+ return input * x
283
+
284
+
285
+ class OSBlock(nn.Module):
286
+ """Omni-scale feature learning block."""
287
+
288
+ def __init__(
289
+ self,
290
+ in_channels,
291
+ out_channels,
292
+ IN=False,
293
+ bottleneck_reduction=4,
294
+ **kwargs
295
+ ):
296
+ super(OSBlock, self).__init__()
297
+ mid_channels = out_channels // bottleneck_reduction
298
+ self.conv1 = Conv1x1(in_channels, mid_channels)
299
+ self.conv2a = LightConv3x3(mid_channels, mid_channels)
300
+ self.conv2b = nn.Sequential(
301
+ LightConv3x3(mid_channels, mid_channels),
302
+ LightConv3x3(mid_channels, mid_channels),
303
+ )
304
+ self.conv2c = nn.Sequential(
305
+ LightConv3x3(mid_channels, mid_channels),
306
+ LightConv3x3(mid_channels, mid_channels),
307
+ LightConv3x3(mid_channels, mid_channels),
308
+ )
309
+ self.conv2d = nn.Sequential(
310
+ LightConv3x3(mid_channels, mid_channels),
311
+ LightConv3x3(mid_channels, mid_channels),
312
+ LightConv3x3(mid_channels, mid_channels),
313
+ LightConv3x3(mid_channels, mid_channels),
314
+ )
315
+ self.gate = ChannelGate(mid_channels)
316
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
317
+ self.downsample = None
318
+ if in_channels != out_channels:
319
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
320
+ self.IN = None
321
+ if IN:
322
+ self.IN = nn.InstanceNorm2d(out_channels, affine=True)
323
+
324
+ def forward(self, x):
325
+ identity = x
326
+ x1 = self.conv1(x)
327
+ x2a = self.conv2a(x1)
328
+ x2b = self.conv2b(x1)
329
+ x2c = self.conv2c(x1)
330
+ x2d = self.conv2d(x1)
331
+ x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
332
+ x3 = self.conv3(x2)
333
+ if self.downsample is not None:
334
+ identity = self.downsample(identity)
335
+ out = x3 + identity
336
+ if self.IN is not None:
337
+ out = self.IN(out)
338
+ return F.relu(out)
339
+
340
+
341
+ class OSNet(nn.Module):
342
+
343
+ def __init__(
344
+ self,
345
+ num_classes,
346
+ blocks,
347
+ layers,
348
+ channels,
349
+ feature_dim=512,
350
+ loss='softmax',
351
+ IN=False,
352
+ **kwargs
353
+ ):
354
+ super(OSNet, self).__init__()
355
+ num_blocks = len(blocks)
356
+ assert num_blocks == len(layers)
357
+ assert num_blocks == len(channels) - 1
358
+ self.loss = loss
359
+ self.feature_dim = feature_dim
360
+
361
+ # convolutional backbone
362
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
363
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
364
+ self.conv2 = self._make_layer(
365
+ blocks[0],
366
+ layers[0],
367
+ channels[0],
368
+ channels[1],
369
+ reduce_spatial_size=True,
370
+ IN=IN
371
+ )
372
+ self.conv3 = self._make_layer(
373
+ blocks[1],
374
+ layers[1],
375
+ channels[1],
376
+ channels[2],
377
+ reduce_spatial_size=True
378
+ )
379
+ self.conv4 = self._make_layer(
380
+ blocks[2],
381
+ layers[2],
382
+ channels[2],
383
+ channels[3],
384
+ reduce_spatial_size=False
385
+ )
386
+ self.conv5 = Conv1x1(channels[3], channels[3])
387
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
388
+ # fully connected layer
389
+ self.fc = self._construct_fc_layer(
390
+ self.feature_dim, channels[3], dropout_p=None
391
+ )
392
+ # identity classification layer
393
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
394
+
395
+ self._init_params()
396
+
397
+ def _make_layer(
398
+ self,
399
+ block,
400
+ layer,
401
+ in_channels,
402
+ out_channels,
403
+ reduce_spatial_size,
404
+ IN=False
405
+ ):
406
+ layers = []
407
+
408
+ layers.append(block(in_channels, out_channels, IN=IN))
409
+ for i in range(1, layer):
410
+ layers.append(block(out_channels, out_channels, IN=IN))
411
+
412
+ if reduce_spatial_size:
413
+ layers.append(
414
+ nn.Sequential(
415
+ Conv1x1(out_channels, out_channels),
416
+ nn.AvgPool2d(2, stride=2)
417
+ )
418
+ )
419
+
420
+ return nn.Sequential(*layers)
421
+
422
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
423
+ if fc_dims is None or fc_dims < 0:
424
+ self.feature_dim = input_dim
425
+ return None
426
+
427
+ if isinstance(fc_dims, int):
428
+ fc_dims = [fc_dims]
429
+
430
+ layers = []
431
+ for dim in fc_dims:
432
+ layers.append(nn.Linear(input_dim, dim))
433
+ layers.append(nn.BatchNorm1d(dim))
434
+ layers.append(nn.ReLU(inplace=True))
435
+ if dropout_p is not None:
436
+ layers.append(nn.Dropout(p=dropout_p))
437
+ input_dim = dim
438
+
439
+ self.feature_dim = fc_dims[-1]
440
+
441
+ return nn.Sequential(*layers)
442
+
443
+ def _init_params(self):
444
+ for m in self.modules():
445
+ if isinstance(m, nn.Conv2d):
446
+ nn.init.kaiming_normal_(
447
+ m.weight, mode='fan_out', nonlinearity='relu'
448
+ )
449
+ if m.bias is not None:
450
+ nn.init.constant_(m.bias, 0)
451
+
452
+ elif isinstance(m, nn.BatchNorm2d):
453
+ nn.init.constant_(m.weight, 1)
454
+ nn.init.constant_(m.bias, 0)
455
+
456
+ elif isinstance(m, nn.BatchNorm1d):
457
+ nn.init.constant_(m.weight, 1)
458
+ nn.init.constant_(m.bias, 0)
459
+
460
+ elif isinstance(m, nn.Linear):
461
+ nn.init.normal_(m.weight, 0, 0.01)
462
+ if m.bias is not None:
463
+ nn.init.constant_(m.bias, 0)
464
+
465
+ def featuremaps(self, x):
466
+ x = self.conv1(x)
467
+ x = self.maxpool(x)
468
+ x = self.conv2(x)
469
+ x = self.conv3(x)
470
+ x = self.conv4(x)
471
+ x = self.conv5(x)
472
+ return x
473
+
474
+ def forward(self, x, return_featuremaps=False):
475
+ x = self.featuremaps(x)
476
+ if return_featuremaps:
477
+ return x
478
+ v = self.global_avgpool(x)
479
+ v = v.view(v.size(0), -1)
480
+ if self.fc is not None:
481
+ v = self.fc(v)
482
+ if not self.training:
483
+ return v
484
+ y = self.classifier(v)
485
+ if self.loss == 'softmax':
486
+ return y
487
+ elif self.loss == 'triplet':
488
+ return y, v
489
+ else:
490
+ raise KeyError("Unsupported loss: {}".format(self.loss))
491
+
492
+
493
+ def init_pretrained_weights(model, key=''):
494
+ import os
495
+ import errno
496
+ import gdown
497
+ from collections import OrderedDict
498
+
499
+ def _get_torch_home():
500
+ ENV_TORCH_HOME = 'TORCH_HOME'
501
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
502
+ DEFAULT_CACHE_DIR = '~/.cache'
503
+ torch_home = os.path.expanduser(
504
+ os.getenv(
505
+ ENV_TORCH_HOME,
506
+ os.path.join(
507
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
508
+ )
509
+ )
510
+ )
511
+ return torch_home
512
+
513
+ torch_home = _get_torch_home()
514
+ model_dir = os.path.join(torch_home, 'checkpoints')
515
+ try:
516
+ os.makedirs(model_dir)
517
+ except OSError as e:
518
+ if e.errno == errno.EEXIST:
519
+ # Directory already exists, ignore.
520
+ pass
521
+ else:
522
+ # Unexpected OSError, re-raise.
523
+ raise
524
+ filename = key + '_imagenet.pth'
525
+ cached_file = os.path.join(model_dir, filename)
526
+
527
+ if not os.path.exists(cached_file):
528
+ gdown.download(pretrained_urls[key], cached_file, quiet=False)
529
+
530
+ state_dict = torch.load(cached_file)
531
+ model_dict = model.state_dict()
532
+ new_state_dict = OrderedDict()
533
+ matched_layers, discarded_layers = [], []
534
+
535
+ for k, v in state_dict.items():
536
+ if k.startswith('module.'):
537
+ k = k[7:] # discard module.
538
+
539
+ if k in model_dict and model_dict[k].size() == v.size():
540
+ new_state_dict[k] = v
541
+ matched_layers.append(k)
542
+ else:
543
+ discarded_layers.append(k)
544
+
545
+ model_dict.update(new_state_dict)
546
+ model.load_state_dict(model_dict)
547
+
548
+ if len(matched_layers) == 0:
549
+ print(
550
+ 'The pretrained weights from "{}" cannot be loaded, '
551
+ 'please check the key names manually '
552
+ '(** ignored and continue **)'.format(cached_file)
553
+ )
554
+ else:
555
+ print(
556
+ 'Successfully loaded imagenet pretrained weights from "{}"'.
557
+ format(cached_file)
558
+ )
559
+ if len(discarded_layers) > 0:
560
+ print(
561
+ '** The following layers are discarded '
562
+ 'due to unmatched keys or layer size: {}'.
563
+ format(discarded_layers)
564
+ )
565
+
566
+
567
+ def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
568
+ # standard size (width x1.0)
569
+ model = OSNet(
570
+ num_classes,
571
+ blocks=[OSBlock, OSBlock, OSBlock],
572
+ layers=[2, 2, 2],
573
+ channels=[64, 256, 384, 512],
574
+ loss=loss,
575
+ **kwargs
576
+ )
577
+ # if pretrained:
578
+ # init_pretrained_weights(model, key='osnet_x1_0')
579
+ return model
580
+
581
+ from typing import Generator, Iterable
582
+ import torchvision.transforms as T
583
+ from collections import OrderedDict
584
+ import os.path as osp
585
+
586
+ def load_checkpoint(fpath):
587
+ fpath = osp.abspath(osp.expanduser(fpath))
588
+ map_location = None if torch.cuda.is_available() else 'cpu'
589
+ # weights_only=False allows checkpoints that contain numpy/other objects (e.g. model.pth.tar-100)
590
+ checkpoint = torch.load(fpath, map_location=map_location, weights_only=False)
591
+ return checkpoint
592
+
593
+ def load_pretrained_weights(model, weight_path):
594
+ checkpoint = load_checkpoint(weight_path)
595
+ if 'state_dict' in checkpoint:
596
+ state_dict = checkpoint['state_dict']
597
+ else:
598
+ state_dict = checkpoint
599
+ model_dict = model.state_dict()
600
+ new_state_dict = OrderedDict()
601
+ matched_layers, discarded_layers = ([], [])
602
+ for k, v in state_dict.items():
603
+ if k.startswith('module.'):
604
+ k = k[7:]
605
+ if k in model_dict and model_dict[k].size() == v.size():
606
+ new_state_dict[k] = v
607
+ matched_layers.append(k)
608
+ else:
609
+ discarded_layers.append(k)
610
+ model_dict.update(new_state_dict)
611
+ model.load_state_dict(model_dict)
612
+
613
+ def load_osnet(device="cuda", weight_path=None):
614
+ """Build osnet_x1_0 and load weights from model.pth.tar-100 via load_pretrained_weights."""
615
+ model = osnet_x1_0(num_classes=1, loss='softmax', pretrained=False, use_gpu=device == 'cuda')
616
+ # if weight_path is None:
617
+ # weight_path = Path(__file__).resolve().parent / "model.pth.tar-100"
618
+ weight_path = Path(weight_path)
619
+ if weight_path.exists():
620
+ load_pretrained_weights(model, str(weight_path))
621
+ model.eval()
622
+ model.to(device)
623
+ return model
624
+
625
+ def filter_player_boxes(
626
+ boxes: List[BoundingBox],
627
+ min_area: int = 1500
628
+ ) -> List[BoundingBox]:
629
+
630
+ players = []
631
+ for b in boxes:
632
+ if b.cls_id != 2: # only players
633
+ continue
634
+ # area = (b.x2 - b.x1) * (b.y2 - b.y1)
635
+ # if area < min_area:
636
+ # continue
637
+
638
+ players.append(b)
639
+
640
+ return players
641
+
642
+ # OSNet preprocess (same as team_cluster: Resize, ToTensor, ImageNet normalize)
643
+ OSNET_IMAGE_SIZE = (64, 32) # (height, width)
644
+ OSNET_PREPROCESS = T.Compose([
645
+ T.Resize(OSNET_IMAGE_SIZE),
646
+ T.ToTensor(),
647
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
648
+ ])
649
+
650
+ def crop_upper_body(frame: np.ndarray, box: BoundingBox) -> np.ndarray:
651
+ # h = box.y2 - box.y1
652
+ # y2 = box.y1 + int(0.6 * h)
653
+
654
+ return frame[
655
+ max(0, box.y1):max(0, box.y2),
656
+ max(0, box.x1):max(0, box.x2)
657
+ ]
658
+
659
+ def preprocess_osnet(crop: np.ndarray) -> torch.Tensor:
660
+ """BGR crop -> RGB PIL -> Resize, ToTensor, ImageNet Normalize (same as team_cluster)."""
661
+ rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
662
+ pil = Image.fromarray(rgb)
663
+ return OSNET_PREPROCESS(pil)
664
+
665
+ @torch.no_grad()
666
+ def extract_osnet_embeddings(
667
+ frames: List[np.ndarray],
668
+ # batch_boxes: List[List[BoundingBox]],
669
+ batch_boxes: dict[int, List[BoundingBox]],
670
+ device="cuda",
671
+ batch_size=4
672
+ ) -> Tuple[np.ndarray, List[BoundingBox]]:
673
+
674
+ crops = []
675
+ meta = []
676
+ for frame, frame_index, boxes in zip(frames, batch_boxes.keys(), batch_boxes.values()):
677
+ players = filter_player_boxes(boxes)
678
+
679
+ for box in players:
680
+ crop = crop_upper_body(frame, box)
681
+ if crop.size == 0:
682
+ continue
683
+
684
+ crops.append(preprocess_osnet(crop))
685
+ meta.append(box)
686
+
687
+ if not crops:
688
+ return None, None
689
+
690
+ all_embeddings = []
691
+
692
+ with torch.no_grad(): # Inference mode saves ~20-30%
693
+ for start in range(0, len(crops), batch_size):
694
+ end = start + batch_size
695
+ batch = torch.stack(crops[start:end]).float().to(device)
696
+ embeddings_chunk = _OSNET_MODEL(batch) # (chunk_size, 256)
697
+ all_embeddings.append(embeddings_chunk.cpu())
698
+ del batch, embeddings_chunk
699
+
700
+ embeddings = torch.cat(all_embeddings, dim=0).numpy()
701
+ # embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
702
+
703
+ return embeddings, meta
704
+
705
+ def aggregate_by_track(
706
+ embeddings: np.ndarray,
707
+ meta: List[BoundingBox]
708
+ ):
709
+ track_map = defaultdict(list)
710
+ box_map = {}
711
+
712
+
713
+ for emb, box in zip(embeddings, meta):
714
+ key = box.track_id if box.track_id is not None else id(box)
715
+ track_map[key].append(emb)
716
+ box_map[key] = box
717
+
718
+ agg_embeddings = []
719
+ agg_boxes = []
720
+
721
+ for key, embs in track_map.items():
722
+ mean_emb = np.mean(embs, axis=0)
723
+ mean_emb /= np.linalg.norm(mean_emb)
724
+
725
+ agg_embeddings.append(mean_emb)
726
+ agg_boxes.append(box_map[key])
727
+
728
+ return np.array(agg_embeddings), agg_boxes
729
+
730
+ def cluster_teams(embeddings: np.ndarray):
731
+ if len(embeddings) < 2:
732
+ return None
733
+
734
+ kmeans = KMeans(n_clusters=2, n_init = 2, random_state=42)
735
+ return kmeans.fit_predict(embeddings)
736
+
737
+ def update_team_ids(
738
+ boxes: List[BoundingBox],
739
+ labels: np.ndarray
740
+ ):
741
+ for box, label in zip(boxes, labels):
742
+ box.cls_id = TEAM_1_ID if label == 0 else TEAM_2_ID
743
+
744
+ def classify_teams_batch(
745
+ frames: List[np.ndarray],
746
+ # batch_boxes: List[List[BoundingBox]],
747
+ batch_boxes: dict[int, List[BoundingBox]],
748
+ batch_size,
749
+ device="cuda"
750
+ ):
751
+ # Fallback: OSNet embeddings + aggregate by track + KMeans
752
+ embeddings, meta = extract_osnet_embeddings(
753
+ frames, batch_boxes, device, batch_size
754
+ )
755
+ if embeddings is None:
756
+ return
757
+ embeddings, agg_boxes = aggregate_by_track(embeddings, meta)
758
+ n = len(embeddings)
759
+ if n == 0:
760
+ return
761
+ if n == 1:
762
+ agg_boxes[0].cls_id = TEAM_1_ID
763
+ return
764
+
765
+ kmeans = KMeans(n_clusters=2, n_init=2, random_state=42)
766
+ kmeans.fit(embeddings)
767
+ centroids = kmeans.cluster_centers_ # (2, dim)
768
+ # print("Clusters' centers:")
769
+ # for i, c in enumerate(centroids):
770
+ # print(f" cluster_{i}: shape={c.shape}, norm={np.linalg.norm(c):.4f}, mean={np.mean(c):.4f}")
771
+ c0, c1 = centroids[0], centroids[1]
772
+ norm_0 = np.linalg.norm(c0)
773
+ norm_1 = np.linalg.norm(c1)
774
+ # Similarity (cosine), distance (L2), square error (SSE) between the two centers
775
+ similarity = np.dot(c0, c1) / (norm_0 * norm_1 + 1e-12)
776
+ distance = np.linalg.norm(c0 - c1)
777
+ square_error = np.sum((c0 - c1) ** 2)
778
+ # print(f" Between centers: similarity(cosine)={similarity:.4f}, distance(L2)={distance:.4f}, square_error(SSE)={square_error:.4f}")
779
+ if similarity > 0.95:
780
+ # Centers too similar: treat as one cluster (all same team)
781
+ for b in agg_boxes:
782
+ b.cls_id = TEAM_1_ID
783
+ # print(" Similarity > 0.95: using single cluster (all assigned to team 1).")
784
+ return
785
+ # If cluster_centers_[0] > cluster_centers_[1] then team A = cluster 0, else team B = cluster 0 (swap)
786
+ if norm_0 <= norm_1:
787
+ kmeans.labels_ = 1 - kmeans.labels_
788
+ update_team_ids(agg_boxes, kmeans.labels_)
789
+
790
+ def get_cls_net(config, pretrained='', **kwargs):
791
+ """Create keypoint detection model with softmax activation"""
792
+
793
+
794
+ def conv3x3(in_planes, out_planes, stride=1):
795
+ """3x3 convolution with padding"""
796
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
797
+ stride=stride, padding=1, bias=False)
798
+
799
+ class BasicBlock(nn.Module):
800
+ expansion = 1
801
+
802
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
803
+ super(BasicBlock, self).__init__()
804
+ self.conv1 = conv3x3(inplanes, planes, stride)
805
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
806
+ self.relu = nn.ReLU(inplace=True)
807
+ self.conv2 = conv3x3(planes, planes)
808
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
809
+ self.downsample = downsample
810
+ self.stride = stride
811
+
812
+ def forward(self, x):
813
+ residual = x
814
+
815
+ out = self.conv1(x)
816
+ out = self.bn1(out)
817
+ out = self.relu(out)
818
+
819
+ out = self.conv2(out)
820
+ out = self.bn2(out)
821
+
822
+ if self.downsample is not None:
823
+ residual = self.downsample(x)
824
+
825
+ out += residual
826
+ out = self.relu(out)
827
+
828
+ return out
829
+
830
+ class Bottleneck(nn.Module):
831
+ expansion = 4
832
+
833
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
834
+ super(Bottleneck, self).__init__()
835
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
836
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
837
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
838
+ padding=1, bias=False)
839
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
840
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
841
+ bias=False)
842
+ self.bn3 = BatchNorm2d(planes * self.expansion,
843
+ momentum=BN_MOMENTUM)
844
+ self.relu = nn.ReLU(inplace=True)
845
+ self.downsample = downsample
846
+ self.stride = stride
847
+
848
+ def forward(self, x):
849
+ residual = x
850
+
851
+ out = self.conv1(x)
852
+ out = self.bn1(out)
853
+ out = self.relu(out)
854
+
855
+ out = self.conv2(out)
856
+ out = self.bn2(out)
857
+ out = self.relu(out)
858
+
859
+ out = self.conv3(out)
860
+ out = self.bn3(out)
861
+
862
+ if self.downsample is not None:
863
+ residual = self.downsample(x)
864
+
865
+ out += residual
866
+ out = self.relu(out)
867
+
868
+ return out
869
+
870
+ BatchNorm2d = nn.BatchNorm2d
871
+ BN_MOMENTUM = 0.1
872
+ blocks_dict = {
873
+ 'BASIC': BasicBlock,
874
+ 'BOTTLENECK': Bottleneck
875
+ }
876
+ class HighResolutionModule(nn.Module):
877
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
878
+ num_channels, fuse_method, multi_scale_output=True):
879
+ super(HighResolutionModule, self).__init__()
880
+ self._check_branches(
881
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
882
+
883
+ self.num_inchannels = num_inchannels
884
+ self.fuse_method = fuse_method
885
+ self.num_branches = num_branches
886
+
887
+ self.multi_scale_output = multi_scale_output
888
+
889
+ self.branches = self._make_branches(
890
+ num_branches, blocks, num_blocks, num_channels)
891
+ self.fuse_layers = self._make_fuse_layers()
892
+ self.relu = nn.ReLU(inplace=True)
893
+
894
+ def _check_branches(self, num_branches, blocks, num_blocks,
895
+ num_inchannels, num_channels):
896
+ if num_branches != len(num_blocks):
897
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
898
+ num_branches, len(num_blocks))
899
+ raise ValueError(error_msg)
900
+
901
+ if num_branches != len(num_channels):
902
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
903
+ num_branches, len(num_channels))
904
+ raise ValueError(error_msg)
905
+
906
+ if num_branches != len(num_inchannels):
907
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
908
+ num_branches, len(num_inchannels))
909
+ raise ValueError(error_msg)
910
+
911
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
912
+ stride=1):
913
+ downsample = None
914
+ if stride != 1 or \
915
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
916
+ downsample = nn.Sequential(
917
+ nn.Conv2d(self.num_inchannels[branch_index],
918
+ num_channels[branch_index] * block.expansion,
919
+ kernel_size=1, stride=stride, bias=False),
920
+ BatchNorm2d(num_channels[branch_index] * block.expansion,
921
+ momentum=BN_MOMENTUM),
922
+ )
923
+
924
+ layers = []
925
+ layers.append(block(self.num_inchannels[branch_index],
926
+ num_channels[branch_index], stride, downsample))
927
+ self.num_inchannels[branch_index] = \
928
+ num_channels[branch_index] * block.expansion
929
+ for i in range(1, num_blocks[branch_index]):
930
+ layers.append(block(self.num_inchannels[branch_index],
931
+ num_channels[branch_index]))
932
+
933
+ return nn.Sequential(*layers)
934
+
935
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
936
+ branches = []
937
+
938
+ for i in range(num_branches):
939
+ branches.append(
940
+ self._make_one_branch(i, block, num_blocks, num_channels))
941
+
942
+ return nn.ModuleList(branches)
943
+
944
+ def _make_fuse_layers(self):
945
+ if self.num_branches == 1:
946
+ return None
947
+
948
+ num_branches = self.num_branches
949
+ num_inchannels = self.num_inchannels
950
+ fuse_layers = []
951
+ for i in range(num_branches if self.multi_scale_output else 1):
952
+ fuse_layer = []
953
+ for j in range(num_branches):
954
+ if j > i:
955
+ fuse_layer.append(nn.Sequential(
956
+ nn.Conv2d(num_inchannels[j],
957
+ num_inchannels[i],
958
+ 1,
959
+ 1,
960
+ 0,
961
+ bias=False),
962
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
963
+ # nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
964
+ elif j == i:
965
+ fuse_layer.append(None)
966
+ else:
967
+ conv3x3s = []
968
+ for k in range(i - j):
969
+ if k == i - j - 1:
970
+ num_outchannels_conv3x3 = num_inchannels[i]
971
+ conv3x3s.append(nn.Sequential(
972
+ nn.Conv2d(num_inchannels[j],
973
+ num_outchannels_conv3x3,
974
+ 3, 2, 1, bias=False),
975
+ BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM)))
976
+ else:
977
+ num_outchannels_conv3x3 = num_inchannels[j]
978
+ conv3x3s.append(nn.Sequential(
979
+ nn.Conv2d(num_inchannels[j],
980
+ num_outchannels_conv3x3,
981
+ 3, 2, 1, bias=False),
982
+ BatchNorm2d(num_outchannels_conv3x3,
983
+ momentum=BN_MOMENTUM),
984
+ nn.ReLU(inplace=True)))
985
+ fuse_layer.append(nn.Sequential(*conv3x3s))
986
+ fuse_layers.append(nn.ModuleList(fuse_layer))
987
+
988
+ return nn.ModuleList(fuse_layers)
989
+
990
+ def get_num_inchannels(self):
991
+ return self.num_inchannels
992
+
993
+ def forward(self, x):
994
+ if self.num_branches == 1:
995
+ return [self.branches[0](x[0])]
996
+
997
+ for i in range(self.num_branches):
998
+ x[i] = self.branches[i](x[i])
999
+
1000
+ x_fuse = []
1001
+ for i in range(len(self.fuse_layers)):
1002
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
1003
+ for j in range(1, self.num_branches):
1004
+ if i == j:
1005
+ y = y + x[j]
1006
+ elif j > i:
1007
+ y = y + F.interpolate(
1008
+ self.fuse_layers[i][j](x[j]),
1009
+ size=[x[i].shape[2], x[i].shape[3]],
1010
+ mode='bilinear')
1011
+ else:
1012
+ y = y + self.fuse_layers[i][j](x[j])
1013
+ x_fuse.append(self.relu(y))
1014
+
1015
+ return x_fuse
1016
+
1017
+ class HighResolutionNet(nn.Module):
1018
+
1019
+ def __init__(self, config, lines=False, **kwargs):
1020
+ self.inplanes = 64
1021
+ self.lines = lines
1022
+ extra = config['MODEL']['EXTRA']
1023
+ super(HighResolutionNet, self).__init__()
1024
+
1025
+ # stem net
1026
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1,
1027
+ bias=False)
1028
+ self.bn1 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
1029
+ self.conv2 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=2, padding=1,
1030
+ bias=False)
1031
+ self.bn2 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM)
1032
+ self.relu = nn.ReLU(inplace=True)
1033
+ self.sf = nn.Softmax(dim=1)
1034
+ self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
1035
+
1036
+ self.stage2_cfg = extra['STAGE2']
1037
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
1038
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
1039
+ num_channels = [
1040
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
1041
+ self.transition1 = self._make_transition_layer(
1042
+ [256], num_channels)
1043
+ self.stage2, pre_stage_channels = self._make_stage(
1044
+ self.stage2_cfg, num_channels)
1045
+
1046
+ self.stage3_cfg = extra['STAGE3']
1047
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
1048
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
1049
+ num_channels = [
1050
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
1051
+ self.transition2 = self._make_transition_layer(
1052
+ pre_stage_channels, num_channels)
1053
+ self.stage3, pre_stage_channels = self._make_stage(
1054
+ self.stage3_cfg, num_channels)
1055
+
1056
+ self.stage4_cfg = extra['STAGE4']
1057
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
1058
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
1059
+ num_channels = [
1060
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
1061
+ self.transition3 = self._make_transition_layer(
1062
+ pre_stage_channels, num_channels)
1063
+ self.stage4, pre_stage_channels = self._make_stage(
1064
+ self.stage4_cfg, num_channels, multi_scale_output=True)
1065
+
1066
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
1067
+ final_inp_channels = sum(pre_stage_channels) + self.inplanes
1068
+
1069
+ self.head = nn.Sequential(nn.Sequential(
1070
+ nn.Conv2d(
1071
+ in_channels=final_inp_channels,
1072
+ out_channels=final_inp_channels,
1073
+ kernel_size=1),
1074
+ BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
1075
+ nn.ReLU(inplace=True),
1076
+ nn.Conv2d(
1077
+ in_channels=final_inp_channels,
1078
+ out_channels=config['MODEL']['NUM_JOINTS'],
1079
+ kernel_size=extra['FINAL_CONV_KERNEL']),
1080
+ nn.Softmax(dim=1) if self.lines == False else nn.Sigmoid()))
1081
+
1082
+
1083
+
1084
+ def _make_head(self, x, x_skip):
1085
+ x = self.upsample(x)
1086
+ x = torch.cat([x, x_skip], dim=1)
1087
+ x = self.head(x)
1088
+
1089
+ return x
1090
+
1091
+ def _make_transition_layer(
1092
+ self, num_channels_pre_layer, num_channels_cur_layer):
1093
+ num_branches_cur = len(num_channels_cur_layer)
1094
+ num_branches_pre = len(num_channels_pre_layer)
1095
+
1096
+ transition_layers = []
1097
+ for i in range(num_branches_cur):
1098
+ if i < num_branches_pre:
1099
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
1100
+ transition_layers.append(nn.Sequential(
1101
+ nn.Conv2d(num_channels_pre_layer[i],
1102
+ num_channels_cur_layer[i],
1103
+ 3,
1104
+ 1,
1105
+ 1,
1106
+ bias=False),
1107
+ BatchNorm2d(
1108
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
1109
+ nn.ReLU(inplace=True)))
1110
+ else:
1111
+ transition_layers.append(None)
1112
+ else:
1113
+ conv3x3s = []
1114
+ for j in range(i + 1 - num_branches_pre):
1115
+ inchannels = num_channels_pre_layer[-1]
1116
+ outchannels = num_channels_cur_layer[i] \
1117
+ if j == i - num_branches_pre else inchannels
1118
+ conv3x3s.append(nn.Sequential(
1119
+ nn.Conv2d(
1120
+ inchannels, outchannels, 3, 2, 1, bias=False),
1121
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
1122
+ nn.ReLU(inplace=True)))
1123
+ transition_layers.append(nn.Sequential(*conv3x3s))
1124
+
1125
+ return nn.ModuleList(transition_layers)
1126
+
1127
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
1128
+ downsample = None
1129
+ if stride != 1 or inplanes != planes * block.expansion:
1130
+ downsample = nn.Sequential(
1131
+ nn.Conv2d(inplanes, planes * block.expansion,
1132
+ kernel_size=1, stride=stride, bias=False),
1133
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
1134
+ )
1135
+
1136
+ layers = []
1137
+ layers.append(block(inplanes, planes, stride, downsample))
1138
+ inplanes = planes * block.expansion
1139
+ for i in range(1, blocks):
1140
+ layers.append(block(inplanes, planes))
1141
+
1142
+ return nn.Sequential(*layers)
1143
+
1144
+ def _make_stage(self, layer_config, num_inchannels,
1145
+ multi_scale_output=True):
1146
+ num_modules = layer_config['NUM_MODULES']
1147
+ num_branches = layer_config['NUM_BRANCHES']
1148
+ num_blocks = layer_config['NUM_BLOCKS']
1149
+ num_channels = layer_config['NUM_CHANNELS']
1150
+ block = blocks_dict[layer_config['BLOCK']]
1151
+ fuse_method = layer_config['FUSE_METHOD']
1152
+
1153
+ modules = []
1154
+ for i in range(num_modules):
1155
+ # multi_scale_output is only used last module
1156
+ if not multi_scale_output and i == num_modules - 1:
1157
+ reset_multi_scale_output = False
1158
+ else:
1159
+ reset_multi_scale_output = True
1160
+ modules.append(
1161
+ HighResolutionModule(num_branches,
1162
+ block,
1163
+ num_blocks,
1164
+ num_inchannels,
1165
+ num_channels,
1166
+ fuse_method,
1167
+ reset_multi_scale_output)
1168
+ )
1169
+ num_inchannels = modules[-1].get_num_inchannels()
1170
+
1171
+ return nn.Sequential(*modules), num_inchannels
1172
+
1173
+ def forward(self, x):
1174
+ # h, w = x.size(2), x.size(3)
1175
+ x = self.conv1(x)
1176
+ x_skip = x.clone()
1177
+ x = self.bn1(x)
1178
+ x = self.relu(x)
1179
+ x = self.conv2(x)
1180
+ x = self.bn2(x)
1181
+ x = self.relu(x)
1182
+ x = self.layer1(x)
1183
+
1184
+ x_list = []
1185
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
1186
+ if self.transition1[i] is not None:
1187
+ x_list.append(self.transition1[i](x))
1188
+ else:
1189
+ x_list.append(x)
1190
+ y_list = self.stage2(x_list)
1191
+
1192
+ x_list = []
1193
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
1194
+ if self.transition2[i] is not None:
1195
+ x_list.append(self.transition2[i](y_list[-1]))
1196
+ else:
1197
+ x_list.append(y_list[i])
1198
+ y_list = self.stage3(x_list)
1199
+
1200
+ x_list = []
1201
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
1202
+ if self.transition3[i] is not None:
1203
+ x_list.append(self.transition3[i](y_list[-1]))
1204
+ else:
1205
+ x_list.append(y_list[i])
1206
+ x = self.stage4(x_list)
1207
+
1208
+ # Head Part
1209
+ height, width = x[0].size(2), x[0].size(3)
1210
+ x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False)
1211
+ x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False)
1212
+ x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False)
1213
+ x = torch.cat([x[0], x1, x2, x3], 1)
1214
+ x = self._make_head(x, x_skip)
1215
+
1216
+ return x
1217
+
1218
+ def init_weights(self, pretrained=''):
1219
+ for m in self.modules():
1220
+ if isinstance(m, nn.Conv2d):
1221
+ if self.lines == False:
1222
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
1223
+ else:
1224
+ nn.init.normal_(m.weight, std=0.001)
1225
+ #nn.init.normal_(m.weight, std=0.001)
1226
+ #nn.init.constant_(m.bias, 0)
1227
+ elif isinstance(m, nn.BatchNorm2d):
1228
+ nn.init.constant_(m.weight, 1)
1229
+ nn.init.constant_(m.bias, 0)
1230
+ if pretrained != '':
1231
+ if os.path.isfile(pretrained):
1232
+ pretrained_dict = torch.load(pretrained)
1233
+ model_dict = self.state_dict()
1234
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
1235
+ if k in model_dict.keys()}
1236
+ model_dict.update(pretrained_dict)
1237
+ self.load_state_dict(model_dict)
1238
+ else:
1239
+ sys.exit(f'Weights {pretrained} not found.')
1240
+
1241
+ model = HighResolutionNet(config, **kwargs)
1242
+ model.init_weights(pretrained)
1243
+ return model
1244
+ # Keypoint Inference
1245
+ def load_kp_model(path, device):
1246
+ config_kp_path = path / 'hrnetv2_w48.yaml'
1247
+ cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
1248
+
1249
+ loaded_state_kp = torch.load(path / "keypoint_detect.pt", map_location=device, weights_only=False)
1250
+ model = get_cls_net(cfg_kp)
1251
+ model.load_state_dict(loaded_state_kp)
1252
+ model.to(device)
1253
+ model.eval()
1254
+ return model
1255
+
1256
+ def preprocess_batch_fast(frames):
1257
+ """Ultra-fast batch preprocessing using optimized tensor operations"""
1258
+ target_size = (540, 960) # H, W format for model input
1259
+ batch = []
1260
+ for i, frame in enumerate(frames):
1261
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1262
+ img = cv2.resize(frame_rgb, (target_size[1], target_size[0]))
1263
+ img = img.astype(np.float32) / 255.0
1264
+ img = np.transpose(img, (2, 0, 1)) # HWC -> CHW
1265
+ batch.append(img)
1266
+ batch = torch.from_numpy(np.stack(batch)).float()
1267
+
1268
+ return batch
1269
+
1270
+ def extract_keypoints_from_heatmap_fast(heatmap: torch.Tensor, scale: int = 2, max_keypoints: int = 1):
1271
+ """Ultra-fast keypoint extraction optimized for speed"""
1272
+ batch_size, n_channels, height, width = heatmap.shape
1273
+
1274
+ # Simplified local maxima detection (faster but slightly less accurate)
1275
+ max_pooled = F.max_pool2d(heatmap, 3, stride=1, padding=1)
1276
+ local_maxima = (max_pooled == heatmap)
1277
+
1278
+ # Apply mask and get top keypoints in one go
1279
+ masked_heatmap = heatmap * local_maxima
1280
+ flat_heatmap = masked_heatmap.view(batch_size, n_channels, -1)
1281
+ scores, indices = torch.topk(flat_heatmap, max_keypoints, dim=-1, sorted=False)
1282
+
1283
+ # Vectorized coordinate calculation
1284
+ y_coords = torch.div(indices, width, rounding_mode="floor") * scale
1285
+ x_coords = (indices % width) * scale
1286
+
1287
+ # Stack results efficiently
1288
+ results = torch.stack([x_coords.float(), y_coords.float(), scores], dim=-1)
1289
+ return results
1290
+
1291
+ def process_keypoints_vectorized(kp_coords, kp_threshold, w, h, batch_size):
1292
+ """Ultra-fast vectorized keypoint processing"""
1293
+ batch_results = []
1294
+
1295
+ # Convert to numpy once for faster CPU operations
1296
+ kp_np = kp_coords.cpu().numpy()
1297
+
1298
+ for batch_idx in range(batch_size):
1299
+ kp_dict = {}
1300
+ # Vectorized threshold check
1301
+ valid_kps = kp_np[batch_idx, :, 0, 2] > kp_threshold
1302
+ valid_indices = np.where(valid_kps)[0]
1303
+
1304
+ for ch_idx in valid_indices:
1305
+ x = float(kp_np[batch_idx, ch_idx, 0, 0]) / w
1306
+ y = float(kp_np[batch_idx, ch_idx, 0, 1]) / h
1307
+ p = float(kp_np[batch_idx, ch_idx, 0, 2])
1308
+ kp_dict[ch_idx + 1] = {'x': x, 'y': y, 'p': p}
1309
+
1310
+ batch_results.append(kp_dict)
1311
+
1312
+ return batch_results
1313
+
1314
+ def inference_batch(frames, model, kp_threshold, device, batch_size=8):
1315
+ """Optimized batch inference for multiple frames"""
1316
+ results = []
1317
+ num_frames = len(frames)
1318
+
1319
+ # Get the device from the model itself
1320
+ model_device = next(model.parameters()).device
1321
+
1322
+ # Process all frames in optimally-sized batches
1323
+ for i in range(0, num_frames, batch_size):
1324
+ current_batch_size = min(batch_size, num_frames - i)
1325
+ batch_frames = frames[i:i + current_batch_size]
1326
+
1327
+ # Fast preprocessing - create on CPU first
1328
+ batch = preprocess_batch_fast(batch_frames)
1329
+ b, c, h, w = batch.size()
1330
+
1331
+ # Move batch to model device
1332
+ batch = batch.to(model_device)
1333
+
1334
+ with torch.inference_mode():
1335
+ heatmaps = model(batch)
1336
+
1337
+ # Ultra-fast keypoint extraction
1338
+ kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1)
1339
+
1340
+ # Vectorized batch processing - no loops
1341
+ batch_results = process_keypoints_vectorized(kp_coords, kp_threshold, 960, 540, current_batch_size)
1342
+ results.extend(batch_results)
1343
+
1344
+ del heatmaps, kp_coords, batch, batch_results, batch_frames
1345
+
1346
+ return results
1347
+
1348
+ map_keypoints = {
1349
+ 1: 1, 2: 14, 3: 25, 4: 2, 5: 10, 6: 18, 7: 26, 8: 3, 9: 7, 10: 23,
1350
+ 11: 27, 20: 4, 21: 8, 22: 24, 23: 28, 24: 5, 25: 13, 26: 21, 27: 29,
1351
+ 28: 6, 29: 17, 30: 30, 31: 11, 32: 15, 33: 19, 34: 12, 35: 16, 36: 20,
1352
+ 45: 9, 50: 31, 52: 32, 57: 22
1353
+ }
1354
+ def get_mapped_keypoints(kp_points):
1355
+ """Apply keypoint mapping to detection results"""
1356
+ mapped_points = {}
1357
+ for key, value in kp_points.items():
1358
+ if key in map_keypoints:
1359
+ mapped_key = map_keypoints[key]
1360
+ mapped_points[mapped_key] = value
1361
+ # else:
1362
+ # Keep unmapped keypoints with original key
1363
+ # mapped_points[key] = value
1364
+ return mapped_points
1365
+
1366
+ def process_batch_input(frames, model, kp_threshold, device='cpu', batch_size=16):
1367
+ """Process multiple input images in batch"""
1368
+ # Batch inference
1369
+ kp_results = inference_batch(frames, model, kp_threshold, device, batch_size)
1370
+ kp_results = [get_mapped_keypoints(kp) for kp in kp_results]
1371
+
1372
+ return kp_results
1373
+
1374
+
1375
+ def convert_keypoints_to_val_format(keypoints):
1376
+ return [tuple(int(x) for x in pair) for pair in keypoints]
1377
+
1378
+ def normalize_keypoints(keypoints_result, batch_images, n_keypoints):
1379
+ keypoints = []
1380
+ if keypoints_result is not None and len(keypoints_result) > 0:
1381
+ for frame_number_in_batch, kp_dict in enumerate(keypoints_result):
1382
+ if frame_number_in_batch >= len(batch_images):
1383
+ break
1384
+ frame_keypoints: List[Tuple[int, int]] = []
1385
+ try:
1386
+ height, width = batch_images[frame_number_in_batch].shape[:2]
1387
+ if kp_dict is not None and isinstance(kp_dict, dict):
1388
+ for idx in range(32):
1389
+ x, y, p = 0, 0, 0
1390
+ kp_idx = idx + 1
1391
+ if kp_idx in kp_dict:
1392
+ try:
1393
+ kp_data = kp_dict[kp_idx]
1394
+ if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data:
1395
+ x = int(kp_data["x"] * width)
1396
+ y = int(kp_data["y"] * height)
1397
+ except Exception as e:
1398
+ pass
1399
+ frame_keypoints.append((x, y))
1400
+ except (IndexError, ValueError, AttributeError):
1401
+ frame_keypoints = [(0, 0)] * 32
1402
+ if len(frame_keypoints) < n_keypoints:
1403
+ frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints)))
1404
+ else:
1405
+ frame_keypoints = frame_keypoints[:n_keypoints]
1406
+ keypoints.append(frame_keypoints)
1407
+ return keypoints
1408
+
1409
+ def fix_keypoints(frame_keypoints: list[tuple[int, int]], n_keypoints: int) -> list[tuple[int, int]]:
1410
+ # Pad or trim to exact n_keypoints
1411
+ if len(frame_keypoints) < n_keypoints:
1412
+ frame_keypoints += [(0, 0)] * (n_keypoints - len(frame_keypoints))
1413
+ elif len(frame_keypoints) > n_keypoints:
1414
+ frame_keypoints = frame_keypoints[:n_keypoints]
1415
+
1416
+ if(frame_keypoints[2] != (0, 0) and frame_keypoints[4] != (0, 0) and frame_keypoints[3] == (0, 0)):
1417
+ frame_keypoints[3] = frame_keypoints[4]
1418
+ frame_keypoints[4] = (0, 0)
1419
+
1420
+ if(frame_keypoints[0] != (0, 0) and frame_keypoints[4] != (0, 0) and frame_keypoints[1] == (0, 0)):
1421
+ frame_keypoints[1] = frame_keypoints[4]
1422
+ frame_keypoints[4] = (0, 0)
1423
+
1424
+ if(frame_keypoints[2] != (0, 0) and frame_keypoints[3] != (0, 0) and frame_keypoints[1] == (0, 0) and frame_keypoints[3][0] > frame_keypoints[2][0]):
1425
+ frame_keypoints[1] = frame_keypoints[3]
1426
+ frame_keypoints[3] = (0, 0)
1427
+
1428
+ if(frame_keypoints[28] != (0, 0) and frame_keypoints[25] == (0, 0) and frame_keypoints[26] != (0, 0) and frame_keypoints[26][0] > frame_keypoints[28][0]):
1429
+ frame_keypoints[25] = frame_keypoints[28]
1430
+ frame_keypoints[28] = (0, 0)
1431
+
1432
+ if(frame_keypoints[24] != (0, 0) and frame_keypoints[28] != (0, 0) and frame_keypoints[25] == (0, 0)):
1433
+ frame_keypoints[25] = frame_keypoints[28]
1434
+ frame_keypoints[28] = (0, 0)
1435
+
1436
+ if(frame_keypoints[24] != (0, 0) and frame_keypoints[27] != (0, 0) and frame_keypoints[26] == (0, 0)):
1437
+ frame_keypoints[26] = frame_keypoints[27]
1438
+ frame_keypoints[27] = (0, 0)
1439
+
1440
+ if(frame_keypoints[28] != (0, 0) and frame_keypoints[23] == (0, 0) and frame_keypoints[20] != (0, 0) and frame_keypoints[20][1] > frame_keypoints[23][1]):
1441
+ frame_keypoints[23] = frame_keypoints[20]
1442
+ frame_keypoints[20] = (0, 0)
1443
+
1444
+ if(frame_keypoints[28] != (0, 0) and frame_keypoints[23] == (0, 0) and frame_keypoints[20] != (0, 0) and frame_keypoints[20][1] > frame_keypoints[23][1]):
1445
+ frame_keypoints[23] = frame_keypoints[20]
1446
+ frame_keypoints[20] = (0, 0)
1447
+
1448
+
1449
+ return frame_keypoints
1450
+
1451
+ def challenge_template(path_hf_repo) -> ndarray:
1452
+ return imread(f"{path_hf_repo}/football_pitch_template.png")
1453
+
1454
+ current_path = str(os.path.dirname(os.path.abspath(__file__)))
1455
+ template_image = challenge_template(current_path)
1456
+ template_image_gray = cvtColor(template_image, COLOR_BGR2GRAY)
1457
+ _sparse_template_cache: dict[tuple[int, int], list[tuple[int, int]]] = {}
1458
+ _shared_eval_executor: ThreadPoolExecutor | None = None
1459
+
1460
+ class MaxSizeCache(OrderedDict):
1461
+ """
1462
+ Fixed-size dictionary behaving like a deque(maxlen=N).
1463
+ Stores key–value pairs with FIFO eviction.
1464
+ """
1465
+
1466
+ def __init__(self, maxlen=500):
1467
+ super().__init__()
1468
+ self.maxlen = maxlen
1469
+ self._lock = threading.Lock()
1470
+
1471
+ def set(self, key, value):
1472
+ """Insert or update an item. Evicts oldest if full."""
1473
+ with self._lock:
1474
+ if key in self:
1475
+ del self[key] # refresh position
1476
+ super().__setitem__(key, value)
1477
+
1478
+ if len(self) > self.maxlen:
1479
+ self.popitem(last=False) # remove oldest
1480
+
1481
+ def get(self, key, default=None):
1482
+ """Retrieve an item without changing order."""
1483
+ with self._lock:
1484
+ return super().get(key, default)
1485
+
1486
+ def exists(self, key):
1487
+ """Check if a key exists."""
1488
+ with self._lock:
1489
+ return key in self
1490
+
1491
+ def load(self, data_dict):
1492
+ """
1493
+ Load initial data into cache.
1494
+ Oldest items evicted if data exceeds maxlen.
1495
+ """
1496
+ for k, v in data_dict.items():
1497
+ self.set(k, v)
1498
+
1499
+ def __repr__(self):
1500
+ return f"MaxSizeCache(maxlen={self.maxlen}, data={dict(self)})"
1501
+ cached = MaxSizeCache()
1502
+ _per_key_locks = defaultdict(threading.Lock)
1503
+
1504
+ def get_or_compute_masks(key, compute_fn):
1505
+ lock = _per_key_locks[key]
1506
+ with lock:
1507
+ if cached.exists(key):
1508
+ return cached.get(key)
1509
+ # compute once
1510
+ masks = compute_fn()
1511
+ cached.set(key, masks)
1512
+ return masks
1513
+
1514
+ INDEX_KEYPOINT_CORNER_BOTTOM_LEFT = 5
1515
+ INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT = 29
1516
+ INDEX_KEYPOINT_CORNER_TOP_LEFT = 0
1517
+ INDEX_KEYPOINT_CORNER_TOP_RIGHT = 24
1518
+
1519
+ KEYPOINTS: list[tuple[int, int]] = [
1520
+ (5, 5), # 1
1521
+ (5, 140), # 2
1522
+ (5, 250), # 3
1523
+ (5, 430), # 4
1524
+ (5, 540), # 5
1525
+ (5, 675), # 6
1526
+ # -------------
1527
+ (55, 250), # 7
1528
+ (55, 430), # 8
1529
+ # -------------
1530
+ (110, 340), # 9
1531
+ # -------------
1532
+ (165, 140), # 10
1533
+ (165, 270), # 11
1534
+ (165, 410), # 12
1535
+ (165, 540), # 13
1536
+ # -------------
1537
+ (527, 5), # 14
1538
+ (527, 253), # 15
1539
+ (527, 433), # 16
1540
+ (527, 675), # 17
1541
+ # -------------
1542
+ (888, 140), # 18
1543
+ (888, 270), # 19
1544
+ (888, 410), # 20
1545
+ (888, 540), # 21
1546
+ # -------------
1547
+ (940, 340), # 22
1548
+ # -------------
1549
+ (998, 250), # 23
1550
+ (998, 430), # 24
1551
+ # -------------
1552
+ (1045, 5), # 25
1553
+ (1045, 140), # 26
1554
+ (1045, 250), # 27
1555
+ (1045, 430), # 28
1556
+ (1045, 540), # 29
1557
+ (1045, 675), # 30
1558
+ # -------------
1559
+ (435, 340), # 31
1560
+ (615, 340), # 32
1561
+ ]
1562
+
1563
+ KEYPOINTS_NP = np.asarray(KEYPOINTS, dtype=np.float32)
1564
+
1565
+ FOOTBALL_KEYPOINTS: list[tuple[int, int]] = [
1566
+ (0, 0), # 1
1567
+ (0, 0), # 2
1568
+ (0, 0), # 3
1569
+ (0, 0), # 4
1570
+ (0, 0), # 5
1571
+ (0, 0), # 6
1572
+
1573
+ (0, 0), # 7
1574
+ (0, 0), # 8
1575
+ (0, 0), # 9
1576
+
1577
+ (0, 0), # 10
1578
+ (0, 0), # 11
1579
+ (0, 0), # 12
1580
+ (0, 0), # 13
1581
+
1582
+ (0, 0), # 14
1583
+ (527, 283), # 15
1584
+ (527, 403), # 16
1585
+ (0, 0), # 17
1586
+
1587
+ (0, 0), # 18
1588
+ (0, 0), # 19
1589
+ (0, 0), # 20
1590
+ (0, 0), # 21
1591
+
1592
+ (0, 0), # 22
1593
+
1594
+ (0, 0), # 23
1595
+ (0, 0), # 24
1596
+
1597
+ (0, 0), # 25
1598
+ (0, 0), # 26
1599
+ (0, 0), # 27
1600
+ (0, 0), # 28
1601
+ (0, 0), # 29
1602
+ (0, 0), # 30
1603
+
1604
+ (405, 340), # 31
1605
+ (645, 340), # 32
1606
+ ]
1607
+
1608
+ FOOTBALL_KEYPOINTS_NP = np.asarray(FOOTBALL_KEYPOINTS, dtype=np.float32)
1609
+
1610
+ groups = {
1611
+ 1: [2, 3, 7, 10],
1612
+ 2: [1, 3, 7, 10],
1613
+ 3: [2, 4, 7, 8],
1614
+ 4: [3, 5, 8, 7],
1615
+ 5: [4, 8, 6, 3],
1616
+ 6: [5, 4, 8, 13],
1617
+ 7: [3, 8, 9, 10],
1618
+ 8: [4, 7, 9, 13],
1619
+ 9: [7, 8, 11, 12],
1620
+ 10: [9, 11, 7, 2],
1621
+ 11: [9, 10, 12, 31],
1622
+ 12: [9, 11, 13, 31],
1623
+ 13: [9, 12, 8, 5],
1624
+ 14: [15, 31, 32, 16],
1625
+ 15: [31, 16, 32, 14],
1626
+ 16: [31, 15, 32, 17],
1627
+ 17: [31, 16, 32, 15],
1628
+ 18: [19, 22, 23, 26],
1629
+ 19: [18, 22, 20, 32],
1630
+ 20: [19, 22, 21, 32],
1631
+ 21: [20, 22, 24, 29],
1632
+ 22: [23, 24, 19, 20],
1633
+ 23: [27, 24, 22, 28],
1634
+ 24: [28, 23, 22, 27],
1635
+ 25: [26, 27, 23, 18],
1636
+ 26: [25, 27, 23, 18],
1637
+ 27: [26, 23, 28, 24],
1638
+ 28: [27, 24, 29, 23],
1639
+ 29: [28, 30, 24, 21],
1640
+ 30: [29, 28, 24, 21],
1641
+ 31: [15, 16, 32, 14],
1642
+ 32: [15, 31, 16, 14]
1643
+ }
1644
+
1645
+ base_temps = [(0, 0)] * 32
1646
+
1647
+ _TEMPLATE_MAX_X: int = 1045
1648
+ _TEMPLATE_MAX_Y: int = 675
1649
+
1650
+ # Precomputed group arrays for faster neighbor lookup (0-based).
1651
+ GROUPS_ARRAY = [np.asarray(groups[i], dtype=np.int32) - 1 for i in range(1, 33)]
1652
+
1653
+ kernel = getStructuringElement(MORPH_RECT, (31, 31))
1654
+ dilate_kernel = getStructuringElement(
1655
+ MORPH_RECT, (3, 3)
1656
+ )
1657
+
1658
+ class InvalidMask(Exception):
1659
+ pass
1660
+
1661
+ def has_a_wide_line(mask: ndarray, max_aspect_ratio: float = 1.0) -> bool:
1662
+ contours, _ = findContours(mask, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
1663
+ for cnt in contours:
1664
+ x, y, w, h = boundingRect(cnt)
1665
+ # Early exit optimization
1666
+ if w == 0 or h == 0:
1667
+ continue
1668
+ aspect_ratio = min(w, h) / max(w, h)
1669
+ if aspect_ratio >= max_aspect_ratio:
1670
+ return True
1671
+ return False
1672
+
1673
+ def is_bowtie(points: ndarray) -> bool:
1674
+ def segments_intersect(p1: int, p2: int, q1: int, q2: int) -> bool:
1675
+ def ccw(a: int, b: int, c: int):
1676
+ return (c[1] - a[1]) * (b[0] - a[0]) > (b[1] - a[1]) * (c[0] - a[0])
1677
+
1678
+ return (ccw(p1, q1, q2) != ccw(p2, q1, q2)) and (
1679
+ ccw(p1, p2, q1) != ccw(p1, p2, q2)
1680
+ )
1681
+
1682
+ pts = points.reshape(-1, 2)
1683
+ edges = [(pts[0], pts[1]), (pts[1], pts[2]), (pts[2], pts[3]), (pts[3], pts[0])]
1684
+ return segments_intersect(*edges[0], *edges[2]) or segments_intersect(
1685
+ *edges[1], *edges[3]
1686
+ )
1687
+
1688
+ def validate_mask_lines(mask: ndarray) -> None:
1689
+ # Use fast count instead of sum when possible
1690
+ nonzero_count = countNonZero(mask)
1691
+ if nonzero_count == 0:
1692
+ raise InvalidMask("No projected lines")
1693
+ if nonzero_count == mask.size:
1694
+ raise InvalidMask("Projected lines cover the entire image surface")
1695
+ # Skip expensive contour check if mask is small
1696
+ if has_a_wide_line(mask=mask):
1697
+ raise InvalidMask("A projected line is too wide")
1698
+
1699
+ def validate_mask_ground(mask: ndarray) -> None:
1700
+ num_labels, _ = connectedComponents(mask)
1701
+ num_distinct_regions = num_labels - 1
1702
+ if num_distinct_regions > 1:
1703
+ raise InvalidMask(
1704
+ f"Projected ground should be a single object, detected {num_distinct_regions}"
1705
+ )
1706
+ area_covered = mask.sum() / mask.size
1707
+ if area_covered >= 0.9:
1708
+ raise InvalidMask(
1709
+ f"Projected ground covers more than {area_covered:.2f}% of the image surface which is unrealistic"
1710
+ )
1711
+
1712
+ def validate_projected_corners(
1713
+ source_keypoints: list[tuple[int, int]], homography_matrix: ndarray
1714
+ ) -> None:
1715
+ # Vectorized: use fancy indexing to extract corners
1716
+ corner_indices = np.array([
1717
+ INDEX_KEYPOINT_CORNER_BOTTOM_LEFT,
1718
+ INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT,
1719
+ INDEX_KEYPOINT_CORNER_TOP_RIGHT,
1720
+ INDEX_KEYPOINT_CORNER_TOP_LEFT
1721
+ ], dtype=np.int32)
1722
+
1723
+ # Convert to array once and index
1724
+ if isinstance(source_keypoints, np.ndarray):
1725
+ src_corners = source_keypoints[corner_indices]
1726
+ else:
1727
+ src_arr = np.array(source_keypoints, dtype=np.float32)
1728
+ src_corners = src_arr[corner_indices]
1729
+
1730
+ src_corners = src_corners[None, :, :]
1731
+ warped_corners = perspectiveTransform(src_corners, homography_matrix)[0]
1732
+
1733
+ if is_bowtie(warped_corners):
1734
+ raise InvalidMask("Projection twisted!")
1735
+
1736
+ def project_image_using_keypoints(
1737
+ image: ndarray,
1738
+ source_keypoints: list[tuple[int, int]],
1739
+ destination_keypoints: list[tuple[int, int]],
1740
+ destination_width: int,
1741
+ destination_height: int,
1742
+ inverse: bool = False,
1743
+ ) -> ndarray:
1744
+ # Vectorized filtering: convert to arrays and filter with boolean mask
1745
+ src_arr = np.array(source_keypoints, dtype=np.float32)
1746
+ dst_arr = np.array(destination_keypoints, dtype=np.float32)
1747
+
1748
+ # Vectorized mask: filter out (0, 0) destination points
1749
+ valid_mask = ~((dst_arr[:, 0] == 0) & (dst_arr[:, 1] == 0))
1750
+
1751
+ source_points = src_arr[valid_mask]
1752
+ destination_points = dst_arr[valid_mask]
1753
+
1754
+ H, _ = findHomography(source_points, destination_points)
1755
+ if H is None:
1756
+ raise InvalidMask("Homography not found")
1757
+ validate_projected_corners(source_keypoints=source_keypoints, homography_matrix=H)
1758
+
1759
+ projected_image = warpPerspective(image, H, (destination_width, destination_height))
1760
+
1761
+ return projected_image
1762
+
1763
+ def extract_masks_for_ground_and_lines(image: ndarray,) -> tuple[ndarray, ndarray]:
1764
+ """assumes template coloured s.t. ground = gray, lines = white, background = black"""
1765
+ # gray = cvtColor(image, COLOR_BGR2GRAY)
1766
+ gray = image
1767
+
1768
+ _, mask_ground = threshold(gray, 10, 1, THRESH_BINARY)
1769
+
1770
+ x, y, w, h = cv2.boundingRect(cv2.findNonZero(mask_ground))
1771
+ rect_size = w * h
1772
+ area_size = countNonZero(mask_ground)
1773
+ is_rect = area_size == rect_size
1774
+
1775
+ if is_rect:
1776
+ raise InvalidMask(
1777
+ f"Projected ground should not be rectangular"
1778
+ )
1779
+
1780
+ total_pixels = mask_ground.size
1781
+ ground_nonzero = int(countNonZero(mask_ground))
1782
+ if ground_nonzero == 0:
1783
+ raise InvalidMask("No projected ground")
1784
+ area_covered = ground_nonzero / float(total_pixels)
1785
+ if area_covered >= 0.9:
1786
+ raise InvalidMask(f"Projected ground covers more than {area_covered:.2f}% of the image surface which is unrealistic")
1787
+
1788
+ validate_mask_ground(mask=mask_ground)
1789
+
1790
+ _, mask_lines = threshold(gray, 200, 1, THRESH_BINARY)
1791
+ validate_mask_lines(mask=mask_lines)
1792
+ return mask_ground, mask_lines
1793
+
1794
+
1795
+ def get_edge_mask(x, y, W, H, t):
1796
+ """Uses bitmasking instead of sets for speed."""
1797
+ mask = 0
1798
+ if x <= t: mask |= 1 # Left
1799
+ if x >= W - t: mask |= 2 # Right
1800
+ if y <= t: mask |= 4 # Top
1801
+ if y >= H - t: mask |= 8 # Bottom
1802
+ return mask
1803
+
1804
+ def both_points_same_direction_fast(A, B, W, H, t=100):
1805
+ mask_a = get_edge_mask(A[0], A[1], W, H, t)
1806
+ if mask_a == 0: return False
1807
+
1808
+ mask_b = get_edge_mask(B[0], B[1], W, H, t)
1809
+ if mask_b == 0: return False
1810
+
1811
+ # Bitwise AND: if any bit matches, they share an edge
1812
+ return (mask_a & mask_b) != 0
1813
+
1814
+ def canonical(obj):
1815
+ # numpy arrays -> keep order
1816
+ if isinstance(obj, np.ndarray):
1817
+ return canonical(obj.tolist())
1818
+
1819
+ # ordered sequences
1820
+ if isinstance(obj, (list, tuple)):
1821
+ return tuple(canonical(x) for x in obj)
1822
+
1823
+ # unordered sets
1824
+ if isinstance(obj, set):
1825
+ return tuple(sorted(canonical(x) for x in obj))
1826
+
1827
+ # dictionaries (keys may not be ordered)
1828
+ if isinstance(obj, dict):
1829
+ return tuple((k, canonical(v)) for k, v in sorted(obj.items()))
1830
+
1831
+ return obj # primitive types
1832
+
1833
+ def fast_cache_key(frame_keypoints, w, h):
1834
+ # Byte-based key avoids deep recursion/tuples while preserving order.
1835
+ # Optimize: check if already array to avoid copy
1836
+ if isinstance(frame_keypoints, np.ndarray):
1837
+ if frame_keypoints.dtype == np.int32:
1838
+ arr = frame_keypoints
1839
+ else:
1840
+ arr = frame_keypoints.astype(np.int32)
1841
+ else:
1842
+ arr = np.asarray(frame_keypoints, dtype=np.int32)
1843
+ return (arr.tobytes(), int(w), int(h))
1844
+
1845
+ blacklists = [
1846
+ [23, 24, 27, 28],
1847
+ [7, 8, 3, 4],
1848
+ [2, 10, 1, 14],
1849
+ [18, 26, 14, 25],
1850
+ [5, 13, 6, 17],
1851
+ [21, 29, 17, 30],
1852
+ [10, 11, 2, 3],
1853
+ [10, 11, 2, 7],
1854
+ [12, 13, 4, 5],
1855
+ [12, 13, 5, 8],
1856
+ [18, 19, 26, 27],
1857
+ [18, 19, 26, 23],
1858
+ [20, 21, 24, 29],
1859
+ [20, 21, 28, 29],
1860
+ [8, 4, 5, 13],
1861
+ [3, 7, 2, 10],
1862
+ [23, 27, 18, 26],
1863
+ [24, 28, 21, 29]
1864
+ ]
1865
+
1866
+ prepared_blacklists = [(set(bl), bl[0]-1, bl[1]-1) for bl in blacklists]
1867
+
1868
+ def evaluate_keypoints_for_frame(
1869
+ frame_keypoints: list[tuple[int, int]],
1870
+ frame_index,
1871
+ h,
1872
+ w,
1873
+ precomputed_key=None,
1874
+ ) -> float:
1875
+ global cache
1876
+ # key = canonical((frame_keypoints, w, h))
1877
+ key = precomputed_key or canonical(frame_keypoints, w, h)
1878
+ template_keypoints = KEYPOINTS
1879
+ floor_markings_template = template_image_gray
1880
+ # start = time.time()
1881
+
1882
+ try:
1883
+ # h, w = frame.shape[:2]
1884
+ def compute_masks_for_key(frame_keypoints, w, h):
1885
+ try:
1886
+ non_idxs_set = {i + 1 for i, kpt in enumerate(frame_keypoints) if kpt[0] != 0 or kpt[1] != 0}
1887
+ for bl_set, idx0, idx1 in prepared_blacklists:
1888
+ if non_idxs_set.issubset(bl_set):
1889
+ if both_points_same_direction_fast(frame_keypoints[idx0], frame_keypoints[idx1], w, h):
1890
+ return None, 0, None
1891
+
1892
+ warped_template = project_image_using_keypoints(
1893
+ image=floor_markings_template,
1894
+ source_keypoints=template_keypoints,
1895
+ destination_keypoints=frame_keypoints,
1896
+ destination_width=w,
1897
+ destination_height=h,
1898
+ )
1899
+ mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines(
1900
+ image=warped_template
1901
+ )
1902
+ mask_expected_on_ground = mask_lines_expected
1903
+
1904
+ ys, xs = np.where(mask_lines_expected == 1)
1905
+
1906
+ if len(xs) == 0:
1907
+ bbox = None # no foreground pixels
1908
+ else:
1909
+ min_x = xs.min()
1910
+ max_x = xs.max()
1911
+ min_y = ys.min()
1912
+ max_y = ys.max()
1913
+ bbox = (min_x, min_y, max_x, max_y)
1914
+ bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) if bbox is not None else 1
1915
+ frame_area = h * w
1916
+
1917
+ if (bbox_area / frame_area) < 0.2:
1918
+ return None, 0, None
1919
+
1920
+ pixels_on_lines = int(countNonZero(mask_expected_on_ground))
1921
+ return mask_expected_on_ground, pixels_on_lines, mask_ground
1922
+ except Exception as e:
1923
+ return None, 0, None
1924
+
1925
+ mask_expected_on_ground, pixels_on_lines, mask_ground = get_or_compute_masks(
1926
+ key, lambda: compute_masks_for_key(frame_keypoints, w, h)
1927
+ )
1928
+ if mask_expected_on_ground is None or pixels_on_lines == 0 or mask_ground is None:
1929
+ return 0.0
1930
+
1931
+ image_edges = check_frame[frame_index]
1932
+
1933
+ h, w = mask_expected_on_ground.shape[:2]
1934
+ work_buffer = np.zeros((h, w), dtype=np.uint8)
1935
+ bitwise_and(
1936
+ image_edges,
1937
+ image_edges,
1938
+ dst=work_buffer,
1939
+ mask=mask_ground
1940
+ )
1941
+ dilate(work_buffer, dilate_kernel, dst=work_buffer, iterations=3)
1942
+ threshold(work_buffer, 0, 255, cv2.THRESH_BINARY, dst=work_buffer)
1943
+ pixels_predicted_count = countNonZero(work_buffer)
1944
+ bitwise_and(work_buffer, mask_expected_on_ground, dst=work_buffer)
1945
+ pixels_overlapping = countNonZero(work_buffer)
1946
+ pixels_rest = pixels_predicted_count - pixels_overlapping
1947
+ total_pixels = pixels_predicted_count + pixels_on_lines - pixels_overlapping
1948
+ if total_pixels > 0 and (pixels_rest / total_pixels) > 0.9:
1949
+ return 0.0
1950
+ score = pixels_overlapping / (pixels_on_lines + 1e-8)
1951
+ return score
1952
+ except Exception as e:
1953
+ pass
1954
+ return 0.0
1955
+
1956
+ def _generate_sparse_template_keypoints(frame_width: int, frame_height: int) -> list[tuple[int, int]]:
1957
+ key = (int(frame_width), int(frame_height))
1958
+ if key in _sparse_template_cache:
1959
+ return _sparse_template_cache[key]
1960
+ template_max_x, template_max_y = (1045, 675)
1961
+ sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1)
1962
+ sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1)
1963
+ # Vectorized scaling and rounding
1964
+ scale_factors = np.array([sx, sy], dtype=np.float32)
1965
+ scaled_np = np.round(FOOTBALL_KEYPOINTS_NP * scale_factors).astype(np.int32)
1966
+ scaled = [(int(x), int(y)) for x, y in scaled_np]
1967
+ _sparse_template_cache[key] = scaled
1968
+ return scaled
1969
+
1970
+ def convert_keypoints_to_val_format(keypoints):
1971
+ # Vectorized: convert to numpy, cast, then back to list of tuples
1972
+ if not keypoints:
1973
+ return []
1974
+ arr = np.asarray(keypoints, dtype=np.int32)
1975
+ return [(int(x), int(y)) for x, y in arr]
1976
+
1977
+
1978
+ def are_collinear(pts, eps=1e-9):
1979
+ pts = np.asarray(pts)
1980
+ if len(pts) < 3:
1981
+ return True
1982
+ a, b, c = pts[:3]
1983
+ area = np.abs(np.cross(b - a, c - a))
1984
+ return area < eps
1985
+
1986
+ def line_to_line_transform(P1, P2, Q1, Q2):
1987
+ """
1988
+ Compute 2D affine transformation mapping line segment P1P2 -> Q1Q2
1989
+ Optimized version reducing allocations.
1990
+
1991
+ Parameters:
1992
+ P1, P2: source points (x, y)
1993
+ Q1, Q2: target points (x, y)
1994
+
1995
+ Returns:
1996
+ M: 3x3 homogeneous transformation matrix
1997
+ """
1998
+ P1 = np.asarray(P1, dtype=np.float64)
1999
+ P2 = np.asarray(P2, dtype=np.float64)
2000
+ Q1 = np.asarray(Q1, dtype=np.float64)
2001
+ Q2 = np.asarray(Q2, dtype=np.float64)
2002
+
2003
+ # Source and target vectors
2004
+ v_s = P2 - P1
2005
+ v_t = Q2 - Q1
2006
+
2007
+ # Scale factor (using hypot for better numerical stability)
2008
+ norm_s = np.hypot(v_s[0], v_s[1])
2009
+ norm_t = np.hypot(v_t[0], v_t[1])
2010
+ s = norm_t / norm_s
2011
+
2012
+ # Rotation angle
2013
+ theta = np.arctan2(v_t[1], v_t[0]) - np.arctan2(v_s[1], v_s[0])
2014
+
2015
+ # Precompute sin/cos
2016
+ cos_theta = np.cos(theta)
2017
+ sin_theta = np.sin(theta)
2018
+
2019
+ # 2x2 scaled rotation components
2020
+ sr00 = s * cos_theta
2021
+ sr01 = -s * sin_theta
2022
+ sr10 = s * sin_theta
2023
+ sr11 = s * cos_theta
2024
+
2025
+ # Translation (direct computation avoiding matrix mul)
2026
+ t0 = Q1[0] - (sr00 * P1[0] + sr01 * P1[1])
2027
+ t1 = Q1[1] - (sr10 * P1[0] + sr11 * P1[1])
2028
+
2029
+ # Homogeneous 3x3 matrix (direct construction)
2030
+ M = np.array([
2031
+ [sr00, sr01, t0],
2032
+ [sr10, sr11, t1],
2033
+ [0.0, 0.0, 1.0]
2034
+ ], dtype=np.float64)
2035
+
2036
+ return M
2037
+
2038
+ def three_point_affine(P, Q):
2039
+ P = np.array(P, dtype=np.float64)
2040
+ Q = np.array(Q, dtype=np.float64)
2041
+ n = P.shape[0]
2042
+
2043
+ # Vectorized construction of least-squares system
2044
+ x, y = P[:, 0], P[:, 1]
2045
+ u, v = Q[:, 0], Q[:, 1]
2046
+
2047
+ # Pre-allocate A matrix
2048
+ A = np.zeros((2*n, 6), dtype=np.float64)
2049
+ A[0::2, 0] = x
2050
+ A[0::2, 1] = y
2051
+ A[0::2, 2] = 1
2052
+ A[1::2, 3] = x
2053
+ A[1::2, 4] = y
2054
+ A[1::2, 5] = 1
2055
+
2056
+ # Vectorized b vector
2057
+ b = np.empty(2*n, dtype=np.float64)
2058
+ b[0::2] = u
2059
+ b[1::2] = v
2060
+
2061
+ # Solve least squares (robust to collinear points)
2062
+ params, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
2063
+ a, b_, e, c, d, f = params
2064
+
2065
+ # Homogeneous transformation matrix
2066
+ M = np.array([
2067
+ [a, b_, e],
2068
+ [c, d, f],
2069
+ [0, 0, 1]
2070
+ ], dtype=np.float64)
2071
+
2072
+ return M
2073
+
2074
+ def affine_from_4_points(src_pts, dst_pts):
2075
+ """
2076
+ Compute a 2D affine transformation from 4 source points to 4 target points using least-squares.
2077
+ Vectorized version for better performance.
2078
+
2079
+ Parameters:
2080
+ src_pts: list of 4 source points [(x1,y1),..., (x4,y4)]
2081
+ dst_pts: list of 4 target points [(u1,v1),..., (u4,v4)]
2082
+
2083
+ Returns:
2084
+ 3x3 homogeneous affine transformation matrix
2085
+ """
2086
+ P = np.array(src_pts, dtype=np.float64)
2087
+ Q = np.array(dst_pts, dtype=np.float64)
2088
+
2089
+ # Vectorized construction of 8x6 system (2 eqs per point)
2090
+ x, y = P[:, 0], P[:, 1]
2091
+ u, v = Q[:, 0], Q[:, 1]
2092
+
2093
+ A = np.zeros((8, 6), dtype=np.float64)
2094
+ A[0::2, 0] = x
2095
+ A[0::2, 1] = y
2096
+ A[0::2, 2] = 1
2097
+ A[1::2, 3] = x
2098
+ A[1::2, 4] = y
2099
+ A[1::2, 5] = 1
2100
+
2101
+ b = np.empty(8, dtype=np.float64)
2102
+ b[0::2] = u
2103
+ b[1::2] = v
2104
+
2105
+ # Solve least-squares
2106
+ params, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
2107
+ a, b_, e, c, d, f = params
2108
+
2109
+ # Construct 3x3 affine matrix
2110
+ M = np.array([
2111
+ [a, b_, e],
2112
+ [c, d, f],
2113
+ [0, 0, 1]
2114
+ ], dtype=np.float64)
2115
+ return M
2116
+
2117
+ def four_point_homography(src_pts, dst_pts):
2118
+ """
2119
+ Compute 2D homography mapping 4 source points to 4 target points.
2120
+ Vectorized version for better performance.
2121
+
2122
+ src_pts: list of 4 source points [(x1,y1),..., (x4,y4)]
2123
+ dst_pts: list of 4 target points [(u1,v1),..., (u4,v4)]
2124
+
2125
+ Returns:
2126
+ 3x3 homography matrix
2127
+ """
2128
+ # Vectorized construction of A matrix
2129
+ src = np.array(src_pts, dtype=np.float64)
2130
+ dst = np.array(dst_pts, dtype=np.float64)
2131
+
2132
+ x, y = src[:, 0], src[:, 1]
2133
+ u, v = dst[:, 0], dst[:, 1]
2134
+
2135
+ # Pre-allocate A matrix
2136
+ A = np.zeros((8, 9), dtype=np.float64)
2137
+ A[0::2, 0] = -x
2138
+ A[0::2, 1] = -y
2139
+ A[0::2, 2] = -1
2140
+ A[0::2, 6] = x * u
2141
+ A[0::2, 7] = y * u
2142
+ A[0::2, 8] = u
2143
+
2144
+ A[1::2, 3] = -x
2145
+ A[1::2, 4] = -y
2146
+ A[1::2, 5] = -1
2147
+ A[1::2, 6] = x * v
2148
+ A[1::2, 7] = y * v
2149
+ A[1::2, 8] = v
2150
+
2151
+ # Solve Ah=0 using SVD
2152
+ _, _, Vt = np.linalg.svd(A)
2153
+ h = Vt[-1, :] # last row of V^T
2154
+ H = h.reshape(3, 3)
2155
+
2156
+ # Normalize
2157
+ H /= H[2, 2]
2158
+ return H
2159
+
2160
+ def unique_points(src, dst):
2161
+ src, dst = np.asarray(src, float), np.asarray(dst, float)
2162
+ # Vectorized filtering for zero points
2163
+ src_nonzero = ~np.all(np.abs(src) < 1e-9, axis=1)
2164
+ dst_nonzero = ~np.all(np.abs(dst) < 1e-9, axis=1)
2165
+ valid_mask = src_nonzero & dst_nonzero
2166
+
2167
+ if not valid_mask.any():
2168
+ return np.array([]), np.array([])
2169
+
2170
+ src_valid = src[valid_mask]
2171
+ dst_valid = dst[valid_mask]
2172
+
2173
+ # Remove duplicates using numpy unique
2174
+ _, unique_idx = np.unique(src_valid, axis=0, return_index=True)
2175
+ unique_idx.sort() # preserve order
2176
+
2177
+ return src_valid[unique_idx], dst_valid[unique_idx]
2178
+
2179
+ def robust_transform(src_pts, dst_pts):
2180
+ src, dst = unique_points(src_pts, dst_pts)
2181
+ n = len(src)
2182
+ if n >= 4:
2183
+ if are_collinear(src) or are_collinear(dst):
2184
+ H = affine_from_4_points(src, dst)
2185
+ return lambda pt: apply_transform(H, pt)
2186
+ else:
2187
+ H = four_point_homography(src, dst)
2188
+ return lambda pt: apply_homo_transform(H, pt)
2189
+ elif n==3:
2190
+ H = three_point_affine(src,dst)
2191
+ elif n==2:
2192
+ H = line_to_line_transform(src[0],src[1],dst[0],dst[1])
2193
+ elif n==1:
2194
+ t = dst[0]-src[0]
2195
+ H = np.eye(3)
2196
+ H[:2,2] = t
2197
+ else:
2198
+ H = np.eye(3)
2199
+ return lambda pt: apply_transform(H, pt)
2200
+
2201
+ def apply_homo_transform(M, P):
2202
+ # Optimized: direct indexing instead of array creation
2203
+ x, y = P[0], P[1]
2204
+
2205
+ # Apply transformation with pre-computed homogeneous coords
2206
+ w = M[2, 0] * x + M[2, 1] * y + M[2, 2]
2207
+ x_new = (M[0, 0] * x + M[0, 1] * y + M[0, 2]) / w
2208
+ y_new = (M[1, 0] * x + M[1, 1] * y + M[1, 2]) / w
2209
+
2210
+ # Displacement vector
2211
+ return (int(x_new - x), int(y_new - y))
2212
+
2213
+ def apply_transform(M, P):
2214
+ """
2215
+ Transform a single 2D point using a 3x3 transformation matrix H.
2216
+ Optimized version avoiding array creation.
2217
+
2218
+ Args:
2219
+ H : 3x3 numpy array
2220
+ Transformation matrix (homography, affine, similarity, etc.)
2221
+ point : (x, y) array-like
2222
+ Single point coordinates to transform.
2223
+
2224
+ Returns:
2225
+ (x', y') : Transformed point coordinates
2226
+ """
2227
+ # Direct computation without intermediate arrays
2228
+ x, y = P[0], P[1]
2229
+ x_new = M[0, 0] * x + M[0, 1] * y + M[0, 2]
2230
+ y_new = M[1, 0] * x + M[1, 1] * y + M[1, 2]
2231
+ return (int(x_new), int(y_new))
2232
+
2233
+ def pick_pt(points):
2234
+ # Fully vectorized neighbor expansion preserving original order.
2235
+ if not points:
2236
+ return []
2237
+ pts_arr = np.asarray(points, dtype=np.int32)
2238
+ seen = np.zeros(32, dtype=bool)
2239
+ valid_mask = (pts_arr >= 0) & (pts_arr < 32)
2240
+ seen[pts_arr[valid_mask]] = True
2241
+
2242
+ out_seen = np.zeros(32, dtype=bool)
2243
+ out = []
2244
+ for p in pts_arr[valid_mask]:
2245
+ neigh = GROUPS_ARRAY[p]
2246
+ candidates = neigh[~seen[neigh] & ~out_seen[neigh]]
2247
+ out_seen[candidates] = True
2248
+ out.extend(candidates.tolist())
2249
+ return out
2250
+
2251
+ def make_possible_keypoints(all_keypoints, frame_width, frame_height, limit=2):
2252
+ # Early exit for empty input
2253
+ if not all_keypoints:
2254
+ return []
2255
+
2256
+ results = []
2257
+
2258
+ for keypoints in all_keypoints:
2259
+ # --- FIX APPLIED HERE ---
2260
+ # np.asarray is smart: it avoids copying if the input is already
2261
+ # the right type/shape, but allows it if conversion is needed.
2262
+ arr = np.asarray(keypoints, dtype=np.int32)
2263
+
2264
+ # Basic shape validation
2265
+ if arr.ndim != 2 or arr.shape[1] != 2:
2266
+ continue
2267
+
2268
+ # Fast Masking and Counting
2269
+ mask = (arr[:, 0] != 0) & (arr[:, 1] != 0)
2270
+ non_zero_count = mask.sum()
2271
+
2272
+ # Logic Flow
2273
+ if non_zero_count > 4:
2274
+ results.append(keypoints)
2275
+ continue
2276
+
2277
+ if non_zero_count < 2:
2278
+ continue
2279
+
2280
+ # If exactly 4, we append the original BUT continue to try and find the 5th
2281
+ if non_zero_count == 4:
2282
+ results.append(keypoints)
2283
+
2284
+ # Prepare Transformation Data
2285
+ non_zero_idxs = np.flatnonzero(mask)
2286
+
2287
+ # Assuming KEYPOINTS_NP is available globally
2288
+ src = KEYPOINTS_NP[non_zero_idxs]
2289
+ dest = arr[non_zero_idxs].astype(np.float32)
2290
+
2291
+ try:
2292
+ # transform_func is calculated once
2293
+ transform_func = robust_transform(src, dest)
2294
+ except Exception:
2295
+ continue
2296
+
2297
+ # Get candidate indices to check
2298
+ candidate_idxs = pick_pt(non_zero_idxs.tolist())
2299
+ if not candidate_idxs:
2300
+ continue
2301
+
2302
+ # Pre-calculate Valid Projections
2303
+ valid_cache = {}
2304
+ valid_real_idxs = []
2305
+
2306
+ for idx in candidate_idxs:
2307
+ # Transform point
2308
+ t_pt = transform_func(KEYPOINTS_NP[idx])
2309
+
2310
+ # Unroll checks for speed
2311
+ tx, ty = t_pt[0], t_pt[1]
2312
+
2313
+ # Boundary check
2314
+ if 0 <= tx < frame_width and 0 <= ty < frame_height:
2315
+ valid_cache[idx] = (int(tx), int(ty))
2316
+ valid_real_idxs.append(idx)
2317
+
2318
+ # Check if we have enough valid points to satisfy the request
2319
+ n_missing = 5 - non_zero_count
2320
+ if len(valid_real_idxs) < n_missing:
2321
+ continue
2322
+
2323
+ # Generate Combinations
2324
+ cnt = 0
2325
+ for group in combinations(valid_real_idxs, n_missing):
2326
+ if cnt >= limit:
2327
+ break
2328
+ cnt += 1
2329
+
2330
+ # Create the result list
2331
+ # A shallow copy of the list is much faster than recreating a numpy object array.
2332
+ new_result = list(keypoints)
2333
+
2334
+ # Fill in the missing points from our cache
2335
+ for idx in group:
2336
+ new_result[idx] = valid_cache[idx]
2337
+
2338
+ results.append(new_result)
2339
+
2340
+ return results
2341
+
2342
+ def _get_shared_eval_executor(max_workers: int) -> ThreadPoolExecutor:
2343
+ global _shared_eval_executor
2344
+ if _shared_eval_executor is None:
2345
+ _shared_eval_executor = ThreadPoolExecutor(max_workers=max_workers)
2346
+ return _shared_eval_executor
2347
+
2348
+ def evaluates(jobs, h, w, total_frames: int):
2349
+ # start_time = time.time()
2350
+ if len(jobs) == 0:
2351
+ return []
2352
+
2353
+ unique_jobs = [] # (job, frame_index, key_bytes)
2354
+ seen = set()
2355
+
2356
+ for (job, frame_index) in jobs:
2357
+ try:
2358
+ # Optimize: check if already array
2359
+ if isinstance(job, np.ndarray):
2360
+ key_bytes = job.astype(np.int32).tobytes() if job.dtype != np.int32 else job.tobytes()
2361
+ else:
2362
+ key_bytes = np.asarray(job, dtype=np.int32).tobytes()
2363
+
2364
+ sig = (frame_index, key_bytes)
2365
+ if sig in seen:
2366
+ continue
2367
+ seen.add(sig)
2368
+ unique_jobs.append((job, frame_index, key_bytes))
2369
+ except Exception as e:
2370
+ continue
2371
+
2372
+ if len(unique_jobs) <= 10:
2373
+ scores_unique = [
2374
+ evaluate_keypoints_for_frame(job, frame_index, h, w, precomputed_key=(key_bytes, w, h))
2375
+ for (job, frame_index, key_bytes) in unique_jobs
2376
+ ]
2377
+ else:
2378
+ cpu_count = max(1, (os.cpu_count() or 1))
2379
+ max_workers = min(max(2, cpu_count), 8)
2380
+
2381
+ chunk_size = 500
2382
+ scores_unique = []
2383
+ ex = _get_shared_eval_executor(max_workers)
2384
+
2385
+ for i in range(0, len(unique_jobs), chunk_size):
2386
+ chunk = unique_jobs[i:i + chunk_size]
2387
+ scores_unique.extend(
2388
+ ex.map(
2389
+ lambda pair: evaluate_keypoints_for_frame(pair[0], pair[1], h, w, precomputed_key=(pair[2], w, h)),
2390
+ chunk,
2391
+ )
2392
+ )
2393
+ scores = np.full(total_frames, -1.0, dtype=np.float32)
2394
+ results = [[(0, 0)] * 32 for _ in range(total_frames)]
2395
+
2396
+ for score, (k, frame_index, _) in zip(scores_unique, unique_jobs):
2397
+ if score > scores[frame_index]:
2398
+ scores[frame_index] = score
2399
+ results[frame_index] = k
2400
+
2401
+ return results
2402
+
2403
+ def fix_keypoints_pri(
2404
+ results_frames,
2405
+ frame_width: int,
2406
+ frame_height: int
2407
+ ) -> list[Any]:
2408
+ sparse_template = convert_keypoints_to_val_format(_generate_sparse_template_keypoints(frame_width, frame_height))
2409
+ max_frames = len(results_frames)
2410
+ limit = 30
2411
+ before = deque(maxlen=limit)
2412
+ after = deque(maxlen=limit)
2413
+
2414
+ all_possible = [None] * max_frames
2415
+ for i in range(max_frames):
2416
+ all_possible[i] = make_possible_keypoints([results_frames[i]], frame_width, frame_height)
2417
+ for i in range(1, min(limit, max_frames)):
2418
+ after.append(all_possible[i])
2419
+
2420
+ current = all_possible[0] if max_frames > 0 else []
2421
+ total_jobs = []
2422
+
2423
+ for frame_index in range(max_frames):
2424
+ if frame_index < max_frames - limit:
2425
+ future_idx = frame_index + limit
2426
+ if all_possible[future_idx] is None:
2427
+ all_possible[future_idx] = make_possible_keypoints([results_frames[future_idx]], frame_width, frame_height)
2428
+ after.append(all_possible[future_idx])
2429
+
2430
+ frame_jobs = [(kpts, frame_index) for kpts in current]
2431
+ for t in after:
2432
+ frame_jobs.extend([(kpts, frame_index) for kpts in t])
2433
+ for t in before:
2434
+ frame_jobs.extend([(kpts, frame_index) for kpts in t])
2435
+ frame_jobs.append((sparse_template, frame_index))
2436
+
2437
+ total_jobs.extend(frame_jobs)
2438
+
2439
+ before.append(current)
2440
+
2441
+ if len(after) != 0:
2442
+ current = after.popleft()
2443
+
2444
+ start_time = time.time()
2445
+ results = evaluates(total_jobs, frame_height, frame_width, max_frames)
2446
+ print(f"Evaluation time: {time.time() - start_time}")
2447
+ return results
2448
+
2449
+
2450
+ def normalize_results(frame_results, threshold):
2451
+ if not frame_results:
2452
+ return []
2453
+
2454
+ results_array = []
2455
+ for result in frame_results:
2456
+ arr = np.array(result, dtype=np.float32) # (N, 3)
2457
+ if arr.size == 0:
2458
+ results_array.append([])
2459
+ continue
2460
+
2461
+ mask = arr[:, 2] > threshold # (N,)
2462
+ scaled = arr[:, :2] # (N, 2)
2463
+ scaled = np.where(mask[:, None], scaled, 0) # Apply mask
2464
+ results_array.append([(int(x), int(y)) for x, y in scaled])
2465
+
2466
+ return results_array
2467
+
2468
+ def convert_to_gray(image):
2469
+ gray = cvtColor(image, COLOR_BGR2GRAY)
2470
+ gray = morphologyEx(gray, MORPH_TOPHAT, kernel, dst=gray)
2471
+ GaussianBlur(gray, (5, 5), 0, dst=gray)
2472
+ image_edges = Canny(gray, 30, 100)
2473
+ return image_edges
2474
+
2475
+ class Miner:
2476
+ def __init__(self, path_hf_repo: Path) -> None:
2477
+
2478
+ global _OSNET_MODEL, team_classifier_path
2479
+ device = "cuda" if torch.cuda.is_available() else "cpu"
2480
+ self.device = device
2481
+ self.path_hf_repo = path_hf_repo
2482
+
2483
+ print("✅ Loading YOLO models...")
2484
+
2485
+ self.bbox_model = YOLO(path_hf_repo / "player_detect.pt")
2486
+
2487
+ print("✅ Loading Team Classifier...")
2488
+
2489
+
2490
+ self.keypoints_model = load_kp_model(path_hf_repo, device)
2491
+ self.pitch_batch_size = 4
2492
+ self.osnet_batch_size = 8
2493
+ self.kp_threshold = 0.3
2494
+
2495
+ team_classifier_path = path_hf_repo / "osnet_model.pth.tar-100"
2496
+
2497
+ _OSNET_MODEL = load_osnet(device, team_classifier_path)
2498
+
2499
+ print("✅ All models loaded")
2500
+
2501
+ def predict_batch(self, batch_images: list[ndarray], offset: int, n_keypoints: int):
2502
+ start = time.time()
2503
+ # ---------- YOLO ----------
2504
+ bboxes = {}
2505
+ bbox_model_results = self.bbox_model.predict(batch_images, verbose=False)
2506
+ print(f"Detect objects: {time.time() - start}")
2507
+
2508
+ start = time.time()
2509
+ track_id = 0
2510
+ track_number = 1
2511
+ for frame_number_in_batch, detection in enumerate(bbox_model_results):
2512
+ boxes: list[BoundingBox] = []
2513
+ for box in detection.boxes.data:
2514
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
2515
+ temp_track_id = None
2516
+ if cls_id == PLAYER_ID :
2517
+ track_id += 1
2518
+ temp_track_id = track_id
2519
+
2520
+ boxes.append(
2521
+ BoundingBox(
2522
+ x1=int(x1), y1=int(y1),
2523
+ x2=int(x2), y2=int(y2),
2524
+ cls_id=int(cls_id),
2525
+ conf=float(conf),
2526
+ track_id = temp_track_id,
2527
+ )
2528
+ )
2529
+
2530
+ ball_idxs = [i for i, b in enumerate(boxes) if b.cls_id == BALL_ID]
2531
+ if len(ball_idxs) > 1:
2532
+ best_i = max(ball_idxs, key=lambda i: boxes[i].conf)
2533
+ boxes = [
2534
+ b for i, b in enumerate(boxes)
2535
+ if not (b.cls_id == BALL_ID and i != best_i)
2536
+ ]
2537
+
2538
+ gk_idxs = [i for i, b in enumerate(boxes) if b.cls_id == GK_ID]
2539
+ if len(gk_idxs) > 1:
2540
+ best_gk_i = max(gk_idxs, key=lambda i: boxes[i].conf)
2541
+ for i in gk_idxs:
2542
+ if i != best_gk_i:
2543
+ boxes[i].cls_id = PLAYER_ID
2544
+ track_id += 1
2545
+ boxes[i].track_id = track_id
2546
+
2547
+ ref_idxs = [i for i, b in enumerate(boxes) if b.cls_id == REF_ID]
2548
+ if len(ref_idxs) > 3:
2549
+ # sort referee indices by confidence (descending)
2550
+ ref_idxs_sorted = sorted(ref_idxs, key=lambda i: boxes[i].conf, reverse=True)
2551
+ keep = set(ref_idxs_sorted[:3])
2552
+ for i in ref_idxs:
2553
+ if i not in keep:
2554
+ boxes[i].cls_id = PLAYER_ID
2555
+ track_id += 1
2556
+ boxes[i].track_id = track_id
2557
+
2558
+ bboxes[offset + frame_number_in_batch] = boxes
2559
+
2560
+ t_redi = team_classifier_path
2561
+ classify_teams_batch(
2562
+ frames=batch_images, # List[np.ndarray]
2563
+ batch_boxes=bboxes, # List[List[BoundingBox]]
2564
+ batch_size=self.osnet_batch_size,
2565
+ device=self.device
2566
+ )
2567
+ print(f"finish team classify")
2568
+ print(f"Object Tracking: {time.time() - start}")
2569
+
2570
+ start = time.time()
2571
+ batch_size = len(batch_images)
2572
+
2573
+ processed_tensors = []
2574
+ original_sizes = []
2575
+
2576
+ gc.collect()
2577
+ if torch.cuda.is_available():
2578
+ torch.cuda.empty_cache()
2579
+ torch.cuda.synchronize()
2580
+
2581
+ pitch_size = min(self.pitch_batch_size, len(batch_images))
2582
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
2583
+ keypoints = []
2584
+ keypoints_result = process_batch_input(
2585
+ batch_images,
2586
+ self.keypoints_model,
2587
+ self.kp_threshold,
2588
+ device_str,
2589
+ batch_size=pitch_size,
2590
+ )
2591
+ print(f"Kps detection: {time.time() - start}")
2592
+ start = time.time()
2593
+ keypoints = normalize_keypoints(keypoints_result, batch_images, n_keypoints)
2594
+ for idx, kpts in enumerate(keypoints):
2595
+ keypoints[idx] = fix_keypoints(kpts, n_keypoints)
2596
+
2597
+ h, w = batch_images[0].shape[:2]
2598
+ keypoints_by_frame = fix_keypoints_pri(keypoints, w, h)
2599
+ print(f"Fix kps: {time.time() - start}")
2600
+
2601
+ results = []
2602
+ for i in range(len(batch_images)):
2603
+ frame_number = offset + i
2604
+ results.append(
2605
+ TVFrameResult(
2606
+ frame_id=frame_number,
2607
+ boxes=bboxes.get(frame_number, []),
2608
+ keypoints=convert_keypoints_to_val_format(keypoints_by_frame[frame_number - offset])
2609
+ )
2610
+ )
2611
+
2612
+ return results
2613
+
osnet_model.pth.tar-100 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d9415749a81b4a86c0d22f7014855ae5570ad85e985720180dd50e23005700
3
+ size 40032239
player_detect.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2522f266ca93f910e2bfd65734c8985062ecf4ac13cd62bab6cc375aa19a4527
3
+ size 22541418