Spaces:
Sleeping
Sleeping
File size: 5,665 Bytes
6f0ff99 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """Maps the GreenRouting classifier output to the partner's response schema.
Inputs:
- QueryProfile from greenrouting.classifier (8 capability probabilities,
continuous difficulty in log-parameters, length distribution)
- PartnerRegistry of candidate models (tier + per-category 1-10 scores + cost)
Outputs:
- capability_weights: dict[7-key partner schema -> float in 0..1]
- category: argmax over the 5-category public set
- complexity: simple|moderate|complex
- difficulty: integer 1..5
- chosen model_id from the registry
- energy_savings_pct vs an always-ultra-tier baseline
- reason string for the partner's audit log
"""
from __future__ import annotations
import math
from typing import Optional
from greenrouting.classifier.infer import QueryProfile
from partner_registry import PARTNER_SCORE_KEYS, PartnerModel, PartnerRegistry
PUBLIC_CATEGORIES: tuple[str, ...] = ("chat", "code", "math", "research", "creative")
COMPLEXITY_BUCKETS: tuple[str, ...] = ("simple", "moderate", "complex")
ULTRA_BASELINE_COST: int = 10
def rebucket_capabilities(profile: QueryProfile) -> dict[str, float]:
"""Map our 8 internal capabilities to the partner's 7 score categories."""
c = profile.capabilities
coding = c.code
math_ = c.math
research = min(1.0, c.reasoning + c.knowledge)
creative = c.creative
chat = min(1.0, c.simple_chat + c.instruction)
roleplay = c.creative * 0.5
ideas = min(1.0, (c.creative + c.reasoning) * 0.4)
return {
"coding": round(coding, 3),
"math": round(math_, 3),
"research": round(research, 3),
"creative": round(creative, 3),
"chat": round(chat, 3),
"roleplay": round(roleplay, 3),
"ideas": round(ideas, 3),
}
def pick_category(weights: dict[str, float]) -> str:
public = {k: weights[k] for k in ("chat", "coding", "math", "research", "creative")}
top = max(public, key=public.get)
if top == "coding":
return "code"
return top
def pick_complexity(profile: QueryProfile) -> str:
log_p = profile.difficulty_log_params
if log_p < math.log(3e9):
return "simple"
if log_p < math.log(20e9):
return "moderate"
return "complex"
def pick_difficulty_int(profile: QueryProfile) -> int:
log_p = profile.difficulty_log_params
boundaries = [math.log(b * 1e9) for b in (1, 5, 15, 50)]
rank = 1
for b in boundaries:
if log_p >= b:
rank += 1
else:
break
return min(5, max(1, rank))
def _allowed_tiers(difficulty: int) -> set[str]:
if difficulty <= 1:
return {"lite", "standard"}
if difficulty == 2:
return {"lite", "standard"}
if difficulty == 3:
return {"standard", "pro"}
if difficulty == 4:
return {"pro", "ultra"}
return {"ultra"}
def quality_fit(model: PartnerModel, weights: dict[str, float]) -> float:
total_weight = sum(weights[k] for k in PARTNER_SCORE_KEYS) or 1.0
weighted = sum(weights[k] * (model.scores.get(k, 0) / 10.0) for k in PARTNER_SCORE_KEYS)
return weighted / total_weight
def _best_ultra(registry: PartnerRegistry, weights: dict[str, float]) -> PartnerModel:
ultras = registry.by_tier("ultra")
pool = ultras if ultras else registry.models
return max(pool, key=lambda m: quality_fit(m, weights))
def select_model(
registry: PartnerRegistry,
weights: dict[str, float],
difficulty: int,
is_ood: bool = False,
quality_floor_ratio: float = 0.65,
) -> tuple[PartnerModel, bool]:
"""Returns (chosen_model, escalated). Escalated means we fell back to the
ultra-tier anchor (low confidence in the prediction)."""
if not registry.models:
raise ValueError("partner registry is empty")
if is_ood:
return _best_ultra(registry, weights), True
allowed = registry.by_tier(*_allowed_tiers(difficulty))
if not allowed:
return _best_ultra(registry, weights), True
best_allowed = max(allowed, key=lambda m: quality_fit(m, weights))
floor = quality_fit(best_allowed, weights) * quality_floor_ratio
qualifying = [m for m in allowed if quality_fit(m, weights) >= floor]
if not qualifying:
return best_allowed, False
chosen = min(qualifying, key=lambda m: (m.cost, -quality_fit(m, weights)))
return chosen, False
def energy_savings_pct(chosen: PartnerModel, baseline_cost: int = ULTRA_BASELINE_COST) -> float:
if baseline_cost <= 0:
return 0.0
saved = (baseline_cost - chosen.cost) / baseline_cost
return max(0.0, min(1.0, saved)) * 100.0
def build_reason(
weights: dict[str, float],
complexity: str,
chosen: PartnerModel,
escalated: bool,
is_ood: bool = False,
) -> str:
top_cap, top_score = max(weights.items(), key=lambda kv: kv[1])
bits: list[str] = []
if is_ood:
bits.append("low-confidence input (escalated to ultra tier)")
elif top_score >= 0.5:
bits.append(f"{top_cap} dominant ({top_score:.2f})")
else:
bits.append("mixed signal")
if not is_ood:
bits.append(f"{complexity} difficulty")
if escalated and not is_ood:
bits.append("escalated (no qualifying tier-allowed model)")
elif not escalated:
bits.append(f"picked {chosen.tier} tier (cost {chosen.cost})")
return ", ".join(bits)
def fold_recent_context(message: str, recent: Optional[list[dict]]) -> str:
if not recent:
return message
last = recent[-1]
content = (last.get("content") or "")[:200] if isinstance(last, dict) else ""
if not content:
return message
return f"{content}\n{message}"
|