File size: 5,906 Bytes
87d005c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | from __future__ import annotations
import importlib.util
import json
import os
import sys
from pathlib import Path
from typing import Any
import cv2
import numpy as np
def _load_local_miner_class():
miner_path = Path(__file__).resolve().parent / "miner.py"
spec = importlib.util.spec_from_file_location("manako_bridge_local_miner", str(miner_path))
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not load miner module from {miner_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
miner_class = getattr(module, "Miner", None)
if miner_class is None:
raise RuntimeError(f"miner.py does not export Miner in {miner_path}")
return miner_class
Miner = _load_local_miner_class()
CLASS_NAMES = ['fire', 'smoke', 'fire extinguisher']
MODEL_TYPE = 'onnxruntime'
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if hasattr(value, "model_dump") and callable(value.model_dump):
dumped = value.model_dump()
if isinstance(dumped, dict):
return dumped
if hasattr(value, "__dict__"):
return dict(value.__dict__)
return {}
def _extract_boxes(frame_result: Any) -> list[Any]:
frame = _to_dict(frame_result)
boxes = frame.get("boxes", [])
if isinstance(boxes, list):
return boxes
return []
def _resolve_runtime_class_names(miner: Any) -> list[str]:
value = getattr(miner, "class_names", None)
if isinstance(value, (list, tuple)):
resolved = [str(item) for item in value]
if resolved:
return resolved
return list(CLASS_NAMES)
def _to_detection(box: Any, class_names: list[str]) -> dict[str, Any]:
payload = _to_dict(box)
cls_id = int(payload.get("cls_id", 0))
x1 = float(payload.get("x1", 0.0))
y1 = float(payload.get("y1", 0.0))
x2 = float(payload.get("x2", 0.0))
y2 = float(payload.get("y2", 0.0))
width = max(0.0, x2 - x1)
height = max(0.0, y2 - y1)
return {
"x": x1 + width / 2.0,
"y": y1 + height / 2.0,
"width": width,
"height": height,
"confidence": float(payload.get("conf", 0.0)),
"class_id": cls_id,
"class": class_names[cls_id] if 0 <= cls_id < len(class_names) else str(cls_id),
}
def _normalize_image_for_miner(image: Any) -> Any:
if image is None or hasattr(image, "shape"):
return image
if isinstance(image, (bytes, bytearray, memoryview)):
try:
buffer = np.frombuffer(bytes(image), dtype=np.uint8)
decoded = cv2.imdecode(buffer, cv2.IMREAD_COLOR)
if decoded is not None:
return decoded
except Exception:
return image
if hasattr(image, "convert") and callable(image.convert):
try:
rgb = image.convert("RGB")
array = np.array(rgb)
if getattr(array, "ndim", 0) == 3 and array.shape[-1] == 3:
return cv2.cvtColor(array, cv2.COLOR_RGB2BGR)
return array
except Exception:
return image
try:
array = np.asarray(image)
if getattr(array, "shape", None):
return array
except Exception:
return image
return image
def load_model(onnx_path: str | None = None, data_dir: str | None = None):
del onnx_path
repo_dir = Path(data_dir) if data_dir else Path(__file__).resolve().parent
miner = Miner(repo_dir)
class_names = _resolve_runtime_class_names(miner)
return {
"miner": miner,
"model_type": MODEL_TYPE,
"class_names": class_names,
}
def _candidate_keypoint_counts(miner: Any) -> list[int]:
counts: list[int] = [0]
for attr in ("n_keypoints", "num_keypoints", "keypoint_count", "num_joints"):
value = getattr(miner, attr, None)
if isinstance(value, int) and value > 0:
counts.append(value)
counts.append(32)
seen: set[int] = set()
ordered: list[int] = []
for count in counts:
if count in seen:
continue
seen.add(count)
ordered.append(count)
return ordered
def _predict_batch_with_fallbacks(miner: Any, image: Any) -> list[Any]:
normalized_image = _normalize_image_for_miner(image)
errors: list[str] = []
for n_keypoints in _candidate_keypoint_counts(miner):
try:
return miner.predict_batch([normalized_image], offset=0, n_keypoints=n_keypoints)
except Exception as exc:
errors.append(f"n_keypoints={n_keypoints} -> {exc}")
continue
raise RuntimeError("predict_batch failed for all keypoint candidates: " + " | ".join(errors))
def run_model(model: Any, image: Any = None, onnx_path: str | None = None, data_dir: str | None = None):
del onnx_path
if image is None:
image = model
model = load_model(data_dir=data_dir)
miner = model["miner"]
class_names = model.get("class_names")
if not isinstance(class_names, list):
class_names = list(CLASS_NAMES)
results = _predict_batch_with_fallbacks(miner, image)
if not results:
return [[]]
frame_boxes = _extract_boxes(results[0])
detections = [_to_detection(box, class_names) for box in frame_boxes]
return [detections]
def main() -> None:
if len(sys.argv) < 2:
print("Usage: main.py <image_path>", file=sys.stderr)
raise SystemExit(1)
image_path = sys.argv[1]
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if image is None:
print(f"Could not read image: {image_path}", file=sys.stderr)
raise SystemExit(1)
data_dir = os.path.dirname(os.path.abspath(__file__))
model = load_model(data_dir=data_dir)
output = run_model(model, image)
print(json.dumps(output, indent=2))
if __name__ == "__main__":
main()
|