Spaces:
Sleeping
Sleeping
| """Schema for queries and labels flowing through the data pipeline.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Optional | |
| from greenrouting.routing.registry import CAPABILITY_KEYS | |
| LENGTH_BUCKETS: tuple[str, str, str] = ("short", "medium", "long") | |
| class RawQuery: | |
| id: str | |
| text: str | |
| source: str | |
| source_category: str | |
| has_grader: bool = False | |
| grader_metadata: dict = field(default_factory=dict) | |
| def to_dict(self) -> dict: | |
| return asdict(self) | |
| class CapabilityVotes: | |
| source_prior: dict[str, float] = field(default_factory=dict) | |
| heuristic: Optional[dict[str, float]] = None | |
| gpt: Optional[dict[str, float]] = None | |
| claude: Optional[dict[str, float]] = None | |
| gemini: Optional[dict[str, float]] = None | |
| def vote_count(self) -> int: | |
| return sum( | |
| 1 for v in (self.heuristic, self.gpt, self.claude, self.gemini) if v is not None | |
| ) | |
| class CapabilityLabel: | |
| query_id: str | |
| capabilities: dict[str, float] | |
| votes: CapabilityVotes | |
| aggregation_method: str | |
| def to_record(self) -> dict: | |
| rec = {"query_id": self.query_id, "aggregation_method": self.aggregation_method} | |
| for k in CAPABILITY_KEYS: | |
| rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0)) | |
| for vendor in ("source_prior", "heuristic", "gpt", "claude", "gemini"): | |
| v = getattr(self.votes, vendor) | |
| if v is None: | |
| continue | |
| for k in CAPABILITY_KEYS: | |
| rec[f"vote_{vendor}_{k}"] = float(v.get(k, 0.0)) | |
| return rec | |
| class LabeledQuery: | |
| raw: RawQuery | |
| capabilities: dict[str, float] | |
| difficulty_log_params: Optional[float] | |
| length_bucket: Optional[str] | |
| cascade_results: dict = field(default_factory=dict) | |
| def to_record(self) -> dict: | |
| rec = { | |
| "id": self.raw.id, | |
| "text": self.raw.text, | |
| "source": self.raw.source, | |
| "source_category": self.raw.source_category, | |
| "has_grader": self.raw.has_grader, | |
| "difficulty_log_params": self.difficulty_log_params, | |
| "length_bucket": self.length_bucket, | |
| } | |
| for k in CAPABILITY_KEYS: | |
| rec[f"cap_{k}"] = float(self.capabilities.get(k, 0.0)) | |
| return rec | |