router-api / greenrouting /classifier /trained_predictor.py
spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""Inference-time predictor that loads the trained artifact and conforms to
the `Predictor` protocol used by the router."""
from __future__ import annotations
import json
import math
from pathlib import Path
from typing import Optional
from greenrouting.classifier.infer import (
CapabilityProfile,
LENGTH_BUCKETS,
LENGTH_TOKEN_TARGETS,
LENGTH_P90_MULTIPLIER,
QueryProfile,
)
from greenrouting.classifier.model import Encoder, ModelSpec, build_head
from greenrouting.classifier.ood import is_ood
from greenrouting.routing.registry import CAPABILITY_KEYS
class TrainedPredictor:
def __init__(self, artifact_dir: str | Path):
self.artifact_dir = Path(artifact_dir)
self._loaded = False
self._encoder: Optional[Encoder] = None
self._head = None
self._spec: Optional[ModelSpec] = None
self._temperature: float = 1.0
self._ood_stats = None
self._ood_thresholds = None
self._ood_min_confidence: float = 0.40
def _ensure_loaded(self) -> None:
if self._loaded:
return
import numpy as np
import torch
meta_path = self.artifact_dir / "metadata.json"
meta = json.loads(meta_path.read_text())
encoder_name = (self.artifact_dir / "encoder_name.txt").read_text().strip()
self._spec = ModelSpec(
encoder_name=encoder_name,
embedding_dim=int(meta["embedding_dim"]),
hidden_dim=int(meta["hidden_dim"]),
n_capabilities=len(meta["capability_keys"]),
n_length_buckets=len(meta["length_buckets"]),
max_seq_len=int(meta.get("max_seq_len", 256)),
)
self._diff_center = float(meta.get("diff_target_center", math.log(8e9)))
self._encoder = Encoder(encoder_name, max_seq_len=self._spec.max_seq_len)
head = build_head(self._spec)
head.load_state_dict(torch.load(self.artifact_dir / "head.pt", map_location="cpu"))
head.to(self._encoder.device).eval()
self._head = head
cal_path = self.artifact_dir / "calibration.json"
if cal_path.exists():
self._temperature = float(json.loads(cal_path.read_text()).get("temperature", 1.0))
ood_path = self.artifact_dir / "ood_stats.npz"
if ood_path.exists():
data = np.load(ood_path)
if "centroid" in data.files and "reference" in data.files:
self._ood_stats = {
"centroid": data["centroid"],
"reference": data["reference"],
"k": int(data["k"]) if "k" in data.files else 5,
}
self._ood_thresholds = {
"centroid_threshold": float(data["centroid_threshold"]),
"knn_threshold": float(data["knn_threshold"]),
}
self._loaded = True
def predict(self, query: str) -> QueryProfile:
import torch
import torch.nn.functional as F
self._ensure_loaded()
text = (query or "").strip()
emb = self._encoder.embed([text])
with torch.no_grad():
out = self._head(emb)
cap_logits = (out["cap_logits"] / max(self._temperature, 1e-3))
cap_probs = torch.sigmoid(cap_logits)[0].cpu().numpy().tolist()
cap_dict = {k: float(v) for k, v in zip(CAPABILITY_KEYS, cap_probs)}
diff_centered = float(out["diff"][0].item())
diff_log_params = diff_centered + self._diff_center
len_probs = F.softmax(out["len_logits"][0], dim=-1).cpu().numpy().tolist()
length_dist = {b: float(p) for b, p in zip(LENGTH_BUCKETS, len_probs)}
confidence = max(cap_dict.values()) if cap_dict else 0.0
confidence_ood = confidence < self._ood_min_confidence
geometric_ood = False
if self._ood_stats is not None and self._ood_thresholds is not None:
emb_np = emb[0].cpu().numpy()
geometric_ood = is_ood(emb_np, self._ood_stats, self._ood_thresholds)
ood_flag = confidence_ood or geometric_ood
in_tokens = max(1, int(len(text.split()) * 1.3) + 4)
out_p50 = int(round(sum(length_dist[b] * LENGTH_TOKEN_TARGETS[b] for b in LENGTH_BUCKETS)))
long_w = length_dist.get("long", 0.0)
out_p90 = int(round(out_p50 * LENGTH_P90_MULTIPLIER + long_w * LENGTH_TOKEN_TARGETS["long"] * 0.3))
return QueryProfile(
capabilities=CapabilityProfile(**cap_dict),
difficulty_log_params=diff_log_params,
length_dist=length_dist,
expected_input_tokens=in_tokens,
expected_output_tokens_p50=out_p50,
expected_output_tokens_p90=out_p90,
confidence=confidence,
is_ood=ood_flag,
raw_query=text,
debug={
"source": "trained",
"temperature": self._temperature,
"confidence_ood": bool(confidence_ood),
"geometric_ood": bool(geometric_ood),
},
)