MTerryJack's picture
subnet_bridge: copy winning miner repo into library
231934f verified
from __future__ import annotations
import importlib.util
import json
import os
import sys
from pathlib import Path
from typing import Any
import cv2
import numpy as np
def _load_local_miner_class():
miner_path = Path(__file__).resolve().parent / "miner.py"
spec = importlib.util.spec_from_file_location("manako_bridge_local_miner", str(miner_path))
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not load miner module from {miner_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
miner_class = getattr(module, "Miner", None)
if miner_class is None:
raise RuntimeError(f"miner.py does not export Miner in {miner_path}")
return miner_class
Miner = _load_local_miner_class()
CLASS_NAMES = ['ball', 'person']
MODEL_TYPE = 'ultralytics-yolo'
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if hasattr(value, "model_dump") and callable(value.model_dump):
dumped = value.model_dump()
if isinstance(dumped, dict):
return dumped
if hasattr(value, "__dict__"):
return dict(value.__dict__)
return {}
def _extract_boxes(frame_result: Any) -> list[Any]:
frame = _to_dict(frame_result)
boxes = frame.get("boxes", [])
if isinstance(boxes, list):
return boxes
return []
def _to_detection(box: Any) -> dict[str, Any]:
payload = _to_dict(box)
cls_id = int(payload.get("cls_id", 0))
x1 = float(payload.get("x1", 0.0))
y1 = float(payload.get("y1", 0.0))
x2 = float(payload.get("x2", 0.0))
y2 = float(payload.get("y2", 0.0))
width = max(0.0, x2 - x1)
height = max(0.0, y2 - y1)
return {
"x": x1 + width / 2.0,
"y": y1 + height / 2.0,
"width": width,
"height": height,
"confidence": float(payload.get("conf", 0.0)),
"class_id": cls_id,
"class": CLASS_NAMES[cls_id] if 0 <= cls_id < len(CLASS_NAMES) else str(cls_id),
}
def _normalize_image_for_miner(image: Any) -> Any:
if image is None or hasattr(image, "shape"):
return image
if isinstance(image, (bytes, bytearray, memoryview)):
try:
buffer = np.frombuffer(bytes(image), dtype=np.uint8)
decoded = cv2.imdecode(buffer, cv2.IMREAD_COLOR)
if decoded is not None:
return decoded
except Exception:
return image
if hasattr(image, "convert") and callable(image.convert):
try:
rgb = image.convert("RGB")
array = np.array(rgb)
if getattr(array, "ndim", 0) == 3 and array.shape[-1] == 3:
return cv2.cvtColor(array, cv2.COLOR_RGB2BGR)
return array
except Exception:
return image
try:
array = np.asarray(image)
if getattr(array, "shape", None):
return array
except Exception:
return image
return image
def load_model(onnx_path: str | None = None, data_dir: str | None = None):
del onnx_path
repo_dir = Path(data_dir) if data_dir else Path(__file__).resolve().parent
miner = Miner(repo_dir)
return {
"miner": miner,
"model_type": MODEL_TYPE,
"class_names": CLASS_NAMES,
}
def _candidate_keypoint_counts(miner: Any) -> list[int]:
counts: list[int] = [0]
for attr in ("n_keypoints", "num_keypoints", "keypoint_count", "num_joints"):
value = getattr(miner, attr, None)
if isinstance(value, int) and value > 0:
counts.append(value)
counts.append(32)
seen: set[int] = set()
ordered: list[int] = []
for count in counts:
if count in seen:
continue
seen.add(count)
ordered.append(count)
return ordered
def _predict_batch_with_fallbacks(miner: Any, image: Any) -> list[Any]:
normalized_image = _normalize_image_for_miner(image)
errors: list[str] = []
for n_keypoints in _candidate_keypoint_counts(miner):
try:
return miner.predict_batch([normalized_image], offset=0, n_keypoints=n_keypoints)
except Exception as exc:
errors.append(f"n_keypoints={n_keypoints} -> {exc}")
continue
raise RuntimeError("predict_batch failed for all keypoint candidates: " + " | ".join(errors))
def run_model(model: Any, image: Any = None, onnx_path: str | None = None, data_dir: str | None = None):
del onnx_path
if image is None:
image = model
model = load_model(data_dir=data_dir)
miner = model["miner"]
results = _predict_batch_with_fallbacks(miner, image)
if not results:
return [[]]
frame_boxes = _extract_boxes(results[0])
detections = [_to_detection(box) for box in frame_boxes]
return [detections]
def main() -> None:
if len(sys.argv) < 2:
print("Usage: main.py <image_path>", file=sys.stderr)
raise SystemExit(1)
image_path = sys.argv[1]
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if image is None:
print(f"Could not read image: {image_path}", file=sys.stderr)
raise SystemExit(1)
data_dir = os.path.dirname(os.path.abspath(__file__))
model = load_model(data_dir=data_dir)
output = run_model(model, image)
print(json.dumps(output, indent=2))
if __name__ == "__main__":
main()