Spaces:
Sleeping
Sleeping
| """Inference-time predictor that loads the trained artifact and conforms to | |
| the `Predictor` protocol used by the router.""" | |
| from __future__ import annotations | |
| import json | |
| import math | |
| from pathlib import Path | |
| from typing import Optional | |
| from greenrouting.classifier.infer import ( | |
| CapabilityProfile, | |
| LENGTH_BUCKETS, | |
| LENGTH_TOKEN_TARGETS, | |
| LENGTH_P90_MULTIPLIER, | |
| QueryProfile, | |
| ) | |
| from greenrouting.classifier.model import Encoder, ModelSpec, build_head | |
| from greenrouting.classifier.ood import is_ood | |
| from greenrouting.routing.registry import CAPABILITY_KEYS | |
| class TrainedPredictor: | |
| def __init__(self, artifact_dir: str | Path): | |
| self.artifact_dir = Path(artifact_dir) | |
| self._loaded = False | |
| self._encoder: Optional[Encoder] = None | |
| self._head = None | |
| self._spec: Optional[ModelSpec] = None | |
| self._temperature: float = 1.0 | |
| self._ood_stats = None | |
| self._ood_thresholds = None | |
| self._ood_min_confidence: float = 0.40 | |
| def _ensure_loaded(self) -> None: | |
| if self._loaded: | |
| return | |
| import numpy as np | |
| import torch | |
| meta_path = self.artifact_dir / "metadata.json" | |
| meta = json.loads(meta_path.read_text()) | |
| encoder_name = (self.artifact_dir / "encoder_name.txt").read_text().strip() | |
| self._spec = ModelSpec( | |
| encoder_name=encoder_name, | |
| embedding_dim=int(meta["embedding_dim"]), | |
| hidden_dim=int(meta["hidden_dim"]), | |
| n_capabilities=len(meta["capability_keys"]), | |
| n_length_buckets=len(meta["length_buckets"]), | |
| max_seq_len=int(meta.get("max_seq_len", 256)), | |
| ) | |
| self._diff_center = float(meta.get("diff_target_center", math.log(8e9))) | |
| self._encoder = Encoder(encoder_name, max_seq_len=self._spec.max_seq_len) | |
| head = build_head(self._spec) | |
| head.load_state_dict(torch.load(self.artifact_dir / "head.pt", map_location="cpu")) | |
| head.to(self._encoder.device).eval() | |
| self._head = head | |
| cal_path = self.artifact_dir / "calibration.json" | |
| if cal_path.exists(): | |
| self._temperature = float(json.loads(cal_path.read_text()).get("temperature", 1.0)) | |
| ood_path = self.artifact_dir / "ood_stats.npz" | |
| if ood_path.exists(): | |
| data = np.load(ood_path) | |
| if "centroid" in data.files and "reference" in data.files: | |
| self._ood_stats = { | |
| "centroid": data["centroid"], | |
| "reference": data["reference"], | |
| "k": int(data["k"]) if "k" in data.files else 5, | |
| } | |
| self._ood_thresholds = { | |
| "centroid_threshold": float(data["centroid_threshold"]), | |
| "knn_threshold": float(data["knn_threshold"]), | |
| } | |
| self._loaded = True | |
| def predict(self, query: str) -> QueryProfile: | |
| import torch | |
| import torch.nn.functional as F | |
| self._ensure_loaded() | |
| text = (query or "").strip() | |
| emb = self._encoder.embed([text]) | |
| with torch.no_grad(): | |
| out = self._head(emb) | |
| cap_logits = (out["cap_logits"] / max(self._temperature, 1e-3)) | |
| cap_probs = torch.sigmoid(cap_logits)[0].cpu().numpy().tolist() | |
| cap_dict = {k: float(v) for k, v in zip(CAPABILITY_KEYS, cap_probs)} | |
| diff_centered = float(out["diff"][0].item()) | |
| diff_log_params = diff_centered + self._diff_center | |
| len_probs = F.softmax(out["len_logits"][0], dim=-1).cpu().numpy().tolist() | |
| length_dist = {b: float(p) for b, p in zip(LENGTH_BUCKETS, len_probs)} | |
| confidence = max(cap_dict.values()) if cap_dict else 0.0 | |
| confidence_ood = confidence < self._ood_min_confidence | |
| geometric_ood = False | |
| if self._ood_stats is not None and self._ood_thresholds is not None: | |
| emb_np = emb[0].cpu().numpy() | |
| geometric_ood = is_ood(emb_np, self._ood_stats, self._ood_thresholds) | |
| ood_flag = confidence_ood or geometric_ood | |
| in_tokens = max(1, int(len(text.split()) * 1.3) + 4) | |
| out_p50 = int(round(sum(length_dist[b] * LENGTH_TOKEN_TARGETS[b] for b in LENGTH_BUCKETS))) | |
| long_w = length_dist.get("long", 0.0) | |
| out_p90 = int(round(out_p50 * LENGTH_P90_MULTIPLIER + long_w * LENGTH_TOKEN_TARGETS["long"] * 0.3)) | |
| return QueryProfile( | |
| capabilities=CapabilityProfile(**cap_dict), | |
| difficulty_log_params=diff_log_params, | |
| length_dist=length_dist, | |
| expected_input_tokens=in_tokens, | |
| expected_output_tokens_p50=out_p50, | |
| expected_output_tokens_p90=out_p90, | |
| confidence=confidence, | |
| is_ood=ood_flag, | |
| raw_query=text, | |
| debug={ | |
| "source": "trained", | |
| "temperature": self._temperature, | |
| "confidence_ood": bool(confidence_ood), | |
| "geometric_ood": bool(geometric_ood), | |
| }, | |
| ) | |