spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""Schema for queries and labels flowing through the data pipeline."""
from __future__ import annotations
from dataclasses import dataclass, field, asdict
from typing import Optional
from greenrouting.routing.registry import CAPABILITY_KEYS
LENGTH_BUCKETS: tuple[str, str, str] = ("short", "medium", "long")
@dataclass
class RawQuery:
id: str
text: str
source: str
source_category: str
has_grader: bool = False
grader_metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class CapabilityVotes:
source_prior: dict[str, float] = field(default_factory=dict)
heuristic: Optional[dict[str, float]] = None
gpt: Optional[dict[str, float]] = None
claude: Optional[dict[str, float]] = None
gemini: Optional[dict[str, float]] = None
def vote_count(self) -> int:
return sum(
1 for v in (self.heuristic, self.gpt, self.claude, self.gemini) if v is not None
)
@dataclass
class CapabilityLabel:
query_id: str
capabilities: dict[str, float]
votes: CapabilityVotes
aggregation_method: str
def to_record(self) -> dict:
rec = {"query_id": self.query_id, "aggregation_method": self.aggregation_method}
for k in CAPABILITY_KEYS:
rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0))
for vendor in ("source_prior", "heuristic", "gpt", "claude", "gemini"):
v = getattr(self.votes, vendor)
if v is None:
continue
for k in CAPABILITY_KEYS:
rec[f"vote_{vendor}_{k}"] = float(v.get(k, 0.0))
return rec
@dataclass
class LabeledQuery:
raw: RawQuery
capabilities: dict[str, float]
difficulty_log_params: Optional[float]
length_bucket: Optional[str]
cascade_results: dict = field(default_factory=dict)
def to_record(self) -> dict:
rec = {
"id": self.raw.id,
"text": self.raw.text,
"source": self.raw.source,
"source_category": self.raw.source_category,
"has_grader": self.raw.has_grader,
"difficulty_log_params": self.difficulty_log_params,
"length_bucket": self.length_bucket,
}
for k in CAPABILITY_KEYS:
rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0))
return rec