router-api / greenrouting /data /capability_labeler.py
spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""Capability labeling: turns RawQuery records into multi-label CapabilityLabel records.
Aggregates up to four independent voters:
- source_prior (always available; derived from source category)
- heuristic (always available; deterministic keyword/regex rules)
- gpt-4o (optional, requires OPENAI_API_KEY)
- claude-sonnet (optional, requires ANTHROPIC_API_KEY)
- gemini-pro (optional, requires GOOGLE_API_KEY)
Designed to run once during dataset prep. The output is committed as a parquet so
downstream training does not depend on API access.
"""
from __future__ import annotations
import json
import os
import re
import time
from dataclasses import dataclass
from typing import Iterable, Optional
from greenrouting.data.schema import CapabilityLabel, CapabilityVotes, RawQuery
from greenrouting.routing.registry import CAPABILITY_KEYS
CATEGORY_TO_LABELS: dict[str, list[str]] = {
"code": ["code"],
"math": ["math"],
"reasoning": ["reasoning"],
"knowledge": ["knowledge"],
"instruction": ["instruction"],
"creative": ["creative"],
"multilingual": ["multilingual"],
"simple_chat": ["simple_chat"],
}
def source_prior_vote(category: str) -> dict[str, float]:
labels = CATEGORY_TO_LABELS.get(category, [])
return {k: (1.0 if k in labels else 0.0) for k in CAPABILITY_KEYS}
_HEURISTIC_PATTERNS: dict[str, list[str]] = {
"code": [
r"\b(code|function|class|def |algorithm|implement|debug|stack trace|api|sdk)\b",
r"\b(python|javascript|typescript|rust|go|c\+\+|java|sql|html|css)\b",
r"\b(refactor|unit test|regex|linter|compile|recursion)\b",
r"```",
],
"math": [
r"\b(calculate|compute|solve|equation|integral|derivative|matrix|vector|theorem|prove)\b",
r"\b(probability|sum|product|mean|median|variance|standard deviation|percentage)\b",
r"\d+\s*[+\-*/×÷=]\s*\d+",
r"\b(arithmetic|fraction|geometry|algebra|trig)\b",
],
"reasoning": [
r"\b(why|how does|explain|reason|because|therefore|argue|justify|implication)\b",
r"\b(compare|contrast|analyze|evaluate|trade.?off|infer|deduce)\b",
],
"knowledge": [
r"\b(who|what is|when did|where is|history|definition|capital|founded|named)\b",
r"\b(country|continent|invented|discovered|president|prime minister)\b",
],
"instruction": [
r"\b(write|draft|create|generate|produce|format|list|outline|step.?by.?step|summarize)\b",
r"\b(rewrite|translate from|convert to|extract)\b",
],
"creative": [
r"\b(story|poem|novel|character|plot|scene|metaphor|fictional|haiku|song lyric)\b",
r"\b(write a (?:short )?(?:story|poem|haiku|song))\b",
],
"multilingual": [
r"\b(translate|translation|en español|en français|auf deutsch|на русском|中文|日本語|한국어)\b",
r"[Ѐ-ӿ一-鿿぀-ゟ゠-ヿ가-힣]",
],
"simple_chat": [
r"^\s*(hi|hello|hey|thanks|thank you|good morning|good evening|sup|yo)\b",
],
}
def heuristic_vote(text: str) -> dict[str, float]:
out = {k: 0.0 for k in CAPABILITY_KEYS}
for cap, patterns in _HEURISTIC_PATTERNS.items():
for pat in patterns:
if re.search(pat, text, flags=re.IGNORECASE | re.MULTILINE):
out[cap] = 1.0
break
if all(v == 0.0 for v in out.values()):
if len(text.strip()) < 80:
out["simple_chat"] = 1.0
else:
out["instruction"] = 1.0
return out
_LABELER_SYSTEM_PROMPT = (
"You are labeling AI queries by which capabilities they require. "
"Capabilities: code, math, reasoning, knowledge, instruction, creative, multilingual, "
"simple_chat. A query can require multiple capabilities. "
"Reply with strict JSON only, in the form: "
'{"code": 0|1, "math": 0|1, "reasoning": 0|1, "knowledge": 0|1, '
'"instruction": 0|1, "creative": 0|1, "multilingual": 0|1, "simple_chat": 0|1}.'
)
def _user_prompt(query: str) -> str:
return f"Query:\n{query}\n\nRespond with JSON only."
def _parse_vote(raw: str) -> dict[str, float]:
try:
data = json.loads(_extract_json(raw))
except Exception:
return {k: 0.0 for k in CAPABILITY_KEYS}
return {k: float(1 if data.get(k) else 0) for k in CAPABILITY_KEYS}
def _extract_json(text: str) -> str:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
return match.group(0) if match else text
def _gpt_vote(text: str) -> Optional[dict[str, float]]:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return None
try:
from openai import OpenAI
except ImportError:
return None
client = OpenAI(api_key=api_key)
resp = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": _LABELER_SYSTEM_PROMPT},
{"role": "user", "content": _user_prompt(text)},
],
temperature=0,
response_format={"type": "json_object"},
)
return _parse_vote(resp.choices[0].message.content or "{}")
def _claude_vote(text: str) -> Optional[dict[str, float]]:
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
return None
try:
import anthropic
except ImportError:
return None
client = anthropic.Anthropic(api_key=api_key)
resp = client.messages.create(
model="claude-haiku-4-5",
max_tokens=200,
system=_LABELER_SYSTEM_PROMPT,
messages=[{"role": "user", "content": _user_prompt(text)}],
)
body = "".join(b.text for b in resp.content if getattr(b, "type", "") == "text")
return _parse_vote(body)
def _gemini_vote(text: str) -> Optional[dict[str, float]]:
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
return None
try:
import google.generativeai as genai
except ImportError:
return None
genai.configure(api_key=api_key)
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=_LABELER_SYSTEM_PROMPT)
resp = model.generate_content(_user_prompt(text))
return _parse_vote(resp.text or "{}")
@dataclass
class LabelerConfig:
use_heuristic: bool = True
use_gpt: bool = False
use_claude: bool = False
use_gemini: bool = False
source_prior_weight: float = 0.5
sleep_between_calls_s: float = 0.0
def aggregate_votes(votes: CapabilityVotes, source_prior_weight: float = 0.5) -> dict[str, float]:
voters = [v for v in (votes.heuristic, votes.gpt, votes.claude, votes.gemini) if v is not None]
if not voters:
return dict(votes.source_prior) if votes.source_prior else {k: 0.0 for k in CAPABILITY_KEYS}
result: dict[str, float] = {}
total_weight = source_prior_weight + len(voters)
for cap in CAPABILITY_KEYS:
prior_term = source_prior_weight * float(votes.source_prior.get(cap, 0.0))
vendor_sum = sum(float(v.get(cap, 0.0)) for v in voters)
result[cap] = (prior_term + vendor_sum) / total_weight
return result
def label_query(query: RawQuery, config: LabelerConfig) -> CapabilityLabel:
votes = CapabilityVotes(source_prior=source_prior_vote(query.source_category))
if config.use_heuristic:
votes.heuristic = heuristic_vote(query.text)
if config.use_gpt:
votes.gpt = _gpt_vote(query.text)
if config.sleep_between_calls_s:
time.sleep(config.sleep_between_calls_s)
if config.use_claude:
votes.claude = _claude_vote(query.text)
if config.sleep_between_calls_s:
time.sleep(config.sleep_between_calls_s)
if config.use_gemini:
votes.gemini = _gemini_vote(query.text)
if config.sleep_between_calls_s:
time.sleep(config.sleep_between_calls_s)
aggregated = aggregate_votes(votes, source_prior_weight=config.source_prior_weight)
method = "+".join(
m for m, present in [
("heuristic", votes.heuristic is not None),
("gpt", votes.gpt is not None),
("claude", votes.claude is not None),
("gemini", votes.gemini is not None),
] if present
) or "source_prior_only"
return CapabilityLabel(
query_id=query.id,
capabilities=aggregated,
votes=votes,
aggregation_method=method,
)
def label_queries(queries: Iterable[RawQuery], config: LabelerConfig) -> list[CapabilityLabel]:
return [label_query(q, config) for q in queries]