router-api / greenrouting /data /builder.py
spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""Dataset orchestrator: source sampling -> capability labeling -> cascade plan -> parquet."""
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from greenrouting.data.capability_labeler import LabelerConfig, label_queries
from greenrouting.data.schema import CapabilityLabel, LabeledQuery, RawQuery
from greenrouting.data.sources import SOURCE_REGISTRY, sample_mix
from greenrouting.routing.registry import CAPABILITY_KEYS
@dataclass
class CascadeRungConfig:
id: str
hf_model: str
params_b: float
decode_tokens_per_second_estimate: float
runs_locally: bool = True
@dataclass
class CascadeConfig:
rungs: list[CascadeRungConfig]
k_samples: int = 1
max_new_tokens: int = 200
temperature_first: float = 0.0
temperature_resample: float = 0.7
def projected_seconds(self, n_queries: int) -> float:
total = 0.0
for r in self.rungs:
inferences = n_queries * self.k_samples
total += inferences * self.max_new_tokens / max(r.decode_tokens_per_second_estimate, 1.0)
total += inferences * 0.4
return total
@dataclass
class BuildConfig:
profile_name: str
target_total_queries: int
test_split: float
seed: int
sources: dict[str, float]
cascade: CascadeConfig
labeler: LabelerConfig
budget_minutes: float = 60.0
output_dir: str = "data"
capability_labels_cache: Optional[str] = None
@classmethod
def from_yaml(cls, path: str | Path) -> "BuildConfig":
import yaml
with open(path, "r", encoding="utf-8") as f:
raw = yaml.safe_load(f)
rungs = [CascadeRungConfig(**r) for r in raw["cascade"]["rungs"]]
cascade = CascadeConfig(
rungs=rungs,
k_samples=raw["cascade"].get("k_samples", 1),
max_new_tokens=raw["cascade"].get("max_new_tokens", 200),
temperature_first=raw["cascade"].get("temperature_first", 0.0),
temperature_resample=raw["cascade"].get("temperature_resample", 0.7),
)
labeler_raw = raw.get("labeler", {})
labeler = LabelerConfig(
use_heuristic=labeler_raw.get("use_heuristic", True),
use_gpt=labeler_raw.get("use_gpt", False),
use_claude=labeler_raw.get("use_claude", False),
use_gemini=labeler_raw.get("use_gemini", False),
source_prior_weight=labeler_raw.get("source_prior_weight", 0.5),
sleep_between_calls_s=labeler_raw.get("sleep_between_calls_s", 0.0),
)
return cls(
profile_name=raw["profile_name"],
target_total_queries=raw["target_total_queries"],
test_split=raw["test_split"],
seed=raw["seed"],
sources=raw["sources"],
cascade=cascade,
labeler=labeler,
budget_minutes=raw.get("budget_minutes", 60.0),
output_dir=raw.get("output_dir", "data"),
capability_labels_cache=raw.get("capability_labels_cache"),
)
@dataclass
class BuildPlan:
config: BuildConfig
n_queries: int
cascade_seconds: float
cascade_minutes: float
over_budget: bool
notes: list[str] = field(default_factory=list)
def report(self) -> str:
lines = [
f"Profile: {self.config.profile_name}",
f"Target queries: {self.config.target_total_queries}",
f"Test split: {int(self.config.test_split * 100)}%",
f"Sources: {', '.join(f'{k}={v}' for k, v in self.config.sources.items())}",
f"Cascade rungs: {', '.join(r.id for r in self.config.cascade.rungs)}",
f"k_samples per rung: {self.config.cascade.k_samples}",
f"Max new tokens: {self.config.cascade.max_new_tokens}",
f"Estimated cascade wall time: {self.cascade_minutes:.1f} min",
f"Configured budget: {self.config.budget_minutes:.1f} min",
f"Over budget: {self.over_budget}",
]
if self.notes:
lines.append("Notes:")
for note in self.notes:
lines.append(f" - {note}")
return "\n".join(lines)
def plan(config: BuildConfig) -> BuildPlan:
notes: list[str] = []
cascade_s = config.cascade.projected_seconds(config.target_total_queries)
cascade_m = cascade_s / 60.0
over_budget = cascade_m > config.budget_minutes
if over_budget:
notes.append(
f"cascade projected {cascade_m:.1f} min exceeds budget {config.budget_minutes:.1f} min"
)
if config.labeler.use_gpt and not os.environ.get("OPENAI_API_KEY"):
notes.append("OPENAI_API_KEY missing; gpt vote will be skipped")
if config.labeler.use_claude and not os.environ.get("ANTHROPIC_API_KEY"):
notes.append("ANTHROPIC_API_KEY missing; claude vote will be skipped")
if config.labeler.use_gemini and not os.environ.get("GOOGLE_API_KEY"):
notes.append("GOOGLE_API_KEY missing; gemini vote will be skipped")
for src in config.sources:
if src not in SOURCE_REGISTRY:
notes.append(f"unknown source '{src}' in mix")
return BuildPlan(
config=config,
n_queries=config.target_total_queries,
cascade_seconds=cascade_s,
cascade_minutes=cascade_m,
over_budget=over_budget,
notes=notes,
)
def write_capability_labels(path: str | Path, labels: list[CapabilityLabel]) -> None:
import pandas as pd
df = pd.DataFrame([lbl.to_record() for lbl in labels])
Path(path).parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False)
def read_capability_labels(path: str | Path) -> dict[str, dict[str, float]]:
import pandas as pd
df = pd.read_parquet(path)
out: dict[str, dict[str, float]] = {}
cap_cols = [c for c in df.columns if c.startswith("cap_")]
for _, row in df.iterrows():
out[row["query_id"]] = {c[4:]: float(row[c]) for c in cap_cols}
return out
def write_raw_manifest(path: str | Path, queries: list[RawQuery]) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for q in queries:
f.write(json.dumps(q.to_dict()) + "\n")
def write_labeled_dataset(
train_path: str | Path,
test_path: str | Path,
rows: list[LabeledQuery],
test_split: float,
seed: int,
) -> None:
import pandas as pd
import random as _random
rng = _random.Random(seed)
indices = list(range(len(rows)))
rng.shuffle(indices)
n_test = max(1, int(len(rows) * test_split))
test_idx = set(indices[:n_test])
train_records = [rows[i].to_record() for i in range(len(rows)) if i not in test_idx]
test_records = [rows[i].to_record() for i in test_idx]
Path(train_path).parent.mkdir(parents=True, exist_ok=True)
pd.DataFrame(train_records).to_parquet(train_path, index=False)
pd.DataFrame(test_records).to_parquet(test_path, index=False)
def build_seed_dataset(
output_dir: str | Path,
test_split: float = 0.15,
seed: int = 42,
suffix: str = "seed",
) -> tuple[Path, Path]:
"""Materialize the curated seed entries into train/test parquet files.
Skips the cascade and the labeler: the seed entries already carry gold
capability multi-labels, difficulty (in log_params), and length buckets.
"""
from greenrouting.data.seed_dataset import (
SEED_QUERIES,
difficulty_log_params_from_b,
seed_capability_dict,
)
rows: list[LabeledQuery] = []
for i, entry in enumerate(SEED_QUERIES):
raw = RawQuery(
id=f"seed-{i:04d}",
text=entry.text,
source="seed",
source_category=entry.primary_category,
has_grader=False,
grader_metadata={},
)
rows.append(LabeledQuery(
raw=raw,
capabilities=seed_capability_dict(entry, CAPABILITY_KEYS),
difficulty_log_params=difficulty_log_params_from_b(entry.difficulty_b),
length_bucket=entry.length,
cascade_results={"source": "seed_curated"},
))
out = Path(output_dir)
train_path = out / f"train_{suffix}.parquet"
test_path = out / f"test_{suffix}.parquet"
write_labeled_dataset(train_path, test_path, rows, test_split=test_split, seed=seed)
return train_path, test_path
def sample_and_label(config: BuildConfig) -> tuple[list[RawQuery], list[CapabilityLabel]]:
queries = sample_mix(config.sources, config.target_total_queries, config.seed)
cached = {}
if config.capability_labels_cache and Path(config.capability_labels_cache).exists():
cached = read_capability_labels(config.capability_labels_cache)
new_queries = [q for q in queries if q.id not in cached]
new_labels = label_queries(new_queries, config.labeler) if new_queries else []
cached_labels: list[CapabilityLabel] = []
for q in queries:
if q.id in cached:
from greenrouting.data.schema import CapabilityVotes
cached_labels.append(CapabilityLabel(
query_id=q.id,
capabilities=cached[q.id],
votes=CapabilityVotes(),
aggregation_method="cached",
))
all_labels = new_labels + cached_labels
by_id = {l.query_id: l for l in all_labels}
aligned = [by_id[q.id] for q in queries if q.id in by_id]
return queries, aligned