Spaces:
Sleeping
Sleeping
| """Capability labeling: turns RawQuery records into multi-label CapabilityLabel records. | |
| Aggregates up to four independent voters: | |
| - source_prior (always available; derived from source category) | |
| - heuristic (always available; deterministic keyword/regex rules) | |
| - gpt-4o (optional, requires OPENAI_API_KEY) | |
| - claude-sonnet (optional, requires ANTHROPIC_API_KEY) | |
| - gemini-pro (optional, requires GOOGLE_API_KEY) | |
| Designed to run once during dataset prep. The output is committed as a parquet so | |
| downstream training does not depend on API access. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Iterable, Optional | |
| from greenrouting.data.schema import CapabilityLabel, CapabilityVotes, RawQuery | |
| from greenrouting.routing.registry import CAPABILITY_KEYS | |
| CATEGORY_TO_LABELS: dict[str, list[str]] = { | |
| "code": ["code"], | |
| "math": ["math"], | |
| "reasoning": ["reasoning"], | |
| "knowledge": ["knowledge"], | |
| "instruction": ["instruction"], | |
| "creative": ["creative"], | |
| "multilingual": ["multilingual"], | |
| "simple_chat": ["simple_chat"], | |
| } | |
| def source_prior_vote(category: str) -> dict[str, float]: | |
| labels = CATEGORY_TO_LABELS.get(category, []) | |
| return {k: (1.0 if k in labels else 0.0) for k in CAPABILITY_KEYS} | |
| _HEURISTIC_PATTERNS: dict[str, list[str]] = { | |
| "code": [ | |
| r"\b(code|function|class|def |algorithm|implement|debug|stack trace|api|sdk)\b", | |
| r"\b(python|javascript|typescript|rust|go|c\+\+|java|sql|html|css)\b", | |
| r"\b(refactor|unit test|regex|linter|compile|recursion)\b", | |
| r"```", | |
| ], | |
| "math": [ | |
| r"\b(calculate|compute|solve|equation|integral|derivative|matrix|vector|theorem|prove)\b", | |
| r"\b(probability|sum|product|mean|median|variance|standard deviation|percentage)\b", | |
| r"\d+\s*[+\-*/×÷=]\s*\d+", | |
| r"\b(arithmetic|fraction|geometry|algebra|trig)\b", | |
| ], | |
| "reasoning": [ | |
| r"\b(why|how does|explain|reason|because|therefore|argue|justify|implication)\b", | |
| r"\b(compare|contrast|analyze|evaluate|trade.?off|infer|deduce)\b", | |
| ], | |
| "knowledge": [ | |
| r"\b(who|what is|when did|where is|history|definition|capital|founded|named)\b", | |
| r"\b(country|continent|invented|discovered|president|prime minister)\b", | |
| ], | |
| "instruction": [ | |
| r"\b(write|draft|create|generate|produce|format|list|outline|step.?by.?step|summarize)\b", | |
| r"\b(rewrite|translate from|convert to|extract)\b", | |
| ], | |
| "creative": [ | |
| r"\b(story|poem|novel|character|plot|scene|metaphor|fictional|haiku|song lyric)\b", | |
| r"\b(write a (?:short )?(?:story|poem|haiku|song))\b", | |
| ], | |
| "multilingual": [ | |
| r"\b(translate|translation|en español|en français|auf deutsch|на русском|中文|日本語|한국어)\b", | |
| r"[Ѐ-ӿ一-鿿-ゟ゠-ヿ가-힣]", | |
| ], | |
| "simple_chat": [ | |
| r"^\s*(hi|hello|hey|thanks|thank you|good morning|good evening|sup|yo)\b", | |
| ], | |
| } | |
| def heuristic_vote(text: str) -> dict[str, float]: | |
| out = {k: 0.0 for k in CAPABILITY_KEYS} | |
| for cap, patterns in _HEURISTIC_PATTERNS.items(): | |
| for pat in patterns: | |
| if re.search(pat, text, flags=re.IGNORECASE | re.MULTILINE): | |
| out[cap] = 1.0 | |
| break | |
| if all(v == 0.0 for v in out.values()): | |
| if len(text.strip()) < 80: | |
| out["simple_chat"] = 1.0 | |
| else: | |
| out["instruction"] = 1.0 | |
| return out | |
| _LABELER_SYSTEM_PROMPT = ( | |
| "You are labeling AI queries by which capabilities they require. " | |
| "Capabilities: code, math, reasoning, knowledge, instruction, creative, multilingual, " | |
| "simple_chat. A query can require multiple capabilities. " | |
| "Reply with strict JSON only, in the form: " | |
| '{"code": 0|1, "math": 0|1, "reasoning": 0|1, "knowledge": 0|1, ' | |
| '"instruction": 0|1, "creative": 0|1, "multilingual": 0|1, "simple_chat": 0|1}.' | |
| ) | |
| def _user_prompt(query: str) -> str: | |
| return f"Query:\n{query}\n\nRespond with JSON only." | |
| def _parse_vote(raw: str) -> dict[str, float]: | |
| try: | |
| data = json.loads(_extract_json(raw)) | |
| except Exception: | |
| return {k: 0.0 for k in CAPABILITY_KEYS} | |
| return {k: float(1 if data.get(k) else 0) for k in CAPABILITY_KEYS} | |
| def _extract_json(text: str) -> str: | |
| match = re.search(r"\{.*\}", text, flags=re.DOTALL) | |
| return match.group(0) if match else text | |
| def _gpt_vote(text: str) -> Optional[dict[str, float]]: | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| return None | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| return None | |
| client = OpenAI(api_key=api_key) | |
| resp = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": _LABELER_SYSTEM_PROMPT}, | |
| {"role": "user", "content": _user_prompt(text)}, | |
| ], | |
| temperature=0, | |
| response_format={"type": "json_object"}, | |
| ) | |
| return _parse_vote(resp.choices[0].message.content or "{}") | |
| def _claude_vote(text: str) -> Optional[dict[str, float]]: | |
| api_key = os.environ.get("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| return None | |
| try: | |
| import anthropic | |
| except ImportError: | |
| return None | |
| client = anthropic.Anthropic(api_key=api_key) | |
| resp = client.messages.create( | |
| model="claude-haiku-4-5", | |
| max_tokens=200, | |
| system=_LABELER_SYSTEM_PROMPT, | |
| messages=[{"role": "user", "content": _user_prompt(text)}], | |
| ) | |
| body = "".join(b.text for b in resp.content if getattr(b, "type", "") == "text") | |
| return _parse_vote(body) | |
| def _gemini_vote(text: str) -> Optional[dict[str, float]]: | |
| api_key = os.environ.get("GOOGLE_API_KEY") | |
| if not api_key: | |
| return None | |
| try: | |
| import google.generativeai as genai | |
| except ImportError: | |
| return None | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=_LABELER_SYSTEM_PROMPT) | |
| resp = model.generate_content(_user_prompt(text)) | |
| return _parse_vote(resp.text or "{}") | |
| class LabelerConfig: | |
| use_heuristic: bool = True | |
| use_gpt: bool = False | |
| use_claude: bool = False | |
| use_gemini: bool = False | |
| source_prior_weight: float = 0.5 | |
| sleep_between_calls_s: float = 0.0 | |
| def aggregate_votes(votes: CapabilityVotes, source_prior_weight: float = 0.5) -> dict[str, float]: | |
| voters = [v for v in (votes.heuristic, votes.gpt, votes.claude, votes.gemini) if v is not None] | |
| if not voters: | |
| return dict(votes.source_prior) if votes.source_prior else {k: 0.0 for k in CAPABILITY_KEYS} | |
| result: dict[str, float] = {} | |
| total_weight = source_prior_weight + len(voters) | |
| for cap in CAPABILITY_KEYS: | |
| prior_term = source_prior_weight * float(votes.source_prior.get(cap, 0.0)) | |
| vendor_sum = sum(float(v.get(cap, 0.0)) for v in voters) | |
| result[cap] = (prior_term + vendor_sum) / total_weight | |
| return result | |
| def label_query(query: RawQuery, config: LabelerConfig) -> CapabilityLabel: | |
| votes = CapabilityVotes(source_prior=source_prior_vote(query.source_category)) | |
| if config.use_heuristic: | |
| votes.heuristic = heuristic_vote(query.text) | |
| if config.use_gpt: | |
| votes.gpt = _gpt_vote(query.text) | |
| if config.sleep_between_calls_s: | |
| time.sleep(config.sleep_between_calls_s) | |
| if config.use_claude: | |
| votes.claude = _claude_vote(query.text) | |
| if config.sleep_between_calls_s: | |
| time.sleep(config.sleep_between_calls_s) | |
| if config.use_gemini: | |
| votes.gemini = _gemini_vote(query.text) | |
| if config.sleep_between_calls_s: | |
| time.sleep(config.sleep_between_calls_s) | |
| aggregated = aggregate_votes(votes, source_prior_weight=config.source_prior_weight) | |
| method = "+".join( | |
| m for m, present in [ | |
| ("heuristic", votes.heuristic is not None), | |
| ("gpt", votes.gpt is not None), | |
| ("claude", votes.claude is not None), | |
| ("gemini", votes.gemini is not None), | |
| ] if present | |
| ) or "source_prior_only" | |
| return CapabilityLabel( | |
| query_id=query.id, | |
| capabilities=aggregated, | |
| votes=votes, | |
| aggregation_method=method, | |
| ) | |
| def label_queries(queries: Iterable[RawQuery], config: LabelerConfig) -> list[CapabilityLabel]: | |
| return [label_query(q, config) for q in queries] | |