aivertex95827 commited on
Commit
6a81599
·
0 Parent(s):

Duplicate from aivertex95827/turbo4_1

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Example Chute for Turbovision 🪂
2
+
3
+ This repository demonstrates how to deploy a **Chute** via the **Turbovision CLI**, hosted on **Hugging Face Hub**.
4
+ It serves as a minimal example showcasing the required structure and workflow for integrating machine learning models, preprocessing, and orchestration into a reproducible Chute environment.
5
+
6
+ ## Repository Structure
7
+ The following two files **must be present** (in their current locations) for a successful deployment — their content can be modified as needed:
8
+
9
+ | File | Purpose |
10
+ |------|----------|
11
+ | `miner.py` | Defines the ML model type(s), orchestration, and all pre/postprocessing logic. |
12
+ | `config.yml` | Specifies machine configuration (e.g., GPU type, memory, environment variables). |
13
+
14
+ Other files — e.g., model weights, utility scripts, or dependencies — are **optional** and can be included as needed for your model. Note: Any required assets must be defined or contained **within this repo**, which is fully open-source, since all network-related operations (downloading challenge data, weights, etc.) are disabled **inside the Chute**
15
+
16
+ ## Overview
17
+
18
+ Below is a high-level diagram showing the interaction between Huggingface, Chutes and Turbovision:
19
+
20
+ ![](../images/miner.png)
21
+
22
+ ## Local Testing
23
+ After editing the `config.yml` and `miner.py` and saving it into your Huggingface Repo, you will want to test it works locally.
24
+
25
+ 1. Copy the file `scorevision/chute_tmeplate/turbovision_chute.py.j2` as a python file called `my_chute.py` and fill in the missing variables:
26
+ ```python
27
+ HF_REPO_NAME = "{{ huggingface_repository_name }}"
28
+ HF_REPO_REVISION = "{{ huggingface_repository_revision }}"
29
+ CHUTES_USERNAME = "{{ chute_username }}"
30
+ CHUTE_NAME = "{{ chute_name }}"
31
+ ```
32
+
33
+ 2. Run the following command to build the chute locally (Caution: there are known issues with the docker location when running this on a mac)
34
+ ```bash
35
+ chutes build my_chute:chute --local --public
36
+ ```
37
+
38
+ 3. Run the name of the docker image just built (i.e. `CHUTE_NAME`) and enter it
39
+ ```bash
40
+ docker run -p 8000:8000 -e CHUTES_EXECUTION_CONTEXT=REMOTE -it <image-name> /bin/bash
41
+ ```
42
+
43
+ 4. Run the file from within the container
44
+ ```bash
45
+ chutes run my_chute:chute --dev --debug
46
+ ```
47
+
48
+ 5. In another terminal, test the local endpoints to ensure there are no bugs
49
+ ```bash
50
+ curl -X POST http://localhost:8000/health -d '{}'
51
+ curl -X POST http://localhost:8000/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}'
52
+ ```
53
+
54
+ ## Live Testing
55
+ 1. If you have any chute with the same name (ie from a previous deployment), ensure you delete that first (or you will get an error when trying to build).
56
+ ```bash
57
+ chutes chutes list
58
+ ```
59
+ Take note of the chute id that you wish to delete (if any)
60
+ ```bash
61
+ chutes chutes delete <chute-id>
62
+ ```
63
+
64
+ You should also delete its associated image
65
+ ```bash
66
+ chutes images list
67
+ ```
68
+ Take note of the chute image id
69
+ ```bash
70
+ chutes images delete <chute-image-id>
71
+ ```
72
+
73
+ 2. Use Turbovision's CLI to build, deploy and commit on-chain (Note: you can skip the on-chain commit using `--no-commit`. You can also specify a past huggingface revision to point to using `--revision` and/or the local files you want to upload to your huggingface repo using `--model-path`)
74
+ ```bash
75
+ sv -vv push
76
+ ```
77
+
78
+ 3. When completed, warm up the chute (if its cold 🧊). (You can confirm its status using `chutes chutes list` or `chutes chutes get <chute-id>` if you already know its id). Note: Warming up can sometimes take a while but if the chute runs without errors (should be if you've tested locally first) and there are sufficient nodes (i.e. machines) available matching the `config.yml` you specified, the chute should become hot 🔥!
79
+ ```bash
80
+ chutes warmup <chute-id>
81
+ ```
82
+
83
+ 4. Test the chute's endpoints
84
+ ```bash
85
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/health -d '{}' -H "Authorization: Bearer $CHUTES_API_KEY"
86
+ curl -X POST https://<YOUR-CHUTE-SLUG>.chutes.ai/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}' -H "Authorization: Bearer $CHUTES_API_KEY"
87
+ ```
88
+
89
+ 5. Test what your chute would get on a validator (this also applies any validation/integrity checks which may fail if you did not use the Turbovision CLI above to deploy the chute)
90
+ ```bash
91
+ sv -vv run-once
92
+ ```
chute_config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Image:
2
+ from_base: parachutes/python:3.12
3
+ run_command:
4
+ - pip install --upgrade setuptools wheel
5
+ - pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
6
+ - pip install "ultralytics==8.3.222" "opencv-python-headless" "numpy" "pydantic" "Pillow"
7
+ - pip install scikit-learn
8
+ - pip install lap
9
+ - pip install scipy
10
+ - pip install onnxruntime-gpu
11
+ set_workdir: /app
12
+ readme: "Image for chutes"
13
+
14
+ NodeSelector:
15
+ gpu_count: 1
16
+ min_vram_gb_per_gpu: 48
17
+ min_memory_gb: 32
18
+ min_cpu_count: 32
19
+ exclude:
20
+ - b200
21
+ - h200
22
+ - mi300x
23
+
24
+ Chute:
25
+ timeout_seconds: 900
26
+ concurrency: 4
27
+ max_instances: 5
28
+ scaling_threshold: 0.5
29
+ shutdown_after_seconds: 288000
football_pitch_template.png ADDED
hrnetv2_w48.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ IMAGE_SIZE: [960, 540]
3
+ NUM_JOINTS: 58
4
+ PRETRAIN: ''
5
+ EXTRA:
6
+ FINAL_CONV_KERNEL: 1
7
+ STAGE1:
8
+ NUM_MODULES: 1
9
+ NUM_BRANCHES: 1
10
+ BLOCK: BOTTLENECK
11
+ NUM_BLOCKS: [4]
12
+ NUM_CHANNELS: [64]
13
+ FUSE_METHOD: SUM
14
+ STAGE2:
15
+ NUM_MODULES: 1
16
+ NUM_BRANCHES: 2
17
+ BLOCK: BASIC
18
+ NUM_BLOCKS: [4, 4]
19
+ NUM_CHANNELS: [48, 96]
20
+ FUSE_METHOD: SUM
21
+ STAGE3:
22
+ NUM_MODULES: 4
23
+ NUM_BRANCHES: 3
24
+ BLOCK: BASIC
25
+ NUM_BLOCKS: [4, 4, 4]
26
+ NUM_CHANNELS: [48, 96, 192]
27
+ FUSE_METHOD: SUM
28
+ STAGE4:
29
+ NUM_MODULES: 3
30
+ NUM_BRANCHES: 4
31
+ BLOCK: BASIC
32
+ NUM_BLOCKS: [4, 4, 4, 4]
33
+ NUM_CHANNELS: [48, 96, 192, 384]
34
+ FUSE_METHOD: SUM
35
+
keypoint_detect.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ea78fa76aaf94976a8eca428d6e3c59697a93430cba1a4603e20284b61f5113
3
+ size 264964645
miner.py ADDED
@@ -0,0 +1,1595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from ultralytics import YOLO
4
+ from numpy import ndarray
5
+ from pydantic import BaseModel
6
+ from typing import List, Tuple, Optional, Dict, Any
7
+ import numpy as np
8
+ import cv2
9
+ from sklearn.cluster import KMeans
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import yaml
14
+ import gc
15
+ import os
16
+ import sys
17
+ from collections import OrderedDict, defaultdict
18
+ from PIL import Image
19
+ import torchvision.transforms as T
20
+
21
+ try:
22
+ from scipy.optimize import linear_sum_assignment as _linear_sum_assignment
23
+ except ImportError:
24
+ _linear_sum_assignment = None
25
+
26
+ # ── Grass / kit helpers ────────────────────────────────
27
+
28
+ def get_grass_color(img: np.ndarray) -> Tuple[int, int, int]:
29
+ if img is None or img.size == 0:
30
+ return (0, 0, 0)
31
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
32
+ lower_green = np.array([30, 40, 40])
33
+ upper_green = np.array([80, 255, 255])
34
+ mask = cv2.inRange(hsv, lower_green, upper_green)
35
+ grass_color = cv2.mean(img, mask=mask)
36
+ return grass_color[:3]
37
+
38
+ def get_players_boxes(result):
39
+ players_imgs, players_boxes = [], []
40
+ for box in result.boxes:
41
+ label = int(box.cls.cpu().numpy()[0])
42
+ if label == 2:
43
+ x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
44
+ crop = result.orig_img[y1:y2, x1:x2]
45
+ if crop.size > 0:
46
+ players_imgs.append(crop)
47
+ players_boxes.append((x1, y1, x2, y2))
48
+ return players_imgs, players_boxes
49
+
50
+ def get_kits_colors(players, grass_hsv=None, frame=None):
51
+ kits_colors = []
52
+ if grass_hsv is None:
53
+ grass_color = get_grass_color(frame)
54
+ grass_hsv = cv2.cvtColor(np.uint8([[list(grass_color)]]), cv2.COLOR_BGR2HSV)
55
+ for player_img in players:
56
+ hsv = cv2.cvtColor(player_img, cv2.COLOR_BGR2HSV)
57
+ lower_green = np.array([grass_hsv[0, 0, 0] - 10, 40, 40])
58
+ upper_green = np.array([grass_hsv[0, 0, 0] + 10, 255, 255])
59
+ mask = cv2.inRange(hsv, lower_green, upper_green)
60
+ mask = cv2.bitwise_not(mask)
61
+ upper_mask = np.zeros(player_img.shape[:2], np.uint8)
62
+ upper_mask[0:player_img.shape[0] // 2, :] = 255
63
+ mask = cv2.bitwise_and(mask, upper_mask)
64
+ kit_color = np.array(cv2.mean(player_img, mask=mask)[:3])
65
+ kits_colors.append(kit_color)
66
+ return kits_colors
67
+
68
+
69
+ # ── Person detection (new-2 style: tracking, votes, adjust) ───
70
+
71
+ # Internal class IDs: goalkeeper=1, player=2, referee=3
72
+ # Validator output: 0=player, 1=referee, 2=goalkeeper
73
+ _C_GOALKEEPER = 1
74
+ _C_PLAYER = 2
75
+ _C_REFEREE = 3
76
+ _CLS_TO_VALIDATOR: Dict[int, int] = {_C_PLAYER: 0, _C_REFEREE: 1, _C_GOALKEEPER: 2}
77
+
78
+ # Person model: 0=player, 1=referee, 2=goalkeeper (person-detection-model.onnx)
79
+ PERSON_MODEL_IMG_SIZE = 640
80
+ PERSON_CONF = 0.4
81
+ PERSON_HALF = True # FP16 on GPU for faster inference
82
+ TRACK_IOU_THRESH = 0.3
83
+ TRACK_IOU_HIGH = 0.4
84
+ TRACK_IOU_LOW = 0.2
85
+ TRACK_MAX_AGE = 3
86
+ TRACK_USE_VELOCITY = True
87
+ NOISE_MIN_APPEARANCES = 5
88
+ NOISE_TAIL_FRAMES = 4
89
+ CLASS_VOTE_MAJORITY = 3
90
+ INTERP_TRACK_GAPS = True
91
+ ENABLE_BOX_SMOOTHING = False
92
+ BOX_SMOOTH_WINDOW = 8
93
+ OVERLAP_IOU = 0.91
94
+
95
+
96
+ def _iou_box4(a: Tuple[float, float, float, float], b: Tuple[float, float, float, float]) -> float:
97
+ ax1, ay1, ax2, ay2 = a
98
+ bx1, by1, bx2, by2 = b
99
+ ix1, iy1 = max(ax1, bx1), max(ay1, by1)
100
+ ix2, iy2 = min(ax2, bx2), min(ay2, by2)
101
+ iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
102
+ inter = iw * ih
103
+ if inter <= 0:
104
+ return 0.0
105
+ area_a = (ax2 - ax1) * (ay2 - ay1)
106
+ area_b = (bx2 - bx1) * (by2 - by1)
107
+ union = area_a + area_b - inter
108
+ return inter / union if union > 0 else 0.0
109
+
110
+
111
+ def _match_tracks_detections(
112
+ prev_list: List[Tuple[int, Tuple[float, float, float, float]]],
113
+ curr_boxes: List[Tuple[float, float, float, float]],
114
+ iou_thresh: float,
115
+ exclude_prev: set,
116
+ exclude_curr: set,
117
+ ) -> List[Tuple[int, int]]:
118
+ prev_filtered = [(pi, tid, pbox) for pi, (tid, pbox) in enumerate(prev_list) if pi not in exclude_prev]
119
+ curr_filtered = [(ci, cbox) for ci, cbox in enumerate(curr_boxes) if ci not in exclude_curr]
120
+ if not prev_filtered or not curr_filtered:
121
+ return []
122
+ n_prev, n_curr = len(prev_filtered), len(curr_filtered)
123
+ iou_mat = np.zeros((n_prev, n_curr), dtype=np.float64)
124
+ for i, (_, _, pbox) in enumerate(prev_filtered):
125
+ for j, (_, cbox) in enumerate(curr_filtered):
126
+ iou_mat[i, j] = _iou_box4(pbox, cbox)
127
+ cost = 1.0 - iou_mat
128
+ cost[iou_mat < iou_thresh] = 1e9
129
+ if _linear_sum_assignment is not None:
130
+ row_ind, col_ind = _linear_sum_assignment(cost)
131
+ matches = [
132
+ (prev_filtered[row_ind[k]][0], curr_filtered[col_ind[k]][0])
133
+ for k in range(len(row_ind))
134
+ if cost[row_ind[k], col_ind[k]] < 1.0
135
+ ]
136
+ else:
137
+ matches = []
138
+ iou_pairs = [
139
+ (iou_mat[i, j], i, j)
140
+ for i in range(n_prev)
141
+ for j in range(n_curr)
142
+ if iou_mat[i, j] >= iou_thresh
143
+ ]
144
+ iou_pairs.sort(key=lambda x: -x[0])
145
+ used_prev, used_curr = set(), set()
146
+ for _, i, j in iou_pairs:
147
+ pi = prev_filtered[i][0]
148
+ ci = curr_filtered[j][0]
149
+ if pi in used_prev or ci in used_curr:
150
+ continue
151
+ matches.append((pi, ci))
152
+ used_prev.add(pi)
153
+ used_curr.add(ci)
154
+ return matches
155
+
156
+
157
+ def _predict_box(prev: Tuple[float, float, float, float], last: Tuple[float, float, float, float]) -> Tuple[float, float, float, float]:
158
+ px1, py1, px2, py2 = prev
159
+ lx1, ly1, lx2, ly2 = last
160
+ pcx = 0.5 * (px1 + px2)
161
+ pcy = 0.5 * (py1 + py2)
162
+ lcx = 0.5 * (lx1 + lx2)
163
+ lcy = 0.5 * (ly1 + ly2)
164
+ w = lx2 - lx1
165
+ h = ly2 - ly1
166
+ ncx = 2.0 * lcx - pcx
167
+ ncy = 2.0 * lcy - pcy
168
+ return (ncx - w * 0.5, ncy - h * 0.5, ncx + w * 0.5, ncy + h * 0.5)
169
+
170
+
171
+ def _assign_person_track_ids(
172
+ prev_state: Dict[int, Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float], int]],
173
+ next_id: int,
174
+ results: list,
175
+ iou_thresh: float = TRACK_IOU_THRESH,
176
+ iou_high: float = TRACK_IOU_HIGH,
177
+ iou_low: float = TRACK_IOU_LOW,
178
+ max_age: int = TRACK_MAX_AGE,
179
+ use_velocity: bool = TRACK_USE_VELOCITY,
180
+ ) -> Tuple[Dict[int, Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float], int]], int, List[List[int]]]:
181
+ state = {tid: (prev_box, last_box, age) for tid, (prev_box, last_box, age) in prev_state.items()}
182
+ nid = next_id
183
+ ids_per_result: List[List[int]] = []
184
+ for result in results:
185
+ if getattr(result, "boxes", None) is None or len(result.boxes) == 0:
186
+ state = {
187
+ tid: (prev_box, last_box, age + 1)
188
+ for tid, (prev_box, last_box, age) in state.items()
189
+ if age + 1 <= max_age
190
+ }
191
+ ids_per_result.append([])
192
+ continue
193
+ b = result.boxes
194
+ xyxy = b.xyxy.cpu().numpy()
195
+ curr_boxes = [tuple(float(x) for x in row) for row in xyxy]
196
+ prev_list: List[Tuple[int, Tuple[float, float, float, float]]] = []
197
+ for tid, (prev_box, last_box, _age) in state.items():
198
+ if use_velocity and (prev_box != last_box):
199
+ pbox = _predict_box(prev_box, last_box)
200
+ else:
201
+ pbox = last_box
202
+ prev_list.append((tid, pbox))
203
+ stage1 = _match_tracks_detections(prev_list, curr_boxes, iou_high, set(), set())
204
+ assigned_prev = {pi for pi, _ in stage1}
205
+ assigned_curr = {ci for _, ci in stage1}
206
+ stage2 = _match_tracks_detections(prev_list, curr_boxes, iou_low, assigned_prev, assigned_curr)
207
+ for pi, ci in stage2:
208
+ assigned_prev.add(pi)
209
+ assigned_curr.add(ci)
210
+ tid_per_curr: Dict[int, int] = {}
211
+ for pi, ci in stage1 + stage2:
212
+ tid_per_curr[ci] = prev_list[pi][0]
213
+ ids: List[int] = []
214
+ new_state: Dict[int, Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float], int]] = {}
215
+ for ci, cbox in enumerate(curr_boxes):
216
+ if ci in tid_per_curr:
217
+ tid = tid_per_curr[ci]
218
+ _prev, last_box, _ = state[tid]
219
+ new_state[tid] = (last_box, cbox, 0)
220
+ else:
221
+ tid = nid
222
+ nid += 1
223
+ new_state[tid] = (cbox, cbox, 0)
224
+ ids.append(tid)
225
+ for pi in range(len(prev_list)):
226
+ if pi in assigned_prev:
227
+ continue
228
+ tid = prev_list[pi][0]
229
+ prev_box, last_box, age = state[tid]
230
+ if age + 1 <= max_age:
231
+ new_state[tid] = (prev_box, last_box, age + 1)
232
+ state = new_state
233
+ ids_per_result.append(ids)
234
+ return (state, nid, ids_per_result)
235
+
236
+
237
+ def _iou_bbox(a: "BoundingBox", b: "BoundingBox") -> float:
238
+ ax1, ay1, ax2, ay2 = int(a.x1), int(a.y1), int(a.x2), int(a.y2)
239
+ bx1, by1, bx2, by2 = int(b.x1), int(b.y1), int(b.x2), int(b.y2)
240
+ ix1, iy1 = max(ax1, bx1), max(ay1, by1)
241
+ ix2, iy2 = min(ax2, bx2), min(ay2, by2)
242
+ iw, ih = max(0, ix2 - ix1), max(0, iy2 - iy1)
243
+ inter = iw * ih
244
+ if inter <= 0:
245
+ return 0.0
246
+ area_a = (ax2 - ax1) * (ay2 - ay1)
247
+ area_b = (bx2 - bx1) * (by2 - by1)
248
+ union = area_a + area_b - inter
249
+ return inter / union if union > 0 else 0.0
250
+
251
+
252
+ def _adjust_boxes(
253
+ bboxes: List["BoundingBox"],
254
+ frame_width: int,
255
+ frame_height: int,
256
+ overlap_iou: float = OVERLAP_IOU,
257
+ do_goalkeeper_dedup: bool = True,
258
+ do_referee_disambiguation: bool = True,
259
+ ) -> List["BoundingBox"]:
260
+ """Overlap NMS, goalkeeper dedup, referee disambiguation (no ball)."""
261
+ kept: List[BoundingBox] = list(bboxes or [])
262
+ W, H = int(frame_width), int(frame_height)
263
+ cy = 0.5 * float(H)
264
+ if overlap_iou > 0 and len(kept) > 1:
265
+ non_balls = [bb for bb in kept if int(bb.cls_id) != 0]
266
+ if len(non_balls) > 1:
267
+ non_balls_sorted = sorted(non_balls, key=lambda bb: float(bb.conf), reverse=True)
268
+ kept_nb = []
269
+ for cand in non_balls_sorted:
270
+ skip = False
271
+ for k in kept_nb:
272
+ iou = _iou_bbox(cand, k)
273
+ if iou >= overlap_iou:
274
+ skip = True
275
+ break
276
+ if (
277
+ abs(int(cand.x1) - int(k.x1)) <= 3
278
+ and abs(int(cand.y1) - int(k.y1)) <= 3
279
+ and abs(int(cand.x2) - int(k.x2)) <= 3
280
+ and abs(int(cand.y2) - int(k.y2)) <= 3
281
+ and iou > 0.85
282
+ ):
283
+ skip = True
284
+ break
285
+ if not skip:
286
+ kept_nb.append(cand)
287
+ kept = kept_nb
288
+ if do_goalkeeper_dedup:
289
+ gks = [bb for bb in kept if int(bb.cls_id) == _C_GOALKEEPER]
290
+ if len(gks) > 1:
291
+ best_gk = max(gks, key=lambda bb: float(bb.conf))
292
+ best_gk_conf = float(best_gk.conf)
293
+ deduped = []
294
+ for bb in kept:
295
+ if int(bb.cls_id) == _C_GOALKEEPER:
296
+ if float(bb.conf) < best_gk_conf or (float(bb.conf) == best_gk_conf and bb is not best_gk):
297
+ deduped.append(BoundingBox(x1=bb.x1, y1=bb.y1, x2=bb.x2, y2=bb.y2, cls_id=_C_PLAYER, conf=float(bb.conf), team_id=bb.team_id, track_id=bb.track_id))
298
+ else:
299
+ deduped.append(bb)
300
+ else:
301
+ deduped.append(bb)
302
+ kept = deduped
303
+ if do_referee_disambiguation:
304
+ refs = [bb for bb in kept if int(bb.cls_id) == _C_REFEREE]
305
+ if len(refs) > 1:
306
+ best_ref = min(refs, key=lambda bb: (0.5 * (bb.y1 + bb.y2) - cy) ** 2)
307
+ kept = [bb for bb in kept if int(bb.cls_id) != _C_REFEREE or bb is best_ref]
308
+ return kept
309
+
310
+
311
+ # ── OSNet team classification (turbo_7 style) ────────────────
312
+
313
+ TEAM_1_ID = 6
314
+ TEAM_2_ID = 7
315
+ PLAYER_CLS_ID = 2
316
+ _OSNET_MODEL = None
317
+ osnet_weight_path = None
318
+
319
+ OSNET_IMAGE_SIZE = (64, 32) # (height, width)
320
+ OSNET_PREPROCESS = T.Compose([
321
+ T.Resize(OSNET_IMAGE_SIZE),
322
+ T.ToTensor(),
323
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
324
+ ])
325
+
326
+
327
+ def _crop_upper_body(frame: ndarray, box: "BoundingBox") -> ndarray:
328
+ return frame[
329
+ max(0, box.y1):max(0, box.y2),
330
+ max(0, box.x1):max(0, box.x2)
331
+ ]
332
+
333
+
334
+ def _preprocess_osnet(crop: ndarray) -> torch.Tensor:
335
+ rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
336
+ pil = Image.fromarray(rgb)
337
+ return OSNET_PREPROCESS(pil)
338
+
339
+
340
+ def _filter_player_boxes(boxes: List["BoundingBox"]) -> List["BoundingBox"]:
341
+ return [b for b in boxes if b.cls_id == PLAYER_CLS_ID]
342
+
343
+
344
+ def _extract_osnet_embeddings(
345
+ frames: List[ndarray],
346
+ batch_boxes: Dict[int, List["BoundingBox"]],
347
+ device: str = "cuda",
348
+ ) -> Tuple[Optional[ndarray], Optional[List["BoundingBox"]]]:
349
+ global _OSNET_MODEL
350
+ crops = []
351
+ meta = []
352
+ sorted_frame_ids = sorted(batch_boxes.keys())
353
+ for idx, frame_idx in enumerate(sorted_frame_ids):
354
+ frame = frames[idx] if idx < len(frames) else None
355
+ if frame is None:
356
+ continue
357
+ boxes = batch_boxes[frame_idx]
358
+ players = _filter_player_boxes(boxes)
359
+ for box in players:
360
+ crop = _crop_upper_body(frame, box)
361
+ if crop.size == 0:
362
+ continue
363
+ crops.append(_preprocess_osnet(crop))
364
+ meta.append(box)
365
+ if not crops:
366
+ return None, None
367
+ batch = torch.stack(crops).to(device, non_blocking=True).float()
368
+ use_amp = device == "cuda"
369
+ with torch.inference_mode():
370
+ with torch.amp.autocast("cuda", enabled=use_amp):
371
+ embeddings = _OSNET_MODEL(batch)
372
+ del batch
373
+ embeddings = embeddings.cpu().numpy()
374
+ return embeddings, meta
375
+
376
+
377
+ def _aggregate_by_track(
378
+ embeddings: ndarray,
379
+ meta: List["BoundingBox"],
380
+ ) -> Tuple[ndarray, List["BoundingBox"]]:
381
+ track_map = defaultdict(list)
382
+ box_map = {}
383
+ for emb, box in zip(embeddings, meta):
384
+ key = box.track_id if box.track_id is not None else id(box)
385
+ track_map[key].append(emb)
386
+ box_map[key] = box
387
+ agg_embeddings = []
388
+ agg_boxes = []
389
+ for key, embs in track_map.items():
390
+ mean_emb = np.mean(embs, axis=0)
391
+ norm = np.linalg.norm(mean_emb)
392
+ if norm > 1e-12:
393
+ mean_emb /= norm
394
+ agg_embeddings.append(mean_emb)
395
+ agg_boxes.append(box_map[key])
396
+ return np.array(agg_embeddings), agg_boxes
397
+
398
+
399
+ def _update_team_ids(boxes: List["BoundingBox"], labels: ndarray) -> None:
400
+ for box, label in zip(boxes, labels):
401
+ # box.cls_id = TEAM_1_ID if label == 0 else TEAM_2_ID
402
+ box.team_id = 1 if label == 0 else 2
403
+
404
+
405
+ def _classify_teams_batch(
406
+ frames: List[ndarray],
407
+ batch_boxes: Dict[int, List["BoundingBox"]],
408
+ device: str = "cuda",
409
+ ) -> None:
410
+ embeddings, meta = _extract_osnet_embeddings(frames, batch_boxes, device)
411
+ if embeddings is None:
412
+ return
413
+ embeddings, agg_boxes = _aggregate_by_track(embeddings, meta)
414
+ n = len(embeddings)
415
+ if n == 0:
416
+ return
417
+ if n == 1:
418
+ agg_boxes[0].cls_id = TEAM_1_ID
419
+ return
420
+ kmeans = KMeans(n_clusters=2, n_init=2, random_state=42)
421
+ kmeans.fit(embeddings)
422
+ centroids = kmeans.cluster_centers_
423
+ c0, c1 = centroids[0], centroids[1]
424
+ norm_0 = np.linalg.norm(c0)
425
+ norm_1 = np.linalg.norm(c1)
426
+ similarity = np.dot(c0, c1) / (norm_0 * norm_1 + 1e-12)
427
+ if similarity > 0.95:
428
+ for b in agg_boxes:
429
+ b.cls_id = TEAM_1_ID
430
+ return
431
+ if norm_0 <= norm_1:
432
+ kmeans.labels_ = 1 - kmeans.labels_
433
+ _update_team_ids(agg_boxes, kmeans.labels_)
434
+
435
+
436
+ class ConvLayer(nn.Module):
437
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, IN=False):
438
+ super().__init__()
439
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False, groups=groups)
440
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True) if IN else nn.BatchNorm2d(out_channels)
441
+ self.relu = nn.ReLU()
442
+
443
+ def forward(self, x):
444
+ return self.relu(self.bn(self.conv(x)))
445
+
446
+
447
+ class Conv1x1(nn.Module):
448
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
449
+ super().__init__()
450
+ self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False, groups=groups)
451
+ self.bn = nn.BatchNorm2d(out_channels)
452
+ self.relu = nn.ReLU()
453
+
454
+ def forward(self, x):
455
+ return self.relu(self.bn(self.conv(x)))
456
+
457
+
458
+ class Conv1x1Linear(nn.Module):
459
+ def __init__(self, in_channels, out_channels, stride=1, bn=True):
460
+ super().__init__()
461
+ self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False)
462
+ self.bn = nn.BatchNorm2d(out_channels) if bn else None
463
+
464
+ def forward(self, x):
465
+ x = self.conv(x)
466
+ return self.bn(x) if self.bn is not None else x
467
+
468
+
469
+ class Conv3x3(nn.Module):
470
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
471
+ super().__init__()
472
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False, groups=groups)
473
+ self.bn = nn.BatchNorm2d(out_channels)
474
+ self.relu = nn.ReLU()
475
+
476
+ def forward(self, x):
477
+ return self.relu(self.bn(self.conv(x)))
478
+
479
+
480
+ class LightConv3x3(nn.Module):
481
+ def __init__(self, in_channels, out_channels):
482
+ super().__init__()
483
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
484
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels)
485
+ self.bn = nn.BatchNorm2d(out_channels)
486
+ self.relu = nn.ReLU()
487
+
488
+ def forward(self, x):
489
+ x = self.conv1(x)
490
+ x = self.conv2(x)
491
+ return self.relu(self.bn(x))
492
+
493
+
494
+ class LightConvStream(nn.Module):
495
+ def __init__(self, in_channels, out_channels, depth):
496
+ super().__init__()
497
+ layers = [LightConv3x3(in_channels, out_channels)]
498
+ for _ in range(depth - 1):
499
+ layers.append(LightConv3x3(out_channels, out_channels))
500
+ self.layers = nn.Sequential(*layers)
501
+
502
+ def forward(self, x):
503
+ return self.layers(x)
504
+
505
+
506
+ class ChannelGate(nn.Module):
507
+ def __init__(self, in_channels, num_gates=None, return_gates=False, gate_activation='sigmoid', reduction=16, layer_norm=False):
508
+ super().__init__()
509
+ if num_gates is None:
510
+ num_gates = in_channels
511
+ self.return_gates = return_gates
512
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
513
+ self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0)
514
+ self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) if layer_norm else None
515
+ self.relu = nn.ReLU()
516
+ self.fc2 = nn.Conv2d(in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0)
517
+ self.gate_activation = nn.Sigmoid() if gate_activation == 'sigmoid' else nn.ReLU()
518
+
519
+ def forward(self, x):
520
+ input = x
521
+ x = self.global_avgpool(x)
522
+ x = self.fc1(x)
523
+ if self.norm1 is not None:
524
+ x = self.norm1(x)
525
+ x = self.relu(x)
526
+ x = self.fc2(x)
527
+ if self.gate_activation is not None:
528
+ x = self.gate_activation(x)
529
+ return x if self.return_gates else input * x
530
+
531
+
532
+ class OSBlockX1(nn.Module):
533
+ def __init__(self, in_channels, out_channels, IN=False, bottleneck_reduction=4):
534
+ super().__init__()
535
+ mid_channels = out_channels // bottleneck_reduction
536
+ self.conv1 = Conv1x1(in_channels, mid_channels)
537
+ self.conv2a = LightConv3x3(mid_channels, mid_channels)
538
+ self.conv2b = nn.Sequential(LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels))
539
+ self.conv2c = nn.Sequential(LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels))
540
+ self.conv2d = nn.Sequential(LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels))
541
+ self.gate = ChannelGate(mid_channels)
542
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
543
+ self.downsample = Conv1x1Linear(in_channels, out_channels) if in_channels != out_channels else None
544
+ self.IN = nn.InstanceNorm2d(out_channels, affine=True) if IN else None
545
+
546
+ def forward(self, x):
547
+ identity = x
548
+ x1 = self.conv1(x)
549
+ x2 = self.gate(self.conv2a(x1)) + self.gate(self.conv2b(x1)) + self.gate(self.conv2c(x1)) + self.gate(self.conv2d(x1))
550
+ x3 = self.conv3(x2)
551
+ if self.downsample is not None:
552
+ identity = self.downsample(identity)
553
+ out = x3 + identity
554
+ if self.IN is not None:
555
+ out = self.IN(out)
556
+ return F.relu(out)
557
+
558
+
559
+ class OSNetX1(nn.Module):
560
+ def __init__(self, num_classes, blocks, layers, channels, feature_dim=512, loss='softmax', IN=False):
561
+ super().__init__()
562
+ self.loss = loss
563
+ self.feature_dim = feature_dim
564
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
565
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
566
+ self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1], reduce_spatial_size=True, IN=IN)
567
+ self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True)
568
+ self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False)
569
+ self.conv5 = Conv1x1(channels[3], channels[3])
570
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
571
+ self.fc = self._construct_fc_layer(feature_dim, channels[3], dropout_p=None)
572
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
573
+ self._init_params()
574
+
575
+ def _make_layer(self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False):
576
+ layers_list = [block(in_channels, out_channels, IN=IN)]
577
+ for _ in range(1, layer):
578
+ layers_list.append(block(out_channels, out_channels, IN=IN))
579
+ if reduce_spatial_size:
580
+ layers_list.append(nn.Sequential(Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2)))
581
+ return nn.Sequential(*layers_list)
582
+
583
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
584
+ if fc_dims is None or fc_dims < 0:
585
+ self.feature_dim = input_dim
586
+ return None
587
+ if isinstance(fc_dims, int):
588
+ fc_dims = [fc_dims]
589
+ layers_list = []
590
+ for dim in fc_dims:
591
+ layers_list.append(nn.Linear(input_dim, dim))
592
+ layers_list.append(nn.BatchNorm1d(dim))
593
+ layers_list.append(nn.ReLU(inplace=True))
594
+ if dropout_p is not None:
595
+ layers_list.append(nn.Dropout(p=dropout_p))
596
+ input_dim = dim
597
+ self.feature_dim = fc_dims[-1]
598
+ return nn.Sequential(*layers_list)
599
+
600
+ def _init_params(self):
601
+ for m in self.modules():
602
+ if isinstance(m, nn.Conv2d):
603
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
604
+ if m.bias is not None:
605
+ nn.init.constant_(m.bias, 0)
606
+ elif isinstance(m, nn.BatchNorm2d):
607
+ nn.init.constant_(m.weight, 1)
608
+ nn.init.constant_(m.bias, 0)
609
+ elif isinstance(m, nn.BatchNorm1d):
610
+ nn.init.constant_(m.weight, 1)
611
+ nn.init.constant_(m.bias, 0)
612
+ elif isinstance(m, nn.InstanceNorm2d):
613
+ nn.init.constant_(m.weight, 1)
614
+ nn.init.constant_(m.bias, 0)
615
+ elif isinstance(m, nn.Linear):
616
+ nn.init.normal_(m.weight, 0, 0.01)
617
+ if m.bias is not None:
618
+ nn.init.constant_(m.bias, 0)
619
+
620
+ def forward(self, x, return_featuremaps=False):
621
+ x = self.conv1(x)
622
+ x = self.maxpool(x)
623
+ x = self.conv2(x)
624
+ x = self.conv3(x)
625
+ x = self.conv4(x)
626
+ x = self.conv5(x)
627
+ if return_featuremaps:
628
+ return x
629
+ v = self.global_avgpool(x)
630
+ v = v.view(v.size(0), -1)
631
+ if self.fc is not None:
632
+ v = self.fc(v)
633
+ if not self.training:
634
+ return v
635
+ y = self.classifier(v)
636
+ if self.loss == 'softmax':
637
+ return y
638
+ elif self.loss == 'triplet':
639
+ return y, v
640
+ raise KeyError(f"Unsupported loss: {self.loss}")
641
+
642
+
643
+ def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
644
+ return OSNetX1(
645
+ num_classes,
646
+ blocks=[OSBlockX1, OSBlockX1, OSBlockX1],
647
+ layers=[2, 2, 2],
648
+ channels=[64, 256, 384, 512],
649
+ loss=loss,
650
+ **kwargs,
651
+ )
652
+
653
+
654
+ def load_checkpoint_osnet(fpath):
655
+ fpath = os.path.abspath(os.path.expanduser(fpath))
656
+ map_location = None if torch.cuda.is_available() else 'cpu'
657
+ checkpoint = torch.load(fpath, map_location=map_location, weights_only=False)
658
+ return checkpoint
659
+
660
+
661
+ def load_pretrained_weights_osnet(model, weight_path):
662
+ checkpoint = load_checkpoint_osnet(weight_path)
663
+ state_dict = checkpoint.get('state_dict', checkpoint)
664
+ model_dict = model.state_dict()
665
+ new_state_dict = OrderedDict()
666
+ for k, v in state_dict.items():
667
+ if k.startswith('module.'):
668
+ k = k[7:]
669
+ if k in model_dict and model_dict[k].size() == v.size():
670
+ new_state_dict[k] = v
671
+ model_dict.update(new_state_dict)
672
+ model.load_state_dict(model_dict)
673
+
674
+
675
+ def load_osnet(device="cuda", weight_path=None):
676
+ model = osnet_x1_0(num_classes=1, loss='softmax', pretrained=False)
677
+ weight_path = Path(weight_path) if weight_path else None
678
+ if weight_path and weight_path.exists():
679
+ load_pretrained_weights_osnet(model, str(weight_path))
680
+ model.eval()
681
+ model.to(device)
682
+ return model
683
+
684
+
685
+ def _resolve_player_cls_id(model: YOLO, fallback: int = PLAYER_CLS_ID) -> int:
686
+ names = getattr(model, "names", None)
687
+ if not names:
688
+ names = getattr(getattr(model, "model", None), "names", None)
689
+ if isinstance(names, dict):
690
+ for idx, name in names.items():
691
+ if str(name).lower() in ("player", "players"):
692
+ return int(idx)
693
+ if isinstance(names, list):
694
+ for idx, name in enumerate(names):
695
+ if str(name).lower() in ("player", "players"):
696
+ return int(idx)
697
+ return fallback
698
+
699
+
700
+ # ── HRNet architecture ───────────────────────────────────────────
701
+
702
+ BatchNorm2d = nn.BatchNorm2d
703
+ BN_MOMENTUM = 0.1
704
+
705
+ def conv3x3(in_planes, out_planes, stride=1):
706
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
707
+
708
+ class BasicBlock(nn.Module):
709
+ expansion = 1
710
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
711
+ super().__init__()
712
+ self.conv1 = conv3x3(inplanes, planes, stride)
713
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
714
+ self.relu = nn.ReLU(inplace=True)
715
+ self.conv2 = conv3x3(planes, planes)
716
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
717
+ self.downsample = downsample
718
+ self.stride = stride
719
+
720
+ def forward(self, x):
721
+ residual = x
722
+ out = self.relu(self.bn1(self.conv1(x)))
723
+ out = self.bn2(self.conv2(out))
724
+ if self.downsample is not None:
725
+ residual = self.downsample(x)
726
+ return self.relu(out + residual)
727
+
728
+ class Bottleneck(nn.Module):
729
+ expansion = 4
730
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
731
+ super().__init__()
732
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
733
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
734
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
735
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
736
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
737
+ self.bn3 = BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
738
+ self.relu = nn.ReLU(inplace=True)
739
+ self.downsample = downsample
740
+ self.stride = stride
741
+
742
+ def forward(self, x):
743
+ residual = x
744
+ out = self.relu(self.bn1(self.conv1(x)))
745
+ out = self.relu(self.bn2(self.conv2(out)))
746
+ out = self.bn3(self.conv3(out))
747
+ if self.downsample is not None:
748
+ residual = self.downsample(x)
749
+ return self.relu(out + residual)
750
+
751
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
752
+
753
+ class HighResolutionModule(nn.Module):
754
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
755
+ num_channels, fuse_method, multi_scale_output=True):
756
+ super().__init__()
757
+ self.num_inchannels = num_inchannels
758
+ self.fuse_method = fuse_method
759
+ self.num_branches = num_branches
760
+ self.multi_scale_output = multi_scale_output
761
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
762
+ self.fuse_layers = self._make_fuse_layers()
763
+ self.relu = nn.ReLU(inplace=True)
764
+
765
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
766
+ downsample = None
767
+ if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
768
+ downsample = nn.Sequential(
769
+ nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion,
770
+ kernel_size=1, stride=stride, bias=False),
771
+ BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
772
+ )
773
+ layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)]
774
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
775
+ for _ in range(1, num_blocks[branch_index]):
776
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
777
+ return nn.Sequential(*layers)
778
+
779
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
780
+ return nn.ModuleList([self._make_one_branch(i, block, num_blocks, num_channels) for i in range(num_branches)])
781
+
782
+ def _make_fuse_layers(self):
783
+ if self.num_branches == 1:
784
+ return None
785
+ num_branches = self.num_branches
786
+ num_inchannels = self.num_inchannels
787
+ fuse_layers = []
788
+ for i in range(num_branches if self.multi_scale_output else 1):
789
+ fuse_layer = []
790
+ for j in range(num_branches):
791
+ if j > i:
792
+ fuse_layer.append(nn.Sequential(
793
+ nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
794
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
795
+ elif j == i:
796
+ fuse_layer.append(None)
797
+ else:
798
+ conv3x3s = []
799
+ for k in range(i - j):
800
+ if k == i - j - 1:
801
+ conv3x3s.append(nn.Sequential(
802
+ nn.Conv2d(num_inchannels[j], num_inchannels[i], 3, 2, 1, bias=False),
803
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
804
+ else:
805
+ conv3x3s.append(nn.Sequential(
806
+ nn.Conv2d(num_inchannels[j], num_inchannels[j], 3, 2, 1, bias=False),
807
+ BatchNorm2d(num_inchannels[j], momentum=BN_MOMENTUM),
808
+ nn.ReLU(inplace=True)))
809
+ fuse_layer.append(nn.Sequential(*conv3x3s))
810
+ fuse_layers.append(nn.ModuleList(fuse_layer))
811
+ return nn.ModuleList(fuse_layers)
812
+
813
+ def get_num_inchannels(self):
814
+ return self.num_inchannels
815
+
816
+ def forward(self, x):
817
+ if self.num_branches == 1:
818
+ return [self.branches[0](x[0])]
819
+ for i in range(self.num_branches):
820
+ x[i] = self.branches[i](x[i])
821
+ x_fuse = []
822
+ for i in range(len(self.fuse_layers)):
823
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
824
+ for j in range(1, self.num_branches):
825
+ if i == j:
826
+ y = y + x[j]
827
+ elif j > i:
828
+ y = y + F.interpolate(self.fuse_layers[i][j](x[j]),
829
+ size=[x[i].shape[2], x[i].shape[3]], mode='bilinear')
830
+ else:
831
+ y = y + self.fuse_layers[i][j](x[j])
832
+ x_fuse.append(self.relu(y))
833
+ return x_fuse
834
+
835
+ class HighResolutionNet(nn.Module):
836
+ def __init__(self, config, lines=False, **kwargs):
837
+ self.inplanes = 64
838
+ self.lines = lines
839
+ extra = config['MODEL']['EXTRA']
840
+ super().__init__()
841
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
842
+ self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
843
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
844
+ self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
845
+ self.relu = nn.ReLU(inplace=True)
846
+ self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
847
+
848
+ self.stage2_cfg = extra['STAGE2']
849
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
850
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
851
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
852
+ self.transition1 = self._make_transition_layer([256], num_channels)
853
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
854
+
855
+ self.stage3_cfg = extra['STAGE3']
856
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
857
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
858
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
859
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
860
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
861
+
862
+ self.stage4_cfg = extra['STAGE4']
863
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
864
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
865
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
866
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
867
+ self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True)
868
+
869
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
870
+ final_inp_channels = sum(pre_stage_channels) + self.inplanes
871
+ self.head = nn.Sequential(nn.Sequential(
872
+ nn.Conv2d(final_inp_channels, final_inp_channels, kernel_size=1),
873
+ BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
874
+ nn.ReLU(inplace=True),
875
+ nn.Conv2d(final_inp_channels, config['MODEL']['NUM_JOINTS'], kernel_size=extra['FINAL_CONV_KERNEL']),
876
+ nn.Softmax(dim=1) if not self.lines else nn.Sigmoid()))
877
+
878
+ def _make_head(self, x, x_skip):
879
+ x = self.upsample(x)
880
+ x = torch.cat([x, x_skip], dim=1)
881
+ return self.head(x)
882
+
883
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
884
+ num_branches_cur = len(num_channels_cur_layer)
885
+ num_branches_pre = len(num_channels_pre_layer)
886
+ transition_layers = []
887
+ for i in range(num_branches_cur):
888
+ if i < num_branches_pre:
889
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
890
+ transition_layers.append(nn.Sequential(
891
+ nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
892
+ BatchNorm2d(num_channels_cur_layer[i], momentum=BN_MOMENTUM),
893
+ nn.ReLU(inplace=True)))
894
+ else:
895
+ transition_layers.append(None)
896
+ else:
897
+ conv3x3s = []
898
+ for j in range(i + 1 - num_branches_pre):
899
+ inchannels = num_channels_pre_layer[-1]
900
+ outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
901
+ conv3x3s.append(nn.Sequential(
902
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
903
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
904
+ nn.ReLU(inplace=True)))
905
+ transition_layers.append(nn.Sequential(*conv3x3s))
906
+ return nn.ModuleList(transition_layers)
907
+
908
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
909
+ downsample = None
910
+ if stride != 1 or inplanes != planes * block.expansion:
911
+ downsample = nn.Sequential(
912
+ nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
913
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
914
+ )
915
+ layers = [block(inplanes, planes, stride, downsample)]
916
+ inplanes = planes * block.expansion
917
+ for _ in range(1, blocks):
918
+ layers.append(block(inplanes, planes))
919
+ return nn.Sequential(*layers)
920
+
921
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
922
+ num_modules = layer_config['NUM_MODULES']
923
+ num_branches = layer_config['NUM_BRANCHES']
924
+ num_blocks = layer_config['NUM_BLOCKS']
925
+ num_channels = layer_config['NUM_CHANNELS']
926
+ block = blocks_dict[layer_config['BLOCK']]
927
+ fuse_method = layer_config['FUSE_METHOD']
928
+ modules = []
929
+ for i in range(num_modules):
930
+ reset_multi_scale_output = True if multi_scale_output or i < num_modules - 1 else False
931
+ modules.append(HighResolutionModule(
932
+ num_branches, block, num_blocks, num_inchannels,
933
+ num_channels, fuse_method, reset_multi_scale_output))
934
+ num_inchannels = modules[-1].get_num_inchannels()
935
+ return nn.Sequential(*modules), num_inchannels
936
+
937
+ def forward(self, x):
938
+ x = self.conv1(x)
939
+ x_skip = x.clone()
940
+ x = self.relu(self.bn1(x))
941
+ x = self.relu(self.bn2(self.conv2(x)))
942
+ x = self.layer1(x)
943
+
944
+ x_list = []
945
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
946
+ x_list.append(self.transition1[i](x) if self.transition1[i] is not None else x)
947
+ y_list = self.stage2(x_list)
948
+
949
+ x_list = []
950
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
951
+ x_list.append(self.transition2[i](y_list[-1]) if self.transition2[i] is not None else y_list[i])
952
+ y_list = self.stage3(x_list)
953
+
954
+ x_list = []
955
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
956
+ x_list.append(self.transition3[i](y_list[-1]) if self.transition3[i] is not None else y_list[i])
957
+ x = self.stage4(x_list)
958
+
959
+ height, width = x[0].size(2), x[0].size(3)
960
+ x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False)
961
+ x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False)
962
+ x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False)
963
+ x = torch.cat([x[0], x1, x2, x3], 1)
964
+ return self._make_head(x, x_skip)
965
+
966
+ def init_weights(self, pretrained=''):
967
+ for m in self.modules():
968
+ if isinstance(m, nn.Conv2d):
969
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
970
+ elif isinstance(m, nn.BatchNorm2d):
971
+ nn.init.constant_(m.weight, 1)
972
+ nn.init.constant_(m.bias, 0)
973
+ if pretrained:
974
+ if os.path.isfile(pretrained):
975
+ pretrained_dict = torch.load(pretrained)
976
+ model_dict = self.state_dict()
977
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
978
+ model_dict.update(pretrained_dict)
979
+ self.load_state_dict(model_dict)
980
+ else:
981
+ sys.exit(f'Weights {pretrained} not found.')
982
+
983
+ def get_cls_net(config, pretrained='', **kwargs):
984
+ model = HighResolutionNet(config, **kwargs)
985
+ model.init_weights(pretrained)
986
+ return model
987
+
988
+
989
+ # ── Keypoint mapping & inference helpers ─────────────────────────
990
+
991
+ map_keypoints = {
992
+ 1: 1, 2: 14, 3: 25, 4: 2, 5: 10, 6: 18, 7: 26, 8: 3, 9: 7, 10: 23,
993
+ 11: 27, 20: 4, 21: 8, 22: 24, 23: 28, 24: 5, 25: 13, 26: 21, 27: 29,
994
+ 28: 6, 29: 17, 30: 30, 31: 11, 32: 15, 33: 19, 34: 12, 35: 16, 36: 20,
995
+ 45: 9, 50: 31, 52: 32, 57: 22
996
+ }
997
+
998
+ # Template keypoints for homography refinement (new-5 style)
999
+ TEMPLATE_F0: List[Tuple[float, float]] = [
1000
+ (5, 5), (5, 140), (5, 250), (5, 430), (5, 540), (5, 675), (55, 250), (55, 430),
1001
+ (110, 340), (165, 140), (165, 270), (165, 410), (165, 540), (527, 5), (527, 253),
1002
+ (527, 433), (527, 675), (888, 140), (888, 270), (888, 410), (888, 540), (940, 340),
1003
+ (998, 250), (998, 430), (1045, 5), (1045, 140), (1045, 250), (1045, 430), (1045, 540),
1004
+ (1045, 675), (435, 340), (615, 340),
1005
+ ]
1006
+ TEMPLATE_F1: List[Tuple[float, float]] = [
1007
+ (2.5, 2.5), (2.5, 139.5), (2.5, 249.5), (2.5, 430.5), (2.5, 540.5), (2.5, 678),
1008
+ (54.5, 249.5), (54.5, 430.5), (110.5, 340.5), (164.5, 139.5), (164.5, 269), (164.5, 411),
1009
+ (164.5, 540.5), (525, 2.5), (525, 249.5), (525, 430.5), (525, 678), (886.5, 139.5),
1010
+ (886.5, 269), (886.5, 411), (886.5, 540.5), (940.5, 340.5), (998, 249.5), (998, 430.5),
1011
+ (1048, 2.5), (1048, 139.5), (1048, 249.5), (1048, 430.5), (1048, 540.5), (1048, 678),
1012
+ (434.5, 340), (615.5, 340),
1013
+ ]
1014
+ HOMOGRAPHY_FILL_ONLY_VALID = True
1015
+ KP_THRESHOLD = 0.2 # new-5 style (was 0.3)
1016
+ # HRNet: smaller input = faster; 432x768 balances speed/accuracy (new-2 style)
1017
+ KP_H, KP_W = 540, 960
1018
+ HRNET_BATCH_SIZE = 24 # larger batch = faster (if GPU mem allows)
1019
+
1020
+
1021
+ def _preprocess_batch(frames):
1022
+ target_h, target_w = KP_H, KP_W
1023
+ batch = []
1024
+ for frame in frames:
1025
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1026
+ img = cv2.resize(img, (target_w, target_h)).astype(np.float32) / 255.0
1027
+ batch.append(np.transpose(img, (2, 0, 1)))
1028
+ return torch.from_numpy(np.stack(batch)).float()
1029
+
1030
+
1031
+ def _extract_keypoints(heatmap: torch.Tensor, scale: int = 2):
1032
+ b, c, h, w = heatmap.shape
1033
+ max_pooled = F.max_pool2d(heatmap, 3, stride=1, padding=1)
1034
+ local_maxima = (max_pooled == heatmap)
1035
+ masked = heatmap * local_maxima
1036
+ flat = masked.view(b, c, -1)
1037
+ scores, indices = torch.topk(flat, 1, dim=-1, sorted=False)
1038
+ y_coords = torch.div(indices, w, rounding_mode="floor") * scale
1039
+ x_coords = (indices % w) * scale
1040
+ return torch.stack([x_coords.float(), y_coords.float(), scores], dim=-1)
1041
+
1042
+
1043
+ def _process_keypoints(kp_coords, threshold, w, h, batch_size):
1044
+ kp_np = kp_coords.cpu().numpy()
1045
+ results = []
1046
+ for b_idx in range(batch_size):
1047
+ kp_dict = {}
1048
+ valid = np.where(kp_np[b_idx, :, 0, 2] > threshold)[0]
1049
+ for ch_idx in valid:
1050
+ kp_dict[ch_idx + 1] = {
1051
+ 'x': float(kp_np[b_idx, ch_idx, 0, 0]) / w,
1052
+ 'y': float(kp_np[b_idx, ch_idx, 0, 1]) / h,
1053
+ 'p': float(kp_np[b_idx, ch_idx, 0, 2]),
1054
+ }
1055
+ results.append(kp_dict)
1056
+ return results
1057
+
1058
+
1059
+ def _run_hrnet_batch(frames, model, threshold, batch_size=16):
1060
+ if not frames or model is None:
1061
+ return []
1062
+ device = next(model.parameters()).device
1063
+ use_amp = device.type == "cuda"
1064
+ results = []
1065
+ for i in range(0, len(frames), batch_size):
1066
+ chunk = frames[i:i + batch_size]
1067
+ batch = _preprocess_batch(chunk).to(device, non_blocking=True)
1068
+ with torch.inference_mode():
1069
+ with torch.amp.autocast("cuda", enabled=use_amp):
1070
+ heatmaps = model(batch)
1071
+ kp_coords = _extract_keypoints(heatmaps[:, :-1, :, :], scale=2)
1072
+ batch_kps = _process_keypoints(kp_coords, threshold, KP_W, KP_H, len(chunk))
1073
+ results.extend(batch_kps)
1074
+ del heatmaps, kp_coords, batch
1075
+ if results:
1076
+ gc.collect()
1077
+ return results
1078
+
1079
+
1080
+ def _apply_keypoint_mapping(kp_dict):
1081
+ return {map_keypoints[k]: v for k, v in kp_dict.items() if k in map_keypoints}
1082
+
1083
+
1084
+ def _normalize_keypoints(kp_results, frames, n_keypoints):
1085
+ keypoints = []
1086
+ max_frames = min(len(kp_results), len(frames))
1087
+ for i in range(max_frames):
1088
+ kp_dict = kp_results[i]
1089
+ h, w = frames[i].shape[:2]
1090
+ frame_kps = []
1091
+ for idx in range(n_keypoints):
1092
+ kp_idx = idx + 1
1093
+ x, y = 0, 0
1094
+ if kp_idx in kp_dict:
1095
+ d = kp_dict[kp_idx]
1096
+ if isinstance(d, dict) and 'x' in d:
1097
+ x = int(d['x'] * w)
1098
+ y = int(d['y'] * h)
1099
+ frame_kps.append((x, y))
1100
+ keypoints.append(frame_kps)
1101
+ return keypoints
1102
+
1103
+
1104
+ def _fix_keypoints(kps: list, n: int) -> list:
1105
+ if len(kps) < n:
1106
+ kps += [(0, 0)] * (n - len(kps))
1107
+ elif len(kps) > n:
1108
+ kps = kps[:n]
1109
+
1110
+ if kps[2] != (0,0) and kps[4] != (0,0) and kps[3] == (0,0):
1111
+ kps[3] = kps[4]; kps[4] = (0,0)
1112
+ if kps[0] != (0,0) and kps[4] != (0,0) and kps[1] == (0,0):
1113
+ kps[1] = kps[4]; kps[4] = (0,0)
1114
+ if kps[2] != (0,0) and kps[3] != (0,0) and kps[1] == (0,0) and kps[3][0] > kps[2][0]:
1115
+ kps[1] = kps[3]; kps[3] = (0,0)
1116
+ if kps[28] != (0,0) and kps[25] == (0,0) and kps[26] != (0,0) and kps[26][0] > kps[28][0]:
1117
+ kps[25] = kps[28]; kps[28] = (0,0)
1118
+ if kps[24] != (0,0) and kps[28] != (0,0) and kps[25] == (0,0):
1119
+ kps[25] = kps[28]; kps[28] = (0,0)
1120
+ if kps[24] != (0,0) and kps[27] != (0,0) and kps[26] == (0,0):
1121
+ kps[26] = kps[27]; kps[27] = (0,0)
1122
+ if kps[28] != (0,0) and kps[23] == (0,0) and kps[20] != (0,0) and kps[20][1] > kps[23][1]:
1123
+ kps[23] = kps[20]; kps[20] = (0,0)
1124
+ return kps
1125
+
1126
+
1127
+ def _keypoints_to_float(keypoints: list) -> List[List[float]]:
1128
+ """Convert keypoints to [[x, y], ...] float format for homography."""
1129
+ return [[float(x), float(y)] for x, y in keypoints]
1130
+
1131
+
1132
+ def _keypoints_to_int(keypoints: list) -> List[Tuple[int, int]]:
1133
+ """Convert keypoints to [(x, y), ...] integer format."""
1134
+ return [(int(round(float(kp[0]))), int(round(float(kp[1])))) for kp in keypoints]
1135
+
1136
+
1137
+ def _apply_homography_refinement(
1138
+ keypoints: List[List[float]],
1139
+ frame: np.ndarray,
1140
+ n_keypoints: int,
1141
+ ) -> List[List[float]]:
1142
+ """Refine keypoints using homography from template to frame (new-5 style)."""
1143
+ if n_keypoints != 32 or len(TEMPLATE_F0) != 32 or len(TEMPLATE_F1) != 32:
1144
+ return keypoints
1145
+ frame_height, frame_width = frame.shape[:2]
1146
+ valid_src: List[Tuple[float, float]] = []
1147
+ valid_dst: List[Tuple[float, float]] = []
1148
+ valid_indices: List[int] = []
1149
+ for kp_idx, kp in enumerate(keypoints):
1150
+ if kp and len(kp) >= 2:
1151
+ x, y = float(kp[0]), float(kp[1])
1152
+ if not (abs(x) < 1e-6 and abs(y) < 1e-6) and 0 <= x < frame_width and 0 <= y < frame_height:
1153
+ valid_src.append(TEMPLATE_F1[kp_idx])
1154
+ valid_dst.append((x, y))
1155
+ valid_indices.append(kp_idx)
1156
+ if len(valid_src) < 4:
1157
+ return keypoints
1158
+ src_pts = np.array(valid_src, dtype=np.float32)
1159
+ dst_pts = np.array(valid_dst, dtype=np.float32)
1160
+ H, _ = cv2.findHomography(src_pts, dst_pts)
1161
+ if H is None:
1162
+ return keypoints
1163
+ all_template_points = np.array(TEMPLATE_F0, dtype=np.float32).reshape(-1, 1, 2)
1164
+ adjusted_points = cv2.perspectiveTransform(all_template_points, H)
1165
+ adjusted_points = adjusted_points.reshape(-1, 2)
1166
+ adj_x = adjusted_points[:32, 0]
1167
+ adj_y = adjusted_points[:32, 1]
1168
+ valid_mask = (adj_x >= 0) & (adj_y >= 0) & (adj_x < frame_width) & (adj_y < frame_height)
1169
+ valid_indices_set = set(valid_indices)
1170
+ adjusted_kps: List[List[float]] = [[0.0, 0.0] for _ in range(32)]
1171
+ for i in np.where(valid_mask)[0]:
1172
+ if not HOMOGRAPHY_FILL_ONLY_VALID or i in valid_indices_set:
1173
+ adjusted_kps[i] = [float(adj_x[i]), float(adj_y[i])]
1174
+ return adjusted_kps
1175
+
1176
+
1177
+ # ── Pydantic models ───────────────────────────────────────────────────────────
1178
+
1179
+ # Team assignment: 6 = team 1, 7 = team 2
1180
+ TEAM_1_ID = 6
1181
+ TEAM_2_ID = 7
1182
+ PLAYER_CLS_ID = 2
1183
+
1184
+
1185
+ class BoundingBox(BaseModel):
1186
+ x1: int
1187
+ y1: int
1188
+ x2: int
1189
+ y2: int
1190
+ cls_id: int
1191
+ conf: float
1192
+ team_id: Optional[int] = None
1193
+ track_id: Optional[int] = None
1194
+
1195
+ class TVFrameResult(BaseModel):
1196
+ frame_id: int
1197
+ boxes: list[BoundingBox]
1198
+ keypoints: List[Tuple[int, int]] # [(x, y), ...] integer coordinates
1199
+
1200
+
1201
+ def _smooth_boxes(
1202
+ results: List[TVFrameResult],
1203
+ window: int = BOX_SMOOTH_WINDOW,
1204
+ tids_by_frame: Optional[Dict[int, List[Optional[int]]]] = None,
1205
+ ) -> List[TVFrameResult]:
1206
+ """Temporal box smoothing by track ID."""
1207
+ if window <= 1 or not results:
1208
+ return results
1209
+ fid_to_idx = {r.frame_id: i for i, r in enumerate(results)}
1210
+ trajectories: Dict[int, List[Tuple[int, int, BoundingBox]]] = {}
1211
+ for i, r in enumerate(results):
1212
+ for j, bb in enumerate(r.boxes):
1213
+ tid = tids_by_frame.get(r.frame_id, [None] * len(r.boxes))[j] if tids_by_frame else bb.track_id
1214
+ if tid is not None and tid >= 0:
1215
+ tid = int(tid)
1216
+ if tid not in trajectories:
1217
+ trajectories[tid] = []
1218
+ trajectories[tid].append((r.frame_id, j, bb))
1219
+ smoothed: Dict[Tuple[int, int], Tuple[int, int, int, int]] = {}
1220
+ half = window // 2
1221
+ for tid, items in trajectories.items():
1222
+ items.sort(key=lambda x: x[0])
1223
+ n = len(items)
1224
+ for k in range(n):
1225
+ fid, box_idx, bb = items[k]
1226
+ result_idx = fid_to_idx[fid]
1227
+ lo = max(0, k - half)
1228
+ hi = min(n, k + half + 1)
1229
+ cx_list = [0.5 * (items[m][2].x1 + items[m][2].x2) for m in range(lo, hi)]
1230
+ cy_list = [0.5 * (items[m][2].y1 + items[m][2].y2) for m in range(lo, hi)]
1231
+ w_list = [items[m][2].x2 - items[m][2].x1 for m in range(lo, hi)]
1232
+ h_list = [items[m][2].y2 - items[m][2].y1 for m in range(lo, hi)]
1233
+ cx_avg = sum(cx_list) / len(cx_list)
1234
+ cy_avg = sum(cy_list) / len(cy_list)
1235
+ w_avg = sum(w_list) / len(w_list)
1236
+ h_avg = sum(h_list) / len(h_list)
1237
+ x1_new = int(round(cx_avg - w_avg / 2))
1238
+ y1_new = int(round(cy_avg - h_avg / 2))
1239
+ x2_new = int(round(cx_avg + w_avg / 2))
1240
+ y2_new = int(round(cy_avg + h_avg / 2))
1241
+ smoothed[(result_idx, box_idx)] = (x1_new, y1_new, x2_new, y2_new)
1242
+ out: List[TVFrameResult] = []
1243
+ for i, r in enumerate(results):
1244
+ new_boxes: List[BoundingBox] = []
1245
+ for j, bb in enumerate(r.boxes):
1246
+ key = (i, j)
1247
+ if key in smoothed:
1248
+ x1, y1, x2, y2 = smoothed[key]
1249
+ new_boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, cls_id=int(bb.cls_id), conf=round(float(bb.conf), 2), team_id=bb.team_id, track_id=bb.track_id))
1250
+ else:
1251
+ new_boxes.append(BoundingBox(x1=int(bb.x1), y1=int(bb.y1), x2=int(bb.x2), y2=int(bb.y2), cls_id=int(bb.cls_id), conf=round(float(bb.conf), 2), team_id=bb.team_id, track_id=bb.track_id))
1252
+ out.append(TVFrameResult(frame_id=r.frame_id, boxes=new_boxes, keypoints=r.keypoints))
1253
+ return out
1254
+
1255
+
1256
+ # ── Miner ─────────────────────────────────────────────────────────────────────
1257
+
1258
+ class Miner:
1259
+ def __init__(self, path_hf_repo: Path) -> None:
1260
+ self.path_hf_repo = Path(path_hf_repo)
1261
+ self.is_start = False
1262
+ self._executor = ThreadPoolExecutor(max_workers=2)
1263
+
1264
+ global _OSNET_MODEL, osnet_weight_path
1265
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1266
+ self.device = device
1267
+
1268
+ # Person model: prefer ONNX (new-2 style), fallback to .pt
1269
+ models_dir = self.path_hf_repo
1270
+ person_onnx = models_dir / "player_detect.onnx"
1271
+ self._person_model_onnx = person_onnx.exists()
1272
+ if person_onnx.exists():
1273
+ self.bbox_model = YOLO(str(person_onnx), task="detect")
1274
+ print("✅ Person Model Loaded (ONNX)")
1275
+ else:
1276
+ self.bbox_model = None
1277
+ print("⚠️ Person model not found (tried player_detect.onnx)")
1278
+
1279
+ # OSNet team classifier
1280
+ osnet_weight_path = self.path_hf_repo / "osnet_model.pth.tar-100"
1281
+ if osnet_weight_path.exists():
1282
+ _OSNET_MODEL = load_osnet(device, osnet_weight_path)
1283
+ print("✅ Team Classifier Loaded (OSNet)")
1284
+ else:
1285
+ _OSNET_MODEL = None
1286
+ print(f"⚠️ OSNet weights not found at {osnet_weight_path}. Using HSV fallback.")
1287
+
1288
+ # Keypoints model: HRNet
1289
+ kp_config_file = "hrnetv2_w48.yaml"
1290
+ kp_weights_file = "keypoint_detect.pt"
1291
+ config_path = Path(kp_config_file) if Path(kp_config_file).exists() else self.path_hf_repo / kp_config_file
1292
+ weights_path = Path(kp_weights_file) if Path(kp_weights_file).exists() else self.path_hf_repo / kp_weights_file
1293
+ cfg = yaml.safe_load(open(config_path, 'r'))
1294
+ hrnet = get_cls_net(cfg)
1295
+ state = torch.load(weights_path, map_location=device, weights_only=False)
1296
+ hrnet.load_state_dict(state)
1297
+ hrnet.to(device).eval()
1298
+ self.keypoints_model = hrnet
1299
+ print("✅ HRNet Keypoints Model Loaded")
1300
+
1301
+ # Person detection state (new-2 style)
1302
+ self._person_tracker_state: Dict[int, Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float], int]] = {}
1303
+ self._person_tracker_next_id = 0
1304
+ self._track_id_to_team_votes: Dict[int, Dict[str, int]] = {}
1305
+ self._track_id_to_class_votes: Dict[int, Dict[int, int]] = {}
1306
+ self._prev_batch_tail_tid_counts: Dict[int, int] = {}
1307
+
1308
+ def reset_for_new_video(self) -> None:
1309
+ self._person_tracker_state.clear()
1310
+ self._person_tracker_next_id = 0
1311
+ self._track_id_to_team_votes.clear()
1312
+ self._track_id_to_class_votes.clear()
1313
+ self._prev_batch_tail_tid_counts.clear()
1314
+
1315
+ def __repr__(self) -> str:
1316
+ return (
1317
+ f"BBox Model: {type(self.bbox_model).__name__}\n"
1318
+ f"Keypoints Model: {type(self.keypoints_model).__name__}\n"
1319
+ f"Team Clustering: OSNet + KMeans"
1320
+ )
1321
+
1322
+ def _bbox_task(self, images: list[ndarray], offset: int = 0) -> list[list[BoundingBox]]:
1323
+ """Person detection pipeline (new-2 style): tracking, class votes, OSNet teams, adjust."""
1324
+ if not images:
1325
+ return []
1326
+ if self.bbox_model is None:
1327
+ return [[] for _ in images]
1328
+ try:
1329
+ kw = {"imgsz": PERSON_MODEL_IMG_SIZE, "conf": PERSON_CONF, "verbose": False}
1330
+ if PERSON_HALF and not self._person_model_onnx:
1331
+ try:
1332
+ if next(self.bbox_model.model.parameters()).is_cuda:
1333
+ kw["half"] = True
1334
+ except Exception:
1335
+ pass
1336
+ batch_res = self.bbox_model(images, **kw)
1337
+ except Exception:
1338
+ return [[] for _ in images]
1339
+ if not isinstance(batch_res, list):
1340
+ batch_res = [batch_res] if batch_res is not None else []
1341
+ self._person_tracker_state, self._person_tracker_next_id, person_track_ids = _assign_person_track_ids(
1342
+ self._person_tracker_state, self._person_tracker_next_id, batch_res, TRACK_IOU_THRESH
1343
+ )
1344
+ person_res = batch_res
1345
+
1346
+ # Parse boxes: ONNX 0=player, 1=referee, 2=goalkeeper; .pt 0=ball(skip), 1=GK, 2=player, 3=referee
1347
+ bboxes_by_frame: Dict[int, List[BoundingBox]] = {}
1348
+ track_ids_by_frame: Dict[int, List[Optional[int]]] = {}
1349
+ for i, det_p in enumerate(person_res):
1350
+ frame_id = offset + i
1351
+ boxes_raw: List[BoundingBox] = []
1352
+ track_ids_raw: List[Optional[int]] = []
1353
+ if det_p is not None and getattr(det_p, "boxes", None) is not None and len(det_p.boxes) > 0:
1354
+ b = det_p.boxes
1355
+ xyxy = b.xyxy.cpu().numpy()
1356
+ confs = b.conf.cpu().numpy() if b.conf is not None else np.ones(len(xyxy), dtype=np.float32)
1357
+ clss = b.cls.cpu().numpy().astype(int) if b.cls is not None else np.zeros(len(xyxy), dtype=np.int32)
1358
+ tids = person_track_ids[i] if i < len(person_track_ids) and len(person_track_ids[i]) == len(clss) else [-1] * len(clss)
1359
+ for (x1, y1, x2, y2), c, cf, tid in zip(xyxy, clss, confs, tids):
1360
+ c, tid = int(c), int(tid)
1361
+ x1r, y1r, x2r, y2r = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
1362
+ tid_out = tid if tid >= 0 else None
1363
+ if self._person_model_onnx:
1364
+ if c == 0:
1365
+ boxes_raw.append(BoundingBox(x1=x1r, y1=y1r, x2=x2r, y2=y2r, cls_id=_C_PLAYER, conf=float(cf), team_id=None, track_id=tid_out))
1366
+ track_ids_raw.append(tid_out)
1367
+ elif c == 1:
1368
+ boxes_raw.append(BoundingBox(x1=x1r, y1=y1r, x2=x2r, y2=y2r, cls_id=_C_REFEREE, conf=float(cf), team_id=None, track_id=tid_out))
1369
+ track_ids_raw.append(tid_out)
1370
+ elif c == 2:
1371
+ boxes_raw.append(BoundingBox(x1=x1r, y1=y1r, x2=x2r, y2=y2r, cls_id=_C_GOALKEEPER, conf=float(cf), team_id=None, track_id=tid_out))
1372
+ track_ids_raw.append(tid_out)
1373
+ else:
1374
+ if c == 0:
1375
+ continue
1376
+ internal_cls = {1: _C_GOALKEEPER, 2: _C_PLAYER, 3: _C_REFEREE}.get(c, _C_PLAYER)
1377
+ boxes_raw.append(BoundingBox(x1=x1r, y1=y1r, x2=x2r, y2=y2r, cls_id=internal_cls, conf=float(cf), team_id=None, track_id=tid_out))
1378
+ track_ids_raw.append(tid_out)
1379
+ bboxes_by_frame[frame_id] = boxes_raw
1380
+ track_ids_by_frame[frame_id] = track_ids_raw
1381
+
1382
+ # Noise filter: remove short tracks in tail
1383
+ if len(images) > NOISE_TAIL_FRAMES:
1384
+ tid_counts: Dict[int, int] = {}
1385
+ tid_first_frame: Dict[int, int] = {}
1386
+ for fid in range(offset, offset + len(images)):
1387
+ for tid in track_ids_by_frame.get(fid, []):
1388
+ if tid is not None and tid >= 0:
1389
+ t = int(tid)
1390
+ tid_counts[t] = tid_counts.get(t, 0) + 1
1391
+ if t not in tid_first_frame or fid < tid_first_frame[t]:
1392
+ tid_first_frame[t] = fid
1393
+ for t, prev_count in self._prev_batch_tail_tid_counts.items():
1394
+ tid_counts[t] = tid_counts.get(t, 0) + prev_count
1395
+ if prev_count > 0:
1396
+ tid_first_frame[t] = offset + len(images)
1397
+ boundary = offset + len(images) - NOISE_TAIL_FRAMES
1398
+ noise_tids = {t for t, count in tid_counts.items() if count < NOISE_MIN_APPEARANCES and tid_first_frame.get(t, 0) < boundary}
1399
+ for fid in range(offset, offset + len(images)):
1400
+ boxes = bboxes_by_frame.get(fid, [])
1401
+ tids = track_ids_by_frame.get(fid, [None] * len(boxes))
1402
+ keep = [j for j in range(len(boxes)) if tids[j] is None or int(tids[j]) not in noise_tids]
1403
+ bboxes_by_frame[fid] = [boxes[j] for j in keep]
1404
+ track_ids_by_frame[fid] = [tids[j] for j in keep]
1405
+ tail_start = offset + len(images) - NOISE_TAIL_FRAMES
1406
+ self._prev_batch_tail_tid_counts = {}
1407
+ for fid in range(tail_start, offset + len(images)):
1408
+ for tid in track_ids_by_frame.get(fid, []):
1409
+ if tid is not None and tid >= 0:
1410
+ t = int(tid)
1411
+ self._prev_batch_tail_tid_counts[t] = self._prev_batch_tail_tid_counts.get(t, 0) + 1
1412
+
1413
+ # Class votes: collect votes per track (skip redundant IoU stabilization)
1414
+ for i in range(len(images)):
1415
+ frame_id = offset + i
1416
+ boxes_raw = bboxes_by_frame[frame_id]
1417
+ track_ids_raw = track_ids_by_frame[frame_id]
1418
+ for idx, bb in enumerate(boxes_raw):
1419
+ tid = track_ids_raw[idx] if idx < len(track_ids_raw) else bb.track_id
1420
+ if tid is not None and int(tid) >= 0:
1421
+ if tid not in self._track_id_to_class_votes:
1422
+ self._track_id_to_class_votes[tid] = {}
1423
+ self._track_id_to_class_votes[tid][int(bb.cls_id)] = self._track_id_to_class_votes[tid].get(int(bb.cls_id), 0) + 1
1424
+
1425
+ # Class votes: majority over track
1426
+ for fid in range(offset, offset + len(images)):
1427
+ new_boxes: List[BoundingBox] = []
1428
+ tids_fid = track_ids_by_frame.get(fid, [None] * len(bboxes_by_frame[fid]))
1429
+ for box_idx, box in enumerate(bboxes_by_frame[fid]):
1430
+ tid = tids_fid[box_idx] if box_idx < len(tids_fid) else None
1431
+ if tid is not None and tid >= 0 and tid in self._track_id_to_class_votes:
1432
+ votes = self._track_id_to_class_votes[tid]
1433
+ ref_votes = votes.get(_C_REFEREE, 0)
1434
+ gk_votes = votes.get(_C_GOALKEEPER, 0)
1435
+ if ref_votes > CLASS_VOTE_MAJORITY:
1436
+ majority_cls = _C_REFEREE
1437
+ elif gk_votes > CLASS_VOTE_MAJORITY:
1438
+ majority_cls = _C_GOALKEEPER
1439
+ else:
1440
+ majority_cls = max(votes.items(), key=lambda x: x[1])[0]
1441
+ new_boxes.append(BoundingBox(x1=box.x1, y1=box.y1, x2=box.x2, y2=box.y2, cls_id=majority_cls, conf=box.conf, team_id=None, track_id=tid))
1442
+ else:
1443
+ new_boxes.append(box)
1444
+ bboxes_by_frame[fid] = new_boxes
1445
+
1446
+ # Interpolate track gaps
1447
+ if INTERP_TRACK_GAPS and len(images) > 1:
1448
+ track_to_frames: Dict[int, List[Tuple[int, BoundingBox]]] = {}
1449
+ for fid in range(offset, offset + len(images)):
1450
+ for bb, tid in zip(bboxes_by_frame[fid], track_ids_by_frame.get(fid, [])):
1451
+ if tid is not None and int(tid) >= 0:
1452
+ track_to_frames.setdefault(int(tid), []).append((fid, bb))
1453
+ to_add: Dict[int, List[Tuple[BoundingBox, int]]] = {}
1454
+ for t, pairs in track_to_frames.items():
1455
+ pairs.sort(key=lambda p: p[0])
1456
+ for i in range(len(pairs) - 1):
1457
+ f1, b1 = pairs[i]
1458
+ f2, b2 = pairs[i + 1]
1459
+ if f2 - f1 <= 1:
1460
+ continue
1461
+ for g in range(f1 + 1, f2):
1462
+ w = (g - f1) / (f2 - f1)
1463
+ interp = BoundingBox(
1464
+ x1=int(round((1 - w) * b1.x1 + w * b2.x1)),
1465
+ y1=int(round((1 - w) * b1.y1 + w * b2.y1)),
1466
+ x2=int(round((1 - w) * b1.x2 + w * b2.x2)),
1467
+ y2=int(round((1 - w) * b1.y2 + w * b2.y2)),
1468
+ cls_id=b2.cls_id, conf=b2.conf, team_id=b2.team_id, track_id=t
1469
+ )
1470
+ to_add.setdefault(g, []).append((interp, t))
1471
+ for g, add_list in to_add.items():
1472
+ bboxes_by_frame[g] = list(bboxes_by_frame.get(g, []))
1473
+ track_ids_by_frame[g] = list(track_ids_by_frame.get(g, []))
1474
+ for interp_box, tid in add_list:
1475
+ bboxes_by_frame[g].append(interp_box)
1476
+ track_ids_by_frame[g].append(tid)
1477
+
1478
+ # OSNet team classification
1479
+ try:
1480
+ batch_boxes_for_osnet = {offset + i: bboxes_by_frame.get(offset + i, []) for i in range(len(images))}
1481
+ _classify_teams_batch(images, batch_boxes_for_osnet, self.device)
1482
+ for fid in batch_boxes_for_osnet:
1483
+ bboxes_by_frame[fid] = batch_boxes_for_osnet[fid]
1484
+ except Exception:
1485
+ pass
1486
+
1487
+ # Team votes
1488
+ reid_team_per_frame: List[List[Optional[str]]] = []
1489
+ for fi in range(len(images)):
1490
+ frame_id = offset + fi
1491
+ boxes_f = bboxes_by_frame.get(frame_id, [])
1492
+ tids_f = track_ids_by_frame.get(frame_id, [])
1493
+ row: List[Optional[str]] = []
1494
+ for bi, box in enumerate(boxes_f):
1495
+ tid = tids_f[bi] if bi < len(tids_f) else box.track_id
1496
+ team_str = str(box.team_id) if box.team_id is not None else None
1497
+ if tid is not None and tid >= 0 and team_str:
1498
+ if tid not in self._track_id_to_team_votes:
1499
+ self._track_id_to_team_votes[tid] = {}
1500
+ self._track_id_to_team_votes[tid][team_str] = self._track_id_to_team_votes[tid].get(team_str, 0) + 1
1501
+ row.append(team_str)
1502
+ reid_team_per_frame.append(row)
1503
+ for fid in range(offset, offset + len(images)):
1504
+ fi = fid - offset
1505
+ new_boxes = []
1506
+ tids_fid = track_ids_by_frame.get(fid, [None] * len(bboxes_by_frame[fid]))
1507
+ for box_idx, box in enumerate(bboxes_by_frame[fid]):
1508
+ tid = tids_fid[box_idx] if box_idx < len(tids_fid) else box.track_id
1509
+ team_from_reid = reid_team_per_frame[fi][box_idx] if fi < len(reid_team_per_frame) and box_idx < len(reid_team_per_frame[fi]) else None
1510
+ default_team = team_from_reid or (str(box.team_id) if box.team_id is not None else None)
1511
+ if tid is not None and tid >= 0 and tid in self._track_id_to_team_votes and self._track_id_to_team_votes[tid]:
1512
+ majority_team = max(self._track_id_to_team_votes[tid].items(), key=lambda x: x[1])[0]
1513
+ else:
1514
+ majority_team = default_team
1515
+ team_id_out = int(majority_team) if majority_team and majority_team.isdigit() else (int(majority_team) if majority_team else None)
1516
+ new_boxes.append(BoundingBox(x1=box.x1, y1=box.y1, x2=box.x2, y2=box.y2, cls_id=box.cls_id, conf=box.conf, team_id=team_id_out, track_id=tid))
1517
+ bboxes_by_frame[fid] = new_boxes
1518
+
1519
+ # Adjust boxes: overlap NMS, GK dedup, referee disambiguation
1520
+ H, W = images[0].shape[:2] if images else (0, 0)
1521
+ for fid in range(offset, offset + len(images)):
1522
+ orig = bboxes_by_frame[fid]
1523
+ tids = track_ids_by_frame.get(fid, [None] * len(orig))
1524
+ adjusted = _adjust_boxes(orig, W, H, do_goalkeeper_dedup=True, do_referee_disambiguation=True)
1525
+ adjusted_tids: List[Optional[int]] = []
1526
+ used = set()
1527
+ for ab in adjusted:
1528
+ for oi, ob in enumerate(orig):
1529
+ if oi in used:
1530
+ continue
1531
+ if ob.x1 == ab.x1 and ob.y1 == ab.y1 and ob.x2 == ab.x2 and ob.y2 == ab.y2:
1532
+ adjusted_tids.append(tids[oi] if oi < len(tids) else None)
1533
+ used.add(oi)
1534
+ break
1535
+ bboxes_by_frame[fid] = adjusted
1536
+
1537
+ # Output: validator cls_id (0=player, 1=referee, 2=goalkeeper)
1538
+ out: List[List[BoundingBox]] = []
1539
+ for i in range(len(images)):
1540
+ boxes = bboxes_by_frame.get(offset + i, [])
1541
+ for bb in boxes:
1542
+ bb.cls_id = _CLS_TO_VALIDATOR.get(int(bb.cls_id), int(bb.cls_id))
1543
+ out.append(boxes)
1544
+ return out
1545
+
1546
+ def _keypoint_task(self, images: list[ndarray], n_keypoints: int) -> list[list]:
1547
+ """HRNet keypoints + homography refinement."""
1548
+ if not images:
1549
+ return []
1550
+ if self.keypoints_model is None:
1551
+ return [[(0, 0)] * n_keypoints for _ in images]
1552
+ try:
1553
+ raw_kps = _run_hrnet_batch(images, self.keypoints_model, KP_THRESHOLD, batch_size=HRNET_BATCH_SIZE)
1554
+ except Exception:
1555
+ return [[(0, 0)] * n_keypoints for _ in images]
1556
+ raw_kps = [_apply_keypoint_mapping(kp) for kp in raw_kps] if raw_kps else []
1557
+ keypoints = _normalize_keypoints(raw_kps, images, n_keypoints) if raw_kps else [[(0, 0)] * n_keypoints for _ in images]
1558
+ keypoints = [_fix_keypoints(kps, n_keypoints) for kps in keypoints]
1559
+ keypoints = [_keypoints_to_float(kps) for kps in keypoints]
1560
+ # if n_keypoints == 32 and len(TEMPLATE_F0) == 32 and len(TEMPLATE_F1) == 32:
1561
+ # for idx in range(len(images)):
1562
+ # try:
1563
+ # keypoints[idx] = _apply_homography_refinement(keypoints[idx], images[idx], n_keypoints)
1564
+ # except Exception:
1565
+ # pass
1566
+ # keypoints = [_keypoints_to_int(kps) for kps in keypoints]
1567
+ return keypoints
1568
+
1569
+ def predict_batch(
1570
+ self,
1571
+ batch_images: list[ndarray],
1572
+ offset: int,
1573
+ n_keypoints: int,
1574
+ ) -> list[TVFrameResult]:
1575
+
1576
+ if not self.is_start:
1577
+ self.is_start = True
1578
+
1579
+ images = list(batch_images)
1580
+ if offset == 0:
1581
+ self.reset_for_new_video()
1582
+ gc.collect()
1583
+ if torch.cuda.is_available():
1584
+ torch.cuda.empty_cache()
1585
+
1586
+ # Run bbox (batched YOLO) and keypoints in parallel
1587
+ future_bbox = self._executor.submit(self._bbox_task, images, offset)
1588
+ future_kp = self._executor.submit(self._keypoint_task, images, n_keypoints)
1589
+ bbox_per_frame = future_bbox.result()
1590
+ keypoints = future_kp.result()
1591
+
1592
+ return [
1593
+ TVFrameResult(frame_id=offset + i, boxes=bbox_per_frame[i], keypoints=keypoints[i])
1594
+ for i in range(len(images))
1595
+ ]
osnet_model.pth.tar-100 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45e1de9d329b534c16f450d99a898c516f8b237dcea471053242c2d4c76b4ace
3
+ size 26846063
player_detect.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14728d02dc5248ef57eda8e336feab57adf977abf6f46d33f08f0a13183e53ea
3
+ size 81533225
player_detect.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:934be460f78c594cc98078027f280c23385c9897e3e761e438559b3193233b46
3
+ size 19209626