smolcode / engine /route_clf.py
seanpoyner's picture
Upload folder using huggingface_hub
daea45b verified
Raw
History Blame Contribute Delete
9.69 kB
"""Learned routing classifier — the confidence-gated upgrade to the regex router.
smolcode's router historically guesses two things from cheap regex
([router.classify_specialty][engine.router.classify_specialty] and
[router.classify_tier][engine.router.classify_tier]). This module adds tiny
learned classifiers (SetFit backbone + light head, exported to int8 ONNX) that
predict, per task:
- **specialty** — which fine-tune family (16-way)
- **tier** — a difficulty bucket -> the *starting* rung in the ladder
- **escalate** — whether the task will likely need a bigger model
Thinking level (off/low/high/xtra) is *derived* from (tier, escalate), not a
separate model.
The design is deliberately "pure upside": every prediction is gated by a
calibrated confidence threshold. Below threshold — or if onnxruntime / the model
artifacts aren't present at all — the field **falls back to the existing regex**,
so we can never route worse than the status quo and rules-confident cases stay
100% deterministic.
Heavy deps (onnxruntime, tokenizers, numpy) are imported lazily; if any is
missing the classifier simply abstains everywhere and the regex drives routing.
"""
from __future__ import annotations
import functools
import json
import os
from pathlib import Path
from pydantic import BaseModel, Field
from .router import classify_specialty, classify_tier
# Difficulty buckets the tier head predicts; mapped onto the ladder by
# start = min(bucket, n_tiers - 1) — exactly classify_tier's clamping contract,
# so the head stays ladder-length-agnostic.
TIER_BUCKETS = 3
# Ordered thinking levels (matches smolcode-cli/src/router.rs Think enum).
THINK_LEVELS = ("off", "low", "high", "xtra")
# Default per-head confidence thresholds; overridden by router_clf.json's
# "thresholds" map written at export/calibration time.
_DEFAULT_TAU = {"specialty": 0.60, "tier": 0.55, "escalate": 0.65}
_DEFAULT_DIR = Path(__file__).resolve().parent.parent / "finetune" / "router_clf" / "onnx"
class RouteDecision(BaseModel):
"""The typed routing decision. `tier` is a start index into the active ladder."""
specialty: str
tier: int
escalate: bool
think: str
# Per-field model confidence (0.0 when the field came from regex/default).
confidences: dict[str, float] = Field(default_factory=dict)
# Per-field provenance: "model" | "regex" | "default" — for telemetry/debugging.
sources: dict[str, str] = Field(default_factory=dict)
def _softmax(row): # row: 1-D numpy array
import numpy as np
# If the ONNX head already emits a probability distribution, don't re-normalize
# (argmax is unaffected either way, but confidence should stay honest).
if row.min() >= 0.0 and abs(float(row.sum()) - 1.0) < 1e-3:
return row
e = np.exp(row - row.max())
return e / e.sum()
class _OnnxHead:
"""A single ONNX sequence-classification head + its tokenizer and label map."""
def __init__(self, session, tokenizer, labels: list[str], input_names: set[str],
max_len: int = 128) -> None:
self.session = session
self.tokenizer = tokenizer
self.labels = labels
self.input_names = input_names
self.max_len = max_len
@classmethod
def try_load(cls, dpath: Path) -> "_OnnxHead | None":
"""Load model.onnx + tokenizer.json + labels.json from a dir, or None."""
model_file, tok_file, labels_file = (
dpath / "model.onnx", dpath / "tokenizer.json", dpath / "labels.json",
)
if not (model_file.exists() and tok_file.exists() and labels_file.exists()):
return None
import onnxruntime as ort
from tokenizers import Tokenizer
sess = ort.InferenceSession(
str(model_file), providers=["CPUExecutionProvider"],
)
tok = Tokenizer.from_file(str(tok_file))
meta = json.loads(labels_file.read_text())
labels = meta["labels"] if isinstance(meta, dict) else list(meta)
max_len = int(meta.get("max_len", 128)) if isinstance(meta, dict) else 128
input_names = {i.name for i in sess.get_inputs()}
return cls(sess, tok, labels, input_names, max_len=max_len)
def predict(self, text: str) -> tuple[str, float]:
"""(label, confidence) for the argmax class."""
import numpy as np
enc = self.tokenizer.encode(text)
ids = enc.ids[: self.max_len]
mask = [1] * len(ids)
feed = {
"input_ids": np.asarray([ids], dtype=np.int64),
"attention_mask": np.asarray([mask], dtype=np.int64),
}
if "token_type_ids" in self.input_names:
feed["token_type_ids"] = np.zeros((1, len(ids)), dtype=np.int64)
out = self.session.run(None, feed)[0]
probs = _softmax(np.asarray(out)[0])
idx = int(probs.argmax())
return self.labels[idx], float(probs[idx])
class RouteClassifier:
"""Loads the (optional) ONNX heads and turns a task string into a RouteDecision.
Always safe to construct: missing deps or artifacts -> empty `heads`, and every
prediction abstains to the regex baseline.
"""
def __init__(self, model_dir: str | os.PathLike | None = None) -> None:
self.model_dir = Path(
model_dir or os.environ.get("SMALLCODE_ROUTER_CLF_DIR", _DEFAULT_DIR)
)
self.heads: dict[str, _OnnxHead] = {}
self.thresholds = dict(_DEFAULT_TAU)
self.think_map: dict | None = None
self._load()
def _load(self) -> None:
try: # the heavy trio — absent in a bare runtime, which is fine.
import numpy # noqa: F401
import onnxruntime # noqa: F401
import tokenizers # noqa: F401
except Exception:
return
cfg_path = self.model_dir / "router_clf.json"
if cfg_path.exists():
try:
cfg = json.loads(cfg_path.read_text())
self.thresholds.update(cfg.get("thresholds", {}))
self.think_map = cfg.get("think_map")
except Exception:
pass
for name in ("specialty", "tier", "escalate"):
try:
head = _OnnxHead.try_load(self.model_dir / name)
except Exception:
head = None
if head is not None:
self.heads[name] = head
@property
def available(self) -> bool:
return bool(self.heads)
# --- per-decision helpers (model if confident, else regex/default) --------
def pick_specialty(self, task: str, specialties=None) -> tuple[str, float, str]:
head = self.heads.get("specialty")
if head is not None:
label, conf = head.predict(task)
ok = conf >= self.thresholds["specialty"]
if ok and (specialties is None or label in specialties):
return label, conf, "model"
return classify_specialty(task), 0.0, "regex"
def pick_tier(self, task: str, n_tiers: int) -> tuple[int, float, str]:
head = self.heads.get("tier")
if head is not None:
label, conf = head.predict(task)
if conf >= self.thresholds["tier"]:
try:
bucket = int(label)
except ValueError:
bucket = 0
return min(bucket, max(n_tiers - 1, 0)), conf, "model"
return classify_tier(task, n_tiers), 0.0, "regex"
def pick_escalate(self, task: str) -> tuple[bool, float, str]:
head = self.heads.get("escalate")
if head is not None:
label, conf = head.predict(task)
if conf >= self.thresholds["escalate"]:
return label in ("1", "true", "yes", "escalate"), conf, "model"
# No regex equivalent — default to "no escalation predicted".
return False, 0.0, "default"
def think_for(self, tier: int, n_tiers: int, escalate: bool) -> str:
if self.think_map:
key = f"{min(tier, n_tiers - 1)}:{int(escalate)}"
lvl = self.think_map.get(key) or self.think_map.get(str(tier))
if lvl in THINK_LEVELS:
return lvl
return default_think(tier, n_tiers, escalate)
def decide(self, task: str, *, specialties=None, n_tiers: int = 1) -> RouteDecision:
sp, sp_c, sp_s = self.pick_specialty(task, specialties)
tier, t_c, t_s = self.pick_tier(task, n_tiers)
esc, e_c, e_s = self.pick_escalate(task)
return RouteDecision(
specialty=sp,
tier=tier,
escalate=esc,
think=self.think_for(tier, n_tiers, esc),
confidences={"specialty": sp_c, "tier": t_c, "escalate": e_c},
sources={"specialty": sp_s, "tier": t_s, "escalate": e_s},
)
def default_think(tier: int, n_tiers: int, escalate: bool) -> str:
"""Monotone map: a higher start rung / predicted escalation -> more thinking."""
if n_tiers <= 1:
return "high" if escalate else "off"
frac = tier / (n_tiers - 1)
if frac >= 0.999:
return "xtra" if escalate else "high"
if frac >= 0.5:
return "high" if escalate else "low"
return "low" if escalate else "off"
@functools.lru_cache(maxsize=1)
def get_classifier() -> RouteClassifier:
"""Process-wide singleton (loads ONNX sessions once)."""
return RouteClassifier()
def classify_route(task: str, *, specialties=None, n_tiers: int = 1) -> RouteDecision:
"""Public entry: a typed, confidence-gated routing decision for `task`."""
return get_classifier().decide(task, specialties=specialties, n_tiers=n_tiers)