MTerryJack commited on
Commit
45c4687
·
verified ·
1 Parent(s): 47307f1

subnet_bridge: add element runtime files

Browse files
Files changed (2) hide show
  1. main.py +99 -0
  2. pyproject.toml +13 -0
main.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import cv2
10
+
11
+ from miner import Miner
12
+
13
+
14
+ CLASS_NAMES = ['football', 'player', 'pitch']
15
+ MODEL_TYPE = 'ultralytics-yolo'
16
+
17
+
18
+ def _to_dict(value: Any) -> dict[str, Any]:
19
+ if isinstance(value, dict):
20
+ return value
21
+ if hasattr(value, "model_dump") and callable(value.model_dump):
22
+ dumped = value.model_dump()
23
+ if isinstance(dumped, dict):
24
+ return dumped
25
+ if hasattr(value, "__dict__"):
26
+ return dict(value.__dict__)
27
+ return {}
28
+
29
+
30
+ def _extract_boxes(frame_result: Any) -> list[Any]:
31
+ frame = _to_dict(frame_result)
32
+ boxes = frame.get("boxes", [])
33
+ if isinstance(boxes, list):
34
+ return boxes
35
+ return []
36
+
37
+
38
+ def _to_detection(box: Any) -> dict[str, Any]:
39
+ payload = _to_dict(box)
40
+ cls_id = int(payload.get("cls_id", 0))
41
+ x1 = float(payload.get("x1", 0.0))
42
+ y1 = float(payload.get("y1", 0.0))
43
+ x2 = float(payload.get("x2", 0.0))
44
+ y2 = float(payload.get("y2", 0.0))
45
+ width = max(0.0, x2 - x1)
46
+ height = max(0.0, y2 - y1)
47
+ return {
48
+ "x": x1 + width / 2.0,
49
+ "y": y1 + height / 2.0,
50
+ "width": width,
51
+ "height": height,
52
+ "confidence": float(payload.get("conf", 0.0)),
53
+ "class_id": cls_id,
54
+ "class": CLASS_NAMES[cls_id] if 0 <= cls_id < len(CLASS_NAMES) else str(cls_id),
55
+ }
56
+
57
+
58
+ def load_model(onnx_path: str | None = None, data_dir: str | None = None):
59
+ del onnx_path
60
+ repo_dir = Path(data_dir) if data_dir else Path(__file__).resolve().parent
61
+ miner = Miner(repo_dir)
62
+ return {
63
+ "miner": miner,
64
+ "model_type": MODEL_TYPE,
65
+ "class_names": CLASS_NAMES,
66
+ }
67
+
68
+
69
+ def run_model(model: Any, image: Any = None, onnx_path: str | None = None, data_dir: str | None = None):
70
+ del onnx_path
71
+ if image is None:
72
+ image = model
73
+ model = load_model(data_dir=data_dir)
74
+ miner = model["miner"]
75
+ results = miner.predict_batch([image], offset=0, n_keypoints=0)
76
+ if not results:
77
+ return [[]]
78
+ frame_boxes = _extract_boxes(results[0])
79
+ detections = [_to_detection(box) for box in frame_boxes]
80
+ return [detections]
81
+
82
+
83
+ def main() -> None:
84
+ if len(sys.argv) < 2:
85
+ print("Usage: main.py <image_path>", file=sys.stderr)
86
+ raise SystemExit(1)
87
+ image_path = sys.argv[1]
88
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR)
89
+ if image is None:
90
+ print(f"Could not read image: {image_path}", file=sys.stderr)
91
+ raise SystemExit(1)
92
+ data_dir = os.path.dirname(os.path.abspath(__file__))
93
+ model = load_model(data_dir=data_dir)
94
+ output = run_model(model, image)
95
+ print(json.dumps(output, indent=2))
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "miner-element-adapter"
3
+ version = "0.1.0"
4
+ requires-python = ">=3.9"
5
+
6
+ dependencies = [
7
+ "huggingface_hub==0.19.4",
8
+ "ultralytics==8.2.40",
9
+ "torch<2.6",
10
+ "opencv-python-headless",
11
+ "pydantic>=2.0",
12
+ "numpy>=1.23",
13
+ ]