car-detection / server /detect.py
socks22's picture
fix detect
a2e9c4d
"""ONNX inference for car detection in aerial images."""
import base64
import os
from pathlib import Path
import cv2
import numpy as np
import onnxruntime as ort
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_PATHS: dict[str, Path] = {
"cars": Path(
os.environ.get(
"CAR_MODEL_PATH",
str(
_PROJECT_ROOT
/ "training"
/ "exported_models"
/ "inference_model.sim.onnx"
),
)
),
"spots": Path(
os.environ.get(
"SPOT_MODEL_PATH",
str(_PROJECT_ROOT / "training" / "spot_exported" / "inference_model.onnx"),
)
),
}
MODEL_CLASSES: dict[str, list[str]] = {
"cars": ["car"],
"spots": ["empty", "occupied"],
}
# Per-class colors in BGR
_CLASS_COLORS = [
(0, 255, 0), # green — class 0
(0, 165, 255), # orange — class 1
(255, 0, 0), # blue — class 2
(0, 255, 255), # yellow — class 3
]
def load_model(model_path: Path) -> ort.InferenceSession:
"""Load the ONNX model and return an inference session."""
return ort.InferenceSession(str(model_path))
def get_resolution(session: ort.InferenceSession) -> int:
"""Read the expected input resolution from the model's input shape."""
shape = session.get_inputs()[0].shape # e.g. [1, 3, H, W]
return int(shape[2])
def preprocess(
image: np.ndarray, resolution: int
) -> tuple[np.ndarray, tuple[int, int]]:
"""Resize and normalize an image for ONNX inference.
Returns the preprocessed tensor (1, 3, H, W) and original (h, w).
"""
orig_h, orig_w = image.shape[:2]
resized = cv2.resize(image, (resolution, resolution))
# BGR -> RGB, HWC -> CHW, uint8 -> float32 [0,1]
tensor = resized[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) / 255.0
# ImageNet normalization
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
tensor = (tensor - mean) / std
return np.expand_dims(tensor, axis=0), (orig_h, orig_w)
def postprocess(
outputs: dict[str, np.ndarray],
orig_hw: tuple[int, int],
threshold: float,
class_names: list[str],
) -> list[dict]:
"""Convert ONNX outputs to a list of detection dicts.
Each dict has keys: "bbox" (list[float] xyxy), "score" (float),
"class_id" (int), "class_name" (str).
RF-DETR uses per-class sigmoid (not softmax). Each logit column is an
independent binary classifier — there is no "no-object" column.
"""
boxes = outputs["dets"].reshape(-1, 4)
logits = outputs["labels"].reshape(boxes.shape[0], -1)
num_classes = logits.shape[1]
# Sigmoid per logit (independent binary classifiers)
probs = 1.0 / (1.0 + np.exp(-logits))
# Best class per detection
class_ids = probs.argmax(axis=1)
scores = probs[np.arange(len(class_ids)), class_ids]
# Normalized cxcywh -> pixel xyxy
orig_h, orig_w = orig_hw
cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
xyxy = np.stack(
[
(cx - w / 2) * orig_w,
(cy - h / 2) * orig_h,
(cx + w / 2) * orig_w,
(cy + h / 2) * orig_h,
],
axis=1,
)
mask = scores >= threshold
xyxy = xyxy[mask]
scores = scores[mask]
class_ids = class_ids[mask]
return [
{
"bbox": box.tolist(),
"score": float(s),
"class_id": int(cid),
"class_name": class_names[cid] if cid < len(class_names) else str(cid),
}
for box, s, cid in zip(xyxy, scores, class_ids)
]
def run_detection(
session: ort.InferenceSession,
image: np.ndarray,
threshold: float = 0.5,
class_names: list[str] | None = None,
) -> list[dict]:
"""Run full detection pipeline on a BGR image."""
if class_names is None:
class_names = ["car"]
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]
resolution = get_resolution(session)
tensor, orig_hw = preprocess(image, resolution)
raw_outputs = session.run(output_names, {input_name: tensor})
outputs = dict(zip(output_names, raw_outputs))
return postprocess(outputs, orig_hw, threshold, class_names)
def annotate_image(image: np.ndarray, detections: list[dict]) -> np.ndarray:
"""Draw bounding boxes and scores on the image."""
annotated = image.copy()
for det in detections:
class_id = det.get("class_id", 0)
color = _CLASS_COLORS[class_id % len(_CLASS_COLORS)]
class_name = det.get("class_name", "")
x1, y1, x2, y2 = [int(v) for v in det["bbox"]]
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
label = f'{class_name} {det["score"]:.2f}'
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(annotated, (x1, y1 - th - 6), (x1 + tw + 4, y1), color, -1)
cv2.putText(
annotated,
label,
(x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
cv2.LINE_AA,
)
return annotated
def image_to_data_uri(image: np.ndarray, quality: int = 85) -> str:
"""Encode a BGR image as a JPEG base64 data URI."""
_, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])
b64 = base64.b64encode(buf.tobytes()).decode("ascii")
return f"data:image/jpeg;base64,{b64}"