"""Inference-time predictor that loads the trained artifact and conforms to the `Predictor` protocol used by the router.""" from __future__ import annotations import json import math from pathlib import Path from typing import Optional from greenrouting.classifier.infer import ( CapabilityProfile, LENGTH_BUCKETS, LENGTH_TOKEN_TARGETS, LENGTH_P90_MULTIPLIER, QueryProfile, ) from greenrouting.classifier.model import Encoder, ModelSpec, build_head from greenrouting.classifier.ood import is_ood from greenrouting.routing.registry import CAPABILITY_KEYS class TrainedPredictor: def __init__(self, artifact_dir: str | Path): self.artifact_dir = Path(artifact_dir) self._loaded = False self._encoder: Optional[Encoder] = None self._head = None self._spec: Optional[ModelSpec] = None self._temperature: float = 1.0 self._ood_stats = None self._ood_thresholds = None self._ood_min_confidence: float = 0.40 def _ensure_loaded(self) -> None: if self._loaded: return import numpy as np import torch meta_path = self.artifact_dir / "metadata.json" meta = json.loads(meta_path.read_text()) encoder_name = (self.artifact_dir / "encoder_name.txt").read_text().strip() self._spec = ModelSpec( encoder_name=encoder_name, embedding_dim=int(meta["embedding_dim"]), hidden_dim=int(meta["hidden_dim"]), n_capabilities=len(meta["capability_keys"]), n_length_buckets=len(meta["length_buckets"]), max_seq_len=int(meta.get("max_seq_len", 256)), ) self._diff_center = float(meta.get("diff_target_center", math.log(8e9))) self._encoder = Encoder(encoder_name, max_seq_len=self._spec.max_seq_len) head = build_head(self._spec) head.load_state_dict(torch.load(self.artifact_dir / "head.pt", map_location="cpu")) head.to(self._encoder.device).eval() self._head = head cal_path = self.artifact_dir / "calibration.json" if cal_path.exists(): self._temperature = float(json.loads(cal_path.read_text()).get("temperature", 1.0)) ood_path = self.artifact_dir / "ood_stats.npz" if ood_path.exists(): data = np.load(ood_path) if "centroid" in data.files and "reference" in data.files: self._ood_stats = { "centroid": data["centroid"], "reference": data["reference"], "k": int(data["k"]) if "k" in data.files else 5, } self._ood_thresholds = { "centroid_threshold": float(data["centroid_threshold"]), "knn_threshold": float(data["knn_threshold"]), } self._loaded = True def predict(self, query: str) -> QueryProfile: import torch import torch.nn.functional as F self._ensure_loaded() text = (query or "").strip() emb = self._encoder.embed([text]) with torch.no_grad(): out = self._head(emb) cap_logits = (out["cap_logits"] / max(self._temperature, 1e-3)) cap_probs = torch.sigmoid(cap_logits)[0].cpu().numpy().tolist() cap_dict = {k: float(v) for k, v in zip(CAPABILITY_KEYS, cap_probs)} diff_centered = float(out["diff"][0].item()) diff_log_params = diff_centered + self._diff_center len_probs = F.softmax(out["len_logits"][0], dim=-1).cpu().numpy().tolist() length_dist = {b: float(p) for b, p in zip(LENGTH_BUCKETS, len_probs)} confidence = max(cap_dict.values()) if cap_dict else 0.0 confidence_ood = confidence < self._ood_min_confidence geometric_ood = False if self._ood_stats is not None and self._ood_thresholds is not None: emb_np = emb[0].cpu().numpy() geometric_ood = is_ood(emb_np, self._ood_stats, self._ood_thresholds) ood_flag = confidence_ood or geometric_ood in_tokens = max(1, int(len(text.split()) * 1.3) + 4) out_p50 = int(round(sum(length_dist[b] * LENGTH_TOKEN_TARGETS[b] for b in LENGTH_BUCKETS))) long_w = length_dist.get("long", 0.0) out_p90 = int(round(out_p50 * LENGTH_P90_MULTIPLIER + long_w * LENGTH_TOKEN_TARGETS["long"] * 0.3)) return QueryProfile( capabilities=CapabilityProfile(**cap_dict), difficulty_log_params=diff_log_params, length_dist=length_dist, expected_input_tokens=in_tokens, expected_output_tokens_p50=out_p50, expected_output_tokens_p90=out_p90, confidence=confidence, is_ood=ood_flag, raw_query=text, debug={ "source": "trained", "temperature": self._temperature, "confidence_ood": bool(confidence_ood), "geometric_ood": bool(geometric_ood), }, )