Spaces:
Sleeping
Sleeping
| """Maps the GreenRouting classifier output to the partner's response schema. | |
| Inputs: | |
| - QueryProfile from greenrouting.classifier (8 capability probabilities, | |
| continuous difficulty in log-parameters, length distribution) | |
| - PartnerRegistry of candidate models (tier + per-category 1-10 scores + cost) | |
| Outputs: | |
| - capability_weights: dict[7-key partner schema -> float in 0..1] | |
| - category: argmax over the 5-category public set | |
| - complexity: simple|moderate|complex | |
| - difficulty: integer 1..5 | |
| - chosen model_id from the registry | |
| - energy_savings_pct vs an always-ultra-tier baseline | |
| - reason string for the partner's audit log | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import Optional | |
| from greenrouting.classifier.infer import QueryProfile | |
| from partner_registry import PARTNER_SCORE_KEYS, PartnerModel, PartnerRegistry | |
| PUBLIC_CATEGORIES: tuple[str, ...] = ("chat", "code", "math", "research", "creative") | |
| COMPLEXITY_BUCKETS: tuple[str, ...] = ("simple", "moderate", "complex") | |
| ULTRA_BASELINE_COST: int = 10 | |
| def rebucket_capabilities(profile: QueryProfile) -> dict[str, float]: | |
| """Map our 8 internal capabilities to the partner's 7 score categories.""" | |
| c = profile.capabilities | |
| coding = c.code | |
| math_ = c.math | |
| research = min(1.0, c.reasoning + c.knowledge) | |
| creative = c.creative | |
| chat = min(1.0, c.simple_chat + c.instruction) | |
| roleplay = c.creative * 0.5 | |
| ideas = min(1.0, (c.creative + c.reasoning) * 0.4) | |
| return { | |
| "coding": round(coding, 3), | |
| "math": round(math_, 3), | |
| "research": round(research, 3), | |
| "creative": round(creative, 3), | |
| "chat": round(chat, 3), | |
| "roleplay": round(roleplay, 3), | |
| "ideas": round(ideas, 3), | |
| } | |
| def pick_category(weights: dict[str, float]) -> str: | |
| public = {k: weights[k] for k in ("chat", "coding", "math", "research", "creative")} | |
| top = max(public, key=public.get) | |
| if top == "coding": | |
| return "code" | |
| return top | |
| def pick_complexity(profile: QueryProfile) -> str: | |
| log_p = profile.difficulty_log_params | |
| if log_p < math.log(3e9): | |
| return "simple" | |
| if log_p < math.log(20e9): | |
| return "moderate" | |
| return "complex" | |
| def pick_difficulty_int(profile: QueryProfile) -> int: | |
| log_p = profile.difficulty_log_params | |
| boundaries = [math.log(b * 1e9) for b in (1, 5, 15, 50)] | |
| rank = 1 | |
| for b in boundaries: | |
| if log_p >= b: | |
| rank += 1 | |
| else: | |
| break | |
| return min(5, max(1, rank)) | |
| def _allowed_tiers(difficulty: int) -> set[str]: | |
| if difficulty <= 1: | |
| return {"lite", "standard"} | |
| if difficulty == 2: | |
| return {"lite", "standard"} | |
| if difficulty == 3: | |
| return {"standard", "pro"} | |
| if difficulty == 4: | |
| return {"pro", "ultra"} | |
| return {"ultra"} | |
| def quality_fit(model: PartnerModel, weights: dict[str, float]) -> float: | |
| total_weight = sum(weights[k] for k in PARTNER_SCORE_KEYS) or 1.0 | |
| weighted = sum(weights[k] * (model.scores.get(k, 0) / 10.0) for k in PARTNER_SCORE_KEYS) | |
| return weighted / total_weight | |
| def _best_ultra(registry: PartnerRegistry, weights: dict[str, float]) -> PartnerModel: | |
| ultras = registry.by_tier("ultra") | |
| pool = ultras if ultras else registry.models | |
| return max(pool, key=lambda m: quality_fit(m, weights)) | |
| def select_model( | |
| registry: PartnerRegistry, | |
| weights: dict[str, float], | |
| difficulty: int, | |
| is_ood: bool = False, | |
| quality_floor_ratio: float = 0.65, | |
| ) -> tuple[PartnerModel, bool]: | |
| """Returns (chosen_model, escalated). Escalated means we fell back to the | |
| ultra-tier anchor (low confidence in the prediction).""" | |
| if not registry.models: | |
| raise ValueError("partner registry is empty") | |
| if is_ood: | |
| return _best_ultra(registry, weights), True | |
| allowed = registry.by_tier(*_allowed_tiers(difficulty)) | |
| if not allowed: | |
| return _best_ultra(registry, weights), True | |
| best_allowed = max(allowed, key=lambda m: quality_fit(m, weights)) | |
| floor = quality_fit(best_allowed, weights) * quality_floor_ratio | |
| qualifying = [m for m in allowed if quality_fit(m, weights) >= floor] | |
| if not qualifying: | |
| return best_allowed, False | |
| chosen = min(qualifying, key=lambda m: (m.cost, -quality_fit(m, weights))) | |
| return chosen, False | |
| def energy_savings_pct(chosen: PartnerModel, baseline_cost: int = ULTRA_BASELINE_COST) -> float: | |
| if baseline_cost <= 0: | |
| return 0.0 | |
| saved = (baseline_cost - chosen.cost) / baseline_cost | |
| return max(0.0, min(1.0, saved)) * 100.0 | |
| def build_reason( | |
| weights: dict[str, float], | |
| complexity: str, | |
| chosen: PartnerModel, | |
| escalated: bool, | |
| is_ood: bool = False, | |
| ) -> str: | |
| top_cap, top_score = max(weights.items(), key=lambda kv: kv[1]) | |
| bits: list[str] = [] | |
| if is_ood: | |
| bits.append("low-confidence input (escalated to ultra tier)") | |
| elif top_score >= 0.5: | |
| bits.append(f"{top_cap} dominant ({top_score:.2f})") | |
| else: | |
| bits.append("mixed signal") | |
| if not is_ood: | |
| bits.append(f"{complexity} difficulty") | |
| if escalated and not is_ood: | |
| bits.append("escalated (no qualifying tier-allowed model)") | |
| elif not escalated: | |
| bits.append(f"picked {chosen.tier} tier (cost {chosen.cost})") | |
| return ", ".join(bits) | |
| def fold_recent_context(message: str, recent: Optional[list[dict]]) -> str: | |
| if not recent: | |
| return message | |
| last = recent[-1] | |
| content = (last.get("content") or "")[:200] if isinstance(last, dict) else "" | |
| if not content: | |
| return message | |
| return f"{content}\n{message}" | |