"""Learned routing classifier — the confidence-gated upgrade to the regex router. smolcode's router historically guesses two things from cheap regex ([router.classify_specialty][engine.router.classify_specialty] and [router.classify_tier][engine.router.classify_tier]). This module adds tiny learned classifiers (SetFit backbone + light head, exported to int8 ONNX) that predict, per task: - **specialty** — which fine-tune family (16-way) - **tier** — a difficulty bucket -> the *starting* rung in the ladder - **escalate** — whether the task will likely need a bigger model Thinking level (off/low/high/xtra) is *derived* from (tier, escalate), not a separate model. The design is deliberately "pure upside": every prediction is gated by a calibrated confidence threshold. Below threshold — or if onnxruntime / the model artifacts aren't present at all — the field **falls back to the existing regex**, so we can never route worse than the status quo and rules-confident cases stay 100% deterministic. Heavy deps (onnxruntime, tokenizers, numpy) are imported lazily; if any is missing the classifier simply abstains everywhere and the regex drives routing. """ from __future__ import annotations import functools import json import os from pathlib import Path from pydantic import BaseModel, Field from .router import classify_specialty, classify_tier # Difficulty buckets the tier head predicts; mapped onto the ladder by # start = min(bucket, n_tiers - 1) — exactly classify_tier's clamping contract, # so the head stays ladder-length-agnostic. TIER_BUCKETS = 3 # Ordered thinking levels (matches smolcode-cli/src/router.rs Think enum). THINK_LEVELS = ("off", "low", "high", "xtra") # Default per-head confidence thresholds; overridden by router_clf.json's # "thresholds" map written at export/calibration time. _DEFAULT_TAU = {"specialty": 0.60, "tier": 0.55, "escalate": 0.65} _DEFAULT_DIR = Path(__file__).resolve().parent.parent / "finetune" / "router_clf" / "onnx" class RouteDecision(BaseModel): """The typed routing decision. `tier` is a start index into the active ladder.""" specialty: str tier: int escalate: bool think: str # Per-field model confidence (0.0 when the field came from regex/default). confidences: dict[str, float] = Field(default_factory=dict) # Per-field provenance: "model" | "regex" | "default" — for telemetry/debugging. sources: dict[str, str] = Field(default_factory=dict) def _softmax(row): # row: 1-D numpy array import numpy as np # If the ONNX head already emits a probability distribution, don't re-normalize # (argmax is unaffected either way, but confidence should stay honest). if row.min() >= 0.0 and abs(float(row.sum()) - 1.0) < 1e-3: return row e = np.exp(row - row.max()) return e / e.sum() class _OnnxHead: """A single ONNX sequence-classification head + its tokenizer and label map.""" def __init__(self, session, tokenizer, labels: list[str], input_names: set[str], max_len: int = 128) -> None: self.session = session self.tokenizer = tokenizer self.labels = labels self.input_names = input_names self.max_len = max_len @classmethod def try_load(cls, dpath: Path) -> "_OnnxHead | None": """Load model.onnx + tokenizer.json + labels.json from a dir, or None.""" model_file, tok_file, labels_file = ( dpath / "model.onnx", dpath / "tokenizer.json", dpath / "labels.json", ) if not (model_file.exists() and tok_file.exists() and labels_file.exists()): return None import onnxruntime as ort from tokenizers import Tokenizer sess = ort.InferenceSession( str(model_file), providers=["CPUExecutionProvider"], ) tok = Tokenizer.from_file(str(tok_file)) meta = json.loads(labels_file.read_text()) labels = meta["labels"] if isinstance(meta, dict) else list(meta) max_len = int(meta.get("max_len", 128)) if isinstance(meta, dict) else 128 input_names = {i.name for i in sess.get_inputs()} return cls(sess, tok, labels, input_names, max_len=max_len) def predict(self, text: str) -> tuple[str, float]: """(label, confidence) for the argmax class.""" import numpy as np enc = self.tokenizer.encode(text) ids = enc.ids[: self.max_len] mask = [1] * len(ids) feed = { "input_ids": np.asarray([ids], dtype=np.int64), "attention_mask": np.asarray([mask], dtype=np.int64), } if "token_type_ids" in self.input_names: feed["token_type_ids"] = np.zeros((1, len(ids)), dtype=np.int64) out = self.session.run(None, feed)[0] probs = _softmax(np.asarray(out)[0]) idx = int(probs.argmax()) return self.labels[idx], float(probs[idx]) class RouteClassifier: """Loads the (optional) ONNX heads and turns a task string into a RouteDecision. Always safe to construct: missing deps or artifacts -> empty `heads`, and every prediction abstains to the regex baseline. """ def __init__(self, model_dir: str | os.PathLike | None = None) -> None: self.model_dir = Path( model_dir or os.environ.get("SMALLCODE_ROUTER_CLF_DIR", _DEFAULT_DIR) ) self.heads: dict[str, _OnnxHead] = {} self.thresholds = dict(_DEFAULT_TAU) self.think_map: dict | None = None self._load() def _load(self) -> None: try: # the heavy trio — absent in a bare runtime, which is fine. import numpy # noqa: F401 import onnxruntime # noqa: F401 import tokenizers # noqa: F401 except Exception: return cfg_path = self.model_dir / "router_clf.json" if cfg_path.exists(): try: cfg = json.loads(cfg_path.read_text()) self.thresholds.update(cfg.get("thresholds", {})) self.think_map = cfg.get("think_map") except Exception: pass for name in ("specialty", "tier", "escalate"): try: head = _OnnxHead.try_load(self.model_dir / name) except Exception: head = None if head is not None: self.heads[name] = head @property def available(self) -> bool: return bool(self.heads) # --- per-decision helpers (model if confident, else regex/default) -------- def pick_specialty(self, task: str, specialties=None) -> tuple[str, float, str]: head = self.heads.get("specialty") if head is not None: label, conf = head.predict(task) ok = conf >= self.thresholds["specialty"] if ok and (specialties is None or label in specialties): return label, conf, "model" return classify_specialty(task), 0.0, "regex" def pick_tier(self, task: str, n_tiers: int) -> tuple[int, float, str]: head = self.heads.get("tier") if head is not None: label, conf = head.predict(task) if conf >= self.thresholds["tier"]: try: bucket = int(label) except ValueError: bucket = 0 return min(bucket, max(n_tiers - 1, 0)), conf, "model" return classify_tier(task, n_tiers), 0.0, "regex" def pick_escalate(self, task: str) -> tuple[bool, float, str]: head = self.heads.get("escalate") if head is not None: label, conf = head.predict(task) if conf >= self.thresholds["escalate"]: return label in ("1", "true", "yes", "escalate"), conf, "model" # No regex equivalent — default to "no escalation predicted". return False, 0.0, "default" def think_for(self, tier: int, n_tiers: int, escalate: bool) -> str: if self.think_map: key = f"{min(tier, n_tiers - 1)}:{int(escalate)}" lvl = self.think_map.get(key) or self.think_map.get(str(tier)) if lvl in THINK_LEVELS: return lvl return default_think(tier, n_tiers, escalate) def decide(self, task: str, *, specialties=None, n_tiers: int = 1) -> RouteDecision: sp, sp_c, sp_s = self.pick_specialty(task, specialties) tier, t_c, t_s = self.pick_tier(task, n_tiers) esc, e_c, e_s = self.pick_escalate(task) return RouteDecision( specialty=sp, tier=tier, escalate=esc, think=self.think_for(tier, n_tiers, esc), confidences={"specialty": sp_c, "tier": t_c, "escalate": e_c}, sources={"specialty": sp_s, "tier": t_s, "escalate": e_s}, ) def default_think(tier: int, n_tiers: int, escalate: bool) -> str: """Monotone map: a higher start rung / predicted escalation -> more thinking.""" if n_tiers <= 1: return "high" if escalate else "off" frac = tier / (n_tiers - 1) if frac >= 0.999: return "xtra" if escalate else "high" if frac >= 0.5: return "high" if escalate else "low" return "low" if escalate else "off" @functools.lru_cache(maxsize=1) def get_classifier() -> RouteClassifier: """Process-wide singleton (loads ONNX sessions once).""" return RouteClassifier() def classify_route(task: str, *, specialties=None, n_tiers: int = 1) -> RouteDecision: """Public entry: a typed, confidence-gated routing decision for `task`.""" return get_classifier().decide(task, specialties=specialties, n_tiers=n_tiers)