"""Capability labeling: turns RawQuery records into multi-label CapabilityLabel records. Aggregates up to four independent voters: - source_prior (always available; derived from source category) - heuristic (always available; deterministic keyword/regex rules) - gpt-4o (optional, requires OPENAI_API_KEY) - claude-sonnet (optional, requires ANTHROPIC_API_KEY) - gemini-pro (optional, requires GOOGLE_API_KEY) Designed to run once during dataset prep. The output is committed as a parquet so downstream training does not depend on API access. """ from __future__ import annotations import json import os import re import time from dataclasses import dataclass from typing import Iterable, Optional from greenrouting.data.schema import CapabilityLabel, CapabilityVotes, RawQuery from greenrouting.routing.registry import CAPABILITY_KEYS CATEGORY_TO_LABELS: dict[str, list[str]] = { "code": ["code"], "math": ["math"], "reasoning": ["reasoning"], "knowledge": ["knowledge"], "instruction": ["instruction"], "creative": ["creative"], "multilingual": ["multilingual"], "simple_chat": ["simple_chat"], } def source_prior_vote(category: str) -> dict[str, float]: labels = CATEGORY_TO_LABELS.get(category, []) return {k: (1.0 if k in labels else 0.0) for k in CAPABILITY_KEYS} _HEURISTIC_PATTERNS: dict[str, list[str]] = { "code": [ r"\b(code|function|class|def |algorithm|implement|debug|stack trace|api|sdk)\b", r"\b(python|javascript|typescript|rust|go|c\+\+|java|sql|html|css)\b", r"\b(refactor|unit test|regex|linter|compile|recursion)\b", r"```", ], "math": [ r"\b(calculate|compute|solve|equation|integral|derivative|matrix|vector|theorem|prove)\b", r"\b(probability|sum|product|mean|median|variance|standard deviation|percentage)\b", r"\d+\s*[+\-*/×÷=]\s*\d+", r"\b(arithmetic|fraction|geometry|algebra|trig)\b", ], "reasoning": [ r"\b(why|how does|explain|reason|because|therefore|argue|justify|implication)\b", r"\b(compare|contrast|analyze|evaluate|trade.?off|infer|deduce)\b", ], "knowledge": [ r"\b(who|what is|when did|where is|history|definition|capital|founded|named)\b", r"\b(country|continent|invented|discovered|president|prime minister)\b", ], "instruction": [ r"\b(write|draft|create|generate|produce|format|list|outline|step.?by.?step|summarize)\b", r"\b(rewrite|translate from|convert to|extract)\b", ], "creative": [ r"\b(story|poem|novel|character|plot|scene|metaphor|fictional|haiku|song lyric)\b", r"\b(write a (?:short )?(?:story|poem|haiku|song))\b", ], "multilingual": [ r"\b(translate|translation|en español|en français|auf deutsch|на русском|中文|日本語|한국어)\b", r"[Ѐ-ӿ一-鿿぀-ゟ゠-ヿ가-힣]", ], "simple_chat": [ r"^\s*(hi|hello|hey|thanks|thank you|good morning|good evening|sup|yo)\b", ], } def heuristic_vote(text: str) -> dict[str, float]: out = {k: 0.0 for k in CAPABILITY_KEYS} for cap, patterns in _HEURISTIC_PATTERNS.items(): for pat in patterns: if re.search(pat, text, flags=re.IGNORECASE | re.MULTILINE): out[cap] = 1.0 break if all(v == 0.0 for v in out.values()): if len(text.strip()) < 80: out["simple_chat"] = 1.0 else: out["instruction"] = 1.0 return out _LABELER_SYSTEM_PROMPT = ( "You are labeling AI queries by which capabilities they require. " "Capabilities: code, math, reasoning, knowledge, instruction, creative, multilingual, " "simple_chat. A query can require multiple capabilities. " "Reply with strict JSON only, in the form: " '{"code": 0|1, "math": 0|1, "reasoning": 0|1, "knowledge": 0|1, ' '"instruction": 0|1, "creative": 0|1, "multilingual": 0|1, "simple_chat": 0|1}.' ) def _user_prompt(query: str) -> str: return f"Query:\n{query}\n\nRespond with JSON only." def _parse_vote(raw: str) -> dict[str, float]: try: data = json.loads(_extract_json(raw)) except Exception: return {k: 0.0 for k in CAPABILITY_KEYS} return {k: float(1 if data.get(k) else 0) for k in CAPABILITY_KEYS} def _extract_json(text: str) -> str: match = re.search(r"\{.*\}", text, flags=re.DOTALL) return match.group(0) if match else text def _gpt_vote(text: str) -> Optional[dict[str, float]]: api_key = os.environ.get("OPENAI_API_KEY") if not api_key: return None try: from openai import OpenAI except ImportError: return None client = OpenAI(api_key=api_key) resp = client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": _LABELER_SYSTEM_PROMPT}, {"role": "user", "content": _user_prompt(text)}, ], temperature=0, response_format={"type": "json_object"}, ) return _parse_vote(resp.choices[0].message.content or "{}") def _claude_vote(text: str) -> Optional[dict[str, float]]: api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: return None try: import anthropic except ImportError: return None client = anthropic.Anthropic(api_key=api_key) resp = client.messages.create( model="claude-haiku-4-5", max_tokens=200, system=_LABELER_SYSTEM_PROMPT, messages=[{"role": "user", "content": _user_prompt(text)}], ) body = "".join(b.text for b in resp.content if getattr(b, "type", "") == "text") return _parse_vote(body) def _gemini_vote(text: str) -> Optional[dict[str, float]]: api_key = os.environ.get("GOOGLE_API_KEY") if not api_key: return None try: import google.generativeai as genai except ImportError: return None genai.configure(api_key=api_key) model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=_LABELER_SYSTEM_PROMPT) resp = model.generate_content(_user_prompt(text)) return _parse_vote(resp.text or "{}") @dataclass class LabelerConfig: use_heuristic: bool = True use_gpt: bool = False use_claude: bool = False use_gemini: bool = False source_prior_weight: float = 0.5 sleep_between_calls_s: float = 0.0 def aggregate_votes(votes: CapabilityVotes, source_prior_weight: float = 0.5) -> dict[str, float]: voters = [v for v in (votes.heuristic, votes.gpt, votes.claude, votes.gemini) if v is not None] if not voters: return dict(votes.source_prior) if votes.source_prior else {k: 0.0 for k in CAPABILITY_KEYS} result: dict[str, float] = {} total_weight = source_prior_weight + len(voters) for cap in CAPABILITY_KEYS: prior_term = source_prior_weight * float(votes.source_prior.get(cap, 0.0)) vendor_sum = sum(float(v.get(cap, 0.0)) for v in voters) result[cap] = (prior_term + vendor_sum) / total_weight return result def label_query(query: RawQuery, config: LabelerConfig) -> CapabilityLabel: votes = CapabilityVotes(source_prior=source_prior_vote(query.source_category)) if config.use_heuristic: votes.heuristic = heuristic_vote(query.text) if config.use_gpt: votes.gpt = _gpt_vote(query.text) if config.sleep_between_calls_s: time.sleep(config.sleep_between_calls_s) if config.use_claude: votes.claude = _claude_vote(query.text) if config.sleep_between_calls_s: time.sleep(config.sleep_between_calls_s) if config.use_gemini: votes.gemini = _gemini_vote(query.text) if config.sleep_between_calls_s: time.sleep(config.sleep_between_calls_s) aggregated = aggregate_votes(votes, source_prior_weight=config.source_prior_weight) method = "+".join( m for m, present in [ ("heuristic", votes.heuristic is not None), ("gpt", votes.gpt is not None), ("claude", votes.claude is not None), ("gemini", votes.gemini is not None), ] if present ) or "source_prior_only" return CapabilityLabel( query_id=query.id, capabilities=aggregated, votes=votes, aggregation_method=method, ) def label_queries(queries: Iterable[RawQuery], config: LabelerConfig) -> list[CapabilityLabel]: return [label_query(q, config) for q in queries]