diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..4d140dfd2c4fad27787de95b270bc54f7374f552 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,44 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +SV_kp.engine filter=lfs diff=lfs merge=lfs -text +osnet_model.pth.tar-100 filter=lfs diff=lfs merge=lfs -text +__pycache__/keypoint_helper_v2.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text +__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text +__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.1837368399824 filter=lfs diff=lfs merge=lfs -text +__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2364780042192 filter=lfs diff=lfs merge=lfs -text +__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2618992613328 filter=lfs diff=lfs merge=lfs -text +best.engine filter=lfs diff=lfs merge=lfs -text +keypoint filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c238a6d590ffcf0a05a6e9de61bb9bc3c1b66504 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +venv +outputs +outputs-keypoints +outputs-detections +*.mp4 diff --git a/20251029-detection.pt b/20251029-detection.pt new file mode 100644 index 0000000000000000000000000000000000000000..a7a78cd91115ac28a63888f4e289ca8bc1599272 --- /dev/null +++ b/20251029-detection.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bbacfcb38e38b1b8816788e9e6e845160533719a0b87b693d58b932380d0d28 +size 152961687 diff --git a/20251029-keypoint.pt b/20251029-keypoint.pt new file mode 100644 index 0000000000000000000000000000000000000000..9c661ffe3061ab8b1cf96f91a67818cd51d7cbd2 --- /dev/null +++ b/20251029-keypoint.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dd10dba85895c92760cdb5a99c5cfca899c68f361a66c5448f38a187280ee1f +size 6849672 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bb022242b42a16ea5307dce57de4a3bc32007b00 --- /dev/null +++ b/README.md @@ -0,0 +1,132 @@ +🚀 Example Chute for Turbovision 🪂 + +This repository demonstrates how to deploy a Chute via the Turbovision CLI, hosted on Hugging Face Hub. It serves as a minimal example showcasing the required structure and workflow for integrating machine learning models, preprocessing, and orchestration into a reproducible Chute environment. + +## Repository Structure + +The following two files must be present (in their current locations) for a successful deployment — their content can be modified as needed: + +| File | Purpose | +|------|---------| +| `miner.py` | Defines the ML model type(s), orchestration, and all pre/postprocessing logic. | +| `config.yml` | Specifies machine configuration (e.g., GPU type, memory, environment variables). | + +Other files — e.g., model weights, utility scripts, or dependencies — are optional and can be included as needed for your model. + +> **Note**: Any required assets must be defined or contained within this repo, which is fully open-source, since all network-related operations (downloading challenge data, weights, etc.) are disabled inside the Chute. + +## Overview + +Below is a high-level diagram showing the interaction between Huggingface, Chutes and Turbovision: + +``` +┌─────────────┐ ┌──────────┐ ┌──────────────┐ +│ HuggingFace │ ───> │ Chutes │ ───> │ Turbovision │ +│ Hub │ │ .ai │ │ Validator │ +└─────────────┘ └──────────┘ └──────────────┘ +``` + +## Local Testing + +After editing the `config.yml` and `miner.py` and saving it into your Huggingface Repo, you will want to test it works locally. + +1. **Copy the template file** `scorevision/chute_template/turbovision_chute.py.j2` as a python file called `my_chute.py` and fill in the missing variables: + +```python +HF_REPO_NAME = "{{ huggingface_repository_name }}" +HF_REPO_REVISION = "{{ huggingface_repository_revision }}" +CHUTES_USERNAME = "{{ chute_username }}" +CHUTE_NAME = "{{ chute_name }}" +``` + +2. **Run the following command to build the chute locally** (Caution: there are known issues with the docker location when running this on a mac): + +```bash +chutes build my_chute:chute --local --public +``` + +3. **Run the name of the docker image just built** (i.e. `CHUTE_NAME`) and enter it: + +```bash +docker run -p 8000:8000 -e CHUTES_EXECUTION_CONTEXT=REMOTE -it /bin/bash +``` + +4. **Run the file from within the container**: + +```bash +chutes run my_chute:chute --dev --debug +``` + +5. **In another terminal, test the local endpoints** to ensure there are no bugs: + +```bash +# Health check +curl -X POST http://localhost:8000/health -d '{}' + +# Prediction test +curl -X POST http://localhost:8000/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}' +``` + +## Live Testing + +If you have any chute with the same name (i.e. from a previous deployment), ensure you delete that first (or you will get an error when trying to build). + +1. **List existing chutes**: + +```bash +chutes chutes list +``` + +Take note of the chute id that you wish to delete (if any): + +```bash +chutes chutes delete +``` + +2. **You should also delete its associated image**: + +```bash +chutes images list +``` + +Take note of the chute image id: + +```bash +chutes images delete +``` + +3. **Use Turbovision's CLI to build, deploy and commit on-chain**: + +```bash +sv -vv push +``` + +> **Note**: You can skip the on-chain commit using `--no-commit`. You can also specify a past huggingface revision to point to using `--revision` and/or the local files you want to upload to your huggingface repo using `--model-path`. + +4. **When completed, warm up the chute** (if its cold 🧊): + +You can confirm its status using `chutes chutes list` or `chutes chutes get ` if you already know its id. + +> **Note**: Warming up can sometimes take a while but if the chute runs without errors (should be if you've tested locally first) and there are sufficient nodes (i.e. machines) available matching the `config.yml` you specified, the chute should become hot 🔥! + +```bash +chutes warmup +``` + +5. **Test the chute's endpoints**: + +```bash +# Health check +curl -X POST https://.chutes.ai/health -d '{}' -H "Authorization: Bearer $CHUTES_API_KEY" + +# Prediction +curl -X POST https://.chutes.ai/predict -d '{"url": "https://scoredata.me/2025_03_14/35ae7a/h1_0f2ca0.mp4","meta": {}}' -H "Authorization: Bearer $CHUTES_API_KEY" +``` + +6. **Test what your chute would get on a validator**: + +This also applies any validation/integrity checks which may fail if you did not use the Turbovision CLI above to deploy the chute: + +```bash +sv -vv run-once +``` diff --git a/SV_kp.engine b/SV_kp.engine new file mode 100644 index 0000000000000000000000000000000000000000..4b069009a95b2795e4187168871c15352a785e5d --- /dev/null +++ b/SV_kp.engine @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f99452eb79e064189e2758abd20a78845a5b639fc8b9c4bc650519c83e13e8db +size 368289641 diff --git a/__pycache__/keypoint_evaluation.cpython-312.pyc b/__pycache__/keypoint_evaluation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2522d905b47cee1bdcf92b4e7588d7b58e694d1c Binary files /dev/null and b/__pycache__/keypoint_evaluation.cpython-312.pyc differ diff --git a/__pycache__/keypoint_helper.cpython-312.pyc b/__pycache__/keypoint_helper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4264a8d248d6a47204eb2ddee303ce80e2f9a4f1 Binary files /dev/null and b/__pycache__/keypoint_helper.cpython-312.pyc differ diff --git a/__pycache__/keypoint_helper_v2.cpython-312.pyc b/__pycache__/keypoint_helper_v2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73ae9fea02dd33dabde0df11705496d42987b2c3 --- /dev/null +++ b/__pycache__/keypoint_helper_v2.cpython-312.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6c301a3602090dab908225fea736b9f6211ecd4c1733cfda40dd3e274e67ed7 +size 119899 diff --git a/__pycache__/keypoint_helper_v2.cpython-312.pyc.2609775282608 b/__pycache__/keypoint_helper_v2.cpython-312.pyc.2609775282608 new file mode 100644 index 0000000000000000000000000000000000000000..0b4a8ce3f32d9e9d7756db0594bf6db2edc43e40 Binary files /dev/null and b/__pycache__/keypoint_helper_v2.cpython-312.pyc.2609775282608 differ diff --git a/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef8e8b9b1efd393e73a5b2538d3b14390de61be1 --- /dev/null +++ b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d60d623ae2f0ce1ba3cfc4e42058914cb6acccaaf082ec098556a55c69ec99a2 +size 135087 diff --git a/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.1837368399824 b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.1837368399824 new file mode 100644 index 0000000000000000000000000000000000000000..c93a86060b4f29a3e05d72da232414209bdfbf52 --- /dev/null +++ b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.1837368399824 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07b377a3473645ff304d0070a9fc71891639557d2f2b4f19ffb0fc108bdc2666 +size 134432 diff --git a/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2364780042192 b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2364780042192 new file mode 100644 index 0000000000000000000000000000000000000000..ae350f327d3d815e6676e0888064a21ef3370ae0 --- /dev/null +++ b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2364780042192 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34253c900b87954bd1f34881e33d9d8cf2fba247b4a65f17cd21673ba837d94d +size 133125 diff --git a/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2618992613328 b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2618992613328 new file mode 100644 index 0000000000000000000000000000000000000000..fb2f24e1cd7f1a83ef972f66010f082b5944124b --- /dev/null +++ b/__pycache__/keypoint_helper_v2_optimized.cpython-312.pyc.2618992613328 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86d3822ba6714e8dd6300f6d6e034c5c69191dca702caa4837326978c503fa0e +size 133215 diff --git a/__pycache__/miner.cpython-312.pyc b/__pycache__/miner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1f294c5ac3895306ba26907ac0381c75b00c76 Binary files /dev/null and b/__pycache__/miner.cpython-312.pyc differ diff --git a/__pycache__/miner.cpython-312.pyc.2050184619568 b/__pycache__/miner.cpython-312.pyc.2050184619568 new file mode 100644 index 0000000000000000000000000000000000000000..29c9ec37f2453887174237613042f589aacda8e9 Binary files /dev/null and b/__pycache__/miner.cpython-312.pyc.2050184619568 differ diff --git a/__pycache__/miner.cpython-312.pyc.2701627401776 b/__pycache__/miner.cpython-312.pyc.2701627401776 new file mode 100644 index 0000000000000000000000000000000000000000..07d4ca730a6707313f32c4976ab82a0d52efe65d Binary files /dev/null and b/__pycache__/miner.cpython-312.pyc.2701627401776 differ diff --git a/__pycache__/miner1.cpython-312.pyc b/__pycache__/miner1.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5175f7b2e34eaf98568f52c7130e15374fa371e3 Binary files /dev/null and b/__pycache__/miner1.cpython-312.pyc differ diff --git a/__pycache__/miner2.cpython-312.pyc b/__pycache__/miner2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9120622ae64b693a2942cc98ba9d070dbb7ae5a0 Binary files /dev/null and b/__pycache__/miner2.cpython-312.pyc differ diff --git a/__pycache__/miner3.cpython-312.pyc b/__pycache__/miner3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e427faf2957d8f5665ce52789775e91e38419f Binary files /dev/null and b/__pycache__/miner3.cpython-312.pyc differ diff --git a/__pycache__/pitch.cpython-312.pyc b/__pycache__/pitch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cfaca49e4aa616ded9a5c74adf90f0df32f4f68 Binary files /dev/null and b/__pycache__/pitch.cpython-312.pyc differ diff --git a/__pycache__/test_predict_batch.cpython-312.pyc b/__pycache__/test_predict_batch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..107049c02141b836bac3eff98a2044dd022d157b Binary files /dev/null and b/__pycache__/test_predict_batch.cpython-312.pyc differ diff --git a/best.engine b/best.engine new file mode 100644 index 0000000000000000000000000000000000000000..4a507b5097f355a12dd2ba4d8fca47adcb8c8b9c --- /dev/null +++ b/best.engine @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d93cf7017bf7190a24f29e48548493aaf8ebd8f96a8257ebb8a0f42bd266e7b +size 9167745 diff --git a/best.onnx b/best.onnx new file mode 100644 index 0000000000000000000000000000000000000000..01d27dbdd561aa359d29d7ce7aad62fdf01ebc10 --- /dev/null +++ b/best.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f45602c5c3f13822c4bdf35d06b505dc4a47c94a14ed60943ccc61c6992433f +size 5908859 diff --git a/best.pt b/best.pt new file mode 100644 index 0000000000000000000000000000000000000000..086c8d8527410ac18d8b911f71ac21c879a720fb --- /dev/null +++ b/best.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce387539acef635b248dc99b1e34e24993de604db59aa5dfd3c6f8c696cac003 +size 5433178 diff --git a/config.yml b/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..328de3ff86a593dd74b9e0a2b0c002927977f19c --- /dev/null +++ b/config.yml @@ -0,0 +1,24 @@ +Image: + from_base: parachutes/python:3.12 + run_command: + - pip install --upgrade setuptools wheel + - pip install ultralytics==8.3.222 opencv-python-headless numpy pydantic + - pip install scikit-learn + - pip install onnxruntime-gpu + set_workdir: /app + +NodeSelector: + gpu_count: 1 + min_vram_gb_per_gpu: 16 + exclude: + - "5090" + - b200 + - h200 + - mi300x + +Chute: + timeout_seconds: 900 + concurrency: 4 + max_instances: 5 + scaling_threshold: 0.5 + shutdown_after_seconds: 3600 \ No newline at end of file diff --git a/detection.onnx b/detection.onnx new file mode 100644 index 0000000000000000000000000000000000000000..40107c76ac96cb10191563cfce28d67e800149eb --- /dev/null +++ b/detection.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b51470cb703f5a9a789df38674b67d4bbe7f8f31846d69dbc97ce484f790cf9 +size 10245169 diff --git a/detection.pt b/detection.pt new file mode 100644 index 0000000000000000000000000000000000000000..89b67e0d451a96f5a0f84e3a7996038253c5330b --- /dev/null +++ b/detection.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ad3e89b658d2626c34174f6799d240ffd37cfe45752c0ce6ef73b05935042e0 +size 52014742 diff --git a/evaluate_from_url.py b/evaluate_from_url.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b00bf1edddf4c63ae3066cd6ec50f46c9e7812 --- /dev/null +++ b/evaluate_from_url.py @@ -0,0 +1,286 @@ +import argparse +import json +import tempfile +from pathlib import Path +from typing import List, Tuple, Dict +import urllib.request +import urllib.parse +import urllib.error + +import cv2 +import numpy as np + +from miner1 import TVFrameResult, BoundingBox +from keypoint_evaluation import ( + load_template_from_file, +) +from test_predict_batch import ( + evaluate_keypoints_batch, + visualize_keypoint_evaluation, +) + + +def fetch_json_data(url: str) -> dict: + """Fetch JSON data from URL.""" + print(f"Fetching data from {url}...") + + # Create a request with headers to avoid 403 errors + req = urllib.request.Request(url) + req.add_header('User-Agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36') + req.add_header('Accept', 'application/json, text/plain, */*') + req.add_header('Accept-Language', 'en-US,en;q=0.9') + + try: + with urllib.request.urlopen(req) as response: + data = json.loads(response.read().decode('utf-8')) + predictions = data.get('predictions', {}) + frames_list = predictions.get('frames', []) + print(f"Successfully fetched data with {len(frames_list)} frames") + return data + except urllib.error.HTTPError as e: + print(f"HTTP Error {e.code}: {e.reason}") + if e.code == 403: + print("403 Forbidden: The server is blocking the request. This might require authentication or different headers.") + raise + except urllib.error.URLError as e: + print(f"URL Error: {e.reason}") + raise + + +def download_video(video_url: str, output_path: Path) -> Path: + """Download video from URL to local file.""" + print(f"Downloading video from {video_url}...") + output_path.parent.mkdir(parents=True, exist_ok=True) + urllib.request.urlretrieve(video_url, str(output_path)) + print(f"Video downloaded to {output_path}") + return output_path + + +def extract_frames_from_video(video_path: Path, frame_ids: List[int] = None) -> Dict[int, np.ndarray]: + """Extract frames from video, optionally only specific frame IDs.""" + print(f"Extracting frames from {video_path}...") + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Unable to open video: {video_path}") + + frames = {} + frame_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + if frame_ids is None or frame_count in frame_ids: + frames[frame_count] = frame + + frame_count += 1 + + cap.release() + print(f"Extracted {len(frames)} frames from video") + return frames + + +def convert_keypoints_format(json_keypoints: List[List[int]]) -> List[Tuple[int, int]]: + """Convert keypoints from JSON format [[x,y], [x,y], ...] to List[Tuple[int, int]].""" + return [(int(kp[0]), int(kp[1])) for kp in json_keypoints] + + +def convert_json_to_tvframe_results( + json_data: dict, + frames: Dict[int, np.ndarray], +) -> List[TVFrameResult]: + """ + Convert JSON data to TVFrameResult objects. + + Args: + json_data: JSON data containing predictions with frames, boxes, and keypoints + frames: Dictionary mapping frame_id to frame image + + Returns: + List of TVFrameResult objects + """ + predictions = json_data.get('predictions', {}) + frames_data = predictions.get('frames', []) + + results = [] + for frame_data in frames_data: + frame_id = frame_data.get('frame_id') + if frame_id not in frames: + print(f"Warning: Frame {frame_id} not found in extracted frames, skipping") + continue + + # Convert boxes + json_boxes = frame_data.get('boxes', []) + boxes = [] + for box_data in json_boxes: + box = BoundingBox( + x1=int(box_data.get('x1', 0)), + y1=int(box_data.get('y1', 0)), + x2=int(box_data.get('x2', 0)), + y2=int(box_data.get('y2', 0)), + cls_id=int(box_data.get('cls_id', 0)), + conf=float(box_data.get('conf', 0.0)), + ) + boxes.append(box) + + # Convert keypoints + json_keypoints = frame_data.get('keypoints', []) + keypoints = convert_keypoints_format(json_keypoints) + + result = TVFrameResult( + frame_id=frame_id, + boxes=boxes, + keypoints=keypoints, + ) + results.append(result) + + return results + + +def evaluate_keypoints_from_json( + json_data: dict, + frames: Dict[int, np.ndarray], + template_image: np.ndarray, + template_keypoints: List[Tuple[int, int]], + visualization_output_dir: Path = None, +) -> Dict[str, float]: + """ + Evaluate keypoint accuracy from JSON data using the same function as test_predict_batch.py. + + Args: + json_data: JSON data containing predictions with frames and keypoints + frames: Dictionary mapping frame_id to frame image + template_image: Template image for evaluation + template_keypoints: Template keypoints + visualization_output_dir: Optional directory to save visualization images + + Returns: + Dictionary with keypoint evaluation statistics + """ + # Convert JSON data to TVFrameResult objects + results = convert_json_to_tvframe_results(json_data, frames) + + if len(results) == 0: + print("No valid frames found in JSON data") + return { + "keypoint_avg_score": 0.0, + "keypoint_valid_frames": 0, + "keypoint_total_frames": 0, + } + + print(f"Evaluating {len(results)} frames using evaluate_keypoints_batch...") + + # Use the same evaluation function as test_predict_batch.py + stats = evaluate_keypoints_batch( + results=results, + original_frames=frames, + template_image=template_image, + template_keypoints=template_keypoints, + visualization_output_dir=visualization_output_dir, + ) + + print("\n=== Keypoint Evaluation Results ===") + print(f"Total frames: {stats['keypoint_total_frames']}") + print(f"Valid frames: {stats['keypoint_valid_frames']}") + print(f"Average score: {stats['keypoint_avg_score']:.3f}") + print(f"Max score: {stats['keypoint_max_score']:.3f}") + print(f"Min score: {stats['keypoint_min_score']:.3f}") + print(f"Frames with score > 0.5: {stats['keypoint_frames_above_0.5']}") + print(f"Frames with score > 0.7: {stats['keypoint_frames_above_0.7']}") + + return stats + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Fetch video and keypoint data from URL, evaluate keypoints, and visualize results." + ) + parser.add_argument( + "--url", + type=str, + default="https://pub-7b4130b6af75472f800371248bca15b6.r2.dev/scorevision/results_soccer/5Fnhz5fDihvno4DfssfRogL84VFvdDRRsgu19grbqEDPbJGv/responses/007115302-f9bd4226d1f4248c782a3179764e3203ce2fc520642eed4f7b02c40e61db55eb.json", + help="URL to fetch JSON data containing video_url and predictions.", + ) + parser.add_argument( + "--template-image", + type=Path, + default='football_pitch_template.png', + help="Path to football pitch template image.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default='outputs/url_evaluation', + help="Directory to save visualizations and downloaded video.", + ) + parser.add_argument( + "--delete-video", + action="store_true", + help="Delete downloaded video file after processing (default: keep video).", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + # Create output directory + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Fetch JSON data + json_data = fetch_json_data(args.url) + + # Get video URL + video_url = json_data.get('video_url') + if not video_url: + raise ValueError("No video_url found in JSON data") + + # Download video + video_filename = Path(urllib.parse.urlparse(video_url).path).name + if not video_filename: + video_filename = "video.mp4" + video_path = args.output_dir / video_filename + + download_video(video_url, video_path) + + # Get video filename without extension for folder naming + video_name_without_ext = Path(video_filename).stem + + # Get frame IDs from JSON + predictions = json_data.get('predictions', {}) + frames_data = predictions.get('frames', []) + frame_ids = [frame_data.get('frame_id') for frame_data in frames_data] + + # Extract frames from video + frames = extract_frames_from_video(video_path, frame_ids=frame_ids if frame_ids else None) + + # Load template + template_image, template_keypoints = load_template_from_file(str(args.template_image)) + + # Create visualization directory with video filename + visualization_dir = args.output_dir / f"visualizations_{video_name_without_ext}" + + # Evaluate keypoints + stats = evaluate_keypoints_from_json( + json_data=json_data, + frames=frames, + template_image=template_image, + template_keypoints=template_keypoints, + visualization_output_dir=visualization_dir, + ) + + # Clean up video if requested + if args.delete_video: + video_path.unlink() + print(f"Deleted video file: {video_path}") + else: + print(f"Video saved at: {video_path}") + + print(f"\nResults saved to: {args.output_dir}") + print(f"Visualizations saved to: {visualization_dir}") + + +if __name__ == "__main__": + main() + diff --git a/football_object_detection.pt b/football_object_detection.pt new file mode 100644 index 0000000000000000000000000000000000000000..a7a78cd91115ac28a63888f4e289ca8bc1599272 --- /dev/null +++ b/football_object_detection.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bbacfcb38e38b1b8816788e9e6e845160533719a0b87b693d58b932380d0d28 +size 152961687 diff --git a/football_pitch_template.png b/football_pitch_template.png new file mode 100644 index 0000000000000000000000000000000000000000..9b6144e801e29d1b6c59b2d548644037161d70fc Binary files /dev/null and b/football_pitch_template.png differ diff --git a/hrnetv2_w48.yaml b/hrnetv2_w48.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ddde65c80a0d419363ed902c421e0edc6431c1a0 --- /dev/null +++ b/hrnetv2_w48.yaml @@ -0,0 +1,35 @@ +MODEL: + IMAGE_SIZE: [960, 540] + NUM_JOINTS: 58 + PRETRAIN: '' + EXTRA: + FINAL_CONV_KERNEL: 1 + STAGE1: + NUM_MODULES: 1 + NUM_BRANCHES: 1 + BLOCK: BOTTLENECK + NUM_BLOCKS: [4] + NUM_CHANNELS: [64] + FUSE_METHOD: SUM + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: [4, 4] + NUM_CHANNELS: [48, 96] + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: [4, 4, 4] + NUM_CHANNELS: [48, 96, 192] + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: [4, 4, 4, 4] + NUM_CHANNELS: [48, 96, 192, 384] + FUSE_METHOD: SUM + diff --git a/inspect_yolo_model.py b/inspect_yolo_model.py new file mode 100644 index 0000000000000000000000000000000000000000..24ec6d8a82931adf49e091c2f189783be6a492c2 --- /dev/null +++ b/inspect_yolo_model.py @@ -0,0 +1,155 @@ +""" +Script to inspect a YOLO .pt model and determine its variant (nano, small, medium, large, xlarge). +""" +import argparse +from pathlib import Path +import torch +from ultralytics import YOLO + + +def inspect_yolo_model(model_path: Path): + """Inspect YOLO model to determine variant and architecture details.""" + print(f"Inspecting model: {model_path}") + print("=" * 60) + + # Method 1: Load with Ultralytics and check metadata + try: + model = YOLO(str(model_path)) + + # Check model info + print("\n--- Model Information ---") + print(f"Model type: {type(model.model)}") + + # Try to get model name from metadata + if hasattr(model, 'model') and hasattr(model.model, 'yaml'): + yaml_path = model.model.yaml + print(f"YAML config: {yaml_path}") + if yaml_path: + # Extract variant from yaml path + yaml_name = Path(yaml_path).stem if isinstance(yaml_path, (str, Path)) else str(yaml_path) + print(f"YAML name: {yaml_name}") + # Common patterns: yolo11n.yaml, yolo11s.yaml, yolo11m.yaml, yolo11l.yaml, yolo11x.yaml + # or yolov8n.yaml, yolov8s.yaml, etc. + if 'n' in yaml_name.lower(): + variant = "Nano (n)" + elif 's' in yaml_name.lower(): + variant = "Small (s)" + elif 'm' in yaml_name.lower(): + variant = "Medium (m)" + elif 'l' in yaml_name.lower(): + variant = "Large (l)" + elif 'x' in yaml_name.lower(): + variant = "XLarge (x)" + else: + variant = "Unknown" + print(f"Detected variant: {variant}") + + # Check model metadata if available + if hasattr(model.model, 'names'): + print(f"Number of classes: {len(model.model.names)}") + print(f"Class names: {list(model.model.names.values())[:5]}...") # Show first 5 + + # Get model info summary + print("\n--- Model Summary ---") + try: + info = model.info(verbose=False) + print(info) + except: + pass + + # Count parameters + if hasattr(model.model, 'parameters'): + total_params = sum(p.numel() for p in model.model.parameters()) + trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad) + print(f"\n--- Parameter Count ---") + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") + + # Rough estimates for YOLO variants (these vary by version but give a ballpark) + if total_params < 3_000_000: + size_estimate = "Nano (n) - typically < 3M params" + elif total_params < 12_000_000: + size_estimate = "Small (s) - typically 3-12M params" + elif total_params < 26_000_000: + size_estimate = "Medium (m) - typically 12-26M params" + elif total_params < 44_000_000: + size_estimate = "Large (l) - typically 26-44M params" + else: + size_estimate = "XLarge (x) - typically > 44M params" + print(f"Size estimate: {size_estimate}") + + except Exception as e: + print(f"Error loading with Ultralytics: {e}") + print("\nTrying alternative method...") + + # Method 2: Direct PyTorch inspection + print("\n" + "=" * 60) + print("--- Direct PyTorch Inspection ---") + try: + checkpoint = torch.load(str(model_path), map_location='cpu') + + # Check for metadata + if 'model' in checkpoint: + model_dict = checkpoint['model'] + if isinstance(model_dict, dict): + # Look for architecture hints in state dict keys + print("Checking state dict keys for architecture hints...") + keys = list(model_dict.keys())[:10] # First 10 keys + for key in keys: + print(f" {key}") + + # Count layers + layer_count = len([k for k in model_dict.keys() if 'weight' in k or 'bias' in k]) + print(f"\nTotal weight/bias tensors: {layer_count}") + + # Check checkpoint metadata + if 'epoch' in checkpoint: + print(f"Training epoch: {checkpoint.get('epoch', 'N/A')}") + if 'best_fitness' in checkpoint: + print(f"Best fitness: {checkpoint.get('best_fitness', 'N/A')}") + + # File size + file_size_mb = model_path.stat().st_size / (1024 * 1024) + print(f"\nModel file size: {file_size_mb:.2f} MB") + + # Rough size estimates based on file size (very approximate) + if file_size_mb < 6: + size_estimate = "Likely Nano (n) - file < 6MB" + elif file_size_mb < 22: + size_estimate = "Likely Small (s) - file 6-22MB" + elif file_size_mb < 50: + size_estimate = "Likely Medium (m) - file 22-50MB" + elif file_size_mb < 85: + size_estimate = "Likely Large (l) - file 50-85MB" + else: + size_estimate = "Likely XLarge (x) - file > 85MB" + print(f"Size estimate from file: {size_estimate}") + + except Exception as e: + print(f"Error with direct PyTorch inspection: {e}") + + print("\n" + "=" * 60) + print("Inspection complete!") + + +def main(): + parser = argparse.ArgumentParser( + description="Inspect YOLO .pt model to determine variant" + ) + parser.add_argument( + "--model_path", + type=Path, + help="Path to YOLO .pt model file" + ) + args = parser.parse_args() + + if not args.model_path.exists(): + print(f"Error: Model file not found: {args.model_path}") + return + + inspect_yolo_model(args.model_path) + + +if __name__ == "__main__": + main() + diff --git a/keypoint b/keypoint new file mode 100644 index 0000000000000000000000000000000000000000..6e4f4b9786f99713afc974dd167b79d2a43e052d --- /dev/null +++ b/keypoint @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ea78fa76aaf94976a8eca428d6e3c59697a93430cba1a4603e20284b61f5113 +size 264964645 diff --git a/keypoint.pt b/keypoint.pt new file mode 100644 index 0000000000000000000000000000000000000000..9c661ffe3061ab8b1cf96f91a67818cd51d7cbd2 --- /dev/null +++ b/keypoint.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dd10dba85895c92760cdb5a99c5cfca899c68f361a66c5448f38a187280ee1f +size 6849672 diff --git a/keypoint_evaluation.py b/keypoint_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..e1a11d3ab1673572e1428ce34bec248dc53671bc --- /dev/null +++ b/keypoint_evaluation.py @@ -0,0 +1,956 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from typing import List, Tuple, Optional +from pathlib import Path +import numpy as np +from numpy import extract, ndarray, array, float32, uint8 +import copy + +import cv2 + +# Try to import PyTorch for GPU-accelerated warping +try: + import torch + import torch.nn.functional as F + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + torch = None + F = None + +# Import cv2 functions +bitwise_and = cv2.bitwise_and +findHomography = cv2.findHomography +warpPerspective = cv2.warpPerspective +cvtColor = cv2.cvtColor +COLOR_BGR2GRAY = cv2.COLOR_BGR2GRAY +threshold = cv2.threshold +THRESH_BINARY = cv2.THRESH_BINARY +getStructuringElement = cv2.getStructuringElement +MORPH_RECT = cv2.MORPH_RECT +MORPH_TOPHAT = cv2.MORPH_TOPHAT +GaussianBlur = cv2.GaussianBlur +morphologyEx = cv2.morphologyEx +Canny = cv2.Canny +connectedComponents = cv2.connectedComponents +perspectiveTransform = cv2.perspectiveTransform +RETR_EXTERNAL = cv2.RETR_EXTERNAL +CHAIN_APPROX_SIMPLE = cv2.CHAIN_APPROX_SIMPLE +findContours = cv2.findContours +boundingRect = cv2.boundingRect +dilate = cv2.dilate + +logger = logging.getLogger(__name__) + +# Template keypoints constant - define your keypoints here +# Format: List of (x, y) tuples representing keypoint coordinates on the template image +TEMPLATE_KEYPOINTS: list[tuple[int, int]] = [ + (5, 5), # 1 + (5, 140), # 2 + (5, 250), # 3 + (5, 430), # 4 + (5, 540), # 5 + (5, 675), # 6 + # ------------- + (55, 250), # 7 + (55, 430), # 8 + # ------------- + (110, 340), # 9 + # ------------- + (165, 140), # 10 + (165, 270), # 11 + (165, 410), # 12 + (165, 540), # 13 + # ------------- + (527, 5), # 14 + (527, 253), # 15 + (527, 433), # 16 + (527, 675), # 17 + # ------------- + (888, 140), # 18 + (888, 270), # 19 + (888, 410), # 20 + (888, 540), # 21 + # ------------- + (940, 340), # 22 + # ------------- + (998, 250), # 23 + (998, 430), # 24 + # ------------- + (1045, 5), # 25 + (1045, 140), # 26 + (1045, 250), # 27 + (1045, 430), # 28 + (1045, 540), # 29 + (1045, 675), # 30 + # ------------- + (435, 340), # 31 + (615, 340), # 32 +] + +INDEX_KEYPOINT_CORNER_BOTTOM_LEFT = 5 +INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT = 29 +INDEX_KEYPOINT_CORNER_TOP_LEFT = 0 +INDEX_KEYPOINT_CORNER_TOP_RIGHT = 24 + + +class InvalidMask(Exception): + """Exception raised when mask validation fails.""" + pass + + +def has_a_wide_line(mask: ndarray, max_aspect_ratio: float = 1.0) -> bool: + contours, _ = findContours(mask, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE) + for cnt in contours: + x, y, w, h = boundingRect(cnt) + aspect_ratio = min(w, h) / max(w, h) + # print(f"Aspect ratio: {aspect_ratio}, width: {w}, height: {h}") + if aspect_ratio >= max_aspect_ratio: + return True + return False + + +def is_bowtie(points: ndarray) -> bool: + def segments_intersect(p1: int, p2: int, q1: int, q2: int) -> bool: + def ccw(a: int, b: int, c: int): + return (c[1] - a[1]) * (b[0] - a[0]) > (b[1] - a[1]) * (c[0] - a[0]) + + return (ccw(p1, q1, q2) != ccw(p2, q1, q2)) and ( + ccw(p1, p2, q1) != ccw(p1, p2, q2) + ) + + pts = points.reshape(-1, 2) + edges = [(pts[0], pts[1]), (pts[1], pts[2]), (pts[2], pts[3]), (pts[3], pts[0])] + return segments_intersect(*edges[0], *edges[2]) or segments_intersect( + *edges[1], *edges[3] + ) + +def validate_mask_lines(mask: ndarray) -> None: + if mask.sum() == 0: + raise InvalidMask("No projected lines") + if mask.sum() == mask.size: + raise InvalidMask("Projected lines cover the entire image surface") + if has_a_wide_line(mask=mask): + raise InvalidMask("A projected line is too wide") + + +def validate_mask_ground(mask: ndarray) -> None: + num_labels, _ = connectedComponents(mask) + num_distinct_regions = num_labels - 1 + if num_distinct_regions > 1: + raise InvalidMask( + f"Projected ground should be a single object, detected {num_distinct_regions}" + ) + area_covered = mask.sum() / mask.size + if area_covered >= 0.9: + raise InvalidMask( + f"Projected ground covers more than {area_covered:.2f}% of the image surface which is unrealistic" + ) + + +def validate_projected_corners( + source_keypoints: list[tuple[int, int]], homography_matrix: ndarray +) -> None: + src_corners = array( + [ + source_keypoints[INDEX_KEYPOINT_CORNER_BOTTOM_LEFT], + source_keypoints[INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT], + source_keypoints[INDEX_KEYPOINT_CORNER_TOP_RIGHT], + source_keypoints[INDEX_KEYPOINT_CORNER_TOP_LEFT], + ], + dtype="float32", + )[None, :, :] + + warped_corners = perspectiveTransform(src_corners, homography_matrix)[0] + + if is_bowtie(warped_corners): + raise InvalidMask("Projection twisted!") + + +def project_image_using_keypoints( + image: ndarray, + source_keypoints: List[Tuple[int, int]], + destination_keypoints: List[Tuple[int, int]], + destination_width: int, + destination_height: int, + inverse: bool = False, +) -> ndarray: + """Project image using homography from source to destination keypoints.""" + filtered_src = [] + filtered_dst = [] + + for src_pt, dst_pt in zip(source_keypoints, destination_keypoints): + if dst_pt[0] == 0.0 and dst_pt[1] == 0.0: # ignore default / missing points + continue + filtered_src.append(src_pt) + filtered_dst.append(dst_pt) + + if len(filtered_src) < 4: + raise ValueError("At least 4 valid keypoints are required for homography.") + + source_points = array(filtered_src, dtype=float32) + destination_points = array(filtered_dst, dtype=float32) + + if inverse: + result = findHomography(destination_points, source_points) + if result is None: + raise ValueError("Failed to compute inverse homography.") + H_inv, _ = result + return warpPerspective(image, H_inv, (destination_width, destination_height)) + + result = findHomography(source_points, destination_points) + if result is None: + raise ValueError("Failed to compute homography.") + H, _ = result + projected_image = warpPerspective(image, H, (destination_width, destination_height)) + + validate_projected_corners(source_keypoints=source_keypoints, homography_matrix=H) + return projected_image + + +def extract_masks_for_ground_and_lines( + image: ndarray, +) -> Tuple[ndarray, ndarray]: + """Extract masks for ground (gray) and lines (white) from template image.""" + gray = cvtColor(image, COLOR_BGR2GRAY) + _, mask_ground = threshold(gray, 10, 255, THRESH_BINARY) + _, mask_lines = threshold(gray, 200, 255, THRESH_BINARY) + mask_ground_binary = (mask_ground > 0).astype(uint8) + mask_lines_binary = (mask_lines > 0).astype(uint8) + validate_mask_ground(mask=mask_ground_binary) + validate_mask_lines(mask=mask_lines_binary) + return mask_ground_binary, mask_lines_binary + + +def extract_masks_for_ground_and_lines_no_validation( + image: ndarray, +) -> Tuple[ndarray, ndarray]: + """ + Extract masks for ground (gray) and lines (white) from template image WITHOUT validation. + This is useful for line distribution analysis where exact fitting might create invalid masks + but we still want to analyze where lines are located. + """ + gray = cvtColor(image, COLOR_BGR2GRAY) + _, mask_ground = threshold(gray, 10, 255, THRESH_BINARY) + _, mask_lines = threshold(gray, 200, 255, THRESH_BINARY) + mask_ground_binary = (mask_ground > 0).astype(uint8) + mask_lines_binary = (mask_lines > 0).astype(uint8) + # No validation - return masks as-is + return mask_ground_binary, mask_lines_binary + + +def extract_mask_of_ground_lines_in_image( + image: ndarray, + ground_mask: ndarray, + blur_ksize: int = 5, + canny_low: int = 30, + canny_high: int = 100, + use_tophat: bool = True, + dilate_kernel_size: int = 3, + dilate_iterations: int = 3, +) -> ndarray: + """Extract line mask from image using edge detection on ground region.""" + gray = cvtColor(image, COLOR_BGR2GRAY) + + if use_tophat: + kernel = getStructuringElement(MORPH_RECT, (31, 31)) + gray = morphologyEx(gray, MORPH_TOPHAT, kernel) + + if blur_ksize and blur_ksize % 2 == 1: + gray = GaussianBlur(gray, (blur_ksize, blur_ksize), 0) + + image_edges = Canny(gray, canny_low, canny_high) + image_edges_on_ground = bitwise_and(image_edges, image_edges, mask=ground_mask) + + if dilate_kernel_size > 1: + dilate_kernel = getStructuringElement( + MORPH_RECT, (dilate_kernel_size, dilate_kernel_size) + ) + image_edges_on_ground = dilate( + image_edges_on_ground, dilate_kernel, iterations=dilate_iterations + ) + + return (image_edges_on_ground > 0).astype(uint8) + + +def evaluate_keypoints_for_frame( + template_keypoints: List[Tuple[int, int]], + frame_keypoints: List[Tuple[int, int]], + frame: ndarray, + floor_markings_template: ndarray, +) -> float: + """ + Evaluate keypoint accuracy for a single frame. + + Returns score between 0.0 and 1.0 based on overlap between + projected template lines and detected lines in frame. + """ + try: + warped_template = project_image_using_keypoints( + image=floor_markings_template, + source_keypoints=template_keypoints, + destination_keypoints=frame_keypoints, + destination_width=frame.shape[1], + destination_height=frame.shape[0], + ) + + mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines( + image=warped_template + ) + + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame, ground_mask=mask_ground + ) + + pixels_overlapping = bitwise_and( + mask_lines_expected, mask_lines_predicted + ).sum() + + pixels_on_lines = mask_lines_expected.sum() + + score = pixels_overlapping / (pixels_on_lines + 1e-8) + + return min(1.0, max(0.0, score)) # Clamp to [0, 1] + + except (InvalidMask, ValueError) as e: + print(f'InvalidMask or ValueError in keypoint evaluation: {e}') + return 0.0 + except Exception as e: + print(f'Unexpected error in keypoint evaluation: {e}') + return 0.0 + +def warp_image_pytorch( + image: ndarray, + homography_matrix: ndarray, + output_width: int, + output_height: int, + device: str = "cuda", +) -> ndarray: + """ + Warp image using PyTorch (GPU-accelerated) instead of cv2.warpPerspective. + + Args: + image: Input image to warp (H, W, C) numpy array + homography_matrix: 3x3 homography matrix + output_width: Output image width + output_height: Output image height + device: "cuda" or "cpu" + + Returns: + Warped image as numpy array + """ + if not TORCH_AVAILABLE: + # Fallback to OpenCV if PyTorch not available + return warpPerspective(image, homography_matrix, (output_width, output_height)) + + # Auto-detect device + if device == "cuda" and (not torch.cuda.is_available()): + device = "cpu" + + try: + # Convert to tensor and move to device + image_tensor = torch.from_numpy(image).to(device).float() + H = torch.from_numpy(homography_matrix).to(device).float() + + # Get image dimensions + h, w = image.shape[:2] + if len(image.shape) == 2: + # Grayscale + image_tensor = image_tensor.unsqueeze(2) # Add channel dimension + channels = 1 + else: + channels = image.shape[2] + + # Create coordinate grid for output image + y_coords, x_coords = torch.meshgrid( + torch.arange(0, output_height, device=device, dtype=torch.float32), + torch.arange(0, output_width, device=device, dtype=torch.float32), + indexing='ij' + ) + + # Apply inverse homography to get source coordinates + ones = torch.ones_like(x_coords) + coords = torch.stack([x_coords.flatten(), y_coords.flatten(), ones.flatten()], dim=0) + H_inv = torch.inverse(H) + src_coords = H_inv @ coords + src_coords = src_coords[:2] / (src_coords[2:3] + 1e-8) + + # Reshape and normalize to [-1, 1] for grid_sample + src_x = src_coords[0].reshape(output_height, output_width) + src_y = src_coords[1].reshape(output_height, output_width) + + # Normalize coordinates to [-1, 1] for grid_sample + src_x_norm = 2.0 * src_x / (w - 1) - 1.0 + src_y_norm = 2.0 * src_y / (h - 1) - 1.0 + grid = torch.stack([src_x_norm, src_y_norm], dim=-1).unsqueeze(0) # [1, H, W, 2] + + # Prepare image tensor: [1, C, H, W] + image_batch = image_tensor.permute(2, 0, 1).unsqueeze(0) + + # Warp using grid_sample + warped = F.grid_sample( + image_batch, grid, mode='bilinear', padding_mode='zeros', align_corners=True + ) + + # Convert back to numpy: [H, W, C] + warped = warped.squeeze(0).permute(1, 2, 0) + + # Remove channel dimension if grayscale + if channels == 1: + warped = warped.squeeze(2) + + # Convert to uint8 and return as numpy + warped_np = warped.cpu().numpy().clip(0, 255).astype(np.uint8) + return warped_np + + except Exception as e: + logger.error(f"PyTorch warping failed: {e}, falling back to OpenCV") + return warpPerspective(image, homography_matrix, (output_width, output_height)) + + +def evaluate_keypoints_for_frame_gpu( + template_keypoints: List[Tuple[int, int]], + frame_keypoints: List[Tuple[int, int]], + frame: ndarray, + floor_markings_template: ndarray, + device: str = "cuda", +) -> float: + """ + GPU-accelerated keypoint evaluation using PyTorch for warping. + + This function uses PyTorch's grid_sample for GPU-accelerated image warping + instead of cv2.warpPerspective, making it compatible with PyTorch CUDA. + + Args: + template_keypoints: Template keypoint coordinates + frame_keypoints: Frame keypoint coordinates + frame: Input frame image + floor_markings_template: Template image + device: "cuda" or "cpu" (auto-detects if CUDA available) + + Returns: + Score between 0.0 and 1.0 + """ + if not TORCH_AVAILABLE: + # Fallback to CPU version if PyTorch not available + return evaluate_keypoints_for_frame( + template_keypoints, frame_keypoints, frame, floor_markings_template + ) + + # Auto-detect device + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + + try: + # Step 1: Compute homography (CPU - small operation) + filtered_src = [] + filtered_dst = [] + for src_pt, dst_pt in zip(template_keypoints, frame_keypoints): + if dst_pt[0] == 0.0 and dst_pt[1] == 0.0: + continue + filtered_src.append(src_pt) + filtered_dst.append(dst_pt) + + if len(filtered_src) < 4: + return 0.0 + + source_points = array(filtered_src, dtype=float32) + destination_points = array(filtered_dst, dtype=float32) + result = findHomography(source_points, destination_points) + if result is None: + return 0.0 + H, _ = result + + # Validate corners + src_corners = array([ + template_keypoints[INDEX_KEYPOINT_CORNER_BOTTOM_LEFT], + template_keypoints[INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT], + template_keypoints[INDEX_KEYPOINT_CORNER_TOP_RIGHT], + template_keypoints[INDEX_KEYPOINT_CORNER_TOP_LEFT], + ], dtype=float32)[None, :, :] + warped_corners = perspectiveTransform(src_corners, H)[0] + if is_bowtie(warped_corners): + return 0.0 + + # Step 2: Warp template using PyTorch (GPU-accelerated) + h, w = frame.shape[:2] + warped_template = warp_image_pytorch( + floor_markings_template, + H, + w, + h, + device=device + ) + + # Step 3: Extract masks (CPU - OpenCV operations) + mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines( + image=warped_template + ) + + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame, ground_mask=mask_ground + ) + + # Step 4: Compute overlap + pixels_overlapping = bitwise_and( + mask_lines_expected, mask_lines_predicted + ).sum() + + pixels_on_lines = mask_lines_expected.sum() + + score = pixels_overlapping / (pixels_on_lines + 1e-8) + return min(1.0, max(0.0, score)) + + except (InvalidMask, ValueError) as e: + logger.debug(f"Keypoint evaluation failed: {e}") + return 0.0 + except Exception as e: + logger.error(f"GPU evaluation failed: {e}, falling back to CPU") + return evaluate_keypoints_for_frame( + template_keypoints, frame_keypoints, frame, floor_markings_template + ) + + +# Cache for template GpuMat to avoid re-uploading on every frame +_template_gpumat_cache = None +_template_cache_key = None +_cuda_available_cache = None +_cuda_module_cache = None +_frame_gpumat_reusable = None # Reusable GpuMat for frames (same size) +_frame_gpumat_size = None # Size of the reusable frame GpuMat + +def evaluate_keypoints_for_frame_opencv_cuda( + template_keypoints: List[Tuple[int, int]], + frame_keypoints: List[Tuple[int, int]], + frame: ndarray, + floor_markings_template: ndarray, + device: str = "cuda", +) -> float: + """ + GPU-accelerated version using OpenCV CUDA (if available). + Falls back to CPU if CUDA not available. + + Note: opencv-python-headless doesn't include CUDA support, so this will + always fall back to CPU. Use evaluate_keypoints_for_frame_gpu for PyTorch GPU acceleration. + + Optimizations: + - Template GpuMat is cached to avoid re-uploading + - CUDA availability check is cached + - Frame GpuMat is reused when frame size matches + - Keypoint filtering optimized with list comprehension + + Args: + device: Ignored (kept for compatibility). OpenCV CUDA check is automatic. + """ + global _template_gpumat_cache, _template_cache_key + global _cuda_available_cache, _cuda_module_cache, _frame_gpumat_reusable, _frame_gpumat_size + + # Cache CUDA availability check (only check once) + if _cuda_available_cache is None: + cuda_available = False + cuda = None + try: + import cv2.cuda as cuda + # Check if cv2.cuda actually has CUDA functions (not just a stub) + if hasattr(cuda, 'warpPerspective'): + # Try to create a GpuMat to verify CUDA is actually working + try: + test_mat = cuda.GpuMat() + test_mat.upload(np.zeros((10, 10, 3), dtype=np.uint8)) + cuda_available = True + except (AttributeError, Exception): + # GpuMat exists but doesn't work (stub module) + cuda_available = False + except (ImportError, AttributeError): + cuda_available = False + + _cuda_available_cache = cuda_available + _cuda_module_cache = cuda + else: + cuda_available = _cuda_available_cache + cuda = _cuda_module_cache + + # Always use CPU version since opencv-python-headless doesn't have CUDA + # The check above will fail, so we fall back to CPU + if not cuda_available: + # Use CPU version (this is what will happen with opencv-python-headless) + return evaluate_keypoints_for_frame( + template_keypoints, frame_keypoints, frame, floor_markings_template + ) + + # If we get here, OpenCV CUDA is actually available (unlikely with opencv-python-headless) + try: + # Create cache key based on template image shape and a fast checksum + # Using shape + sum of corner pixels for fast comparison (much faster than full hash) + template_shape = floor_markings_template.shape + # Quick checksum: sum of corner pixels (fast to compute) + checksum = ( + int(floor_markings_template[0, 0].sum()) + + int(floor_markings_template[0, -1].sum()) + + int(floor_markings_template[-1, 0].sum()) + + int(floor_markings_template[-1, -1].sum()) + ) + current_cache_key = (template_shape, checksum) + + # Check if we need to update the cached GpuMat + if _template_gpumat_cache is None or _template_cache_key != current_cache_key: + # Upload template to GPU (only once or when template changes) + _template_gpumat_cache = cuda.GpuMat() + _template_gpumat_cache.upload(floor_markings_template) + _template_cache_key = current_cache_key + + # Optimize frame upload: reuse GpuMat if frame size matches + h, w = frame.shape[:2] + frame_shape = (h, w) + if _frame_gpumat_reusable is None or _frame_gpumat_size != frame_shape: + _frame_gpumat_reusable = cuda.GpuMat() + _frame_gpumat_size = frame_shape + gpu_frame = _frame_gpumat_reusable + gpu_frame.upload(frame) + + # Use cached template GpuMat + gpu_template = _template_gpumat_cache + + # Optimize keypoint filtering with list comprehension (faster than loop) + filtered_pairs = [(src_pt, dst_pt) for src_pt, dst_pt in zip(template_keypoints, frame_keypoints) + if not (dst_pt[0] == 0.0 and dst_pt[1] == 0.0)] + + if len(filtered_pairs) < 4: + return 0.0 + + # Unpack filtered pairs + filtered_src, filtered_dst = zip(*filtered_pairs) + + # Compute homography (CPU - small operation, fast) + source_points = array(filtered_src, dtype=float32) + destination_points = array(filtered_dst, dtype=float32) + result = findHomography(source_points, destination_points) + if result is None: + return 0.0 + H, _ = result + + # Warp on GPU + gpu_warped = cuda.warpPerspective(gpu_template, H, (w, h)) + + # Download for mask extraction (unavoidable - mask extraction uses CPU OpenCV) + warped_template = gpu_warped.download() + + # Rest of the pipeline (CPU operations - these are fast) + mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines(warped_template) + mask_lines_predicted = extract_mask_of_ground_lines_in_image(frame, mask_ground) + + # Overlap computation (using cv2.bitwise_and for consistency) + pixels_overlapping = bitwise_and(mask_lines_expected, mask_lines_predicted).sum() + pixels_on_lines = mask_lines_expected.sum() + score = pixels_overlapping / (pixels_on_lines + 1e-8) + return min(1.0, max(0.0, score)) + + except Exception as e: + logger.error(f"OpenCV CUDA evaluation failed: {e}, falling back to CPU") + return evaluate_keypoints_for_frame( + template_keypoints, frame_keypoints, frame, floor_markings_template + ) + +def evaluate_keypoints_batch_gpu( + template_keypoints: List[Tuple[int, int]], + frame_keypoints_list: List[List[Tuple[int, int]]], + frames: List[ndarray], + floor_markings_template: ndarray, + device: str = "cuda", +) -> List[float]: + """ + Batch GPU-accelerated keypoint evaluation for multiple frames simultaneously. + + This function processes multiple frames in parallel using PyTorch batch operations, + which is much faster than evaluating frames one-by-one. + + Args: + template_keypoints: Template keypoint coordinates (same for all frames) + frame_keypoints_list: List of frame keypoint coordinates (one per frame) + frames: List of frame images (numpy arrays) + floor_markings_template: Template image + device: "cuda" or "cpu" + + Returns: + List of scores (one per frame) between 0.0 and 1.0 + """ + if not TORCH_AVAILABLE: + # Fallback to sequential CPU evaluation + return [ + evaluate_keypoints_for_frame( + template_keypoints, kp, frame, floor_markings_template + ) + for kp, frame in zip(frame_keypoints_list, frames) + ] + + # Auto-detect device + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + + batch_size = len(frames) + if batch_size == 0: + return [] + + # Get frame dimensions (assuming all frames have same size) + h, w = frames[0].shape[:2] + + try: + # Step 1: Compute homographies for all frames (CPU - vectorized where possible) + homographies = [] + valid_indices = [] + + for idx, (frame_keypoints, frame) in enumerate(zip(frame_keypoints_list, frames)): + # Filter keypoints + filtered_pairs = [(src_pt, dst_pt) for src_pt, dst_pt in zip(template_keypoints, frame_keypoints) + if not (dst_pt[0] == 0.0 and dst_pt[1] == 0.0)] + + if len(filtered_pairs) < 4: + continue + + filtered_src, filtered_dst = zip(*filtered_pairs) + source_points = array(filtered_src, dtype=float32) + destination_points = array(filtered_dst, dtype=float32) + result = findHomography(source_points, destination_points) + if result is None: + continue + H, _ = result + + # Validate corners + src_corners = array([ + template_keypoints[INDEX_KEYPOINT_CORNER_BOTTOM_LEFT], + template_keypoints[INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT], + template_keypoints[INDEX_KEYPOINT_CORNER_TOP_RIGHT], + template_keypoints[INDEX_KEYPOINT_CORNER_TOP_LEFT], + ], dtype=float32)[None, :, :] + warped_corners = perspectiveTransform(src_corners, H)[0] + if not is_bowtie(warped_corners): + homographies.append(H) + valid_indices.append(idx) + + if len(homographies) == 0: + return [0.0] * batch_size + + # Step 2: Batch warp using PyTorch (much faster than sequential) + template_tensor = torch.from_numpy(floor_markings_template).to(device).float() + t_h, t_w = floor_markings_template.shape[:2] + + if len(floor_markings_template.shape) == 2: + template_tensor = template_tensor.unsqueeze(2) + t_channels = 1 + else: + t_channels = floor_markings_template.shape[2] + + # Prepare template batch: [B, C, H, W] + template_batch = template_tensor.permute(2, 0, 1).unsqueeze(0).repeat(len(homographies), 1, 1, 1) + + # Create coordinate grids for all frames + y_coords, x_coords = torch.meshgrid( + torch.arange(0, h, device=device, dtype=torch.float32), + torch.arange(0, w, device=device, dtype=torch.float32), + indexing='ij' + ) + ones = torch.ones_like(x_coords) + coords = torch.stack([x_coords.flatten(), y_coords.flatten(), ones.flatten()], dim=0) # [3, H*W] + + # Batch process homographies + H_tensors = torch.from_numpy(np.stack(homographies)).to(device).float() # [B, 3, 3] + H_inv_batch = torch.inverse(H_tensors) # [B, 3, 3] + + # Apply inverse homography for each frame: [B, 3, 3] @ [3, H*W] -> [B, 3, H*W] + coords_expanded = coords.unsqueeze(0).expand(len(homographies), -1, -1) # [B, 3, H*W] + src_coords_batch = torch.bmm(H_inv_batch, coords_expanded) # [B, 3, H*W] + src_coords_batch = src_coords_batch[:, :2] / (src_coords_batch[:, 2:3] + 1e-8) # [B, 2, H*W] + + # Reshape and normalize to [-1, 1] for grid_sample + src_x_batch = src_coords_batch[:, 0].reshape(len(homographies), h, w) + src_y_batch = src_coords_batch[:, 1].reshape(len(homographies), h, w) + src_x_norm = 2.0 * src_x_batch / (t_w - 1) - 1.0 + src_y_norm = 2.0 * src_y_batch / (t_h - 1) - 1.0 + grid_batch = torch.stack([src_x_norm, src_y_norm], dim=-1) # [B, H, W, 2] + + # Batch warp using grid_sample (all frames at once!) + warped_batch = F.grid_sample( + template_batch, grid_batch, mode='bilinear', padding_mode='zeros', align_corners=True + ) # [B, C, H, W] + + # Convert back to numpy: [B, H, W, C] + warped_batch = warped_batch.permute(0, 2, 3, 1) + if t_channels == 1: + warped_batch = warped_batch.squeeze(3) + warped_templates = warped_batch.cpu().numpy().clip(0, 255).astype(np.uint8) + + # Step 3: Batch mask extraction and evaluation on GPU + scores = [0.0] * batch_size + + # Convert to tensors for batch processing + warped_templates_tensor = torch.from_numpy(warped_templates).to(device).float() + frames_tensor = torch.from_numpy(np.stack([frames[i] for i in valid_indices])).to(device).float() + + # Batch extract masks for warped templates (GPU) + # Convert to grayscale + if len(warped_templates_tensor.shape) == 4: # [B, H, W, C] + gray_templates = (warped_templates_tensor[:, :, :, 0] * 0.299 + + warped_templates_tensor[:, :, :, 1] * 0.587 + + warped_templates_tensor[:, :, :, 2] * 0.114) + else: + gray_templates = warped_templates_tensor + + # Threshold for ground and lines (batch operation) + mask_ground_batch = (gray_templates > 10.0).float() # [B, H, W] + mask_lines_expected_batch = (gray_templates > 200.0).float() # [B, H, W] + + # Batch extract predicted lines from frames (GPU) + if len(frames_tensor.shape) == 4: # [B, H, W, C] + gray_frames = (frames_tensor[:, :, :, 0] * 0.299 + + frames_tensor[:, :, :, 1] * 0.587 + + frames_tensor[:, :, :, 2] * 0.114) + else: + gray_frames = frames_tensor + + # Simplified edge detection (batch Sobel) + # Sobel kernels + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + device=device, dtype=torch.float32).unsqueeze(0).unsqueeze(0) + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], + device=device, dtype=torch.float32).unsqueeze(0).unsqueeze(0) + + # Apply Sobel to batch + gray_frames_batch = gray_frames.unsqueeze(1) # [B, 1, H, W] + grad_x_batch = F.conv2d(gray_frames_batch, sobel_x, padding=1) + grad_y_batch = F.conv2d(gray_frames_batch, sobel_y, padding=1) + magnitude_batch = torch.sqrt(grad_x_batch.squeeze(1) ** 2 + grad_y_batch.squeeze(1) ** 2 + 1e-8) + edges_batch = (magnitude_batch > 30.0).float() # [B, H, W] + + # Apply ground mask + mask_lines_predicted_batch = edges_batch * mask_ground_batch + + # Batch overlap computation (all on GPU!) + pixels_overlapping_batch = (mask_lines_expected_batch * mask_lines_predicted_batch).sum(dim=(1, 2)) # [B] + pixels_on_lines_batch = mask_lines_expected_batch.sum(dim=(1, 2)) # [B] + scores_batch = (pixels_overlapping_batch / (pixels_on_lines_batch + 1e-8)).cpu().numpy() + + # Fill in scores for valid indices + for batch_idx, valid_idx in enumerate(valid_indices): + scores[valid_idx] = min(1.0, max(0.0, float(scores_batch[batch_idx]))) + + return scores + + except Exception as e: + logger.error(f"Batch GPU evaluation failed: {e}, falling back to sequential CPU") + return [ + evaluate_keypoints_for_frame( + template_keypoints, kp, frame, floor_markings_template + ) + for kp, frame in zip(frame_keypoints_list, frames) + ] + + +def evaluate_keypoints_batch_for_frame( + template_keypoints: List[Tuple[int, int]], + frame_keypoints_list: List[List[Tuple[int, int]]], + frame: ndarray, + floor_markings_template: ndarray, + device: str = "cuda", + batch_size: int = 32, +) -> List[float]: + """ + Fast batch GPU evaluation of multiple keypoint sets for a single frame. + + This function evaluates multiple keypoint sets (e.g., from different models) + for the same frame using batch GPU processing, which is much faster than + evaluating them sequentially. + + Args: + template_keypoints: Template keypoint coordinates + frame_keypoints_list: List of frame keypoint coordinate sets to evaluate + frame: Single frame image (same for all keypoint sets) + floor_markings_template: Template image + device: "cuda" or "cpu" + batch_size: Number of keypoint sets to process in each GPU batch + + Returns: + List of scores (one per keypoint set) between 0.0 and 1.0 + """ + if len(frame_keypoints_list) == 0: + return [] + + if len(frame_keypoints_list) == 1: + # Single evaluation - use regular function + return [evaluate_keypoints_for_frame_opencv_cuda( + template_keypoints=template_keypoints, + frame_keypoints=frame_keypoints_list[0], + frame=frame, + floor_markings_template=floor_markings_template, + device=device + )] + + # For multiple keypoint sets, use batch processing + # Create list of frames (same frame repeated) + frames_list = [frame] * len(frame_keypoints_list) + + # Use batch GPU evaluation + try: + scores = evaluate_keypoints_batch_gpu( + template_keypoints=template_keypoints, + frame_keypoints_list=frame_keypoints_list, + frames=frames_list, + floor_markings_template=floor_markings_template, + device=device, + ) + return scores + except Exception as e: + logger.warning(f"Batch GPU evaluation failed: {e}, falling back to sequential") + # Fallback to sequential evaluation + scores = [] + for frame_keypoints in frame_keypoints_list: + try: + score = evaluate_keypoints_for_frame_opencv_cuda( + template_keypoints=template_keypoints, + frame_keypoints=frame_keypoints, + frame=frame, + floor_markings_template=floor_markings_template, + device=device + ) + scores.append(score) + except Exception as e2: + logger.debug(f"Error evaluating keypoints: {e2}") + scores.append(0.0) + return scores + + +def load_template_from_file( + template_image_path: str, +) -> Tuple[ndarray, List[Tuple[int, int]]]: + """ + Load template image and use TEMPLATE_KEYPOINTS constant for keypoints. + + Args: + template_image_path: Path to template image file + + Returns: + template_image: Loaded template image + template_keypoints: List of (x, y) keypoint coordinates from TEMPLATE_KEYPOINTS constant + """ + # Load template image + template_image = cv2.imread(template_image_path) + if template_image is None: + raise ValueError(f"Could not load template image from {template_image_path}") + + # Use TEMPLATE_KEYPOINTS constant + if len(TEMPLATE_KEYPOINTS) == 0: + raise ValueError( + "TEMPLATE_KEYPOINTS constant is empty. Please define keypoints in keypoint_evaluation.py" + ) + + if len(TEMPLATE_KEYPOINTS) < 4: + raise ValueError(f"TEMPLATE_KEYPOINTS must have at least 4 keypoints, found {len(TEMPLATE_KEYPOINTS)}") + + logger.info(f"Loaded template image: {template_image_path}") + logger.info(f"Using TEMPLATE_KEYPOINTS constant with {len(TEMPLATE_KEYPOINTS)} keypoints") + + return template_image, TEMPLATE_KEYPOINTS + + diff --git a/keypoint_helper.py b/keypoint_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..9a164cdc2ee345aa61b6312bc66244050a795f0d --- /dev/null +++ b/keypoint_helper.py @@ -0,0 +1,115 @@ + +import numpy as np +from typing import List, Tuple, Sequence, Any + +FOOTBALL_KEYPOINTS: list[tuple[int, int]] = [ + (0, 0), # 1 + (0, 0), # 2 + (0, 0), # 3 + (0, 0), # 4 + (0, 0), # 5 + (0, 0), # 6 + + (0, 0), # 7 + (0, 0), # 8 + (0, 0), # 9 + + (0, 0), # 10 + (0, 0), # 11 + (0, 0), # 12 + (0, 0), # 13 + + (0, 0), # 14 + (527, 283), # 15 + (527, 403), # 16 + (0, 0), # 17 + + (0, 0), # 18 + (0, 0), # 19 + (0, 0), # 20 + (0, 0), # 21 + + (0, 0), # 22 + + (0, 0), # 23 + (0, 0), # 24 + + (0, 0), # 25 + (0, 0), # 26 + (0, 0), # 27 + (0, 0), # 28 + (0, 0), # 29 + (0, 0), # 30 + + (405, 340), # 31 + (645, 340), # 32 +] + +def convert_keypoints_to_val_format(keypoints): + return [tuple(int(x) for x in pair) for pair in keypoints] + +def predict_failed_indices(results_frames: Sequence[Any]) -> list[int]: + + max_frames = len(results_frames) + if max_frames == 0: + return [] + + failed_indices: list[int] = [] + for frame_index, frame_result in enumerate(results_frames): + frame_keypoints = getattr(frame_result, "keypoints", []) or [] + non_zero_count = sum(1 for (x, y) in frame_keypoints if int(x) != 0 and int(y) != 0) + if non_zero_count <= 4: + failed_indices.append(frame_index) + return failed_indices + +def _generate_sparse_template_keypoints(frame_width: int, frame_height: int) -> list[tuple[int, int]]: + template_max_x, template_max_y = (1045, 675) + sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1) + sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1) + scaled: list[tuple[int, int]] = [] + for i in range(32): + tx, ty = FOOTBALL_KEYPOINTS[i] + x_scaled = int(round(tx * sx)) + y_scaled = int(round(ty * sy)) + scaled.append((x_scaled, y_scaled)) + return scaled + +def fix_keypoints( + results_frames: Sequence[Any], + failed_indices: Sequence[int], + frame_width: int, + frame_height: int, +) -> list[Any]: + max_frames = len(results_frames) + if max_frames == 0: + return list(results_frames) + + failed_set = set(int(i) for i in failed_indices) + all_indices = list(range(max_frames)) + successful_indices = [i for i in all_indices if i not in failed_set] + + if len(successful_indices) == 0: + sparse_template = _generate_sparse_template_keypoints(frame_width, frame_height) + for frame_result in results_frames: + setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(sparse_template))) + return list(results_frames) + + seed_index = successful_indices[0] + seed_kps_raw = getattr(results_frames[seed_index], "keypoints", []) or [] + last_success_kps = convert_keypoints_to_val_format(seed_kps_raw) + + for frame_index in range(max_frames): + frame_result = results_frames[frame_index] + if frame_index in failed_set: + setattr(frame_result, "keypoints", list(last_success_kps)) + else: + current_kps_raw = getattr(frame_result, "keypoints", []) or [] + current_kps = convert_keypoints_to_val_format(current_kps_raw) + setattr(frame_result, "keypoints", list(current_kps)) + last_success_kps = current_kps + + return list(results_frames) + +def run_keypoints_post_processing(results_frames: Sequence[Any], frame_width: int, frame_height: int) -> list[Any]: + failed_indices = predict_failed_indices(results_frames) + return fix_keypoints(results_frames, failed_indices, frame_width, frame_height) \ No newline at end of file diff --git a/keypoint_helper_v2.py b/keypoint_helper_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..e91d91cc49397b644e4d6d30f83fbfbf21efb650 --- /dev/null +++ b/keypoint_helper_v2.py @@ -0,0 +1,3720 @@ + +import time +import numpy as np +import cv2 +from typing import List, Tuple, Sequence, Any +from numpy import ndarray + +FOOTBALL_KEYPOINTS: list[tuple[int, int]] = [ + (0, 0), # 1 + (0, 0), # 2 + (0, 0), # 3 + (0, 0), # 4 + (0, 0), # 5 + (0, 0), # 6 + + (0, 0), # 7 + (0, 0), # 8 + (0, 0), # 9 + + (0, 0), # 10 + (0, 0), # 11 + (0, 0), # 12 + (0, 0), # 13 + + (0, 0), # 14 + (527, 283), # 15 + (527, 403), # 16 + (0, 0), # 17 + + (0, 0), # 18 + (0, 0), # 19 + (0, 0), # 20 + (0, 0), # 21 + + (0, 0), # 22 + + (0, 0), # 23 + (0, 0), # 24 + + (0, 0), # 25 + (0, 0), # 26 + (0, 0), # 27 + (0, 0), # 28 + (0, 0), # 29 + (0, 0), # 30 + + (405, 340), # 31 + (645, 340), # 32 +] + +def convert_keypoints_to_val_format(keypoints): + return [tuple(int(x) for x in pair) for pair in keypoints] + +def validate_with_nearby_keypoints( + kp_idx: int, + kp: tuple[int, int], + valid_indices: list[int], + result: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]], + scale_factor: float = None, +) -> float: + """ + Validate a keypoint by checking distances to nearby keypoints on the same side. + + Returns validation score (lower is better), or None if validation not possible. + """ + template_kp = template_keypoints[kp_idx] + + # Define which keypoints are on the same side + # Left side: 10, 11, 12, 13 (indices 9, 10, 11, 12) + # Right side: 18, 19, 20, 21, 22, 23, 24, 25-30 (indices 17-29) + + left_side_indices = [9, 10, 11, 12] # Keypoints 10-13 + right_side_indices = list(range(17, 30)) # Keypoints 18-30 + + # Determine which side this keypoint should be on + if kp_idx in left_side_indices: + same_side_indices = left_side_indices + elif kp_idx in right_side_indices: + same_side_indices = right_side_indices + else: + return None # Can't validate + + # Find nearby keypoints on the same side that are detected + nearby_kps = [] + for nearby_idx in same_side_indices: + if nearby_idx != kp_idx and nearby_idx in valid_indices: + nearby_kp = result[nearby_idx] + nearby_template_kp = template_keypoints[nearby_idx] + nearby_kps.append((nearby_idx, nearby_kp, nearby_template_kp)) + + if len(nearby_kps) == 0: + return None # No nearby keypoints to validate with + + # Calculate distance errors to nearby keypoints + distance_errors = [] + for nearby_idx, nearby_kp, nearby_template_kp in nearby_kps: + # Detected distance + detected_dist = np.sqrt((kp[0] - nearby_kp[0])**2 + (kp[1] - nearby_kp[1])**2) + + # Template distance + template_dist = np.sqrt((template_kp[0] - nearby_template_kp[0])**2 + + (template_kp[1] - nearby_template_kp[1])**2) + + if template_dist > 0: + # Expected detected distance + if scale_factor: + expected_dist = template_dist * scale_factor + else: + expected_dist = template_dist + + if expected_dist > 0: + # Normalized error + error = abs(detected_dist - expected_dist) / expected_dist + distance_errors.append(error) + + if len(distance_errors) > 0: + return np.mean(distance_errors) + return None + +def remove_duplicate_detections( + keypoints: list[tuple[int, int]], + frame_width: int = None, + frame_height: int = None, +) -> list[tuple[int, int]]: + """ + Remove duplicate/conflicting keypoint detections using distance-based validation. + + Uses the principle that if two keypoints are detected very close together, + but in the template they should be far apart, one of them is likely wrong. + Validates each keypoint by checking if its distances to other keypoints + match the expected template distances. + + Args: + keypoints: List of 32 keypoints + frame_width: Optional frame width for validation + frame_height: Optional frame height for validation + + Returns: + Cleaned list of keypoints with duplicates removed + """ + if len(keypoints) != 32: + if len(keypoints) < 32: + keypoints = list(keypoints) + [(0, 0)] * (32 - len(keypoints)) + else: + keypoints = keypoints[:32] + + result = list(keypoints) + + try: + from keypoint_evaluation import TEMPLATE_KEYPOINTS + template_available = True + except ImportError: + template_available = False + + if not template_available: + return result + + # Get all valid detected keypoints + valid_indices = [] + for i in range(32): + if result[i][0] > 0 and result[i][1] > 0: + valid_indices.append(i) + + if len(valid_indices) < 2: + return result + + # Calculate scale factor from detected keypoints to template + # Use pairs of keypoints that are far apart in template to estimate scale + scale_factor = None + if len(valid_indices) >= 2: + max_template_dist = 0 + max_detected_dist = 0 + + for i in range(len(valid_indices)): + for j in range(i + 1, len(valid_indices)): + idx_i = valid_indices[i] + idx_j = valid_indices[j] + + template_i = TEMPLATE_KEYPOINTS[idx_i] + template_j = TEMPLATE_KEYPOINTS[idx_j] + template_dist = np.sqrt((template_i[0] - template_j[0])**2 + (template_i[1] - template_j[1])**2) + + kp_i = result[idx_i] + kp_j = result[idx_j] + detected_dist = np.sqrt((kp_i[0] - kp_j[0])**2 + (kp_i[1] - kp_j[1])**2) + + if template_dist > max_template_dist and detected_dist > 0: + max_template_dist = template_dist + max_detected_dist = detected_dist + + if max_template_dist > 0 and max_detected_dist > 0: + scale_factor = max_detected_dist / max_template_dist + + # For each keypoint, validate it by checking distances to other keypoints + keypoint_scores = {} + for idx in valid_indices: + kp = result[idx] + template_kp = TEMPLATE_KEYPOINTS[idx] + + # Calculate how well this keypoint's distances match template distances + distance_errors = [] + num_comparisons = 0 + + for other_idx in valid_indices: + if other_idx == idx: + continue + + other_kp = result[other_idx] + other_template_kp = TEMPLATE_KEYPOINTS[other_idx] + + # Calculate detected distance + detected_dist = np.sqrt((kp[0] - other_kp[0])**2 + (kp[1] - other_kp[1])**2) + + # Calculate template distance + template_dist = np.sqrt((template_kp[0] - other_template_kp[0])**2 + + (template_kp[1] - other_template_kp[1])**2) + + if template_dist > 50: # Only check keypoints that should be reasonably far apart + num_comparisons += 1 + + # Expected detected distance (scaled from template) + if scale_factor: + expected_dist = template_dist * scale_factor + else: + expected_dist = template_dist + + # Calculate error (normalized) + if expected_dist > 0: + error = abs(detected_dist - expected_dist) / expected_dist + distance_errors.append(error) + + # Score: lower is better (smaller distance errors) + if num_comparisons > 0: + avg_error = np.mean(distance_errors) + keypoint_scores[idx] = avg_error + else: + keypoint_scores[idx] = 0.0 + + # Find pairs of keypoints that are too close but should be far apart + conflicts = [] + for i in range(len(valid_indices)): + for j in range(i + 1, len(valid_indices)): + idx_i = valid_indices[i] + idx_j = valid_indices[j] + + kp_i = result[idx_i] + kp_j = result[idx_j] + + # Calculate detected distance + detected_dist = np.sqrt((kp_i[0] - kp_j[0])**2 + (kp_i[1] - kp_j[1])**2) + + # Calculate template distance + template_i = TEMPLATE_KEYPOINTS[idx_i] + template_j = TEMPLATE_KEYPOINTS[idx_j] + template_dist = np.sqrt((template_i[0] - template_j[0])**2 + + (template_i[1] - template_j[1])**2) + + # If template distance is large but detected distance is small, it's a conflict + if template_dist > 100 and detected_dist < 30: + # Enhanced validation: use nearby keypoints to determine which is correct + # For example, if we have 24 and 29, we can check distances to determine if it's 13 or 21 + score_i = keypoint_scores.get(idx_i, 1.0) + score_j = keypoint_scores.get(idx_j, 1.0) + + # Try to validate using nearby keypoints on the same side + # Keypoint 13 is on left side, keypoint 21 is on right side + # If we have right-side keypoints (like 24, 29), check distances + nearby_validation_i = validate_with_nearby_keypoints( + idx_i, kp_i, valid_indices, result, TEMPLATE_KEYPOINTS, scale_factor + ) + nearby_validation_j = validate_with_nearby_keypoints( + idx_j, kp_j, valid_indices, result, TEMPLATE_KEYPOINTS, scale_factor + ) + + # Prioritize nearby validation: if one has nearby validation and the other doesn't, + # prefer the one with nearby validation (it's more reliable) + validation_score_i = score_i + validation_score_j = score_j + + if nearby_validation_i is not None and nearby_validation_j is not None: + # Both have nearby validation, use those scores + validation_score_i = nearby_validation_i + validation_score_j = nearby_validation_j + elif nearby_validation_i is not None: + # Only i has nearby validation, prefer it (give it much better score) + validation_score_i = nearby_validation_i + validation_score_j = score_j + 1.0 # Penalize j for not having nearby validation + elif nearby_validation_j is not None: + # Only j has nearby validation, prefer it + validation_score_i = score_i + 1.0 # Penalize i for not having nearby validation + validation_score_j = nearby_validation_j + # If neither has nearby validation, use general distance scores + + # Remove the one with worse validation score + if validation_score_i > validation_score_j: + conflicts.append((idx_i, idx_j, validation_score_i, validation_score_j)) + else: + conflicts.append((idx_j, idx_i, validation_score_j, validation_score_i)) + + # Remove conflicting keypoints (keep the one with better score) + removed_indices = set() + for remove_idx, keep_idx, remove_score, keep_score in conflicts: + if remove_idx not in removed_indices: + print(f"Removing duplicate detection: keypoint {remove_idx+1} at {result[remove_idx]} conflicts with keypoint {keep_idx+1} at {result[keep_idx]} " + f"(detected distance: {np.sqrt((result[remove_idx][0] - result[keep_idx][0])**2 + (result[remove_idx][1] - result[keep_idx][1])**2):.1f}, " + f"template distance: {np.sqrt((TEMPLATE_KEYPOINTS[remove_idx][0] - TEMPLATE_KEYPOINTS[keep_idx][0])**2 + (TEMPLATE_KEYPOINTS[remove_idx][1] - TEMPLATE_KEYPOINTS[keep_idx][1])**2):.1f}). " + f"Keeping keypoint {keep_idx+1} (score: {keep_score:.3f} vs {remove_score:.3f}).") + result[remove_idx] = (0, 0) + removed_indices.add(remove_idx) + + return result + +def calculate_missing_keypoints( + keypoints: list[tuple[int, int]], + frame_width: int = None, + frame_height: int = None, +) -> list[tuple[int, int]]: + """ + Calculate missing keypoint coordinates for multiple cases: + 1. Given keypoints 14, 15, 16 (and possibly 17), and either 31 or 32, + calculate the missing center circle point (32 or 31). + 2. Given three or four of keypoints 18, 19, 20, 21 and any of 22-30, + calculate missing keypoint positions (like 22 or others) to prevent warping failures. + + Args: + keypoints: List of 32 keypoints (some may be (0,0) if missing) + frame_width: Optional frame width for validation + frame_height: Optional frame height for validation + + Returns: + Updated list of 32 keypoints with calculated missing keypoints filled in + """ + if len(keypoints) != 32: + # Pad or truncate to 32 + if len(keypoints) < 32: + keypoints = list(keypoints) + [(0, 0)] * (32 - len(keypoints)) + else: + keypoints = keypoints[:32] + + result = list(keypoints) + + # Helper to get keypoint + def get_kp(kp_idx): + if kp_idx < 0 or kp_idx >= 32: + return None + x, y = result[kp_idx] + + if x == 0 and y == 0: + return None + + return (x, y) + + + # Case 1: Find center x-coordinate from center line keypoints (14, 15, 16, or 17) + # Keypoints 14, 15, 16, 17 are on the center vertical line (indices 13, 14, 15, 16) + center_x = None + for center_kp_idx in [13, 14, 15, 16]: # 14, 15, 16, 17 (0-indexed) + kp = get_kp(center_kp_idx) + if kp: + center_x = kp[0] + break + + # If we have center line, calculate missing center circle point + if center_x is not None: + # Keypoint 31 is at index 30 (left side of center circle) + # Keypoint 32 is at index 31 (right side of center circle) + kp_31 = get_kp(30) # Keypoint 31 + kp_32 = get_kp(31) # Keypoint 32 + + if kp_31 and not kp_32: + # Given 31, calculate 32 by reflecting across center_x + # Formula: x_32 = center_x + (center_x - x_31) = 2*center_x - x_31 + # y_32 = y_31 (same y-coordinate, both on center horizontal line) + dx = center_x - kp_31[0] + result[31] = (int(round(center_x + dx)), kp_31[1]) + elif kp_32 and not kp_31: + # Given 32, calculate 31 by reflecting across center_x + # Formula: x_31 = center_x - (x_32 - center_x) = 2*center_x - x_32 + # y_31 = y_32 (same y-coordinate, both on center horizontal line) + dx = kp_32[0] - center_x + result[30] = (int(round(center_x - dx)), kp_32[1]) + + # Case 1.5: Unified handling of left side keypoints (1-13) + # Three parallel vertical lines on left side: + # - Line 1-6: keypoints 1, 2, 3, 4, 5, 6 (indices 0-5) + # - Line 7-8: keypoints 7, 8 (indices 6-7) + # - Line 10-13: keypoints 10, 11, 12, 13 (indices 9-12) + # Keypoint 9 (index 8) is between line 1-6 and line 10-13 + + # Collect all left-side keypoints (1-13, indices 0-12, excluding 9 which is center) + left_side_all = [] + line_1_6_points = [] # Indices 0-5 + line_7_8_points = [] # Indices 6-7 + line_10_13_points = [] # Indices 9-12 + + for idx in range(0, 13): # Keypoints 1-13 (indices 0-12) + if idx == 8: # Skip keypoint 9 (index 8) - it's a center point + continue + kp = get_kp(idx) + if kp: + left_side_all.append((idx, kp)) + if 0 <= idx <= 5: # Line 1-6 + line_1_6_points.append((idx, kp)) + elif 6 <= idx <= 7: # Line 7-8 + line_7_8_points.append((idx, kp)) + elif 9 <= idx <= 12: # Line 10-13 + line_10_13_points.append((idx, kp)) + + kp_9 = get_kp(8) # Keypoint 9 + if kp_9: + left_side_all.append((8, kp_9)) + + total_left_side_count = len(left_side_all) + + # If we have 6 or more points, no need to calculate more + if total_left_side_count >= 6: + pass # Don't calculate more points + elif total_left_side_count == 5: + # Check if 4 points are on one line and 1 on another line + counts_per_line = [ + len(line_1_6_points), + len(line_7_8_points), + len(line_10_13_points) + ] + + if max(counts_per_line) == 4 and sum(counts_per_line) == 4: + # 4 points on one line, need to calculate 1 more point on another line + # Determine which line has 4 points and calculate on a different line + if len(line_1_6_points) == 4: + # All 4 on line 1-6, calculate on line 10-13 or 7-8 + # Prefer line 10-13 (right edge of left side) + if len(line_10_13_points) == 0: + # Calculate a point on line 10-13 + # Fit line through 1-6 points + points_1_6 = np.array([[kp[0], kp[1]] for _, kp in line_1_6_points]) + x_coords = points_1_6[:, 0] + y_coords = points_1_6[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_1_6, b_1_6 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 10-13 (parallel to 1-6) + # Use template y-coordinate for one of 10-13 points + template_ys_10_13 = [140, 270, 410, 540] # Template y for 10-13 + template_indices_10_13 = [9, 10, 11, 12] + + # Use median y from 1-6 points to estimate scale + median_y = np.median(y_coords) + + # Calculate x using parallel line geometry + # In template: line 10-13 is at x=165, line 1-6 is at x=5 + # Ratio: 165/5 = 33 + if abs(m_1_6) > 1e-6: + x_on_line_1_6 = (median_y - b_1_6) / m_1_6 + x_new = int(round(x_on_line_1_6 * 33)) + else: + x_new = int(round(np.median(x_coords) * 33)) + + # Find first missing index in 10-13 range + for template_y, idx in zip(template_ys_10_13, template_indices_10_13): + if result[idx] is None: + result[idx] = (x_new, int(round(median_y))) + break + elif len(line_10_13_points) == 4: + # All 4 on line 10-13, calculate on line 1-6 + # Similar logic but in reverse + points_10_13 = np.array([[kp[0], kp[1]] for _, kp in line_10_13_points]) + x_coords = points_10_13[:, 0] + y_coords = points_10_13[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_10_13, b_10_13 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 1-6 + template_ys_1_6 = [5, 140, 250, 430, 540, 675] # Template y for 1-6 + template_indices_1_6 = [0, 1, 2, 3, 4, 5] + + median_y = np.median(y_coords) + + # Calculate x using parallel line geometry + # Ratio: 5/165 ≈ 0.0303 + if abs(m_10_13) > 1e-6: + x_on_line_10_13 = (median_y - b_10_13) / m_10_13 + x_new = int(round(x_on_line_10_13 * 0.0303)) + else: + x_new = int(round(np.median(x_coords) * 0.0303)) + + for template_y, idx in zip(template_ys_1_6, template_indices_1_6): + if result[idx] is None: + result[idx] = (x_new, int(round(median_y))) + break + elif total_left_side_count < 5: + # Need to calculate missing keypoints to get exactly 5 points + # Requirements: + # 1. Must have keypoint 9 (if possible) + # 2. 4 points shouldn't be all on one line (need distribution) + + # Template coordinates for reference + template_coords_left = { + 0: (5, 5), # 1 + 1: (5, 140), # 2 + 2: (5, 250), # 3 + 3: (5, 430), # 4 + 4: (5, 540), # 5 + 5: (5, 675), # 6 + 6: (55, 250), # 7 + 7: (55, 430), # 8 + 8: (110, 340), # 9 (what we're calculating) + 9: (165, 140), # 10 + 10: (165, 270), # 11 + 11: (165, 410), # 12 + 12: (165, 540), # 13 + } + + # Define line groups (vertical and horizontal lines) + # Vertical lines: 1-6, 7-8, 10-13 + # Horizontal lines: 2-10, 3-7, 4-8, 5-13 + line_groups_left = { + '1-6': ([0, 1, 2, 3, 4, 5], 'vertical'), # indices: 1, 2, 3, 4, 5, 6 + '7-8': ([6, 7], 'vertical'), # indices: 7, 8 + '10-13': ([9, 10, 11, 12], 'vertical'), # indices: 10, 11, 12, 13 + '2-10': ([1, 9], 'horizontal'), # indices: 2, 10 + '3-7': ([2, 6], 'horizontal'), # indices: 3, 7 + '4-8': ([3, 7], 'horizontal'), # indices: 4, 8 + '5-13': ([4, 12], 'horizontal'), # indices: 5, 13 + } + + # Collect all available points with their indices + all_available_points_left = {} + for idx, kp in line_1_6_points: + all_available_points_left[idx] = kp + for idx, kp in line_7_8_points: + all_available_points_left[idx] = kp + for idx, kp in line_10_13_points: + all_available_points_left[idx] = kp + + # Step 1: Find the best vertical line and best horizontal line separately + best_vertical_line_name_left = None + best_vertical_line_points_left = [] + max_vertical_points_left = 1 + + best_horizontal_line_name_left = None + best_horizontal_line_points_left = [] + max_horizontal_points_left = 1 + + for line_name, (indices, line_type) in line_groups_left.items(): + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if line_type == 'vertical' and len(line_points) > max_vertical_points_left: + max_vertical_points_left = len(line_points) + best_vertical_line_name_left = line_name + best_vertical_line_points_left = line_points + elif line_type == 'horizontal' and len(line_points) > max_horizontal_points_left: + max_horizontal_points_left = len(line_points) + best_horizontal_line_name_left = line_name + best_horizontal_line_points_left = line_points + + # Check and calculate missing points on detected lines + # For vertical lines + if best_vertical_line_name_left is not None: + expected_indices = line_groups_left[best_vertical_line_name_left][0] + detected_indices = {idx for idx, _ in best_vertical_line_points_left} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords_left[best_vertical_line_points_left[0][0]] + template_end = template_coords_left[best_vertical_line_points_left[-1][0]] + frame_start = best_vertical_line_points_left[0][1] + frame_end = best_vertical_line_points_left[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords_left[missing_idx] + + # Calculate ratio along the line based on y-coordinate (vertical line) + template_y_start = template_start[1] + template_y_end = template_end[1] + template_y_missing = template_missing[1] + + if abs(template_y_end - template_y_start) > 1e-6: + ratio = (template_y_missing - template_y_start) / (template_y_end - template_y_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_vertical_line_points_left.append((missing_idx, new_point)) + all_available_points_left[missing_idx] = new_point + total_left_side_count += 1 + max_vertical_points_left = len(best_vertical_line_points_left) + + # Sort by index to maintain order + best_vertical_line_points_left.sort(key=lambda x: x[0]) + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_horizontal_points_left: + max_horizontal_points_left = len(line_points) + best_horizontal_line_name_left = line_name + best_horizontal_line_points_left = line_points + + # For horizontal lines + if best_horizontal_line_name_left is not None: + expected_indices = line_groups_left[best_horizontal_line_name_left][0] + detected_indices = {idx for idx, _ in best_horizontal_line_points_left} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords_left[best_horizontal_line_points_left[0][0]] + template_end = template_coords_left[best_horizontal_line_points_left[-1][0]] + frame_start = best_horizontal_line_points_left[0][1] + frame_end = best_horizontal_line_points_left[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords_left[missing_idx] + + # Calculate ratio along the line based on x-coordinate (horizontal line) + template_x_start = template_start[0] + template_x_end = template_end[0] + template_x_missing = template_missing[0] + + if abs(template_x_end - template_x_start) > 1e-6: + ratio = (template_x_missing - template_x_start) / (template_x_end - template_x_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_horizontal_line_points_left.append((missing_idx, new_point)) + all_available_points_left[missing_idx] = new_point + total_left_side_count += 1 + max_horizontal_points_left = len(best_horizontal_line_points_left) + + # Sort by index to maintain order + best_horizontal_line_points_left.sort(key=lambda x: x[0]) + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_vertical_points_left: + max_vertical_points_left = len(line_points) + best_vertical_line_name_left = line_name + best_vertical_line_points_left = line_points + + # If we only have one direction, try to calculate the other direction line + # Similar logic to right side, adapted for left side structure + if best_vertical_line_name_left is not None and best_horizontal_line_name_left is None: + # We have vertical line but no horizontal line + # Find an off-line point (not on the vertical line) + off_line_point = None + off_line_idx = None + vertical_line_indices = line_groups_left[best_vertical_line_name_left][0] + for idx, kp in all_available_points_left.items(): + if idx not in vertical_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Convert off_line_point to numpy array for arithmetic operations + off_line_point = np.array(off_line_point) + + # Project off_line_point onto vertical line + template_off_line = template_coords_left[off_line_idx] + + template_vertical_start_index = best_vertical_line_points_left[0][0] + template_vertical_end_index = best_vertical_line_points_left[-1][0] + + template_vertical_start = template_coords_left[template_vertical_start_index] + template_vertical_end = template_coords_left[template_vertical_end_index] + + # Project at same y as off_line_point + template_y_off = template_off_line[1] + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_proj = (template_y_off - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_proj = 0.5 + + frame_vertical_start = best_vertical_line_points_left[0][1] + frame_vertical_end = best_vertical_line_points_left[-1][1] + proj_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_proj + proj_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + + # Calculate horizontal line points based on which vertical line we have + if best_vertical_line_name_left == '10-13': + # Line 10-13: can calculate points on horizontal lines 2-10, 5-13 + if off_line_idx == 1: # Point 2 (index 1) is off-line, calculate point 10 (index 9) + kp_10 = np.array(best_vertical_line_points_left[0][1]) # 10 point + kp_2 = off_line_point + (kp_10 - proj_point) + result[1] = tuple(kp_2.astype(int)) + total_left_side_count += 1 + all_available_points_left[1] = tuple(kp_2.astype(int)) + elif off_line_idx == 4: # Point 5 (index 4) is off-line, calculate point 13 (index 12) + kp_13 = np.array(best_vertical_line_points_left[-1][1]) # 13 point + kp_5 = off_line_point + (kp_13 - proj_point) + result[4] = tuple(kp_5.astype(int)) + total_left_side_count += 1 + all_available_points_left[4] = tuple(kp_5.astype(int)) + + elif best_vertical_line_name_left == '1-6': + # Line 1-6: can calculate points on horizontal lines 2-10, 3-7, 4-8, 5-13 + if off_line_idx == 6 or off_line_idx == 7: # Point 7 or 8 is off-line, calculate point 3 or 4 + template_off = template_coords_left[off_line_idx] + template_3 = template_coords_left[2] # 3 point, index 2 + template_4 = template_coords_left[3] # 4 point, index 3 + template_7 = template_coords_left[6] # 7 point, index 6 + template_8 = template_coords_left[7] # 8 point, index 7 + + if off_line_idx == 6: # Point 7, calculate point 3 + ratio = (template_3[0] - template_7[0]) / (template_7[0] - template_off[0]) if abs(template_7[0] - template_off[0]) > 1e-6 else 0.5 + kp_3 = proj_point + (off_line_point - proj_point) * ratio + result[2] = tuple(kp_3.astype(int)) + total_left_side_count += 1 + all_available_points_left[2] = tuple(kp_3.astype(int)) + else: # Point 8, calculate point 4 + ratio = (template_4[0] - template_8[0]) / (template_8[0] - template_off[0]) if abs(template_8[0] - template_off[0]) > 1e-6 else 0.5 + kp_4 = proj_point + (off_line_point - proj_point) * ratio + result[3] = tuple(kp_4.astype(int)) + total_left_side_count += 1 + all_available_points_left[3] = tuple(kp_4.astype(int)) + elif off_line_idx == 9 or off_line_idx == 12: # Point 10 or 13 is off-line, calculate point 2 or 5 + if off_line_idx == 9: # Point 10, calculate point 2 + kp_2 = off_line_point + (np.array(best_vertical_line_points_left[1][1]) - proj_point) + result[1] = tuple(kp_2.astype(int)) + total_left_side_count += 1 + all_available_points_left[1] = tuple(kp_2.astype(int)) + else: # Point 13, calculate point 5 + kp_5 = off_line_point + (np.array(best_vertical_line_points_left[4][1]) - proj_point) + result[4] = tuple(kp_5.astype(int)) + total_left_side_count += 1 + all_available_points_left[4] = tuple(kp_5.astype(int)) + + elif best_vertical_line_name_left == '7-8': + # Line 7-8: can calculate points on horizontal lines 3-7, 4-8 + if off_line_idx == 2 or off_line_idx == 3: # Point 3 or 4 is off-line, calculate point 7 or 8 + if off_line_idx == 2: # Point 3, calculate point 7 + kp_7 = off_line_point + (np.array(best_vertical_line_points_left[0][1]) - proj_point) + result[6] = tuple(kp_7.astype(int)) + total_left_side_count += 1 + all_available_points_left[6] = tuple(kp_7.astype(int)) + else: # Point 4, calculate point 8 + kp_8 = off_line_point + (np.array(best_vertical_line_points_left[-1][1]) - proj_point) + result[7] = tuple(kp_8.astype(int)) + total_left_side_count += 1 + all_available_points_left[7] = tuple(kp_8.astype(int)) + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_horizontal_points_left: + max_horizontal_points_left = len(line_points) + best_horizontal_line_name_left = line_name + best_horizontal_line_points_left = line_points + + elif best_horizontal_line_name_left is not None and best_vertical_line_name_left is None: + # We have horizontal line but no vertical line + # Find an off-line point (not on the horizontal line) + off_line_point = None + off_line_idx = None + horizontal_line_indices = line_groups_left[best_horizontal_line_name_left][0] + for idx, kp in all_available_points_left.items(): + if idx not in horizontal_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Project off_line_point onto horizontal line + template_off_line = template_coords_left[off_line_idx] + template_horizontal_start = template_coords_left[best_horizontal_line_points_left[0][0]] + template_horizontal_end = template_coords_left[best_horizontal_line_points_left[-1][0]] + + # Project at same x as off_line_point + template_x_off = template_off_line[0] + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_proj = (template_x_off - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_proj = 0.5 + + frame_horizontal_start = best_horizontal_line_points_left[0][1] + frame_horizontal_end = best_horizontal_line_points_left[-1][1] + proj_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_proj + proj_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + off_line_point = np.array(off_line_point) + + # Calculate vertical line points based on which horizontal line we have + if best_horizontal_line_name_left == '2-10': + # Line 2-10: can calculate points on vertical lines 1-6, 10-13 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 2 + kp_2 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[1] = tuple(kp_2.astype(int)) + total_left_side_count += 1 + all_available_points_left[1] = tuple(kp_2.astype(int)) + elif off_line_idx == 9 or off_line_idx == 12: # Point 10 or 13 is off-line, calculate point 10 + kp_10 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[9] = tuple(kp_10.astype(int)) + total_left_side_count += 1 + all_available_points_left[9] = tuple(kp_10.astype(int)) + + elif best_horizontal_line_name_left == '3-7': + # Line 3-7: can calculate points on vertical lines 1-6, 7-8 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 3 + kp_3 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[2] = tuple(kp_3.astype(int)) + total_left_side_count += 1 + all_available_points_left[2] = tuple(kp_3.astype(int)) + elif off_line_idx == 6 or off_line_idx == 7: # Point 7 or 8 is off-line, calculate point 7 + kp_7 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[6] = tuple(kp_7.astype(int)) + total_left_side_count += 1 + all_available_points_left[6] = tuple(kp_7.astype(int)) + + elif best_horizontal_line_name_left == '4-8': + # Line 4-8: can calculate points on vertical lines 1-6, 7-8 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 4 + kp_4 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[3] = tuple(kp_4.astype(int)) + total_left_side_count += 1 + all_available_points_left[3] = tuple(kp_4.astype(int)) + elif off_line_idx == 6 or off_line_idx == 7: # Point 7 or 8 is off-line, calculate point 8 + kp_8 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[7] = tuple(kp_8.astype(int)) + total_left_side_count += 1 + all_available_points_left[7] = tuple(kp_8.astype(int)) + + elif best_horizontal_line_name_left == '5-13': + # Line 5-13: can calculate points on vertical lines 1-6, 10-13 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 5 + kp_5 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[4] = tuple(kp_5.astype(int)) + total_left_side_count += 1 + all_available_points_left[4] = tuple(kp_5.astype(int)) + elif off_line_idx == 9 or off_line_idx == 12: # Point 10 or 13 is off-line, calculate point 13 + kp_13 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[12] = tuple(kp_13.astype(int)) + total_left_side_count += 1 + all_available_points_left[12] = tuple(kp_13.astype(int)) + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_vertical_points_left: + max_vertical_points_left = len(line_points) + best_vertical_line_name_left = line_name + best_vertical_line_points_left = line_points + + # Calculate keypoint 9 if we have at least one line + if best_vertical_line_name_left is not None and best_horizontal_line_name_left is not None: + if kp_9 is None: + print(f"Calculating keypoint 9 using both vertical and horizontal lines: {best_vertical_line_name_left} and {best_horizontal_line_name_left}") + + template_x_9 = 110 + template_y_9 = 340 + + # Project keypoint 9 onto vertical line + template_vertical_start = template_coords_left[best_vertical_line_points_left[0][0]] + template_vertical_end = template_coords_left[best_vertical_line_points_left[-1][0]] + + # Project at y=340 (same y as keypoint 9) + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_9_vertical = (template_y_9 - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_9_vertical = 0.5 + + frame_vertical_start = best_vertical_line_points_left[0][1] + frame_vertical_end = best_vertical_line_points_left[-1][1] + proj_9_on_vertical_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_9_vertical + proj_9_on_vertical_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_9_vertical + proj_9_on_vertical = (proj_9_on_vertical_x, proj_9_on_vertical_y) + + # Project keypoint 9 onto horizontal line + template_horizontal_start = template_coords_left[best_horizontal_line_points_left[0][0]] + template_horizontal_end = template_coords_left[best_horizontal_line_points_left[-1][0]] + + # Project at x=110 (same x as keypoint 9) + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_9_horizontal = (template_x_9 - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_9_horizontal = 0.5 + + frame_horizontal_start = best_horizontal_line_points_left[0][1] + frame_horizontal_end = best_horizontal_line_points_left[-1][1] + proj_9_on_horizontal_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_9_horizontal + proj_9_on_horizontal_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_9_horizontal + proj_9_on_horizontal = (proj_9_on_horizontal_x, proj_9_on_horizontal_y) + + # Calculate keypoint 9 as intersection of two lines + # Line 1: Passes through proj_9_on_vertical, parallel to best_horizontal_line + # Line 2: Passes through proj_9_on_horizontal, parallel to best_vertical_line + + # Calculate direction vector of best_horizontal_line + horizontal_dir_x = frame_horizontal_end[0] - frame_horizontal_start[0] + horizontal_dir_y = frame_horizontal_end[1] - frame_horizontal_start[1] + horizontal_dir_length = np.sqrt(horizontal_dir_x**2 + horizontal_dir_y**2) + + # Calculate direction vector of best_vertical_line + vertical_dir_x = frame_vertical_end[0] - frame_vertical_start[0] + vertical_dir_y = frame_vertical_end[1] - frame_vertical_start[1] + vertical_dir_length = np.sqrt(vertical_dir_x**2 + vertical_dir_y**2) + + if horizontal_dir_length > 1e-6 and vertical_dir_length > 1e-6: + # Normalize direction vectors + horizontal_dir_x /= horizontal_dir_length + horizontal_dir_y /= horizontal_dir_length + vertical_dir_x /= vertical_dir_length + vertical_dir_y /= vertical_dir_length + + # Find intersection: proj_9_on_vertical + t * horizontal_dir = proj_9_on_horizontal + s * vertical_dir + A = np.array([ + [horizontal_dir_x, -vertical_dir_x], + [horizontal_dir_y, -vertical_dir_y] + ]) + b = np.array([ + proj_9_on_horizontal[0] - proj_9_on_vertical[0], + proj_9_on_horizontal[1] - proj_9_on_vertical[1] + ]) + + try: + t, s = np.linalg.solve(A, b) + + # Calculate intersection point using line 1 + x_9 = proj_9_on_vertical[0] + t * horizontal_dir_x + y_9 = proj_9_on_vertical[1] + t * horizontal_dir_y + + result[8] = (int(round(x_9)), int(round(y_9))) + total_left_side_count += 1 + except np.linalg.LinAlgError: + # Lines are parallel or nearly parallel, use simple intersection + x_9 = proj_9_on_vertical[0] + y_9 = proj_9_on_horizontal[1] + result[8] = (int(round(x_9)), int(round(y_9))) + total_left_side_count += 1 + else: + # Fallback: use simple intersection + x_9 = proj_9_on_vertical[0] + y_9 = proj_9_on_horizontal[1] + result[8] = (int(round(x_9)), int(round(y_9))) + total_left_side_count += 1 + + print(f"total_left_side_count: {total_left_side_count}, result: {result}") + if total_left_side_count > 5: + pass # Continue to right side logic + + # Calculate m_line and b_line from best vertical or horizontal line for use in calculating other points + m_line_left = None + b_line_left = None + best_line_for_calc_left = None + best_line_type_for_calc_left = None + + if best_vertical_line_name_left is not None and len(best_vertical_line_points_left) >= 2: + best_line_for_calc_left = best_vertical_line_points_left + best_line_type_for_calc_left = 'vertical' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_vertical_line_points_left]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line_left, b_line_left = np.linalg.lstsq(A, y_coords, rcond=None)[0] + elif best_horizontal_line_name_left is not None and len(best_horizontal_line_points_left) >= 2: + best_line_for_calc_left = best_horizontal_line_points_left + best_line_type_for_calc_left = 'horizontal' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_horizontal_line_points_left]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line_left, b_line_left = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate missing points to reach exactly 5 points + # Ensure 4 points aren't all on one line + if total_left_side_count < 5 and (m_line_left is not None or (best_line_for_calc_left is not None and best_line_type_for_calc_left == 'vertical')): + # Check current distribution + counts_per_line = [ + len(line_1_6_points), + len(line_7_8_points), + len(line_10_13_points) + ] + + # Calculate points on line 1-6 if needed + template_ys_1_6 = [5, 140, 250, 430, 540, 675] + template_indices_1_6 = [0, 1, 2, 3, 4, 5] + + if best_vertical_line_name_left == '10-13': + # Construct parallel line 1-6 from line 10-13 + for template_y, idx in zip(template_ys_1_6, template_indices_1_6): + if result[idx] is None and total_left_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 1-6 + if max(new_counts) >= 4 and total_left_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate y using scale from template + ref_ys = [kp[1] for _, kp in line_10_13_points] + ref_template_ys = [140, 270, 410, 540] + ref_indices = [9, 10, 11, 12] + + matched_template_ys = [] + for ref_idx, ref_kp in line_10_13_points: + if ref_idx in ref_indices: + template_idx = ref_indices.index(ref_idx) + matched_template_ys.append((ref_template_ys[template_idx], ref_kp[1])) + + if len(matched_template_ys) >= 1: + ref_template_y, ref_frame_y = matched_template_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + # Calculate x using parallel line geometry + if abs(m_line_left) > 1e-6: + x_on_line_10_13 = (y_new - b_line_left) / m_line_left + x_new = int(round(x_on_line_10_13 * 0.0303)) # 5/165 + else: + x_new = int(round(np.median([kp[0] for _, kp in line_10_13_points]) * 0.0303)) + + result[idx] = (x_new, y_new) + total_left_side_count += 1 + if total_left_side_count >= 5: + break + elif best_vertical_line_name_left == '1-6': + # Calculate missing points on line 1-6 + for template_y, idx in zip(template_ys_1_6, template_indices_1_6): + if result[idx] is None and total_left_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 1-6 + if max(new_counts) >= 4 and total_left_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate x on the line + if abs(m_line_left) > 1e-6: + x_new = (template_y - b_line_left) / m_line_left + else: + x_new = np.median([kp[0] for _, kp in line_1_6_points]) + + # Scale y based on available points + ref_ys = [kp[1] for _, kp in line_1_6_points] + ref_template_ys = [] + for ref_idx, _ in line_1_6_points: + if ref_idx in template_indices_1_6: + template_idx = template_indices_1_6.index(ref_idx) + ref_template_ys.append(template_ys_1_6[template_idx]) + + if len(ref_ys) >= 1 and len(ref_template_ys) >= 1: + ref_template_y = ref_template_ys[0] + ref_frame_y = ref_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + result[idx] = (int(round(x_new)), y_new) + total_left_side_count += 1 + if total_left_side_count >= 5: + break + + print(f"total_left_side_count: {total_left_side_count}, result: {result}") + + # Case 2: Unified handling of right side keypoints (18-30) + # Three parallel lines on right side: + # - Line 18-21: keypoints 18, 19, 20, 21 (indices 17-20) + # - Line 23-24: keypoints 23, 24 (indices 22-23) + # - Line 25-30: keypoints 25, 26, 27, 28, 29, 30 (indices 24-29) + # Keypoint 22 (index 21) is between line 18-21 and line 25-30 + + # Collect all right-side keypoints (18-30, indices 17-29) + right_side_all = [] + line_18_21_points = [] # Indices 17-20 + line_23_24_points = [] # Indices 22-23 + line_25_30_points = [] # Indices 24-29 + + for idx in range(17, 30): # Keypoints 18-30 (indices 17-29) + kp = get_kp(idx) + if kp: + right_side_all.append((idx, kp)) + if 17 <= idx <= 20: # Line 18-21 + line_18_21_points.append((idx, kp)) + elif 22 <= idx <= 23: # Line 23-24 + line_23_24_points.append((idx, kp)) + elif 24 <= idx <= 29: # Line 25-30 + line_25_30_points.append((idx, kp)) + + kp_22 = get_kp(21) # Keypoint 22 + if kp_22: + right_side_all.append((21, kp_22)) + + total_right_side_count = len(right_side_all) + + # If we have 6 or more points, no need to calculate more + if total_right_side_count >= 6: + pass # Don't calculate more points + elif total_right_side_count == 5: + # Check if 4 points are on one line and 1 on another line + counts_per_line = [ + len(line_18_21_points), + len(line_23_24_points), + len(line_25_30_points) + ] + + if max(counts_per_line) == 4 and sum(counts_per_line) == 4: + # 4 points on one line, need to calculate 1 more point on another line + # Determine which line has 4 points and calculate on a different line + if len(line_18_21_points) == 4: + # All 4 on line 18-21, calculate on line 25-30 or 23-24 + # Prefer line 25-30 (right edge) + if len(line_25_30_points) == 0: + # Calculate a point on line 25-30 + # Fit line through 18-21 points + points_18_21 = np.array([[kp[0], kp[1]] for _, kp in line_18_21_points]) + x_coords = points_18_21[:, 0] + y_coords = points_18_21[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_18_21, b_18_21 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 25-30 (parallel to 18-21) + # Use template y-coordinate for one of 25-30 points + template_ys_25_30 = [5, 140, 250, 430, 540, 675] # Template y for 25-30 + template_indices_25_30 = [24, 25, 26, 27, 28, 29] + + # Use median y from 18-21 points to estimate scale + median_y = np.median(y_coords) + # Find closest template y + ref_template_y = min(template_ys_25_30, key=lambda ty: abs(ty - np.median([kp[1] for _, kp in line_18_21_points]))) + ref_idx = template_ys_25_30.index(ref_template_y) + + # Calculate y for the new point + y_new = int(round(median_y)) + + # Calculate x using parallel line geometry + # In template: line 25-30 is at x=1045, line 18-21 is at x=888 + # Ratio: 1045/888 ≈ 1.177 + if abs(m_18_21) > 1e-6: + x_on_line_18_21 = (y_new - b_18_21) / m_18_21 + x_new = int(round(x_on_line_18_21 * 1.177)) + else: + x_new = int(round(np.median(x_coords) * 1.177)) + + # Find first missing index in 25-30 range + for template_y, idx in zip(template_ys_25_30, template_indices_25_30): + if result[idx] is None: + result[idx] = (x_new, y_new) + break + elif len(line_25_30_points) == 4: + # All 4 on line 25-30, calculate on line 18-21 + # Similar logic but in reverse + points_25_30 = np.array([[kp[0], kp[1]] for _, kp in line_25_30_points]) + x_coords = points_25_30[:, 0] + y_coords = points_25_30[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_25_30, b_25_30 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 18-21 + template_ys_18_21 = [140, 270, 410, 540] # Template y for 18-21 + template_indices_18_21 = [17, 18, 19, 20] + + median_y = np.median(y_coords) + + # Calculate x using parallel line geometry + # Ratio: 888/1045 ≈ 0.850 + if abs(m_25_30) > 1e-6: + x_on_line_25_30 = (median_y - b_25_30) / m_25_30 + x_new = int(round(x_on_line_25_30 * 0.850)) + else: + x_new = int(round(np.median(x_coords) * 0.850)) + + for template_y, idx in zip(template_ys_18_21, template_indices_18_21): + if result[idx] is None: + result[idx] = (x_new, int(round(median_y))) + break + elif total_right_side_count < 5: + # Need to calculate missing keypoints to get exactly 5 points + # Requirements: + # 1. Must have keypoint 22 + # 2. 4 points shouldn't be all on one line (need distribution) + + # Template coordinates for reference + template_coords = { + 17: (888, 140), # 18 + 18: (888, 270), # 19 + 19: (888, 410), # 20 + 20: (888, 540), # 21 + 21: (940, 340), # 22 (what we're calculating) + 22: (998, 250), # 23 + 23: (998, 430), # 24 + 24: (1045, 5), # 25 + 25: (1045, 140), # 26 + 26: (1045, 250), # 27 + 27: (1045, 430), # 28 + 28: (1045, 540), # 29 + 29: (1045, 675), # 30 + } + + # Define line groups (vertical and horizontal lines) + # Vertical lines: 18-21, 23-24, 25-30 + # Horizontal lines: 18-26, 23-27, 24-28, 21-29 + line_groups = { + '18-21': ([17, 18, 19, 20], 'vertical'), # indices: 18, 19, 20, 21 + '23-24': ([22, 23], 'vertical'), # indices: 23, 24 + '25-30': ([24, 25, 26, 27, 28, 29], 'vertical'), # indices: 25, 26, 27, 28, 29, 30 + '18-26': ([17, 25], 'horizontal'), # indices: 18, 26 + '23-27': ([22, 26], 'horizontal'), # indices: 23, 27 + '24-28': ([23, 27], 'horizontal'), # indices: 24, 28 + '21-29': ([20, 28], 'horizontal'), # indices: 21, 29 + } + + # Collect all available points with their indices + all_available_points = {} + for idx, kp in line_18_21_points: + all_available_points[idx] = kp + for idx, kp in line_23_24_points: + all_available_points[idx] = kp + for idx, kp in line_25_30_points: + all_available_points[idx] = kp + + # Step 1: Find the best vertical line and best horizontal line separately + best_vertical_line_name = None + best_vertical_line_points = [] + max_vertical_points = 1 + + best_horizontal_line_name = None + best_horizontal_line_points = [] + max_horizontal_points = 1 + + for line_name, (indices, line_type) in line_groups.items(): + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if line_type == 'vertical' and len(line_points) > max_vertical_points: + max_vertical_points = len(line_points) + best_vertical_line_name = line_name + best_vertical_line_points = line_points + elif line_type == 'horizontal' and len(line_points) > max_horizontal_points: + max_horizontal_points = len(line_points) + best_horizontal_line_name = line_name + best_horizontal_line_points = line_points + + # Check and calculate missing points on detected lines + # For vertical lines + if best_vertical_line_name is not None: + expected_indices = line_groups[best_vertical_line_name][0] + detected_indices = {idx for idx, _ in best_vertical_line_points} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords[best_vertical_line_points[0][0]] + template_end = template_coords[best_vertical_line_points[-1][0]] + frame_start = best_vertical_line_points[0][1] + frame_end = best_vertical_line_points[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords[missing_idx] + + # Calculate ratio along the line based on y-coordinate (vertical line) + template_y_start = template_start[1] + template_y_end = template_end[1] + template_y_missing = template_missing[1] + + if abs(template_y_end - template_y_start) > 1e-6: + ratio = (template_y_missing - template_y_start) / (template_y_end - template_y_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_vertical_line_points.append((missing_idx, new_point)) + all_available_points[missing_idx] = new_point + total_right_side_count += 1 + max_vertical_points = len(best_vertical_line_points) + + # Sort by index to maintain order + best_vertical_line_points.sort(key=lambda x: x[0]) + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_horizontal_points: + max_horizontal_points = len(line_points) + best_horizontal_line_name = line_name + best_horizontal_line_points = line_points + + # For horizontal lines + if best_horizontal_line_name is not None: + expected_indices = line_groups[best_horizontal_line_name][0] + detected_indices = {idx for idx, _ in best_horizontal_line_points} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords[best_horizontal_line_points[0][0]] + template_end = template_coords[best_horizontal_line_points[-1][0]] + frame_start = best_horizontal_line_points[0][1] + frame_end = best_horizontal_line_points[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords[missing_idx] + + # Calculate ratio along the line based on x-coordinate (horizontal line) + template_x_start = template_start[0] + template_x_end = template_end[0] + template_x_missing = template_missing[0] + + if abs(template_x_end - template_x_start) > 1e-6: + ratio = (template_x_missing - template_x_start) / (template_x_end - template_x_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_horizontal_line_points.append((missing_idx, new_point)) + all_available_points[missing_idx] = new_point + total_right_side_count += 1 + max_horizontal_points = len(best_horizontal_line_points) + + # Sort by index to maintain order + best_horizontal_line_points.sort(key=lambda x: x[0]) + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_vertical_points: + max_vertical_points = len(line_points) + best_vertical_line_name = line_name + best_vertical_line_points = line_points + + # If we only have one direction, try to calculate the other direction line + if best_vertical_line_name is not None and best_horizontal_line_name is None: + # possible cases: + # line is 25-30 and off line point is 19, then we can calculate 18 so get horizontal line 18-26 + # line is 25-30 and off line point is 20, then we can calculate 18 so get horizontal line 18-26 + # line is 18-21 and off line point is 23, then we can calculate 27 so get horizontal line 23-27 + # line is 18-21 and off line point is 24, then we can calculate 28 so get horizontal line 24-28 + # line is 18-21 and off line point is 25, then we can calculate 26 so get horizontal line 18-26 + # line is 18-21 and off line point is 27, then we can calculate 26 so get horizontal line 18-26 + # line is 18-21 and off line point is 28, then we can calculate 29 so get horizontal line 21-29 + # line is 18-21 and off line point is 30, then we can calculate 29 so get horizontal line 21-29 + # line is 23-24 and off line point is 18, then we can calculate 26 so get horizontal line 18-26 + # line is 23-24 and off line point is 19, then we can calculate 18 so get horizontal line 18-26 + # line is 23-24 and off line point is 20, then we can calculate 21 so get horizontal line 21-29 + # line is 23-24 and off line point is 21, then we can calculate 29 so get horizontal line 21-29 + # line is 23-24 and off line point is 25, then we can calculate 27 so get horizontal line 23-27 + # line is 23-24 and off line point is 26, then we can calculate 27 so get horizontal line 23-27 + # line is 23-24 and off line point is 29, then we can calculate 28 so get horizontal line 24-28 + # line is 23-24 and off line point is 30, then we can calculate 28 so get horizontal line 24-28 + # We have vertical line but no horizontal line + # Find an off-line point (not on the vertical line) + off_line_point = None + off_line_idx = None + vertical_line_indices = line_groups[best_vertical_line_name][0] + for idx, kp in all_available_points.items(): + if idx not in vertical_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Convert off_line_point to numpy array for arithmetic operations + off_line_point = np.array(off_line_point) + + # Project off_line_point onto vertical line + template_off_line = template_coords[off_line_idx] + + template_vertical_start_index = best_vertical_line_points[0][0] + template_vertical_end_index = best_vertical_line_points[-1][0] + + template_vertical_start = template_coords[template_vertical_start_index] + template_vertical_end = template_coords[template_vertical_end_index] + + # Project at same y as off_line_point + template_y_off = template_off_line[1] + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_proj = (template_y_off - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_proj = 0.5 + + frame_vertical_start = best_vertical_line_points[0][1] + frame_vertical_end = best_vertical_line_points[-1][1] + proj_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_proj + proj_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + + if best_vertical_line_name == '25-30' and len(best_vertical_line_points) == 6: + if off_line_idx == 18 or off_line_idx == 19: # 19 or 20 point is off line point, so we can calculate 18 + kp_26 = np.array(best_vertical_line_points[1][1]) # 26 point + + kp_18 = off_line_point + (kp_26 - proj_point) + result[17] = tuple(kp_18.astype(int)) + total_right_side_count += 1 + all_available_points[17] = tuple(kp_18.astype(int)) # 18 point is now available, index is 17 + + if best_vertical_line_name == '18-21' and len(best_vertical_line_points) == 4: + if off_line_idx == 22 or off_line_idx == 23: # 23 or 24 point is off line point, so we can calculate 27 + template_19 = template_coords[18] # 19 point, index is 18 + template_23 = template_coords[22] # 23 point, index is 22 + template_27 = template_coords[26] # 27 point, index is 26 + + ratio = (template_27[0] - template_19[0]) / (template_23[0] - template_19[0]) # ratio in x coordinates because y coordinates are the same + + expected_point = proj_point + (off_line_point - proj_point) * ratio + + if off_line_idx == 22: + result[26] = tuple(expected_point.astype(int)) # 27 point, index is 26 + total_right_side_count += 1 + all_available_points[26] = tuple(expected_point.astype(int)) # 27 point is now available, index is 26 + else: + result[27] = tuple(expected_point.astype(int)) # 28 point, index is 27 + total_right_side_count += 1 + all_available_points[27] = tuple(expected_point.astype(int)) # 28 point is now available, index is 27 + + if off_line_idx == 24 or off_line_idx == 26: # 25 or 27 point is off line point, so we can calculate 26 + kp_18 = np.array(best_vertical_line_points[0][1]) # 18 point + kp_26 = off_line_point + (kp_18 - proj_point) + + result[25] = tuple(kp_26.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(kp_26.astype(int)) # 26 point is now available, index is 25 + + if off_line_idx == 27 or off_line_idx == 29: # 28 or 30 point is off line point, so we can calculate 29 + kp_21 = np.array(best_vertical_line_points[-1][1]) # 21 point + kp_29 = off_line_point + (kp_21 - proj_point) + + result[28] = tuple(kp_29.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(kp_29.astype(int)) # 29 point is now available, index is 28 + + + if best_vertical_line_name == '23-24' and len(best_vertical_line_points) == 2: + if off_line_idx == 17 or off_line_idx == 18 or off_line_idx == 19 or off_line_idx == 20: # 18 or 19 or 20 or 21 point is off line point, so we can calculate 26 + template_18 = template_coords[17] # 18 point, index is 17 + template_26 = template_coords[25] # 26 point, index is 25 + template_23 = template_coords[22] # 23 point, index is 22 + + ratio_26 = (template_26[0] - template_18[0]) / (template_23[0] - template_18[0]) # ratio in x coordinates because y coordinates are the same + + kp_18 = None + if off_line_idx == 17: + kp_18 = off_line_point + elif off_line_idx == 18 or off_line_idx == 19 or off_line_idx == 20: + template_off_line = template_coords[off_line_idx] + ratio = (template_18[1] - template_off_line[1]) / (template_23[1] - template_off_line[1]) + kp_18 = off_line_point + (np.array(best_vertical_line_points[0][1]) - proj_point) * ratio + + if kp_18 is not None: + kp_26 = kp_18 + (proj_point - off_line_point) * ratio_26 + result[25] = tuple(kp_26.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(kp_26.astype(int)) # 26 point is now available, index is 25 + + if off_line_idx == 24 or off_line_idx == 25: # 25 or 26 point is off line point, so we can calculate 27 + kp_27 = off_line_point + (np.array(best_vertical_line_points[0][1]) - proj_point) + + result[26] = tuple(kp_27.astype(int)) + total_right_side_count += 1 + all_available_points[26] = tuple(kp_27.astype(int)) # 27 point is now available, index is 26 + + if off_line_idx == 28 or off_line_idx == 29: # 29 or 30 point is off line point, so we can calculate 29 + kp_29 = off_line_point + (np.array(best_vertical_line_points[-1][1]) - proj_point) + + result[28] = tuple(kp_29.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(kp_29.astype(int)) # 29 point is now available, index is 28 + + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_horizontal_points: + max_horizontal_points = len(line_points) + best_horizontal_line_name = line_name + best_horizontal_line_points = line_points + + + elif best_horizontal_line_name is not None and best_vertical_line_name is None: + # possible cases: + # line is 18-26 and off line point is 23, then we can calculate 27 so get vertical line 25-30 + # line is 18-26 and off line point is 24, then we can calculate 28 so get vertical line 25-30 + # line is 23-27 and off line point is 18, then we can calculate 26 so get vertical line 25-30 + # line is 23-27 and off line point is 19, then we can calculate 18 so get vertical line 18-21 + # line is 23-27 and off line point is 20, then we can calculate 18 so get vertical line 18-21 + # line is 23-27 and off line point is 21, then we can calculate 29 so get vertical line 25-30 + # line is 24-28 and off line point is 18, then we can calculate 26 so get vertical line 25-30 + # line is 24-28 and off line point is 19, then we can calculate 21 so get vertical line 18-21 + # line is 24-28 and off line point is 20, then we can calculate 21 so get vertical line 18-21 + # line is 24-28 and off line point is 21, then we can calculate 29 so get vertical line 25-30 + # line is 21-29 and off line point is 23, then we can calculate 27 so get vertical line 25-30 + # line is 21-29 and off line point is 24, then we can calculate 28 so get vertical line 25-30 + # We have horizontal line but no vertical line + # Find an off-line point (not on the horizontal line) + off_line_point = None + off_line_idx = None + horizontal_line_indices = line_groups[best_horizontal_line_name][0] + for idx, kp in all_available_points.items(): + if idx not in horizontal_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Project off_line_point onto horizontal line + template_off_line = template_coords[off_line_idx] + template_horizontal_start = template_coords[best_horizontal_line_points[0][0]] + template_horizontal_end = template_coords[best_horizontal_line_points[-1][0]] + + # Project at same x as off_line_point + template_x_off = template_off_line[0] + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_proj = (template_x_off - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_proj = 0.5 + + frame_horizontal_start = best_horizontal_line_points[0][1] + frame_horizontal_end = best_horizontal_line_points[-1][1] + proj_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_proj + proj_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + + if best_horizontal_line_name == '18-26': + if off_line_idx == 22 or off_line_idx == 23: # 23 or 24 point is off line point, so we can calculate 27 or 28 + template_18 = template_coords[best_horizontal_line_points[0][0]] # 18 point, index is 17 + template_26 = template_coords[best_horizontal_line_points[-1][0]] # 26 point, index is 25 + template_23 = template_coords[off_line_idx] # 23 or 24 point, index is 22 or 23 + + ratio_26 = (template_26[0] - template_23[0]) / (template_26[0] - template_18[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_26 + + if off_line_idx == 22: + result[26] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[26] = tuple(detected_point.astype(int)) # 26 point is now available, index is 26 + else: + result[27] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[27] = tuple(detected_point.astype(int)) # 27 point is now available, index is 27 + + if best_horizontal_line_name == '23-27': + if off_line_idx == 17 or off_line_idx == 20: + template_18 = template_coords[17] # 18 point, index is 17 + template_26 = template_coords[25] # 26 point, index is 25 + template_23 = template_coords[best_horizontal_line_points[0][0]] # 23 , index is 22 + + ratio_26 = (template_26[0] - template_18[0]) / (template_26[0] - template_23[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_26 + + if off_line_idx == 17: + result[25] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(detected_point.astype(int)) # 26 point is now available, index is 25 + else: + result[28] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(detected_point.astype(int)) # 29 point is now available, index is 28 + + if off_line_idx == 18 or off_line_idx == 19: # 19 or 20 point is off line point, so we can calculate 18 + template_18 = template_coords[17] # 18 point, index is 17 + template_off_line = template_coords[off_line_idx] + template_23 = template_coords[best_horizontal_line_points[0][0]] # 23 point, index is 22 + + ratio = (template_off_line[1] - template_18[1]) / (template_off_line[1] - template_23[1]) + kp_18 = off_line_point + (proj_point - off_line_point) * ratio + + result[17] = tuple(kp_18.astype(int)) + total_right_side_count += 1 + all_available_points[17] = tuple(kp_18.astype(int)) # 18 point is now available, index is 17 + + if best_horizontal_line_name == '24-28': + if off_line_idx == 17 or off_line_idx == 20: + template_18 = template_coords[17] # 18 point, index is 17 + template_26 = template_coords[25] # 26 point, index is 25 + template_24 = template_coords[best_horizontal_line_points[0][0]] # 24 , index is 23 + + ratio_26 = (template_26[0] - template_18[0]) / (template_26[0] - template_24[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_26 + + if off_line_idx == 17: + result[25] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(detected_point.astype(int)) # 26 point is now available, index is 25 + else: + result[28] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(detected_point.astype(int)) # 29 point is now available, index is 28 + + if off_line_idx == 18 or off_line_idx == 19: # 19 or 20 point is off line point, so we can calculate 18 + template_21 = template_coords[20] # 21 point, index is 20 + template_off_line = template_coords[off_line_idx] + template_24 = template_coords[best_horizontal_line_points[0][0]] # 24 point, index is 23 + + ratio = (template_21[1] - template_off_line[1]) / (template_24[1] - template_off_line[1]) + kp_21 = off_line_point + (proj_point - off_line_point) * ratio + + result[20] = tuple(kp_18.astype(int)) + total_right_side_count += 1 + all_available_points[20] = tuple(kp_18.astype(int)) # 21 point is now available, index is 20 + + if best_horizontal_line_name == '21-29': + if off_line_idx == 22 or off_line_idx == 23: # 23 or 24 point is off line point, so we can calculate 27 or 28 + template_21 = template_coords[best_horizontal_line_points[0][0]] # 21 point, index is 20 + template_29 = template_coords[best_horizontal_line_points[-1][0]] # 29 point, index is 28 + template_23 = template_coords[off_line_idx] # 23 or 24 point, index is 22 or 23 + + ratio_29 = (template_29[0] - template_23[0]) / (template_29[0] - template_21[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_29 + + if off_line_idx == 22: + result[26] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[26] = tuple(detected_point.astype(int)) # 26 point is now available, index is 26 + else: + result[27] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[27] = tuple(detected_point.astype(int)) # 27 point is now available, index is 27 + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_vertical_points: + max_vertical_points = len(line_points) + best_vertical_line_name = line_name + best_vertical_line_points = line_points + + # Calculate keypoint 22 if we have at least one line + if best_vertical_line_name is not None and best_horizontal_line_name is not None: + if kp_22 is None: + print(f"Calculating keypoint 22 using both vertical and horizontal lines: {best_vertical_line_name} and {best_horizontal_line_name}") + + template_x_22 = 940 + template_y_22 = 340 + + # Step 2: Project keypoint 22 onto vertical line (if available) + + template_vertical_start = template_coords[best_vertical_line_points[0][0]] + template_vertical_end = template_coords[best_vertical_line_points[-1][0]] + + # Project at y=340 (same y as keypoint 22) + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_22_vertical = (template_y_22 - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_22_vertical = 0.5 + + frame_vertical_start = best_vertical_line_points[0][1] + frame_vertical_end = best_vertical_line_points[-1][1] + proj_22_on_vertical_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_22_vertical + proj_22_on_vertical_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_22_vertical + proj_22_on_vertical = (proj_22_on_vertical_x, proj_22_on_vertical_y) + + # Step 3: Project keypoint 22 onto horizontal line (if available) + + template_horizontal_start = template_coords[best_horizontal_line_points[0][0]] + template_horizontal_end = template_coords[best_horizontal_line_points[-1][0]] + + # Project at x=940 (same x as keypoint 22) + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_22_horizontal = (template_x_22 - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_22_horizontal = 0.5 + + frame_horizontal_start = best_horizontal_line_points[0][1] + frame_horizontal_end = best_horizontal_line_points[-1][1] + proj_22_on_horizontal_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_22_horizontal + proj_22_on_horizontal_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_22_horizontal + proj_22_on_horizontal = (proj_22_on_horizontal_x, proj_22_on_horizontal_y) + + # Step 4: Calculate keypoint 22 as intersection of two lines + # Line 1: Passes through proj_22_on_vertical, parallel to best_horizontal_line + # Line 2: Passes through proj_22_on_horizontal, parallel to best_vertical_line + + # Calculate direction vector of best_horizontal_line + horizontal_dir_x = frame_horizontal_end[0] - frame_horizontal_start[0] + horizontal_dir_y = frame_horizontal_end[1] - frame_horizontal_start[1] + horizontal_dir_length = np.sqrt(horizontal_dir_x**2 + horizontal_dir_y**2) + + # Calculate direction vector of best_vertical_line + vertical_dir_x = frame_vertical_end[0] - frame_vertical_start[0] + vertical_dir_y = frame_vertical_end[1] - frame_vertical_start[1] + vertical_dir_length = np.sqrt(vertical_dir_x**2 + vertical_dir_y**2) + + if horizontal_dir_length > 1e-6 and vertical_dir_length > 1e-6: + # Normalize direction vectors + horizontal_dir_x /= horizontal_dir_length + horizontal_dir_y /= horizontal_dir_length + vertical_dir_x /= vertical_dir_length + vertical_dir_y /= vertical_dir_length + + # Line 1: passes through proj_22_on_vertical with direction of best_horizontal_line + # Parametric: p1 = proj_22_on_vertical + t * horizontal_dir + # Line 2: passes through proj_22_on_horizontal with direction of best_vertical_line + # Parametric: p2 = proj_22_on_horizontal + s * vertical_dir + + # Find intersection: proj_22_on_vertical + t * horizontal_dir = proj_22_on_horizontal + s * vertical_dir + # This gives us: + # proj_22_on_vertical[0] + t * horizontal_dir_x = proj_22_on_horizontal[0] + s * vertical_dir_x + # proj_22_on_vertical[1] + t * horizontal_dir_y = proj_22_on_horizontal[1] + s * vertical_dir_y + + # Rearranging: + # t * horizontal_dir_x - s * vertical_dir_x = proj_22_on_horizontal[0] - proj_22_on_vertical[0] + # t * horizontal_dir_y - s * vertical_dir_y = proj_22_on_horizontal[1] - proj_22_on_vertical[1] + + # Solve for t and s using linear algebra + A = np.array([ + [horizontal_dir_x, -vertical_dir_x], + [horizontal_dir_y, -vertical_dir_y] + ]) + b = np.array([ + proj_22_on_horizontal[0] - proj_22_on_vertical[0], + proj_22_on_horizontal[1] - proj_22_on_vertical[1] + ]) + + try: + t, s = np.linalg.solve(A, b) + + # Calculate intersection point using line 1 + x_22 = proj_22_on_vertical[0] + t * horizontal_dir_x + y_22 = proj_22_on_vertical[1] + t * horizontal_dir_y + + result[21] = (int(round(x_22)), int(round(y_22))) + total_right_side_count += 1 + except np.linalg.LinAlgError: + # Lines are parallel or nearly parallel, use simple intersection + # If lines are parallel, use the projection points directly + x_22 = proj_22_on_vertical[0] + y_22 = proj_22_on_horizontal[1] + result[21] = (int(round(x_22)), int(round(y_22))) + total_right_side_count += 1 + else: + # Fallback: use simple intersection + x_22 = proj_22_on_vertical[0] + y_22 = proj_22_on_horizontal[1] + result[21] = (int(round(x_22)), int(round(y_22))) + total_right_side_count += 1 + + print(f"total_right_side_count: {total_right_side_count}, result: {result}") + if total_right_side_count > 5: + return result + + # Calculate m_line and b_line from best vertical or horizontal line for use in calculating other points + m_line = None + b_line = None + best_line_for_calc = None + best_line_type_for_calc = None + + if best_vertical_line_name is not None and len(best_vertical_line_points) >= 2: + best_line_for_calc = best_vertical_line_points + best_line_type_for_calc = 'vertical' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_vertical_line_points]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line, b_line = np.linalg.lstsq(A, y_coords, rcond=None)[0] + elif best_horizontal_line_name is not None and len(best_horizontal_line_points) >= 2: + best_line_for_calc = best_horizontal_line_points + best_line_type_for_calc = 'horizontal' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_horizontal_line_points]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line, b_line = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate missing points to reach exactly 5 points + # Ensure 4 points aren't all on one line + if total_right_side_count < 5 and (m_line is not None or (best_line_for_calc is not None and best_line_type_for_calc == 'vertical')): + # Check current distribution + counts_per_line = [ + len(line_18_21_points), + len(line_23_24_points), + len(line_25_30_points) + ] + + # Calculate points on line 18-21 if needed + template_ys_18_21 = [140, 270, 410, 540] + template_indices_18_21 = [17, 18, 19, 20] + + if best_vertical_line_name == '25-30': + # Construct parallel line 18-21 from line 25-30 + for template_y, idx in zip(template_ys_18_21, template_indices_18_21): + if result[idx] is None and total_right_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 18-21 + if max(new_counts) >= 4 and total_right_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate y using scale from template + ref_ys = [kp[1] for _, kp in line_25_30_points] + ref_template_ys = [5, 140, 250, 430, 540, 675] + ref_indices = [24, 25, 26, 27, 28, 29] + + matched_template_ys = [] + for ref_idx, ref_kp in line_25_30_points: + if ref_idx in ref_indices: + template_idx = ref_indices.index(ref_idx) + matched_template_ys.append((ref_template_ys[template_idx], ref_kp[1])) + + if len(matched_template_ys) >= 1: + ref_template_y, ref_frame_y = matched_template_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + # Calculate x using parallel line geometry + if abs(m_line) > 1e-6: + x_on_line_25_30 = (y_new - b_line) / m_line + x_new = int(round(x_on_line_25_30 * 0.850)) + else: + x_new = int(round(np.median([kp[0] for _, kp in line_25_30_points]) * 0.850)) + + result[idx] = (x_new, y_new) + total_right_side_count += 1 + if total_right_side_count >= 5: + break + elif best_vertical_line_name == '18-21': + # Calculate missing points on line 18-21 + for template_y, idx in zip(template_ys_18_21, template_indices_18_21): + if result[idx] is None and total_right_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 18-21 + if max(new_counts) >= 4 and total_right_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate x on the line + if abs(m_line) > 1e-6: + x_new = (template_y - b_line) / m_line + else: + x_new = np.median([kp[0] for _, kp in line_18_21_points]) + + # Scale y based on available points + ref_ys = [kp[1] for _, kp in line_18_21_points] + ref_template_ys = [] + for ref_idx, _ in line_18_21_points: + if ref_idx in template_indices_18_21: + template_idx = template_indices_18_21.index(ref_idx) + ref_template_ys.append(template_ys_18_21[template_idx]) + + if len(ref_ys) >= 1 and len(ref_template_ys) >= 1: + ref_template_y = ref_template_ys[0] + ref_frame_y = ref_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + result[idx] = (int(round(x_new)), y_new) + total_right_side_count += 1 + if total_right_side_count >= 5: + break + + # Note: The unified approach above handles all cases (2a and 2b combined) + # Legacy code removed - all logic is now in the unified case 2 above + + return result + +def check_keypoints_would_cause_invalid_mask( + frame_keypoints: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]] = None, + frame: np.ndarray = None, + floor_markings_template: np.ndarray = None, + return_warped_data: bool = False, +) -> tuple[bool, str] | tuple[bool, str, tuple]: + """ + Check if keypoints would cause InvalidMask errors during evaluation. + + Args: + frame_keypoints: Frame keypoints to check + template_keypoints: Template keypoints (defaults to TEMPLATE_KEYPOINTS) + frame: Optional frame image for full validation + floor_markings_template: Optional template image for full validation + + Returns: + Tuple of (would_cause_error, error_message) + """ + try: + from keypoint_evaluation import ( + validate_projected_corners, + TEMPLATE_KEYPOINTS, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + findHomography, + InvalidMask, + ) + + if template_keypoints is None: + template_keypoints = TEMPLATE_KEYPOINTS + + # Filter valid keypoints + filtered_template = [] + filtered_frame = [] + + for i, (t_kp, f_kp) in enumerate(zip(template_keypoints, frame_keypoints)): + if f_kp[0] > 0 and f_kp[1] > 0: + filtered_template.append(t_kp) + filtered_frame.append(f_kp) + + if len(filtered_template) < 4: + if return_warped_data: + return (True, "Not enough keypoints for homography", None) + return (True, "Not enough keypoints for homography") + + # Compute homography + src_pts = np.array(filtered_template, dtype=np.float32) + dst_pts = np.array(filtered_frame, dtype=np.float32) + + result = findHomography(src_pts, dst_pts) + if result is None: + if return_warped_data: + return (True, "Failed to compute homography", None) + return (True, "Failed to compute homography") + H, _ = result + + # Check for twisted projection (bowtie) + try: + validate_projected_corners( + source_keypoints=template_keypoints, + homography_matrix=H + ) + except Exception as e: + error_msg = "Projection twisted (bowtie)" if "twisted" in str(e).lower() or "Projection twisted" in str(e).lower() else str(e) + if return_warped_data: + return (True, error_msg, None) + return (True, error_msg) + + # If frame and template are provided, check mask validation + if frame is not None and floor_markings_template is not None: + try: + from keypoint_evaluation import ( + project_image_using_keypoints, + extract_masks_for_ground_and_lines, + InvalidMask, + ) + + # project_image_using_keypoints can raise InvalidMask from validate_projected_corners + try: + # start_time = time.time() + warped_template = project_image_using_keypoints( + image=floor_markings_template, + source_keypoints=template_keypoints, + destination_keypoints=frame_keypoints, + destination_width=frame.shape[1], + destination_height=frame.shape[0], + ) + # end_time = time.time() + # print(f"project_image_using_keypoints time: {end_time - start_time} seconds") + except InvalidMask as e: + if return_warped_data: + return (True, f"Projection validation failed: {e}", None) + return (True, f"Projection validation failed: {e}") + except Exception as e: + # Other errors (e.g., ValueError from homography failure) + if return_warped_data: + return (True, f"Projection failed: {e}", None) + return (True, f"Projection failed: {e}") + + # extract_masks_for_ground_and_lines can raise InvalidMask from validation + try: + mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines( + image=warped_template + ) + except InvalidMask as e: + if return_warped_data: + return (True, f"Mask extraction validation failed: {e}", None) + return (True, f"Mask extraction validation failed: {e}") + except Exception as e: + if return_warped_data: + return (True, f"Mask extraction failed: {e}", None) + return (True, f"Mask extraction failed: {e}") + + # Additional explicit validation (though extract_masks_for_ground_and_lines already validates) + from keypoint_evaluation import validate_mask_lines, validate_mask_ground + try: + validate_mask_lines(mask_lines_expected) + except InvalidMask as e: + if return_warped_data: + return (True, f"Mask lines validation failed: {e}", None) + return (True, f"Mask lines validation failed: {e}") + except Exception as e: + if return_warped_data: + return (True, f"Mask lines validation error: {e}", None) + return (True, f"Mask lines validation error: {e}") + + try: + validate_mask_ground(mask_ground) + except InvalidMask as e: + if return_warped_data: + return (True, f"Mask ground validation failed: {e}", None) + return (True, f"Mask ground validation failed: {e}") + except Exception as e: + if return_warped_data: + return (True, f"Mask ground validation error: {e}", None) + return (True, f"Mask ground validation error: {e}") + + # If return_warped_data is True and validation passed, return the computed data + if return_warped_data: + return (False, "", (warped_template, mask_ground, mask_lines_expected)) + + except ImportError: + # If keypoint_evaluation is not available, skip validation + pass + except InvalidMask as e: + # Catch any InvalidMask that wasn't caught above + if return_warped_data: + return (True, f"InvalidMask error: {e}", None) + return (True, f"InvalidMask error: {e}") + except Exception as e: + # If we can't check masks for other reasons, assume it's okay + # Don't let exceptions propagate + pass + + # If we get here, keypoints should be valid + if return_warped_data: + return (False, "", None) # No warped data if frame/template not provided + return (False, "") + + except ImportError: + # If keypoint_evaluation is not available, skip validation + if return_warped_data: + return (False, "", None) + return (False, "") + except Exception as e: + # Any other error - assume it would cause problems + if return_warped_data: + return (True, f"Validation error: {e}", None) + return (True, f"Validation error: {e}") + + +def evaluate_keypoints_with_cached_data( + frame: np.ndarray, + mask_ground: np.ndarray, + mask_lines_expected: np.ndarray, +) -> float: + """ + Evaluate keypoints using pre-computed warped template and masks. + This avoids redundant computation when we already have the warped data from validation. + + Args: + frame: Frame image + mask_ground: Pre-computed ground mask from warped template + mask_lines_expected: Pre-computed expected lines mask from warped template + + Returns: + Score between 0.0 and 1.0 + """ + try: + from keypoint_evaluation import ( + extract_mask_of_ground_lines_in_image, + bitwise_and, + ) + + # Only need to extract predicted lines from frame (uses cached mask_ground) + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame, ground_mask=mask_ground + ) + + pixels_overlapping = bitwise_and( + mask_lines_expected, mask_lines_predicted + ).sum() + + pixels_on_lines = mask_lines_expected.sum() + + score = pixels_overlapping / (pixels_on_lines + 1e-8) + + return min(1.0, max(0.0, score)) # Clamp to [0, 1] + + except Exception as e: + print(f'Error in cached keypoint evaluation: {e}') + return 0.0 + + +def check_and_evaluate_keypoints( + frame_keypoints: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]], + frame: np.ndarray, + floor_markings_template: np.ndarray, +) -> tuple[bool, float]: + """ + Check if keypoints would cause InvalidMask errors and evaluate them in one call. + This reuses the warped template and masks computed during validation for evaluation. + + Args: + frame_keypoints: Frame keypoints to check and evaluate + template_keypoints: Template keypoints + frame: Frame image + floor_markings_template: Template image + + Returns: + Tuple of (is_valid, score). If is_valid is False, score is 0.0. + """ + # Check with return_warped_data=True to get cached data + # start_time = time.time() + check_result = check_keypoints_would_cause_invalid_mask( + frame_keypoints, template_keypoints, frame, floor_markings_template, + return_warped_data=True + ) + # end_time = time.time() + # print(f"check_keypoints_would_cause_invalid_mask time: {end_time - start_time} seconds") + + if len(check_result) == 3: + would_cause_error, error_msg, warped_data = check_result + else: + would_cause_error, error_msg = check_result + warped_data = None + + if would_cause_error: + return (False, 0.0) + + # If we have cached warped data, use it for fast evaluation + if warped_data is not None: + _, mask_ground, mask_lines_expected = warped_data + try: + score = evaluate_keypoints_with_cached_data( + frame, mask_ground, mask_lines_expected + ) + return (True, score) + except Exception as e: + print(f'Error evaluating with cached data: {e}') + return (True, 0.0) + + # Fallback to regular evaluation if no cached data + try: + from keypoint_evaluation import evaluate_keypoints_for_frame + score = evaluate_keypoints_for_frame( + template_keypoints, frame_keypoints, frame, floor_markings_template + ) + return (True, score) + except Exception as e: + print(f'Error in regular evaluation: {e}') + return (True, 0.0) + + +def adjust_keypoints_to_avoid_invalid_mask( + frame_keypoints: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]] = None, + frame: np.ndarray = None, + floor_markings_template: np.ndarray = None, + max_iterations: int = 5, +) -> list[tuple[int, int]]: + """ + Adjust keypoints to avoid InvalidMask errors. + + This function tries to fix common issues: + 1. Twisted projection (bowtie) - adjusts corner keypoints + 2. Ground covers too much - shrinks projected area by moving corners inward + 3. Other mask validation issues - adjusts keypoints to improve projection + + Args: + frame_keypoints: Frame keypoints to adjust + template_keypoints: Template keypoints + frame: Optional frame image for validation + floor_markings_template: Optional template image for validation + max_iterations: Maximum number of adjustment iterations + + Returns: + Adjusted keypoints that should avoid InvalidMask errors + """ + adjusted = list(frame_keypoints) + + # Check if adjustment is needed + would_cause_error, error_msg = check_keypoints_would_cause_invalid_mask( + adjusted, template_keypoints, frame, floor_markings_template + ) + print(f"Would cause error: {would_cause_error}, error_msg: {error_msg}") + if not would_cause_error: + return (True,adjusted) + + # Try to fix twisted projection (most common issue) + if "twisted" in error_msg.lower() or "bowtie" in error_msg.lower() or "Projection twisted" in error_msg.lower(): + # Use the existing _adjust_keypoints_to_pass_validation function + adjusted = _adjust_keypoints_to_pass_validation( + adjusted, template_keypoints, + frame.shape[1] if frame is not None else None, + frame.shape[0] if frame is not None else None + ) + + # Check again after adjustment + would_cause_error, error_msg = check_keypoints_would_cause_invalid_mask( + adjusted, template_keypoints, frame, floor_markings_template + ) + + if not would_cause_error: + return (True,adjusted) + + start_time = time.time() + # Handle "a projected line is too wide" error + # This happens when projected lines are too thick/wide (aspect ratio too high) + if "too wide" in error_msg.lower() or "wide line" in error_msg.lower(): + print(f"Adjusting keypoints to fix 'a projected line is too wide' error") + try: + # This error usually means the projection is creating lines that are too thick + # Strategy: Adjust keypoints to reduce projection distortion + + valid_keypoints = [] + for idx in range(len(adjusted)): + x, y = adjusted[idx] + if x == 0 and y == 0: + continue + valid_keypoints.append((idx, x, y)) + + if len(valid_keypoints) >= 4: + # Calculate center and spread of keypoints + center_x = sum(x for _, x, y in valid_keypoints) / len(valid_keypoints) + center_y = sum(y for _, x, y in valid_keypoints) / len(valid_keypoints) + + # Calculate distances from center + distances = [] + for idx, x, y in valid_keypoints: + dist = np.sqrt((x - center_x)**2 + (y - center_y)**2) + distances.append((idx, x, y, dist)) + + # Sort by distance + distances.sort(key=lambda d: d[3], reverse=True) + + # Strategy 1: Try moving keypoints slightly outward to reduce compression + # This can help if keypoints are too close together causing wide lines + best_wide_kps = None + best_wide_score = -1.0 + + # Try expanding keypoints slightly (opposite of shrinking) + for expand_factor in [1.02, 1.05, 1.08, 1.10]: + test_kps = list(adjusted) + for idx, x, y, dist in distances: + # Move keypoint slightly away from center + new_x = int(round(center_x + (x - center_x) * expand_factor)) + new_y = int(round(center_y + (y - center_y) * expand_factor)) + test_kps[idx] = (new_x, new_y) + + # Check and get cached warped data if available + check_result = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template, + return_warped_data=(frame is not None and floor_markings_template is not None) + ) + + if len(check_result) == 3: + would_cause_error, test_error_msg, warped_data = check_result + else: + would_cause_error, test_error_msg = check_result + warped_data = None + + if not would_cause_error: + # Evaluate score for this adjustment + if frame is not None and floor_markings_template is not None: + try: + if warped_data is not None: + # Use cached warped data for faster evaluation + _, mask_ground, mask_lines_expected = warped_data + score = evaluate_keypoints_with_cached_data( + frame, mask_ground, mask_lines_expected + ) + else: + # Fallback to regular evaluation + from keypoint_evaluation import evaluate_keypoints_for_frame + score = evaluate_keypoints_for_frame( + template_keypoints, test_kps, frame, floor_markings_template + ) + if score > best_wide_score: + best_wide_score = score + best_wide_kps = test_kps + print(f"Found valid wide-line adjustment (expand) with factor {expand_factor}, score: {score:.4f}") + except Exception: + # If score evaluation fails, use this adjustment anyway + print(f"Successfully adjusted keypoints for wide line (expand) with factor {expand_factor}") + return (True,test_kps) + else: + print(f"Successfully adjusted keypoints for wide line (expand) with factor {expand_factor}") + return (True,test_kps) + + # Strategy 2: If expanding didn't work, try adjusting individual keypoints + # Move keypoints that are too close together slightly apart + for idx, x, y, dist in distances: + # Try small adjustments to this keypoint + for adjust_x in [-3, -2, -1, 1, 2, 3]: + for adjust_y in [-3, -2, -1, 1, 2, 3]: + test_kps = list(adjusted) + test_kps[idx] = (x + adjust_x, y + adjust_y) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + if score > best_wide_score: + best_wide_score = score + best_wide_kps = test_kps + print(f"Found valid wide-line adjustment (perturb) for keypoint {idx}, score: {score:.4f}") + else: + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + return (True, test_kps) + + # Return the best scoring adjustment if we found any + if best_wide_kps is not None: + end_time = time.time() + print(f"Returning best scoring wide-line adjustment time: {end_time - start_time} seconds") + print(f"Returning best scoring wide-line adjustment with score: {best_wide_score:.4f}") + return (True,best_wide_kps) + + # Strategy 3: Try slight shrinking (opposite approach - reduce projection area) + for shrink_factor in [0.98, 0.96, 0.94]: + test_kps = list(adjusted) + for idx, x, y, dist in distances: + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + try: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + if score > best_wide_score: + best_wide_score = score + best_wide_kps = test_kps + print(f"Found valid wide-line adjustment (shrink) with factor {shrink_factor}, score: {score:.4f}") + except Exception: + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + return (True, test_kps) + else: + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + return (True, test_kps) + + if best_wide_kps is not None: + print(f"Returning best scoring wide-line adjustment with score: {best_wide_score:.4f}") + return (True,best_wide_kps) + except Exception as e: + print(f"Error in wide line adjustment: {e}") + pass + + # Handle "projected ground should be a single object" error + # This happens when the ground mask has multiple disconnected regions + if "should be a single" in error_msg.lower() or "single object" in error_msg.lower() or "distinct regions" in error_msg.lower(): + print(f"Adjusting keypoints to fix 'projected ground should be a single object' error") + try: + # This error usually means the projection creates gaps/holes in the ground mask + # We need to adjust keypoints to make the projection more continuous + + # Strategy 1: Try moving keypoints closer together to reduce gaps + valid_keypoints = [] + for idx in range(len(adjusted)): + x, y = adjusted[idx] + if x == 0 and y == 0: + continue + valid_keypoints.append((idx, x, y)) + + if len(valid_keypoints) >= 4: + # Calculate center of all keypoints + center_x = sum(x for _, x, y in valid_keypoints) / len(valid_keypoints) + center_y = sum(y for _, x, y in valid_keypoints) / len(valid_keypoints) + + # Try moving keypoints slightly closer to center to reduce fragmentation + # Use smaller adjustments to preserve geometry + best_single_kps = None + best_single_score = -1.0 + + for shrink_factor in [0.98, 0.96, 0.94, 0.92, 0.90]: + test_kps = list(adjusted) + for idx, x, y in valid_keypoints: + # Move keypoint slightly toward center + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + try: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + if score > best_single_score: + best_single_score = score + best_single_kps = test_kps + print(f"Found valid single-object adjustment with shrink_factor {shrink_factor}, score: {score:.4f}") + except Exception: + # If score evaluation fails, use this adjustment anyway + print(f"Successfully adjusted keypoints for single object with shrink_factor {shrink_factor}") + return (True, test_kps) + else: + # No frame/template for score evaluation, use first valid adjustment + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + print(f"Successfully adjusted keypoints for single object with shrink_factor {shrink_factor}") + return (True, test_kps) + + # Return the best scoring adjustment if we found any + if best_single_kps is not None: + print(f"Returning best scoring single-object adjustment with score: {best_single_score:.4f}") + return (True,best_single_kps) + + # Strategy 2: If moving toward center didn't work, try adjusting boundary keypoints + # Calculate distances from center + distances = [] + for idx, x, y in valid_keypoints: + dist = np.sqrt((x - center_x)**2 + (y - center_y)**2) + distances.append((idx, x, y, dist)) + + # Sort by distance (farthest first) + distances.sort(key=lambda d: d[3], reverse=True) + + # Try adjusting the farthest keypoints (which might be causing fragmentation) + for shrink_factor in [0.95, 0.90, 0.85]: + test_kps = list(adjusted) + # Adjust top 25% of farthest keypoints + boundary_count = max(1, len(distances) // 4) + for idx, x, y, dist in distances[:boundary_count]: + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + try: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + print(f"Found valid boundary adjustment for single object with shrink_factor {shrink_factor}, score: {score:.4f}") + return (True, test_kps) + except Exception: + return (True, test_kps) + else: + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + return (True, test_kps) + except Exception as e: + print(f"Error in single object adjustment: {e}") + pass + + # Handle "ground covers too much" error by shrinking the projected area + if "ground covers" in error_msg.lower() or "covers more than" in error_msg.lower(): + print(f"Adjusting keypoints to avoid 'ground covers too much' error") + try: + from keypoint_evaluation import ( + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + # First, try adjusting corners if available + corner_indices = [ + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ] + + # Get corner keypoints + corners = [] + center_x, center_y = 0, 0 + valid_corners = 0 + + for corner_idx in corner_indices: + if corner_idx < len(adjusted): + x, y = adjusted[corner_idx] + if x == 0 and y == 0: + continue + corners.append((corner_idx, x, y)) + center_x += x + center_y += y + valid_corners += 1 + + if valid_corners >= 4: + center_x /= valid_corners + center_y /= valid_corners + + # Move corners inward toward center to reduce projected area + # Try different shrink factors, starting with smaller adjustments to preserve score + # Use score-aware adjustment: evaluate score for each candidate and pick the best + best_corner_kps = None + best_corner_score = -1.0 + + for shrink_factor in [0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65]: + test_kps = list(adjusted) + for corner_idx, x, y in corners: + # Move corner toward center + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[corner_idx] = (new_x, new_y) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + try: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + if score > best_corner_score: + best_corner_score = score + best_corner_kps = test_kps + print(f"Found valid corner adjustment with shrink_factor {shrink_factor}, score: {score:.4f}") + except Exception: + # If score evaluation fails, use this adjustment anyway + return (True, test_kps) + else: + # No frame/template for score evaluation, use first valid adjustment + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + return (True, test_kps) + + # Return the best scoring corner adjustment if we found any + if best_corner_kps is not None: + print(f"Returning best scoring corner adjustment with score: {best_corner_score:.4f}") + return (True,best_corner_kps) + + # If corners adjustment didn't work or we don't have enough corners, + # try adjusting individual keypoints one at a time + # This handles cases where non-corner keypoints (like 15, 16, 17, 31, 32) are causing the issue + valid_keypoints = [] + all_center_x, all_center_y = 0, 0 + valid_count = 0 + + for idx in range(len(adjusted)): + x, y = adjusted[idx] + if x == 0 and y == 0: + continue + valid_keypoints.append((idx, x, y)) + all_center_x += x + all_center_y += y + valid_count += 1 + + if valid_count >= 4: + all_center_x /= valid_count + all_center_y /= valid_count + + # Calculate distances from center for each keypoint + # Try adjusting keypoints farthest from center first (most likely to cause coverage issues) + distances = [] + for idx, x, y in valid_keypoints: + dist = np.sqrt((x - all_center_x)**2 + (y - all_center_y)**2) + distances.append((idx, x, y, dist)) + + # Sort by distance (farthest first) - these are most likely causing the coverage issue + distances.sort(key=lambda d: d[3], reverse=True) + + # Try adjusting each keypoint individually, starting with farthest from center + # Use score-aware adjustment: evaluate score for each candidate and pick the best + best_kps = None + best_score = -1.0 + + for idx, x, y, dist in distances: + # Try different shrink factors for this single keypoint + # Start with smaller adjustments to preserve score better + for shrink_factor in [0.98, 0.95, 0.92, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65]: + test_kps = list(adjusted) # Start with original adjusted keypoints + # Only adjust this one keypoint + new_x = int(round(all_center_x + (x - all_center_x) * shrink_factor)) + new_y = int(round(all_center_y + (y - all_center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + try: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + if score > best_score: + best_score = score + best_kps = test_kps + print(f"Found valid adjustment for keypoint {idx} with shrink_factor {shrink_factor}, score: {score:.4f}") + except Exception: + # If score evaluation fails, use this adjustment anyway + print(f"Successfully adjusted keypoint {idx} (distance {dist:.1f} from center) with shrink_factor {shrink_factor}") + return (True, test_kps) + else: + # No frame/template for score evaluation, use first valid adjustment + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + print(f"Successfully adjusted keypoint {idx} (distance {dist:.1f} from center) with shrink_factor {shrink_factor}") + return (True, test_kps) + + # Return the best scoring adjustment if we found any + if best_kps is not None: + print(f"Returning best scoring adjustment with score: {best_score:.4f}") + return (True,best_kps) + + # If adjusting individual keypoints didn't work, try adjusting pairs of keypoints + # (but only if we have enough keypoints) + if valid_count >= 6: + # Try adjusting the two farthest keypoints together + # Use score-aware adjustment here too + best_pair_kps = None + best_pair_score = -1.0 + + for shrink_factor in [0.95, 0.90, 0.85, 0.80, 0.75, 0.70]: + test_kps = list(adjusted) + # Adjust top 2 farthest keypoints + for idx, x, y, dist in distances[:2]: + new_x = int(round(all_center_x + (x - all_center_x) * shrink_factor)) + new_y = int(round(all_center_y + (y - all_center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + try: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + if score > best_pair_score: + best_pair_score = score + best_pair_kps = test_kps + print(f"Found valid pair adjustment with shrink_factor {shrink_factor}, score: {score:.4f}") + except Exception: + # If score evaluation fails, use this adjustment anyway + print(f"Successfully adjusted 2 farthest keypoints with shrink_factor {shrink_factor}") + return (True, test_kps) + else: + # No frame/template for score evaluation, use first valid adjustment + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + print(f"Successfully adjusted 2 farthest keypoints with shrink_factor {shrink_factor}") + return (True, test_kps) + + # Return the best scoring pair adjustment if we found any + if best_pair_kps is not None: + print(f"Returning best scoring pair adjustment with score: {best_pair_score:.4f}") + return (True,best_pair_kps) + except Exception as e: + print(f"Error in ground coverage adjustment: {e}") + pass + + # If still causing errors, try small perturbations to corner keypoints + # This helps with mask validation issues + if would_cause_error and max_iterations > 0: + try: + from keypoint_evaluation import ( + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + corner_indices = [ + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ] + + # Try small adjustments to corners and evaluate to find the best one + best_corner_kps = None + best_corner_score = -1.0 + + for corner_idx in corner_indices: + if corner_idx < len(adjusted): + x, y = adjusted[corner_idx] + if x == 0 and y == 0: + continue + # Try small perturbations + for dx in [-5, -3, -1, 1, 3, 5]: + for dy in [-5, -3, -1, 1, 3, 5]: + test_kps = list(adjusted) + test_kps[corner_idx] = (x + dx, y + dy) + + # Use optimized check_and_evaluate to reuse warped data + if frame is not None and floor_markings_template is not None: + try: + is_valid, score = check_and_evaluate_keypoints( + test_kps, template_keypoints, frame, floor_markings_template + ) + if is_valid: + if score > best_corner_score: + best_corner_score = score + best_corner_kps = test_kps + print(f"Found valid corner perturbation: corner {corner_idx} with adjust ({dx}, {dy}), score: {score:.4f}") + except Exception: + pass + else: + # No frame/template, just check validation + would_cause_error, _ = check_keypoints_would_cause_invalid_mask( + test_kps, template_keypoints, frame, floor_markings_template + ) + if not would_cause_error: + return (True, test_kps) + + # Return the best scoring corner adjustment if we found any + if best_corner_kps is not None: + print(f"Returning best scoring corner perturbation with score: {best_corner_score:.4f}") + return (True, best_corner_kps) + except Exception: + pass + + # If we can't fix it, return adjusted (best effort) + return (False,adjusted) + + +def _validate_keypoints_corners( + frame_keypoints: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]] = None, +) -> bool: + """ + Validate that frame keypoints can form a valid homography with template keypoints + (corners don't create twisted projection). + + Returns True if validation passes, False otherwise. + """ + try: + from keypoint_evaluation import ( + validate_projected_corners, + TEMPLATE_KEYPOINTS, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + # Use provided template_keypoints or default TEMPLATE_KEYPOINTS + if template_keypoints is None: + template_keypoints = TEMPLATE_KEYPOINTS + + # Filter valid keypoints (non-zero) + filtered_template = [] + filtered_frame = [] + + for i, (t_kp, f_kp) in enumerate(zip(template_keypoints, frame_keypoints)): + if f_kp[0] > 0 and f_kp[1] > 0: # Frame keypoint is valid + filtered_template.append(t_kp) + filtered_frame.append(f_kp) + + if len(filtered_template) < 4: + return False # Not enough keypoints for homography + + # Compute homography from template to frame + src_pts = np.array(filtered_template, dtype=np.float32) + dst_pts = np.array(filtered_frame, dtype=np.float32) + + H, mask = cv2.findHomography(src_pts, dst_pts) + + if H is None: + return False # Homography computation failed + + # Validate corners using the homography + try: + validate_projected_corners( + source_keypoints=template_keypoints, + homography_matrix=H + ) + return True # Validation passed + except Exception: + return False # Validation failed (twisted projection) + + except ImportError: + # If keypoint_evaluation is not available, skip validation + return True + except Exception: + # Any other error - assume invalid + return False + +def predict_failed_indices( + results_frames: Sequence[Any], + template_keypoints: list[tuple[int, int]] = None, + frame_width: int = None, + frame_height: int = None, + frames: List[np.ndarray] = None, + floor_markings_template: np.ndarray = None, + offset: int = 0, +) -> list[int]: + """ + Predict failed frame indices based on: + 1. Having <= 4 valid keypoints (after calculating missing ones) + 2. Failing validate_projected_corners validation (twisted projection) + + For each frame, tries to calculate missing keypoints first. If after calculation + we have more than 5 keypoints, the frame is not marked as failed. + """ + max_frames = len(results_frames) + if max_frames == 0: + return [] + + failed_indices: list[int] = [] + for frame_index, frame_result in enumerate(results_frames): + frame_keypoints = getattr(frame_result, "keypoints", []) or [] + original_count = sum(1 for (x, y) in frame_keypoints if int(x) != 0 and int(y) != 0) + + # Try to calculate missing keypoints + # First, remove duplicate/conflicting detections (e.g., same point detected as both 13 and 21) + cleaned_keypoints = remove_duplicate_detections( + frame_keypoints, frame_width, frame_height + ) + + valid_keypoint_indices = [idx for idx, kp in enumerate(cleaned_keypoints) if kp[0] != 0 and kp[1] != 0] + + if len(valid_keypoint_indices) > 5: + calculated_keypoints = cleaned_keypoints + else: + left_side_indices_range = range(0, 13) + right_side_indices_range = range(17, 30) + + side_check_set = set() + if len(valid_keypoint_indices) >= 4: + for idx in valid_keypoint_indices: + if idx in left_side_indices_range: + side_check_set.add("left") + elif idx in right_side_indices_range: + side_check_set.add("right") + else: + side_check_set.add("center") + + if len(side_check_set) > 1: + calculated_keypoints = cleaned_keypoints + else: + # Then calculate missing keypoints + calculated_keypoints = calculate_missing_keypoints( + cleaned_keypoints, frame_width, frame_height + ) + + # Get frame image if available + frame_image = None + if frames is not None and frame_index < len(frames): + frame_image = frames[frame_index] + + + original_frame_number = offset + frame_index + print(f"Frame {original_frame_number} (index {frame_index}): original_count: {original_count}, cleaned_keypoints: {len([kp for kp in cleaned_keypoints if kp[0] != 0 and kp[1] != 0])}, calculated_keypoints: {len([kp for kp in calculated_keypoints if kp[0] != 0 and kp[1] != 0])}") + + + start_time = time.time() + adjusted_success, calculated_keypoints = adjust_keypoints_to_avoid_invalid_mask( + calculated_keypoints, template_keypoints, frame_image, floor_markings_template + ) + end_time = time.time() + print(f"adjust_keypoints_to_avoid_invalid_mask time: {end_time - start_time} seconds") + if not adjusted_success: + failed_indices.append(frame_index) + continue + + print(f"after adjustment, calculated_keypoints: {calculated_keypoints}") + + # Update the frame result with calculated keypoints + setattr(frame_result, "keypoints", list(calculated_keypoints)) + + + return failed_indices + +def _generate_sparse_template_keypoints( + frame_width: int, + frame_height: int, + frame_image: np.ndarray = None, + template_image: np.ndarray = None, + template_keypoints: list[tuple[int, int]] = None +) -> list[tuple[int, int]]: + # Calculate template dimensions from template_keypoints if available, otherwise use default + if template_keypoints is not None and len(template_keypoints) > 0: + valid_template_points = [(x, y) for x, y in template_keypoints if x > 0 and y > 0] + if len(valid_template_points) > 0: + template_max_x = max(x for x, y in valid_template_points) + template_max_y = max(y for x, y in valid_template_points) + else: + template_max_x, template_max_y = (1045, 675) # Default fallback + else: + template_max_x, template_max_y = (1045, 675) # Default fallback + + # Calculate scaling factors for both dimensions + sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1) + sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1) + + # Always use uniform scaling to preserve pitch geometry and aspect ratio + # This prevents distortion that creates square contours (like 3x3, 4x4) which fail the wide line check + # Uniform scaling ensures the pitch maintains its shape and avoids twisted projections + uniform_scale = min(sx, sy) + + # Scale down significantly to create a much smaller pitch in the warped template + # Use a small fraction of the uniform scale to make the pitch as small as possible + # This creates a small pitch centered in the frame, avoiding edge artifacts + scale_factor = 0.25 # Use 25% of the frame-filling scale to make pitch much smaller + uniform_scale = uniform_scale * scale_factor + + # Ensure minimum scale to avoid keypoints being too close together + # Very small scales cause warping artifacts that create square contours (1x1, 2x2 pixels) + # These single-pixel artifacts trigger the "too wide" error + # Use a fixed minimum scale based on template dimensions to ensure keypoints are spaced properly + # This prevents warping artifacts regardless of frame size + # Template is 1045x675, need sufficient scale to avoid 1x1 pixel artifacts from warping + # Higher minimum scale ensures warped template doesn't create tiny square artifacts + min_scale_absolute = 0.5 # Fixed minimum 50% of template size to avoid 1x1 pixel warping artifacts + # Higher scale is necessary to prevent warping interpolation from creating single-pixel squares + uniform_scale = max(uniform_scale, min_scale_absolute) + + # Analyze line distribution to determine which keypoints to use and where to place them + # Default: use center line keypoints (15, 16, 31, 32) - indices 14, 15, 30, 31 + selected_keypoint_indices = set([14, 15, 30, 31]) # Default: 15, 16, 31, 32 + line_distribution = None # Will store: (top_count, bottom_count, left_count, right_count, total_count) + + # If we have line distribution analysis, select appropriate keypoints + if frame_image is not None and template_image is not None and template_keypoints is not None: + try: + from keypoint_evaluation import ( + project_image_using_keypoints, + extract_masks_for_ground_and_lines_no_validation, + extract_mask_of_ground_lines_in_image + ) + + # Generate initial keypoints for analysis using EXACT FITTING (full frame coverage) + # This ensures we get correct line distribution analysis + # Use non-uniform scaling to fit exactly to frame dimensions + initial_sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1) + initial_sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1) + initial_scaled = [] + num_template_kps = len(template_keypoints) if template_keypoints is not None else 32 + for i in range(max(32, num_template_kps)): # Ensure we have at least 32 keypoints + if i < num_template_kps: + tx, ty = template_keypoints[i] + if tx > 0 and ty > 0: # Only scale non-zero keypoints + # Use non-uniform scaling for exact fit + x_scaled = int(round(tx * initial_sx)) + y_scaled = int(round(ty * initial_sy)) + # Clamp to frame bounds + x_scaled = max(0, min(x_scaled, frame_width - 1)) + y_scaled = max(0, min(y_scaled, frame_height - 1)) + initial_scaled.append((x_scaled, y_scaled)) + else: + initial_scaled.append((0, 0)) + else: + initial_scaled.append((0, 0)) # Pad to 32 if template_keypoints has fewer + + # With exact fitting, keypoints already fill the frame, no centering needed + initial_centered = initial_scaled + + if len(initial_scaled) > 0: + try: + warped_template = project_image_using_keypoints( + image=template_image, + source_keypoints=template_keypoints, + destination_keypoints=initial_centered, + destination_width=frame_width, + destination_height=frame_height, + ) + + # Use non-validating version for line distribution analysis + # Exact fitting might create invalid masks, but we still want to analyze line distribution + mask_ground, mask_lines = extract_masks_for_ground_and_lines_no_validation(image=warped_template) + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame_image, ground_mask=mask_ground + ) + + h, w = mask_lines_predicted.shape + top_half = mask_lines_predicted[:h//2, :] + bottom_half = mask_lines_predicted[h//2:, :] + left_half = mask_lines_predicted[:, :w//2] + right_half = mask_lines_predicted[:, w//2:] + + top_line_count = np.sum(top_half > 0) + bottom_line_count = np.sum(bottom_half > 0) + left_line_count = np.sum(left_half > 0) + right_line_count = np.sum(right_half > 0) + total_line_count = top_line_count + bottom_line_count + + line_distribution = (top_line_count, bottom_line_count, left_line_count, right_line_count, total_line_count) + + # print(f"top_line_count: {top_line_count}, bottom_line_count: {bottom_line_count}, left_line_count: {left_line_count}, right_line_count: {right_line_count}, total_line_count: {total_line_count}") + + if total_line_count > 100: # Only use analysis if enough lines detected + # Select keypoints based on where lines are detected + # If lines at top -> use top part of pitch (so top part aligns with top where lines are) + # If lines at bottom -> use bottom part of pitch (so bottom part aligns with bottom where lines are) + # If lines at left -> use left part of pitch (so left part aligns with left where lines are) + # If lines at right -> use right part of pitch (so right part aligns with right where lines are) + + # Define keypoint sets for different regions + top_part = set([0, 1, 2, 9, 13, 14, 24, 25, 26]) # Top part of pitch + bottom_part = set([3, 4, 5, 12, 15, 16, 27, 28, 29]) # Bottom part of pitch + left_part = set(list(range(0, 13)) + [6, 7, 8]) # Left part of pitch + right_part = set(list(range(17, 30)) + [21, 22, 23]) # Right part of pitch + center_reference = set([13, 14, 15, 16, 30, 31]) # Center line and circle + + # Select vertical region - match where lines are detected + vertical_selection = set() + if top_line_count > bottom_line_count: + # Lines at top -> use top part of pitch + vertical_selection = top_part + else: + # Lines at bottom -> use bottom part of pitch + vertical_selection = bottom_part + + # Select horizontal region - match where lines are detected + horizontal_selection = set() + if left_line_count > right_line_count: + # Lines at left -> use left part of pitch + horizontal_selection = left_part + else: + # Lines at right -> use right part of pitch + horizontal_selection = right_part + + # Use intersection when both conditions are met, otherwise use the stronger signal + # This ensures we select a coherent region (e.g., bottom-right corner) rather than union + vertical_diff = abs(top_line_count - bottom_line_count) + horizontal_diff = abs(left_line_count - right_line_count) + + if vertical_diff > horizontal_diff * 1.5: + # Vertical signal is much stronger - use only vertical selection + selected_keypoint_indices = vertical_selection + elif horizontal_diff > vertical_diff * 1.5: + # Horizontal signal is much stronger - use only horizontal selection + selected_keypoint_indices = horizontal_selection + else: + # Both signals are similar - use intersection to get corner region + selected_keypoint_indices = vertical_selection & horizontal_selection + # If intersection is too small, fall back to union + if len(selected_keypoint_indices) < 4: + selected_keypoint_indices = vertical_selection | horizontal_selection + + # Always include center line and center circle for reference + selected_keypoint_indices.update(center_reference) + + # Ensure we have at least 4 keypoints + if len(selected_keypoint_indices) < 4: + selected_keypoint_indices = set([14, 15, 30, 31]) # Fallback to default + except Exception: + pass # Use default keypoints if analysis fails + except Exception: + pass # Use default keypoints if analysis fails + + # Generate scaled keypoints only for selected indices + # Use template_keypoints if available, otherwise fall back to FOOTBALL_KEYPOINTS + source_keypoints = template_keypoints if template_keypoints is not None else FOOTBALL_KEYPOINTS + num_keypoints = len(source_keypoints) if source_keypoints is not None else 32 + + scaled: list[tuple[int, int]] = [] + for i in range(num_keypoints): + if i in selected_keypoint_indices and i < len(source_keypoints): + tx, ty = source_keypoints[i] + if tx > 0 and ty > 0: # Only scale non-zero keypoints + x_scaled = int(round(tx * uniform_scale)) + y_scaled = int(round(ty * uniform_scale)) + scaled.append((x_scaled, y_scaled)) + else: + scaled.append((0, 0)) + else: + scaled.append((0, 0)) # Set unselected keypoints to (0, 0) + + # Ensure minimum spacing between keypoints to avoid warping artifacts + # Very close keypoints can create single-pixel square contours during warping + # Check if any keypoints are too close and adjust scale if needed + min_spacing = 5 # Minimum 5 pixels between keypoints to avoid 1x1 artifacts + needs_adjustment = False + for i in range(len(scaled)): + if scaled[i][0] == 0 and scaled[i][1] == 0: + continue + x1, y1 = scaled[i] + for j in range(i + 1, len(scaled)): + if scaled[j][0] == 0 and scaled[j][1] == 0: + continue + x2, y2 = scaled[j] + dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2) + if dist > 0 and dist < min_spacing: + needs_adjustment = True + break + if needs_adjustment: + break + + # If keypoints are too close, slightly increase scale to maintain minimum spacing + if needs_adjustment and uniform_scale < 0.25: + uniform_scale = uniform_scale * 1.2 # Increase by 20% to ensure spacing + uniform_scale = min(uniform_scale, 0.25) # Cap at 25% to keep it small + # Recalculate selected keypoints with adjusted scale + scaled = [] + for k in range(num_keypoints): + if k in selected_keypoint_indices and k < len(source_keypoints): + tx, ty = source_keypoints[k] + if tx > 0 and ty > 0: # Only scale non-zero keypoints + x_scaled = int(round(tx * uniform_scale)) + y_scaled = int(round(ty * uniform_scale)) + scaled.append((x_scaled, y_scaled)) + else: + scaled.append((0, 0)) + else: + scaled.append((0, 0)) + + # Use line distribution analysis (already computed above) to determine optimal pitch placement + offset_x = 0 + offset_y = 0 + + if line_distribution is not None: + top_line_count, bottom_line_count, left_line_count, right_line_count, total_line_count = line_distribution + + # Adjust keypoint placement based on line distribution + valid_points = [(x, y) for x, y in scaled if x > 0 and y > 0] + if len(valid_points) > 0: + scaled_width = max(x for x, y in valid_points) + scaled_height = max(y for x, y in valid_points) + + margin = 5 + offset_x = max(margin, (frame_width - scaled_width) // 2) + offset_x = min(offset_x, frame_width - scaled_width - margin) + offset_x = max(0, offset_x) + + # Only use line distribution analysis if we detected a reasonable number of lines + # Otherwise fall back to default centering + if total_line_count > 100: # Minimum threshold to trust the analysis + # Calculate the bounding box of selected keypoints in scaled coordinates + selected_scaled_points = [(x, y) for i, (x, y) in enumerate(scaled) + if i in selected_keypoint_indices and x > 0 and y > 0] + + if len(selected_scaled_points) > 0: + min_y_selected = min(y for x, y in selected_scaled_points) + max_y_selected = max(y for x, y in selected_scaled_points) + min_x_selected = min(x for x, y in selected_scaled_points) + max_x_selected = max(x for x, y in selected_scaled_points) + + # Determine which part of template the selected keypoints represent + # Check template coordinates to determine if selected keypoints are from top/bottom/left/right + if template_keypoints is not None: + selected_template_points = [(template_keypoints[i][0], template_keypoints[i][1]) + for i in selected_keypoint_indices + if i < len(template_keypoints) and template_keypoints[i][0] > 0] + + if len(selected_template_points) > 0: + template_min_y = min(y for x, y in selected_template_points) + template_max_y = max(y for x, y in selected_template_points) + template_min_x = min(x for x, y in selected_template_points) + template_max_x = max(x for x, y in selected_template_points) + + template_height = max(y for x, y in template_keypoints if x > 0) if template_keypoints else 675 + template_width = max(x for x, y in template_keypoints if x > 0) if template_keypoints else 1045 + + # Determine if selected keypoints are from top, bottom, left, or right part + is_top_part = template_min_y < template_height * 0.4 # Top 40% of template + is_bottom_part = template_max_y > template_height * 0.6 # Bottom 40% of template + is_left_part = template_min_x < template_width * 0.4 # Left 40% of template + is_right_part = template_max_x > template_width * 0.6 # Right 40% of template + + # Position selected keypoints to align with where lines are detected + if top_line_count > bottom_line_count: + # Lines detected at top -> align selected keypoints with top region + if is_bottom_part: + # Selected bottom part -> position so its top edge aligns with top of frame + offset_y = margin - min_y_selected # Shift so min_y_selected aligns with margin + elif is_top_part: + # Selected top part -> position at top + offset_y = margin - min_y_selected + else: + # Mixed or center -> position at top + offset_y = margin + else: + # Lines detected at bottom -> align selected keypoints with bottom region + if is_top_part: + # Selected top part -> position so its bottom edge aligns with bottom of frame + offset_y = frame_height - max_y_selected - margin # Shift so max_y_selected aligns with bottom + elif is_bottom_part: + # Selected bottom part -> position at bottom + offset_y = frame_height - max_y_selected - margin + else: + # Mixed or center -> position at bottom + offset_y = frame_height - scaled_height - margin + + # Horizontal alignment + if left_line_count > right_line_count: + # Lines detected at left -> align selected keypoints with left region + if is_right_part: + offset_x = margin - min_x_selected + elif is_left_part: + offset_x = margin - min_x_selected + else: + offset_x = max(margin, (frame_width - scaled_width) // 2) + else: + # Lines detected at right -> align selected keypoints with right region + if is_left_part: + offset_x = frame_width - max_x_selected - margin + elif is_right_part: + offset_x = frame_width - max_x_selected - margin + else: + offset_x = max(margin, (frame_width - scaled_width) // 2) + else: + # Fallback to simple positioning + if top_line_count > bottom_line_count: + offset_y = margin + else: + offset_y = frame_height - scaled_height - margin + else: + # Fallback if no template_keypoints available + if top_line_count > bottom_line_count: + offset_y = margin + else: + offset_y = frame_height - scaled_height - margin + + # Ensure reasonable bounds - keep pitch within frame + min_visible_height = scaled_height * 0.3 + offset_y = max(margin, min(offset_y, frame_height - scaled_height - margin)) + + # Ensure at least some portion of pitch is visible + if offset_y + scaled_height < min_visible_height: + offset_y = frame_height - min_visible_height - margin + if offset_y > frame_height - min_visible_height: + offset_y = margin + else: + # Not enough lines detected, use default centering + offset_y = max(margin, (frame_height - scaled_height) // 2) + offset_y = min(offset_y, frame_height - scaled_height - margin) + offset_y = max(0, offset_y) + else: + # Default centering if no line distribution analysis + valid_points = [(x, y) for x, y in scaled if x > 0 and y > 0] + if len(valid_points) > 0: + scaled_width = max(x for x, y in valid_points) + scaled_height = max(y for x, y in valid_points) + margin = 5 + offset_x = max(margin, (frame_width - scaled_width) // 2) + offset_y = max(margin, (frame_height - scaled_height) // 2) + offset_x = min(offset_x, frame_width - scaled_width - margin) + offset_y = min(offset_y, frame_height - scaled_height - margin) + offset_x = max(0, offset_x) + offset_y = max(0, offset_y) + + # Apply centering offset + centered = [] + for x, y in scaled: + if x == 0 and y == 0: + centered.append((0, 0)) + else: + new_x = x + offset_x + new_y = y + offset_y + # Allow negative y coordinates (pitch extends above frame) + # But ensure x coordinates are within frame bounds to avoid warping artifacts + new_x = max(0, min(new_x, frame_width - 1)) + # Allow negative y, but ensure at least some keypoints are in frame + # This prevents large square artifacts from warping + centered.append((new_x, new_y)) + + # Ensure at least some keypoints have positive y coordinates (visible in frame) + # This prevents warping from creating large square artifacts + visible_keypoints = [kp for kp in centered if kp[1] > 0] + if len(visible_keypoints) < 4: + # Not enough visible keypoints - adjust offset_y to ensure visibility + # This prevents warping artifacts that create large squares + min_y = min(y for x, y in centered if y != 0) if visible_keypoints else 0 + if min_y < 0: + adjustment = abs(min_y) + 10 # Push down by at least 10 pixels + centered = [] + for x, y in scaled: + if x == 0 and y == 0: + centered.append((0, 0)) + else: + new_x = x + offset_x + new_y = y + offset_y + adjustment + new_x = max(0, min(new_x, frame_width - 1)) + new_y = max(0, new_y) # Ensure at least some are visible + centered.append((new_x, new_y)) + + return centered + +def _adjust_keypoints_to_pass_validation( + keypoints: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]] = None, + frame_width: int = None, + frame_height: int = None, +) -> list[tuple[int, int]]: + """ + Adjust keypoints to pass validate_projected_corners. + If validation fails, try to fix by ensuring corners form a valid quadrilateral. + """ + if _validate_keypoints_corners(keypoints, template_keypoints): + return keypoints # Already valid + + # If validation fails, try to fix by ensuring corner keypoints are in correct order + try: + from keypoint_evaluation import ( + TEMPLATE_KEYPOINTS, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + if template_keypoints is None: + template_keypoints = TEMPLATE_KEYPOINTS + + # Get corner indices + corner_indices = [ + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ] + + # Check if we have all corner keypoints + corners = [] + for idx in corner_indices: + if idx < len(keypoints): + x, y = keypoints[idx] + if x > 0 and y > 0: + corners.append((x, y, idx)) + + if len(corners) < 4: + # Not enough corners - can't fix, return original + return keypoints + + # Extract corner coordinates + corner_coords = [(x, y) for x, y, _ in corners] + + # Check if corners form a bowtie (twisted quadrilateral) + # A bowtie occurs when opposite edges intersect + def segments_intersect(p1, p2, q1, q2): + """Check if line segments p1-p2 and q1-q2 intersect.""" + def ccw(a, b, c): + return (c[1] - a[1]) * (b[0] - a[0]) > (b[1] - a[1]) * (c[0] - a[0]) + return (ccw(p1, q1, q2) != ccw(p2, q1, q2)) and (ccw(p1, p2, q1) != ccw(p1, p2, q2)) + + # Try different corner orderings to find a valid one + # Current order: top-left, top-right, bottom-right, bottom-left + # If this creates a bowtie, we need to reorder + + # Sort corners by position to get proper order + # Top row (smaller y values) + top_corners = sorted([c for c in corners if c[1] <= np.mean([c[1] for c in corners])], + key=lambda c: c[0]) + # Bottom row (larger y values) + bottom_corners = sorted([c for c in corners if c[1] > np.mean([c[1] for c in corners])], + key=lambda c: c[0]) + + # If we have 2 top and 2 bottom corners, ensure proper ordering + if len(top_corners) == 2 and len(bottom_corners) == 2: + # Ensure left < right + if top_corners[0][0] > top_corners[1][0]: + top_corners = top_corners[::-1] + if bottom_corners[0][0] > bottom_corners[1][0]: + bottom_corners = bottom_corners[::-1] + + # Reconstruct with proper order: top-left, top-right, bottom-right, bottom-left + result = list(keypoints) + + # Map to corner indices + corner_mapping = { + INDEX_KEYPOINT_CORNER_TOP_LEFT: top_corners[0], + INDEX_KEYPOINT_CORNER_TOP_RIGHT: top_corners[1], + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT: bottom_corners[1], + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT: bottom_corners[0], + } + + for corner_idx, (x, y, _) in corner_mapping.items(): + if corner_idx < len(result): + result[corner_idx] = (x, y) + + # Validate the adjusted keypoints + if _validate_keypoints_corners(result, template_keypoints): + return result + + # Alternative: If we can't fix by reordering, try using template-based scaling + # for corners only, keeping other keypoints as-is + if len(corners) >= 4: + # Calculate scale from non-corner keypoints if available + non_corner_kps = [(i, keypoints[i]) for i in range(len(keypoints)) + if i not in corner_indices and keypoints[i][0] > 0 and keypoints[i][1] > 0] + + if len(non_corner_kps) >= 2: + # Use template scaling approach + scales_x = [] + scales_y = [] + for idx, (x, y) in non_corner_kps: + if idx < len(template_keypoints): + tx, ty = template_keypoints[idx] + if tx > 0: + scales_x.append(x / tx) + if ty > 0: + scales_y.append(y / ty) + + if scales_x and scales_y: + avg_scale_x = sum(scales_x) / len(scales_x) + avg_scale_y = sum(scales_y) / len(scales_y) + + result = list(keypoints) + # Recalculate corners using template scaling + for corner_idx in corner_indices: + if corner_idx < len(template_keypoints): + tx, ty = template_keypoints[corner_idx] + new_x = int(round(tx * avg_scale_x)) + new_y = int(round(ty * avg_scale_y)) + if corner_idx < len(result): + result[corner_idx] = (new_x, new_y) + + # Validate again + if _validate_keypoints_corners(result, template_keypoints): + return result + + except Exception: + pass + + # If we can't fix, return original + return keypoints + +def fix_keypoints( + results_frames: Sequence[Any], + failed_indices: Sequence[int], + frame_width: int, + frame_height: int, + template_keypoints: list[tuple[int, int]] = None, + frames: List[np.ndarray] = None, + floor_markings_template: np.ndarray = None, + offset: int = 0, +) -> list[Any]: + max_frames = len(results_frames) + if max_frames == 0: + return list(results_frames) + + failed_set = set(int(i) for i in failed_indices) + all_indices = list(range(max_frames)) + successful_indices = [i for i in all_indices if i not in failed_set] + + # Use actual frame dimensions instead of hardcoded values + # Using actual dimensions ensures keypoints match the frame and avoid warping artifacts + # if len(successful_indices) == 0: + # for frame_result in results_frames: + # setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(sparse_template))) + # return list(results_frames) + + # seed_index = successful_indices[0] + # seed_kps_raw = getattr(results_frames[seed_index], "keypoints", []) or [] + # last_success_kps = convert_keypoints_to_val_format(seed_kps_raw) + + # # Validate and adjust seed keypoints + # last_success_kps = _adjust_keypoints_to_pass_validation( + # last_success_kps, template_keypoints, frame_width, frame_height + # ) + + last_success_kps = None + + for frame_index in range(max_frames): + frame_result = results_frames[frame_index] + + if frame_index in failed_set and last_success_kps is not None: + # Substitute last_success_kps and validate/adjust + adjusted_kps = _adjust_keypoints_to_pass_validation( + last_success_kps, template_keypoints, frame_width, frame_height + ) + + setattr(frame_result, "keypoints", list(adjusted_kps)) + + else: + current_kps_raw = getattr(frame_result, "keypoints", []) or [] + current_kps = convert_keypoints_to_val_format(current_kps_raw) + + last_success_kps = current_kps + + setattr(frame_result, "keypoints", list(current_kps)) + + # Get original detected keypoints from frame_result + original_keypoints_raw = getattr(frame_result, "keypoints", []) or [] + original_keypoints = convert_keypoints_to_val_format(original_keypoints_raw) + + # Check if original keypoints are valid (have at least some non-zero keypoints) + original_keypoints_valid = len([kp for kp in original_keypoints if kp[0] != 0 or kp[1] != 0]) >= 4 + + # Generate sparse template keypoints with line distribution analysis for this frame + frame_for_analysis = None + template_for_analysis = None + if frames is not None and frame_index < len(frames): + frame_for_analysis = frames[frame_index] + if floor_markings_template is not None: + template_for_analysis = floor_markings_template + + # Generate sparse template keypoints for this specific frame + sparse_template = _generate_sparse_template_keypoints( + frame_width, + frame_height, + frame_image=frame_for_analysis, + template_image=template_for_analysis, + template_keypoints=template_keypoints + ) + + # Evaluate both keypoint sets and choose the one with higher score + final_keypoints = sparse_template + sparse_score = 0.0 + original_score = 0.0 + + # Only evaluate if we have frame/template and original keypoints might be better + if frame_for_analysis is not None and template_for_analysis is not None and template_keypoints is not None: + try: + from keypoint_evaluation import evaluate_keypoints_for_frame + + # Evaluate sparse template keypoints first + try: + sparse_score = evaluate_keypoints_for_frame( + template_keypoints=template_keypoints, + frame_keypoints=sparse_template, + frame=frame_for_analysis, + floor_markings_template=template_for_analysis, + ) + except Exception as e: + # If evaluation fails, use sparse_template as fallback + sparse_score = 0.0 + + # Only evaluate original keypoints if: + # 1. They are valid + # 2. Sparse score is not already very high (>= 0.8) - skip if sparse is already good + should_evaluate_original = original_keypoints_valid and sparse_score < 0.8 + + if should_evaluate_original: + try: + original_score = evaluate_keypoints_for_frame( + template_keypoints=template_keypoints, + frame_keypoints=original_keypoints, + frame=frame_for_analysis, + floor_markings_template=template_for_analysis, + ) + except Exception as e: + # If evaluation fails, keep sparse_template + original_score = 0.0 + else: + if not original_keypoints_valid: + # Original keypoints are invalid, set score to -1 to ensure sparse_template is used + original_score = -1.0 + else: + # Sparse score is already high, skip expensive original evaluation + original_score = sparse_score - 0.01 # Slightly lower to prefer sparse + + # Choose the keypoints with higher score + if original_score > sparse_score: + final_keypoints = original_keypoints + print(f"Frame {frame_index}: Using original keypoints (score: {original_score:.4f} > sparse: {sparse_score:.4f})") + else: + final_keypoints = sparse_template + if original_keypoints_valid: + print(f"Frame {frame_index}: Using sparse template keypoints (score: {sparse_score:.4f} >= original: {original_score:.4f})") + else: + print(f"Frame {frame_index}: Using sparse template keypoints (score: {sparse_score:.4f}, original keypoints invalid)") + + except Exception as e: + # If evaluation fails completely, use sparse_template as default + print(f"Frame {frame_index}: Could not evaluate keypoints, using sparse template: {e}") + final_keypoints = sparse_template + else: + # If we don't have frame/template for evaluation, use sparse_template + final_keypoints = sparse_template + + setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(final_keypoints))) + # frame_image = frames[frame_index] + + # if frame_index in failed_set: + # # Substitute last_success_kps and validate/adjust + # adjusted_kps = _adjust_keypoints_to_pass_validation( + # last_success_kps, template_keypoints, frame_width, frame_height + # ) + + # check_success, adjusted_kps = adjust_keypoints_to_avoid_invalid_mask( + # adjusted_kps, template_keypoints, frame_image, floor_markings_template + # ) + + # if check_success: + # setattr(frame_result, "keypoints", list(adjusted_kps)) + # else: + # setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(sparse_template))) + + # # setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(sparse_template))) + # else: + # current_kps_raw = getattr(frame_result, "keypoints", []) or [] + # current_kps = convert_keypoints_to_val_format(current_kps_raw) + + # last_success_kps = current_kps + + # setattr(frame_result, "keypoints", list(current_kps)) + + + + return list(results_frames) + +def run_keypoints_post_processing( + results_frames: Sequence[Any], + frame_width: int, + frame_height: int, + frames: List[np.ndarray] = None, + template_keypoints: list[tuple[int, int]] = None, + floor_markings_template: np.ndarray = None, + template_image_path: str = None, + offset: int = 0, +) -> list[Any]: + """ + Post-process keypoints with validation and adjustment. + + Args: + results_frames: Sequence of frame results with keypoints + frame_width: Frame width + frame_height: Frame height + frames: Optional list of frame images for validation + template_keypoints: Optional template keypoints (defaults to TEMPLATE_KEYPOINTS) + floor_markings_template: Optional template image for validation + template_image_path: Optional path to template image (will load if not provided) + offset: Frame offset for tracking (defaults to 0) + + Returns: + List of processed frame results + """ + # Load template_keypoints and floor_markings_template if not provided + if template_keypoints is None or floor_markings_template is None: + try: + from keypoint_evaluation import ( + load_template_from_file, + TEMPLATE_KEYPOINTS, + ) + + if template_keypoints is None: + template_keypoints = TEMPLATE_KEYPOINTS + + if floor_markings_template is None: + if template_image_path is None: + # Try to find template in common locations + from pathlib import Path + possible_paths = [ + Path("football_pitch_template.png"), + Path("templates/football_pitch_template.png"), + ] + template_image_path = None + for path in possible_paths: + if path.exists(): + template_image_path = str(path) + break + + if template_image_path is not None: + loaded_template_image, loaded_template_keypoints = load_template_from_file(template_image_path) + floor_markings_template = loaded_template_image + if template_keypoints is None: + template_keypoints = loaded_template_keypoints + else: + # If not found, we'll skip full validation but still do basic checks + print("Warning: Template image not found, skipping full mask validation") + except ImportError: + pass + except Exception as e: + print(f"Warning: Could not load template: {e}") + + failed_indices = predict_failed_indices( + results_frames, template_keypoints, frame_width, frame_height, frames, floor_markings_template, offset + ) + + return fix_keypoints( + results_frames, failed_indices, frame_width, frame_height, + template_keypoints, frames, floor_markings_template, offset + ) \ No newline at end of file diff --git a/keypoint_helper_v2_optimized.py b/keypoint_helper_v2_optimized.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9f3f78a3caf1a8e62e34551da1dd1aa03cb3dc --- /dev/null +++ b/keypoint_helper_v2_optimized.py @@ -0,0 +1,4119 @@ + +import time +import numpy as np +import cv2 +from typing import List, Tuple, Sequence, Any +from numpy import ndarray +from multiprocessing import cpu_count +from functools import partial +import copy +import threading +from pathlib import Path + +# Module-level template variables (initialized lazily) +_TEMPLATE_KEYPOINTS: list[tuple[int, int]] = None +_TEMPLATE_IMAGE: np.ndarray = None +# Cached template dimensions for performance (default values) +_TEMPLATE_MAX_X: int = 1045 +_TEMPLATE_MAX_Y: int = 675 + + +def _initialize_template_variables(template_keypoints=None, template_image=None): + """ + Initialize module-level template variables. + Called once from run_keypoints_post_processing. + + Args: + template_keypoints: Optional template keypoints (pre-loaded) + template_image: Optional template image (pre-loaded from miner constructor) + """ + global _TEMPLATE_KEYPOINTS, _TEMPLATE_IMAGE + + if _TEMPLATE_KEYPOINTS is None or _TEMPLATE_IMAGE is None: + try: + from keypoint_evaluation import ( + TEMPLATE_KEYPOINTS, + ) + + # Set template keypoints (use provided or use default) + if _TEMPLATE_KEYPOINTS is None: + if template_keypoints is not None: + _TEMPLATE_KEYPOINTS = template_keypoints + else: + _TEMPLATE_KEYPOINTS = TEMPLATE_KEYPOINTS + + # Set template image (use provided pre-loaded image) + if _TEMPLATE_IMAGE is None: + if template_image is not None: + # Use pre-loaded template image (from miner constructor) + _TEMPLATE_IMAGE = template_image + else: + print("Warning: Template image not provided, some validation may be skipped") + + # Cache template dimensions for performance + global _TEMPLATE_MAX_X, _TEMPLATE_MAX_Y + if _TEMPLATE_KEYPOINTS is not None and len(_TEMPLATE_KEYPOINTS) > 0: + valid_template_points = [(x, y) for x, y in _TEMPLATE_KEYPOINTS if x > 0 and y > 0] + if len(valid_template_points) > 0: + _TEMPLATE_MAX_X = max(x for x, y in valid_template_points) + _TEMPLATE_MAX_Y = max(y for x, y in valid_template_points) + except ImportError: + pass + except Exception as e: + print(f"Warning: Could not load template: {e}") + +FOOTBALL_KEYPOINTS: list[tuple[int, int]] = [ + (0, 0), # 1 + (0, 0), # 2 + (0, 0), # 3 + (0, 0), # 4 + (0, 0), # 5 + (0, 0), # 6 + + (0, 0), # 7 + (0, 0), # 8 + (0, 0), # 9 + + (0, 0), # 10 + (0, 0), # 11 + (0, 0), # 12 + (0, 0), # 13 + + (0, 0), # 14 + (527, 283), # 15 + (527, 403), # 16 + (0, 0), # 17 + + (0, 0), # 18 + (0, 0), # 19 + (0, 0), # 20 + (0, 0), # 21 + + (0, 0), # 22 + + (0, 0), # 23 + (0, 0), # 24 + + (0, 0), # 25 + (0, 0), # 26 + (0, 0), # 27 + (0, 0), # 28 + (0, 0), # 29 + (0, 0), # 30 + + (405, 340), # 31 + (645, 340), # 32 +] + +def convert_keypoints_to_val_format(keypoints): + return [tuple(int(x) for x in pair) for pair in keypoints] + +def validate_with_nearby_keypoints( + kp_idx: int, + kp: tuple[int, int], + valid_indices: list[int], + result: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]], + scale_factor: float = None, +) -> float: + """ + Validate a keypoint by checking distances to nearby keypoints on the same side. + + Returns validation score (lower is better), or None if validation not possible. + """ + template_kp = template_keypoints[kp_idx] + + # Define which keypoints are on the same side + # Left side: 10, 11, 12, 13 (indices 9, 10, 11, 12) + # Right side: 18, 19, 20, 21, 22, 23, 24, 25-30 (indices 17-29) + + left_side_indices = [9, 10, 11, 12] # Keypoints 10-13 + right_side_indices = list(range(17, 30)) # Keypoints 18-30 + + # Determine which side this keypoint should be on + if kp_idx in left_side_indices: + same_side_indices = left_side_indices + elif kp_idx in right_side_indices: + same_side_indices = right_side_indices + else: + return None # Can't validate + + # Find nearby keypoints on the same side that are detected + nearby_kps = [] + for nearby_idx in same_side_indices: + if nearby_idx != kp_idx and nearby_idx in valid_indices: + nearby_kp = result[nearby_idx] + nearby_template_kp = template_keypoints[nearby_idx] + nearby_kps.append((nearby_idx, nearby_kp, nearby_template_kp)) + + if len(nearby_kps) == 0: + return None # No nearby keypoints to validate with + + # Calculate distance errors to nearby keypoints + distance_errors = [] + for nearby_idx, nearby_kp, nearby_template_kp in nearby_kps: + # Detected distance + detected_dist = np.sqrt((kp[0] - nearby_kp[0])**2 + (kp[1] - nearby_kp[1])**2) + + # Template distance + template_dist = np.sqrt((template_kp[0] - nearby_template_kp[0])**2 + + (template_kp[1] - nearby_template_kp[1])**2) + + if template_dist > 0: + # Expected detected distance + if scale_factor: + expected_dist = template_dist * scale_factor + else: + expected_dist = template_dist + + if expected_dist > 0: + # Normalized error + error = abs(detected_dist - expected_dist) / expected_dist + distance_errors.append(error) + + if len(distance_errors) > 0: + return np.mean(distance_errors) + return None + +def remove_duplicate_detections( + keypoints: list[tuple[int, int]], + frame_width: int = None, + frame_height: int = None, +) -> list[tuple[int, int]]: + """ + Remove duplicate/conflicting keypoint detections using distance-based validation. + + Uses the principle that if two keypoints are detected very close together, + but in the template they should be far apart, one of them is likely wrong. + Validates each keypoint by checking if its distances to other keypoints + match the expected template distances. + + Args: + keypoints: List of 32 keypoints + frame_width: Optional frame width for validation + frame_height: Optional frame height for validation + + Returns: + Cleaned list of keypoints with duplicates removed + """ + if len(keypoints) != 32: + if len(keypoints) < 32: + keypoints = list(keypoints) + [(0, 0)] * (32 - len(keypoints)) + else: + keypoints = keypoints[:32] + + result = list(keypoints) + + try: + from keypoint_evaluation import TEMPLATE_KEYPOINTS + template_available = True + except ImportError: + template_available = False + + if not template_available: + return result + + # Get all valid detected keypoints + valid_indices = [] + for i in range(32): + if result[i][0] > 0 and result[i][1] > 0: + valid_indices.append(i) + + if len(valid_indices) < 2: + return result + + # Calculate scale factor from detected keypoints to template + # Use pairs of keypoints that are far apart in template to estimate scale + scale_factor = None + if len(valid_indices) >= 2: + max_template_dist = 0 + max_detected_dist = 0 + + for i in range(len(valid_indices)): + for j in range(i + 1, len(valid_indices)): + idx_i = valid_indices[i] + idx_j = valid_indices[j] + + template_i = TEMPLATE_KEYPOINTS[idx_i] + template_j = TEMPLATE_KEYPOINTS[idx_j] + template_dist = np.sqrt((template_i[0] - template_j[0])**2 + (template_i[1] - template_j[1])**2) + + kp_i = result[idx_i] + kp_j = result[idx_j] + detected_dist = np.sqrt((kp_i[0] - kp_j[0])**2 + (kp_i[1] - kp_j[1])**2) + + if template_dist > max_template_dist and detected_dist > 0: + max_template_dist = template_dist + max_detected_dist = detected_dist + + if max_template_dist > 0 and max_detected_dist > 0: + scale_factor = max_detected_dist / max_template_dist + + # For each keypoint, validate it by checking distances to other keypoints + keypoint_scores = {} + for idx in valid_indices: + kp = result[idx] + template_kp = TEMPLATE_KEYPOINTS[idx] + + # Calculate how well this keypoint's distances match template distances + distance_errors = [] + num_comparisons = 0 + + for other_idx in valid_indices: + if other_idx == idx: + continue + + other_kp = result[other_idx] + other_template_kp = TEMPLATE_KEYPOINTS[other_idx] + + # Calculate detected distance + detected_dist = np.sqrt((kp[0] - other_kp[0])**2 + (kp[1] - other_kp[1])**2) + + # Calculate template distance + template_dist = np.sqrt((template_kp[0] - other_template_kp[0])**2 + + (template_kp[1] - other_template_kp[1])**2) + + if template_dist > 50: # Only check keypoints that should be reasonably far apart + num_comparisons += 1 + + # Expected detected distance (scaled from template) + if scale_factor: + expected_dist = template_dist * scale_factor + else: + expected_dist = template_dist + + # Calculate error (normalized) + if expected_dist > 0: + error = abs(detected_dist - expected_dist) / expected_dist + distance_errors.append(error) + + # Score: lower is better (smaller distance errors) + if num_comparisons > 0: + avg_error = np.mean(distance_errors) + keypoint_scores[idx] = avg_error + else: + keypoint_scores[idx] = 0.0 + + # Find pairs of keypoints that are too close but should be far apart + conflicts = [] + for i in range(len(valid_indices)): + for j in range(i + 1, len(valid_indices)): + idx_i = valid_indices[i] + idx_j = valid_indices[j] + + kp_i = result[idx_i] + kp_j = result[idx_j] + + # Calculate detected distance + detected_dist = np.sqrt((kp_i[0] - kp_j[0])**2 + (kp_i[1] - kp_j[1])**2) + + # Calculate template distance + template_i = TEMPLATE_KEYPOINTS[idx_i] + template_j = TEMPLATE_KEYPOINTS[idx_j] + template_dist = np.sqrt((template_i[0] - template_j[0])**2 + + (template_i[1] - template_j[1])**2) + + # If template distance is large but detected distance is small, it's a conflict + if template_dist > 100 and detected_dist < 30: + # Enhanced validation: use nearby keypoints to determine which is correct + # For example, if we have 24 and 29, we can check distances to determine if it's 13 or 21 + score_i = keypoint_scores.get(idx_i, 1.0) + score_j = keypoint_scores.get(idx_j, 1.0) + + # Try to validate using nearby keypoints on the same side + # Keypoint 13 is on left side, keypoint 21 is on right side + # If we have right-side keypoints (like 24, 29), check distances + nearby_validation_i = validate_with_nearby_keypoints( + idx_i, kp_i, valid_indices, result, TEMPLATE_KEYPOINTS, scale_factor + ) + nearby_validation_j = validate_with_nearby_keypoints( + idx_j, kp_j, valid_indices, result, TEMPLATE_KEYPOINTS, scale_factor + ) + + # Prioritize nearby validation: if one has nearby validation and the other doesn't, + # prefer the one with nearby validation (it's more reliable) + validation_score_i = score_i + validation_score_j = score_j + + if nearby_validation_i is not None and nearby_validation_j is not None: + # Both have nearby validation, use those scores + validation_score_i = nearby_validation_i + validation_score_j = nearby_validation_j + elif nearby_validation_i is not None: + # Only i has nearby validation, prefer it (give it much better score) + validation_score_i = nearby_validation_i + validation_score_j = score_j + 1.0 # Penalize j for not having nearby validation + elif nearby_validation_j is not None: + # Only j has nearby validation, prefer it + validation_score_i = score_i + 1.0 # Penalize i for not having nearby validation + validation_score_j = nearby_validation_j + # If neither has nearby validation, use general distance scores + + # Remove the one with worse validation score + if validation_score_i > validation_score_j: + conflicts.append((idx_i, idx_j, validation_score_i, validation_score_j)) + else: + conflicts.append((idx_j, idx_i, validation_score_j, validation_score_i)) + + # Remove conflicting keypoints (keep the one with better score) + removed_indices = set() + for remove_idx, keep_idx, remove_score, keep_score in conflicts: + if remove_idx not in removed_indices: + print(f"Removing duplicate detection: keypoint {remove_idx+1} at {result[remove_idx]} conflicts with keypoint {keep_idx+1} at {result[keep_idx]} " + f"(detected distance: {np.sqrt((result[remove_idx][0] - result[keep_idx][0])**2 + (result[remove_idx][1] - result[keep_idx][1])**2):.1f}, " + f"template distance: {np.sqrt((TEMPLATE_KEYPOINTS[remove_idx][0] - TEMPLATE_KEYPOINTS[keep_idx][0])**2 + (TEMPLATE_KEYPOINTS[remove_idx][1] - TEMPLATE_KEYPOINTS[keep_idx][1])**2):.1f}). " + f"Keeping keypoint {keep_idx+1} (score: {keep_score:.3f} vs {remove_score:.3f}).") + result[remove_idx] = (0, 0) + removed_indices.add(remove_idx) + + return result + +def calculate_missing_keypoints( + keypoints: list[tuple[int, int]], + frame_width: int = None, + frame_height: int = None, +) -> list[tuple[int, int]]: + """ + Calculate missing keypoint coordinates for multiple cases: + 1. Given keypoints 14, 15, 16 (and possibly 17), and either 31 or 32, + calculate the missing center circle point (32 or 31). + 2. Given three or four of keypoints 18, 19, 20, 21 and any of 22-30, + calculate missing keypoint positions (like 22 or others) to prevent warping failures. + + Args: + keypoints: List of 32 keypoints (some may be (0,0) if missing) + frame_width: Optional frame width for validation + frame_height: Optional frame height for validation + + Returns: + Updated list of 32 keypoints with calculated missing keypoints filled in + """ + if len(keypoints) != 32: + # Pad or truncate to 32 + if len(keypoints) < 32: + keypoints = list(keypoints) + [(0, 0)] * (32 - len(keypoints)) + else: + keypoints = keypoints[:32] + + result = list(keypoints) + + # Helper to get keypoint + def get_kp(kp_idx): + if kp_idx < 0 or kp_idx >= 32: + return None + x, y = result[kp_idx] + + if x == 0 and y == 0: + return None + + return (x, y) + + + # Case 1: Find center x-coordinate from center line keypoints (14, 15, 16, or 17) + # Keypoints 14, 15, 16, 17 are on the center vertical line (indices 13, 14, 15, 16) + center_x = None + for center_kp_idx in [13, 14, 15, 16]: # 14, 15, 16, 17 (0-indexed) + kp = get_kp(center_kp_idx) + if kp: + center_x = kp[0] + break + + # If we have center line, calculate missing center circle point + if center_x is not None: + # Keypoint 31 is at index 30 (left side of center circle) + # Keypoint 32 is at index 31 (right side of center circle) + kp_31 = get_kp(30) # Keypoint 31 + kp_32 = get_kp(31) # Keypoint 32 + + if kp_31 and not kp_32: + # Given 31, calculate 32 by reflecting across center_x + # Formula: x_32 = center_x + (center_x - x_31) = 2*center_x - x_31 + # y_32 = y_31 (same y-coordinate, both on center horizontal line) + dx = center_x - kp_31[0] + result[31] = (int(round(center_x + dx)), kp_31[1]) + elif kp_32 and not kp_31: + # Given 32, calculate 31 by reflecting across center_x + # Formula: x_31 = center_x - (x_32 - center_x) = 2*center_x - x_32 + # y_31 = y_32 (same y-coordinate, both on center horizontal line) + dx = kp_32[0] - center_x + result[30] = (int(round(center_x - dx)), kp_32[1]) + + # Case 1.5: Unified handling of left side keypoints (1-13) + # Three parallel vertical lines on left side: + # - Line 1-6: keypoints 1, 2, 3, 4, 5, 6 (indices 0-5) + # - Line 7-8: keypoints 7, 8 (indices 6-7) + # - Line 10-13: keypoints 10, 11, 12, 13 (indices 9-12) + # Keypoint 9 (index 8) is between line 1-6 and line 10-13 + + # Collect all left-side keypoints (1-13, indices 0-12, excluding 9 which is center) + left_side_all = [] + line_1_6_points = [] # Indices 0-5 + line_7_8_points = [] # Indices 6-7 + line_10_13_points = [] # Indices 9-12 + + for idx in range(0, 13): # Keypoints 1-13 (indices 0-12) + if idx == 8: # Skip keypoint 9 (index 8) - it's a center point + continue + kp = get_kp(idx) + if kp: + left_side_all.append((idx, kp)) + if 0 <= idx <= 5: # Line 1-6 + line_1_6_points.append((idx, kp)) + elif 6 <= idx <= 7: # Line 7-8 + line_7_8_points.append((idx, kp)) + elif 9 <= idx <= 12: # Line 10-13 + line_10_13_points.append((idx, kp)) + + kp_9 = get_kp(8) # Keypoint 9 + if kp_9: + left_side_all.append((8, kp_9)) + + total_left_side_count = len(left_side_all) + + # If we have 6 or more points, no need to calculate more + if total_left_side_count >= 6: + pass # Don't calculate more points + elif total_left_side_count == 5: + # Check if 4 points are on one line and 1 on another line + counts_per_line = [ + len(line_1_6_points), + len(line_7_8_points), + len(line_10_13_points) + ] + + if max(counts_per_line) == 4 and sum(counts_per_line) == 4: + # 4 points on one line, need to calculate 1 more point on another line + # Determine which line has 4 points and calculate on a different line + if len(line_1_6_points) == 4: + # All 4 on line 1-6, calculate on line 10-13 or 7-8 + # Prefer line 10-13 (right edge of left side) + if len(line_10_13_points) == 0: + # Calculate a point on line 10-13 + # Fit line through 1-6 points + points_1_6 = np.array([[kp[0], kp[1]] for _, kp in line_1_6_points]) + x_coords = points_1_6[:, 0] + y_coords = points_1_6[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_1_6, b_1_6 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 10-13 (parallel to 1-6) + # Use template y-coordinate for one of 10-13 points + template_ys_10_13 = [140, 270, 410, 540] # Template y for 10-13 + template_indices_10_13 = [9, 10, 11, 12] + + # Use median y from 1-6 points to estimate scale + median_y = np.median(y_coords) + + # Calculate x using parallel line geometry + # In template: line 10-13 is at x=165, line 1-6 is at x=5 + # Ratio: 165/5 = 33 + if abs(m_1_6) > 1e-6: + x_on_line_1_6 = (median_y - b_1_6) / m_1_6 + x_new = int(round(x_on_line_1_6 * 33)) + else: + x_new = int(round(np.median(x_coords) * 33)) + + # Find first missing index in 10-13 range + for template_y, idx in zip(template_ys_10_13, template_indices_10_13): + if result[idx] is None: + result[idx] = (x_new, int(round(median_y))) + break + elif len(line_10_13_points) == 4: + # All 4 on line 10-13, calculate on line 1-6 + # Similar logic but in reverse + points_10_13 = np.array([[kp[0], kp[1]] for _, kp in line_10_13_points]) + x_coords = points_10_13[:, 0] + y_coords = points_10_13[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_10_13, b_10_13 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 1-6 + template_ys_1_6 = [5, 140, 250, 430, 540, 675] # Template y for 1-6 + template_indices_1_6 = [0, 1, 2, 3, 4, 5] + + median_y = np.median(y_coords) + + # Calculate x using parallel line geometry + # Ratio: 5/165 ≈ 0.0303 + if abs(m_10_13) > 1e-6: + x_on_line_10_13 = (median_y - b_10_13) / m_10_13 + x_new = int(round(x_on_line_10_13 * 0.0303)) + else: + x_new = int(round(np.median(x_coords) * 0.0303)) + + for template_y, idx in zip(template_ys_1_6, template_indices_1_6): + if result[idx] is None: + result[idx] = (x_new, int(round(median_y))) + break + elif total_left_side_count < 5: + # Need to calculate missing keypoints to get exactly 5 points + # Requirements: + # 1. Must have keypoint 9 (if possible) + # 2. 4 points shouldn't be all on one line (need distribution) + + # Template coordinates for reference + template_coords_left = { + 0: (5, 5), # 1 + 1: (5, 140), # 2 + 2: (5, 250), # 3 + 3: (5, 430), # 4 + 4: (5, 540), # 5 + 5: (5, 675), # 6 + 6: (55, 250), # 7 + 7: (55, 430), # 8 + 8: (110, 340), # 9 (what we're calculating) + 9: (165, 140), # 10 + 10: (165, 270), # 11 + 11: (165, 410), # 12 + 12: (165, 540), # 13 + } + + # Define line groups (vertical and horizontal lines) + # Vertical lines: 1-6, 7-8, 10-13 + # Horizontal lines: 2-10, 3-7, 4-8, 5-13 + line_groups_left = { + '1-6': ([0, 1, 2, 3, 4, 5], 'vertical'), # indices: 1, 2, 3, 4, 5, 6 + '7-8': ([6, 7], 'vertical'), # indices: 7, 8 + '10-13': ([9, 10, 11, 12], 'vertical'), # indices: 10, 11, 12, 13 + '2-10': ([1, 9], 'horizontal'), # indices: 2, 10 + '3-7': ([2, 6], 'horizontal'), # indices: 3, 7 + '4-8': ([3, 7], 'horizontal'), # indices: 4, 8 + '5-13': ([4, 12], 'horizontal'), # indices: 5, 13 + } + + # Collect all available points with their indices + all_available_points_left = {} + for idx, kp in line_1_6_points: + all_available_points_left[idx] = kp + for idx, kp in line_7_8_points: + all_available_points_left[idx] = kp + for idx, kp in line_10_13_points: + all_available_points_left[idx] = kp + + # Step 1: Find the best vertical line and best horizontal line separately + best_vertical_line_name_left = None + best_vertical_line_points_left = [] + max_vertical_points_left = 1 + + best_horizontal_line_name_left = None + best_horizontal_line_points_left = [] + max_horizontal_points_left = 1 + + for line_name, (indices, line_type) in line_groups_left.items(): + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if line_type == 'vertical' and len(line_points) > max_vertical_points_left: + max_vertical_points_left = len(line_points) + best_vertical_line_name_left = line_name + best_vertical_line_points_left = line_points + elif line_type == 'horizontal' and len(line_points) > max_horizontal_points_left: + max_horizontal_points_left = len(line_points) + best_horizontal_line_name_left = line_name + best_horizontal_line_points_left = line_points + + # Check and calculate missing points on detected lines + # For vertical lines + if best_vertical_line_name_left is not None: + expected_indices = line_groups_left[best_vertical_line_name_left][0] + detected_indices = {idx for idx, _ in best_vertical_line_points_left} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords_left[best_vertical_line_points_left[0][0]] + template_end = template_coords_left[best_vertical_line_points_left[-1][0]] + frame_start = best_vertical_line_points_left[0][1] + frame_end = best_vertical_line_points_left[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords_left[missing_idx] + + # Calculate ratio along the line based on y-coordinate (vertical line) + template_y_start = template_start[1] + template_y_end = template_end[1] + template_y_missing = template_missing[1] + + if abs(template_y_end - template_y_start) > 1e-6: + ratio = (template_y_missing - template_y_start) / (template_y_end - template_y_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_vertical_line_points_left.append((missing_idx, new_point)) + all_available_points_left[missing_idx] = new_point + total_left_side_count += 1 + max_vertical_points_left = len(best_vertical_line_points_left) + + # Sort by index to maintain order + best_vertical_line_points_left.sort(key=lambda x: x[0]) + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_horizontal_points_left: + max_horizontal_points_left = len(line_points) + best_horizontal_line_name_left = line_name + best_horizontal_line_points_left = line_points + + # For horizontal lines + if best_horizontal_line_name_left is not None: + expected_indices = line_groups_left[best_horizontal_line_name_left][0] + detected_indices = {idx for idx, _ in best_horizontal_line_points_left} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords_left[best_horizontal_line_points_left[0][0]] + template_end = template_coords_left[best_horizontal_line_points_left[-1][0]] + frame_start = best_horizontal_line_points_left[0][1] + frame_end = best_horizontal_line_points_left[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords_left[missing_idx] + + # Calculate ratio along the line based on x-coordinate (horizontal line) + template_x_start = template_start[0] + template_x_end = template_end[0] + template_x_missing = template_missing[0] + + if abs(template_x_end - template_x_start) > 1e-6: + ratio = (template_x_missing - template_x_start) / (template_x_end - template_x_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_horizontal_line_points_left.append((missing_idx, new_point)) + all_available_points_left[missing_idx] = new_point + total_left_side_count += 1 + max_horizontal_points_left = len(best_horizontal_line_points_left) + + # Sort by index to maintain order + best_horizontal_line_points_left.sort(key=lambda x: x[0]) + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_vertical_points_left: + max_vertical_points_left = len(line_points) + best_vertical_line_name_left = line_name + best_vertical_line_points_left = line_points + + # If we only have one direction, try to calculate the other direction line + # Similar logic to right side, adapted for left side structure + if best_vertical_line_name_left is not None and best_horizontal_line_name_left is None: + # We have vertical line but no horizontal line + # Find an off-line point (not on the vertical line) + off_line_point = None + off_line_idx = None + vertical_line_indices = line_groups_left[best_vertical_line_name_left][0] + for idx, kp in all_available_points_left.items(): + if idx not in vertical_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Convert off_line_point to numpy array for arithmetic operations + off_line_point = np.array(off_line_point) + + # Project off_line_point onto vertical line + template_off_line = template_coords_left[off_line_idx] + + template_vertical_start_index = best_vertical_line_points_left[0][0] + template_vertical_end_index = best_vertical_line_points_left[-1][0] + + template_vertical_start = template_coords_left[template_vertical_start_index] + template_vertical_end = template_coords_left[template_vertical_end_index] + + # Project at same y as off_line_point + template_y_off = template_off_line[1] + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_proj = (template_y_off - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_proj = 0.5 + + frame_vertical_start = best_vertical_line_points_left[0][1] + frame_vertical_end = best_vertical_line_points_left[-1][1] + proj_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_proj + proj_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + + # Calculate horizontal line points based on which vertical line we have + if best_vertical_line_name_left == '10-13': + # Line 10-13: can calculate points on horizontal lines 2-10, 5-13 + if off_line_idx == 1: # Point 2 (index 1) is off-line, calculate point 10 (index 9) + kp_10 = np.array(best_vertical_line_points_left[0][1]) # 10 point + kp_2 = off_line_point + (kp_10 - proj_point) + result[1] = tuple(kp_2.astype(int)) + total_left_side_count += 1 + all_available_points_left[1] = tuple(kp_2.astype(int)) + elif off_line_idx == 4: # Point 5 (index 4) is off-line, calculate point 13 (index 12) + kp_13 = np.array(best_vertical_line_points_left[-1][1]) # 13 point + kp_5 = off_line_point + (kp_13 - proj_point) + result[4] = tuple(kp_5.astype(int)) + total_left_side_count += 1 + all_available_points_left[4] = tuple(kp_5.astype(int)) + + elif best_vertical_line_name_left == '1-6': + # Line 1-6: can calculate points on horizontal lines 2-10, 3-7, 4-8, 5-13 + if off_line_idx == 6 or off_line_idx == 7: # Point 7 or 8 is off-line, calculate point 3 or 4 + template_off = template_coords_left[off_line_idx] + template_3 = template_coords_left[2] # 3 point, index 2 + template_4 = template_coords_left[3] # 4 point, index 3 + template_7 = template_coords_left[6] # 7 point, index 6 + template_8 = template_coords_left[7] # 8 point, index 7 + + if off_line_idx == 6: # Point 7, calculate point 3 + ratio = (template_3[0] - template_7[0]) / (template_7[0] - template_off[0]) if abs(template_7[0] - template_off[0]) > 1e-6 else 0.5 + kp_3 = proj_point + (off_line_point - proj_point) * ratio + result[2] = tuple(kp_3.astype(int)) + total_left_side_count += 1 + all_available_points_left[2] = tuple(kp_3.astype(int)) + else: # Point 8, calculate point 4 + ratio = (template_4[0] - template_8[0]) / (template_8[0] - template_off[0]) if abs(template_8[0] - template_off[0]) > 1e-6 else 0.5 + kp_4 = proj_point + (off_line_point - proj_point) * ratio + result[3] = tuple(kp_4.astype(int)) + total_left_side_count += 1 + all_available_points_left[3] = tuple(kp_4.astype(int)) + elif off_line_idx == 9 or off_line_idx == 12: # Point 10 or 13 is off-line, calculate point 2 or 5 + if off_line_idx == 9: # Point 10, calculate point 2 + kp_2 = off_line_point + (np.array(best_vertical_line_points_left[1][1]) - proj_point) + result[1] = tuple(kp_2.astype(int)) + total_left_side_count += 1 + all_available_points_left[1] = tuple(kp_2.astype(int)) + else: # Point 13, calculate point 5 + kp_5 = off_line_point + (np.array(best_vertical_line_points_left[4][1]) - proj_point) + result[4] = tuple(kp_5.astype(int)) + total_left_side_count += 1 + all_available_points_left[4] = tuple(kp_5.astype(int)) + + elif best_vertical_line_name_left == '7-8': + # Line 7-8: can calculate points on horizontal lines 3-7, 4-8 + if off_line_idx == 2 or off_line_idx == 3: # Point 3 or 4 is off-line, calculate point 7 or 8 + if off_line_idx == 2: # Point 3, calculate point 7 + kp_7 = off_line_point + (np.array(best_vertical_line_points_left[0][1]) - proj_point) + result[6] = tuple(kp_7.astype(int)) + total_left_side_count += 1 + all_available_points_left[6] = tuple(kp_7.astype(int)) + else: # Point 4, calculate point 8 + kp_8 = off_line_point + (np.array(best_vertical_line_points_left[-1][1]) - proj_point) + result[7] = tuple(kp_8.astype(int)) + total_left_side_count += 1 + all_available_points_left[7] = tuple(kp_8.astype(int)) + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_horizontal_points_left: + max_horizontal_points_left = len(line_points) + best_horizontal_line_name_left = line_name + best_horizontal_line_points_left = line_points + + elif best_horizontal_line_name_left is not None and best_vertical_line_name_left is None: + # We have horizontal line but no vertical line + # Find an off-line point (not on the horizontal line) + off_line_point = None + off_line_idx = None + horizontal_line_indices = line_groups_left[best_horizontal_line_name_left][0] + for idx, kp in all_available_points_left.items(): + if idx not in horizontal_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Project off_line_point onto horizontal line + template_off_line = template_coords_left[off_line_idx] + template_horizontal_start = template_coords_left[best_horizontal_line_points_left[0][0]] + template_horizontal_end = template_coords_left[best_horizontal_line_points_left[-1][0]] + + # Project at same x as off_line_point + template_x_off = template_off_line[0] + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_proj = (template_x_off - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_proj = 0.5 + + frame_horizontal_start = best_horizontal_line_points_left[0][1] + frame_horizontal_end = best_horizontal_line_points_left[-1][1] + proj_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_proj + proj_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + off_line_point = np.array(off_line_point) + + # Calculate vertical line points based on which horizontal line we have + if best_horizontal_line_name_left == '2-10': + # Line 2-10: can calculate points on vertical lines 1-6, 10-13 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 2 + kp_2 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[1] = tuple(kp_2.astype(int)) + total_left_side_count += 1 + all_available_points_left[1] = tuple(kp_2.astype(int)) + elif off_line_idx == 9 or off_line_idx == 12: # Point 10 or 13 is off-line, calculate point 10 + kp_10 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[9] = tuple(kp_10.astype(int)) + total_left_side_count += 1 + all_available_points_left[9] = tuple(kp_10.astype(int)) + + elif best_horizontal_line_name_left == '3-7': + # Line 3-7: can calculate points on vertical lines 1-6, 7-8 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 3 + kp_3 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[2] = tuple(kp_3.astype(int)) + total_left_side_count += 1 + all_available_points_left[2] = tuple(kp_3.astype(int)) + elif off_line_idx == 6 or off_line_idx == 7: # Point 7 or 8 is off-line, calculate point 7 + kp_7 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[6] = tuple(kp_7.astype(int)) + total_left_side_count += 1 + all_available_points_left[6] = tuple(kp_7.astype(int)) + + elif best_horizontal_line_name_left == '4-8': + # Line 4-8: can calculate points on vertical lines 1-6, 7-8 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 4 + kp_4 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[3] = tuple(kp_4.astype(int)) + total_left_side_count += 1 + all_available_points_left[3] = tuple(kp_4.astype(int)) + elif off_line_idx == 6 or off_line_idx == 7: # Point 7 or 8 is off-line, calculate point 8 + kp_8 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[7] = tuple(kp_8.astype(int)) + total_left_side_count += 1 + all_available_points_left[7] = tuple(kp_8.astype(int)) + + elif best_horizontal_line_name_left == '5-13': + # Line 5-13: can calculate points on vertical lines 1-6, 10-13 + if off_line_idx == 0 or off_line_idx == 5: # Point 1 or 6 is off-line, calculate point 5 + kp_5 = off_line_point + (np.array(best_horizontal_line_points_left[0][1]) - proj_point) + result[4] = tuple(kp_5.astype(int)) + total_left_side_count += 1 + all_available_points_left[4] = tuple(kp_5.astype(int)) + elif off_line_idx == 9 or off_line_idx == 12: # Point 10 or 13 is off-line, calculate point 13 + kp_13 = off_line_point + (np.array(best_horizontal_line_points_left[-1][1]) - proj_point) + result[12] = tuple(kp_13.astype(int)) + total_left_side_count += 1 + all_available_points_left[12] = tuple(kp_13.astype(int)) + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups_left.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points_left[idx]) for idx in indices if idx in all_available_points_left] + if len(line_points) > max_vertical_points_left: + max_vertical_points_left = len(line_points) + best_vertical_line_name_left = line_name + best_vertical_line_points_left = line_points + + # Calculate keypoint 9 if we have at least one line + if best_vertical_line_name_left is not None and best_horizontal_line_name_left is not None: + if kp_9 is None: + print(f"Calculating keypoint 9 using both vertical and horizontal lines: {best_vertical_line_name_left} and {best_horizontal_line_name_left}") + + template_x_9 = 110 + template_y_9 = 340 + + # Project keypoint 9 onto vertical line + template_vertical_start = template_coords_left[best_vertical_line_points_left[0][0]] + template_vertical_end = template_coords_left[best_vertical_line_points_left[-1][0]] + + # Project at y=340 (same y as keypoint 9) + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_9_vertical = (template_y_9 - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_9_vertical = 0.5 + + frame_vertical_start = best_vertical_line_points_left[0][1] + frame_vertical_end = best_vertical_line_points_left[-1][1] + proj_9_on_vertical_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_9_vertical + proj_9_on_vertical_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_9_vertical + proj_9_on_vertical = (proj_9_on_vertical_x, proj_9_on_vertical_y) + + # Project keypoint 9 onto horizontal line + template_horizontal_start = template_coords_left[best_horizontal_line_points_left[0][0]] + template_horizontal_end = template_coords_left[best_horizontal_line_points_left[-1][0]] + + # Project at x=110 (same x as keypoint 9) + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_9_horizontal = (template_x_9 - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_9_horizontal = 0.5 + + frame_horizontal_start = best_horizontal_line_points_left[0][1] + frame_horizontal_end = best_horizontal_line_points_left[-1][1] + proj_9_on_horizontal_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_9_horizontal + proj_9_on_horizontal_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_9_horizontal + proj_9_on_horizontal = (proj_9_on_horizontal_x, proj_9_on_horizontal_y) + + # Calculate keypoint 9 as intersection of two lines + # Line 1: Passes through proj_9_on_vertical, parallel to best_horizontal_line + # Line 2: Passes through proj_9_on_horizontal, parallel to best_vertical_line + + # Calculate direction vector of best_horizontal_line + horizontal_dir_x = frame_horizontal_end[0] - frame_horizontal_start[0] + horizontal_dir_y = frame_horizontal_end[1] - frame_horizontal_start[1] + horizontal_dir_length = np.sqrt(horizontal_dir_x**2 + horizontal_dir_y**2) + + # Calculate direction vector of best_vertical_line + vertical_dir_x = frame_vertical_end[0] - frame_vertical_start[0] + vertical_dir_y = frame_vertical_end[1] - frame_vertical_start[1] + vertical_dir_length = np.sqrt(vertical_dir_x**2 + vertical_dir_y**2) + + if horizontal_dir_length > 1e-6 and vertical_dir_length > 1e-6: + # Normalize direction vectors + horizontal_dir_x /= horizontal_dir_length + horizontal_dir_y /= horizontal_dir_length + vertical_dir_x /= vertical_dir_length + vertical_dir_y /= vertical_dir_length + + # Find intersection: proj_9_on_vertical + t * horizontal_dir = proj_9_on_horizontal + s * vertical_dir + A = np.array([ + [horizontal_dir_x, -vertical_dir_x], + [horizontal_dir_y, -vertical_dir_y] + ]) + b = np.array([ + proj_9_on_horizontal[0] - proj_9_on_vertical[0], + proj_9_on_horizontal[1] - proj_9_on_vertical[1] + ]) + + try: + t, s = np.linalg.solve(A, b) + + # Calculate intersection point using line 1 + x_9 = proj_9_on_vertical[0] + t * horizontal_dir_x + y_9 = proj_9_on_vertical[1] + t * horizontal_dir_y + + result[8] = (int(round(x_9)), int(round(y_9))) + total_left_side_count += 1 + except np.linalg.LinAlgError: + # Lines are parallel or nearly parallel, use simple intersection + x_9 = proj_9_on_vertical[0] + y_9 = proj_9_on_horizontal[1] + result[8] = (int(round(x_9)), int(round(y_9))) + total_left_side_count += 1 + else: + # Fallback: use simple intersection + x_9 = proj_9_on_vertical[0] + y_9 = proj_9_on_horizontal[1] + result[8] = (int(round(x_9)), int(round(y_9))) + total_left_side_count += 1 + + print(f"total_left_side_count: {total_left_side_count}, result: {result}") + if total_left_side_count > 5: + pass # Continue to right side logic + + # Calculate m_line and b_line from best vertical or horizontal line for use in calculating other points + m_line_left = None + b_line_left = None + best_line_for_calc_left = None + best_line_type_for_calc_left = None + + if best_vertical_line_name_left is not None and len(best_vertical_line_points_left) >= 2: + best_line_for_calc_left = best_vertical_line_points_left + best_line_type_for_calc_left = 'vertical' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_vertical_line_points_left]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line_left, b_line_left = np.linalg.lstsq(A, y_coords, rcond=None)[0] + elif best_horizontal_line_name_left is not None and len(best_horizontal_line_points_left) >= 2: + best_line_for_calc_left = best_horizontal_line_points_left + best_line_type_for_calc_left = 'horizontal' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_horizontal_line_points_left]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line_left, b_line_left = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate missing points to reach exactly 5 points + # Ensure 4 points aren't all on one line + if total_left_side_count < 5 and (m_line_left is not None or (best_line_for_calc_left is not None and best_line_type_for_calc_left == 'vertical')): + # Check current distribution + counts_per_line = [ + len(line_1_6_points), + len(line_7_8_points), + len(line_10_13_points) + ] + + # Calculate points on line 1-6 if needed + template_ys_1_6 = [5, 140, 250, 430, 540, 675] + template_indices_1_6 = [0, 1, 2, 3, 4, 5] + + if best_vertical_line_name_left == '10-13': + # Construct parallel line 1-6 from line 10-13 + for template_y, idx in zip(template_ys_1_6, template_indices_1_6): + if result[idx] is None and total_left_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 1-6 + if max(new_counts) >= 4 and total_left_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate y using scale from template + ref_ys = [kp[1] for _, kp in line_10_13_points] + ref_template_ys = [140, 270, 410, 540] + ref_indices = [9, 10, 11, 12] + + matched_template_ys = [] + for ref_idx, ref_kp in line_10_13_points: + if ref_idx in ref_indices: + template_idx = ref_indices.index(ref_idx) + matched_template_ys.append((ref_template_ys[template_idx], ref_kp[1])) + + if len(matched_template_ys) >= 1: + ref_template_y, ref_frame_y = matched_template_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + # Calculate x using parallel line geometry + if abs(m_line_left) > 1e-6: + x_on_line_10_13 = (y_new - b_line_left) / m_line_left + x_new = int(round(x_on_line_10_13 * 0.0303)) # 5/165 + else: + x_new = int(round(np.median([kp[0] for _, kp in line_10_13_points]) * 0.0303)) + + result[idx] = (x_new, y_new) + total_left_side_count += 1 + if total_left_side_count >= 5: + break + elif best_vertical_line_name_left == '1-6': + # Calculate missing points on line 1-6 + for template_y, idx in zip(template_ys_1_6, template_indices_1_6): + if result[idx] is None and total_left_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 1-6 + if max(new_counts) >= 4 and total_left_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate x on the line + if abs(m_line_left) > 1e-6: + x_new = (template_y - b_line_left) / m_line_left + else: + x_new = np.median([kp[0] for _, kp in line_1_6_points]) + + # Scale y based on available points + ref_ys = [kp[1] for _, kp in line_1_6_points] + ref_template_ys = [] + for ref_idx, _ in line_1_6_points: + if ref_idx in template_indices_1_6: + template_idx = template_indices_1_6.index(ref_idx) + ref_template_ys.append(template_ys_1_6[template_idx]) + + if len(ref_ys) >= 1 and len(ref_template_ys) >= 1: + ref_template_y = ref_template_ys[0] + ref_frame_y = ref_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + result[idx] = (int(round(x_new)), y_new) + total_left_side_count += 1 + if total_left_side_count >= 5: + break + + print(f"total_left_side_count: {total_left_side_count}, result: {result}") + + # Case 2: Unified handling of right side keypoints (18-30) + # Three parallel lines on right side: + # - Line 18-21: keypoints 18, 19, 20, 21 (indices 17-20) + # - Line 23-24: keypoints 23, 24 (indices 22-23) + # - Line 25-30: keypoints 25, 26, 27, 28, 29, 30 (indices 24-29) + # Keypoint 22 (index 21) is between line 18-21 and line 25-30 + + # Collect all right-side keypoints (18-30, indices 17-29) + right_side_all = [] + line_18_21_points = [] # Indices 17-20 + line_23_24_points = [] # Indices 22-23 + line_25_30_points = [] # Indices 24-29 + + for idx in range(17, 30): # Keypoints 18-30 (indices 17-29) + kp = get_kp(idx) + if kp: + right_side_all.append((idx, kp)) + if 17 <= idx <= 20: # Line 18-21 + line_18_21_points.append((idx, kp)) + elif 22 <= idx <= 23: # Line 23-24 + line_23_24_points.append((idx, kp)) + elif 24 <= idx <= 29: # Line 25-30 + line_25_30_points.append((idx, kp)) + + kp_22 = get_kp(21) # Keypoint 22 + if kp_22: + right_side_all.append((21, kp_22)) + + total_right_side_count = len(right_side_all) + + # If we have 6 or more points, no need to calculate more + if total_right_side_count >= 6: + pass # Don't calculate more points + elif total_right_side_count == 5: + # Check if 4 points are on one line and 1 on another line + counts_per_line = [ + len(line_18_21_points), + len(line_23_24_points), + len(line_25_30_points) + ] + + if max(counts_per_line) == 4 and sum(counts_per_line) == 4: + # 4 points on one line, need to calculate 1 more point on another line + # Determine which line has 4 points and calculate on a different line + if len(line_18_21_points) == 4: + # All 4 on line 18-21, calculate on line 25-30 or 23-24 + # Prefer line 25-30 (right edge) + if len(line_25_30_points) == 0: + # Calculate a point on line 25-30 + # Fit line through 18-21 points + points_18_21 = np.array([[kp[0], kp[1]] for _, kp in line_18_21_points]) + x_coords = points_18_21[:, 0] + y_coords = points_18_21[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_18_21, b_18_21 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 25-30 (parallel to 18-21) + # Use template y-coordinate for one of 25-30 points + template_ys_25_30 = [5, 140, 250, 430, 540, 675] # Template y for 25-30 + template_indices_25_30 = [24, 25, 26, 27, 28, 29] + + # Use median y from 18-21 points to estimate scale + median_y = np.median(y_coords) + # Find closest template y + ref_template_y = min(template_ys_25_30, key=lambda ty: abs(ty - np.median([kp[1] for _, kp in line_18_21_points]))) + ref_idx = template_ys_25_30.index(ref_template_y) + + # Calculate y for the new point + y_new = int(round(median_y)) + + # Calculate x using parallel line geometry + # In template: line 25-30 is at x=1045, line 18-21 is at x=888 + # Ratio: 1045/888 ≈ 1.177 + if abs(m_18_21) > 1e-6: + x_on_line_18_21 = (y_new - b_18_21) / m_18_21 + x_new = int(round(x_on_line_18_21 * 1.177)) + else: + x_new = int(round(np.median(x_coords) * 1.177)) + + # Find first missing index in 25-30 range + for template_y, idx in zip(template_ys_25_30, template_indices_25_30): + if result[idx] is None: + result[idx] = (x_new, y_new) + break + elif len(line_25_30_points) == 4: + # All 4 on line 25-30, calculate on line 18-21 + # Similar logic but in reverse + points_25_30 = np.array([[kp[0], kp[1]] for _, kp in line_25_30_points]) + x_coords = points_25_30[:, 0] + y_coords = points_25_30[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_25_30, b_25_30 = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate a point on line 18-21 + template_ys_18_21 = [140, 270, 410, 540] # Template y for 18-21 + template_indices_18_21 = [17, 18, 19, 20] + + median_y = np.median(y_coords) + + # Calculate x using parallel line geometry + # Ratio: 888/1045 ≈ 0.850 + if abs(m_25_30) > 1e-6: + x_on_line_25_30 = (median_y - b_25_30) / m_25_30 + x_new = int(round(x_on_line_25_30 * 0.850)) + else: + x_new = int(round(np.median(x_coords) * 0.850)) + + for template_y, idx in zip(template_ys_18_21, template_indices_18_21): + if result[idx] is None: + result[idx] = (x_new, int(round(median_y))) + break + elif total_right_side_count < 5: + # Need to calculate missing keypoints to get exactly 5 points + # Requirements: + # 1. Must have keypoint 22 + # 2. 4 points shouldn't be all on one line (need distribution) + + # Template coordinates for reference + template_coords = { + 17: (888, 140), # 18 + 18: (888, 270), # 19 + 19: (888, 410), # 20 + 20: (888, 540), # 21 + 21: (940, 340), # 22 (what we're calculating) + 22: (998, 250), # 23 + 23: (998, 430), # 24 + 24: (1045, 5), # 25 + 25: (1045, 140), # 26 + 26: (1045, 250), # 27 + 27: (1045, 430), # 28 + 28: (1045, 540), # 29 + 29: (1045, 675), # 30 + } + + # Define line groups (vertical and horizontal lines) + # Vertical lines: 18-21, 23-24, 25-30 + # Horizontal lines: 18-26, 23-27, 24-28, 21-29 + line_groups = { + '18-21': ([17, 18, 19, 20], 'vertical'), # indices: 18, 19, 20, 21 + '23-24': ([22, 23], 'vertical'), # indices: 23, 24 + '25-30': ([24, 25, 26, 27, 28, 29], 'vertical'), # indices: 25, 26, 27, 28, 29, 30 + '18-26': ([17, 25], 'horizontal'), # indices: 18, 26 + '23-27': ([22, 26], 'horizontal'), # indices: 23, 27 + '24-28': ([23, 27], 'horizontal'), # indices: 24, 28 + '21-29': ([20, 28], 'horizontal'), # indices: 21, 29 + } + + # Collect all available points with their indices + all_available_points = {} + for idx, kp in line_18_21_points: + all_available_points[idx] = kp + for idx, kp in line_23_24_points: + all_available_points[idx] = kp + for idx, kp in line_25_30_points: + all_available_points[idx] = kp + + # Step 1: Find the best vertical line and best horizontal line separately + best_vertical_line_name = None + best_vertical_line_points = [] + max_vertical_points = 1 + + best_horizontal_line_name = None + best_horizontal_line_points = [] + max_horizontal_points = 1 + + for line_name, (indices, line_type) in line_groups.items(): + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if line_type == 'vertical' and len(line_points) > max_vertical_points: + max_vertical_points = len(line_points) + best_vertical_line_name = line_name + best_vertical_line_points = line_points + elif line_type == 'horizontal' and len(line_points) > max_horizontal_points: + max_horizontal_points = len(line_points) + best_horizontal_line_name = line_name + best_horizontal_line_points = line_points + + # Check and calculate missing points on detected lines + # For vertical lines + if best_vertical_line_name is not None: + expected_indices = line_groups[best_vertical_line_name][0] + detected_indices = {idx for idx, _ in best_vertical_line_points} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords[best_vertical_line_points[0][0]] + template_end = template_coords[best_vertical_line_points[-1][0]] + frame_start = best_vertical_line_points[0][1] + frame_end = best_vertical_line_points[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords[missing_idx] + + # Calculate ratio along the line based on y-coordinate (vertical line) + template_y_start = template_start[1] + template_y_end = template_end[1] + template_y_missing = template_missing[1] + + if abs(template_y_end - template_y_start) > 1e-6: + ratio = (template_y_missing - template_y_start) / (template_y_end - template_y_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_vertical_line_points.append((missing_idx, new_point)) + all_available_points[missing_idx] = new_point + total_right_side_count += 1 + max_vertical_points = len(best_vertical_line_points) + + # Sort by index to maintain order + best_vertical_line_points.sort(key=lambda x: x[0]) + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_horizontal_points: + max_horizontal_points = len(line_points) + best_horizontal_line_name = line_name + best_horizontal_line_points = line_points + + # For horizontal lines + if best_horizontal_line_name is not None: + expected_indices = line_groups[best_horizontal_line_name][0] + detected_indices = {idx for idx, _ in best_horizontal_line_points} + missing_indices = [idx for idx in expected_indices if idx not in detected_indices] + + if len(missing_indices) > 0: + # Calculate missing points using template ratios + template_start = template_coords[best_horizontal_line_points[0][0]] + template_end = template_coords[best_horizontal_line_points[-1][0]] + frame_start = best_horizontal_line_points[0][1] + frame_end = best_horizontal_line_points[-1][1] + + for missing_idx in missing_indices: + template_missing = template_coords[missing_idx] + + # Calculate ratio along the line based on x-coordinate (horizontal line) + template_x_start = template_start[0] + template_x_end = template_end[0] + template_x_missing = template_missing[0] + + if abs(template_x_end - template_x_start) > 1e-6: + ratio = (template_x_missing - template_x_start) / (template_x_end - template_x_start) + else: + ratio = 0.5 + + # Calculate frame coordinates + x_new = frame_start[0] + (frame_end[0] - frame_start[0]) * ratio + y_new = frame_start[1] + (frame_end[1] - frame_start[1]) * ratio + new_point = (int(round(x_new)), int(round(y_new))) + + # Add to result and update collections + result[missing_idx] = new_point + best_horizontal_line_points.append((missing_idx, new_point)) + all_available_points[missing_idx] = new_point + total_right_side_count += 1 + max_horizontal_points = len(best_horizontal_line_points) + + # Sort by index to maintain order + best_horizontal_line_points.sort(key=lambda x: x[0]) + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_vertical_points: + max_vertical_points = len(line_points) + best_vertical_line_name = line_name + best_vertical_line_points = line_points + + # If we only have one direction, try to calculate the other direction line + if best_vertical_line_name is not None and best_horizontal_line_name is None: + # possible cases: + # line is 25-30 and off line point is 19, then we can calculate 18 so get horizontal line 18-26 + # line is 25-30 and off line point is 20, then we can calculate 18 so get horizontal line 18-26 + # line is 18-21 and off line point is 23, then we can calculate 27 so get horizontal line 23-27 + # line is 18-21 and off line point is 24, then we can calculate 28 so get horizontal line 24-28 + # line is 18-21 and off line point is 25, then we can calculate 26 so get horizontal line 18-26 + # line is 18-21 and off line point is 27, then we can calculate 26 so get horizontal line 18-26 + # line is 18-21 and off line point is 28, then we can calculate 29 so get horizontal line 21-29 + # line is 18-21 and off line point is 30, then we can calculate 29 so get horizontal line 21-29 + # line is 23-24 and off line point is 18, then we can calculate 26 so get horizontal line 18-26 + # line is 23-24 and off line point is 19, then we can calculate 18 so get horizontal line 18-26 + # line is 23-24 and off line point is 20, then we can calculate 21 so get horizontal line 21-29 + # line is 23-24 and off line point is 21, then we can calculate 29 so get horizontal line 21-29 + # line is 23-24 and off line point is 25, then we can calculate 27 so get horizontal line 23-27 + # line is 23-24 and off line point is 26, then we can calculate 27 so get horizontal line 23-27 + # line is 23-24 and off line point is 29, then we can calculate 28 so get horizontal line 24-28 + # line is 23-24 and off line point is 30, then we can calculate 28 so get horizontal line 24-28 + # We have vertical line but no horizontal line + # Find an off-line point (not on the vertical line) + off_line_point = None + off_line_idx = None + vertical_line_indices = line_groups[best_vertical_line_name][0] + for idx, kp in all_available_points.items(): + if idx not in vertical_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Convert off_line_point to numpy array for arithmetic operations + off_line_point = np.array(off_line_point) + + # Project off_line_point onto vertical line + template_off_line = template_coords[off_line_idx] + + template_vertical_start_index = best_vertical_line_points[0][0] + template_vertical_end_index = best_vertical_line_points[-1][0] + + template_vertical_start = template_coords[template_vertical_start_index] + template_vertical_end = template_coords[template_vertical_end_index] + + # Project at same y as off_line_point + template_y_off = template_off_line[1] + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_proj = (template_y_off - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_proj = 0.5 + + frame_vertical_start = best_vertical_line_points[0][1] + frame_vertical_end = best_vertical_line_points[-1][1] + proj_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_proj + proj_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + + if best_vertical_line_name == '25-30' and len(best_vertical_line_points) == 6: + if off_line_idx == 18 or off_line_idx == 19: # 19 or 20 point is off line point, so we can calculate 18 + kp_26 = np.array(best_vertical_line_points[1][1]) # 26 point + + kp_18 = off_line_point + (kp_26 - proj_point) + result[17] = tuple(kp_18.astype(int)) + total_right_side_count += 1 + all_available_points[17] = tuple(kp_18.astype(int)) # 18 point is now available, index is 17 + + if best_vertical_line_name == '18-21' and len(best_vertical_line_points) == 4: + if off_line_idx == 22 or off_line_idx == 23: # 23 or 24 point is off line point, so we can calculate 27 + template_19 = template_coords[18] # 19 point, index is 18 + template_23 = template_coords[22] # 23 point, index is 22 + template_27 = template_coords[26] # 27 point, index is 26 + + ratio = (template_27[0] - template_19[0]) / (template_23[0] - template_19[0]) # ratio in x coordinates because y coordinates are the same + + expected_point = proj_point + (off_line_point - proj_point) * ratio + + if off_line_idx == 22: + result[26] = tuple(expected_point.astype(int)) # 27 point, index is 26 + total_right_side_count += 1 + all_available_points[26] = tuple(expected_point.astype(int)) # 27 point is now available, index is 26 + else: + result[27] = tuple(expected_point.astype(int)) # 28 point, index is 27 + total_right_side_count += 1 + all_available_points[27] = tuple(expected_point.astype(int)) # 28 point is now available, index is 27 + + if off_line_idx == 24 or off_line_idx == 26: # 25 or 27 point is off line point, so we can calculate 26 + kp_18 = np.array(best_vertical_line_points[0][1]) # 18 point + kp_26 = off_line_point + (kp_18 - proj_point) + + result[25] = tuple(kp_26.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(kp_26.astype(int)) # 26 point is now available, index is 25 + + if off_line_idx == 27 or off_line_idx == 29: # 28 or 30 point is off line point, so we can calculate 29 + kp_21 = np.array(best_vertical_line_points[-1][1]) # 21 point + kp_29 = off_line_point + (kp_21 - proj_point) + + result[28] = tuple(kp_29.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(kp_29.astype(int)) # 29 point is now available, index is 28 + + + if best_vertical_line_name == '23-24' and len(best_vertical_line_points) == 2: + if off_line_idx == 17 or off_line_idx == 18 or off_line_idx == 19 or off_line_idx == 20: # 18 or 19 or 20 or 21 point is off line point, so we can calculate 26 + template_18 = template_coords[17] # 18 point, index is 17 + template_26 = template_coords[25] # 26 point, index is 25 + template_23 = template_coords[22] # 23 point, index is 22 + + ratio_26 = (template_26[0] - template_18[0]) / (template_23[0] - template_18[0]) # ratio in x coordinates because y coordinates are the same + + kp_18 = None + if off_line_idx == 17: + kp_18 = off_line_point + elif off_line_idx == 18 or off_line_idx == 19 or off_line_idx == 20: + template_off_line = template_coords[off_line_idx] + ratio = (template_18[1] - template_off_line[1]) / (template_23[1] - template_off_line[1]) + kp_18 = off_line_point + (np.array(best_vertical_line_points[0][1]) - proj_point) * ratio + + if kp_18 is not None: + kp_26 = kp_18 + (proj_point - off_line_point) * ratio_26 + result[25] = tuple(kp_26.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(kp_26.astype(int)) # 26 point is now available, index is 25 + + if off_line_idx == 24 or off_line_idx == 25: # 25 or 26 point is off line point, so we can calculate 27 + kp_27 = off_line_point + (np.array(best_vertical_line_points[0][1]) - proj_point) + + result[26] = tuple(kp_27.astype(int)) + total_right_side_count += 1 + all_available_points[26] = tuple(kp_27.astype(int)) # 27 point is now available, index is 26 + + if off_line_idx == 28 or off_line_idx == 29: # 29 or 30 point is off line point, so we can calculate 29 + kp_29 = off_line_point + (np.array(best_vertical_line_points[-1][1]) - proj_point) + + result[28] = tuple(kp_29.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(kp_29.astype(int)) # 29 point is now available, index is 28 + + + # Check if we can now form a horizontal line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'horizontal': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_horizontal_points: + max_horizontal_points = len(line_points) + best_horizontal_line_name = line_name + best_horizontal_line_points = line_points + + + elif best_horizontal_line_name is not None and best_vertical_line_name is None: + # possible cases: + # line is 18-26 and off line point is 23, then we can calculate 27 so get vertical line 25-30 + # line is 18-26 and off line point is 24, then we can calculate 28 so get vertical line 25-30 + # line is 23-27 and off line point is 18, then we can calculate 26 so get vertical line 25-30 + # line is 23-27 and off line point is 19, then we can calculate 18 so get vertical line 18-21 + # line is 23-27 and off line point is 20, then we can calculate 18 so get vertical line 18-21 + # line is 23-27 and off line point is 21, then we can calculate 29 so get vertical line 25-30 + # line is 24-28 and off line point is 18, then we can calculate 26 so get vertical line 25-30 + # line is 24-28 and off line point is 19, then we can calculate 21 so get vertical line 18-21 + # line is 24-28 and off line point is 20, then we can calculate 21 so get vertical line 18-21 + # line is 24-28 and off line point is 21, then we can calculate 29 so get vertical line 25-30 + # line is 21-29 and off line point is 23, then we can calculate 27 so get vertical line 25-30 + # line is 21-29 and off line point is 24, then we can calculate 28 so get vertical line 25-30 + # We have horizontal line but no vertical line + # Find an off-line point (not on the horizontal line) + off_line_point = None + off_line_idx = None + horizontal_line_indices = line_groups[best_horizontal_line_name][0] + for idx, kp in all_available_points.items(): + if idx not in horizontal_line_indices: + off_line_point = kp + off_line_idx = idx + break + + if off_line_point is not None: + # Project off_line_point onto horizontal line + template_off_line = template_coords[off_line_idx] + template_horizontal_start = template_coords[best_horizontal_line_points[0][0]] + template_horizontal_end = template_coords[best_horizontal_line_points[-1][0]] + + # Project at same x as off_line_point + template_x_off = template_off_line[0] + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_proj = (template_x_off - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_proj = 0.5 + + frame_horizontal_start = best_horizontal_line_points[0][1] + frame_horizontal_end = best_horizontal_line_points[-1][1] + proj_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_proj + proj_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_proj + proj_point = np.array([proj_x, proj_y]) + + if best_horizontal_line_name == '18-26': + if off_line_idx == 22 or off_line_idx == 23: # 23 or 24 point is off line point, so we can calculate 27 or 28 + template_18 = template_coords[best_horizontal_line_points[0][0]] # 18 point, index is 17 + template_26 = template_coords[best_horizontal_line_points[-1][0]] # 26 point, index is 25 + template_23 = template_coords[off_line_idx] # 23 or 24 point, index is 22 or 23 + + ratio_26 = (template_26[0] - template_23[0]) / (template_26[0] - template_18[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_26 + + if off_line_idx == 22: + result[26] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[26] = tuple(detected_point.astype(int)) # 26 point is now available, index is 26 + else: + result[27] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[27] = tuple(detected_point.astype(int)) # 27 point is now available, index is 27 + + if best_horizontal_line_name == '23-27': + if off_line_idx == 17 or off_line_idx == 20: + template_18 = template_coords[17] # 18 point, index is 17 + template_26 = template_coords[25] # 26 point, index is 25 + template_23 = template_coords[best_horizontal_line_points[0][0]] # 23 , index is 22 + + ratio_26 = (template_26[0] - template_18[0]) / (template_26[0] - template_23[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_26 + + if off_line_idx == 17: + result[25] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(detected_point.astype(int)) # 26 point is now available, index is 25 + else: + result[28] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(detected_point.astype(int)) # 29 point is now available, index is 28 + + if off_line_idx == 18 or off_line_idx == 19: # 19 or 20 point is off line point, so we can calculate 18 + template_18 = template_coords[17] # 18 point, index is 17 + template_off_line = template_coords[off_line_idx] + template_23 = template_coords[best_horizontal_line_points[0][0]] # 23 point, index is 22 + + ratio = (template_off_line[1] - template_18[1]) / (template_off_line[1] - template_23[1]) + kp_18 = off_line_point + (proj_point - off_line_point) * ratio + + result[17] = tuple(kp_18.astype(int)) + total_right_side_count += 1 + all_available_points[17] = tuple(kp_18.astype(int)) # 18 point is now available, index is 17 + + if best_horizontal_line_name == '24-28': + if off_line_idx == 17 or off_line_idx == 20: + template_18 = template_coords[17] # 18 point, index is 17 + template_26 = template_coords[25] # 26 point, index is 25 + template_24 = template_coords[best_horizontal_line_points[0][0]] # 24 , index is 23 + + ratio_26 = (template_26[0] - template_18[0]) / (template_26[0] - template_24[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_26 + + if off_line_idx == 17: + result[25] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[25] = tuple(detected_point.astype(int)) # 26 point is now available, index is 25 + else: + result[28] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[28] = tuple(detected_point.astype(int)) # 29 point is now available, index is 28 + + if off_line_idx == 18 or off_line_idx == 19: # 19 or 20 point is off line point, so we can calculate 18 + template_21 = template_coords[20] # 21 point, index is 20 + template_off_line = template_coords[off_line_idx] + template_24 = template_coords[best_horizontal_line_points[0][0]] # 24 point, index is 23 + + ratio = (template_21[1] - template_off_line[1]) / (template_24[1] - template_off_line[1]) + kp_21 = off_line_point + (proj_point - off_line_point) * ratio + + result[20] = tuple(kp_18.astype(int)) + total_right_side_count += 1 + all_available_points[20] = tuple(kp_18.astype(int)) # 21 point is now available, index is 20 + + if best_horizontal_line_name == '21-29': + if off_line_idx == 22 or off_line_idx == 23: # 23 or 24 point is off line point, so we can calculate 27 or 28 + template_21 = template_coords[best_horizontal_line_points[0][0]] # 21 point, index is 20 + template_29 = template_coords[best_horizontal_line_points[-1][0]] # 29 point, index is 28 + template_23 = template_coords[off_line_idx] # 23 or 24 point, index is 22 or 23 + + ratio_29 = (template_29[0] - template_23[0]) / (template_29[0] - template_21[0]) # ratio in x coordinates because y coordinates are the same + + detected_point = off_line_point + (np.array(best_horizontal_line_points[-1][1]) - np.array(best_horizontal_line_points[0][1])) * ratio_29 + + if off_line_idx == 22: + result[26] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[26] = tuple(detected_point.astype(int)) # 26 point is now available, index is 26 + else: + result[27] = tuple(detected_point.astype(int)) + total_right_side_count += 1 + all_available_points[27] = tuple(detected_point.astype(int)) # 27 point is now available, index is 27 + + # Check if we can now form a vertical line with the newly calculated points + for line_name, (indices, line_type) in line_groups.items(): + if line_type == 'vertical': + line_points = [(idx, all_available_points[idx]) for idx in indices if idx in all_available_points] + if len(line_points) > max_vertical_points: + max_vertical_points = len(line_points) + best_vertical_line_name = line_name + best_vertical_line_points = line_points + + # Calculate keypoint 22 if we have at least one line + if best_vertical_line_name is not None and best_horizontal_line_name is not None: + if kp_22 is None: + print(f"Calculating keypoint 22 using both vertical and horizontal lines: {best_vertical_line_name} and {best_horizontal_line_name}") + + template_x_22 = 940 + template_y_22 = 340 + + # Step 2: Project keypoint 22 onto vertical line (if available) + + template_vertical_start = template_coords[best_vertical_line_points[0][0]] + template_vertical_end = template_coords[best_vertical_line_points[-1][0]] + + # Project at y=340 (same y as keypoint 22) + template_y_vertical_start = template_vertical_start[1] + template_y_vertical_end = template_vertical_end[1] + + if abs(template_y_vertical_end - template_y_vertical_start) > 1e-6: + ratio_22_vertical = (template_y_22 - template_y_vertical_start) / (template_y_vertical_end - template_y_vertical_start) + else: + ratio_22_vertical = 0.5 + + frame_vertical_start = best_vertical_line_points[0][1] + frame_vertical_end = best_vertical_line_points[-1][1] + proj_22_on_vertical_x = frame_vertical_start[0] + (frame_vertical_end[0] - frame_vertical_start[0]) * ratio_22_vertical + proj_22_on_vertical_y = frame_vertical_start[1] + (frame_vertical_end[1] - frame_vertical_start[1]) * ratio_22_vertical + proj_22_on_vertical = (proj_22_on_vertical_x, proj_22_on_vertical_y) + + # Step 3: Project keypoint 22 onto horizontal line (if available) + + template_horizontal_start = template_coords[best_horizontal_line_points[0][0]] + template_horizontal_end = template_coords[best_horizontal_line_points[-1][0]] + + # Project at x=940 (same x as keypoint 22) + template_x_horizontal_start = template_horizontal_start[0] + template_x_horizontal_end = template_horizontal_end[0] + + if abs(template_x_horizontal_end - template_x_horizontal_start) > 1e-6: + ratio_22_horizontal = (template_x_22 - template_x_horizontal_start) / (template_x_horizontal_end - template_x_horizontal_start) + else: + ratio_22_horizontal = 0.5 + + frame_horizontal_start = best_horizontal_line_points[0][1] + frame_horizontal_end = best_horizontal_line_points[-1][1] + proj_22_on_horizontal_x = frame_horizontal_start[0] + (frame_horizontal_end[0] - frame_horizontal_start[0]) * ratio_22_horizontal + proj_22_on_horizontal_y = frame_horizontal_start[1] + (frame_horizontal_end[1] - frame_horizontal_start[1]) * ratio_22_horizontal + proj_22_on_horizontal = (proj_22_on_horizontal_x, proj_22_on_horizontal_y) + + # Step 4: Calculate keypoint 22 as intersection of two lines + # Line 1: Passes through proj_22_on_vertical, parallel to best_horizontal_line + # Line 2: Passes through proj_22_on_horizontal, parallel to best_vertical_line + + # Calculate direction vector of best_horizontal_line + horizontal_dir_x = frame_horizontal_end[0] - frame_horizontal_start[0] + horizontal_dir_y = frame_horizontal_end[1] - frame_horizontal_start[1] + horizontal_dir_length = np.sqrt(horizontal_dir_x**2 + horizontal_dir_y**2) + + # Calculate direction vector of best_vertical_line + vertical_dir_x = frame_vertical_end[0] - frame_vertical_start[0] + vertical_dir_y = frame_vertical_end[1] - frame_vertical_start[1] + vertical_dir_length = np.sqrt(vertical_dir_x**2 + vertical_dir_y**2) + + if horizontal_dir_length > 1e-6 and vertical_dir_length > 1e-6: + # Normalize direction vectors + horizontal_dir_x /= horizontal_dir_length + horizontal_dir_y /= horizontal_dir_length + vertical_dir_x /= vertical_dir_length + vertical_dir_y /= vertical_dir_length + + # Line 1: passes through proj_22_on_vertical with direction of best_horizontal_line + # Parametric: p1 = proj_22_on_vertical + t * horizontal_dir + # Line 2: passes through proj_22_on_horizontal with direction of best_vertical_line + # Parametric: p2 = proj_22_on_horizontal + s * vertical_dir + + # Find intersection: proj_22_on_vertical + t * horizontal_dir = proj_22_on_horizontal + s * vertical_dir + # This gives us: + # proj_22_on_vertical[0] + t * horizontal_dir_x = proj_22_on_horizontal[0] + s * vertical_dir_x + # proj_22_on_vertical[1] + t * horizontal_dir_y = proj_22_on_horizontal[1] + s * vertical_dir_y + + # Rearranging: + # t * horizontal_dir_x - s * vertical_dir_x = proj_22_on_horizontal[0] - proj_22_on_vertical[0] + # t * horizontal_dir_y - s * vertical_dir_y = proj_22_on_horizontal[1] - proj_22_on_vertical[1] + + # Solve for t and s using linear algebra + A = np.array([ + [horizontal_dir_x, -vertical_dir_x], + [horizontal_dir_y, -vertical_dir_y] + ]) + b = np.array([ + proj_22_on_horizontal[0] - proj_22_on_vertical[0], + proj_22_on_horizontal[1] - proj_22_on_vertical[1] + ]) + + try: + t, s = np.linalg.solve(A, b) + + # Calculate intersection point using line 1 + x_22 = proj_22_on_vertical[0] + t * horizontal_dir_x + y_22 = proj_22_on_vertical[1] + t * horizontal_dir_y + + result[21] = (int(round(x_22)), int(round(y_22))) + total_right_side_count += 1 + except np.linalg.LinAlgError: + # Lines are parallel or nearly parallel, use simple intersection + # If lines are parallel, use the projection points directly + x_22 = proj_22_on_vertical[0] + y_22 = proj_22_on_horizontal[1] + result[21] = (int(round(x_22)), int(round(y_22))) + total_right_side_count += 1 + else: + # Fallback: use simple intersection + x_22 = proj_22_on_vertical[0] + y_22 = proj_22_on_horizontal[1] + result[21] = (int(round(x_22)), int(round(y_22))) + total_right_side_count += 1 + + print(f"total_right_side_count: {total_right_side_count}, result: {result}") + if total_right_side_count > 5: + return result + + # Calculate m_line and b_line from best vertical or horizontal line for use in calculating other points + m_line = None + b_line = None + best_line_for_calc = None + best_line_type_for_calc = None + + if best_vertical_line_name is not None and len(best_vertical_line_points) >= 2: + best_line_for_calc = best_vertical_line_points + best_line_type_for_calc = 'vertical' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_vertical_line_points]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line, b_line = np.linalg.lstsq(A, y_coords, rcond=None)[0] + elif best_horizontal_line_name is not None and len(best_horizontal_line_points) >= 2: + best_line_for_calc = best_horizontal_line_points + best_line_type_for_calc = 'horizontal' + points_array = np.array([[kp[0], kp[1]] for _, kp in best_horizontal_line_points]) + x_coords = points_array[:, 0] + y_coords = points_array[:, 1] + A = np.vstack([x_coords, np.ones(len(x_coords))]).T + m_line, b_line = np.linalg.lstsq(A, y_coords, rcond=None)[0] + + # Calculate missing points to reach exactly 5 points + # Ensure 4 points aren't all on one line + if total_right_side_count < 5 and (m_line is not None or (best_line_for_calc is not None and best_line_type_for_calc == 'vertical')): + # Check current distribution + counts_per_line = [ + len(line_18_21_points), + len(line_23_24_points), + len(line_25_30_points) + ] + + # Calculate points on line 18-21 if needed + template_ys_18_21 = [140, 270, 410, 540] + template_indices_18_21 = [17, 18, 19, 20] + + if best_vertical_line_name == '25-30': + # Construct parallel line 18-21 from line 25-30 + for template_y, idx in zip(template_ys_18_21, template_indices_18_21): + if result[idx] is None and total_right_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 18-21 + if max(new_counts) >= 4 and total_right_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate y using scale from template + ref_ys = [kp[1] for _, kp in line_25_30_points] + ref_template_ys = [5, 140, 250, 430, 540, 675] + ref_indices = [24, 25, 26, 27, 28, 29] + + matched_template_ys = [] + for ref_idx, ref_kp in line_25_30_points: + if ref_idx in ref_indices: + template_idx = ref_indices.index(ref_idx) + matched_template_ys.append((ref_template_ys[template_idx], ref_kp[1])) + + if len(matched_template_ys) >= 1: + ref_template_y, ref_frame_y = matched_template_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + # Calculate x using parallel line geometry + if abs(m_line) > 1e-6: + x_on_line_25_30 = (y_new - b_line) / m_line + x_new = int(round(x_on_line_25_30 * 0.850)) + else: + x_new = int(round(np.median([kp[0] for _, kp in line_25_30_points]) * 0.850)) + + result[idx] = (x_new, y_new) + total_right_side_count += 1 + if total_right_side_count >= 5: + break + elif best_vertical_line_name == '18-21': + # Calculate missing points on line 18-21 + for template_y, idx in zip(template_ys_18_21, template_indices_18_21): + if result[idx] is None and total_right_side_count < 5: + # Check if adding this point would put 4 on one line + new_counts = counts_per_line.copy() + new_counts[0] += 1 # Adding to line 18-21 + if max(new_counts) >= 4 and total_right_side_count == 4: + # Would have 4 on one line, skip + continue + + # Calculate x on the line + if abs(m_line) > 1e-6: + x_new = (template_y - b_line) / m_line + else: + x_new = np.median([kp[0] for _, kp in line_18_21_points]) + + # Scale y based on available points + ref_ys = [kp[1] for _, kp in line_18_21_points] + ref_template_ys = [] + for ref_idx, _ in line_18_21_points: + if ref_idx in template_indices_18_21: + template_idx = template_indices_18_21.index(ref_idx) + ref_template_ys.append(template_ys_18_21[template_idx]) + + if len(ref_ys) >= 1 and len(ref_template_ys) >= 1: + ref_template_y = ref_template_ys[0] + ref_frame_y = ref_ys[0] + if ref_template_y > 0: + scale = ref_frame_y / ref_template_y + y_new = int(round(template_y * scale)) + else: + y_new = ref_frame_y + else: + y_new = int(round(np.median(ref_ys))) if ref_ys else template_y + + result[idx] = (int(round(x_new)), y_new) + total_right_side_count += 1 + if total_right_side_count >= 5: + break + + # Note: The unified approach above handles all cases (2a and 2b combined) + # Legacy code removed - all logic is now in the unified case 2 above + + return result + +def check_keypoints_would_cause_invalid_mask( + frame_keypoints: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]] = None, + frame: np.ndarray = None, + floor_markings_template: np.ndarray = None, + return_warped_data: bool = False, +) -> tuple[bool, str] | tuple[bool, str, tuple]: + """ + Check if keypoints would cause InvalidMask errors during evaluation. + + Args: + frame_keypoints: Frame keypoints to check + template_keypoints: Template keypoints (defaults to TEMPLATE_KEYPOINTS) + frame: Optional frame image for full validation + floor_markings_template: Optional template image for full validation + + Returns: + Tuple of (would_cause_error, error_message) + """ + try: + from keypoint_evaluation import ( + validate_projected_corners, + TEMPLATE_KEYPOINTS, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + findHomography, + InvalidMask, + ) + + if template_keypoints is None: + template_keypoints = TEMPLATE_KEYPOINTS + + # Filter valid keypoints + filtered_template = [] + filtered_frame = [] + + for i, (t_kp, f_kp) in enumerate(zip(template_keypoints, frame_keypoints)): + if f_kp[0] > 0 and f_kp[1] > 0: + filtered_template.append(t_kp) + filtered_frame.append(f_kp) + + if len(filtered_template) < 4: + if return_warped_data: + return (True, "Not enough keypoints for homography", None) + return (True, "Not enough keypoints for homography") + + # Compute homography + src_pts = np.array(filtered_template, dtype=np.float32) + dst_pts = np.array(filtered_frame, dtype=np.float32) + + result = findHomography(src_pts, dst_pts) + if result is None: + if return_warped_data: + return (True, "Failed to compute homography", None) + return (True, "Failed to compute homography") + H, _ = result + + # Check for twisted projection (bowtie) + try: + validate_projected_corners( + source_keypoints=template_keypoints, + homography_matrix=H + ) + except Exception as e: + error_msg = "Projection twisted (bowtie)" if "twisted" in str(e).lower() or "Projection twisted" in str(e).lower() else str(e) + if return_warped_data: + return (True, error_msg, None) + return (True, error_msg) + + # If frame and template are provided, check mask validation + if frame is not None and floor_markings_template is not None: + try: + from keypoint_evaluation import ( + project_image_using_keypoints, + extract_masks_for_ground_and_lines, + InvalidMask, + ) + + # project_image_using_keypoints can raise InvalidMask from validate_projected_corners + try: + # start_time = time.time() + warped_template = project_image_using_keypoints( + image=floor_markings_template, + source_keypoints=template_keypoints, + destination_keypoints=frame_keypoints, + destination_width=frame.shape[1], + destination_height=frame.shape[0], + ) + # end_time = time.time() + # print(f"project_image_using_keypoints time: {end_time - start_time} seconds") + except InvalidMask as e: + if return_warped_data: + return (True, f"Projection validation failed: {e}", None) + return (True, f"Projection validation failed: {e}") + except Exception as e: + # Other errors (e.g., ValueError from homography failure) + if return_warped_data: + return (True, f"Projection failed: {e}", None) + return (True, f"Projection failed: {e}") + + # extract_masks_for_ground_and_lines can raise InvalidMask from validation + try: + mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines( + image=warped_template + ) + except InvalidMask as e: + if return_warped_data: + return (True, f"Mask extraction validation failed: {e}", None) + return (True, f"Mask extraction validation failed: {e}") + except Exception as e: + if return_warped_data: + return (True, f"Mask extraction failed: {e}", None) + return (True, f"Mask extraction failed: {e}") + + # Additional explicit validation (though extract_masks_for_ground_and_lines already validates) + from keypoint_evaluation import validate_mask_lines, validate_mask_ground + try: + validate_mask_lines(mask_lines_expected) + except InvalidMask as e: + if return_warped_data: + return (True, f"Mask lines validation failed: {e}", None) + return (True, f"Mask lines validation failed: {e}") + except Exception as e: + if return_warped_data: + return (True, f"Mask lines validation error: {e}", None) + return (True, f"Mask lines validation error: {e}") + + try: + validate_mask_ground(mask_ground) + except InvalidMask as e: + if return_warped_data: + return (True, f"Mask ground validation failed: {e}", None) + return (True, f"Mask ground validation failed: {e}") + except Exception as e: + if return_warped_data: + return (True, f"Mask ground validation error: {e}", None) + return (True, f"Mask ground validation error: {e}") + + # If return_warped_data is True and validation passed, return the computed data + if return_warped_data: + return (False, "", (warped_template, mask_ground, mask_lines_expected)) + + except ImportError: + # If keypoint_evaluation is not available, skip validation + pass + except InvalidMask as e: + # Catch any InvalidMask that wasn't caught above + if return_warped_data: + return (True, f"InvalidMask error: {e}", None) + return (True, f"InvalidMask error: {e}") + except Exception as e: + # If we can't check masks for other reasons, assume it's okay + # Don't let exceptions propagate + pass + + # If we get here, keypoints should be valid + if return_warped_data: + return (False, "", None) # No warped data if frame/template not provided + return (False, "") + + except ImportError: + # If keypoint_evaluation is not available, skip validation + if return_warped_data: + return (False, "", None) + return (False, "") + except Exception as e: + # Any other error - assume it would cause problems + if return_warped_data: + return (True, f"Validation error: {e}", None) + return (True, f"Validation error: {e}") + + +def evaluate_keypoints_with_cached_data( + frame: np.ndarray, + mask_ground: np.ndarray, + mask_lines_expected: np.ndarray, +) -> float: + """ + Evaluate keypoints using pre-computed warped template and masks. + This avoids redundant computation when we already have the warped data from validation. + + Args: + frame: Frame image + mask_ground: Pre-computed ground mask from warped template + mask_lines_expected: Pre-computed expected lines mask from warped template + + Returns: + Score between 0.0 and 1.0 + """ + try: + from keypoint_evaluation import ( + extract_mask_of_ground_lines_in_image, + bitwise_and, + ) + + # Only need to extract predicted lines from frame (uses cached mask_ground) + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame, ground_mask=mask_ground + ) + + pixels_overlapping = bitwise_and( + mask_lines_expected, mask_lines_predicted + ).sum() + + pixels_on_lines = mask_lines_expected.sum() + + score = pixels_overlapping / (pixels_on_lines + 1e-8) + + return min(1.0, max(0.0, score)) # Clamp to [0, 1] + + except Exception as e: + print(f'Error in cached keypoint evaluation: {e}') + return 0.0 + + +def check_and_evaluate_keypoints( + frame_keypoints: list[tuple[int, int]], + frame: np.ndarray, +) -> tuple[bool, float, str]: + """ + Check if keypoints would cause InvalidMask errors and evaluate them in one call. + This reuses the warped template and masks computed during validation for evaluation. + + Args: + frame_keypoints: Frame keypoints to check and evaluate + frame: Frame image + + Returns: + Tuple of (is_valid, score, error_msg). + - If is_valid is True, score is the evaluation score and error_msg is empty string. + - If is_valid is False, score is 0.0 and error_msg contains the error message. + """ + # Check with return_warped_data=True to get cached data + # start_time = time.time() + check_result = check_keypoints_would_cause_invalid_mask( + frame_keypoints, _TEMPLATE_KEYPOINTS, frame, _TEMPLATE_IMAGE, + return_warped_data=True + ) + # end_time = time.time() + # print(f"check_keypoints_would_cause_invalid_mask time: {end_time - start_time} seconds") + + if len(check_result) == 3: + would_cause_error, error_msg, warped_data = check_result + else: + would_cause_error, error_msg = check_result + warped_data = None + + if would_cause_error: + return (False, 0.0, error_msg) + + # If we have cached warped data, use it for fast evaluation + if warped_data is not None: + _, mask_ground, mask_lines_expected = warped_data + try: + score = evaluate_keypoints_with_cached_data( + frame, mask_ground, mask_lines_expected + ) + return (True, score, "") + except Exception as e: + print(f'Error evaluating with cached data: {e}') + return (True, 0.0, "") + + # Fallback to regular evaluation if no cached data + try: + from keypoint_evaluation import evaluate_keypoints_for_frame + score = evaluate_keypoints_for_frame( + _TEMPLATE_KEYPOINTS, frame_keypoints, frame, _TEMPLATE_IMAGE + ) + return (True, score, "") + except Exception as e: + print(f'Error in regular evaluation: {e}') + return (True, 0.0, "") + + +# ============================================================================ +# MULTIPROCESSING WORKER FUNCTIONS +# ============================================================================ + +def _evaluate_batch_of_candidates(args): + """ + Worker function to evaluate a batch of keypoint candidates. + Uses threading, so we can share the frame/template without pickling overhead. + OpenCV operations are thread-safe for read operations, so no locking needed. + """ + candidate_batch, frame = args + + results = [] + for test_kps, candidate_metadata in candidate_batch: + # Match the exact behavior of sequential evaluation + # Only catch exceptions silently like the sequential version does + try: + if frame is not None and _TEMPLATE_IMAGE is not None: + is_valid, score, _ = check_and_evaluate_keypoints( + test_kps, frame + ) + # Only append valid results with positive scores (matching sequential behavior) + if is_valid: + results.append((is_valid, score, test_kps, candidate_metadata)) + except Exception: + # Silently skip like the sequential version - don't add invalid results + # This matches the original behavior exactly + pass + + return results + + +def evaluate_keypoints_candidates_parallel( + candidate_kps_list: List[List[Tuple[int, int]]], + candidate_metadata: List[Any], + frame: np.ndarray, + num_workers: int = None, +) -> Tuple[bool, float, List[Tuple[int, int]], Any]: + """ + Evaluate multiple keypoint candidates in parallel using threading. + Threading is faster than multiprocessing here because: + 1. OpenCV releases GIL, so threads can run in parallel + 2. No pickling overhead for large arrays (frame, template) + 3. Lower overhead than spawning processes + """ + if len(candidate_kps_list) == 0: + return (False, -1.0, None, None) + + if num_workers is None: + # Cap workers to avoid overhead with too many threads + # Optimal range is typically 8-32 workers depending on workload + # Too many threads cause context switching overhead and contention + # Cap at 32 even if CPU count is higher (e.g., cloud servers with 96+ CPUs) + max_cpu_workers = min(32, cpu_count()) # Cap at 32 to avoid overhead + max_workers = min(max_cpu_workers, len(candidate_kps_list)) + num_workers = max(1, max_workers) + + # For small numbers of candidates, use sequential evaluation + # Threading overhead isn't worth it for very small batches + # Lowered threshold to ensure we don't miss candidates due to batching issues + if len(candidate_kps_list) < 10: + best_result = None + best_score = -1.0 + for test_kps, metadata in zip(candidate_kps_list, candidate_metadata): + try: + is_valid, score, _ = check_and_evaluate_keypoints( + test_kps, frame + ) + if is_valid and score > best_score: + best_score = score + best_result = (is_valid, score, test_kps, metadata) + except Exception: + pass + else: + # Check if we're on Linux - ThreadPoolExecutor doesn't work well with opencv-python-headless + import platform + is_linux = platform.system().lower() == 'linux' + + # Use parallel processing for larger batches + if is_linux: + # Use ProcessPoolExecutor on Linux (multiprocessing) - works because each process has its own GIL + from concurrent.futures import ProcessPoolExecutor, as_completed + else: + # Use ThreadPoolExecutor on Windows/Other (threading) - OpenCV releases GIL + from concurrent.futures import ThreadPoolExecutor, as_completed + + # Split candidates into batches for each worker + # Ensure we process ALL candidates - use ceiling division + batch_size = max(1, (len(candidate_kps_list) + num_workers - 1) // num_workers) + batches = [] + total_candidates_in_batches = 0 + for i in range(0, len(candidate_kps_list), batch_size): + batch = list(zip( + candidate_kps_list[i:i+batch_size], + candidate_metadata[i:i+batch_size] + )) + if len(batch) > 0: # Only add non-empty batches + batches.append((batch, frame)) + total_candidates_in_batches += len(batch) + + # Verify we're processing all candidates + if total_candidates_in_batches != len(candidate_kps_list): + print(f"Warning: Batch mismatch! Expected {len(candidate_kps_list)} candidates, got {total_candidates_in_batches}") + + best_result = None + best_score = -1.0 + + try: + if is_linux: + executor_class = ProcessPoolExecutor + else: + executor_class = ThreadPoolExecutor + + with executor_class(max_workers=num_workers) as executor: + futures = [executor.submit(_evaluate_batch_of_candidates, args) for args in batches] + + all_results = [] + for future in as_completed(futures): + try: + batch_results = future.result() + if batch_results: # Only extend if we have results + all_results.extend(batch_results) + except Exception as e: + # Log but continue processing other batches + print(f"Error processing batch result: {e}") + import traceback + traceback.print_exc() + pass + + # Debug: Check if we got results from all batches + if len(all_results) == 0: + print(f"Warning: No valid results from parallel evaluation of {len(candidate_kps_list)} candidates") + + # Process all results and find the best one + # This ensures we compare ALL candidates, not just within batches + # Match the exact logic from sequential evaluation + for result in all_results: + if result is not None: + is_valid, score, test_kps, metadata = result + # Ensure score is numeric for comparison + try: + score = float(score) if score is not None else 0.0 + except (ValueError, TypeError): + score = 0.0 + # Match sequential evaluation: only update if valid and score is better + if is_valid and score > best_score: + best_score = score + best_result = (is_valid, score, test_kps, metadata) + except Exception as e: + print(f"Threading evaluation failed: {e}, falling back to sequential") + for test_kps, metadata in zip(candidate_kps_list, candidate_metadata): + try: + is_valid, score, _ = check_and_evaluate_keypoints( + test_kps, frame + ) + if is_valid and score > best_score: + best_score = score + best_result = (is_valid, score, test_kps, metadata) + except Exception: + pass + + if best_result is not None: + return best_result + + return (False, -1.0, None, None) + + +def _process_single_frame_for_prediction(args): + """ + Worker function to process a single frame for failed index prediction. + Returns: (frame_index, score, adjusted_success) + - score: evaluation score of the calculated keypoints (0.0 if failed or invalid) + - adjusted_success: True if keypoints were successfully adjusted, False otherwise + """ + frame_index, frame_result, frame_width, frame_height, frame_image, offset = args + + try: + from keypoint_helper_v2_optimized import ( + remove_duplicate_detections, + calculate_missing_keypoints, + adjust_keypoints_to_avoid_invalid_mask, + ) + + frame_keypoints = getattr(frame_result, "keypoints", []) or [] + original_count = sum(1 for (x, y) in frame_keypoints if int(x) != 0 and int(y) != 0) + + cleaned_keypoints = remove_duplicate_detections( + frame_keypoints, frame_width, frame_height + ) + + valid_keypoint_indices = [idx for idx, kp in enumerate(cleaned_keypoints) if kp[0] != 0 and kp[1] != 0] + + if len(valid_keypoint_indices) > 5: + calculated_keypoints = cleaned_keypoints + else: + left_side_indices_range = range(0, 13) + right_side_indices_range = range(17, 30) + + side_check_set = set() + if len(valid_keypoint_indices) >= 4: + for idx in valid_keypoint_indices: + if idx in left_side_indices_range: + side_check_set.add("left") + elif idx in right_side_indices_range: + side_check_set.add("right") + else: + side_check_set.add("center") + + if len(side_check_set) > 1: + calculated_keypoints = cleaned_keypoints + else: + calculated_keypoints = calculate_missing_keypoints( + cleaned_keypoints, frame_width, frame_height + ) + + original_frame_number = offset + frame_index + print(f"Frame {original_frame_number} (index {frame_index}): original_count: {original_count}, cleaned_keypoints: {len([kp for kp in cleaned_keypoints if kp[0] != 0 and kp[1] != 0])}, calculated_keypoints: {len([kp for kp in calculated_keypoints if kp[0] != 0 and kp[1] != 0])}") + + start_time = time.time() + adjusted_success, calculated_keypoints, score = adjust_keypoints_to_avoid_invalid_mask( + calculated_keypoints, frame_image + ) + end_time = time.time() + print(f"adjust_keypoints_to_avoid_invalid_mask time: {end_time - start_time} seconds") + + if not adjusted_success: + return (frame_index, 0.0, False) # Failed, score is 0.0 + + print(f"after adjustment, calculated_keypoints: {calculated_keypoints}, score: {score:.4f}") + setattr(frame_result, "keypoints", list(calculated_keypoints)) + + return (frame_index, score, True) # Success with score + except Exception as e: + print(f"Error processing frame {frame_index}: {e}") + return (frame_index, 0.0, False) # Failed on error, score is 0.0 + + +def _generate_sparse_keypoints_for_frame(args): + """ + Worker function to generate sparse keypoints for a single frame. + Returns: (frame_index, sparse_keypoints) + """ + frame_index, frame_width, frame_height, frame_image = args + + try: + from keypoint_helper_v2_optimized import ( + _generate_sparse_template_keypoints, + ) + + sparse_keypoints = _generate_sparse_template_keypoints( + frame_width, + frame_height, + frame_image=frame_image, + ) + + return (frame_index, sparse_keypoints) + except Exception as e: + print(f"Error generating sparse keypoints for frame {frame_index}: {e}") + # Return empty keypoints on error + return (frame_index, [(0, 0)] * 32) + + +def _evaluate_keypoints_for_frame(args): + """ + Worker function to evaluate both sparse and calculated keypoints for a single frame. + Returns: (frame_index, sparse_score, calculated_score, sparse_keypoints, calculated_keypoints) + """ + frame_index, sparse_keypoints, calculated_keypoints, frame_image, pre_calculated_score = args + + sparse_score = 0.0 + calculated_score = 0.0 + + # Use pre-calculated score if available (from _process_single_frame_for_prediction) + if pre_calculated_score is not None and pre_calculated_score > 0.0: + calculated_score = pre_calculated_score + print(f"Frame {frame_index}: Using pre-calculated score: {calculated_score:.4f}") + else: + # Need to evaluate calculated keypoints + calculated_score = 0.0 + + try: + from keypoint_evaluation import evaluate_keypoints_for_frame + + # Evaluate sparse keypoints + if frame_image is not None and _TEMPLATE_IMAGE is not None and _TEMPLATE_KEYPOINTS is not None: + try: + sparse_score = evaluate_keypoints_for_frame( + template_keypoints=_TEMPLATE_KEYPOINTS, + frame_keypoints=sparse_keypoints, + frame=frame_image, + floor_markings_template=_TEMPLATE_IMAGE, + ) + except Exception: + sparse_score = 0.0 + + # Evaluate calculated keypoints only if not pre-calculated + if pre_calculated_score is None or pre_calculated_score <= 0.0: + calculated_keypoints_valid = len([kp for kp in calculated_keypoints if kp[0] != 0 or kp[1] != 0]) >= 4 + if calculated_keypoints_valid: + try: + calculated_score = evaluate_keypoints_for_frame( + template_keypoints=_TEMPLATE_KEYPOINTS, + frame_keypoints=calculated_keypoints, + frame=frame_image, + floor_markings_template=_TEMPLATE_IMAGE, + ) + except Exception: + calculated_score = 0.0 + else: + calculated_score = -1.0 + except Exception as e: + print(f"Error evaluating keypoints for frame {frame_index}: {e}") + + return (frame_index, sparse_score, calculated_score, sparse_keypoints, calculated_keypoints) + +def _calculate_keypoints_score( + keypoints: list[tuple[int, int]], + frame: np.ndarray, +) -> float: + """ + Helper function to calculate score for keypoints. + Returns 0.0 if evaluation fails or keypoints are invalid. + """ + score = 0.0 + try: + from keypoint_evaluation import evaluate_keypoints_for_frame + + # Check if keypoints are valid (at least 4 non-zero keypoints) + keypoints_valid = len([kp for kp in keypoints if kp[0] != 0 or kp[1] != 0]) >= 4 + if keypoints_valid and frame is not None and _TEMPLATE_IMAGE is not None and _TEMPLATE_KEYPOINTS is not None: + try: + score = evaluate_keypoints_for_frame( + template_keypoints=_TEMPLATE_KEYPOINTS, + frame_keypoints=keypoints, + frame=frame, + floor_markings_template=_TEMPLATE_IMAGE, + ) + except Exception: + score = 0.0 + except Exception: + score = 0.0 + + return score + + +def adjust_keypoints_to_avoid_invalid_mask( + frame_keypoints: list[tuple[int, int]], + frame: np.ndarray = None, + max_iterations: int = 5, + num_workers: int = None, +) -> tuple[bool, list[tuple[int, int]], float]: + """ + Adjust keypoints to avoid InvalidMask errors. + + This function tries to fix common issues: + 1. Twisted projection (bowtie) - adjusts corner keypoints + 2. Ground covers too much - shrinks projected area by moving corners inward + 3. Other mask validation issues - adjusts keypoints to improve projection + + Args: + frame_keypoints: Frame keypoints to adjust + frame: Optional frame image for validation + max_iterations: Maximum number of adjustment iterations + num_workers: Number of workers for parallel evaluation + + Returns: + Tuple of (success, adjusted_keypoints, score): + - success: True if keypoints were successfully adjusted, False otherwise + - adjusted_keypoints: Adjusted keypoints + - score: Evaluation score of the adjusted keypoints (0.0 if failed or invalid) + """ + adjusted = list(frame_keypoints) + + # Check if adjustment is needed and evaluate score in one call + # This reuses warped data from validation for efficient evaluation + error_msg = "" + would_cause_error = False + + is_valid, score, error_msg = check_and_evaluate_keypoints( + adjusted, frame + ) + if is_valid: + return (True, adjusted, score) + # error_msg is already available from check_and_evaluate_keypoints + would_cause_error = True # Keypoints are invalid + + + print(f"Would cause error: {would_cause_error}, error_msg: {error_msg}") + + # Try to fix twisted projection (most common issue) + if "twisted" in error_msg.lower() or "bowtie" in error_msg.lower() or "Projection twisted" in error_msg.lower(): + # Use the existing _adjust_keypoints_to_pass_validation function + adjusted = _adjust_keypoints_to_pass_validation( + adjusted, + frame.shape[1] if frame is not None else None, + frame.shape[0] if frame is not None else None + ) + + # Check again after adjustment and evaluate score + if frame is not None and _TEMPLATE_IMAGE is not None and _TEMPLATE_KEYPOINTS is not None: + is_valid, score, error_msg = check_and_evaluate_keypoints( + adjusted, frame + ) + if is_valid: + return (True, adjusted, score) + # error_msg is already available from check_and_evaluate_keypoints + else: + would_cause_error, error_msg = check_keypoints_would_cause_invalid_mask( + adjusted, _TEMPLATE_KEYPOINTS, frame, _TEMPLATE_IMAGE + ) + if not would_cause_error: + score = 0.0 + return (True, adjusted, score) + + # Handle "a projected line is too wide" error + # This happens when projected lines are too thick/wide (aspect ratio too high) + if "too wide" in error_msg.lower() or "wide line" in error_msg.lower(): + print(f"Adjusting keypoints to fix 'a projected line is too wide' error") + try: + # This error usually means the projection is creating lines that are too thick + # Strategy: Adjust keypoints to reduce projection distortion + + valid_keypoints = [] + for idx in range(len(adjusted)): + x, y = adjusted[idx] + if x == 0 and y == 0: + continue + valid_keypoints.append((idx, x, y)) + + if len(valid_keypoints) >= 4: + # Calculate center and spread of keypoints + center_x = sum(x for _, x, y in valid_keypoints) / len(valid_keypoints) + center_y = sum(y for _, x, y in valid_keypoints) / len(valid_keypoints) + + # Calculate distances from center + distances = [] + for idx, x, y in valid_keypoints: + dist = np.sqrt((x - center_x)**2 + (y - center_y)**2) + distances.append((idx, x, y, dist)) + + # Sort by distance + distances.sort(key=lambda d: d[3], reverse=True) + + # Strategy 1: Try moving keypoints slightly outward to reduce compression + # This can help if keypoints are too close together causing wide lines + best_wide_kps = None + best_wide_score = -1.0 + + # Collect all candidate keypoints first, then evaluate in parallel + candidate_kps_list = [] + candidate_metadata = [] + + # Strategy 1: Try expanding keypoints slightly (opposite of shrinking) + # Reduced from 4 to 2 candidates for faster computation + for expand_factor in [1.05, 1.10]: + test_kps = list(adjusted) + for idx, x, y, dist in distances: + new_x = int(round(center_x + (x - center_x) * expand_factor)) + new_y = int(round(center_y + (y - center_y) * expand_factor)) + test_kps[idx] = (new_x, new_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('expand', expand_factor)) + + # Strategy 2: Try adjusting individual keypoints (only top 2 farthest, reduced adjustments) + # Reduced from 6x6=36 per keypoint to 3x3=9, and only test top 2 keypoints + for idx, x, y, dist in distances[:2]: + for adjust_x in [-2, 0, 2]: + for adjust_y in [-2, 0, 2]: + if adjust_x == 0 and adjust_y == 0: + continue # Skip no-op + test_kps = list(adjusted) + test_kps[idx] = (x + adjust_x, y + adjust_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('perturb', idx, adjust_x, adjust_y)) + + # Strategy 3: Try slight shrinking (opposite approach - reduce projection area) + # Reduced from 3 to 2 candidates + for shrink_factor in [0.96, 0.94]: + test_kps = list(adjusted) + for idx, x, y, dist in distances: + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('shrink', shrink_factor)) + + # Evaluate all candidates in parallel + if len(candidate_kps_list) > 0: + print(f"Evaluating {len(candidate_kps_list)} wide-line candidates in parallel...") + eval_start = time.time() + is_valid, score, best_kps, best_meta = evaluate_keypoints_candidates_parallel( + candidate_kps_list, candidate_metadata, + frame, num_workers + ) + eval_time = time.time() - eval_start + print(f"Parallel evaluation took {eval_time:.2f} seconds for {len(candidate_kps_list)} candidates") + + if is_valid and score > best_wide_score: + best_wide_score = score + best_wide_kps = best_kps + print(f"Found best wide-line adjustment: {best_meta}, score: {score:.4f}") + + if best_wide_kps is not None: + # Score is already calculated in evaluate_keypoints_candidates_parallel + return (True, best_wide_kps, best_wide_score) + except Exception as e: + print(f"Error in wide line adjustment: {e}") + pass + + # Handle "projected ground should be a single object" error + # This happens when the ground mask has multiple disconnected regions + if "should be a single" in error_msg.lower() or "single object" in error_msg.lower() or "distinct regions" in error_msg.lower(): + print(f"Adjusting keypoints to fix 'projected ground should be a single object' error (optimized)") + try: + valid_keypoints = [] + for idx in range(len(adjusted)): + x, y = adjusted[idx] + if x == 0 and y == 0: + continue + valid_keypoints.append((idx, x, y)) + + if len(valid_keypoints) >= 4: + center_x = sum(x for _, x, y in valid_keypoints) / len(valid_keypoints) + center_y = sum(y for _, x, y in valid_keypoints) / len(valid_keypoints) + + candidate_kps_list = [] + candidate_metadata = [] + + # Strategy 1: Move keypoints closer to center + # Reduced from 5 to 3 candidates + for shrink_factor in [0.96, 0.92, 0.90]: + test_kps = list(adjusted) + for idx, x, y in valid_keypoints: + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('shrink', shrink_factor)) + + # Strategy 2: Adjust boundary keypoints + distances = [] + for idx, x, y in valid_keypoints: + dist = np.sqrt((x - center_x)**2 + (y - center_y)**2) + distances.append((idx, x, y, dist)) + distances.sort(key=lambda d: d[3], reverse=True) + + # Reduced from 3 to 2 candidates + for shrink_factor in [0.90, 0.85]: + test_kps = list(adjusted) + boundary_count = max(1, len(distances) // 4) + for idx, x, y, dist in distances[:boundary_count]: + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('boundary', shrink_factor)) + + # Evaluate all candidates in parallel + if len(candidate_kps_list) > 0: + print(f"Evaluating {len(candidate_kps_list)} single-object candidates in parallel...") + eval_start = time.time() + is_valid, score, best_kps, best_meta = evaluate_keypoints_candidates_parallel( + candidate_kps_list, candidate_metadata, + frame, num_workers + ) + eval_time = time.time() - eval_start + print(f"Parallel evaluation took {eval_time:.2f} seconds for {len(candidate_kps_list)} candidates") + + if is_valid: + print(f"Found best single-object adjustment: {best_meta}, score: {score:.4f}") + # Score is already calculated in evaluate_keypoints_candidates_parallel + return (True, best_kps, score) + except Exception as e: + print(f"Error in optimized single object adjustment: {e}") + pass + + # Handle "ground covers too much" error by shrinking the projected area + if "ground covers" in error_msg.lower() or "covers more than" in error_msg.lower(): + print(f"Adjusting keypoints to avoid 'ground covers too much' error") + try: + from keypoint_evaluation import ( + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + # First, try adjusting corners if available + corner_indices = [ + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ] + + # Get corner keypoints + corners = [] + center_x, center_y = 0, 0 + valid_corners = 0 + + for corner_idx in corner_indices: + if corner_idx < len(adjusted): + x, y = adjusted[corner_idx] + if x == 0 and y == 0: + continue + corners.append((corner_idx, x, y)) + center_x += x + center_y += y + valid_corners += 1 + + if valid_corners >= 4: + center_x /= valid_corners + center_y /= valid_corners + + candidate_kps_list = [] + candidate_metadata = [] + + # Move corners inward + # Reduced from 7 to 4 candidates for faster computation + for shrink_factor in [0.90, 0.85, 0.75, 0.65]: + test_kps = list(adjusted) + for corner_idx, x, y in corners: + new_x = int(round(center_x + (x - center_x) * shrink_factor)) + new_y = int(round(center_y + (y - center_y) * shrink_factor)) + test_kps[corner_idx] = (new_x, new_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('corner', shrink_factor)) + + # Evaluate all candidates in parallel + if len(candidate_kps_list) > 0: + print(f"Evaluating {len(candidate_kps_list)} corner adjustment candidates in parallel...") + eval_start = time.time() + is_valid, score, best_kps, best_meta = evaluate_keypoints_candidates_parallel( + candidate_kps_list, candidate_metadata, + frame, num_workers + ) + eval_time = time.time() - eval_start + print(f"Parallel evaluation took {eval_time:.2f} seconds for {len(candidate_kps_list)} candidates") + + if is_valid: + print(f"Found best corner adjustment: {best_meta}, score: {score:.4f}") + # Score is already calculated in evaluate_keypoints_candidates_parallel + return (True, best_kps, score) + + # If corners adjustment didn't work or we don't have enough corners, + # try adjusting individual keypoints one at a time + # This handles cases where non-corner keypoints (like 15, 16, 17, 31, 32) are causing the issue + valid_keypoints = [] + all_center_x, all_center_y = 0, 0 + valid_count = 0 + + for idx in range(len(adjusted)): + x, y = adjusted[idx] + if x == 0 and y == 0: + continue + valid_keypoints.append((idx, x, y)) + all_center_x += x + all_center_y += y + valid_count += 1 + + if valid_count >= 4: + all_center_x /= valid_count + all_center_y /= valid_count + + # Calculate distances from center for each keypoint + # Try adjusting keypoints farthest from center first (most likely to cause coverage issues) + distances = [] + for idx, x, y in valid_keypoints: + dist = np.sqrt((x - all_center_x)**2 + (y - all_center_y)**2) + distances.append((idx, x, y, dist)) + + # Sort by distance (farthest first) - these are most likely causing the coverage issue + distances.sort(key=lambda d: d[3], reverse=True) + + # Collect all candidate keypoints for parallel evaluation + candidate_kps_list = [] + candidate_metadata = [] + + # Try adjusting each keypoint individually + # Reduced: only test top 3 farthest keypoints, and reduce shrink factors from 9 to 4 + for idx, x, y, dist in distances[:3]: + for shrink_factor in [0.95, 0.90, 0.80, 0.70]: + test_kps = list(adjusted) + new_x = int(round(all_center_x + (x - all_center_x) * shrink_factor)) + new_y = int(round(all_center_y + (y - all_center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('individual', idx, shrink_factor)) + + # Try adjusting pairs + # Reduced from 6 to 3 candidates + if valid_count >= 6: + for shrink_factor in [0.90, 0.80, 0.70]: + test_kps = list(adjusted) + for idx, x, y, dist in distances[:2]: + new_x = int(round(all_center_x + (x - all_center_x) * shrink_factor)) + new_y = int(round(all_center_y + (y - all_center_y) * shrink_factor)) + test_kps[idx] = (new_x, new_y) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('pair', shrink_factor)) + + # Evaluate all candidates in parallel + if len(candidate_kps_list) > 0: + print(f"Evaluating {len(candidate_kps_list)} ground-coverage candidates in parallel...") + eval_start = time.time() + is_valid, score, best_kps, best_meta = evaluate_keypoints_candidates_parallel( + candidate_kps_list, candidate_metadata, + frame, num_workers + ) + eval_time = time.time() - eval_start + print(f"Parallel evaluation took {eval_time:.2f} seconds for {len(candidate_kps_list)} candidates") + + if is_valid: + print(f"Found best ground-coverage adjustment: {best_meta}, score: {score:.4f}") + # Score is already calculated in evaluate_keypoints_candidates_parallel + return (True, best_kps, score) + except Exception as e: + print(f"Error in ground coverage adjustment: {e}") + pass + + # If still causing errors, try small perturbations to corner keypoints + # This helps with mask validation issues + if would_cause_error and max_iterations > 0: + try: + from keypoint_evaluation import ( + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + corner_indices = [ + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ] + + # Collect all corner perturbation candidates + candidate_kps_list = [] + candidate_metadata = [] + + # Reduced corner perturbations: from 6x6=36 per corner to 3x3=9 per corner + # Also skip (0,0) to avoid no-op + for corner_idx in corner_indices: + if corner_idx < len(adjusted): + x, y = adjusted[corner_idx] + if x == 0 and y == 0: + continue + for dx in [-3, 0, 3]: + for dy in [-3, 0, 3]: + if dx == 0 and dy == 0: + continue # Skip no-op + test_kps = list(adjusted) + test_kps[corner_idx] = (x + dx, y + dy) + + # Add directly - validation and evaluation will be done in parallel + candidate_kps_list.append(test_kps) + candidate_metadata.append(('corner_perturb', corner_idx, dx, dy)) + + # Evaluate all candidates in parallel + if len(candidate_kps_list) > 0: + print(f"Evaluating {len(candidate_kps_list)} corner perturbation candidates in parallel...") + eval_start = time.time() + is_valid, score, best_kps, best_meta = evaluate_keypoints_candidates_parallel( + candidate_kps_list, candidate_metadata, + frame, num_workers + ) + eval_time = time.time() - eval_start + print(f"Parallel evaluation took {eval_time:.2f} seconds for {len(candidate_kps_list)} candidates") + + if is_valid: + print(f"Found best corner perturbation: {best_meta}, score: {score:.4f}") + # Score is already calculated in evaluate_keypoints_candidates_parallel + return (True, best_kps, score) + except Exception: + pass + + # If we can't fix it, return adjusted (best effort) with score 0.0 + score = _calculate_keypoints_score(adjusted, frame) + return (False, adjusted, score) + + +def _validate_keypoints_corners( + frame_keypoints: list[tuple[int, int]], + template_keypoints: list[tuple[int, int]] = None, +) -> bool: + """ + Validate that frame keypoints can form a valid homography with template keypoints + (corners don't create twisted projection). + + Returns True if validation passes, False otherwise. + """ + try: + from keypoint_evaluation import ( + validate_projected_corners, + TEMPLATE_KEYPOINTS, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + # Use provided template_keypoints or default TEMPLATE_KEYPOINTS + if template_keypoints is None: + template_keypoints = TEMPLATE_KEYPOINTS + + # Filter valid keypoints (non-zero) + filtered_template = [] + filtered_frame = [] + + for i, (t_kp, f_kp) in enumerate(zip(template_keypoints, frame_keypoints)): + if f_kp[0] > 0 and f_kp[1] > 0: # Frame keypoint is valid + filtered_template.append(t_kp) + filtered_frame.append(f_kp) + + if len(filtered_template) < 4: + return False # Not enough keypoints for homography + + # Compute homography from template to frame + src_pts = np.array(filtered_template, dtype=np.float32) + dst_pts = np.array(filtered_frame, dtype=np.float32) + + H, mask = cv2.findHomography(src_pts, dst_pts) + + if H is None: + return False # Homography computation failed + + # Validate corners using the homography + try: + validate_projected_corners( + source_keypoints=template_keypoints, + homography_matrix=H + ) + return True # Validation passed + except Exception: + return False # Validation failed (twisted projection) + + except ImportError: + # If keypoint_evaluation is not available, skip validation + return True + except Exception: + # Any other error - assume invalid + return False + +def calculate_and_adjust_keypoints( + results_frames: Sequence[Any], + frame_width: int = None, + frame_height: int = None, + frames: List[np.ndarray] = None, + offset: int = 0, + num_workers: int = None, +) -> list[tuple[int, float, bool]]: + """ + Calculate missing keypoints, adjust them to avoid invalid masks, and evaluate scores. + Processes frames in parallel using threading. + + For each frame: + 1. Calculates missing keypoints if needed + 2. Adjusts keypoints to avoid InvalidMask errors + 3. Evaluates the adjusted keypoints and calculates a score + + Args: + results_frames: Sequence of frame results with keypoints + frame_width: Frame width + frame_height: Frame height + frames: Optional list of frame images for validation + offset: Frame offset for tracking + num_workers: Number of worker threads (defaults to cpu_count()) + + Returns: + List of tuples (frame_index, score, adjusted_success) for all frames: + - frame_index: Index of the frame + - score: Evaluation score of the adjusted keypoints (0.0 if failed) + - adjusted_success: True if keypoints were successfully adjusted, False otherwise + """ + max_frames = len(results_frames) + if max_frames == 0: + return [] + + if num_workers is None: + # Cap workers to avoid overhead with too many threads + # Optimal range is typically 8-32 workers depending on workload + # Too many threads cause context switching overhead and contention + # Cap at 32 even if CPU count is higher (e.g., cloud servers with 96+ CPUs) + max_cpu_workers = min(32, cpu_count()) # Cap at 32 to avoid overhead + max_workers = min(max_cpu_workers, max_frames) + num_workers = max(1, max_workers) + + # Prepare arguments for each frame + # Note: With spawn method, each worker will pickle/unpickle the data anyway + # So we pass references - copying here would be redundant + args_list = [] + for frame_index, frame_result in enumerate(results_frames): + frame_image = None + if frames is not None and frame_index < len(frames): + frame_image = frames[frame_index] # Pass reference, will be copied by pickle + + args_list.append(( + frame_index, frame_result, frame_width, frame_height, + frame_image, offset + )) + + results = [] + + # Check if we're on Linux - ThreadPoolExecutor doesn't work well with opencv-python-headless + # OpenCV headless doesn't release GIL properly on Linux, so use ProcessPoolExecutor instead + import platform + is_linux = platform.system().lower() == 'linux' + + # Use parallel processing for larger batches + if max_frames >= 4 and num_workers > 1: + try: + if is_linux: + # Use ProcessPoolExecutor on Linux (multiprocessing) - works because each process has its own GIL + from concurrent.futures import ProcessPoolExecutor, as_completed + print(f"Linux detected: Processing {max_frames} frames in parallel using {num_workers} processes (ProcessPoolExecutor)...") + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = {executor.submit(_process_single_frame_for_prediction, args): args for args in args_list} + + for future in as_completed(futures): + try: + frame_index, score, adjusted_success = future.result() + results.append((frame_index, score, adjusted_success)) + except Exception as e: + print(f"Error getting result from worker: {e}") + # If we can't get the result, mark as failed with score 0.0 + args = futures[future] + frame_index = args[0] + results.append((frame_index, 0.0, False)) + else: + # Use ThreadPoolExecutor on Windows/Other (threading) - OpenCV releases GIL + from concurrent.futures import ThreadPoolExecutor, as_completed + print(f"Processing {max_frames} frames in parallel using {num_workers} workers (ThreadPoolExecutor)...") + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = {executor.submit(_process_single_frame_for_prediction, args): args for args in args_list} + + for future in as_completed(futures): + try: + frame_index, score, adjusted_success = future.result() + results.append((frame_index, score, adjusted_success)) + except Exception as e: + print(f"Error getting result from worker: {e}") + # If we can't get the result, mark as failed with score 0.0 + args = futures[future] + frame_index = args[0] + results.append((frame_index, 0.0, False)) + except Exception as e: + print(f"Parallel processing failed: {e}, falling back to sequential") + # Fallback to sequential + for args in args_list: + try: + frame_index, score, adjusted_success = _process_single_frame_for_prediction(args) + results.append((frame_index, score, adjusted_success)) + except Exception as e: + print(f"Error processing frame: {e}") + # Mark as failed if exception occurs + frame_index = args[0] + results.append((frame_index, 0.0, False)) + else: + # Sequential processing for small batches + for args in args_list: + try: + frame_index, score, adjusted_success = _process_single_frame_for_prediction(args) + results.append((frame_index, score, adjusted_success)) + except Exception as e: + print(f"Error processing frame: {e}") + # Mark as failed if exception occurs + frame_index = args[0] + results.append((frame_index, 0.0, False)) + + # Sort results by frame_index to ensure consistent ordering + results.sort(key=lambda x: x[0]) + return results + +def _generate_sparse_template_keypoints( + frame_width: int, + frame_height: int, + frame_image: np.ndarray = None, +) -> list[tuple[int, int]]: + # Use cached template dimensions for performance + template_max_x = _TEMPLATE_MAX_X + template_max_y = _TEMPLATE_MAX_Y + + # Calculate scaling factors for both dimensions + sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1) + sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1) + + # Always use uniform scaling to preserve pitch geometry and aspect ratio + # This prevents distortion that creates square contours (like 3x3, 4x4) which fail the wide line check + # Uniform scaling ensures the pitch maintains its shape and avoids twisted projections + uniform_scale = min(sx, sy) + + # Scale down significantly to create a much smaller pitch in the warped template + # Use a small fraction of the uniform scale to make the pitch as small as possible + # This creates a small pitch centered in the frame, avoiding edge artifacts + scale_factor = 0.15 # Use 15% of the frame-filling scale to make pitch much smaller + uniform_scale = uniform_scale * scale_factor + + # Ensure minimum scale to avoid keypoints being too close together + # Very small scales cause warping artifacts that create square contours (1x1, 2x2 pixels) + # These single-pixel artifacts trigger the "too wide" error + # Use a fixed minimum scale based on template dimensions to ensure keypoints are spaced properly + # This prevents warping artifacts regardless of frame size + # Template is 1045x675, need sufficient scale to avoid 1x1 pixel artifacts from warping + # Higher minimum scale ensures warped template doesn't create tiny square artifacts + min_scale_absolute = 0.3 # Fixed minimum 30% of template size to avoid 1x1 pixel warping artifacts + # Higher scale is necessary to prevent warping interpolation from creating single-pixel squares + uniform_scale = max(uniform_scale, min_scale_absolute) + + # Use only corner keypoints for sparse template + # Get corner indices from keypoint_evaluation + try: + from keypoint_evaluation import ( + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ) + selected_keypoint_indices = set([ + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ]) + except ImportError: + # Fallback to default corner indices if import fails + # Based on typical template: top-left=0, top-right=24, bottom-right=29, bottom-left=5 + selected_keypoint_indices = set([0, 24, 29, 5]) # Default corner indices + + line_distribution = None # Will store: (total_count, best_region_center, max_density) + + # If we have line distribution analysis, select appropriate keypoints + if frame_image is not None and _TEMPLATE_IMAGE is not None and _TEMPLATE_KEYPOINTS is not None: + try: + from keypoint_evaluation import ( + project_image_using_keypoints, + extract_masks_for_ground_and_lines_no_validation, + extract_mask_of_ground_lines_in_image + ) + + # Generate initial keypoints for analysis using EXACT FITTING (full frame coverage) + # This ensures we get correct line distribution analysis + # Use non-uniform scaling to fit exactly to frame dimensions + # Optimized with NumPy for better performance + initial_sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1) + initial_sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1) + num_template_kps = len(_TEMPLATE_KEYPOINTS) if _TEMPLATE_KEYPOINTS is not None else 32 + num_kps = max(32, num_template_kps) # Ensure we have at least 32 keypoints + + # Use NumPy for vectorized scaling + if _TEMPLATE_KEYPOINTS is not None and len(_TEMPLATE_KEYPOINTS) >= num_kps: + template_array = np.array(_TEMPLATE_KEYPOINTS[:num_kps], dtype=np.float32) + else: + # Pad with zeros if needed + template_array = np.zeros((num_kps, 2), dtype=np.float32) + if _TEMPLATE_KEYPOINTS is not None: + template_array[:len(_TEMPLATE_KEYPOINTS)] = _TEMPLATE_KEYPOINTS + + # Vectorized scaling and clamping + scaled_array = template_array.copy() + scaled_array[:, 0] = np.clip(np.round(template_array[:, 0] * initial_sx), 0, frame_width - 1) + scaled_array[:, 1] = np.clip(np.round(template_array[:, 1] * initial_sy), 0, frame_height - 1) + + # Set zero keypoints to (0, 0) + mask = (template_array[:, 0] <= 0) | (template_array[:, 1] <= 0) + scaled_array[mask] = 0 + + # Convert to list of tuples + initial_scaled = [(int(x), int(y)) for x, y in scaled_array] + + # With exact fitting, keypoints already fill the frame, no centering needed + initial_centered = initial_scaled + + if len(initial_scaled) > 0: + try: + warped_template = project_image_using_keypoints( + image=_TEMPLATE_IMAGE, + source_keypoints=_TEMPLATE_KEYPOINTS, + destination_keypoints=initial_centered, + destination_width=frame_width, + destination_height=frame_height, + ) + + # Use non-validating version for line distribution analysis + # Exact fitting might create invalid masks, but we still want to analyze line distribution + mask_ground, mask_lines = extract_masks_for_ground_and_lines_no_validation(image=warped_template) + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame_image, ground_mask=mask_ground + ) + + h, w = mask_lines_predicted.shape + + # Density-based region analysis: Find arbitrary region with highest line density + # Optimized with NumPy convolution for much faster computation + region_size_ratio = 0.35 # Region will be 35% of frame size + region_w = max(50, int(w * region_size_ratio)) + region_h = max(50, int(h * region_size_ratio)) + + # Use larger step size for faster computation (less precise but much faster) + # Increase step size significantly to reduce iterations + step_size = max(20, min(region_w // 3, region_h // 3, w // 10, h // 10)) + + # Optimized sliding window using NumPy + max_density = 0.0 + best_region_center = None + + # Pre-compute valid regions to avoid repeated calculations + y_starts = list(range(0, h - region_h + 1, step_size)) + x_starts = list(range(0, w - region_w + 1, step_size)) + + # Use vectorized operations where possible + for y_start in y_starts: + y_end = min(y_start + region_h, h) + for x_start in x_starts: + x_end = min(x_start + region_w, w) + + # Extract region and compute density in one operation + region_mask = mask_lines_predicted[y_start:y_end, x_start:x_end] + region_area = (x_end - x_start) * (y_end - y_start) + + if region_area == 0: + continue + + # Vectorized line count and density calculation + line_count = np.count_nonzero(region_mask) + density = float(line_count) / float(region_area) + + # Track region with highest density + if density > max_density: + max_density = density + best_region_center = ((x_start + x_end) // 2, (y_start + y_end) // 2) + + # If no region found, use frame center as fallback + if best_region_center is None: + best_region_center = (w // 2, h // 2) + max_density = 0.0 + + # Calculate total line count for validation + total_line_count = np.sum(mask_lines_predicted > 0) + + line_distribution = (total_line_count, best_region_center, max_density) + + print(f"Density-based region analysis: center={best_region_center}, density={max_density:.4f}, total_lines={total_line_count}") + except Exception: + pass # Use default keypoints if analysis fails + except Exception: + pass # Use default keypoints if analysis fails + + # Generate scaled keypoints only for selected indices + # Use _TEMPLATE_KEYPOINTS if available, otherwise fall back to FOOTBALL_KEYPOINTS + source_keypoints = _TEMPLATE_KEYPOINTS if _TEMPLATE_KEYPOINTS is not None else FOOTBALL_KEYPOINTS + num_keypoints = len(source_keypoints) if source_keypoints is not None else 32 + + scaled: list[tuple[int, int]] = [] + for i in range(num_keypoints): + if i in selected_keypoint_indices and i < len(source_keypoints): + tx, ty = source_keypoints[i] + if tx > 0 and ty > 0: # Only scale non-zero keypoints + x_scaled = int(round(tx * uniform_scale)) + y_scaled = int(round(ty * uniform_scale)) + scaled.append((x_scaled, y_scaled)) + else: + scaled.append((0, 0)) + else: + scaled.append((0, 0)) # Set unselected keypoints to (0, 0) + + # Ensure minimum spacing between keypoints to avoid warping artifacts + # Very close keypoints can create single-pixel square contours during warping + # Check if any keypoints are too close and adjust scale if needed + # Optimized with NumPy for better performance + min_spacing = 5 # Minimum 5 pixels between keypoints to avoid 1x1 artifacts + min_spacing_sq = min_spacing * min_spacing # Use squared distance to avoid sqrt + + # Extract valid keypoints (non-zero) for distance checking + valid_kps = np.array([(x, y) for x, y in scaled if x != 0 or y != 0], dtype=np.float32) + needs_adjustment = False + + if len(valid_kps) > 1: + # Use NumPy broadcasting for efficient distance calculation + # Compute pairwise squared distances + diff = valid_kps[:, None, :] - valid_kps[None, :, :] # Shape: (n, n, 2) + dist_sq = np.sum(diff ** 2, axis=2) # Shape: (n, n) + + # Set diagonal to large value to ignore self-distances + np.fill_diagonal(dist_sq, min_spacing_sq + 1) + + # Check if any distance is below threshold + if np.any(dist_sq < min_spacing_sq): + needs_adjustment = True + + # If keypoints are too close, slightly increase scale to maintain minimum spacing + if needs_adjustment and uniform_scale < 0.25: + uniform_scale = uniform_scale * 1.2 # Increase by 20% to ensure spacing + uniform_scale = min(uniform_scale, 0.25) # Cap at 25% to keep it small + # Recalculate selected keypoints with adjusted scale using NumPy + source_array = np.array(source_keypoints[:num_keypoints] if len(source_keypoints) >= num_keypoints + else source_keypoints + [(0, 0)] * (num_keypoints - len(source_keypoints)), + dtype=np.float32) + + # Create mask for selected indices + selected_mask = np.array([i in selected_keypoint_indices for i in range(num_keypoints)], dtype=bool) + valid_mask = (source_array[:, 0] > 0) & (source_array[:, 1] > 0) + final_mask = selected_mask & valid_mask + + # Vectorized scaling + scaled_array = source_array.copy() + scaled_array[final_mask, 0] = np.round(source_array[final_mask, 0] * uniform_scale) + scaled_array[final_mask, 1] = np.round(source_array[final_mask, 1] * uniform_scale) + scaled_array[~final_mask] = 0 + + # Convert to list of tuples + scaled = [(int(x), int(y)) for x, y in scaled_array] + + # Use line distribution analysis (already computed above) to determine optimal pitch placement + offset_x = 0 + offset_y = 0 + + if line_distribution is not None: + # Extract line distribution data (new format: total_count, best_region_center, max_density) + if len(line_distribution) >= 3: + total_line_count, best_region_center, max_density = line_distribution + else: + # Fallback if format is unexpected + total_line_count = line_distribution[0] if len(line_distribution) > 0 else 0 + best_region_center = None + max_density = 0.0 + + # Adjust keypoint placement based on line distribution + valid_points = [(x, y) for x, y in scaled if x > 0 and y > 0] + if len(valid_points) > 0: + scaled_width = max(x for x, y in valid_points) + scaled_height = max(y for x, y in valid_points) + + margin = 5 + + # Only use line distribution analysis if we detected a reasonable number of lines and found a good region + # Otherwise fall back to default centering + if total_line_count > 100 and best_region_center is not None and max_density > 0.01: # Minimum threshold to trust the analysis + # Use density-based region analysis: center sparse template on the region with highest density + target_center_x, target_center_y = best_region_center + + # Calculate offset to center the scaled template on the target region center + # The template center should align with the target region center + scaled_center_x = scaled_width // 2 + scaled_center_y = scaled_height // 2 + + offset_x = target_center_x - scaled_center_x + offset_y = target_center_y - scaled_center_y + + # Ensure template stays within frame bounds + offset_x = max(margin, min(offset_x, frame_width - scaled_width - margin)) + offset_y = max(margin, min(offset_y, frame_height - scaled_height - margin)) + + print(f"Positioning sparse template: target_center=({target_center_x}, {target_center_y}), offset=({offset_x}, {offset_y}), scaled_size=({scaled_width}, {scaled_height}), density={max_density:.4f}") + else: # Fallback to default centering + # Simple center positioning + offset_x = max(margin, (frame_width - scaled_width) // 2) + offset_y = max(margin, (frame_height - scaled_height) // 2) + offset_x = min(offset_x, frame_width - scaled_width - margin) + offset_y = min(offset_y, frame_height - scaled_height - margin) + offset_x = max(0, offset_x) + offset_y = max(0, offset_y) + else: + # Default centering if no line distribution analysis + valid_points = [(x, y) for x, y in scaled if x > 0 and y > 0] + if len(valid_points) > 0: + scaled_width = max(x for x, y in valid_points) + scaled_height = max(y for x, y in valid_points) + margin = 5 + offset_x = max(margin, (frame_width - scaled_width) // 2) + offset_y = max(margin, (frame_height - scaled_height) // 2) + offset_x = min(offset_x, frame_width - scaled_width - margin) + offset_y = min(offset_y, frame_height - scaled_height - margin) + offset_x = max(0, offset_x) + offset_y = max(0, offset_y) + + # Lightweight vertical adjustment: Try small vertical offsets to align top/bottom edge with lines + # This improves overlap without much speed penalty + if frame_image is not None and _TEMPLATE_IMAGE is not None and _TEMPLATE_KEYPOINTS is not None and line_distribution is not None: + try: + total_line_count, best_region_center, max_density = line_distribution + if total_line_count > 100: # Only adjust if we have enough lines + from keypoint_evaluation import ( + project_image_using_keypoints, + extract_masks_for_ground_and_lines_no_validation, + extract_mask_of_ground_lines_in_image + ) + + # Get initial positioned keypoints + initial_centered = [] + for x, y in scaled: + if x == 0 and y == 0: + initial_centered.append((0, 0)) + else: + new_x = x + offset_x + new_y = y + offset_y + new_x = max(0, min(new_x, frame_width - 1)) + initial_centered.append((new_x, new_y)) + + # Try small vertical adjustments (only 5 positions for speed) + best_adjusted_offset_y = offset_y + best_overlap = 0.0 + + # Try vertical offsets: -15, -7, 0, 7, 15 pixels + vertical_adjustments = [-15, -7, 0, 7, 15] + for adj_y in vertical_adjustments: + test_offset_y = offset_y + adj_y + + # Ensure within bounds + test_offset_y = max(margin, min(test_offset_y, frame_height - scaled_height - margin)) + + # Generate test keypoints with adjusted vertical position + test_centered = [] + for x, y in scaled: + if x == 0 and y == 0: + test_centered.append((0, 0)) + else: + new_x = x + offset_x + new_y = y + test_offset_y + new_x = max(0, min(new_x, frame_width - 1)) + new_y = max(0, min(new_y, frame_height - 1)) + test_centered.append((new_x, new_y)) + + # Quick validation: check spacing + test_corners = [test_centered[idx] for idx in sorted(selected_keypoint_indices) + if idx < len(test_centered) and test_centered[idx][0] > 0] + + if len(test_corners) == 4: + min_dist = float('inf') + for i in range(len(test_corners)): + for j in range(i + 1, len(test_corners)): + x1, y1 = test_corners[i] + x2, y2 = test_corners[j] + dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2) + min_dist = min(min_dist, dist) + + min_required_dist = max(30, min(frame_width, frame_height) * 0.1) + if min_dist < min_required_dist: + continue # Skip if corners too close + + # Project and calculate overlap + try: + warped = project_image_using_keypoints( + image=_TEMPLATE_IMAGE, + source_keypoints=_TEMPLATE_KEYPOINTS, + destination_keypoints=test_centered, + destination_width=frame_width, + destination_height=frame_height, + ) + + mask_ground_test, mask_lines_expected = extract_masks_for_ground_and_lines_no_validation(image=warped) + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame_image, ground_mask=mask_ground_test + ) + + # Calculate overlap + overlap_mask = (mask_lines_expected > 0) & (mask_lines_predicted > 0) + expected_pixels = np.sum(mask_lines_expected > 0) + + if expected_pixels > 0: + overlap = np.sum(overlap_mask) / float(expected_pixels) + + if overlap > best_overlap: + best_overlap = overlap + best_adjusted_offset_y = test_offset_y + except Exception: + continue # Skip if projection fails + # Use the best vertical offset + if best_overlap > 0.0: + offset_y = best_adjusted_offset_y + print(f"Vertical adjustment: best overlap={best_overlap:.4f}, adjusted offset_y={offset_y}") + except Exception as e: + print(f"Vertical adjustment error: {e}") + pass # Continue with original offset if adjustment fails + + # Apply centering offset + centered = [] + for x, y in scaled: + if x == 0 and y == 0: + centered.append((0, 0)) + else: + new_x = x + offset_x + new_y = y + offset_y + # Allow negative y coordinates (pitch extends above frame) + # But ensure x coordinates are within frame bounds to avoid warping artifacts + new_x = max(0, min(new_x, frame_width - 1)) + # Allow negative y, but ensure at least some keypoints are in frame + # This prevents large square artifacts from warping + centered.append((new_x, new_y)) + + # Ensure at least some keypoints have positive y coordinates (visible in frame) + # This prevents warping from creating large square artifacts + visible_keypoints = [kp for kp in centered if kp[1] > 0] + if len(visible_keypoints) < 4: + # Not enough visible keypoints - adjust offset_y to ensure visibility + # This prevents warping artifacts that create large squares + min_y = min(y for x, y in centered if y != 0) if visible_keypoints else 0 + if min_y < 0: + adjustment = abs(min_y) + 10 # Push down by at least 10 pixels + centered = [] + for x, y in scaled: + if x == 0 and y == 0: + centered.append((0, 0)) + else: + new_x = x + offset_x + new_y = y + offset_y + adjustment + new_x = max(0, min(new_x, frame_width - 1)) + new_y = max(0, new_y) # Ensure at least some are visible + centered.append((new_x, new_y)) + return centered + +# def _generate_sparse_template_keypoints( +# frame_width: int, +# frame_height: int, +# frame_image: np.ndarray = None, +# template_image: np.ndarray = None, +# template_keypoints: list[tuple[int, int]] = None, +# ) -> list[tuple[int, int]]: +# """ +# Generate sparse template keypoints that fill the frame exactly. +# We map the template bounds to the frame bounds (non-uniform scale), +# so the warped template covers the full frame without manual shifts. +# """ +# # Infer template dimensions from provided keypoints if available +# if template_keypoints is not None and len(template_keypoints) > 0: +# valid_template_points = [(x, y) for x, y in template_keypoints if x > 0 and y > 0] +# if len(valid_template_points) > 0: +# template_max_x = max(x for x, y in valid_template_points) +# template_max_y = max(y for x, y in valid_template_points) +# else: +# template_max_x, template_max_y = (1045, 675) +# else: +# template_max_x, template_max_y = (1045, 675) + +# # Non-uniform scale to fit the frame exactly (may stretch if aspect differs) +# sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1) +# sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1) + +# source_keypoints = template_keypoints if template_keypoints is not None else FOOTBALL_KEYPOINTS +# num_kps = len(source_keypoints) if source_keypoints is not None else 32 + +# scaled: list[tuple[int, int]] = [] +# for i in range(num_kps): +# tx, ty = source_keypoints[i] +# if tx > 0 and ty > 0: +# x_scaled = int(round(tx * sx)) +# y_scaled = int(round(ty * sy)) +# # Clamp to frame bounds +# x_scaled = max(0, min(x_scaled, frame_width - 1)) +# y_scaled = max(0, min(y_scaled, frame_height - 1)) +# scaled.append((x_scaled, y_scaled)) +# else: +# scaled.append((0, 0)) + +# return scaled + +def _adjust_keypoints_to_pass_validation( + keypoints: list[tuple[int, int]], + frame_width: int = None, + frame_height: int = None, +) -> list[tuple[int, int]]: + """ + Adjust keypoints to pass validate_projected_corners. + If validation fails, try to fix by ensuring corners form a valid quadrilateral. + """ + if _validate_keypoints_corners(keypoints, _TEMPLATE_KEYPOINTS): + return keypoints # Already valid + + # If validation fails, try to fix by ensuring corner keypoints are in correct order + try: + from keypoint_evaluation import ( + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + ) + + template_keypoints = _TEMPLATE_KEYPOINTS + + # Get corner indices + corner_indices = [ + INDEX_KEYPOINT_CORNER_TOP_LEFT, + INDEX_KEYPOINT_CORNER_TOP_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT, + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT, + ] + + # Check if we have all corner keypoints + corners = [] + for idx in corner_indices: + if idx < len(keypoints): + x, y = keypoints[idx] + if x > 0 and y > 0: + corners.append((x, y, idx)) + + if len(corners) < 4: + # Not enough corners - can't fix, return original + return keypoints + + # Extract corner coordinates + corner_coords = [(x, y) for x, y, _ in corners] + + # Check if corners form a bowtie (twisted quadrilateral) + # A bowtie occurs when opposite edges intersect + def segments_intersect(p1, p2, q1, q2): + """Check if line segments p1-p2 and q1-q2 intersect.""" + def ccw(a, b, c): + return (c[1] - a[1]) * (b[0] - a[0]) > (b[1] - a[1]) * (c[0] - a[0]) + return (ccw(p1, q1, q2) != ccw(p2, q1, q2)) and (ccw(p1, p2, q1) != ccw(p1, p2, q2)) + + # Try different corner orderings to find a valid one + # Current order: top-left, top-right, bottom-right, bottom-left + # If this creates a bowtie, we need to reorder + + # Sort corners by position to get proper order + # Top row (smaller y values) + top_corners = sorted([c for c in corners if c[1] <= np.mean([c[1] for c in corners])], + key=lambda c: c[0]) + # Bottom row (larger y values) + bottom_corners = sorted([c for c in corners if c[1] > np.mean([c[1] for c in corners])], + key=lambda c: c[0]) + + # If we have 2 top and 2 bottom corners, ensure proper ordering + if len(top_corners) == 2 and len(bottom_corners) == 2: + # Ensure left < right + if top_corners[0][0] > top_corners[1][0]: + top_corners = top_corners[::-1] + if bottom_corners[0][0] > bottom_corners[1][0]: + bottom_corners = bottom_corners[::-1] + + # Reconstruct with proper order: top-left, top-right, bottom-right, bottom-left + result = list(keypoints) + + # Map to corner indices + corner_mapping = { + INDEX_KEYPOINT_CORNER_TOP_LEFT: top_corners[0], + INDEX_KEYPOINT_CORNER_TOP_RIGHT: top_corners[1], + INDEX_KEYPOINT_CORNER_BOTTOM_RIGHT: bottom_corners[1], + INDEX_KEYPOINT_CORNER_BOTTOM_LEFT: bottom_corners[0], + } + + for corner_idx, (x, y, _) in corner_mapping.items(): + if corner_idx < len(result): + result[corner_idx] = (x, y) + + # Validate the adjusted keypoints + if _validate_keypoints_corners(result, _TEMPLATE_KEYPOINTS): + return result + + # Alternative: If we can't fix by reordering, try using template-based scaling + # for corners only, keeping other keypoints as-is + if len(corners) >= 4: + # Calculate scale from non-corner keypoints if available + non_corner_kps = [(i, keypoints[i]) for i in range(len(keypoints)) + if i not in corner_indices and keypoints[i][0] > 0 and keypoints[i][1] > 0] + + if len(non_corner_kps) >= 2: + # Use template scaling approach + scales_x = [] + scales_y = [] + for idx, (x, y) in non_corner_kps: + if idx < len(template_keypoints): + tx, ty = template_keypoints[idx] + if tx > 0: + scales_x.append(x / tx) + if ty > 0: + scales_y.append(y / ty) + + if scales_x and scales_y: + avg_scale_x = sum(scales_x) / len(scales_x) + avg_scale_y = sum(scales_y) / len(scales_y) + + result = list(keypoints) + # Recalculate corners using template scaling + for corner_idx in corner_indices: + if corner_idx < len(template_keypoints): + tx, ty = template_keypoints[corner_idx] + new_x = int(round(tx * avg_scale_x)) + new_y = int(round(ty * avg_scale_y)) + if corner_idx < len(result): + result[corner_idx] = (new_x, new_y) + + # Validate again + if _validate_keypoints_corners(result, _TEMPLATE_KEYPOINTS): + return result + + except Exception: + pass + + # If we can't fix, return original + return keypoints + +def fix_keypoints( + results_frames: Sequence[Any], + frame_results: list[tuple[int, float, bool]], + frame_width: int, + frame_height: int, + frames: List[np.ndarray] = None, + offset: int = 0, + num_workers: int = None, +) -> list[Any]: + """ + Optimized version using batch-first approach: + 1. Generate sparse keypoints for ALL frames first + 2. Evaluate both sparse and calculated keypoints for ALL frames + 3. Choose the one with bigger score per frame + + Args: + results_frames: Sequence of frame results with keypoints + frame_results: List of tuples (frame_index, score, adjusted_success) from calculate_and_adjust_keypoints + frame_width: Frame width + frame_height: Frame height + frames: Optional list of frame images for validation + offset: Frame offset for tracking + num_workers: Number of worker threads (defaults to cpu_count()) + + Returns: + List of processed frame results + """ + max_frames = len(results_frames) + if max_frames == 0: + return list(results_frames) + + # Create a dictionary mapping frame_index to (score, adjusted_success) for quick lookup + frame_results_dict = {frame_index: (score, adjusted_success) for frame_index, score, adjusted_success in frame_results} + + + if num_workers is None: + # Cap workers to avoid overhead with too many threads + # Optimal range is typically 8-32 workers depending on workload + # Too many threads cause context switching overhead and contention + # Cap at 32 even if CPU count is higher (e.g., cloud servers with 96+ CPUs) + max_cpu_workers = min(32, cpu_count()) # Cap at 32 to avoid overhead + max_workers = min(max_cpu_workers, max_frames) + num_workers = max(1, max_workers) + + # Step 1: Extract calculated keypoints and pre-calculated scores from frame_results + # (already calculated in calculate_and_adjust_keypoints) + from keypoint_helper_v2_optimized import convert_keypoints_to_val_format + + calculated_keypoints_list = [] + pre_calculated_scores = {} + last_success_kps = None + + for frame_index in range(max_frames): + frame_result = results_frames[frame_index] + current_kps_raw = getattr(frame_result, "keypoints", []) or [] + calculated_kps = convert_keypoints_to_val_format(current_kps_raw) + + # Get pre-calculated score from frame_results (from calculate_and_adjust_keypoints) + if frame_index in frame_results_dict: + score, adjusted_success = frame_results_dict[frame_index] + if adjusted_success: # Only use valid scores + pre_calculated_scores[frame_index] = score + calculated_keypoints_list.append(calculated_kps) + last_success_kps = calculated_kps + else: + if last_success_kps is not None: + calculated_keypoints_list.append(last_success_kps) + else: + calculated_keypoints_list.append(calculated_kps) + else: + if last_success_kps is not None: + calculated_keypoints_list.append(last_success_kps) + else: + calculated_keypoints_list.append(calculated_kps) + + # Step 2: Generate sparse keypoints for ALL frames in parallel + print(f"Generating sparse keypoints for {max_frames} frames...") + sparse_args_list = [] + for frame_index in range(max_frames): + frame_for_analysis = None + if frames is not None and frame_index < len(frames): + frame_for_analysis = frames[frame_index] + + sparse_args_list.append(( + frame_index, frame_width, frame_height, + frame_for_analysis + )) + + sparse_keypoints_dict = {} + # Check if we're on Linux - use ProcessPoolExecutor instead of ThreadPoolExecutor + import platform + is_linux = platform.system().lower() == 'linux' + + if max_frames >= 4 and num_workers > 1: + try: + if is_linux: + from concurrent.futures import ProcessPoolExecutor, as_completed + executor_class = ProcessPoolExecutor + else: + from concurrent.futures import ThreadPoolExecutor, as_completed + executor_class = ThreadPoolExecutor + + with executor_class(max_workers=num_workers) as executor: + futures = [executor.submit(_generate_sparse_keypoints_for_frame, args) for args in sparse_args_list] + + for future in as_completed(futures): + try: + frame_idx, sparse_kps = future.result() + sparse_keypoints_dict[frame_idx] = sparse_kps + except Exception as e: + print(f"Error generating sparse keypoints: {e}") + except Exception as e: + print(f"Parallel processing failed for sparse generation: {e}, falling back to sequential") + for args in sparse_args_list: + try: + frame_idx, sparse_kps = _generate_sparse_keypoints_for_frame(args) + sparse_keypoints_dict[frame_idx] = sparse_kps + except Exception: + pass + else: + # Sequential for small batches + for args in sparse_args_list: + try: + frame_idx, sparse_kps = _generate_sparse_keypoints_for_frame(args) + sparse_keypoints_dict[frame_idx] = sparse_kps + except Exception: + pass + + # Ensure we have sparse keypoints for all frames + for frame_index in range(max_frames): + if frame_index not in sparse_keypoints_dict: + sparse_keypoints_dict[frame_index] = [(0, 0)] * 32 + + # Step 3: Evaluate both sparse and calculated keypoints for ALL frames in parallel + print(f"Evaluating sparse and calculated keypoints for {max_frames} frames...") + eval_args_list = [] + for frame_index in range(max_frames): + sparse_kps = sparse_keypoints_dict[frame_index] + calculated_kps = calculated_keypoints_list[frame_index] + + frame_for_analysis = None + if frames is not None and frame_index < len(frames): + frame_for_analysis = frames[frame_index] + + # Get pre-calculated score if available + pre_calculated_score = pre_calculated_scores.get(frame_index, None) + + eval_args_list.append(( + frame_index, sparse_kps, calculated_kps, + frame_for_analysis, pre_calculated_score + )) + + evaluation_results = {} + if max_frames >= 4 and num_workers > 1: + try: + if is_linux: + from concurrent.futures import ProcessPoolExecutor, as_completed + executor_class = ProcessPoolExecutor + else: + from concurrent.futures import ThreadPoolExecutor, as_completed + executor_class = ThreadPoolExecutor + + with executor_class(max_workers=num_workers) as executor: + futures = [executor.submit(_evaluate_keypoints_for_frame, args) for args in eval_args_list] + + for future in as_completed(futures): + try: + frame_idx, sparse_score, calculated_score, sparse_kps, calculated_kps = future.result() + evaluation_results[frame_idx] = (sparse_score, calculated_score, sparse_kps, calculated_kps) + except Exception as e: + print(f"Error evaluating keypoints: {e}") + except Exception as e: + print(f"Threading failed for evaluation: {e}, falling back to sequential") + for args in eval_args_list: + try: + frame_idx, sparse_score, calculated_score, sparse_kps, calculated_kps = _evaluate_keypoints_for_frame(args) + evaluation_results[frame_idx] = (sparse_score, calculated_score, sparse_kps, calculated_kps) + except Exception: + pass + else: + # Sequential for small batches + for args in eval_args_list: + try: + frame_idx, sparse_score, calculated_score, sparse_kps, calculated_kps = _evaluate_keypoints_for_frame(args) + evaluation_results[frame_idx] = (sparse_score, calculated_score, sparse_kps, calculated_kps) + except Exception: + pass + + # Step 4: Choose the keypoint set with bigger score per frame + print(f"Choosing best keypoints for {max_frames} frames...") + + for frame_index in range(max_frames): + frame_result = results_frames[frame_index] + + # Get evaluation results for this frame + if frame_index in evaluation_results: + sparse_score, calculated_score, sparse_kps, calculated_kps = evaluation_results[frame_index] + + # Choose the one with bigger score + if calculated_score > sparse_score: + final_keypoints = calculated_kps + print(f"Frame {frame_index}: Using calculated keypoints (score: {calculated_score:.4f} > sparse: {sparse_score:.4f})") + else: + final_keypoints = sparse_kps + print(f"Frame {frame_index}: Using sparse keypoints (score: {sparse_score:.4f} >= calculated: {calculated_score:.4f})") + else: + # Fallback to sparse if evaluation failed + final_keypoints = sparse_keypoints_dict.get(frame_index, [(0, 0)] * 32) + print(f"Frame {frame_index}: Using sparse keypoints (evaluation failed)") + + setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(final_keypoints))) + + return list(results_frames) + +def run_keypoints_post_processing( + results_frames: Sequence[Any], + frame_width: int, + frame_height: int, + frames: List[np.ndarray] = None, + template_keypoints: list[tuple[int, int]] = None, + template_image: np.ndarray = None, + offset: int = 0, + num_workers: int = None, +) -> list[Any]: + """ + Optimized post-processing with multiprocessing support. + + Args: + results_frames: Sequence of frame results with keypoints + frame_width: Frame width + frame_height: Frame height + frames: Optional list of frame images for validation + template_keypoints: Optional template keypoints (defaults to TEMPLATE_KEYPOINTS) + template_image: Optional pre-loaded template image (from miner constructor) + offset: Frame offset for tracking (defaults to 0) + num_workers: Number of worker processes for multiprocessing (defaults to cpu_count()) + + Returns: + List of processed frame results + """ + # Initialize module-level template variables (use pre-loaded template_image) + _initialize_template_variables(template_keypoints, template_image) + + # Calculate and adjust keypoints for all frames, getting scores and success status + frame_results = calculate_and_adjust_keypoints( + results_frames, frame_width, frame_height, + frames, offset, num_workers + ) + + return fix_keypoints( + results_frames, frame_results, frame_width, frame_height, + frames, offset, num_workers + ) \ No newline at end of file diff --git a/miner.py b/miner.py new file mode 100644 index 0000000000000000000000000000000000000000..b701a140f7bfea5af0dfec66d9eed35f8e082c6b --- /dev/null +++ b/miner.py @@ -0,0 +1,881 @@ +from pathlib import Path +from typing import List, Tuple, Dict, Optional +import sys +import os + +from numpy import ndarray +from pydantic import BaseModel + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from keypoint_helper import run_keypoints_post_processing +from keypoint_helper_v2 import run_keypoints_post_processing as run_keypoints_post_processing_v2 + +from ultralytics import YOLO +from team_cluster import TeamClassifier +from utils import ( + BoundingBox, + Constants, +) + +import time +import torch +import gc +import cv2 +import numpy as np +from collections import defaultdict +from pitch import process_batch_input, get_cls_net +from keypoint_evaluation import ( + evaluate_keypoints_for_frame, + evaluate_keypoints_for_frame_gpu, + load_template_from_file, + evaluate_keypoints_for_frame_opencv_cuda, + evaluate_keypoints_batch_for_frame, +) + +import yaml + + +class BoundingBox(BaseModel): + x1: int + y1: int + x2: int + y2: int + cls_id: int + conf: float + + +class TVFrameResult(BaseModel): + frame_id: int + boxes: List[BoundingBox] + keypoints: List[Tuple[int, int]] + + +class Miner: + SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA + SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX + SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT + CORNER_INDICES = Constants.CORNER_INDICES + KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE + CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE + GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN + MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier + MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting + + def __init__(self, path_hf_repo: Path) -> None: + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + model_path = path_hf_repo / "detection.onnx" + self.bbox_model = YOLO(model_path) + + print(f"BBox Model Loaded: class name {self.bbox_model.names}") + + team_model_path = path_hf_repo / "osnet_model.pth.tar-100" + self.team_classifier = TeamClassifier( + device=device, + batch_size=32, + model_name=str(team_model_path) + ) + print("Team Classifier Loaded") + + # Team classification state + self.team_classifier_fitted = False + self.player_crops_for_fit = [] + + self.keypoints_model_yolo = YOLO(path_hf_repo / "keypoint.pt") + + model_kp_path = path_hf_repo / 'keypoint' + config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' + cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) + + loaded_state_kp = torch.load(model_kp_path, map_location=device) + model = get_cls_net(cfg_kp) + model.load_state_dict(loaded_state_kp) + model.to(device) + model.eval() + + self.keypoints_model = model + print("Keypoints Model (keypoint.pt) Loaded") + + template_image_path = path_hf_repo / "football_pitch_template.png" + self.template_image, self.template_keypoints = load_template_from_file(str(template_image_path)) + + self.kp_threshold = 0.1 + self.pitch_batch_size = 4 + self.health = "healthy" + + print("✅ Keypoints Model Loaded") + except Exception as e: + self.health = "❌ Miner initialization failed: " + str(e) + print(self.health) + + def __repr__(self) -> str: + if self.health == 'healthy': + return ( + f"health: {self.health}\n" + f"BBox Model: {type(self.bbox_model).__name__}\n" + f"Keypoints Model: {type(self.keypoints_model).__name__}" + ) + else: + return self.health + + def _calculate_iou(self, box1: Tuple[float, float, float, float], + box2: Tuple[float, float, float, float]) -> float: + """ + Calculate Intersection over Union (IoU) between two bounding boxes. + Args: + box1: (x1, y1, x2, y2) + box2: (x1, y1, x2, y2) + Returns: + IoU score (0-1) + """ + x1_1, y1_1, x2_1, y2_1 = box1 + x1_2, y1_2, x2_2, y2_2 = box2 + + # Calculate intersection area + x_left = max(x1_1, x1_2) + y_top = max(y1_1, y1_2) + x_right = min(x2_1, x2_2) + y_bottom = min(y2_1, y2_2) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # Calculate union area + box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) + box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) + union_area = box1_area + box2_area - intersection_area + + if union_area == 0: + return 0.0 + + return intersection_area / union_area + + def _extract_jersey_region(self, crop: ndarray) -> ndarray: + """ + Extract jersey region (upper body) from player crop. + For close-ups, focuses on upper 60%, for distant shots uses full crop. + """ + if crop is None or crop.size == 0: + return crop + + h, w = crop.shape[:2] + if h < 10 or w < 10: + return crop + + # For close-up shots, extract upper body (jersey region) + is_closeup = h > 100 or (h * w) > 12000 + if is_closeup: + # Upper 60% of the crop (jersey area, avoiding shorts) + jersey_top = 0 + jersey_bottom = int(h * 0.60) + jersey_left = max(0, int(w * 0.05)) + jersey_right = min(w, int(w * 0.95)) + return crop[jersey_top:jersey_bottom, jersey_left:jersey_right] + return crop + + def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]: + """ + Extract color signature from jersey region using HSV and LAB color spaces. + Returns a feature vector with dominant colors and color statistics. + """ + if crop is None or crop.size == 0: + return None + + jersey_region = self._extract_jersey_region(crop) + if jersey_region.size == 0: + return None + + try: + # Convert to HSV and LAB color spaces + hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV) + lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB) + + # Reshape for processing + hsv_flat = hsv.reshape(-1, 3).astype(np.float32) + lab_flat = lab.reshape(-1, 3).astype(np.float32) + + # Compute statistics for HSV + hsv_mean = np.mean(hsv_flat, axis=0) / 255.0 + hsv_std = np.std(hsv_flat, axis=0) / 255.0 + + # Compute statistics for LAB + lab_mean = np.mean(lab_flat, axis=0) / 255.0 + lab_std = np.std(lab_flat, axis=0) / 255.0 + + # Dominant color (most frequent hue) + hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180)) + dominant_hue = np.argmax(hue_hist) * 5 # Convert to hue value + + # Combine features + color_features = np.concatenate([ + hsv_mean, + hsv_std, + lab_mean[:2], # L and A channels (B is less informative) + lab_std[:2], + [dominant_hue / 180.0] # Normalized dominant hue + ]) + + return color_features + except Exception as e: + print(f"Error extracting color signature: {e}") + return None + + def _get_spatial_position(self, bbox: Tuple[float, float, float, float], + frame_width: int, frame_height: int) -> Tuple[float, float]: + """ + Get normalized spatial position of player on the pitch. + Returns (x_normalized, y_normalized) where 0,0 is top-left. + """ + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + + # Normalize to [0, 1] + x_norm = center_x / frame_width if frame_width > 0 else 0.5 + y_norm = center_y / frame_height if frame_height > 0 else 0.5 + + return (x_norm, y_norm) + + def _find_best_match(self, target_box: Tuple[float, float, float, float], + predicted_frame_data: Dict[int, Tuple[Tuple, str]], + iou_threshold: float) -> Tuple[Optional[str], float]: + """ + Find best matching box in predicted frame data using IoU. + """ + best_iou = 0.0 + best_team_id = None + + for idx, (bbox, team_cls_id) in predicted_frame_data.items(): + iou = self._calculate_iou(target_box, bbox) + if iou > best_iou and iou >= iou_threshold: + best_iou = iou + best_team_id = team_cls_id + + return (best_team_id, best_iou) + + def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: + batch_size = 16 + detection_results = [] + n_frames = len(decoded_images) + for frame_number in range(0, n_frames, batch_size): + batch_images = decoded_images[frame_number: frame_number + batch_size] + detections = self.bbox_model(batch_images, verbose=False, save=False) + detection_results.extend(detections) + + return detection_results + + def _team_classify(self, detection_results, decoded_images, offset): + self.team_classifier_fitted = False + start = time.time() + # Collect player crops from first batch for fitting + fit_sample_size = 600 + player_crops_for_fit = [] + + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + if len(detection_box) < 4: + continue + # Collect player boxes for team classification fitting (first batch only) + if len(player_crops_for_fit) < fit_sample_size: + frame_image = decoded_images[frame_id] + for box in detection_box: + x1, y1, x2, y2, conf, cls_id = box.tolist() + if conf < 0.5: + continue + mapped_cls_id = str(int(cls_id)) + # Only collect player crops (cls_id = 2) + if mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + player_crops_for_fit.append(crop) + + # Fit team classifier after collecting samples + if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: + print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + break + if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: + print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + end = time.time() + print(f"Fitting Kmeans time: {end - start}") + + # Second pass: predict teams with configurable frame skipping optimization + start = time.time() + + # Get configuration for frame skipping + prediction_interval = 1 # Default: predict every 2 frames + iou_threshold = 0.3 + + print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") + + # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}} + predicted_frame_data = {} + + # Step 1: Predict for frames at prediction_interval only + frames_to_predict = [] + for frame_id in range(len(detection_results)): + if frame_id % prediction_interval == 0: + frames_to_predict.append(frame_id) + + print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " + f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") + + for frame_id in frames_to_predict: + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + + # Collect player crops for this frame + frame_player_crops = [] + frame_player_indices = [] + frame_player_boxes = [] + + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + # Collect player crops for prediction + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + frame_player_crops.append(crop) + frame_player_indices.append(idx) + frame_player_boxes.append((x1, y1, x2, y2)) + + # Predict teams for all players in this frame + if len(frame_player_crops) > 0: + team_ids = self.team_classifier.predict(frame_player_crops) + predicted_frame_data[frame_id] = {} + for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): + # Map team_id (0,1) to cls_id (6,7) + team_cls_id = str(6 + int(team_id)) + predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) + + # Step 2: Process all frames (interpolate skipped frames) + fallback_count = 0 + interpolated_count = 0 + bboxes: dict[int, list[BoundingBox]] = {} + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + boxes = [] + + team_predictions = {} + + if frame_id % prediction_interval == 0: + # Predicted frame: use pre-computed predictions + if frame_id in predicted_frame_data: + for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): + team_predictions[idx] = team_cls_id + else: + # Skipped frame: interpolate from neighboring predicted frames + # Find nearest predicted frames + prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval + next_predicted_frame = prev_predicted_frame + prediction_interval + + # Collect current frame player boxes + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + target_box = (x1, y1, x2, y2) + + # Try to match with previous predicted frame + best_team_id = None + best_iou = 0.0 + + if prev_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[prev_predicted_frame], + iou_threshold + ) + if team_id is not None: + best_team_id = team_id + best_iou = iou + + # Try to match with next predicted frame if available and no good match yet + if best_team_id is None and next_predicted_frame < len(detection_results): + if next_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[next_predicted_frame], + iou_threshold + ) + if team_id is not None and iou > best_iou: + best_team_id = team_id + best_iou = iou + + # Track interpolation success + if best_team_id is not None: + interpolated_count += 1 + else: + # Fallback: if no match found, predict individually + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + team_id = self.team_classifier.predict([crop])[0] + best_team_id = str(6 + int(team_id)) + fallback_count += 1 + + if best_team_id is not None: + team_predictions[idx] = best_team_id + + # Parse boxes with team classification + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + + # Check overlap with staff box + overlap_staff = False + for idy, boxy in enumerate(detection_box): + s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() + if cls_id == 2 and s_cls_id == 4: + staff_iou = self._calculate_iou(box[:4], boxy[:4]) + if staff_iou >= 0.8: + overlap_staff = True + break + if overlap_staff: + continue + + mapped_cls_id = str(int(cls_id)) + + # Override cls_id for players with team prediction + if idx in team_predictions: + mapped_cls_id = team_predictions[idx] + if mapped_cls_id != '4': + if int(mapped_cls_id) == 3 and conf < 0.5: + continue + boxes.append( + BoundingBox( + x1=int(x1), + y1=int(y1), + x2=int(x2), + y2=int(y2), + cls_id=int(mapped_cls_id), + conf=float(conf), + ) + ) + # Handle footballs - keep only the best one + footballs = [bb for bb in boxes if int(bb.cls_id) == 0] + if len(footballs) > 1: + best_ball = max(footballs, key=lambda b: b.conf) + boxes = [bb for bb in boxes if int(bb.cls_id) != 0] + boxes.append(best_ball) + + bboxes[offset + frame_id] = boxes + return bboxes + + + def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: + start = time.time() + detection_results = self._detect_objects_batch(batch_images) + end = time.time() + print(f"Detection time: {end - start}") + + # Use hybrid team classification + start = time.time() + bboxes = self._team_classify(detection_results, batch_images, offset) + end = time.time() + print(f"Team classify time: {end - start}") + + # Phase 3: Keypoint Detection + start = time.time() + keypoints_yolo: Dict[int, List[Tuple[int, int]]] = {} + + keypoints_yolo = self._detect_keypoints_batch(batch_images, offset, n_keypoints) + + + pitch_batch_size = min(self.pitch_batch_size, len(batch_images)) + keypoints: Dict[int, List[Tuple[int, int]]] = {} + + start = time.time() + last_score = 0 + last_valid_keypoints = None + while True: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + device_str = "cuda" + keypoints_result = process_batch_input( + batch_images, + self.keypoints_model, + self.kp_threshold, + device_str, + batch_size=pitch_batch_size, + ) + if keypoints_result is not None and len(keypoints_result) > 0: + for frame_number_in_batch, kp_dict in enumerate(keypoints_result): + if frame_number_in_batch >= len(batch_images): + break + frame_keypoints: List[Tuple[int, int]] = [] + try: + height, width = batch_images[frame_number_in_batch].shape[:2] + if kp_dict is not None and isinstance(kp_dict, dict): + for idx in range(32): + x, y = 0, 0 + kp_idx = idx + 1 + if kp_idx in kp_dict: + try: + kp_data = kp_dict[kp_idx] + if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data: + x = int(kp_data["x"] * width) + y = int(kp_data["y"] * height) + except (KeyError, TypeError, ValueError): + pass + frame_keypoints.append((x, y)) + except (IndexError, ValueError, AttributeError): + frame_keypoints = [(0, 0)] * 32 + if len(frame_keypoints) < n_keypoints: + frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints))) + else: + frame_keypoints = frame_keypoints[:n_keypoints] + + time1 = time.time() + frame_keypoints_yolo = keypoints_yolo.get(offset + frame_number_in_batch, frame_keypoints) + + valid_keypoints_count = 0 + valid_keypoints_yolo_count = 0 + for kp in frame_keypoints: + if kp[0] != 0.0 or kp[1] != 0.0: + valid_keypoints_count += 1 + if valid_keypoints_count > 3: + break + + for kp in frame_keypoints_yolo: + if kp[0] != 0.0 or kp[1] != 0.0: + valid_keypoints_yolo_count += 1 + if valid_keypoints_yolo_count > 3: + break + + # Evaluate and select best keypoints (using batch evaluation for speed) + if valid_keypoints_count > 3 and valid_keypoints_yolo_count > 3: + try: + last_valid_keypoints = keypoints.get(offset + frame_number_in_batch - 1, frame_keypoints) + # Evaluate both keypoint sets in batch (much faster!) + scores = evaluate_keypoints_batch_for_frame( + template_keypoints=self.template_keypoints, + frame_keypoints_list=[frame_keypoints, frame_keypoints_yolo, last_valid_keypoints], + frame=batch_images[frame_number_in_batch], + floor_markings_template=self.template_image, + device="cuda" + ) + score = scores[0] + score_yolo = scores[1] + last_score = scores[2] + + if last_score > score and last_score > score_yolo: + frame_keypoints = last_valid_keypoints + elif score_yolo > score: + frame_keypoints = frame_keypoints_yolo + last_score = score_yolo + else: + last_score = score + + last_valid_keypoints = frame_keypoints + + except Exception as e: + # Fallback: use YOLO if available, otherwise use pitch model + if valid_keypoints_yolo_count > 3: + frame_keypoints = frame_keypoints_yolo + elif valid_keypoints_yolo_count > 3: + # Only YOLO has valid keypoints + frame_keypoints = frame_keypoints_yolo + else: + if last_valid_keypoints is not None: + frame_keypoints = last_valid_keypoints + + time2 = time.time() + print(f"Keypoint evaluation time: {time2 - time1}") + + keypoints[offset + frame_number_in_batch] = frame_keypoints + break + end = time.time() + print(f"Keypoint time: {end - start}") + + results: List[TVFrameResult] = [] + for frame_number in range(offset, offset + len(batch_images)): + frame_boxes = bboxes.get(frame_number, []) + result = TVFrameResult( + frame_id=frame_number, + boxes=frame_boxes, + keypoints=keypoints.get( + frame_number, + [(0, 0) for _ in range(n_keypoints)], + ), + ) + results.append(result) + + start = time.time() + if len(batch_images) > 0: + h, w = batch_images[0].shape[:2] + results = run_keypoints_post_processing_v2( + results, w, h, + frames=batch_images, + template_keypoints=self.template_keypoints, + floor_markings_template=self.template_image, + offset=offset + ) + end = time.time() + print(f"Keypoint post processing time: {end - start}") + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + return results + + def _detect_keypoints_batch(self, batch_images: List[ndarray], + offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]: + """ + Phase 3: Keypoint detection for all frames in batch. + + Args: + batch_images: List of images to process + offset: Frame offset for numbering + n_keypoints: Number of keypoints expected + + Returns: + Dictionary mapping frame_id to list of keypoint coordinates + """ + keypoints: Dict[int, List[Tuple[int, int]]] = {} + keypoints_model_results = self.keypoints_model_yolo.predict(batch_images) + + if keypoints_model_results is None: + return keypoints + + for frame_idx_in_batch, detection in enumerate(keypoints_model_results): + if not hasattr(detection, "keypoints") or detection.keypoints is None: + continue + + # Extract keypoints with confidence + frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] + for i, part_points in enumerate(detection.keypoints.data): + for k_id, (x, y, _) in enumerate(part_points): + confidence = float(detection.keypoints.conf[i][k_id]) + frame_keypoints_with_conf.append((int(x), int(y), confidence)) + + # Pad or truncate to expected number of keypoints + if len(frame_keypoints_with_conf) < n_keypoints: + frame_keypoints_with_conf.extend( + [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) + ) + else: + frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] + + # Filter keypoints based on confidence thresholds + filtered_keypoints: List[Tuple[int, int]] = [] + for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): + if idx in self.CORNER_INDICES: + # Corner keypoints have lower confidence threshold + if confidence < 0.3: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + else: + # Regular keypoints + if confidence < 0.5: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + + frame_id = offset + frame_idx_in_batch + keypoints[frame_id] = filtered_keypoints + + return keypoints + + def predict_keypoints( + self, + images: List[ndarray], + n_keypoints: int = 32, + batch_size: Optional[int] = None, + conf_threshold: float = 0.5, + corner_conf_threshold: float = 0.3, + verbose: bool = False + ) -> Dict[int, List[Tuple[int, int]]]: + """ + Standalone function for keypoint detection on a list of images. + Optimized for maximum prediction speed. + + Args: + images: List of images (numpy arrays) to process + n_keypoints: Number of keypoints expected per frame (default: 32) + batch_size: Batch size for YOLO prediction (None = auto, uses all images) + conf_threshold: Confidence threshold for regular keypoints (default: 0.5) + corner_conf_threshold: Confidence threshold for corner keypoints (default: 0.3) + verbose: Whether to print progress information + + Returns: + Dictionary mapping frame index to list of keypoint coordinates (x, y) + Frame indices start from 0 + """ + if not images: + return {} + + keypoints: Dict[int, List[Tuple[int, int]]] = {} + + # Use provided batch_size or process all at once for maximum speed + if batch_size is None: + batch_size = len(images) + + # Process in batches for optimal GPU utilization + for batch_start in range(0, len(images), batch_size): + batch_end = min(batch_start + batch_size, len(images)) + batch_images = images[batch_start:batch_end] + + if verbose: + print(f"Processing keypoints batch {batch_start}-{batch_end-1} ({len(batch_images)} images)") + + # YOLO keypoint prediction (optimized batch processing) + keypoints_model_results = self.keypoints_model_yolo.predict( + batch_images, + verbose=False, + save=False, + conf=0.1, # Lower conf for detection, we filter later + ) + + if keypoints_model_results is None: + # Fill with empty keypoints for this batch + for frame_idx in range(batch_start, batch_end): + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Process each frame in the batch + for batch_idx, detection in enumerate(keypoints_model_results): + frame_idx = batch_start + batch_idx + + if not hasattr(detection, "keypoints") or detection.keypoints is None: + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Extract keypoints with confidence + frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] + try: + for i, part_points in enumerate(detection.keypoints.data): + for k_id, (x, y, _) in enumerate(part_points): + confidence = float(detection.keypoints.conf[i][k_id]) + frame_keypoints_with_conf.append((int(x), int(y), confidence)) + except (AttributeError, IndexError, TypeError): + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Pad or truncate to expected number of keypoints + if len(frame_keypoints_with_conf) < n_keypoints: + frame_keypoints_with_conf.extend( + [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) + ) + else: + frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] + + # Filter keypoints based on confidence thresholds + filtered_keypoints: List[Tuple[int, int]] = [] + for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): + if idx in self.CORNER_INDICES: + # Corner keypoints have lower confidence threshold + if confidence < corner_conf_threshold: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + else: + # Regular keypoints + if confidence < conf_threshold: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + + keypoints[frame_idx] = filtered_keypoints + + return keypoints + + def predict_objects( + self, + images: List[ndarray], + batch_size: Optional[int] = 16, + conf_threshold: float = 0.5, + iou_threshold: float = 0.45, + classes: Optional[List[int]] = None, + verbose: bool = False, + ) -> Dict[int, List[BoundingBox]]: + """ + Standalone high-throughput object detection function. + Runs the YOLO detector directly on raw images while skipping + any team-classification or keypoint stages for maximum FPS. + + Args: + images: List of frames (BGR numpy arrays). + batch_size: Number of frames per inference pass. Use None to process + all frames at once (fastest but highest memory usage). + conf_threshold: Detection confidence threshold. + iou_threshold: IoU threshold for NMS within YOLO. + classes: Optional list of class IDs to keep (None = all classes). + verbose: Whether to print per-batch progress from YOLO. + + Returns: + Dict mapping frame index -> list of BoundingBox predictions. + """ + if not images: + return {} + + detections: Dict[int, List[BoundingBox]] = {} + effective_batch = len(images) if batch_size is None else max(1, batch_size) + + for batch_start in range(0, len(images), effective_batch): + batch_end = min(batch_start + effective_batch, len(images)) + batch_images = images[batch_start:batch_end] + + start = time.time() + yolo_results = self.bbox_model( + batch_images, + conf=conf_threshold, + iou=iou_threshold, + classes=classes, + verbose=verbose, + save=False, + ) + end = time.time() + print(f"YOLO time: {end - start}") + + for local_idx, result in enumerate(yolo_results): + frame_idx = batch_start + local_idx + frame_boxes: List[BoundingBox] = [] + + if not hasattr(result, "boxes") or result.boxes is None: + detections[frame_idx] = frame_boxes + continue + + boxes_tensor = result.boxes.data + if boxes_tensor is None: + detections[frame_idx] = frame_boxes + continue + + for box in boxes_tensor: + try: + x1, y1, x2, y2, conf, cls_id = box.tolist() + frame_boxes.append( + BoundingBox( + x1=int(x1), + y1=int(y1), + x2=int(x2), + y2=int(y2), + cls_id=int(cls_id), + conf=float(conf), + ) + ) + except (ValueError, TypeError): + continue + + detections[frame_idx] = frame_boxes + + return detections + \ No newline at end of file diff --git a/miner1.py b/miner1.py new file mode 100644 index 0000000000000000000000000000000000000000..055d5cf272a6cc5b69f6cea2cf63160d8414b451 --- /dev/null +++ b/miner1.py @@ -0,0 +1,685 @@ +from pathlib import Path +from typing import List, Tuple, Dict, Optional +import sys +import os + +from numpy import ndarray +from pydantic import BaseModel + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from keypoint_helper import run_keypoints_post_processing +from keypoint_helper_v2 import run_keypoints_post_processing as run_keypoints_post_processing_v2 + +from ultralytics import YOLO +from team_cluster import TeamClassifier +from utils import ( + BoundingBox, + Constants, +) + +import time +import torch +import gc +import cv2 +import numpy as np +from collections import defaultdict +from pitch import process_batch_input, get_cls_net +from keypoint_evaluation import ( + evaluate_keypoints_for_frame, + evaluate_keypoints_for_frame_gpu, + load_template_from_file, + evaluate_keypoints_for_frame_opencv_cuda, + evaluate_keypoints_batch_for_frame, +) + +import yaml + + +class BoundingBox(BaseModel): + x1: int + y1: int + x2: int + y2: int + cls_id: int + conf: float + + +class TVFrameResult(BaseModel): + frame_id: int + boxes: List[BoundingBox] + keypoints: List[Tuple[int, int]] + + +class Miner: + SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA + SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX + SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT + CORNER_INDICES = Constants.CORNER_INDICES + KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE + CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE + GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN + MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier + MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting + + def __init__(self, path_hf_repo: Path) -> None: + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + model_path = path_hf_repo / "detection.onnx" + self.bbox_model = YOLO(model_path) + + print(f"BBox Model Loaded: class name {self.bbox_model.names}") + + team_model_path = path_hf_repo / "osnet_model.pth.tar-100" + self.team_classifier = TeamClassifier( + device=device, + batch_size=32, + model_name=str(team_model_path) + ) + print("Team Classifier Loaded") + + # Team classification state + self.team_classifier_fitted = False + self.player_crops_for_fit = [] + + self.keypoints_model_yolo = YOLO(path_hf_repo / "keypoint.pt") + + model_kp_path = path_hf_repo / 'keypoint' + config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' + cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) + + loaded_state_kp = torch.load(model_kp_path, map_location=device) + model = get_cls_net(cfg_kp) + model.load_state_dict(loaded_state_kp) + model.to(device) + model.eval() + + self.keypoints_model = model + print("Keypoints Model (keypoint.pt) Loaded") + + template_image_path = path_hf_repo / "football_pitch_template.png" + self.template_image, self.template_keypoints = load_template_from_file(str(template_image_path)) + + self.kp_threshold = 0.1 + self.pitch_batch_size = 4 + self.health = "healthy" + + print("✅ Keypoints Model Loaded") + except Exception as e: + self.health = "❌ Miner initialization failed: " + str(e) + print(self.health) + + def __repr__(self) -> str: + if self.health == 'healthy': + return ( + f"health: {self.health}\n" + f"BBox Model: {type(self.bbox_model).__name__}\n" + f"Keypoints Model: {type(self.keypoints_model).__name__}" + ) + else: + return self.health + + def _calculate_iou(self, box1: Tuple[float, float, float, float], + box2: Tuple[float, float, float, float]) -> float: + """ + Calculate Intersection over Union (IoU) between two bounding boxes. + Args: + box1: (x1, y1, x2, y2) + box2: (x1, y1, x2, y2) + Returns: + IoU score (0-1) + """ + x1_1, y1_1, x2_1, y2_1 = box1 + x1_2, y1_2, x2_2, y2_2 = box2 + + # Calculate intersection area + x_left = max(x1_1, x1_2) + y_top = max(y1_1, y1_2) + x_right = min(x2_1, x2_2) + y_bottom = min(y2_1, y2_2) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # Calculate union area + box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) + box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) + union_area = box1_area + box2_area - intersection_area + + if union_area == 0: + return 0.0 + + return intersection_area / union_area + + def _extract_jersey_region(self, crop: ndarray) -> ndarray: + """ + Extract jersey region (upper body) from player crop. + For close-ups, focuses on upper 60%, for distant shots uses full crop. + """ + if crop is None or crop.size == 0: + return crop + + h, w = crop.shape[:2] + if h < 10 or w < 10: + return crop + + # For close-up shots, extract upper body (jersey region) + is_closeup = h > 100 or (h * w) > 12000 + if is_closeup: + # Upper 60% of the crop (jersey area, avoiding shorts) + jersey_top = 0 + jersey_bottom = int(h * 0.60) + jersey_left = max(0, int(w * 0.05)) + jersey_right = min(w, int(w * 0.95)) + return crop[jersey_top:jersey_bottom, jersey_left:jersey_right] + return crop + + def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]: + """ + Extract color signature from jersey region using HSV and LAB color spaces. + Returns a feature vector with dominant colors and color statistics. + """ + if crop is None or crop.size == 0: + return None + + jersey_region = self._extract_jersey_region(crop) + if jersey_region.size == 0: + return None + + try: + # Convert to HSV and LAB color spaces + hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV) + lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB) + + # Reshape for processing + hsv_flat = hsv.reshape(-1, 3).astype(np.float32) + lab_flat = lab.reshape(-1, 3).astype(np.float32) + + # Compute statistics for HSV + hsv_mean = np.mean(hsv_flat, axis=0) / 255.0 + hsv_std = np.std(hsv_flat, axis=0) / 255.0 + + # Compute statistics for LAB + lab_mean = np.mean(lab_flat, axis=0) / 255.0 + lab_std = np.std(lab_flat, axis=0) / 255.0 + + # Dominant color (most frequent hue) + hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180)) + dominant_hue = np.argmax(hue_hist) * 5 # Convert to hue value + + # Combine features + color_features = np.concatenate([ + hsv_mean, + hsv_std, + lab_mean[:2], # L and A channels (B is less informative) + lab_std[:2], + [dominant_hue / 180.0] # Normalized dominant hue + ]) + + return color_features + except Exception as e: + print(f"Error extracting color signature: {e}") + return None + + def _get_spatial_position(self, bbox: Tuple[float, float, float, float], + frame_width: int, frame_height: int) -> Tuple[float, float]: + """ + Get normalized spatial position of player on the pitch. + Returns (x_normalized, y_normalized) where 0,0 is top-left. + """ + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + + # Normalize to [0, 1] + x_norm = center_x / frame_width if frame_width > 0 else 0.5 + y_norm = center_y / frame_height if frame_height > 0 else 0.5 + + return (x_norm, y_norm) + + def _find_best_match(self, target_box: Tuple[float, float, float, float], + predicted_frame_data: Dict[int, Tuple[Tuple, str]], + iou_threshold: float) -> Tuple[Optional[str], float]: + """ + Find best matching box in predicted frame data using IoU. + """ + best_iou = 0.0 + best_team_id = None + + for idx, (bbox, team_cls_id) in predicted_frame_data.items(): + iou = self._calculate_iou(target_box, bbox) + if iou > best_iou and iou >= iou_threshold: + best_iou = iou + best_team_id = team_cls_id + + return (best_team_id, best_iou) + + def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: + batch_size = 16 + detection_results = [] + n_frames = len(decoded_images) + for frame_number in range(0, n_frames, batch_size): + batch_images = decoded_images[frame_number: frame_number + batch_size] + detections = self.bbox_model(batch_images, verbose=False, save=False) + detection_results.extend(detections) + + return detection_results + + def _team_classify(self, detection_results, decoded_images, offset): + self.team_classifier_fitted = False + start = time.time() + # Collect player crops from first batch for fitting + fit_sample_size = 600 + player_crops_for_fit = [] + + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + if len(detection_box) < 4: + continue + # Collect player boxes for team classification fitting (first batch only) + if len(player_crops_for_fit) < fit_sample_size: + frame_image = decoded_images[frame_id] + for box in detection_box: + x1, y1, x2, y2, conf, cls_id = box.tolist() + if conf < 0.5: + continue + mapped_cls_id = str(int(cls_id)) + # Only collect player crops (cls_id = 2) + if mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + player_crops_for_fit.append(crop) + + # Fit team classifier after collecting samples + if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: + print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + break + if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: + print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + end = time.time() + print(f"Fitting Kmeans time: {end - start}") + + # Second pass: predict teams with configurable frame skipping optimization + start = time.time() + + # Get configuration for frame skipping + prediction_interval = 1 # Default: predict every 2 frames + iou_threshold = 0.3 + + print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") + + # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}} + predicted_frame_data = {} + + # Step 1: Predict for frames at prediction_interval only + frames_to_predict = [] + for frame_id in range(len(detection_results)): + if frame_id % prediction_interval == 0: + frames_to_predict.append(frame_id) + + print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " + f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") + + for frame_id in frames_to_predict: + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + + # Collect player crops for this frame + frame_player_crops = [] + frame_player_indices = [] + frame_player_boxes = [] + + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + # Collect player crops for prediction + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + frame_player_crops.append(crop) + frame_player_indices.append(idx) + frame_player_boxes.append((x1, y1, x2, y2)) + + # Predict teams for all players in this frame + if len(frame_player_crops) > 0: + team_ids = self.team_classifier.predict(frame_player_crops) + predicted_frame_data[frame_id] = {} + for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): + # Map team_id (0,1) to cls_id (6,7) + team_cls_id = str(6 + int(team_id)) + predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) + + # Step 2: Process all frames (interpolate skipped frames) + fallback_count = 0 + interpolated_count = 0 + bboxes: dict[int, list[BoundingBox]] = {} + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + boxes = [] + + team_predictions = {} + + if frame_id % prediction_interval == 0: + # Predicted frame: use pre-computed predictions + if frame_id in predicted_frame_data: + for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): + team_predictions[idx] = team_cls_id + else: + # Skipped frame: interpolate from neighboring predicted frames + # Find nearest predicted frames + prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval + next_predicted_frame = prev_predicted_frame + prediction_interval + + # Collect current frame player boxes + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + target_box = (x1, y1, x2, y2) + + # Try to match with previous predicted frame + best_team_id = None + best_iou = 0.0 + + if prev_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[prev_predicted_frame], + iou_threshold + ) + if team_id is not None: + best_team_id = team_id + best_iou = iou + + # Try to match with next predicted frame if available and no good match yet + if best_team_id is None and next_predicted_frame < len(detection_results): + if next_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[next_predicted_frame], + iou_threshold + ) + if team_id is not None and iou > best_iou: + best_team_id = team_id + best_iou = iou + + # Track interpolation success + if best_team_id is not None: + interpolated_count += 1 + else: + # Fallback: if no match found, predict individually + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + team_id = self.team_classifier.predict([crop])[0] + best_team_id = str(6 + int(team_id)) + fallback_count += 1 + + if best_team_id is not None: + team_predictions[idx] = best_team_id + + # Parse boxes with team classification + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + + # Check overlap with staff box + overlap_staff = False + for idy, boxy in enumerate(detection_box): + s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() + if cls_id == 2 and s_cls_id == 4: + staff_iou = self._calculate_iou(box[:4], boxy[:4]) + if staff_iou >= 0.8: + overlap_staff = True + break + if overlap_staff: + continue + + mapped_cls_id = str(int(cls_id)) + + # Override cls_id for players with team prediction + if idx in team_predictions: + mapped_cls_id = team_predictions[idx] + if mapped_cls_id != '4': + if int(mapped_cls_id) == 3 and conf < 0.5: + continue + boxes.append( + BoundingBox( + x1=int(x1), + y1=int(y1), + x2=int(x2), + y2=int(y2), + cls_id=int(mapped_cls_id), + conf=float(conf), + ) + ) + # Handle footballs - keep only the best one + footballs = [bb for bb in boxes if int(bb.cls_id) == 0] + if len(footballs) > 1: + best_ball = max(footballs, key=lambda b: b.conf) + boxes = [bb for bb in boxes if int(bb.cls_id) != 0] + boxes.append(best_ball) + + bboxes[offset + frame_id] = boxes + return bboxes + + + def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: + print('=' * 10) + print(f"Offset: {offset}, Batch size: {len(batch_images)}") + print('=' * 10) + + start = time.time() + detection_results = self._detect_objects_batch(batch_images) + end = time.time() + print(f"Detection time: {end - start}") + + # Use hybrid team classification + start = time.time() + bboxes = self._team_classify(detection_results, batch_images, offset) + end = time.time() + print(f"Team classify time: {end - start}") + + # Phase 3: Keypoint Detection + keypoints_yolo: Dict[int, List[Tuple[int, int]]] = {} + + keypoints_yolo = self._detect_keypoints_batch(batch_images, offset, n_keypoints) + + + # pitch_batch_size = min(self.pitch_batch_size, len(batch_images)) + # keypoints: Dict[int, List[Tuple[int, int]]] = {} + + # start = time.time() + # while True: + # gc.collect() + # if torch.cuda.is_available(): + # torch.cuda.empty_cache() + # torch.cuda.synchronize() + # device_str = "cuda" + # keypoints_result = process_batch_input( + # batch_images, + # self.keypoints_model, + # self.kp_threshold, + # device_str, + # batch_size=pitch_batch_size, + # ) + # if keypoints_result is not None and len(keypoints_result) > 0: + # for frame_number_in_batch, kp_dict in enumerate(keypoints_result): + # if frame_number_in_batch >= len(batch_images): + # break + # frame_keypoints: List[Tuple[int, int]] = [] + # try: + # height, width = batch_images[frame_number_in_batch].shape[:2] + # if kp_dict is not None and isinstance(kp_dict, dict): + # for idx in range(32): + # x, y = 0, 0 + # kp_idx = idx + 1 + # if kp_idx in kp_dict: + # try: + # kp_data = kp_dict[kp_idx] + # if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data: + # x = int(kp_data["x"] * width) + # y = int(kp_data["y"] * height) + # except (KeyError, TypeError, ValueError): + # pass + # frame_keypoints.append((x, y)) + # except (IndexError, ValueError, AttributeError): + # frame_keypoints = [(0, 0)] * 32 + # if len(frame_keypoints) < n_keypoints: + # frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints))) + # else: + # frame_keypoints = frame_keypoints[:n_keypoints] + + # # time1 = time.time() + # # frame_keypoints_yolo = keypoints_yolo.get(offset + frame_number_in_batch, frame_keypoints) + + # # valid_keypoints_count = 0 + # # valid_keypoints_yolo_count = 0 + # # for kp in frame_keypoints: + # # if kp[0] != 0.0 or kp[1] != 0.0: + # # valid_keypoints_count += 1 + # # if valid_keypoints_count > 3: + # # break + + # # for kp in frame_keypoints_yolo: + # # if kp[0] != 0.0 or kp[1] != 0.0: + # # valid_keypoints_yolo_count += 1 + # # if valid_keypoints_yolo_count > 3: + # # break + + # # # Evaluate and select best keypoints (using batch evaluation for speed) + # # if valid_keypoints_count > 3 and valid_keypoints_yolo_count > 3: + # # try: + # # # Evaluate both keypoint sets in batch (much faster!) + # # scores = evaluate_keypoints_batch_for_frame( + # # template_keypoints=self.template_keypoints, + # # frame_keypoints_list=[frame_keypoints, frame_keypoints_yolo], + # # frame=batch_images[frame_number_in_batch], + # # floor_markings_template=self.template_image, + # # device="cuda" + # # ) + # # score = scores[0] + # # score_yolo = scores[1] + + # # # Select the one with higher score + # # if score_yolo > score: + # # frame_keypoints = frame_keypoints_yolo + # # except Exception as e: + # # # Fallback: use YOLO if available, otherwise use pitch model + # # if valid_keypoints_yolo_count > 3: + # # frame_keypoints = frame_keypoints_yolo + # # elif valid_keypoints_yolo_count > 3: + # # # Only YOLO has valid keypoints + # # frame_keypoints = frame_keypoints_yolo + # # time2 = time.time() + # # print(f"Keypoint evaluation time: {time2 - time1}") + + # keypoints[offset + frame_number_in_batch] = frame_keypoints + # break + # end = time.time() + # print(f"Keypoint time: {end - start}") + + results: List[TVFrameResult] = [] + for frame_number in range(offset, offset + len(batch_images)): + frame_boxes = bboxes.get(frame_number, []) + result = TVFrameResult( + frame_id=frame_number, + boxes=frame_boxes, + keypoints=keypoints_yolo.get( + frame_number, + [(0, 0) for _ in range(n_keypoints)], + ), + ) + results.append(result) + + start = time.time() + if len(batch_images) > 0: + h, w = batch_images[0].shape[:2] + results = run_keypoints_post_processing_v2( + results, w, h, + frames=batch_images, + template_keypoints=self.template_keypoints, + floor_markings_template=self.template_image, + offset=offset + ) + end = time.time() + print(f"Keypoint post processing time: {end - start}") + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + return results + + def _detect_keypoints_batch(self, batch_images: List[ndarray], + offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]: + """ + Phase 3: Keypoint detection for all frames in batch. + + Args: + batch_images: List of images to process + offset: Frame offset for numbering + n_keypoints: Number of keypoints expected + + Returns: + Dictionary mapping frame_id to list of keypoint coordinates + """ + keypoints: Dict[int, List[Tuple[int, int]]] = {} + keypoints_model_results = self.keypoints_model_yolo.predict(batch_images) + + if keypoints_model_results is None: + return keypoints + + for frame_idx_in_batch, detection in enumerate(keypoints_model_results): + if not hasattr(detection, "keypoints") or detection.keypoints is None: + continue + + # Extract keypoints with confidence + frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] + for i, part_points in enumerate(detection.keypoints.data): + for k_id, (x, y, _) in enumerate(part_points): + confidence = float(detection.keypoints.conf[i][k_id]) + frame_keypoints_with_conf.append((int(x), int(y), confidence)) + + # Pad or truncate to expected number of keypoints + if len(frame_keypoints_with_conf) < n_keypoints: + frame_keypoints_with_conf.extend( + [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) + ) + else: + frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] + + # Filter keypoints based on confidence thresholds + filtered_keypoints: List[Tuple[int, int]] = [] + for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): + if idx in self.CORNER_INDICES: + # Corner keypoints have lower confidence threshold + if confidence < 0.3: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + else: + # Regular keypoints + if confidence < 0.5: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + + frame_id = offset + frame_idx_in_batch + keypoints[frame_id] = filtered_keypoints + + return keypoints + \ No newline at end of file diff --git a/miner2.py b/miner2.py new file mode 100644 index 0000000000000000000000000000000000000000..af87f6e7c11fa6c18871b65664535113b9f7a9c6 --- /dev/null +++ b/miner2.py @@ -0,0 +1,953 @@ +from pathlib import Path +from typing import List, Tuple, Dict, Optional +import sys +import os + +from numpy import ndarray +from pydantic import BaseModel + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from keypoint_helper import run_keypoints_post_processing +from keypoint_helper_v2 import run_keypoints_post_processing as run_keypoints_post_processing_v2 + +from ultralytics import YOLO +from team_cluster import TeamClassifier +from utils import ( + BoundingBox, + Constants, +) + +import time +import torch +import gc +import cv2 +import numpy as np +from collections import defaultdict +from pitch import process_batch_input, get_cls_net +from keypoint_evaluation import ( + evaluate_keypoints_for_frame, + evaluate_keypoints_for_frame_gpu, + load_template_from_file, + evaluate_keypoints_for_frame_opencv_cuda, + evaluate_keypoints_batch_for_frame, +) + +import yaml + + +class BoundingBox(BaseModel): + x1: int + y1: int + x2: int + y2: int + cls_id: int + conf: float + + +class TVFrameResult(BaseModel): + frame_id: int + boxes: List[BoundingBox] + keypoints: List[Tuple[int, int]] + + +class Miner: + SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA + SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX + SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT + CORNER_INDICES = Constants.CORNER_INDICES + KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE + CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE + GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN + MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier + MAX_SAMPLES_FOR_FIT = 600 # Maximum samples to avoid overfitting + + def __init__(self, path_hf_repo: Path) -> None: + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + model_path = path_hf_repo / "detection.onnx" + self.bbox_model = YOLO(model_path) + + print(f"BBox Model Loaded: class name {self.bbox_model.names}") + + team_model_path = path_hf_repo / "osnet_model.pth.tar-100" + self.team_classifier = TeamClassifier( + device=device, + batch_size=32, + model_name=str(team_model_path) + ) + print("Team Classifier Loaded") + + self.last_score = 0 + self.last_valid_keypoints = None + # Team classification state + self.team_classifier_fitted = False + self.player_crops_for_fit = [] + + self.keypoints_model_yolo = YOLO(path_hf_repo / "keypoint.pt") + + model_kp_path = path_hf_repo / 'keypoint' + config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' + cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) + + loaded_state_kp = torch.load(model_kp_path, map_location=device) + model = get_cls_net(cfg_kp) + model.load_state_dict(loaded_state_kp) + model.to(device) + model.eval() + + self.keypoints_model = model + print("Keypoints Model (keypoint.pt) Loaded") + + template_image_path = path_hf_repo / "football_pitch_template.png" + self.template_image, self.template_keypoints = load_template_from_file(str(template_image_path)) + + self.kp_threshold = 0.1 + self.pitch_batch_size = 4 + self.health = "healthy" + + print("✅ Keypoints Model Loaded") + except Exception as e: + self.health = "❌ Miner initialization failed: " + str(e) + print(self.health) + + def __repr__(self) -> str: + if self.health == 'healthy': + return ( + f"health: {self.health}\n" + f"BBox Model: {type(self.bbox_model).__name__}\n" + f"Keypoints Model: {type(self.keypoints_model).__name__}" + ) + else: + return self.health + + def _calculate_iou(self, box1: Tuple[float, float, float, float], + box2: Tuple[float, float, float, float]) -> float: + """ + Calculate Intersection over Union (IoU) between two bounding boxes. + Args: + box1: (x1, y1, x2, y2) + box2: (x1, y1, x2, y2) + Returns: + IoU score (0-1) + """ + x1_1, y1_1, x2_1, y2_1 = box1 + x1_2, y1_2, x2_2, y2_2 = box2 + + # Calculate intersection area + x_left = max(x1_1, x1_2) + y_top = max(y1_1, y1_2) + x_right = min(x2_1, x2_2) + y_bottom = min(y2_1, y2_2) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # Calculate union area + box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) + box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) + union_area = box1_area + box2_area - intersection_area + + if union_area == 0: + return 0.0 + + return intersection_area / union_area + + def _extract_jersey_region(self, crop: ndarray) -> ndarray: + """ + Extract jersey region (upper body) from player crop. + For close-ups, focuses on upper 60%, for distant shots uses full crop. + """ + if crop is None or crop.size == 0: + return crop + + h, w = crop.shape[:2] + if h < 10 or w < 10: + return crop + + # For close-up shots, extract upper body (jersey region) + is_closeup = h > 100 or (h * w) > 12000 + if is_closeup: + # Upper 60% of the crop (jersey area, avoiding shorts) + jersey_top = 0 + jersey_bottom = int(h * 0.60) + jersey_left = max(0, int(w * 0.05)) + jersey_right = min(w, int(w * 0.95)) + return crop[jersey_top:jersey_bottom, jersey_left:jersey_right] + return crop + + def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]: + """ + Extract color signature from jersey region using HSV and LAB color spaces. + Returns a feature vector with dominant colors and color statistics. + """ + if crop is None or crop.size == 0: + return None + + jersey_region = self._extract_jersey_region(crop) + if jersey_region.size == 0: + return None + + try: + # Convert to HSV and LAB color spaces + hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV) + lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB) + + # Reshape for processing + hsv_flat = hsv.reshape(-1, 3).astype(np.float32) + lab_flat = lab.reshape(-1, 3).astype(np.float32) + + # Compute statistics for HSV + hsv_mean = np.mean(hsv_flat, axis=0) / 255.0 + hsv_std = np.std(hsv_flat, axis=0) / 255.0 + + # Compute statistics for LAB + lab_mean = np.mean(lab_flat, axis=0) / 255.0 + lab_std = np.std(lab_flat, axis=0) / 255.0 + + # Dominant color (most frequent hue) + hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180)) + dominant_hue = np.argmax(hue_hist) * 5 # Convert to hue value + + # Combine features + color_features = np.concatenate([ + hsv_mean, + hsv_std, + lab_mean[:2], # L and A channels (B is less informative) + lab_std[:2], + [dominant_hue / 180.0] # Normalized dominant hue + ]) + + return color_features + except Exception as e: + print(f"Error extracting color signature: {e}") + return None + + def _get_spatial_position(self, bbox: Tuple[float, float, float, float], + frame_width: int, frame_height: int) -> Tuple[float, float]: + """ + Get normalized spatial position of player on the pitch. + Returns (x_normalized, y_normalized) where 0,0 is top-left. + """ + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + + # Normalize to [0, 1] + x_norm = center_x / frame_width if frame_width > 0 else 0.5 + y_norm = center_y / frame_height if frame_height > 0 else 0.5 + + return (x_norm, y_norm) + + def _find_best_match(self, target_box: Tuple[float, float, float, float], + predicted_frame_data: Dict[int, Tuple[Tuple, str]], + iou_threshold: float) -> Tuple[Optional[str], float]: + """ + Find best matching box in predicted frame data using IoU. + """ + best_iou = 0.0 + best_team_id = None + + for idx, (bbox, team_cls_id) in predicted_frame_data.items(): + iou = self._calculate_iou(target_box, bbox) + if iou > best_iou and iou >= iou_threshold: + best_iou = iou + best_team_id = team_cls_id + + return (best_team_id, best_iou) + + def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: + batch_size = 16 + detection_results = [] + n_frames = len(decoded_images) + for frame_number in range(0, n_frames, batch_size): + batch_images = decoded_images[frame_number: frame_number + batch_size] + detections = self.bbox_model(batch_images, verbose=False, save=False) + detection_results.extend(detections) + + return detection_results + + def _team_classify(self, detection_results, decoded_images, offset): + self.team_classifier_fitted = False + start = time.time() + # Collect player crops from first batch for fitting + fit_sample_size = 600 + player_crops_for_fit = [] + + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + if len(detection_box) < 4: + continue + # Collect player boxes for team classification fitting (first batch only) + if len(player_crops_for_fit) < fit_sample_size: + frame_image = decoded_images[frame_id] + for box in detection_box: + x1, y1, x2, y2, conf, cls_id = box.tolist() + if conf < 0.5: + continue + mapped_cls_id = str(int(cls_id)) + # Only collect player crops (cls_id = 2) + if mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + player_crops_for_fit.append(crop) + + # Fit team classifier after collecting samples + if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: + print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + break + if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: + print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + end = time.time() + print(f"Fitting Kmeans time: {end - start}") + + # Second pass: predict teams with configurable frame skipping optimization + start = time.time() + + # Get configuration for frame skipping + prediction_interval = 1 # Default: predict every 2 frames + iou_threshold = 0.3 + + print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") + + # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}} + predicted_frame_data = {} + + # Step 1: Predict for frames at prediction_interval only + frames_to_predict = [] + for frame_id in range(len(detection_results)): + if frame_id % prediction_interval == 0: + frames_to_predict.append(frame_id) + + print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " + f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") + + for frame_id in frames_to_predict: + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + + # Collect player crops for this frame + frame_player_crops = [] + frame_player_indices = [] + frame_player_boxes = [] + + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + # Collect player crops for prediction + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + frame_player_crops.append(crop) + frame_player_indices.append(idx) + frame_player_boxes.append((x1, y1, x2, y2)) + + # Predict teams for all players in this frame + if len(frame_player_crops) > 0: + team_ids = self.team_classifier.predict(frame_player_crops) + predicted_frame_data[frame_id] = {} + for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): + # Map team_id (0,1) to cls_id (6,7) + team_cls_id = str(6 + int(team_id)) + predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) + + # Step 2: Process all frames (interpolate skipped frames) + fallback_count = 0 + interpolated_count = 0 + bboxes: dict[int, list[BoundingBox]] = {} + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + boxes = [] + + team_predictions = {} + + if frame_id % prediction_interval == 0: + # Predicted frame: use pre-computed predictions + if frame_id in predicted_frame_data: + for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): + team_predictions[idx] = team_cls_id + else: + # Skipped frame: interpolate from neighboring predicted frames + # Find nearest predicted frames + prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval + next_predicted_frame = prev_predicted_frame + prediction_interval + + # Collect current frame player boxes + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + target_box = (x1, y1, x2, y2) + + # Try to match with previous predicted frame + best_team_id = None + best_iou = 0.0 + + if prev_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[prev_predicted_frame], + iou_threshold + ) + if team_id is not None: + best_team_id = team_id + best_iou = iou + + # Try to match with next predicted frame if available and no good match yet + if best_team_id is None and next_predicted_frame < len(detection_results): + if next_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[next_predicted_frame], + iou_threshold + ) + if team_id is not None and iou > best_iou: + best_team_id = team_id + best_iou = iou + + # Track interpolation success + if best_team_id is not None: + interpolated_count += 1 + else: + # Fallback: if no match found, predict individually + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + team_id = self.team_classifier.predict([crop])[0] + best_team_id = str(6 + int(team_id)) + fallback_count += 1 + + if best_team_id is not None: + team_predictions[idx] = best_team_id + + # Parse boxes with team classification + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + + # Check overlap with staff box + overlap_staff = False + for idy, boxy in enumerate(detection_box): + s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() + if cls_id == 2 and s_cls_id == 4: + staff_iou = self._calculate_iou(box[:4], boxy[:4]) + if staff_iou >= 0.8: + overlap_staff = True + break + if overlap_staff: + continue + + mapped_cls_id = str(int(cls_id)) + + # Override cls_id for players with team prediction + if idx in team_predictions: + mapped_cls_id = team_predictions[idx] + if mapped_cls_id != '4': + if int(mapped_cls_id) == 3 and conf < 0.5: + continue + boxes.append( + BoundingBox( + x1=int(x1), + y1=int(y1), + x2=int(x2), + y2=int(y2), + cls_id=int(mapped_cls_id), + conf=float(conf), + ) + ) + # Handle footballs - keep only the best one + footballs = [bb for bb in boxes if int(bb.cls_id) == 0] + if len(footballs) > 1: + best_ball = max(footballs, key=lambda b: b.conf) + boxes = [bb for bb in boxes if int(bb.cls_id) != 0] + boxes.append(best_ball) + + bboxes[offset + frame_id] = boxes + return bboxes + + + def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: + start = time.time() + detection_results = self._detect_objects_batch(batch_images) + end = time.time() + print(f"Detection time: {end - start}") + + # Use hybrid team classification + start = time.time() + bboxes = self._team_classify(detection_results, batch_images, offset) + end = time.time() + print(f"Team classify time: {end - start}") + + # Phase 3: Keypoint Detection + start = time.time() + keypoints_yolo: Dict[int, List[Tuple[int, int]]] = {} + + keypoints_yolo = self._detect_keypoints_batch(batch_images, offset, n_keypoints) + + + pitch_batch_size = min(self.pitch_batch_size, len(batch_images)) + keypoints: Dict[int, List[Tuple[int, int]]] = {} + + start = time.time() + + while True: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + device_str = "cuda" + keypoints_result = process_batch_input( + batch_images, + self.keypoints_model, + self.kp_threshold, + device_str, + batch_size=pitch_batch_size, + ) + if keypoints_result is not None and len(keypoints_result) > 0: + for frame_number_in_batch, kp_dict in enumerate(keypoints_result): + if frame_number_in_batch >= len(batch_images): + break + frame_keypoints: List[Tuple[int, int]] = [] + try: + height, width = batch_images[frame_number_in_batch].shape[:2] + if kp_dict is not None and isinstance(kp_dict, dict): + for idx in range(32): + x, y = 0, 0 + kp_idx = idx + 1 + if kp_idx in kp_dict: + try: + kp_data = kp_dict[kp_idx] + if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data: + x = int(kp_data["x"] * width) + y = int(kp_data["y"] * height) + except (KeyError, TypeError, ValueError): + pass + frame_keypoints.append((x, y)) + except (IndexError, ValueError, AttributeError): + frame_keypoints = [(0, 0)] * 32 + if len(frame_keypoints) < n_keypoints: + frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints))) + else: + frame_keypoints = frame_keypoints[:n_keypoints] + + # time1 = time.time() + # frame_keypoints_yolo = keypoints_yolo.get(offset + frame_number_in_batch, frame_keypoints) + + # valid_keypoints_count = 0 + # valid_keypoints_yolo_count = 0 + # for kp in frame_keypoints: + # if kp[0] != 0.0 or kp[1] != 0.0: + # valid_keypoints_count += 1 + # if valid_keypoints_count > 3: + # break + + # for kp in frame_keypoints_yolo: + # if kp[0] != 0.0 or kp[1] != 0.0: + # valid_keypoints_yolo_count += 1 + # if valid_keypoints_yolo_count > 3: + # break + + # # Evaluate and select best keypoints (using batch evaluation for speed) + # if valid_keypoints_count > 3 and valid_keypoints_yolo_count > 3: + # try: + # last_valid_keypoints = keypoints.get(offset + frame_number_in_batch - 1, frame_keypoints) + # # Evaluate both keypoint sets in batch (much faster!) + # scores = evaluate_keypoints_batch_for_frame( + # template_keypoints=self.template_keypoints, + # frame_keypoints_list=[frame_keypoints, frame_keypoints_yolo, last_valid_keypoints], + # frame=batch_images[frame_number_in_batch], + # floor_markings_template=self.template_image, + # device="cuda" + # ) + # score = scores[0] + # score_yolo = scores[1] + # last_score = scores[2] + + # if last_score > score and last_score > score_yolo: + # frame_keypoints = last_valid_keypoints + # if score_yolo > score: + # frame_keypoints = frame_keypoints_yolo + # last_score = score_yolo + # else: + # last_score = score + + # last_valid_keypoints = frame_keypoints + + # except Exception as e: + # # Fallback: use YOLO if available, otherwise use pitch model + # if valid_keypoints_yolo_count > 3: + # frame_keypoints = frame_keypoints_yolo + # elif valid_keypoints_yolo_count > 3: + # # Only YOLO has valid keypoints + # frame_keypoints = frame_keypoints_yolo + # else: + # if last_valid_keypoints is not None: + # frame_keypoints = last_valid_keypoints + + # time2 = time.time() + # print(f"Keypoint evaluation time: {time2 - time1}") + + keypoints[offset + frame_number_in_batch] = frame_keypoints + break + end = time.time() + print(f"Keypoint time: {end - start}") + + results: List[TVFrameResult] = [] + for frame_number in range(offset, offset + len(batch_images)): + frame_boxes = bboxes.get(frame_number, []) + result = TVFrameResult( + frame_id=frame_number, + boxes=frame_boxes, + keypoints=keypoints.get( + frame_number, + [(0, 0) for _ in range(n_keypoints)], + ), + ) + results.append(result) + + results_yolo: List[TVFrameResult] = [] + for frame_number in range(offset, offset + len(batch_images)): + frame_boxes = bboxes.get(frame_number, []) + result = TVFrameResult( + frame_id=frame_number, + boxes=frame_boxes, + keypoints=keypoints_yolo.get( + frame_number, + [(0, 0) for _ in range(n_keypoints)], + ), + ) + results_yolo.append(result) + + start = time.time() + if len(batch_images) > 0: + h, w = batch_images[0].shape[:2] + results = run_keypoints_post_processing_v2( + results, w, h, + frames=batch_images, + template_keypoints=self.template_keypoints, + floor_markings_template=self.template_image, + offset=offset + ) + results_yolo = run_keypoints_post_processing_v2( + results_yolo, w, h, + frames=batch_images, + template_keypoints=self.template_keypoints, + floor_markings_template=self.template_image, + offset=offset + ) + end = time.time() + print(f"Keypoint post processing time: {end - start}") + + final_keypoints: Dict[int, List[Tuple[int, int]]] = {} + + for frame_number_in_batch, (result, result_yolo) in enumerate(zip(results, results_yolo)): + frame_keypoints = result.keypoints + try: + if self.last_valid_keypoints is None: + self.last_valid_keypoints = final_keypoints.get(offset + frame_number_in_batch - 1, self.last_valid_keypoints) + # Evaluate both keypoint sets in batch (much faster!) + scores = evaluate_keypoints_batch_for_frame( + template_keypoints=self.template_keypoints, + frame_keypoints_list=[result.keypoints, result_yolo.keypoints, self.last_valid_keypoints], + frame=batch_images[frame_number_in_batch], + floor_markings_template=self.template_image, + device="cuda" + ) + score = scores[0] + score_yolo = scores[1] + self.last_score = scores[2] + + if self.last_score > score and self.last_score > score_yolo: + frame_keypoints = self.last_valid_keypoints + elif score_yolo > score: + frame_keypoints = result_yolo.keypoints + self.last_score = score_yolo + else: + self.last_score = score + + + except Exception as e: + # Fallback: use YOLO if available, otherwise use pitch model + print('Error: ', e) + + self.last_valid_keypoints = frame_keypoints + + final_keypoints[offset + frame_number_in_batch] = frame_keypoints + + + final_results: List[TVFrameResult] = [] + for frame_number in range(offset, offset + len(batch_images)): + frame_boxes = bboxes.get(frame_number, []) + result = TVFrameResult( + frame_id=frame_number, + boxes=frame_boxes, + keypoints=final_keypoints.get( + frame_number, + [(0, 0) for _ in range(n_keypoints)], + ), + ) + final_results.append(result) + + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + return final_results + + def _detect_keypoints_batch(self, batch_images: List[ndarray], + offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]: + """ + Phase 3: Keypoint detection for all frames in batch. + + Args: + batch_images: List of images to process + offset: Frame offset for numbering + n_keypoints: Number of keypoints expected + + Returns: + Dictionary mapping frame_id to list of keypoint coordinates + """ + keypoints: Dict[int, List[Tuple[int, int]]] = {} + keypoints_model_results = self.keypoints_model_yolo.predict(batch_images) + + if keypoints_model_results is None: + return keypoints + + for frame_idx_in_batch, detection in enumerate(keypoints_model_results): + if not hasattr(detection, "keypoints") or detection.keypoints is None: + continue + + # Extract keypoints with confidence + frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] + for i, part_points in enumerate(detection.keypoints.data): + for k_id, (x, y, _) in enumerate(part_points): + confidence = float(detection.keypoints.conf[i][k_id]) + frame_keypoints_with_conf.append((int(x), int(y), confidence)) + + # Pad or truncate to expected number of keypoints + if len(frame_keypoints_with_conf) < n_keypoints: + frame_keypoints_with_conf.extend( + [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) + ) + else: + frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] + + # Filter keypoints based on confidence thresholds + filtered_keypoints: List[Tuple[int, int]] = [] + for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): + if idx in self.CORNER_INDICES: + # Corner keypoints have lower confidence threshold + if confidence < 0.3: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + else: + # Regular keypoints + if confidence < 0.5: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + + frame_id = offset + frame_idx_in_batch + keypoints[frame_id] = filtered_keypoints + + return keypoints + + def predict_keypoints( + self, + images: List[ndarray], + n_keypoints: int = 32, + batch_size: Optional[int] = None, + conf_threshold: float = 0.5, + corner_conf_threshold: float = 0.3, + verbose: bool = False + ) -> Dict[int, List[Tuple[int, int]]]: + """ + Standalone function for keypoint detection on a list of images. + Optimized for maximum prediction speed. + + Args: + images: List of images (numpy arrays) to process + n_keypoints: Number of keypoints expected per frame (default: 32) + batch_size: Batch size for YOLO prediction (None = auto, uses all images) + conf_threshold: Confidence threshold for regular keypoints (default: 0.5) + corner_conf_threshold: Confidence threshold for corner keypoints (default: 0.3) + verbose: Whether to print progress information + + Returns: + Dictionary mapping frame index to list of keypoint coordinates (x, y) + Frame indices start from 0 + """ + if not images: + return {} + + keypoints: Dict[int, List[Tuple[int, int]]] = {} + + # Use provided batch_size or process all at once for maximum speed + if batch_size is None: + batch_size = len(images) + + # Process in batches for optimal GPU utilization + for batch_start in range(0, len(images), batch_size): + batch_end = min(batch_start + batch_size, len(images)) + batch_images = images[batch_start:batch_end] + + if verbose: + print(f"Processing keypoints batch {batch_start}-{batch_end-1} ({len(batch_images)} images)") + + # YOLO keypoint prediction (optimized batch processing) + keypoints_model_results = self.keypoints_model_yolo.predict( + batch_images, + verbose=False, + save=False, + conf=0.1, # Lower conf for detection, we filter later + ) + + if keypoints_model_results is None: + # Fill with empty keypoints for this batch + for frame_idx in range(batch_start, batch_end): + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Process each frame in the batch + for batch_idx, detection in enumerate(keypoints_model_results): + frame_idx = batch_start + batch_idx + + if not hasattr(detection, "keypoints") or detection.keypoints is None: + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Extract keypoints with confidence + frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] + try: + for i, part_points in enumerate(detection.keypoints.data): + for k_id, (x, y, _) in enumerate(part_points): + confidence = float(detection.keypoints.conf[i][k_id]) + frame_keypoints_with_conf.append((int(x), int(y), confidence)) + except (AttributeError, IndexError, TypeError): + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Pad or truncate to expected number of keypoints + if len(frame_keypoints_with_conf) < n_keypoints: + frame_keypoints_with_conf.extend( + [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) + ) + else: + frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] + + # Filter keypoints based on confidence thresholds + filtered_keypoints: List[Tuple[int, int]] = [] + for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): + if idx in self.CORNER_INDICES: + # Corner keypoints have lower confidence threshold + if confidence < corner_conf_threshold: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + else: + # Regular keypoints + if confidence < conf_threshold: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + + keypoints[frame_idx] = filtered_keypoints + + return keypoints + + def predict_objects( + self, + images: List[ndarray], + batch_size: Optional[int] = 16, + conf_threshold: float = 0.5, + iou_threshold: float = 0.45, + classes: Optional[List[int]] = None, + verbose: bool = False, + ) -> Dict[int, List[BoundingBox]]: + """ + Standalone high-throughput object detection function. + Runs the YOLO detector directly on raw images while skipping + any team-classification or keypoint stages for maximum FPS. + + Args: + images: List of frames (BGR numpy arrays). + batch_size: Number of frames per inference pass. Use None to process + all frames at once (fastest but highest memory usage). + conf_threshold: Detection confidence threshold. + iou_threshold: IoU threshold for NMS within YOLO. + classes: Optional list of class IDs to keep (None = all classes). + verbose: Whether to print per-batch progress from YOLO. + + Returns: + Dict mapping frame index -> list of BoundingBox predictions. + """ + if not images: + return {} + + detections: Dict[int, List[BoundingBox]] = {} + effective_batch = len(images) if batch_size is None else max(1, batch_size) + + for batch_start in range(0, len(images), effective_batch): + batch_end = min(batch_start + effective_batch, len(images)) + batch_images = images[batch_start:batch_end] + + start = time.time() + yolo_results = self.bbox_model( + batch_images, + conf=conf_threshold, + iou=iou_threshold, + classes=classes, + verbose=verbose, + save=False, + ) + end = time.time() + print(f"YOLO time: {end - start}") + + for local_idx, result in enumerate(yolo_results): + frame_idx = batch_start + local_idx + frame_boxes: List[BoundingBox] = [] + + if not hasattr(result, "boxes") or result.boxes is None: + detections[frame_idx] = frame_boxes + continue + + boxes_tensor = result.boxes.data + if boxes_tensor is None: + detections[frame_idx] = frame_boxes + continue + + for box in boxes_tensor: + try: + x1, y1, x2, y2, conf, cls_id = box.tolist() + frame_boxes.append( + BoundingBox( + x1=int(x1), + y1=int(y1), + x2=int(x2), + y2=int(y2), + cls_id=int(cls_id), + conf=float(conf), + ) + ) + except (ValueError, TypeError): + continue + + detections[frame_idx] = frame_boxes + + return detections + \ No newline at end of file diff --git a/miner3.py b/miner3.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8b4e0740808ec700c7bd75ca90fa66fde983e0 --- /dev/null +++ b/miner3.py @@ -0,0 +1,952 @@ +from pathlib import Path +from typing import List, Tuple, Dict, Optional +import sys +import os +import psutil + +from numpy import ndarray +from pydantic import BaseModel +from multiprocessing import cpu_count + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from keypoint_helper_v2_optimized import run_keypoints_post_processing + +from ultralytics import YOLO +from team_cluster import TeamClassifier +from utils import ( + BoundingBox, + Constants, +) + +import time +import torch +import gc +import cv2 +import numpy as np +from collections import defaultdict +from pitch import process_batch_input, get_cls_net +from keypoint_evaluation import ( + evaluate_keypoints_for_frame, + evaluate_keypoints_for_frame_gpu, + load_template_from_file, + evaluate_keypoints_for_frame_opencv_cuda, + evaluate_keypoints_batch_for_frame, +) + +import yaml + + +class BoundingBox(BaseModel): + x1: int + y1: int + x2: int + y2: int + cls_id: int + conf: float + + +class TVFrameResult(BaseModel): + frame_id: int + boxes: List[BoundingBox] + keypoints: List[Tuple[int, int]] + + +class Miner: + SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA + SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX + SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT + CORNER_INDICES = Constants.CORNER_INDICES + KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE + 0.3 + CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE + GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN + MIN_SAMPLES_FOR_FIT = 16 # Minimum player crops needed before fitting TeamClassifier + MAX_SAMPLES_FOR_FIT = 1000 # Maximum samples to avoid overfitting + + def __init__(self, path_hf_repo: Path) -> None: + try: + + device = "cuda" if torch.cuda.is_available() else "cpu" + model_path = path_hf_repo / "detection.onnx" + self.bbox_model = YOLO(model_path) + + print(f"BBox Model Loaded: class name {self.bbox_model.names}") + + team_model_path = path_hf_repo / "osnet_model.pth.tar-100" + self.team_classifier = TeamClassifier( + device=device, + batch_size=32, + model_name=str(team_model_path) + ) + print("Team Classifier Loaded") + + self.last_score = 0 + self.last_valid_keypoints = None + # Team classification state + self.team_classifier_fitted = False + self.player_crops_for_fit = [] + + self.keypoints_model_yolo = YOLO(path_hf_repo / "keypoint.pt") + + model_kp_path = path_hf_repo / 'keypoint' + config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml' + cfg_kp = yaml.safe_load(open(config_kp_path, 'r')) + + loaded_state_kp = torch.load(model_kp_path, map_location=device) + model = get_cls_net(cfg_kp) + model.load_state_dict(loaded_state_kp) + model.to(device) + model.eval() + + self.keypoints_model = model + print("Keypoints Model (keypoint.pt) Loaded") + + template_image_path = path_hf_repo / "football_pitch_template.png" + self.template_image, self.template_keypoints = load_template_from_file(str(template_image_path)) + + self.kp_threshold = 0.3 + self.pitch_batch_size = 4 + self.health = "healthy" + + print("✅ Keypoints Model Loaded") + except Exception as e: + self.health = "❌ Miner initialization failed: " + str(e) + print(self.health) + + def __repr__(self) -> str: + if self.health == 'healthy': + return ( + f"health: {self.health}\n" + f"BBox Model: {type(self.bbox_model).__name__}\n" + f"Keypoints Model: {type(self.keypoints_model).__name__}" + f"CPU Count: {cpu_count()}\n" + f"CPU Speed: {psutil.cpu_freq().current/1000:.2f} GHz" + ) + else: + return self.health + + def _calculate_iou(self, box1: Tuple[float, float, float, float], + box2: Tuple[float, float, float, float]) -> float: + """ + Calculate Intersection over Union (IoU) between two bounding boxes. + Args: + box1: (x1, y1, x2, y2) + box2: (x1, y1, x2, y2) + Returns: + IoU score (0-1) + """ + x1_1, y1_1, x2_1, y2_1 = box1 + x1_2, y1_2, x2_2, y2_2 = box2 + + # Calculate intersection area + x_left = max(x1_1, x1_2) + y_top = max(y1_1, y1_2) + x_right = min(x2_1, x2_2) + y_bottom = min(y2_1, y2_2) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # Calculate union area + box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) + box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) + union_area = box1_area + box2_area - intersection_area + + if union_area == 0: + return 0.0 + + return intersection_area / union_area + + def _extract_jersey_region(self, crop: ndarray) -> ndarray: + """ + Extract jersey region (upper body) from player crop. + For close-ups, focuses on upper 60%, for distant shots uses full crop. + """ + if crop is None or crop.size == 0: + return crop + + h, w = crop.shape[:2] + if h < 10 or w < 10: + return crop + + # For close-up shots, extract upper body (jersey region) + is_closeup = h > 100 or (h * w) > 12000 + if is_closeup: + # Upper 60% of the crop (jersey area, avoiding shorts) + jersey_top = 0 + jersey_bottom = int(h * 0.60) + jersey_left = max(0, int(w * 0.05)) + jersey_right = min(w, int(w * 0.95)) + return crop[jersey_top:jersey_bottom, jersey_left:jersey_right] + return crop + + def _extract_color_signature(self, crop: ndarray) -> Optional[np.ndarray]: + """ + Extract color signature from jersey region using HSV and LAB color spaces. + Returns a feature vector with dominant colors and color statistics. + """ + if crop is None or crop.size == 0: + return None + + jersey_region = self._extract_jersey_region(crop) + if jersey_region.size == 0: + return None + + try: + # Convert to HSV and LAB color spaces + hsv = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2HSV) + lab = cv2.cvtColor(jersey_region, cv2.COLOR_BGR2LAB) + + # Reshape for processing + hsv_flat = hsv.reshape(-1, 3).astype(np.float32) + lab_flat = lab.reshape(-1, 3).astype(np.float32) + + # Compute statistics for HSV + hsv_mean = np.mean(hsv_flat, axis=0) / 255.0 + hsv_std = np.std(hsv_flat, axis=0) / 255.0 + + # Compute statistics for LAB + lab_mean = np.mean(lab_flat, axis=0) / 255.0 + lab_std = np.std(lab_flat, axis=0) / 255.0 + + # Dominant color (most frequent hue) + hue_hist, _ = np.histogram(hsv_flat[:, 0], bins=36, range=(0, 180)) + dominant_hue = np.argmax(hue_hist) * 5 # Convert to hue value + + # Combine features + color_features = np.concatenate([ + hsv_mean, + hsv_std, + lab_mean[:2], # L and A channels (B is less informative) + lab_std[:2], + [dominant_hue / 180.0] # Normalized dominant hue + ]) + + return color_features + except Exception as e: + print(f"Error extracting color signature: {e}") + return None + + def _get_spatial_position(self, bbox: Tuple[float, float, float, float], + frame_width: int, frame_height: int) -> Tuple[float, float]: + """ + Get normalized spatial position of player on the pitch. + Returns (x_normalized, y_normalized) where 0,0 is top-left. + """ + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + + # Normalize to [0, 1] + x_norm = center_x / frame_width if frame_width > 0 else 0.5 + y_norm = center_y / frame_height if frame_height > 0 else 0.5 + + return (x_norm, y_norm) + + def _find_best_match(self, target_box: Tuple[float, float, float, float], + predicted_frame_data: Dict[int, Tuple[Tuple, str]], + iou_threshold: float) -> Tuple[Optional[str], float]: + """ + Find best matching box in predicted frame data using IoU. + Optimized with vectorized calculations when possible. + """ + if len(predicted_frame_data) == 0: + return (None, 0.0) + + # Vectorized IoU calculation for better performance + target_array = np.array(target_box, dtype=np.float32) + bboxes_array = np.array([bbox for bbox, _ in predicted_frame_data.values()], dtype=np.float32) + team_ids = [team_cls_id for _, team_cls_id in predicted_frame_data.values()] + + # Calculate IoU for all boxes at once using vectorization + # Extract coordinates + t_x1, t_y1, t_x2, t_y2 = target_array + b_x1 = bboxes_array[:, 0] + b_y1 = bboxes_array[:, 1] + b_x2 = bboxes_array[:, 2] + b_y2 = bboxes_array[:, 3] + + # Calculate intersection + x_left = np.maximum(t_x1, b_x1) + y_top = np.maximum(t_y1, b_y1) + x_right = np.minimum(t_x2, b_x2) + y_bottom = np.minimum(t_y2, b_y2) + + # Intersection area + intersection = np.maximum(0, x_right - x_left) * np.maximum(0, y_bottom - y_top) + + # Union area + target_area = (t_x2 - t_x1) * (t_y2 - t_y1) + bbox_areas = (b_x2 - b_x1) * (b_y2 - b_y1) + union = target_area + bbox_areas - intersection + + # IoU (avoid division by zero) + ious = np.where(union > 0, intersection / union, 0.0) + + # Find best match above threshold + valid_mask = ious >= iou_threshold + if np.any(valid_mask): + best_idx = np.argmax(ious) + if ious[best_idx] >= iou_threshold: + return (team_ids[best_idx], float(ious[best_idx])) + + return (None, 0.0) + + def _detect_objects_batch(self, decoded_images: List[ndarray]) -> Dict[int, List[BoundingBox]]: + batch_size = 16 + detection_results = [] + n_frames = len(decoded_images) + for frame_number in range(0, n_frames, batch_size): + batch_images = decoded_images[frame_number: frame_number + batch_size] + detections = self.bbox_model(batch_images, verbose=False, save=False) + detection_results.extend(detections) + + return detection_results + + def _team_classify(self, detection_results, decoded_images, offset): + self.team_classifier_fitted = False + start = time.time() + # Collect player crops from first batch for fitting + fit_sample_size = 1000 + player_crops_for_fit = [] + + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + if len(detection_box) < 4: + continue + # Collect player boxes for team classification fitting (first batch only) + if len(player_crops_for_fit) < fit_sample_size: + frame_image = decoded_images[frame_id] + for box in detection_box: + x1, y1, x2, y2, conf, cls_id = box.tolist() + if conf < 0.5: + continue + mapped_cls_id = str(int(cls_id)) + # Only collect player crops (cls_id = 2) + if mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + player_crops_for_fit.append(crop) + + # Fit team classifier after collecting samples + if self.team_classifier and not self.team_classifier_fitted and len(player_crops_for_fit) >= fit_sample_size: + print(f"Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + break + if not self.team_classifier_fitted and len(player_crops_for_fit) >= 16: + print(f"Fallback: Fitting TeamClassifier with {len(player_crops_for_fit)} player crops") + self.team_classifier.fit(player_crops_for_fit) + self.team_classifier_fitted = True + end = time.time() + print(f"Fitting Kmeans time: {end - start}") + + # Second pass: predict teams with configurable frame skipping optimization + start = time.time() + + # Get configuration for frame skipping + prediction_interval = 1 # Default: predict every 2 frames + iou_threshold = 0.3 + + print(f"Team classification - prediction_interval: {prediction_interval}, iou_threshold: {iou_threshold}") + + # Storage for predicted frame results: {frame_id: {box_idx: (bbox, team_id)}} + predicted_frame_data = {} + + # Step 1: Predict for frames at prediction_interval only + frames_to_predict = [] + for frame_id in range(len(detection_results)): + if frame_id % prediction_interval == 0: + frames_to_predict.append(frame_id) + + print(f"Predicting teams for {len(frames_to_predict)}/{len(detection_results)} frames " + f"(saving {100 - (len(frames_to_predict) * 100 // len(detection_results))}% compute)") + + for frame_id in frames_to_predict: + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + + # Collect player crops for this frame + frame_player_crops = [] + frame_player_indices = [] + frame_player_boxes = [] + + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + # Collect player crops for prediction + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + frame_player_crops.append(crop) + frame_player_indices.append(idx) + frame_player_boxes.append((x1, y1, x2, y2)) + + # Predict teams for all players in this frame + if len(frame_player_crops) > 0: + team_ids = self.team_classifier.predict(frame_player_crops) + predicted_frame_data[frame_id] = {} + for idx, bbox, team_id in zip(frame_player_indices, frame_player_boxes, team_ids): + # Map team_id (0,1) to cls_id (6,7) + team_cls_id = str(6 + int(team_id)) + predicted_frame_data[frame_id][idx] = (bbox, team_cls_id) + + # Step 2: Process all frames (interpolate skipped frames) + fallback_count = 0 + interpolated_count = 0 + bboxes: dict[int, list[BoundingBox]] = {} + for frame_id in range(len(detection_results)): + detection_box = detection_results[frame_id].boxes.data + frame_image = decoded_images[frame_id] + boxes = [] + + team_predictions = {} + + if frame_id % prediction_interval == 0: + # Predicted frame: use pre-computed predictions + if frame_id in predicted_frame_data: + for idx, (bbox, team_cls_id) in predicted_frame_data[frame_id].items(): + team_predictions[idx] = team_cls_id + else: + # Skipped frame: interpolate from neighboring predicted frames + # Find nearest predicted frames + prev_predicted_frame = (frame_id // prediction_interval) * prediction_interval + next_predicted_frame = prev_predicted_frame + prediction_interval + + # Collect current frame player boxes and fallback crops for batch prediction + fallback_crops = [] + fallback_indices = [] + + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + mapped_cls_id = str(int(cls_id)) + + if self.team_classifier and self.team_classifier_fitted and mapped_cls_id == '2': + target_box = (x1, y1, x2, y2) + + # Try to match with previous predicted frame + best_team_id = None + best_iou = 0.0 + + if prev_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[prev_predicted_frame], + iou_threshold + ) + if team_id is not None: + best_team_id = team_id + best_iou = iou + + # Try to match with next predicted frame if available and no good match yet + if best_team_id is None and next_predicted_frame < len(detection_results): + if next_predicted_frame in predicted_frame_data: + team_id, iou = self._find_best_match( + target_box, + predicted_frame_data[next_predicted_frame], + iou_threshold + ) + if team_id is not None and iou > best_iou: + best_team_id = team_id + best_iou = iou + + # Track interpolation success + if best_team_id is not None: + interpolated_count += 1 + team_predictions[idx] = best_team_id + else: + # Collect fallback crops for batch prediction + crop = frame_image[int(y1):int(y2), int(x1):int(x2)] + if crop.size > 0: + fallback_crops.append(crop) + fallback_indices.append(idx) + + # Batch predict all fallback crops at once (much faster than individual calls) + if len(fallback_crops) > 0: + fallback_team_ids = self.team_classifier.predict(fallback_crops) + for idx, team_id in zip(fallback_indices, fallback_team_ids): + team_predictions[idx] = str(6 + int(team_id)) + fallback_count += 1 + + # Pre-filter staff boxes once per frame (optimization) + staff_boxes = [] + for idy, boxy in enumerate(detection_box): + s_x1, s_y1, s_x2, s_y2, s_conf, s_cls_id = boxy.tolist() + if s_cls_id == 4: + staff_boxes.append((s_x1, s_y1, s_x2, s_y2)) + + # Pre-compute player boxes for vectorized staff overlap check (if many players) + player_boxes_for_staff_check = [] + player_indices_for_staff_check = [] + if len(staff_boxes) > 0: + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf >= 0.6: + player_boxes_for_staff_check.append((x1, y1, x2, y2)) + player_indices_for_staff_check.append(idx) + + # Vectorized staff overlap check if we have players and staff + staff_overlap_mask = set() + if len(staff_boxes) > 0 and len(player_boxes_for_staff_check) > 0: + # Use vectorized IoU calculation for all player-staff pairs + staff_array = np.array(staff_boxes, dtype=np.float32) + player_array = np.array(player_boxes_for_staff_check, dtype=np.float32) + + # Broadcast to compute all pairwise IoUs + for player_idx, player_box in enumerate(player_boxes_for_staff_check): + p_x1, p_y1, p_x2, p_y2 = player_box + s_x1 = staff_array[:, 0] + s_y1 = staff_array[:, 1] + s_x2 = staff_array[:, 2] + s_y2 = staff_array[:, 3] + + # Vectorized IoU calculation + x_left = np.maximum(p_x1, s_x1) + y_top = np.maximum(p_y1, s_y1) + x_right = np.minimum(p_x2, s_x2) + y_bottom = np.minimum(p_y2, s_y2) + + intersection = np.maximum(0, x_right - x_left) * np.maximum(0, y_bottom - y_top) + player_area = (p_x2 - p_x1) * (p_y2 - p_y1) + staff_areas = (s_x2 - s_x1) * (s_y2 - s_y1) + union = player_area + staff_areas - intersection + + ious = np.where(union > 0, intersection / union, 0.0) + if np.any(ious >= 0.8): + staff_overlap_mask.add(player_indices_for_staff_check[player_idx]) + + # Parse boxes with team classification + for idx, box in enumerate(detection_box): + x1, y1, x2, y2, conf, cls_id = box.tolist() + if cls_id == 2 and conf < 0.6: + continue + + # Check overlap with staff box (using pre-computed mask) + if idx in staff_overlap_mask: + continue + + mapped_cls_id = str(int(cls_id)) + + # Override cls_id for players with team prediction + if idx in team_predictions: + mapped_cls_id = team_predictions[idx] + if mapped_cls_id != '4': + if int(mapped_cls_id) == 3 and conf < 0.5: + continue + boxes.append( + BoundingBox( + x1=int(x1), + y1=int(y1), + x2=int(x2), + y2=int(y2), + cls_id=int(mapped_cls_id), + conf=float(conf), + ) + ) + # Handle footballs - keep only the best one + footballs = [bb for bb in boxes if int(bb.cls_id) == 0] + if len(footballs) > 1: + best_ball = max(footballs, key=lambda b: b.conf) + boxes = [bb for bb in boxes if int(bb.cls_id) != 0] + boxes.append(best_ball) + + bboxes[offset + frame_id] = boxes + return bboxes + + + def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]: + start = time.time() + detection_results = self._detect_objects_batch(batch_images) + end = time.time() + print(f"Detection time: {end - start}") + + # Use hybrid team classification + start = time.time() + bboxes = self._team_classify(detection_results, batch_images, offset) + end = time.time() + print(f"Team classify time: {end - start}") + + # Phase 3: Keypoint Detection + start = time.time() + + + pitch_batch_size = min(self.pitch_batch_size, len(batch_images)) + keypoints: Dict[int, List[Tuple[int, int]]] = {} + + start = time.time() + + while True: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + device_str = "cuda" + keypoints_result = process_batch_input( + batch_images, + self.keypoints_model, + self.kp_threshold, + device_str, + batch_size=pitch_batch_size, + ) + if keypoints_result is not None and len(keypoints_result) > 0: + for frame_number_in_batch, kp_dict in enumerate(keypoints_result): + if frame_number_in_batch >= len(batch_images): + break + frame_keypoints: List[Tuple[int, int]] = [] + try: + height, width = batch_images[frame_number_in_batch].shape[:2] + if kp_dict is not None and isinstance(kp_dict, dict): + for idx in range(32): + x, y = 0, 0 + kp_idx = idx + 1 + if kp_idx in kp_dict: + try: + kp_data = kp_dict[kp_idx] + if isinstance(kp_data, dict) and "x" in kp_data and "y" in kp_data: + x = int(kp_data["x"] * width) + y = int(kp_data["y"] * height) + except (KeyError, TypeError, ValueError): + pass + frame_keypoints.append((x, y)) + except (IndexError, ValueError, AttributeError): + frame_keypoints = [(0, 0)] * 32 + if len(frame_keypoints) < n_keypoints: + frame_keypoints.extend([(0, 0)] * (n_keypoints - len(frame_keypoints))) + else: + frame_keypoints = frame_keypoints[:n_keypoints] + + keypoints[offset + frame_number_in_batch] = frame_keypoints + break + end = time.time() + print(f"Keypoint time: {end - start}") + + results: List[TVFrameResult] = [] + for frame_number in range(offset, offset + len(batch_images)): + frame_boxes = bboxes.get(frame_number, []) + result = TVFrameResult( + frame_id=frame_number, + boxes=frame_boxes, + keypoints=keypoints.get( + frame_number, + [(0, 0) for _ in range(n_keypoints)], + ), + ) + results.append(result) + + start = time.time() + if len(batch_images) > 0: + h, w = batch_images[0].shape[:2] + results = run_keypoints_post_processing( + results, w, h, + frames=batch_images, + offset=offset, + template_keypoints=self.template_keypoints, + template_image=self.template_image, + ) + end = time.time() + print(f"Keypoint post processing time: {end - start}") + + final_keypoints: Dict[int, List[Tuple[int, int]]] = {} + + for frame_number_in_batch, result in enumerate(results): + frame_keypoints = result.keypoints + try: + if self.last_valid_keypoints is None: + self.last_valid_keypoints = final_keypoints.get(offset + frame_number_in_batch - 1, self.last_valid_keypoints) + # Evaluate both keypoint sets in batch (much faster!) + scores = evaluate_keypoints_batch_for_frame( + template_keypoints=self.template_keypoints, + frame_keypoints_list=[result.keypoints, self.last_valid_keypoints], + frame=batch_images[frame_number_in_batch], + floor_markings_template=self.template_image, + device="cuda" + ) + score = scores[0] + self.last_score = scores[1] + + if self.last_score > score: + frame_keypoints = self.last_valid_keypoints + else: + self.last_score = score + + + except Exception as e: + # Fallback: use YOLO if available, otherwise use pitch model + print('Error: ', e) + + self.last_valid_keypoints = frame_keypoints + + final_keypoints[offset + frame_number_in_batch] = frame_keypoints + + + final_results: List[TVFrameResult] = [] + for frame_number in range(offset, offset + len(batch_images)): + frame_boxes = bboxes.get(frame_number, []) + result = TVFrameResult( + frame_id=frame_number, + boxes=frame_boxes, + keypoints=final_keypoints.get( + frame_number, + [(0, 0) for _ in range(n_keypoints)], + ), + ) + final_results.append(result) + + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + return final_results + # return results + + def _detect_keypoints_batch(self, batch_images: List[ndarray], + offset: int, n_keypoints: int) -> Dict[int, List[Tuple[int, int]]]: + """ + Phase 3: Keypoint detection for all frames in batch. + + Args: + batch_images: List of images to process + offset: Frame offset for numbering + n_keypoints: Number of keypoints expected + + Returns: + Dictionary mapping frame_id to list of keypoint coordinates + """ + keypoints: Dict[int, List[Tuple[int, int]]] = {} + keypoints_model_results = self.keypoints_model_yolo.predict(batch_images) + + if keypoints_model_results is None: + return keypoints + + for frame_idx_in_batch, detection in enumerate(keypoints_model_results): + if not hasattr(detection, "keypoints") or detection.keypoints is None: + continue + + # Extract keypoints with confidence + frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] + for i, part_points in enumerate(detection.keypoints.data): + for k_id, (x, y, _) in enumerate(part_points): + confidence = float(detection.keypoints.conf[i][k_id]) + frame_keypoints_with_conf.append((int(x), int(y), confidence)) + + # Pad or truncate to expected number of keypoints + if len(frame_keypoints_with_conf) < n_keypoints: + frame_keypoints_with_conf.extend( + [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) + ) + else: + frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] + + # Filter keypoints based on confidence thresholds + filtered_keypoints: List[Tuple[int, int]] = [] + for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): + if idx in self.CORNER_INDICES: + # Corner keypoints have lower confidence threshold + if confidence < 0.3: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + else: + # Regular keypoints + if confidence < 0.5: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + + frame_id = offset + frame_idx_in_batch + keypoints[frame_id] = filtered_keypoints + + return keypoints + + def predict_keypoints( + self, + images: List[ndarray], + n_keypoints: int = 32, + batch_size: Optional[int] = None, + conf_threshold: float = 0.5, + corner_conf_threshold: float = 0.3, + verbose: bool = False + ) -> Dict[int, List[Tuple[int, int]]]: + """ + Standalone function for keypoint detection on a list of images. + Optimized for maximum prediction speed. + + Args: + images: List of images (numpy arrays) to process + n_keypoints: Number of keypoints expected per frame (default: 32) + batch_size: Batch size for YOLO prediction (None = auto, uses all images) + conf_threshold: Confidence threshold for regular keypoints (default: 0.5) + corner_conf_threshold: Confidence threshold for corner keypoints (default: 0.3) + verbose: Whether to print progress information + + Returns: + Dictionary mapping frame index to list of keypoint coordinates (x, y) + Frame indices start from 0 + """ + if not images: + return {} + + keypoints: Dict[int, List[Tuple[int, int]]] = {} + + # Use provided batch_size or process all at once for maximum speed + if batch_size is None: + batch_size = len(images) + + # Process in batches for optimal GPU utilization + for batch_start in range(0, len(images), batch_size): + batch_end = min(batch_start + batch_size, len(images)) + batch_images = images[batch_start:batch_end] + + if verbose: + print(f"Processing keypoints batch {batch_start}-{batch_end-1} ({len(batch_images)} images)") + + # YOLO keypoint prediction (optimized batch processing) + keypoints_model_results = self.keypoints_model_yolo.predict( + batch_images, + verbose=False, + save=False, + conf=0.1, # Lower conf for detection, we filter later + ) + + if keypoints_model_results is None: + # Fill with empty keypoints for this batch + for frame_idx in range(batch_start, batch_end): + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Process each frame in the batch + for batch_idx, detection in enumerate(keypoints_model_results): + frame_idx = batch_start + batch_idx + + if not hasattr(detection, "keypoints") or detection.keypoints is None: + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Extract keypoints with confidence + frame_keypoints_with_conf: List[Tuple[int, int, float]] = [] + try: + for i, part_points in enumerate(detection.keypoints.data): + for k_id, (x, y, _) in enumerate(part_points): + confidence = float(detection.keypoints.conf[i][k_id]) + frame_keypoints_with_conf.append((int(x), int(y), confidence)) + except (AttributeError, IndexError, TypeError): + keypoints[frame_idx] = [(0, 0)] * n_keypoints + continue + + # Pad or truncate to expected number of keypoints + if len(frame_keypoints_with_conf) < n_keypoints: + frame_keypoints_with_conf.extend( + [(0, 0, 0.0)] * (n_keypoints - len(frame_keypoints_with_conf)) + ) + else: + frame_keypoints_with_conf = frame_keypoints_with_conf[:n_keypoints] + + # Filter keypoints based on confidence thresholds + filtered_keypoints: List[Tuple[int, int]] = [] + for idx, (x, y, confidence) in enumerate(frame_keypoints_with_conf): + if idx in self.CORNER_INDICES: + # Corner keypoints have lower confidence threshold + if confidence < corner_conf_threshold: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + else: + # Regular keypoints + if confidence < conf_threshold: + filtered_keypoints.append((0, 0)) + else: + filtered_keypoints.append((int(x), int(y))) + + keypoints[frame_idx] = filtered_keypoints + + return keypoints + + def predict_objects( + self, + images: List[ndarray], + batch_size: Optional[int] = 16, + conf_threshold: float = 0.5, + iou_threshold: float = 0.45, + classes: Optional[List[int]] = None, + verbose: bool = False, + ) -> Dict[int, List[BoundingBox]]: + """ + Standalone high-throughput object detection function. + Runs the YOLO detector directly on raw images while skipping + any team-classification or keypoint stages for maximum FPS. + + Args: + images: List of frames (BGR numpy arrays). + batch_size: Number of frames per inference pass. Use None to process + all frames at once (fastest but highest memory usage). + conf_threshold: Detection confidence threshold. + iou_threshold: IoU threshold for NMS within YOLO. + classes: Optional list of class IDs to keep (None = all classes). + verbose: Whether to print per-batch progress from YOLO. + + Returns: + Dict mapping frame index -> list of BoundingBox predictions. + """ + if not images: + return {} + + detections: Dict[int, List[BoundingBox]] = {} + effective_batch = len(images) if batch_size is None else max(1, batch_size) + + for batch_start in range(0, len(images), effective_batch): + batch_end = min(batch_start + effective_batch, len(images)) + batch_images = images[batch_start:batch_end] + + start = time.time() + yolo_results = self.bbox_model( + batch_images, + conf=conf_threshold, + iou=iou_threshold, + classes=classes, + verbose=verbose, + save=False, + ) + end = time.time() + print(f"YOLO time: {end - start}") + + for local_idx, result in enumerate(yolo_results): + frame_idx = batch_start + local_idx + frame_boxes: List[BoundingBox] = [] + + if not hasattr(result, "boxes") or result.boxes is None: + detections[frame_idx] = frame_boxes + continue + + boxes_tensor = result.boxes.data + if boxes_tensor is None: + detections[frame_idx] = frame_boxes + continue + + for box in boxes_tensor: + try: + x1, y1, x2, y2, conf, cls_id = box.tolist() + frame_boxes.append( + BoundingBox( + x1=int(x1), + y1=int(y1), + x2=int(x2), + y2=int(y2), + cls_id=int(cls_id), + conf=float(conf), + ) + ) + except (ValueError, TypeError): + continue + + detections[frame_idx] = frame_boxes + + return detections + \ No newline at end of file diff --git a/object-detection.onnx b/object-detection.onnx new file mode 100644 index 0000000000000000000000000000000000000000..169b7f74931fe53b342cd91349776c3f2308b7bf --- /dev/null +++ b/object-detection.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05112479be8cb59494e9ae23a57af43becd5aa1f448b0e5ed33fcb6b4c2bbbc3 +size 273322667 diff --git a/osnet_ain.pyc b/osnet_ain.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3f3da762ddd7bae2d1d20b60e3c4c866a0694a Binary files /dev/null and b/osnet_ain.pyc differ diff --git a/osnet_model.pth.tar-100 b/osnet_model.pth.tar-100 new file mode 100644 index 0000000000000000000000000000000000000000..b2cd761d202c765f6dc44e318a03b71ecdf665da --- /dev/null +++ b/osnet_model.pth.tar-100 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64873ef0e8abf28df31facd113f27634e2d085a2dcf8d19123409b1d0e2566c8 +size 36189526 diff --git a/pitch.py b/pitch.py new file mode 100644 index 0000000000000000000000000000000000000000..3764f6f0547e73fd08642f938f9ad989bdaa1253 --- /dev/null +++ b/pitch.py @@ -0,0 +1,687 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import time +from typing import List, Optional, Tuple + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as f +from pydantic import BaseModel + +import logging +logger = logging.getLogger(__name__) + + +class BoundingBox(BaseModel): + x1: int + y1: int + x2: int + y2: int + cls_id: int + conf: float + + +class TVFrameResult(BaseModel): + frame_id: int + boxes: list[BoundingBox] + keypoints: list[tuple[int, int]] + +BatchNorm2d = nn.BatchNorm2d +BN_MOMENTUM = 0.1 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) + # nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[x[i].shape[2], x[i].shape[3]], + mode='bilinear') + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, config, **kwargs): + self.inplanes = 64 + extra = config['MODEL']['EXTRA'] + super(HighResolutionNet, self).__init__() + + # stem net + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = BatchNorm2d(self.inplanes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.sf = nn.Softmax(dim=1) + self.layer1 = self._make_layer(Bottleneck, 64, 64, 4) + + self.stage2_cfg = extra['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = extra['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = extra['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + final_inp_channels = sum(pre_stage_channels) + self.inplanes + + self.head = nn.Sequential(nn.Sequential( + nn.Conv2d( + in_channels=final_inp_channels, + out_channels=final_inp_channels, + kernel_size=1), + BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels=final_inp_channels, + out_channels=config['MODEL']['NUM_JOINTS'], + kernel_size=extra['FINAL_CONV_KERNEL']), + nn.Softmax(dim=1))) + + + + def _make_head(self, x, x_skip): + x = self.upsample(x) + x = torch.cat([x, x_skip], dim=1) + x = self.head(x) + + return x + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + # h, w = x.size(2), x.size(3) + x = self.conv1(x) + x_skip = x.clone() + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Head Part + height, width = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(height, width), mode='bilinear', align_corners=False) + x2 = F.interpolate(x[2], size=(height, width), mode='bilinear', align_corners=False) + x3 = F.interpolate(x[3], size=(height, width), mode='bilinear', align_corners=False) + x = torch.cat([x[0], x1, x2, x3], 1) + x = self._make_head(x, x_skip) + + return x + + def init_weights(self, pretrained=''): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + #nn.init.normal_(m.weight, std=0.001) + #nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if pretrained != '': + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + else: + sys.exit(f'Weights {pretrained} not found.') + + +def get_cls_net(config, pretrained='', **kwargs): + """Create keypoint detection model with softmax activation""" + model = HighResolutionNet(config, **kwargs) + model.init_weights(pretrained) + return model + + +def get_cls_net_l(config, pretrained='', **kwargs): + """Create line detection model with sigmoid activation""" + model = HighResolutionNet(config, **kwargs) + model.init_weights(pretrained) + + # After loading weights, replace just the activation function + # The saved model expects the nested Sequential structure + inner_seq = model.head[0] + # Replace softmax (index 4) with sigmoid + model.head[0][4] = nn.Sigmoid() + + return model + +# Simplified utility functions - removed complex Gaussian generation functions +# These were mainly used for training data generation, not inference + + + +# generate_gaussian_array_vectorized_dist_l function removed - not used in current implementation +@torch.inference_mode() +def run_inference(model, input_tensor: torch.Tensor, device): + input_tensor = input_tensor.to(device).to(memory_format=torch.channels_last) + output = model.module().forward(input_tensor) + return output + +def preprocess_batch_fast(frames): + """Ultra-fast batch preprocessing using optimized tensor operations""" + target_size = (540, 960) # H, W format for model input + batch = [] + for i, frame in enumerate(frames): + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = cv2.resize(frame_rgb, (target_size[1], target_size[0])) + img = img.astype(np.float32) / 255.0 + img = np.transpose(img, (2, 0, 1)) # HWC -> CHW + batch.append(img) + batch = torch.from_numpy(np.stack(batch)).float() + + return batch + +def extract_keypoints_from_heatmap(heatmap: torch.Tensor, scale: int = 2, max_keypoints: int = 1): + """Optimized keypoint extraction from heatmaps""" + batch_size, n_channels, height, width = heatmap.shape + + # Find local maxima using max pooling (keep on GPU) + kernel = 3 + pad = 1 + max_pooled = F.max_pool2d(heatmap, kernel, stride=1, padding=pad) + local_maxima = (max_pooled == heatmap) + heatmap = heatmap * local_maxima + + # Get top keypoints (keep on GPU longer) + scores, indices = torch.topk(heatmap.view(batch_size, n_channels, -1), max_keypoints, sorted=False) + y_coords = torch.div(indices, width, rounding_mode="floor") + x_coords = indices % width + + # Optimized tensor operations + x_coords = x_coords * scale + y_coords = y_coords * scale + + # Create result tensor directly on GPU + results = torch.stack([x_coords.float(), y_coords.float(), scores], dim=-1) + + return results + + +def extract_keypoints_from_heatmap_fast(heatmap: torch.Tensor, scale: int = 2, max_keypoints: int = 1): + """Ultra-fast keypoint extraction optimized for speed""" + batch_size, n_channels, height, width = heatmap.shape + + # Simplified local maxima detection (faster but slightly less accurate) + max_pooled = F.max_pool2d(heatmap, 3, stride=1, padding=1) + local_maxima = (max_pooled == heatmap) + + # Apply mask and get top keypoints in one go + masked_heatmap = heatmap * local_maxima + flat_heatmap = masked_heatmap.view(batch_size, n_channels, -1) + scores, indices = torch.topk(flat_heatmap, max_keypoints, dim=-1, sorted=False) + + # Vectorized coordinate calculation + y_coords = torch.div(indices, width, rounding_mode="floor") * scale + x_coords = (indices % width) * scale + + # Stack results efficiently + results = torch.stack([x_coords.float(), y_coords.float(), scores], dim=-1) + return results + + +def process_keypoints_vectorized(kp_coords, kp_threshold, w, h, batch_size): + """Ultra-fast vectorized keypoint processing""" + batch_results = [] + + # Convert to numpy once for faster CPU operations + kp_np = kp_coords.cpu().numpy() + + for batch_idx in range(batch_size): + kp_dict = {} + # Vectorized threshold check + valid_kps = kp_np[batch_idx, :, 0, 2] > kp_threshold + valid_indices = np.where(valid_kps)[0] + + for ch_idx in valid_indices: + x = float(kp_np[batch_idx, ch_idx, 0, 0]) / w + y = float(kp_np[batch_idx, ch_idx, 0, 1]) / h + p = float(kp_np[batch_idx, ch_idx, 0, 2]) + kp_dict[ch_idx + 1] = {'x': x, 'y': y, 'p': p} + + batch_results.append(kp_dict) + + return batch_results + +def inference_batch(frames, model, kp_threshold, device, batch_size=8): + """Optimized batch inference for multiple frames""" + results = [] + num_frames = len(frames) + + # Get the device from the model itself + model_device = next(model.parameters()).device + + # Process all frames in optimally-sized batches + for i in range(0, num_frames, batch_size): + current_batch_size = min(batch_size, num_frames - i) + batch_frames = frames[i:i + current_batch_size] + + # Fast preprocessing - create on CPU first + batch = preprocess_batch_fast(batch_frames) + b, c, h, w = batch.size() + + # Move batch to model device + batch = batch.to(model_device) + + with torch.no_grad(): + heatmaps = model(batch) + + # Ultra-fast keypoint extraction + kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1) + + # Vectorized batch processing - no loops + batch_results = process_keypoints_vectorized(kp_coords, kp_threshold, 960, 540, current_batch_size) + results.extend(batch_results) + + # Minimal cleanup + del heatmaps, kp_coords, batch + + return results + +# Keypoint mapping from detection indices to standard football pitch keypoint IDs +map_keypoints = { + 1: 1, 2: 14, 3: 25, 4: 2, 5: 10, 6: 18, 7: 26, 8: 3, 9: 7, 10: 23, + 11: 27, 20: 4, 21: 8, 22: 24, 23: 28, 24: 5, 25: 13, 26: 21, 27: 29, + 28: 6, 29: 17, 30: 30, 31: 11, 32: 15, 33: 19, 34: 12, 35: 16, 36: 20, + 45: 9, 50: 31, 52: 32, 57: 22 +} + +def get_mapped_keypoints(kp_points): + """Apply keypoint mapping to detection results""" + mapped_points = {} + for key, value in kp_points.items(): + if key in map_keypoints: + mapped_key = map_keypoints[key] + mapped_points[mapped_key] = value + # else: + # Keep unmapped keypoints with original key + # mapped_points[key] = value + return mapped_points + +def process_batch_input(frames, model, kp_threshold, device, batch_size=8): + """Process multiple input images in batch""" + # Batch inference + kp_results = inference_batch(frames, model, kp_threshold, device, batch_size) + kp_results = [get_mapped_keypoints(kp) for kp in kp_results] + # Draw results and save + # for i, (frame, kp_points, input_path) in enumerate(zip(frames, kp_results, valid_paths)): + # height, width = frame.shape[:2] + + # # Apply mapping to get standard keypoint IDs + # mapped_kp_points = get_mapped_keypoints(kp_points) + + # for key, value in mapped_kp_points.items(): + # x = int(value['x'] * width) + # y = int(value['y'] * height) + # cv2.circle(frame, (x, y), 5, (0, 255, 0), -1) # Green circles + # cv2.putText(frame, str(key), (x+10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) + + # # Save result + # output_path = input_path.replace('.png', '_result.png').replace('.jpg', '_result.jpg') + # cv2.imwrite(output_path, frame) + + # print(f"Batch processing complete. Processed {len(frames)} images.") + + return kp_results \ No newline at end of file diff --git a/player.pt b/player.pt new file mode 100644 index 0000000000000000000000000000000000000000..4b852fde60015132ece56c0ffa8968114ffd4af6 --- /dev/null +++ b/player.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce9fc31f61e6f156f786077abb8eef36b0836bda1ef07d1d0ba82d43ae0ecd0b +size 22540152 diff --git a/player.py b/player.py new file mode 100644 index 0000000000000000000000000000000000000000..dd05e65b8979cf98266a6855a58aca77c1e9270f --- /dev/null +++ b/player.py @@ -0,0 +1,389 @@ +import cv2 +import numpy as np +from sklearn.cluster import KMeans +import warnings +import time + +import torch +from torchvision.ops import batched_nms +from numpy import ndarray +# Suppress ALL runtime and sklearn warnings +warnings.filterwarnings('ignore', category=RuntimeWarning) +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=UserWarning) + +# Suppress sklearn warnings specifically +import logging +logging.getLogger('sklearn').setLevel(logging.ERROR) + +def get_grass_color(img): + # Convert image to HSV color space + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + + # Define range of green color in HSV + lower_green = np.array([30, 40, 40]) + upper_green = np.array([80, 255, 255]) + + # Threshold the HSV image to get only green colors + mask = cv2.inRange(hsv, lower_green, upper_green) + + # Calculate the mean value of the pixels that are not masked + masked_img = cv2.bitwise_and(img, img, mask=mask) + grass_color = cv2.mean(img, mask=mask) + return grass_color[:3] + +def get_players_boxes(frame, result): + players_imgs = [] + players_boxes = [] + for (box, score, cls) in result: + label = int(cls) + if label == 0: + x1, y1, x2, y2 = box.astype(int) + player_img = frame[y1: y2, x1: x2] + players_imgs.append(player_img) + players_boxes.append([box, score, cls]) + return players_imgs, players_boxes + +def get_kits_colors(players, grass_hsv=None, frame=None): + kits_colors = [] + if grass_hsv is None: + grass_color = get_grass_color(frame) + grass_hsv = cv2.cvtColor(np.uint8([[list(grass_color)]]), cv2.COLOR_BGR2HSV) + + for player_img in players: + # Skip empty or invalid images + if player_img is None or player_img.size == 0 or len(player_img.shape) != 3: + continue + + # Convert image to HSV color space + hsv = cv2.cvtColor(player_img, cv2.COLOR_BGR2HSV) + + # Define range of green color in HSV + lower_green = np.array([grass_hsv[0, 0, 0] - 10, 40, 40]) + upper_green = np.array([grass_hsv[0, 0, 0] + 10, 255, 255]) + + # Threshold the HSV image to get only green colors + mask = cv2.inRange(hsv, lower_green, upper_green) + + # Bitwise-AND mask and original image + mask = cv2.bitwise_not(mask) + upper_mask = np.zeros(player_img.shape[:2], np.uint8) + upper_mask[0:player_img.shape[0]//2, 0:player_img.shape[1]] = 255 + mask = cv2.bitwise_and(mask, upper_mask) + + kit_color = np.array(cv2.mean(player_img, mask=mask)[:3]) + + kits_colors.append(kit_color) + return kits_colors + +def get_kits_classifier(kits_colors): + if len(kits_colors) == 0: + return None + if len(kits_colors) == 1: + # Only one kit color, create a dummy classifier + return None + kits_kmeans = KMeans(n_clusters=2) + kits_kmeans.fit(kits_colors) + return kits_kmeans + +def classify_kits(kits_classifer, kits_colors): + if kits_classifer is None or len(kits_colors) == 0: + return np.array([0]) # Default to team 0 + team = kits_classifer.predict(kits_colors) + return team + +def get_left_team_label(players_boxes, kits_colors, kits_clf): + left_team_label = 0 + team_0 = [] + team_1 = [] + + for i in range(len(players_boxes)): + x1, y1, x2, y2 = players_boxes[i][0].astype(int) + team = classify_kits(kits_clf, [kits_colors[i]]).item() + if team == 0: + team_0.append(np.array([x1])) + else: + team_1.append(np.array([x1])) + + team_0 = np.array(team_0) + team_1 = np.array(team_1) + + # Safely calculate averages with fallback for empty arrays + avg_team_0 = np.average(team_0) if len(team_0) > 0 else 0 + avg_team_1 = np.average(team_1) if len(team_1) > 0 else 0 + + if avg_team_0 - avg_team_1 > 0: + left_team_label = 1 + + return left_team_label + +def check_box_boundaries(boxes, img_height, img_width): + """ + Check if bounding boxes are within image boundaries and clip them if necessary. + + Args: + boxes: numpy array of shape (N, 4) with [x1, y1, x2, y2] format + img_height: height of the image + img_width: width of the image + + Returns: + valid_boxes: numpy array of valid boxes within boundaries + valid_indices: indices of valid boxes + """ + x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + + # Check if boxes are within boundaries + valid_mask = (x1 >= 0) & (y1 >= 0) & (x2 < img_width) & (y2 < img_height) & (x1 < x2) & (y1 < y2) + + if not np.any(valid_mask): + return np.array([]), np.array([]) + + valid_boxes = boxes[valid_mask] + valid_indices = np.where(valid_mask)[0] + + # Clip boxes to image boundaries + valid_boxes[:, 0] = np.clip(valid_boxes[:, 0], 0, img_width - 1) # x1 + valid_boxes[:, 1] = np.clip(valid_boxes[:, 1], 0, img_height - 1) # y1 + valid_boxes[:, 2] = np.clip(valid_boxes[:, 2], 0, img_width - 1) # x2 + valid_boxes[:, 3] = np.clip(valid_boxes[:, 3], 0, img_height - 1) # y2 + + return valid_boxes, valid_indices + +def process_team_identification_batch(frames, results, kits_clf, left_team_label, grass_hsv): + """ + Process team identification and label formatting for batch results. + + Args: + frames: list of frames + results: list of detection results for each frame + kits_clf: trained kit classifier + left_team_label: label for left team + grass_hsv: grass color in HSV format + + Returns: + processed_results: list of processed results with team identification + """ + processed_results = [] + + for frame_idx, frame in enumerate(frames): + frame_results = [] + frame_detections = results[frame_idx] + + if not frame_detections: + processed_results.append([]) + continue + + # Extract player boxes and images + players_imgs = [] + players_boxes = [] + player_indices = [] + + for idx, (box, score, cls) in enumerate(frame_detections): + label = int(cls) + if label == 0: # Player detection + x1, y1, x2, y2 = box.astype(int) + + # Check boundaries + if (x1 >= 0 and y1 >= 0 and x2 < frame.shape[1] and y2 < frame.shape[0] and x1 < x2 and y1 < y2): + player_img = frame[y1:y2, x1:x2] + if player_img.size > 0: # Ensure valid image + players_imgs.append(player_img) + players_boxes.append([box, score, cls]) + player_indices.append(idx) + + # Initialize player team mapping + player_team_map = {} + + # Process team identification if we have players + if players_imgs and kits_clf is not None: + kits_colors = get_kits_colors(players_imgs, grass_hsv) + teams = classify_kits(kits_clf, kits_colors) + + # Create mapping from player index to team + for i, team in enumerate(teams): + player_team_map[player_indices[i]] = team.item() + + id = 0 + # Process all detections with team identification + for idx, (box, score, cls) in enumerate(frame_detections): + label = int(cls) + x1, y1, x2, y2 = box.astype(int) + + # Check boundaries + valid_boxes, valid_indices = check_box_boundaries( + np.array([[x1, y1, x2, y2]]), frame.shape[0], frame.shape[1] + ) + + if len(valid_boxes) == 0: + continue + + x1, y1, x2, y2 = valid_boxes[0].astype(int) + + # Apply team identification logic + if label == 0: # Player + if players_imgs and kits_clf is not None and idx in player_team_map: + team = player_team_map[idx] + if team == left_team_label: + final_label = 6 # Player-L (Left team) + else: + final_label = 7 # Player-R (Right team) + else: + final_label = 6 # Default player label + + elif label == 1: # Goalkeeper + final_label = 1 # GK + + elif label == 2: # Ball + final_label = 0 # Ball + + elif label == 3 or label == 4: # Referee or other + final_label = 3 # Referee + + else: + continue + # final_label = int(label) # Keep original label, ensure it's int + + frame_results.append({ + "id": int(id), + "bbox": [int(x1), int(y1), int(x2), int(y2)], + "class_id": int(final_label), + "conf": float(score) + }) + id = id + 1 + + processed_results.append(frame_results) + + return processed_results + +def convert_numpy_types(obj): + """Convert numpy types to native Python types for JSON serialization.""" + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, dict): + return {key: convert_numpy_types(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_types(item) for item in obj] + else: + return obj + +def pre_process_img(frames, scale): + imgs = np.stack([cv2.resize(frame, (int(scale), int(scale))) for frame in frames]) + imgs = imgs.transpose(0, 3, 1, 2) + imgs = imgs.astype(np.float32) / 255.0 # Normalize + return imgs + +def post_process_output(outputs, x_scale, y_scale, conf_thresh=0.6, nms_thresh=0.75): + B, C, N = outputs.shape + outputs = torch.from_numpy(outputs) + outputs = outputs.permute(0, 2, 1) + boxes = outputs[..., :4] + class_scores = 1 / (1 + torch.exp(-outputs[..., 4:])) + conf, class_id = class_scores.max(dim=2) + + mask = conf > conf_thresh + + for i in range(class_id.shape[0]): # loop over batch + # Find detections that are balls + ball_idx = np.where(class_id[i] == 2)[0] + if ball_idx.size > 0: + # Pick the one with the highest confidence + top = ball_idx[np.argmax(conf[i, ball_idx])] + if conf[i, top] > 0.55: # apply confidence threshold + mask[i, top] = True + + # ball_mask = (class_id == 2) & (conf > 0.51) + # mask = mask | ball_mask + + batch_idx, pred_idx = mask.nonzero(as_tuple=True) + + if len(batch_idx) == 0: + return [[] for _ in range(B)] + + boxes = boxes[batch_idx, pred_idx] + conf = conf[batch_idx, pred_idx] + class_id = class_id[batch_idx, pred_idx] + + x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + x1 = (x - w / 2) * x_scale + y1 = (y - h / 2) * y_scale + x2 = (x + w / 2) * x_scale + y2 = (y + h / 2) * y_scale + boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1) + + max_coord = 1e4 + offset = batch_idx.to(boxes_xyxy) * max_coord + boxes_for_nms = boxes_xyxy + offset[:, None] + + keep = batched_nms(boxes_for_nms, conf, batch_idx, nms_thresh) + + boxes_final = boxes_xyxy[keep] + conf_final = conf[keep] + class_final = class_id[keep] + batch_final = batch_idx[keep] + + results = [[] for _ in range(B)] + for b in range(B): + mask_b = batch_final == b + if mask_b.sum() == 0: + continue + results[b] = list(zip(boxes_final[mask_b].numpy(), + conf_final[mask_b].numpy(), + class_final[mask_b].numpy())) + return results + +def player_detection_result(frames: list[ndarray], batch_size, model, kits_clf=None, left_team_label=None, grass_hsv=None): + start_time = time.time() + # input_layer = model.input(0) + # output_layer = model.output(0) + height, width = frames[0].shape[:2] + scale = 640.0 + x_scale = width / scale + y_scale = height / scale + + # infer_queue = AsyncInferQueue(model, len(frames)) + + infer_time = time.time() + kits_clf = kits_clf + left_team_label = left_team_label + grass_hsv = grass_hsv + results = [] + for i in range(0, len(frames), batch_size): + if i + batch_size > len(frames): + batch_size = len(frames) - i + batch_frames = frames[i:i + batch_size] + imgs = pre_process_img(batch_frames, scale) + + input_name = model.get_inputs()[0].name + outputs = model.run(None, {input_name: imgs})[0] + raw_results = post_process_output(np.array(outputs), x_scale, y_scale) + + if kits_clf is None or left_team_label is None or grass_hsv is None: + # Use first frame to initialize team classification + first_frame = batch_frames[0] + first_frame_results = raw_results[0] if raw_results else [] + + if first_frame_results: + players_imgs, players_boxes = get_players_boxes(first_frame, first_frame_results) + if players_imgs: + grass_color = get_grass_color(first_frame) + grass_hsv = cv2.cvtColor(np.uint8([[list(grass_color)]]), cv2.COLOR_BGR2HSV) + kits_colors = get_kits_colors(players_imgs, grass_hsv) + if kits_colors: # Only proceed if we have valid kit colors + kits_clf = get_kits_classifier(kits_colors) + if kits_clf is not None: + left_team_label = int(get_left_team_label(players_boxes, kits_colors, kits_clf)) + + # Process team identification and boundary checking + processed_results = process_team_identification_batch( + batch_frames, raw_results, kits_clf, left_team_label, grass_hsv + ) + + processed_results = convert_numpy_types(processed_results) + results.extend(processed_results) + + # Return the same format as before for compatibility + return results, kits_clf, left_team_label, grass_hsv \ No newline at end of file diff --git a/team_cluster.pyc b/team_cluster.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65a79a86f7949cf2cf91eedc7c9ab22a6a83e1f2 Binary files /dev/null and b/team_cluster.pyc differ diff --git a/test_predict_batch.py b/test_predict_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..5143dc2fc566b8cf29e85ef5cd0f75ee2bb5b13a --- /dev/null +++ b/test_predict_batch.py @@ -0,0 +1,770 @@ +import argparse +import time +from collections import defaultdict +from pathlib import Path +from typing import List, Tuple, Dict + +import cv2 +import numpy as np + +from miner3 import Miner, TVFrameResult, BoundingBox +from keypoint_evaluation import ( + evaluate_keypoints_for_frame, + evaluate_keypoints_for_frame_opencv_cuda, + evaluate_keypoints_batch_gpu, + load_template_from_file, + project_image_using_keypoints, + extract_masks_for_ground_and_lines, + extract_mask_of_ground_lines_in_image, + extract_masks_for_ground_and_lines_no_validation, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run Miner.predict_batch on a video and visualize results." + ) + parser.add_argument( + "--repo-path", + type=Path, + default="", + help="Path to the HuggingFace/SecretVision repository (models, configs).", + ) + parser.add_argument( + "--video-path", + type=Path, + default="2025_06_28_e40fec95_39d4f90f11cd419b89c620a6442d37_1414c99f.mp4", + help="Path to the input video file.", + ) + parser.add_argument( + "--output-video", + type=Path, + default='outputs/annotated.mp4', + help="Optional path to save an annotated video.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default='outputs/frames', + help="Optional directory to dump annotated frames.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Number of frames per predict_batch call.", + ) + parser.add_argument( + "--stride", + type=int, + default=1, + help="Sample every Nth frame from the video.", + ) + parser.add_argument( + "--max-frames", + type=int, + default=None, + help="Maximum number of frames to process (after stride).", + ) + parser.add_argument( + "--visualize-keypoints", + type=Path, + default="outputs/keypoints_visualizations", + help="Optional directory to save keypoint evaluation visualizations (warped template + original template for all frames).", + ) + parser.add_argument( + "--n-keypoints", + type=int, + default=32, + help="Number of keypoints Miner should return per frame.", + ) + parser.add_argument( + "--template-image", + type=Path, + default='football_pitch_template.png', + help="Path to football pitch template image (default: football_pitch_template.png in repo path).", + ) + return parser.parse_args() + + +def draw_keypoints(frame: np.ndarray, keypoints: List[Tuple[int, int]]) -> None: + for x, y in keypoints: + if x == 0 and y == 0: + continue + cv2.circle(frame, (x, y), radius=2, color=(0, 255, 255), thickness=-1) + + +def draw_boxes(frame: np.ndarray, boxes: List[BoundingBox]) -> None: + color_map = { + 0: (0, 255, 255), # football + 1: (0, 165, 255), # referee + 2: (0, 255, 0), # generic player + 3: (255, 0, 0), # goalkeeper + 4: (128, 128, 128), # staff + 5: (255, 255, 0), # coach/etc. + 6: (255, 0, 255), # team A + 7: (0, 128, 255), # team B + } + for box in boxes: + color = color_map.get(box.cls_id, (255, 255, 255)) + cv2.rectangle(frame, (box.x1, box.y1), (box.x2, box.y2), color, 2) + label = f"{box.cls_id}:{box.conf:.2f}" + cv2.putText( + frame, + label, + (box.x1, max(10, box.y1 - 5)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + color, + 1, + lineType=cv2.LINE_AA, + ) + + +def annotate_frame(frame: np.ndarray, result: TVFrameResult) -> np.ndarray: + annotated = frame.copy() + draw_boxes(annotated, result.boxes) + draw_keypoints(annotated, result.keypoints) + cv2.putText( + annotated, + f"Frame {result.frame_id}", + (10, 20), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (255, 255, 255), + 2, + lineType=cv2.LINE_AA, + ) + return annotated + + +def ensure_output_dir(path: Path) -> None: + if path is not None: + path.mkdir(parents=True, exist_ok=True) + + +def aggregate_stats(results: List[TVFrameResult]) -> Dict[str, float]: + class_counts = defaultdict(int) + team_counts = defaultdict(int) + total_boxes = 0 + for res in results: + for box in res.boxes: + class_counts[box.cls_id] += 1 + if box.cls_id in (6, 7): + team_counts[box.cls_id] += 1 + total_boxes += 1 + stats = { + "frames": len(results), + "boxes": total_boxes, + } + for cls_id, count in sorted(class_counts.items()): + stats[f"class_{cls_id}_count"] = count + for team_id, count in sorted(team_counts.items()): + stats[f"team_{team_id}_count"] = count + return stats + + +def visualize_keypoint_evaluation( + frame: np.ndarray, + frame_keypoints: List[Tuple[int, int]], + template_image: np.ndarray, + template_keypoints: List[Tuple[int, int]], + score: float, + output_path: Path, + frame_id: int, +) -> np.ndarray: + """ + Visualize keypoint evaluation by drawing warped template and original template side by side. + + Args: + frame: Original frame image + frame_keypoints: Keypoints detected in the frame + template_image: Original template image + template_keypoints: Template keypoints + score: Evaluation score + output_path: Path to save the visualization + frame_id: Frame ID for labeling + + Returns: + Visualization image with warped template and original template side by side + """ + # Try to warp template to frame, but handle twisted projection gracefully + warped_template = None + mask_lines_expected = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + mask_lines_predicted = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + is_twisted = False + + try: + # Warp template to frame + warped_template = project_image_using_keypoints( + image=template_image, + source_keypoints=template_keypoints, + destination_keypoints=frame_keypoints, + destination_width=frame.shape[1], + destination_height=frame.shape[0], + ) + + # Extract masks for visualization + try: + mask_ground, mask_lines_expected = extract_masks_for_ground_and_lines( + image=warped_template + ) + mask_lines_predicted = extract_mask_of_ground_lines_in_image( + image=frame, ground_mask=mask_ground + ) + except Exception as e: + # If mask extraction fails, create empty masks + mask_lines_expected = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + mask_lines_predicted = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + except Exception as e: + # If warping fails (e.g., twisted projection), create a blank warped template + # but still draw keypoints + is_twisted = "twisted" in str(e).lower() or "Projection twisted" in str(e) + warped_template = None + print(f"Warning: Could not warp template for frame {frame_id}: {e}") + + # Always create visualization, even if warping failed + # Resize template to match frame height for side-by-side display + template_resized = cv2.resize( + template_image, + (int(template_image.shape[1] * frame.shape[0] / template_image.shape[0]), frame.shape[0]) + ) + + # Create side-by-side visualization: Frame | Warped Template | Original Template + h, w = frame.shape[:2] + template_h, template_w = template_resized.shape[:2] + spacing = 10 + vis_width = w + spacing + w + spacing + template_w + 20 # Frame + spacing + Warped + spacing + Template + margin + # Calculate number of non-zero keypoints to determine height needed + # Include all keypoints except (0, 0) which means "not detected" + num_valid_keypoints = sum(1 for x, y in frame_keypoints if not (x == 0 and y == 0)) + max_lines_per_column = 10 + num_columns = (num_valid_keypoints + max_lines_per_column - 1) // max_lines_per_column + keypoint_text_height = 55 + min(max_lines_per_column, num_valid_keypoints) * 18 # Base offset + lines + vis_height = max(h, template_h) + max(60, keypoint_text_height) # Extra space for text and keypoints + + visualization = np.ones((vis_height, vis_width, 3), dtype=np.uint8) * 255 + + # Place frame on left + frame_with_mask = frame.copy() + # Overlay predicted lines (green) on frame + mask_predicted_colored = np.zeros_like(frame_with_mask) + mask_predicted_colored[:, :, 1] = mask_lines_predicted * 255 # Green channel + frame_with_mask = cv2.addWeighted(frame_with_mask, 0.7, mask_predicted_colored, 0.3, 0) + visualization[:h, :w] = frame_with_mask + + # Place warped template in middle + warped_x = w + spacing + if warped_template is not None: + warped_with_mask = warped_template.copy() + # Overlay expected lines (blue) on warped template + mask_expected_colored = np.zeros_like(warped_with_mask) + mask_expected_colored[:, :, 0] = mask_lines_expected * 255 # Blue channel + warped_with_mask = cv2.addWeighted(warped_with_mask, 0.7, mask_expected_colored, 0.3, 0) + # Also overlay predicted lines (green) for comparison + mask_predicted_colored_warped = np.zeros_like(warped_with_mask) + mask_predicted_colored_warped[:, :, 1] = mask_lines_predicted * 255 # Green channel + warped_with_mask = cv2.addWeighted(warped_with_mask, 0.8, mask_predicted_colored_warped, 0.2, 0) + visualization[:h, warped_x:warped_x+w] = warped_with_mask + else: + # If warping failed, show a blank/error image + error_img = np.zeros((h, w, 3), dtype=np.uint8) + cv2.putText( + error_img, "Warping Failed", (w//4, h//2), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2 + ) + visualization[:h, warped_x:warped_x+w] = error_img + + # Place original template on right + template_x = warped_x + w + spacing + visualization[:template_h, template_x:template_x+template_w] = template_resized + + # Draw keypoints on frame (ALWAYS draw, even if warping failed) + # Only skip (0, 0) which means "not detected", but allow negative coordinates + for i, (x, y) in enumerate(frame_keypoints): + if not (x == 0 and y == 0): + # Clamp coordinates to visualization bounds for drawing + draw_x = max(0, min(x, vis_width - 1)) + draw_y = max(0, min(y, vis_height - 1)) + cv2.circle(visualization, (draw_x, draw_y), 5, (0, 255, 0), -1) + cv2.putText( + visualization, str(i+1), (draw_x+8, draw_y-8), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1 + ) + + # Add labels and score + cv2.putText( + visualization, "Original Frame (Green=Predicted Lines)", (10, h + 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2 + ) + warped_label = f"Warped Template (Blue=Expected, Green=Predicted, Score: {score:.3f})" + if is_twisted: + warped_label += " [TWISTED]" + cv2.putText( + visualization, warped_label, (warped_x, h + 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255) if is_twisted else (0, 0, 0), 2 + ) + cv2.putText( + visualization, "Original Template", (template_x, template_h + 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2 + ) + + cv2.putText( + visualization, f"Frame {frame_id}", (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2 + ) + + # Display keypoint coordinates at bottom left of whole image + line_height = 18 + font_scale = 0.4 + font_thickness = 1 + + # Format keypoints: show index and coordinates for non-zero keypoints + keypoint_lines = [] + for i, (x, y) in enumerate(frame_keypoints): + # Include all keypoints except (0, 0) which means "not detected" + # Display negative coordinates as well + if not (x == 0 and y == 0): + keypoint_lines.append(f"KP{i+1}: ({x},{y})") + + # Display keypoints in columns to save space, starting from bottom + max_lines_per_column = 10 + num_columns = (len(keypoint_lines) + max_lines_per_column - 1) // max_lines_per_column + column_width = 150 + + # Starting y position from bottom + start_y_bottom = vis_height - 10 # Start 10 pixels from bottom + + for col_idx in range(num_columns): + start_idx = col_idx * max_lines_per_column + end_idx = min(start_idx + max_lines_per_column, len(keypoint_lines)) + x_pos = 10 + col_idx * column_width + column_lines = keypoint_lines[start_idx:end_idx] + num_lines_in_column = len(column_lines) + + for line_idx, kp_line in enumerate(column_lines): + # Calculate y position from bottom (working upwards) + # Last line in column is at start_y_bottom, first line is above it + y_pos = start_y_bottom - (num_lines_in_column - line_idx - 1) * line_height + cv2.putText( + visualization, kp_line, (x_pos, y_pos), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), font_thickness + ) + + # Save visualization + output_path.parent.mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(output_path), visualization) + + return visualization + + +def evaluate_keypoints_batch( + results: List[TVFrameResult], + original_frames: Dict[int, np.ndarray], + template_image: np.ndarray, + template_keypoints: List[Tuple[int, int]], + visualization_output_dir: Path = None, +) -> Dict[str, float]: + """ + Evaluate keypoint accuracy for a batch of results. + + Args: + results: List of TVFrameResult objects with keypoints + original_frames: Dictionary mapping frame_id to frame image + template_image: Template image for evaluation + template_keypoints: Template keypoints + visualization_output_dir: Optional directory to save visualization images for all frames + + Returns: + Dictionary with keypoint evaluation statistics + """ + frame_scores = [] + valid_frames = 0 + + for result in results: + frame_id = result.frame_id + if frame_id not in original_frames: + continue + + frame_image = original_frames[frame_id] + frame_keypoints = result.keypoints + + # Need at least 4 valid keypoints for homography + valid_keypoints = [kp for kp in frame_keypoints if kp[0] != 0.0 or kp[1] != 0.0] + if len(valid_keypoints) < 4: + score = 0.0 + frame_scores.append(score) + # Still visualize even if invalid + if visualization_output_dir: + vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}_invalid.jpg" + visualize_keypoint_evaluation( + frame=frame_image, + frame_keypoints=frame_keypoints, + template_image=template_image, + template_keypoints=template_keypoints, + score=score, + output_path=vis_path, + frame_id=frame_id, + ) + continue + + if len(frame_keypoints) != len(template_keypoints): + score = 0.0 + frame_scores.append(score) + # Still visualize even if mismatch + if visualization_output_dir: + vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}_mismatch.jpg" + visualize_keypoint_evaluation( + frame=frame_image, + frame_keypoints=frame_keypoints, + template_image=template_image, + template_keypoints=template_keypoints, + score=score, + output_path=vis_path, + frame_id=frame_id, + ) + continue + + try: + score = evaluate_keypoints_for_frame( + template_keypoints=template_keypoints, + frame_keypoints=frame_keypoints, + frame=frame_image, + floor_markings_template=template_image.copy(), + ) + print(f'Frame {frame_id} score: {score}') + frame_scores.append(score) + valid_frames += 1 + + # Visualize all frames + if visualization_output_dir: + vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}.jpg" + visualize_keypoint_evaluation( + frame=frame_image, + frame_keypoints=frame_keypoints, + template_image=template_image, + template_keypoints=template_keypoints, + score=score, + output_path=vis_path, + frame_id=frame_id, + ) + except Exception as e: + print(f"Error evaluating keypoints for frame {frame_id}: {e}") + score = 0.0 + frame_scores.append(score) + # Visualize failed frames too + if visualization_output_dir: + vis_path = visualization_output_dir / f"frame_{frame_id:06d}_score_{score:.3f}_error.jpg" + visualize_keypoint_evaluation( + frame=frame_image, + frame_keypoints=frame_keypoints, + template_image=template_image, + template_keypoints=template_keypoints, + score=score, + output_path=vis_path, + frame_id=frame_id, + ) + + if len(frame_scores) == 0: + return { + "keypoint_avg_score": 0.0, + "keypoint_valid_frames": 0, + "keypoint_total_frames": len(results), + } + + return { + "keypoint_avg_score": sum(frame_scores) / len(frame_scores), + "keypoint_max_score": max(frame_scores), + "keypoint_min_score": min(frame_scores), + "keypoint_valid_frames": valid_frames, + "keypoint_total_frames": len(results), + "keypoint_frames_above_0.5": sum(1 for s in frame_scores if s > 0.5), + "keypoint_frames_above_0.7": sum(1 for s in frame_scores if s > 0.7), + } + + +def evaluate_keypoints_batch_fast( + results: List[TVFrameResult], + original_frames: Dict[int, np.ndarray], + template_image: np.ndarray, + template_keypoints: List[Tuple[int, int]], + batch_size: int = 32, +) -> Dict[str, float]: + """ + Fast batch GPU evaluation of keypoint accuracy for multiple frames simultaneously. + + This function uses batch GPU processing to evaluate frames in smaller batches, + which is 5-10x faster than sequential evaluation while avoiding memory issues. + + Args: + results: List of TVFrameResult objects + original_frames: Dictionary mapping frame_id to frame image + template_image: Template image for evaluation + template_keypoints: Template keypoints + batch_size: Number of frames to process in each GPU batch (default: 8) + + Returns: + Dictionary with keypoint evaluation statistics + """ + # Prepare batch data + frame_keypoints_list = [] + frames_list = [] + result_indices = [] + + for idx, result in enumerate(results): + frame_id = result.frame_id + if frame_id not in original_frames: + continue + + frame_image = original_frames[frame_id] + frame_keypoints = result.keypoints + + # Need at least 4 valid keypoints for homography + valid_keypoints = [kp for kp in frame_keypoints if kp[0] != 0.0 or kp[1] != 0.0] + if len(valid_keypoints) < 4: + continue + + if len(frame_keypoints) != len(template_keypoints): + continue + + frame_keypoints_list.append(frame_keypoints) + frames_list.append(frame_image) + result_indices.append(idx) + + if len(frames_list) == 0: + return { + "keypoint_avg_score": 0.0, + "keypoint_valid_frames": 0, + "keypoint_total_frames": len(results), + } + + # Process in smaller batches to avoid memory issues + all_scores = [] + all_result_indices = [] + + num_batches = (len(frames_list) + batch_size - 1) // batch_size + + for batch_idx in range(num_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, len(frames_list)) + + batch_frames = frames_list[start_idx:end_idx] + batch_keypoints = frame_keypoints_list[start_idx:end_idx] + batch_indices = result_indices[start_idx:end_idx] + + # Use batch GPU evaluation for this chunk + try: + scores_batch = evaluate_keypoints_batch_gpu( + template_keypoints=template_keypoints, + frame_keypoints_list=batch_keypoints, + frames=batch_frames, + floor_markings_template=template_image, + device="cuda", + ) + all_scores.extend(scores_batch) + all_result_indices.extend(batch_indices) + except Exception as e: + print(f"Error in batch GPU evaluation (batch {batch_idx + 1}/{num_batches}): {e}, falling back to sequential for this batch") + # Fallback to sequential evaluation for this batch + for frame_keypoints, frame_image, result_idx in zip(batch_keypoints, batch_frames, batch_indices): + try: + score = evaluate_keypoints_for_frame_opencv_cuda( + template_keypoints=template_keypoints, + frame_keypoints=frame_keypoints, + frame=frame_image, + floor_markings_template=template_image.copy(), + ) + all_scores.append(score) + all_result_indices.append(result_idx) + except Exception as e2: + print(f"Error evaluating keypoints: {e2}") + all_scores.append(0.0) + all_result_indices.append(result_idx) + + # Map scores back to all results (0.0 for frames that weren't evaluated) + frame_scores = [0.0] * len(results) + valid_frames = 0 + for result_idx, score in zip(all_result_indices, all_scores): + frame_scores[result_idx] = score + if score > 0.0: + valid_frames += 1 + + if len([s for s in frame_scores if s > 0.0]) == 0: + return { + "keypoint_avg_score": 0.0, + "keypoint_valid_frames": 0, + "keypoint_total_frames": len(results), + } + + # Calculate statistics only on valid scores + valid_scores = [s for s in frame_scores if s > 0.0] + + return { + "keypoint_avg_score": sum(valid_scores) / len(valid_scores) if valid_scores else 0.0, + "keypoint_max_score": max(valid_scores) if valid_scores else 0.0, + "keypoint_min_score": min(valid_scores) if valid_scores else 0.0, + "keypoint_valid_frames": valid_frames, + "keypoint_total_frames": len(results), + "keypoint_frames_above_0.5": sum(1 for s in valid_scores if s > 0.5), + "keypoint_frames_above_0.7": sum(1 for s in valid_scores if s > 0.7), + } + + +def process_batches( + miner: Miner, + frames: List[np.ndarray], + frame_ids: List[int], + n_keypoints: int, +) -> List[TVFrameResult]: + start = time.time() + results = miner.predict_batch(frames, offset=frame_ids[0], n_keypoints=n_keypoints) + end = time.time() + print( + f"[Batch frames {frame_ids[0]}-{frame_ids[-1]}] " + f"predict_batch latency: {end - start:.2f}s " + f"({len(frames) / (end - start + 1e-6):.2f} FPS)" + ) + return results + + +def main() -> None: + args = parse_args() + miner = Miner(args.repo_path) + + cap = cv2.VideoCapture(str(args.video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Unable to open video: {args.video_path}") + + ensure_output_dir(args.output_dir) + + # Get video dimensions + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Determine template image path + if args.template_image: + template_image_path = args.template_image + else: + # Use default: football_pitch_template.png in repo path + template_image_path = args.repo_path / "football_pitch_template.png" + + if not template_image_path.exists(): + raise ValueError( + f"Template image not found: {template_image_path}. " + f"Please provide --template-image path or place football_pitch_template.png in repo path." + ) + + # Load template for keypoint evaluation + print(f"Loading template from {template_image_path}") + template_image, template_keypoints = load_template_from_file(str(template_image_path)) + print(f"Loaded template with {len(template_keypoints)} keypoints") + + writer = None + if args.output_video: + args.output_video.parent.mkdir(parents=True, exist_ok=True) + writer = cv2.VideoWriter( + str(args.output_video), + cv2.VideoWriter_fourcc(*"mp4v"), + fps / args.stride, + (width, height), + ) + + processed_results: List[TVFrameResult] = [] + frames_buffer: List[np.ndarray] = [] + frame_ids_buffer: List[int] = [] + original_frames: Dict[int, np.ndarray] = {} # Store original frames for evaluation + processed_frames = 0 + source_frame_idx = 0 + + start_time = time.time() + while True: + ret, frame = cap.read() + if not ret: + break + if source_frame_idx % args.stride != 0: + source_frame_idx += 1 + continue + + frames_buffer.append(frame) + frame_ids_buffer.append(source_frame_idx) + original_frames[source_frame_idx] = frame.copy() # Store for evaluation + processed_frames += 1 + source_frame_idx += 1 + + if args.max_frames and processed_frames >= args.max_frames: + break + if len(frames_buffer) < args.batch_size: + continue + + batch_results = process_batches( + miner, frames_buffer, frame_ids_buffer, args.n_keypoints + ) + processed_results.extend(batch_results) + for res, original in zip(batch_results, frames_buffer): + annotated = annotate_frame(original, res) + if writer: + writer.write(annotated) + if args.output_dir: + frame_path = args.output_dir / f"frame_{res.frame_id:06d}.jpg" + cv2.imwrite(str(frame_path), annotated) + frames_buffer.clear() + frame_ids_buffer.clear() + + # Flush remaining frames + if frames_buffer: + batch_results = process_batches( + miner, frames_buffer, frame_ids_buffer, args.n_keypoints + ) + processed_results.extend(batch_results) + for res, original in zip(batch_results, frames_buffer): + annotated = annotate_frame(original, res) + if writer: + writer.write(annotated) + if args.output_dir: + frame_path = args.output_dir / f"frame_{res.frame_id:06d}.jpg" + cv2.imwrite(str(frame_path), annotated) + + cap.release() + if writer: + writer.release() + + stats = aggregate_stats(processed_results) + + end_time = time.time() + print(f"Total time taken: {end_time - start_time:.2f} seconds") + + # Evaluate keypoints (using fast batch GPU evaluation) + time_start = time.time() + print("\n===== Evaluating Keypoints =====") + keypoint_stats = evaluate_keypoints_batch( + processed_results, + original_frames, + template_image, + template_keypoints, + visualization_output_dir=args.visualize_keypoints, + ) + time_end = time.time() + print(f"Keypoint evaluation time: {time_end - time_start:.2f} seconds") + + print("\n===== Summary =====") + for key, value in stats.items(): + print(f"{key}: {value}") + if stats["frames"]: + avg_boxes = stats["boxes"] / stats["frames"] + print(f"Average boxes per frame: {avg_boxes:.2f}") + + print("\n===== Keypoint Evaluation =====") + for key, value in keypoint_stats.items(): + print(f"{key}: {value}") + if keypoint_stats["keypoint_total_frames"] > 0: + valid_ratio = keypoint_stats["keypoint_valid_frames"] / keypoint_stats["keypoint_total_frames"] + print(f"Keypoint evaluation success rate: {valid_ratio:.2%}") + + print("Done.") + + +if __name__ == "__main__": + main() + + diff --git a/test_predict_keypoints_video.py b/test_predict_keypoints_video.py new file mode 100644 index 0000000000000000000000000000000000000000..84a918af8c0f14c929a5e175b91432792842ed46 --- /dev/null +++ b/test_predict_keypoints_video.py @@ -0,0 +1,335 @@ +import argparse +import time +from pathlib import Path +from typing import List, Dict, Tuple +import sys +import os + +import cv2 +import numpy as np + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from miner import Miner + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Test keypoint prediction on video file with maximum speed optimization." + ) + parser.add_argument( + "--repo-path", + type=Path, + default="", + help="Path to the HuggingFace/SecretVision repository (models, configs).", + ) + parser.add_argument( + "--video-path", + type=Path, + default="test.mp4", + help="Path to the input video file.", + ) + parser.add_argument( + "--output-video", + type=Path, + default="outputs-keypoints/annotated.mp4", + help="Optional path to save an annotated video with keypoints.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default="outputs-keypoints/frames", + help="Optional directory to save annotated frames.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size for keypoint prediction (None = auto, processes all frames at once for max speed).", + ) + parser.add_argument( + "--stride", + type=int, + default=1, + help="Sample every Nth frame from the video (1 = all frames).", + ) + parser.add_argument( + "--max-frames", + type=int, + default=None, + help="Maximum number of frames to process (after stride).", + ) + parser.add_argument( + "--n-keypoints", + type=int, + default=32, + help="Number of keypoints expected per frame.", + ) + parser.add_argument( + "--conf-threshold", + type=float, + default=0.5, + help="Confidence threshold for regular keypoints.", + ) + parser.add_argument( + "--corner-conf-threshold", + type=float, + default=0.3, + help="Confidence threshold for corner keypoints.", + ) + parser.add_argument( + "--no-visualization", + action="store_true", + help="Skip visualization to maximize speed.", + ) + return parser.parse_args() + + +def draw_keypoints(frame: np.ndarray, keypoints: List[Tuple[int, int]], + color: Tuple[int, int, int] = (0, 255, 255)) -> None: + """Draw keypoints on frame.""" + for x, y in keypoints: + if x == 0 and y == 0: + continue + cv2.circle(frame, (x, y), radius=3, color=color, thickness=-1) + cv2.circle(frame, (x, y), radius=5, color=(0, 0, 0), thickness=1) + + +def annotate_frame(frame: np.ndarray, keypoints: List[Tuple[int, int]], + frame_id: int) -> np.ndarray: + """Annotate frame with keypoints and frame ID.""" + annotated = frame.copy() + draw_keypoints(annotated, keypoints) + + # Count valid keypoints + valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0) + + # Draw frame info + info_text = f"Frame {frame_id} | Keypoints: {valid_count}/{len(keypoints)}" + cv2.putText( + annotated, + info_text, + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + lineType=cv2.LINE_AA, + ) + return annotated + + +def load_video_frames(video_path: Path, stride: int = 1, max_frames: int = None) -> List[np.ndarray]: + """Load frames from video file.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Unable to open video: {video_path}") + + frames = [] + frame_count = 0 + source_frame_idx = 0 + + print(f"Loading frames from video: {video_path}") + while True: + ret, frame = cap.read() + if not ret: + break + + if source_frame_idx % stride != 0: + source_frame_idx += 1 + continue + + frames.append(frame) + frame_count += 1 + source_frame_idx += 1 + + if max_frames and frame_count >= max_frames: + break + + if frame_count % 100 == 0: + print(f"Loaded {frame_count} frames...") + + cap.release() + print(f"Total frames loaded: {len(frames)}") + return frames + + +def save_results( + frames: List[np.ndarray], + keypoints_dict: Dict[int, List[Tuple[int, int]]], + output_video: Path = None, + output_dir: Path = None, + fps: float = 25.0, + width: int = None, + height: int = None, +) -> None: + """Save annotated frames and/or video.""" + if output_video is None and output_dir is None: + return + + if width is None or height is None: + height, width = frames[0].shape[:2] + + writer = None + if output_video: + output_video.parent.mkdir(parents=True, exist_ok=True) + writer = cv2.VideoWriter( + str(output_video), + cv2.VideoWriter_fourcc(*"mp4v"), + fps, + (width, height), + ) + print(f"Saving annotated video to: {output_video}") + + if output_dir: + output_dir.mkdir(parents=True, exist_ok=True) + print(f"Saving annotated frames to: {output_dir}") + + for frame_idx, frame in enumerate(frames): + keypoints = keypoints_dict.get(frame_idx, []) + annotated = annotate_frame(frame, keypoints, frame_idx) + + if writer: + writer.write(annotated) + + if output_dir: + frame_path = output_dir / f"frame_{frame_idx:06d}.jpg" + cv2.imwrite(str(frame_path), annotated) + + if (frame_idx + 1) % 100 == 0: + print(f"Saved {frame_idx + 1}/{len(frames)} frames...") + + if writer: + writer.release() + print(f"Video saved: {output_video}") + + +def calculate_statistics(keypoints_dict: Dict[int, List[Tuple[int, int]]]) -> Dict[str, float]: + """Calculate keypoint detection statistics.""" + total_frames = len(keypoints_dict) + if total_frames == 0: + return { + "total_frames": 0, + "avg_valid_keypoints": 0.0, + "max_valid_keypoints": 0, + "min_valid_keypoints": 0, + "frames_with_keypoints": 0, + } + + valid_counts = [] + frames_with_keypoints = 0 + + for keypoints in keypoints_dict.values(): + valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0) + valid_counts.append(valid_count) + if valid_count > 0: + frames_with_keypoints += 1 + + return { + "total_frames": total_frames, + "avg_valid_keypoints": sum(valid_counts) / len(valid_counts) if valid_counts else 0.0, + "max_valid_keypoints": max(valid_counts) if valid_counts else 0, + "min_valid_keypoints": min(valid_counts) if valid_counts else 0, + "frames_with_keypoints": frames_with_keypoints, + "keypoint_detection_rate": frames_with_keypoints / total_frames if total_frames > 0 else 0.0, + } + + +def main() -> None: + args = parse_args() + + # Initialize miner + print("Initializing Miner...") + init_start = time.time() + miner = Miner(args.repo_path) + init_time = time.time() - init_start + print(f"Miner initialized in {init_time:.2f} seconds") + + # Load video frames + print("\n" + "="*60) + print("Loading video frames...") + load_start = time.time() + frames = load_video_frames(args.video_path, args.stride, args.max_frames) + load_time = time.time() - load_start + print(f"Frames loaded in {load_time:.2f} seconds") + + if len(frames) == 0: + print("No frames loaded. Exiting.") + return + + # Get video properties for output + height, width = frames[0].shape[:2] + cap = cv2.VideoCapture(str(args.video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + cap.release() + + # Predict keypoints + print("\n" + "="*60) + print("Predicting keypoints...") + print(f"Total frames: {len(frames)}") + print(f"Batch size: {args.batch_size if args.batch_size else 'auto (all frames)'}") + print(f"Confidence threshold: {args.conf_threshold}") + print(f"Corner confidence threshold: {args.corner_conf_threshold}") + + predict_start = time.time() + keypoints_dict = miner.predict_keypoints( + images=frames, + n_keypoints=args.n_keypoints, + batch_size=args.batch_size, + conf_threshold=args.conf_threshold, + corner_conf_threshold=args.corner_conf_threshold, + verbose=True, + ) + predict_time = time.time() - predict_start + + # Calculate performance metrics + total_frames = len(frames) + fps_achieved = total_frames / predict_time if predict_time > 0 else 0 + time_per_frame = predict_time / total_frames if total_frames > 0 else 0 + + # Print performance results + print("\n" + "="*60) + print("KEYPOINT PREDICTION PERFORMANCE") + print("="*60) + print(f"Total frames processed: {total_frames}") + print(f"Total prediction time: {predict_time:.3f} seconds") + print(f"Average time per frame: {time_per_frame*1000:.2f} ms") + print(f"Throughput: {fps_achieved:.2f} FPS") + print(f"Batch processing: {'Yes' if args.batch_size else 'No (single batch)'}") + + # Calculate and print statistics + stats = calculate_statistics(keypoints_dict) + print("\n" + "="*60) + print("KEYPOINT DETECTION STATISTICS") + print("="*60) + for key, value in stats.items(): + if isinstance(value, float): + print(f"{key}: {value:.2f}") + else: + print(f"{key}: {value}") + + # Save results if requested + if not args.no_visualization and (args.output_video or args.output_dir): + print("\n" + "="*60) + print("Saving results...") + save_start = time.time() + save_results( + frames=frames, + keypoints_dict=keypoints_dict, + output_video=args.output_video, + output_dir=args.output_dir, + fps=fps / args.stride, + width=width, + height=height, + ) + save_time = time.time() - save_start + print(f"Results saved in {save_time:.2f} seconds") + + print("\n" + "="*60) + print("Done!") + print("="*60) + + +if __name__ == "__main__": + main() + diff --git a/test_predict_objects_video.py b/test_predict_objects_video.py new file mode 100644 index 0000000000000000000000000000000000000000..f85c3ba28957370b9fb279ff9df516b52c227231 --- /dev/null +++ b/test_predict_objects_video.py @@ -0,0 +1,688 @@ +import argparse +import time +from pathlib import Path +from typing import Dict, List, Tuple, Optional +import sys +import os +import queue +import threading +import tempfile +from urllib.parse import urlparse + +import cv2 +import requests + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from miner import Miner, BoundingBox + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="High-speed object detection benchmark on a video file." + ) + parser.add_argument( + "--repo-path", + type=Path, + default="", + help="Path to the HuggingFace/SecretVision repository (models, configs).", + ) + parser.add_argument( + "--video-path", + type=str, + default="test.mp4", + help="Path to the input video file or URL (http:// or https://).", + ) + parser.add_argument( + "--video-url", + type=str, + default=None, + help="URL to download video from (alternative to --video-path).", + ) + parser.add_argument( + "--output-video", + type=Path, + default="outputs-detections/annotated.mp4", + help="Optional path to save an annotated video with detections.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default="outputs-detections/frames", + help="Optional directory to save annotated frames.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size for YOLO inference (None = process all frames at once).", + ) + parser.add_argument( + "--stride", + type=int, + default=1, + help="Sample every Nth frame from the video.", + ) + parser.add_argument( + "--max-frames", + type=int, + default=None, + help="Maximum number of frames to process (after stride).", + ) + parser.add_argument( + "--conf-threshold", + type=float, + default=0.5, + help="Confidence threshold for detections.", + ) + parser.add_argument( + "--iou-threshold", + type=float, + default=0.45, + help="IoU threshold used by YOLO NMS.", + ) + parser.add_argument( + "--classes", + type=int, + nargs="+", + default=None, + help="Optional list of class IDs to keep (default: all classes).", + ) + parser.add_argument( + "--no-visualization", + action="store_true", + help="Skip saving annotated frames/video to maximize throughput.", + ) + return parser.parse_args() + + +def draw_boxes(frame, boxes: List[BoundingBox]) -> None: + """Draw bounding boxes on a frame.""" + if not boxes: + return + + color_map = { + 0: (0, 255, 255), # ball - cyan + 1: (0, 165, 255), # goalkeeper - orange + 2: (0, 255, 0), # player - green + 3: (255, 0, 0), # referee - blue + 4: (128, 128, 128), # gray + 5: (255, 255, 0), # cyan + 6: (255, 0, 255), # magenta + 7: (0, 128, 255), # orange + } + + h, w = frame.shape[:2] + + for box in boxes: + # Validate and clamp coordinates + x1 = max(0, min(int(box.x1), w - 1)) + y1 = max(0, min(int(box.y1), h - 1)) + x2 = max(x1 + 1, min(int(box.x2), w)) + y2 = max(y1 + 1, min(int(box.y2), h)) + + if x2 <= x1 or y2 <= y1: + continue # Skip invalid boxes + + color = color_map.get(box.cls_id, (255, 255, 255)) + cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) + label = f"{box.cls_id}:{box.conf:.2f}" + cv2.putText( + frame, + label, + (x1, max(12, y1 - 6)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + color, + 1, + lineType=cv2.LINE_AA, + ) + + +def annotate_frame(frame, boxes: List[BoundingBox], frame_id: int) -> cv2.Mat: + annotated = frame.copy() + draw_boxes(annotated, boxes) + info = f"Frame {frame_id} | Boxes: {len(boxes)}" + cv2.putText( + annotated, + info, + (10, 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + lineType=cv2.LINE_AA, + ) + return annotated + + +def download_video_from_url(url: str, temp_dir: Optional[Path] = None) -> Path: + """Download video from URL to a temporary file.""" + print(f"Downloading video from {url}...") + download_start = time.time() + + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + + if temp_dir is None: + temp_dir = Path(tempfile.gettempdir()) + else: + temp_dir.mkdir(parents=True, exist_ok=True) + + # Get filename from URL or use a temp name + parsed_url = urlparse(url) + filename = os.path.basename(parsed_url.path) or "video.mp4" + temp_file = temp_dir / f"temp_{int(time.time())}_{filename}" + + with open(temp_file, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + download_time = time.time() - download_start + print(f"Download completed in {download_time:.3f}s") + return temp_file + + +def stream_video_frames( + video_path: Path, + frame_queue: queue.Queue, + stride: int = 1, + max_frames: Optional[int] = None, + stop_event: Optional[threading.Event] = None, +) -> Tuple[int, float]: + """ + Decode video frames in a separate thread and put them in a queue. + Returns: (total_frames_decoded, fps) + """ + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Unable to open video: {video_path}") + + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + frame_count = 0 + source_idx = 0 + decode_start = time.time() + + print(f"Decoding frames from {video_path}...") + try: + while True: + if stop_event and stop_event.is_set(): + break + + ret, frame = cap.read() + if not ret: + break + + if source_idx % stride == 0: + frame_queue.put((frame_count, frame)) + frame_count += 1 + if max_frames and frame_count >= max_frames: + break + if frame_count % 100 == 0: + print(f"Decoded {frame_count} frames...") + + source_idx += 1 + finally: + cap.release() + frame_queue.put((None, None)) # Sentinel to signal end + + decode_time = time.time() - decode_start + print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s") + return frame_count, fps + + +def load_video_frames( + video_path: Path, stride: int = 1, max_frames: Optional[int] = None +) -> List[cv2.Mat]: + """Legacy function: load all frames into memory (non-streaming).""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Unable to open video: {video_path}") + + frames: List[cv2.Mat] = [] + frame_count = 0 + source_idx = 0 + + print(f"Loading frames from {video_path}") + while True: + ret, frame = cap.read() + if not ret: + break + + if source_idx % stride == 0: + frames.append(frame) + frame_count += 1 + if max_frames and frame_count >= max_frames: + break + if frame_count % 100 == 0: + print(f"Loaded {frame_count} frames...") + + source_idx += 1 + + cap.release() + print(f"Total frames loaded: {len(frames)}") + return frames + + +def save_results( + frames: List[cv2.Mat], + detections: Dict[int, List[BoundingBox]], + output_video: Optional[Path], + output_dir: Optional[Path], + fps: float, +) -> None: + if output_video is None and output_dir is None: + return + + if not frames: + print("No frames to save.") + return + + height, width = frames[0].shape[:2] + writer = None + if output_video: + output_video.parent.mkdir(parents=True, exist_ok=True) + writer = cv2.VideoWriter( + str(output_video), + cv2.VideoWriter_fourcc(*"mp4v"), + fps, + (width, height), + ) + print(f"Saving annotated video to {output_video}") + + if output_dir: + output_dir.mkdir(parents=True, exist_ok=True) + print(f"Saving annotated frames to {output_dir}") + + for frame_idx, frame in enumerate(frames): + boxes = detections.get(frame_idx, []) + annotated = annotate_frame(frame, boxes, frame_idx) + + if writer: + writer.write(annotated) + if output_dir: + frame_path = output_dir / f"frame_{frame_idx:06d}.jpg" + cv2.imwrite(str(frame_path), annotated) + + if (frame_idx + 1) % 100 == 0: + print(f"Saved {frame_idx + 1}/{len(frames)} frames...") + + if writer: + writer.release() + print(f"Video saved to {output_video}") + + +def aggregate_stats(detections: Dict[int, List[BoundingBox]]) -> Dict[str, float]: + total_frames = len(detections) + total_boxes = sum(len(boxes) for boxes in detections.values()) + + class_counts: Dict[int, int] = {} + for boxes in detections.values(): + for box in boxes: + class_counts[box.cls_id] = class_counts.get(box.cls_id, 0) + 1 + + stats: Dict[str, float] = { + "frames": total_frames, + "boxes": total_boxes, + } + stats["avg_boxes_per_frame"] = ( + total_boxes / total_frames if total_frames > 0 else 0.0 + ) + for cls_id, count in sorted(class_counts.items()): + stats[f"class_{cls_id}_count"] = count + + return stats + + +def detection_worker( + miner: Miner, + frame_queue: queue.Queue, + result_queue: queue.Queue, + batch_size: int, + conf_threshold: float, + iou_threshold: float, + classes: Optional[List[int]], + stop_event: threading.Event, +) -> None: + """ + Worker thread that processes frames for detection. + Takes frames from frame_queue and puts results in result_queue. + """ + frame_batch: List[cv2.Mat] = [] + frame_indices: List[int] = [] + + while True: + if stop_event.is_set(): + break + + try: + item = frame_queue.get(timeout=0.5) + frame_idx, frame = item + + if frame_idx is None: # Sentinel - decoding finished + # Process remaining frames in batch + if frame_batch: + batch_detections = miner.predict_objects( + images=frame_batch, + batch_size=None, + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + classes=classes, + verbose=False, + ) + + result_queue.put(('batch', { + 'indices': frame_indices, + 'detections': batch_detections, + 'frames': frame_batch.copy(), + })) + + result_queue.put(('done', None)) + break + + frame_batch.append(frame) + frame_indices.append(frame_idx) + + # Process batch when full + if len(frame_batch) >= batch_size: + batch_detections = miner.predict_objects( + images=frame_batch, + batch_size=None, + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + classes=classes, + verbose=False, + ) + + # Debug: Check what we got + total_boxes_in_batch = sum(len(boxes) for boxes in batch_detections.values()) + if total_boxes_in_batch > 0: + print(f"Detection worker: Processed batch of {len(frame_batch)} frames, " + f"found {total_boxes_in_batch} boxes, " + f"detection keys: {list(batch_detections.keys())}, " + f"frame indices: {frame_indices[:5]}...") + + result_queue.put(('batch', { + 'indices': frame_indices.copy(), + 'detections': batch_detections, + 'frames': frame_batch.copy(), + })) + + frame_batch.clear() + frame_indices.clear() + + except queue.Empty: + continue + except Exception as e: + print(f"Error in detection worker: {e}") + result_queue.put(('error', str(e))) + break + + +def process_video_streaming( + miner: Miner, + video_path: Path, + batch_size: Optional[int], + conf_threshold: float, + iou_threshold: float, + classes: Optional[List[int]], + stride: int, + max_frames: Optional[int], +) -> Tuple[Dict[int, List[BoundingBox]], List[cv2.Mat], float, float]: + """ + Process video with truly parallel decode and detection. + Decode thread and detection thread run simultaneously. + Returns: (detections, frames, fps, total_time) + """ + frame_queue: queue.Queue = queue.Queue(maxsize=50) # Buffer for decoded frames + result_queue: queue.Queue = queue.Queue() # Results from detection + frames_queue: queue.Queue = queue.Queue() # Store all decoded frames separately + stop_event = threading.Event() + + effective_batch = batch_size if batch_size else 16 + + # Modified decode function that also stores frames + def decode_and_store_frames(): + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise RuntimeError(f"Unable to open video: {video_path}") + + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + frame_count = 0 + source_idx = 0 + decode_start = time.time() + + print(f"Decoding frames from {video_path}...") + try: + while True: + if stop_event.is_set(): + break + + ret, frame = cap.read() + if not ret: + break + + if source_idx % stride == 0: + frame_queue.put((frame_count, frame)) + frames_queue.put((frame_count, frame)) # Store frame separately + frame_count += 1 + if max_frames and frame_count >= max_frames: + break + if frame_count % 100 == 0: + print(f"Decoded {frame_count} frames...") + + source_idx += 1 + finally: + cap.release() + frame_queue.put((None, None)) # Sentinel to signal end + frames_queue.put((None, None)) # Sentinel for frames queue + + decode_time = time.time() - decode_start + print(f"Total frames decoded: {frame_count} in {decode_time:.3f}s") + return frame_count, fps + + # Start decode thread + decode_thread = threading.Thread( + target=decode_and_store_frames, + daemon=True, + ) + + # Start detection thread + detect_thread = threading.Thread( + target=detection_worker, + args=(miner, frame_queue, result_queue, effective_batch, + conf_threshold, iou_threshold, classes, stop_event), + daemon=True, + ) + + print("\n" + "=" * 60) + print("Running parallel decode + detection...") + print(f"Batch size: {effective_batch}") + print(f"Conf threshold: {conf_threshold}") + print(f"IoU threshold: {iou_threshold}") + if classes: + print(f"Classes filtered: {classes}") + + total_time_start = time.time() + decode_thread.start() + detect_thread.start() + + # Collect all decoded frames first (independent of detection) + frames_dict: Dict[int, cv2.Mat] = {} + while True: + try: + frame_idx, frame = frames_queue.get(timeout=1.0) + if frame_idx is None: + break + frames_dict[frame_idx] = frame + except queue.Empty: + if not decode_thread.is_alive(): + break + continue + + # Collect results from detection thread + all_batches = [] # Store all batch results + frames_processed = 0 + detection_done = False + + while not detection_done: + try: + result_type, result_data = result_queue.get(timeout=2.0) + + if result_type == 'batch': + batch_boxes = sum(len(boxes) for boxes in result_data['detections'].values()) + all_batches.append(result_data) + frames_processed += len(result_data['indices']) + if batch_boxes > 0: + print(f"Collected batch: {len(result_data['indices'])} frames, {batch_boxes} boxes") + if frames_processed % 100 == 0: + print(f"Processed {frames_processed} frames...") + + elif result_type == 'done': + detection_done = True + break + + elif result_type == 'error': + print(f"Detection error: {result_data}") + detection_done = True + break + + except queue.Empty: + # Check if threads are still alive + if not detect_thread.is_alive(): + detection_done = True + break + continue + + # Assemble detections in correct order + detections: Dict[int, List[BoundingBox]] = {} + + print(f"Debug: Assembling detections from {len(all_batches)} batches...") + for batch_idx, batch_data in enumerate(all_batches): + batch_indices = batch_data['indices'] + batch_detections = batch_data['detections'] + + # Debug first batch + if batch_idx == 0: + print(f"Debug batch 0: {len(batch_indices)} frame indices, " + f"detection keys: {list(batch_detections.keys())}, " + f"total boxes in batch: {sum(len(boxes) for boxes in batch_detections.values())}") + + for local_idx, orig_idx in enumerate(batch_indices): + boxes = batch_detections.get(local_idx, []) + detections[orig_idx] = boxes + if batch_idx == 0 and local_idx < 3 and len(boxes) > 0: + print(f"Debug: Frame {orig_idx} (local_idx {local_idx}) has {len(boxes)} boxes") + + # Convert frames_dict to ordered list + if frames_dict: + max_idx = max(frames_dict.keys()) + frames = [frames_dict[i] for i in range(max_idx + 1) if i in frames_dict] + + # Debug: Print detection statistics + total_detections = sum(len(boxes) for boxes in detections.values()) + frames_with_detections = sum(1 for boxes in detections.values() if len(boxes) > 0) + print(f"Debug: {len(frames)} frames, {len(detections)} detection entries, " + f"{total_detections} total boxes, {frames_with_detections} frames with detections") + else: + frames = [] + + # Wait for threads to finish + decode_thread.join(timeout=5.0) + detect_thread.join(timeout=10.0) + total_time = time.time() - total_time_start + + # Get FPS from video metadata + cap = cv2.VideoCapture(str(video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + cap.release() + + return detections, frames, fps, total_time + + +def main() -> None: + args = parse_args() + + print("Initializing Miner...") + init_start = time.time() + miner = Miner(args.repo_path) + print(f"Miner initialized in {time.time() - init_start:.2f}s") + + # Handle URL or local file + video_path = args.video_url if args.video_url else args.video_path + temp_file = None + + # Check if it's a URL + if str(video_path).startswith(('http://', 'https://')): + print("\n" + "=" * 60) + temp_file = download_video_from_url(str(video_path)) + video_path = temp_file + + # Use streaming mode for parallel processing + print("\n" + "=" * 60) + process_start = time.time() + detections, frames, fps, total_time = process_video_streaming( + miner=miner, + video_path=Path(video_path), + batch_size=args.batch_size, + conf_threshold=args.conf_threshold, + iou_threshold=args.iou_threshold, + classes=args.classes, + stride=args.stride, + max_frames=args.max_frames, + ) + + # Clean up temp file if downloaded + if temp_file and temp_file.exists(): + try: + temp_file.unlink() + print(f"Cleaned up temporary file: {temp_file}") + except Exception as e: + print(f"Warning: Could not delete temp file {temp_file}: {e}") + + total_frames = len(frames) + fps_achieved = total_frames / total_time if total_time > 0 else 0.0 + time_per_frame = total_time / total_frames if total_frames > 0 else 0.0 + + print("\n" + "=" * 60) + print("OBJECT DETECTION PERFORMANCE") + print("=" * 60) + print(f"Total frames processed: {total_frames}") + print(f"Total processing time: {total_time:.3f}s") + print(f"Average time per frame: {time_per_frame*1000:.2f} ms") + print(f"Throughput: {fps_achieved:.2f} FPS") + + stats = aggregate_stats(detections) + print("\n" + "=" * 60) + print("DETECTION STATISTICS") + print("=" * 60) + for key, value in stats.items(): + if isinstance(value, float): + print(f"{key}: {value:.2f}") + else: + print(f"{key}: {value}") + + if not args.no_visualization and (args.output_video or args.output_dir) and frames: + print("\n" + "=" * 60) + print("Saving annotated outputs...") + save_start = time.time() + save_results( + frames=frames, + detections=detections, + output_video=args.output_video, + output_dir=args.output_dir, + fps=fps / args.stride, + ) + print(f"Outputs saved in {time.time() - save_start:.2f}s") + elif not frames: + print("\n" + "=" * 60) + print("No frames processed. Skipping output saving.") + + print("\n" + "=" * 60) + print("Done!") + print("=" * 60) + + +if __name__ == "__main__": + main() + diff --git a/utils.pyc b/utils.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a1d87ac5483a28b0c295772d4b0c1d5f7f454f Binary files /dev/null and b/utils.pyc differ