File size: 5,581 Bytes
f3f6f5d
 
 
84f8376
f3f6f5d
 
 
 
 
 
 
 
4bce717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f6f5d
4bce717
 
f3f6f5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bce717
f3f6f5d
 
 
4bce717
 
a2e9c4d
 
 
f3f6f5d
 
4bce717
a2e9c4d
4bce717
a2e9c4d
 
4bce717
a2e9c4d
 
 
f3f6f5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bce717
f3f6f5d
 
4bce717
 
 
 
 
 
 
f3f6f5d
 
 
 
 
 
 
4bce717
f3f6f5d
 
4bce717
 
f3f6f5d
 
 
 
 
 
 
4bce717
f3f6f5d
 
 
 
 
 
4bce717
 
 
 
f3f6f5d
 
4bce717
f3f6f5d
 
 
4bce717
 
 
 
 
 
 
 
f3f6f5d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""ONNX inference for car detection in aerial images."""

import base64
import os
from pathlib import Path

import cv2
import numpy as np
import onnxruntime as ort

_PROJECT_ROOT = Path(__file__).resolve().parent.parent

MODEL_PATHS: dict[str, Path] = {
    "cars": Path(
        os.environ.get(
            "CAR_MODEL_PATH",
            str(
                _PROJECT_ROOT
                / "training"
                / "exported_models"
                / "inference_model.sim.onnx"
            ),
        )
    ),
    "spots": Path(
        os.environ.get(
            "SPOT_MODEL_PATH",
            str(_PROJECT_ROOT / "training" / "spot_exported" / "inference_model.onnx"),
        )
    ),
}

MODEL_CLASSES: dict[str, list[str]] = {
    "cars": ["car"],
    "spots": ["empty", "occupied"],
}

# Per-class colors in BGR
_CLASS_COLORS = [
    (0, 255, 0),  # green — class 0
    (0, 165, 255),  # orange — class 1
    (255, 0, 0),  # blue — class 2
    (0, 255, 255),  # yellow — class 3
]


def load_model(model_path: Path) -> ort.InferenceSession:
    """Load the ONNX model and return an inference session."""
    return ort.InferenceSession(str(model_path))


def get_resolution(session: ort.InferenceSession) -> int:
    """Read the expected input resolution from the model's input shape."""
    shape = session.get_inputs()[0].shape  # e.g. [1, 3, H, W]
    return int(shape[2])


def preprocess(
    image: np.ndarray, resolution: int
) -> tuple[np.ndarray, tuple[int, int]]:
    """Resize and normalize an image for ONNX inference.

    Returns the preprocessed tensor (1, 3, H, W) and original (h, w).
    """
    orig_h, orig_w = image.shape[:2]
    resized = cv2.resize(image, (resolution, resolution))
    # BGR -> RGB, HWC -> CHW, uint8 -> float32 [0,1]
    tensor = resized[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
    # ImageNet normalization
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
    tensor = (tensor - mean) / std
    return np.expand_dims(tensor, axis=0), (orig_h, orig_w)


def postprocess(
    outputs: dict[str, np.ndarray],
    orig_hw: tuple[int, int],
    threshold: float,
    class_names: list[str],
) -> list[dict]:
    """Convert ONNX outputs to a list of detection dicts.

    Each dict has keys: "bbox" (list[float] xyxy), "score" (float),
    "class_id" (int), "class_name" (str).

    RF-DETR uses per-class sigmoid (not softmax). Each logit column is an
    independent binary classifier — there is no "no-object" column.
    """
    boxes = outputs["dets"].reshape(-1, 4)
    logits = outputs["labels"].reshape(boxes.shape[0], -1)
    num_classes = logits.shape[1]

    # Sigmoid per logit (independent binary classifiers)
    probs = 1.0 / (1.0 + np.exp(-logits))

    # Best class per detection
    class_ids = probs.argmax(axis=1)
    scores = probs[np.arange(len(class_ids)), class_ids]

    # Normalized cxcywh -> pixel xyxy
    orig_h, orig_w = orig_hw
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    xyxy = np.stack(
        [
            (cx - w / 2) * orig_w,
            (cy - h / 2) * orig_h,
            (cx + w / 2) * orig_w,
            (cy + h / 2) * orig_h,
        ],
        axis=1,
    )

    mask = scores >= threshold
    xyxy = xyxy[mask]
    scores = scores[mask]
    class_ids = class_ids[mask]

    return [
        {
            "bbox": box.tolist(),
            "score": float(s),
            "class_id": int(cid),
            "class_name": class_names[cid] if cid < len(class_names) else str(cid),
        }
        for box, s, cid in zip(xyxy, scores, class_ids)
    ]


def run_detection(
    session: ort.InferenceSession,
    image: np.ndarray,
    threshold: float = 0.5,
    class_names: list[str] | None = None,
) -> list[dict]:
    """Run full detection pipeline on a BGR image."""
    if class_names is None:
        class_names = ["car"]
    input_name = session.get_inputs()[0].name
    output_names = [o.name for o in session.get_outputs()]
    resolution = get_resolution(session)

    tensor, orig_hw = preprocess(image, resolution)
    raw_outputs = session.run(output_names, {input_name: tensor})
    outputs = dict(zip(output_names, raw_outputs))
    return postprocess(outputs, orig_hw, threshold, class_names)


def annotate_image(image: np.ndarray, detections: list[dict]) -> np.ndarray:
    """Draw bounding boxes and scores on the image."""
    annotated = image.copy()
    for det in detections:
        class_id = det.get("class_id", 0)
        color = _CLASS_COLORS[class_id % len(_CLASS_COLORS)]
        class_name = det.get("class_name", "")

        x1, y1, x2, y2 = [int(v) for v in det["bbox"]]
        cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
        label = f'{class_name} {det["score"]:.2f}'
        (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        cv2.rectangle(annotated, (x1, y1 - th - 6), (x1 + tw + 4, y1), color, -1)
        cv2.putText(
            annotated,
            label,
            (x1 + 2, y1 - 4),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (0, 0, 0),
            1,
            cv2.LINE_AA,
        )
    return annotated


def image_to_data_uri(image: np.ndarray, quality: int = 85) -> str:
    """Encode a BGR image as a JPEG base64 data URI."""
    _, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])
    b64 = base64.b64encode(buf.tobytes()).decode("ascii")
    return f"data:image/jpeg;base64,{b64}"