| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import cv2 |
|
|
| from miner import Miner |
|
|
|
|
| CLASS_NAMES = ['football', 'player', 'pitch'] |
| MODEL_TYPE = 'ultralytics-yolo' |
|
|
|
|
| def _to_dict(value: Any) -> dict[str, Any]: |
| if isinstance(value, dict): |
| return value |
| if hasattr(value, "model_dump") and callable(value.model_dump): |
| dumped = value.model_dump() |
| if isinstance(dumped, dict): |
| return dumped |
| if hasattr(value, "__dict__"): |
| return dict(value.__dict__) |
| return {} |
|
|
|
|
| def _extract_boxes(frame_result: Any) -> list[Any]: |
| frame = _to_dict(frame_result) |
| boxes = frame.get("boxes", []) |
| if isinstance(boxes, list): |
| return boxes |
| return [] |
|
|
|
|
| def _to_detection(box: Any) -> dict[str, Any]: |
| payload = _to_dict(box) |
| cls_id = int(payload.get("cls_id", 0)) |
| x1 = float(payload.get("x1", 0.0)) |
| y1 = float(payload.get("y1", 0.0)) |
| x2 = float(payload.get("x2", 0.0)) |
| y2 = float(payload.get("y2", 0.0)) |
| width = max(0.0, x2 - x1) |
| height = max(0.0, y2 - y1) |
| return { |
| "x": x1 + width / 2.0, |
| "y": y1 + height / 2.0, |
| "width": width, |
| "height": height, |
| "confidence": float(payload.get("conf", 0.0)), |
| "class_id": cls_id, |
| "class": CLASS_NAMES[cls_id] if 0 <= cls_id < len(CLASS_NAMES) else str(cls_id), |
| } |
|
|
|
|
| def load_model(onnx_path: str | None = None, data_dir: str | None = None): |
| del onnx_path |
| repo_dir = Path(data_dir) if data_dir else Path(__file__).resolve().parent |
| miner = Miner(repo_dir) |
| return { |
| "miner": miner, |
| "model_type": MODEL_TYPE, |
| "class_names": CLASS_NAMES, |
| } |
|
|
|
|
| def run_model(model: Any, image: Any = None, onnx_path: str | None = None, data_dir: str | None = None): |
| del onnx_path |
| if image is None: |
| image = model |
| model = load_model(data_dir=data_dir) |
| miner = model["miner"] |
| results = miner.predict_batch([image], offset=0, n_keypoints=0) |
| if not results: |
| return [[]] |
| frame_boxes = _extract_boxes(results[0]) |
| detections = [_to_detection(box) for box in frame_boxes] |
| return [detections] |
|
|
|
|
| def main() -> None: |
| if len(sys.argv) < 2: |
| print("Usage: main.py <image_path>", file=sys.stderr) |
| raise SystemExit(1) |
| image_path = sys.argv[1] |
| image = cv2.imread(image_path, cv2.IMREAD_COLOR) |
| if image is None: |
| print(f"Could not read image: {image_path}", file=sys.stderr) |
| raise SystemExit(1) |
| data_dir = os.path.dirname(os.path.abspath(__file__)) |
| model = load_model(data_dir=data_dir) |
| output = run_model(model, image) |
| print(json.dumps(output, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|