"""Schema for queries and labels flowing through the data pipeline.""" from __future__ import annotations from dataclasses import dataclass, field, asdict from typing import Optional from greenrouting.routing.registry import CAPABILITY_KEYS LENGTH_BUCKETS: tuple[str, str, str] = ("short", "medium", "long") @dataclass class RawQuery: id: str text: str source: str source_category: str has_grader: bool = False grader_metadata: dict = field(default_factory=dict) def to_dict(self) -> dict: return asdict(self) @dataclass class CapabilityVotes: source_prior: dict[str, float] = field(default_factory=dict) heuristic: Optional[dict[str, float]] = None gpt: Optional[dict[str, float]] = None claude: Optional[dict[str, float]] = None gemini: Optional[dict[str, float]] = None def vote_count(self) -> int: return sum( 1 for v in (self.heuristic, self.gpt, self.claude, self.gemini) if v is not None ) @dataclass class CapabilityLabel: query_id: str capabilities: dict[str, float] votes: CapabilityVotes aggregation_method: str def to_record(self) -> dict: rec = {"query_id": self.query_id, "aggregation_method": self.aggregation_method} for k in CAPABILITY_KEYS: rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0)) for vendor in ("source_prior", "heuristic", "gpt", "claude", "gemini"): v = getattr(self.votes, vendor) if v is None: continue for k in CAPABILITY_KEYS: rec[f"vote_{vendor}_{k}"] = float(v.get(k, 0.0)) return rec @dataclass class LabeledQuery: raw: RawQuery capabilities: dict[str, float] difficulty_log_params: Optional[float] length_bucket: Optional[str] cascade_results: dict = field(default_factory=dict) def to_record(self) -> dict: rec = { "id": self.raw.id, "text": self.raw.text, "source": self.raw.source, "source_category": self.raw.source_category, "has_grader": self.raw.has_grader, "difficulty_log_params": self.difficulty_log_params, "length_bucket": self.length_bucket, } for k in CAPABILITY_KEYS: rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0)) return rec