HomesteaderLabs's picture
Deploy Forager's Field Station
0e3ea0a verified
"""
infer.py β€” Two-stage ONNX inference for the Field Station Space.
Mirrors the on-device forager_ml pipeline (domain router -> expert routing ->
abstention) but runs on CPU via onnxruntime instead of the Hailo NPU.
Stage 1: domain router classifies the frame into berry / mushroom / plant / other.
Stage 2: route to the matching expert(s); for multi-expert domains run both and
keep the higher-confidence call. Abstain (UNKNOWN) when the router is
unsure, the domain is "other", or the winning expert is below threshold.
Preprocessing note: these ONNX files are the bare timm tf_efficientnet_lite2
models (no normalization baked in), so inputs are ImageNet-normalized
[1, 3, 224, 224] β€” NOT the [0,255] NHWC the HEF expects.
"""
import json
import os
import numpy as np
import onnxruntime as ort
from PIL import Image
MODELS_DIR = os.path.join(os.path.dirname(__file__), "..", "models")
ROUTER = "domain_router_v2"
# The psychedelics/mycologist expert is intentionally NOT shipped in this public
# Space: it is never routed to (mushroom -> highvalue only) and a psilocybin
# identifier invites policy scrutiny for zero functional gain. It still lives in
# forager_ml and can ship as its own model repo.
EXPERTS = ["berry_expert", "highvalue_expert", "medicinals_expert"]
# Router domain -> the ONE expert that owns it. Single-expert routing (no
# cross-expert voting): an off-domain expert never gets to misclassify an input
# it doesn't own β€” e.g. highvalue never sees a plant, so it can't call a hemlock
# "ramps". The deadly plants live in medicinals (0% toxic-as-edible FAR).
# "other" is intentionally absent => abstain. The mycologist/psychedelics expert
# is held out of the live path (weak on real photos; benched).
DOMAIN_EXPERTS: dict[str, str] = {
"berry": "berry_expert",
"mushroom": "highvalue_expert",
"plant": "medicinals_expert",
}
# Gates (match the on-device convergence thresholds).
ROUTER_CONFIDENCE_THRESHOLD = 0.74
EXPERT_CONFIDENCE_THRESHOLD = 0.75
# Energy-OOD vote suppression: an expert's vote is dropped when its input energy
# exceeds the in-domain threshold (i.e. the frame is out-of-domain for that
# expert). This stops an off-domain expert from out-voting the correct one β€”
# e.g. highvalue calling a hemlock "ramps". Thresholds are fp32 in-domain
# percentiles in models/energy_thresholds.json (correct -logsumexp energy).
# With single-expert routing there are no competing votes to suppress, and
# val-calibrated thresholds over-fire on real photos. Off by default; the
# router's "other" class + the confidence gate carry OOD.
ENABLE_ENERGY_SUPPRESSION = False
ENERGY_SUPPRESS_PERCENTILE = "p90"
_IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
_RESIZE_SHORT = int(224 * 1.14) # 255, matches the training/val transform
_CROP = 224
def preprocess(img: Image.Image) -> np.ndarray:
"""PIL image -> ImageNet-normalized float32 NCHW [1, 3, 224, 224]."""
img = img.convert("RGB")
w, h = img.size
if w <= h:
nw, nh = _RESIZE_SHORT, round(_RESIZE_SHORT * h / w)
else:
nw, nh = round(_RESIZE_SHORT * w / h), _RESIZE_SHORT
img = img.resize((nw, nh), Image.BILINEAR)
left = (nw - _CROP) // 2
top = (nh - _CROP) // 2
img = img.crop((left, top, left + _CROP, top + _CROP))
x = np.asarray(img, dtype=np.float32) / 255.0 # HWC, [0,1]
x = (x - _IMAGENET_MEAN) / _IMAGENET_STD # ImageNet normalize
x = np.transpose(x, (2, 0, 1))[None] # 1, C, H, W
return np.ascontiguousarray(x, dtype=np.float32)
def _softmax(logits: np.ndarray) -> np.ndarray:
z = logits - logits.max()
e = np.exp(z)
return e / e.sum()
def _energy(logits: np.ndarray) -> float:
"""Correct energy E(x) = -logsumexp(logits). Higher = more out-of-domain."""
m = logits.max()
return -float(m + np.log(np.exp(logits - m).sum()))
class Pipeline:
"""Loads all ONNX sessions once and runs the two-stage identification."""
def __init__(self, models_dir: str = MODELS_DIR):
self._sessions: dict[str, ort.InferenceSession] = {}
self._classes: dict[str, list[str]] = {}
for name in [ROUTER, *EXPERTS]:
self._sessions[name] = ort.InferenceSession(
os.path.join(models_dir, f"{name}_logits.onnx"),
providers=["CPUExecutionProvider"],
)
with open(os.path.join(models_dir, f"{name}_classes.json")) as f:
self._classes[name] = json.load(f)
self._energy_thr: dict[str, float] = {}
thr_path = os.path.join(models_dir, "energy_thresholds.json")
if ENABLE_ENERGY_SUPPRESSION and os.path.exists(thr_path):
with open(thr_path) as f:
table = json.load(f)
self._energy_thr = {n: v[ENERGY_SUPPRESS_PERCENTILE] for n, v in table.items()}
def _run(self, name: str, x: np.ndarray) -> tuple[str, float, float]:
"""Returns (top_class, top_confidence, energy)."""
logits = self._sessions[name].run(None, {"input": x})[0][0]
probs = _softmax(logits)
idx = int(probs.argmax())
return self._classes[name][idx], float(probs[idx]), _energy(logits)
def identify(self, img: Image.Image) -> dict:
"""
Returns a dict describing the call:
{ domain, domain_confidence, abstain, reason?,
expert?, species?, confidence?, runner_up? }
"""
x = preprocess(img)
# ── Stage 1: domain router ───────────────────────────────────────────
domain, dconf, _ = self._run(ROUTER, x)
out = {"domain": domain, "domain_confidence": dconf}
if dconf < ROUTER_CONFIDENCE_THRESHOLD or domain not in DOMAIN_EXPERTS:
reason = "uncertain_domain" if dconf < ROUTER_CONFIDENCE_THRESHOLD else "off_domain"
return {**out, "abstain": True, "reason": reason}
# ── Stage 2: run the single expert that owns this domain. Optional
# energy gate abstains if the frame is out-of-domain for that expert.
ename = DOMAIN_EXPERTS[domain]
species, conf, energy = self._run(ename, x)
thr = self._energy_thr.get(ename)
if thr is not None and energy > thr:
return {**out, "abstain": True, "reason": "off_domain"}
call = {"expert": ename, "species": species, "confidence": conf, "energy": round(energy, 4)}
return {**out, "abstain": False, "calls": [call]}