""" infer.py — Two-stage ONNX inference for the Field Station Space. Mirrors the on-device forager_ml pipeline (domain router -> expert routing -> abstention) but runs on CPU via onnxruntime instead of the Hailo NPU. Stage 1: domain router classifies the frame into berry / mushroom / plant / other. Stage 2: route to the matching expert(s); for multi-expert domains run both and keep the higher-confidence call. Abstain (UNKNOWN) when the router is unsure, the domain is "other", or the winning expert is below threshold. Preprocessing note: these ONNX files are the bare timm tf_efficientnet_lite2 models (no normalization baked in), so inputs are ImageNet-normalized [1, 3, 224, 224] — NOT the [0,255] NHWC the HEF expects. """ import json import os import numpy as np import onnxruntime as ort from PIL import Image MODELS_DIR = os.path.join(os.path.dirname(__file__), "..", "models") ROUTER = "domain_router_v2" # The psychedelics/mycologist expert is intentionally NOT shipped in this public # Space: it is never routed to (mushroom -> highvalue only) and a psilocybin # identifier invites policy scrutiny for zero functional gain. It still lives in # forager_ml and can ship as its own model repo. EXPERTS = ["berry_expert", "highvalue_expert", "medicinals_expert"] # Router domain -> the ONE expert that owns it. Single-expert routing (no # cross-expert voting): an off-domain expert never gets to misclassify an input # it doesn't own — e.g. highvalue never sees a plant, so it can't call a hemlock # "ramps". The deadly plants live in medicinals (0% toxic-as-edible FAR). # "other" is intentionally absent => abstain. The mycologist/psychedelics expert # is held out of the live path (weak on real photos; benched). DOMAIN_EXPERTS: dict[str, str] = { "berry": "berry_expert", "mushroom": "highvalue_expert", "plant": "medicinals_expert", } # Gates (match the on-device convergence thresholds). ROUTER_CONFIDENCE_THRESHOLD = 0.74 EXPERT_CONFIDENCE_THRESHOLD = 0.75 # Energy-OOD vote suppression: an expert's vote is dropped when its input energy # exceeds the in-domain threshold (i.e. the frame is out-of-domain for that # expert). This stops an off-domain expert from out-voting the correct one — # e.g. highvalue calling a hemlock "ramps". Thresholds are fp32 in-domain # percentiles in models/energy_thresholds.json (correct -logsumexp energy). # With single-expert routing there are no competing votes to suppress, and # val-calibrated thresholds over-fire on real photos. Off by default; the # router's "other" class + the confidence gate carry OOD. ENABLE_ENERGY_SUPPRESSION = False ENERGY_SUPPRESS_PERCENTILE = "p90" _IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) _IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) _RESIZE_SHORT = int(224 * 1.14) # 255, matches the training/val transform _CROP = 224 def preprocess(img: Image.Image) -> np.ndarray: """PIL image -> ImageNet-normalized float32 NCHW [1, 3, 224, 224].""" img = img.convert("RGB") w, h = img.size if w <= h: nw, nh = _RESIZE_SHORT, round(_RESIZE_SHORT * h / w) else: nw, nh = round(_RESIZE_SHORT * w / h), _RESIZE_SHORT img = img.resize((nw, nh), Image.BILINEAR) left = (nw - _CROP) // 2 top = (nh - _CROP) // 2 img = img.crop((left, top, left + _CROP, top + _CROP)) x = np.asarray(img, dtype=np.float32) / 255.0 # HWC, [0,1] x = (x - _IMAGENET_MEAN) / _IMAGENET_STD # ImageNet normalize x = np.transpose(x, (2, 0, 1))[None] # 1, C, H, W return np.ascontiguousarray(x, dtype=np.float32) def _softmax(logits: np.ndarray) -> np.ndarray: z = logits - logits.max() e = np.exp(z) return e / e.sum() def _energy(logits: np.ndarray) -> float: """Correct energy E(x) = -logsumexp(logits). Higher = more out-of-domain.""" m = logits.max() return -float(m + np.log(np.exp(logits - m).sum())) class Pipeline: """Loads all ONNX sessions once and runs the two-stage identification.""" def __init__(self, models_dir: str = MODELS_DIR): self._sessions: dict[str, ort.InferenceSession] = {} self._classes: dict[str, list[str]] = {} for name in [ROUTER, *EXPERTS]: self._sessions[name] = ort.InferenceSession( os.path.join(models_dir, f"{name}_logits.onnx"), providers=["CPUExecutionProvider"], ) with open(os.path.join(models_dir, f"{name}_classes.json")) as f: self._classes[name] = json.load(f) self._energy_thr: dict[str, float] = {} thr_path = os.path.join(models_dir, "energy_thresholds.json") if ENABLE_ENERGY_SUPPRESSION and os.path.exists(thr_path): with open(thr_path) as f: table = json.load(f) self._energy_thr = {n: v[ENERGY_SUPPRESS_PERCENTILE] for n, v in table.items()} def _run(self, name: str, x: np.ndarray) -> tuple[str, float, float]: """Returns (top_class, top_confidence, energy).""" logits = self._sessions[name].run(None, {"input": x})[0][0] probs = _softmax(logits) idx = int(probs.argmax()) return self._classes[name][idx], float(probs[idx]), _energy(logits) def identify(self, img: Image.Image) -> dict: """ Returns a dict describing the call: { domain, domain_confidence, abstain, reason?, expert?, species?, confidence?, runner_up? } """ x = preprocess(img) # ── Stage 1: domain router ─────────────────────────────────────────── domain, dconf, _ = self._run(ROUTER, x) out = {"domain": domain, "domain_confidence": dconf} if dconf < ROUTER_CONFIDENCE_THRESHOLD or domain not in DOMAIN_EXPERTS: reason = "uncertain_domain" if dconf < ROUTER_CONFIDENCE_THRESHOLD else "off_domain" return {**out, "abstain": True, "reason": reason} # ── Stage 2: run the single expert that owns this domain. Optional # energy gate abstains if the frame is out-of-domain for that expert. ename = DOMAIN_EXPERTS[domain] species, conf, energy = self._run(ename, x) thr = self._energy_thr.get(ename) if thr is not None and energy > thr: return {**out, "abstain": True, "reason": "off_domain"} call = {"expert": ename, "species": species, "confidence": conf, "energy": round(energy, 4)} return {**out, "abstain": False, "calls": [call]}