| """ |
| infer.py β Two-stage ONNX inference for the Field Station Space. |
| |
| Mirrors the on-device forager_ml pipeline (domain router -> expert routing -> |
| abstention) but runs on CPU via onnxruntime instead of the Hailo NPU. |
| |
| Stage 1: domain router classifies the frame into berry / mushroom / plant / other. |
| Stage 2: route to the matching expert(s); for multi-expert domains run both and |
| keep the higher-confidence call. Abstain (UNKNOWN) when the router is |
| unsure, the domain is "other", or the winning expert is below threshold. |
| |
| Preprocessing note: these ONNX files are the bare timm tf_efficientnet_lite2 |
| models (no normalization baked in), so inputs are ImageNet-normalized |
| [1, 3, 224, 224] β NOT the [0,255] NHWC the HEF expects. |
| """ |
|
|
| import json |
| import os |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from PIL import Image |
|
|
| MODELS_DIR = os.path.join(os.path.dirname(__file__), "..", "models") |
|
|
| ROUTER = "domain_router_v2" |
| |
| |
| |
| |
| EXPERTS = ["berry_expert", "highvalue_expert", "medicinals_expert"] |
|
|
| |
| |
| |
| |
| |
| |
| DOMAIN_EXPERTS: dict[str, str] = { |
| "berry": "berry_expert", |
| "mushroom": "highvalue_expert", |
| "plant": "medicinals_expert", |
| } |
|
|
| |
| ROUTER_CONFIDENCE_THRESHOLD = 0.74 |
| EXPERT_CONFIDENCE_THRESHOLD = 0.75 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| ENABLE_ENERGY_SUPPRESSION = False |
| ENERGY_SUPPRESS_PERCENTILE = "p90" |
|
|
| _IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| _IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| _RESIZE_SHORT = int(224 * 1.14) |
| _CROP = 224 |
|
|
|
|
| def preprocess(img: Image.Image) -> np.ndarray: |
| """PIL image -> ImageNet-normalized float32 NCHW [1, 3, 224, 224].""" |
| img = img.convert("RGB") |
| w, h = img.size |
| if w <= h: |
| nw, nh = _RESIZE_SHORT, round(_RESIZE_SHORT * h / w) |
| else: |
| nw, nh = round(_RESIZE_SHORT * w / h), _RESIZE_SHORT |
| img = img.resize((nw, nh), Image.BILINEAR) |
| left = (nw - _CROP) // 2 |
| top = (nh - _CROP) // 2 |
| img = img.crop((left, top, left + _CROP, top + _CROP)) |
|
|
| x = np.asarray(img, dtype=np.float32) / 255.0 |
| x = (x - _IMAGENET_MEAN) / _IMAGENET_STD |
| x = np.transpose(x, (2, 0, 1))[None] |
| return np.ascontiguousarray(x, dtype=np.float32) |
|
|
|
|
| def _softmax(logits: np.ndarray) -> np.ndarray: |
| z = logits - logits.max() |
| e = np.exp(z) |
| return e / e.sum() |
|
|
|
|
| def _energy(logits: np.ndarray) -> float: |
| """Correct energy E(x) = -logsumexp(logits). Higher = more out-of-domain.""" |
| m = logits.max() |
| return -float(m + np.log(np.exp(logits - m).sum())) |
|
|
|
|
| class Pipeline: |
| """Loads all ONNX sessions once and runs the two-stage identification.""" |
|
|
| def __init__(self, models_dir: str = MODELS_DIR): |
| self._sessions: dict[str, ort.InferenceSession] = {} |
| self._classes: dict[str, list[str]] = {} |
| for name in [ROUTER, *EXPERTS]: |
| self._sessions[name] = ort.InferenceSession( |
| os.path.join(models_dir, f"{name}_logits.onnx"), |
| providers=["CPUExecutionProvider"], |
| ) |
| with open(os.path.join(models_dir, f"{name}_classes.json")) as f: |
| self._classes[name] = json.load(f) |
|
|
| self._energy_thr: dict[str, float] = {} |
| thr_path = os.path.join(models_dir, "energy_thresholds.json") |
| if ENABLE_ENERGY_SUPPRESSION and os.path.exists(thr_path): |
| with open(thr_path) as f: |
| table = json.load(f) |
| self._energy_thr = {n: v[ENERGY_SUPPRESS_PERCENTILE] for n, v in table.items()} |
|
|
| def _run(self, name: str, x: np.ndarray) -> tuple[str, float, float]: |
| """Returns (top_class, top_confidence, energy).""" |
| logits = self._sessions[name].run(None, {"input": x})[0][0] |
| probs = _softmax(logits) |
| idx = int(probs.argmax()) |
| return self._classes[name][idx], float(probs[idx]), _energy(logits) |
|
|
| def identify(self, img: Image.Image) -> dict: |
| """ |
| Returns a dict describing the call: |
| { domain, domain_confidence, abstain, reason?, |
| expert?, species?, confidence?, runner_up? } |
| """ |
| x = preprocess(img) |
|
|
| |
| domain, dconf, _ = self._run(ROUTER, x) |
| out = {"domain": domain, "domain_confidence": dconf} |
|
|
| if dconf < ROUTER_CONFIDENCE_THRESHOLD or domain not in DOMAIN_EXPERTS: |
| reason = "uncertain_domain" if dconf < ROUTER_CONFIDENCE_THRESHOLD else "off_domain" |
| return {**out, "abstain": True, "reason": reason} |
|
|
| |
| |
| ename = DOMAIN_EXPERTS[domain] |
| species, conf, energy = self._run(ename, x) |
| thr = self._energy_thr.get(ename) |
| if thr is not None and energy > thr: |
| return {**out, "abstain": True, "reason": "off_domain"} |
|
|
| call = {"expert": ename, "species": species, "confidence": conf, "energy": round(energy, 4)} |
| return {**out, "abstain": False, "calls": [call]} |
|
|