"""Dataset orchestrator: source sampling -> capability labeling -> cascade plan -> parquet.""" from __future__ import annotations import json import os from dataclasses import dataclass, field from pathlib import Path from typing import Optional from greenrouting.data.capability_labeler import LabelerConfig, label_queries from greenrouting.data.schema import CapabilityLabel, LabeledQuery, RawQuery from greenrouting.data.sources import SOURCE_REGISTRY, sample_mix from greenrouting.routing.registry import CAPABILITY_KEYS @dataclass class CascadeRungConfig: id: str hf_model: str params_b: float decode_tokens_per_second_estimate: float runs_locally: bool = True @dataclass class CascadeConfig: rungs: list[CascadeRungConfig] k_samples: int = 1 max_new_tokens: int = 200 temperature_first: float = 0.0 temperature_resample: float = 0.7 def projected_seconds(self, n_queries: int) -> float: total = 0.0 for r in self.rungs: inferences = n_queries * self.k_samples total += inferences * self.max_new_tokens / max(r.decode_tokens_per_second_estimate, 1.0) total += inferences * 0.4 return total @dataclass class BuildConfig: profile_name: str target_total_queries: int test_split: float seed: int sources: dict[str, float] cascade: CascadeConfig labeler: LabelerConfig budget_minutes: float = 60.0 output_dir: str = "data" capability_labels_cache: Optional[str] = None @classmethod def from_yaml(cls, path: str | Path) -> "BuildConfig": import yaml with open(path, "r", encoding="utf-8") as f: raw = yaml.safe_load(f) rungs = [CascadeRungConfig(**r) for r in raw["cascade"]["rungs"]] cascade = CascadeConfig( rungs=rungs, k_samples=raw["cascade"].get("k_samples", 1), max_new_tokens=raw["cascade"].get("max_new_tokens", 200), temperature_first=raw["cascade"].get("temperature_first", 0.0), temperature_resample=raw["cascade"].get("temperature_resample", 0.7), ) labeler_raw = raw.get("labeler", {}) labeler = LabelerConfig( use_heuristic=labeler_raw.get("use_heuristic", True), use_gpt=labeler_raw.get("use_gpt", False), use_claude=labeler_raw.get("use_claude", False), use_gemini=labeler_raw.get("use_gemini", False), source_prior_weight=labeler_raw.get("source_prior_weight", 0.5), sleep_between_calls_s=labeler_raw.get("sleep_between_calls_s", 0.0), ) return cls( profile_name=raw["profile_name"], target_total_queries=raw["target_total_queries"], test_split=raw["test_split"], seed=raw["seed"], sources=raw["sources"], cascade=cascade, labeler=labeler, budget_minutes=raw.get("budget_minutes", 60.0), output_dir=raw.get("output_dir", "data"), capability_labels_cache=raw.get("capability_labels_cache"), ) @dataclass class BuildPlan: config: BuildConfig n_queries: int cascade_seconds: float cascade_minutes: float over_budget: bool notes: list[str] = field(default_factory=list) def report(self) -> str: lines = [ f"Profile: {self.config.profile_name}", f"Target queries: {self.config.target_total_queries}", f"Test split: {int(self.config.test_split * 100)}%", f"Sources: {', '.join(f'{k}={v}' for k, v in self.config.sources.items())}", f"Cascade rungs: {', '.join(r.id for r in self.config.cascade.rungs)}", f"k_samples per rung: {self.config.cascade.k_samples}", f"Max new tokens: {self.config.cascade.max_new_tokens}", f"Estimated cascade wall time: {self.cascade_minutes:.1f} min", f"Configured budget: {self.config.budget_minutes:.1f} min", f"Over budget: {self.over_budget}", ] if self.notes: lines.append("Notes:") for note in self.notes: lines.append(f" - {note}") return "\n".join(lines) def plan(config: BuildConfig) -> BuildPlan: notes: list[str] = [] cascade_s = config.cascade.projected_seconds(config.target_total_queries) cascade_m = cascade_s / 60.0 over_budget = cascade_m > config.budget_minutes if over_budget: notes.append( f"cascade projected {cascade_m:.1f} min exceeds budget {config.budget_minutes:.1f} min" ) if config.labeler.use_gpt and not os.environ.get("OPENAI_API_KEY"): notes.append("OPENAI_API_KEY missing; gpt vote will be skipped") if config.labeler.use_claude and not os.environ.get("ANTHROPIC_API_KEY"): notes.append("ANTHROPIC_API_KEY missing; claude vote will be skipped") if config.labeler.use_gemini and not os.environ.get("GOOGLE_API_KEY"): notes.append("GOOGLE_API_KEY missing; gemini vote will be skipped") for src in config.sources: if src not in SOURCE_REGISTRY: notes.append(f"unknown source '{src}' in mix") return BuildPlan( config=config, n_queries=config.target_total_queries, cascade_seconds=cascade_s, cascade_minutes=cascade_m, over_budget=over_budget, notes=notes, ) def write_capability_labels(path: str | Path, labels: list[CapabilityLabel]) -> None: import pandas as pd df = pd.DataFrame([lbl.to_record() for lbl in labels]) Path(path).parent.mkdir(parents=True, exist_ok=True) df.to_parquet(path, index=False) def read_capability_labels(path: str | Path) -> dict[str, dict[str, float]]: import pandas as pd df = pd.read_parquet(path) out: dict[str, dict[str, float]] = {} cap_cols = [c for c in df.columns if c.startswith("cap_")] for _, row in df.iterrows(): out[row["query_id"]] = {c[4:]: float(row[c]) for c in cap_cols} return out def write_raw_manifest(path: str | Path, queries: list[RawQuery]) -> None: Path(path).parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: for q in queries: f.write(json.dumps(q.to_dict()) + "\n") def write_labeled_dataset( train_path: str | Path, test_path: str | Path, rows: list[LabeledQuery], test_split: float, seed: int, ) -> None: import pandas as pd import random as _random rng = _random.Random(seed) indices = list(range(len(rows))) rng.shuffle(indices) n_test = max(1, int(len(rows) * test_split)) test_idx = set(indices[:n_test]) train_records = [rows[i].to_record() for i in range(len(rows)) if i not in test_idx] test_records = [rows[i].to_record() for i in test_idx] Path(train_path).parent.mkdir(parents=True, exist_ok=True) pd.DataFrame(train_records).to_parquet(train_path, index=False) pd.DataFrame(test_records).to_parquet(test_path, index=False) def build_seed_dataset( output_dir: str | Path, test_split: float = 0.15, seed: int = 42, suffix: str = "seed", ) -> tuple[Path, Path]: """Materialize the curated seed entries into train/test parquet files. Skips the cascade and the labeler: the seed entries already carry gold capability multi-labels, difficulty (in log_params), and length buckets. """ from greenrouting.data.seed_dataset import ( SEED_QUERIES, difficulty_log_params_from_b, seed_capability_dict, ) rows: list[LabeledQuery] = [] for i, entry in enumerate(SEED_QUERIES): raw = RawQuery( id=f"seed-{i:04d}", text=entry.text, source="seed", source_category=entry.primary_category, has_grader=False, grader_metadata={}, ) rows.append(LabeledQuery( raw=raw, capabilities=seed_capability_dict(entry, CAPABILITY_KEYS), difficulty_log_params=difficulty_log_params_from_b(entry.difficulty_b), length_bucket=entry.length, cascade_results={"source": "seed_curated"}, )) out = Path(output_dir) train_path = out / f"train_{suffix}.parquet" test_path = out / f"test_{suffix}.parquet" write_labeled_dataset(train_path, test_path, rows, test_split=test_split, seed=seed) return train_path, test_path def sample_and_label(config: BuildConfig) -> tuple[list[RawQuery], list[CapabilityLabel]]: queries = sample_mix(config.sources, config.target_total_queries, config.seed) cached = {} if config.capability_labels_cache and Path(config.capability_labels_cache).exists(): cached = read_capability_labels(config.capability_labels_cache) new_queries = [q for q in queries if q.id not in cached] new_labels = label_queries(new_queries, config.labeler) if new_queries else [] cached_labels: list[CapabilityLabel] = [] for q in queries: if q.id in cached: from greenrouting.data.schema import CapabilityVotes cached_labels.append(CapabilityLabel( query_id=q.id, capabilities=cached[q.id], votes=CapabilityVotes(), aggregation_method="cached", )) all_labels = new_labels + cached_labels by_id = {l.query_id: l for l in all_labels} aligned = [by_id[q.id] for q in queries if q.id in by_id] return queries, aligned