"""Maps the GreenRouting classifier output to the partner's response schema. Inputs: - QueryProfile from greenrouting.classifier (8 capability probabilities, continuous difficulty in log-parameters, length distribution) - PartnerRegistry of candidate models (tier + per-category 1-10 scores + cost) Outputs: - capability_weights: dict[7-key partner schema -> float in 0..1] - category: argmax over the 5-category public set - complexity: simple|moderate|complex - difficulty: integer 1..5 - chosen model_id from the registry - energy_savings_pct vs an always-ultra-tier baseline - reason string for the partner's audit log """ from __future__ import annotations import math from typing import Optional from greenrouting.classifier.infer import QueryProfile from partner_registry import PARTNER_SCORE_KEYS, PartnerModel, PartnerRegistry PUBLIC_CATEGORIES: tuple[str, ...] = ("chat", "code", "math", "research", "creative") COMPLEXITY_BUCKETS: tuple[str, ...] = ("simple", "moderate", "complex") ULTRA_BASELINE_COST: int = 10 def rebucket_capabilities(profile: QueryProfile) -> dict[str, float]: """Map our 8 internal capabilities to the partner's 7 score categories.""" c = profile.capabilities coding = c.code math_ = c.math research = min(1.0, c.reasoning + c.knowledge) creative = c.creative chat = min(1.0, c.simple_chat + c.instruction) roleplay = c.creative * 0.5 ideas = min(1.0, (c.creative + c.reasoning) * 0.4) return { "coding": round(coding, 3), "math": round(math_, 3), "research": round(research, 3), "creative": round(creative, 3), "chat": round(chat, 3), "roleplay": round(roleplay, 3), "ideas": round(ideas, 3), } def pick_category(weights: dict[str, float]) -> str: public = {k: weights[k] for k in ("chat", "coding", "math", "research", "creative")} top = max(public, key=public.get) if top == "coding": return "code" return top def pick_complexity(profile: QueryProfile) -> str: log_p = profile.difficulty_log_params if log_p < math.log(3e9): return "simple" if log_p < math.log(20e9): return "moderate" return "complex" def pick_difficulty_int(profile: QueryProfile) -> int: log_p = profile.difficulty_log_params boundaries = [math.log(b * 1e9) for b in (1, 5, 15, 50)] rank = 1 for b in boundaries: if log_p >= b: rank += 1 else: break return min(5, max(1, rank)) def _allowed_tiers(difficulty: int) -> set[str]: if difficulty <= 1: return {"lite", "standard"} if difficulty == 2: return {"lite", "standard"} if difficulty == 3: return {"standard", "pro"} if difficulty == 4: return {"pro", "ultra"} return {"ultra"} def quality_fit(model: PartnerModel, weights: dict[str, float]) -> float: total_weight = sum(weights[k] for k in PARTNER_SCORE_KEYS) or 1.0 weighted = sum(weights[k] * (model.scores.get(k, 0) / 10.0) for k in PARTNER_SCORE_KEYS) return weighted / total_weight def _best_ultra(registry: PartnerRegistry, weights: dict[str, float]) -> PartnerModel: ultras = registry.by_tier("ultra") pool = ultras if ultras else registry.models return max(pool, key=lambda m: quality_fit(m, weights)) def select_model( registry: PartnerRegistry, weights: dict[str, float], difficulty: int, is_ood: bool = False, quality_floor_ratio: float = 0.65, ) -> tuple[PartnerModel, bool]: """Returns (chosen_model, escalated). Escalated means we fell back to the ultra-tier anchor (low confidence in the prediction).""" if not registry.models: raise ValueError("partner registry is empty") if is_ood: return _best_ultra(registry, weights), True allowed = registry.by_tier(*_allowed_tiers(difficulty)) if not allowed: return _best_ultra(registry, weights), True best_allowed = max(allowed, key=lambda m: quality_fit(m, weights)) floor = quality_fit(best_allowed, weights) * quality_floor_ratio qualifying = [m for m in allowed if quality_fit(m, weights) >= floor] if not qualifying: return best_allowed, False chosen = min(qualifying, key=lambda m: (m.cost, -quality_fit(m, weights))) return chosen, False def energy_savings_pct(chosen: PartnerModel, baseline_cost: int = ULTRA_BASELINE_COST) -> float: if baseline_cost <= 0: return 0.0 saved = (baseline_cost - chosen.cost) / baseline_cost return max(0.0, min(1.0, saved)) * 100.0 def build_reason( weights: dict[str, float], complexity: str, chosen: PartnerModel, escalated: bool, is_ood: bool = False, ) -> str: top_cap, top_score = max(weights.items(), key=lambda kv: kv[1]) bits: list[str] = [] if is_ood: bits.append("low-confidence input (escalated to ultra tier)") elif top_score >= 0.5: bits.append(f"{top_cap} dominant ({top_score:.2f})") else: bits.append("mixed signal") if not is_ood: bits.append(f"{complexity} difficulty") if escalated and not is_ood: bits.append("escalated (no qualifying tier-allowed model)") elif not escalated: bits.append(f"picked {chosen.tier} tier (cost {chosen.cost})") return ", ".join(bits) def fold_recent_context(message: str, recent: Optional[list[dict]]) -> str: if not recent: return message last = recent[-1] content = (last.get("content") or "")[:200] if isinstance(last, dict) else "" if not content: return message return f"{content}\n{message}"