File size: 5,906 Bytes
87d005c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from __future__ import annotations

import importlib.util
import json
import os
import sys
from pathlib import Path
from typing import Any

import cv2
import numpy as np


def _load_local_miner_class():
    miner_path = Path(__file__).resolve().parent / "miner.py"
    spec = importlib.util.spec_from_file_location("manako_bridge_local_miner", str(miner_path))
    if spec is None or spec.loader is None:
        raise RuntimeError(f"Could not load miner module from {miner_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    miner_class = getattr(module, "Miner", None)
    if miner_class is None:
        raise RuntimeError(f"miner.py does not export Miner in {miner_path}")
    return miner_class


Miner = _load_local_miner_class()


CLASS_NAMES = ['fire', 'smoke', 'fire extinguisher']
MODEL_TYPE = 'onnxruntime'


def _to_dict(value: Any) -> dict[str, Any]:
    if isinstance(value, dict):
        return value
    if hasattr(value, "model_dump") and callable(value.model_dump):
        dumped = value.model_dump()
        if isinstance(dumped, dict):
            return dumped
    if hasattr(value, "__dict__"):
        return dict(value.__dict__)
    return {}


def _extract_boxes(frame_result: Any) -> list[Any]:
    frame = _to_dict(frame_result)
    boxes = frame.get("boxes", [])
    if isinstance(boxes, list):
        return boxes
    return []


def _resolve_runtime_class_names(miner: Any) -> list[str]:
    value = getattr(miner, "class_names", None)
    if isinstance(value, (list, tuple)):
        resolved = [str(item) for item in value]
        if resolved:
            return resolved
    return list(CLASS_NAMES)


def _to_detection(box: Any, class_names: list[str]) -> dict[str, Any]:
    payload = _to_dict(box)
    cls_id = int(payload.get("cls_id", 0))
    x1 = float(payload.get("x1", 0.0))
    y1 = float(payload.get("y1", 0.0))
    x2 = float(payload.get("x2", 0.0))
    y2 = float(payload.get("y2", 0.0))
    width = max(0.0, x2 - x1)
    height = max(0.0, y2 - y1)
    return {
        "x": x1 + width / 2.0,
        "y": y1 + height / 2.0,
        "width": width,
        "height": height,
        "confidence": float(payload.get("conf", 0.0)),
        "class_id": cls_id,
        "class": class_names[cls_id] if 0 <= cls_id < len(class_names) else str(cls_id),
    }


def _normalize_image_for_miner(image: Any) -> Any:
    if image is None or hasattr(image, "shape"):
        return image
    if isinstance(image, (bytes, bytearray, memoryview)):
        try:
            buffer = np.frombuffer(bytes(image), dtype=np.uint8)
            decoded = cv2.imdecode(buffer, cv2.IMREAD_COLOR)
            if decoded is not None:
                return decoded
        except Exception:
            return image
    if hasattr(image, "convert") and callable(image.convert):
        try:
            rgb = image.convert("RGB")
            array = np.array(rgb)
            if getattr(array, "ndim", 0) == 3 and array.shape[-1] == 3:
                return cv2.cvtColor(array, cv2.COLOR_RGB2BGR)
            return array
        except Exception:
            return image
    try:
        array = np.asarray(image)
        if getattr(array, "shape", None):
            return array
    except Exception:
        return image
    return image


def load_model(onnx_path: str | None = None, data_dir: str | None = None):
    del onnx_path
    repo_dir = Path(data_dir) if data_dir else Path(__file__).resolve().parent
    miner = Miner(repo_dir)
    class_names = _resolve_runtime_class_names(miner)
    return {
        "miner": miner,
        "model_type": MODEL_TYPE,
        "class_names": class_names,
    }


def _candidate_keypoint_counts(miner: Any) -> list[int]:
    counts: list[int] = [0]
    for attr in ("n_keypoints", "num_keypoints", "keypoint_count", "num_joints"):
        value = getattr(miner, attr, None)
        if isinstance(value, int) and value > 0:
            counts.append(value)
    counts.append(32)

    seen: set[int] = set()
    ordered: list[int] = []
    for count in counts:
        if count in seen:
            continue
        seen.add(count)
        ordered.append(count)
    return ordered


def _predict_batch_with_fallbacks(miner: Any, image: Any) -> list[Any]:
    normalized_image = _normalize_image_for_miner(image)
    errors: list[str] = []
    for n_keypoints in _candidate_keypoint_counts(miner):
        try:
            return miner.predict_batch([normalized_image], offset=0, n_keypoints=n_keypoints)
        except Exception as exc:
            errors.append(f"n_keypoints={n_keypoints} -> {exc}")
            continue
    raise RuntimeError("predict_batch failed for all keypoint candidates: " + " | ".join(errors))


def run_model(model: Any, image: Any = None, onnx_path: str | None = None, data_dir: str | None = None):
    del onnx_path
    if image is None:
        image = model
        model = load_model(data_dir=data_dir)
    miner = model["miner"]
    class_names = model.get("class_names")
    if not isinstance(class_names, list):
        class_names = list(CLASS_NAMES)
    results = _predict_batch_with_fallbacks(miner, image)
    if not results:
        return [[]]
    frame_boxes = _extract_boxes(results[0])
    detections = [_to_detection(box, class_names) for box in frame_boxes]
    return [detections]


def main() -> None:
    if len(sys.argv) < 2:
        print("Usage: main.py <image_path>", file=sys.stderr)
        raise SystemExit(1)
    image_path = sys.argv[1]
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    if image is None:
        print(f"Could not read image: {image_path}", file=sys.stderr)
        raise SystemExit(1)
    data_dir = os.path.dirname(os.path.abspath(__file__))
    model = load_model(data_dir=data_dir)
    output = run_model(model, image)
    print(json.dumps(output, indent=2))


if __name__ == "__main__":
    main()