Spaces:
Sleeping
Sleeping
| """Dataset orchestrator: source sampling -> capability labeling -> cascade plan -> parquet.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Optional | |
| from greenrouting.data.capability_labeler import LabelerConfig, label_queries | |
| from greenrouting.data.schema import CapabilityLabel, LabeledQuery, RawQuery | |
| from greenrouting.data.sources import SOURCE_REGISTRY, sample_mix | |
| from greenrouting.routing.registry import CAPABILITY_KEYS | |
| class CascadeRungConfig: | |
| id: str | |
| hf_model: str | |
| params_b: float | |
| decode_tokens_per_second_estimate: float | |
| runs_locally: bool = True | |
| class CascadeConfig: | |
| rungs: list[CascadeRungConfig] | |
| k_samples: int = 1 | |
| max_new_tokens: int = 200 | |
| temperature_first: float = 0.0 | |
| temperature_resample: float = 0.7 | |
| def projected_seconds(self, n_queries: int) -> float: | |
| total = 0.0 | |
| for r in self.rungs: | |
| inferences = n_queries * self.k_samples | |
| total += inferences * self.max_new_tokens / max(r.decode_tokens_per_second_estimate, 1.0) | |
| total += inferences * 0.4 | |
| return total | |
| class BuildConfig: | |
| profile_name: str | |
| target_total_queries: int | |
| test_split: float | |
| seed: int | |
| sources: dict[str, float] | |
| cascade: CascadeConfig | |
| labeler: LabelerConfig | |
| budget_minutes: float = 60.0 | |
| output_dir: str = "data" | |
| capability_labels_cache: Optional[str] = None | |
| def from_yaml(cls, path: str | Path) -> "BuildConfig": | |
| import yaml | |
| with open(path, "r", encoding="utf-8") as f: | |
| raw = yaml.safe_load(f) | |
| rungs = [CascadeRungConfig(**r) for r in raw["cascade"]["rungs"]] | |
| cascade = CascadeConfig( | |
| rungs=rungs, | |
| k_samples=raw["cascade"].get("k_samples", 1), | |
| max_new_tokens=raw["cascade"].get("max_new_tokens", 200), | |
| temperature_first=raw["cascade"].get("temperature_first", 0.0), | |
| temperature_resample=raw["cascade"].get("temperature_resample", 0.7), | |
| ) | |
| labeler_raw = raw.get("labeler", {}) | |
| labeler = LabelerConfig( | |
| use_heuristic=labeler_raw.get("use_heuristic", True), | |
| use_gpt=labeler_raw.get("use_gpt", False), | |
| use_claude=labeler_raw.get("use_claude", False), | |
| use_gemini=labeler_raw.get("use_gemini", False), | |
| source_prior_weight=labeler_raw.get("source_prior_weight", 0.5), | |
| sleep_between_calls_s=labeler_raw.get("sleep_between_calls_s", 0.0), | |
| ) | |
| return cls( | |
| profile_name=raw["profile_name"], | |
| target_total_queries=raw["target_total_queries"], | |
| test_split=raw["test_split"], | |
| seed=raw["seed"], | |
| sources=raw["sources"], | |
| cascade=cascade, | |
| labeler=labeler, | |
| budget_minutes=raw.get("budget_minutes", 60.0), | |
| output_dir=raw.get("output_dir", "data"), | |
| capability_labels_cache=raw.get("capability_labels_cache"), | |
| ) | |
| class BuildPlan: | |
| config: BuildConfig | |
| n_queries: int | |
| cascade_seconds: float | |
| cascade_minutes: float | |
| over_budget: bool | |
| notes: list[str] = field(default_factory=list) | |
| def report(self) -> str: | |
| lines = [ | |
| f"Profile: {self.config.profile_name}", | |
| f"Target queries: {self.config.target_total_queries}", | |
| f"Test split: {int(self.config.test_split * 100)}%", | |
| f"Sources: {', '.join(f'{k}={v}' for k, v in self.config.sources.items())}", | |
| f"Cascade rungs: {', '.join(r.id for r in self.config.cascade.rungs)}", | |
| f"k_samples per rung: {self.config.cascade.k_samples}", | |
| f"Max new tokens: {self.config.cascade.max_new_tokens}", | |
| f"Estimated cascade wall time: {self.cascade_minutes:.1f} min", | |
| f"Configured budget: {self.config.budget_minutes:.1f} min", | |
| f"Over budget: {self.over_budget}", | |
| ] | |
| if self.notes: | |
| lines.append("Notes:") | |
| for note in self.notes: | |
| lines.append(f" - {note}") | |
| return "\n".join(lines) | |
| def plan(config: BuildConfig) -> BuildPlan: | |
| notes: list[str] = [] | |
| cascade_s = config.cascade.projected_seconds(config.target_total_queries) | |
| cascade_m = cascade_s / 60.0 | |
| over_budget = cascade_m > config.budget_minutes | |
| if over_budget: | |
| notes.append( | |
| f"cascade projected {cascade_m:.1f} min exceeds budget {config.budget_minutes:.1f} min" | |
| ) | |
| if config.labeler.use_gpt and not os.environ.get("OPENAI_API_KEY"): | |
| notes.append("OPENAI_API_KEY missing; gpt vote will be skipped") | |
| if config.labeler.use_claude and not os.environ.get("ANTHROPIC_API_KEY"): | |
| notes.append("ANTHROPIC_API_KEY missing; claude vote will be skipped") | |
| if config.labeler.use_gemini and not os.environ.get("GOOGLE_API_KEY"): | |
| notes.append("GOOGLE_API_KEY missing; gemini vote will be skipped") | |
| for src in config.sources: | |
| if src not in SOURCE_REGISTRY: | |
| notes.append(f"unknown source '{src}' in mix") | |
| return BuildPlan( | |
| config=config, | |
| n_queries=config.target_total_queries, | |
| cascade_seconds=cascade_s, | |
| cascade_minutes=cascade_m, | |
| over_budget=over_budget, | |
| notes=notes, | |
| ) | |
| def write_capability_labels(path: str | Path, labels: list[CapabilityLabel]) -> None: | |
| import pandas as pd | |
| df = pd.DataFrame([lbl.to_record() for lbl in labels]) | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| df.to_parquet(path, index=False) | |
| def read_capability_labels(path: str | Path) -> dict[str, dict[str, float]]: | |
| import pandas as pd | |
| df = pd.read_parquet(path) | |
| out: dict[str, dict[str, float]] = {} | |
| cap_cols = [c for c in df.columns if c.startswith("cap_")] | |
| for _, row in df.iterrows(): | |
| out[row["query_id"]] = {c[4:]: float(row[c]) for c in cap_cols} | |
| return out | |
| def write_raw_manifest(path: str | Path, queries: list[RawQuery]) -> None: | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "w", encoding="utf-8") as f: | |
| for q in queries: | |
| f.write(json.dumps(q.to_dict()) + "\n") | |
| def write_labeled_dataset( | |
| train_path: str | Path, | |
| test_path: str | Path, | |
| rows: list[LabeledQuery], | |
| test_split: float, | |
| seed: int, | |
| ) -> None: | |
| import pandas as pd | |
| import random as _random | |
| rng = _random.Random(seed) | |
| indices = list(range(len(rows))) | |
| rng.shuffle(indices) | |
| n_test = max(1, int(len(rows) * test_split)) | |
| test_idx = set(indices[:n_test]) | |
| train_records = [rows[i].to_record() for i in range(len(rows)) if i not in test_idx] | |
| test_records = [rows[i].to_record() for i in test_idx] | |
| Path(train_path).parent.mkdir(parents=True, exist_ok=True) | |
| pd.DataFrame(train_records).to_parquet(train_path, index=False) | |
| pd.DataFrame(test_records).to_parquet(test_path, index=False) | |
| def build_seed_dataset( | |
| output_dir: str | Path, | |
| test_split: float = 0.15, | |
| seed: int = 42, | |
| suffix: str = "seed", | |
| ) -> tuple[Path, Path]: | |
| """Materialize the curated seed entries into train/test parquet files. | |
| Skips the cascade and the labeler: the seed entries already carry gold | |
| capability multi-labels, difficulty (in log_params), and length buckets. | |
| """ | |
| from greenrouting.data.seed_dataset import ( | |
| SEED_QUERIES, | |
| difficulty_log_params_from_b, | |
| seed_capability_dict, | |
| ) | |
| rows: list[LabeledQuery] = [] | |
| for i, entry in enumerate(SEED_QUERIES): | |
| raw = RawQuery( | |
| id=f"seed-{i:04d}", | |
| text=entry.text, | |
| source="seed", | |
| source_category=entry.primary_category, | |
| has_grader=False, | |
| grader_metadata={}, | |
| ) | |
| rows.append(LabeledQuery( | |
| raw=raw, | |
| capabilities=seed_capability_dict(entry, CAPABILITY_KEYS), | |
| difficulty_log_params=difficulty_log_params_from_b(entry.difficulty_b), | |
| length_bucket=entry.length, | |
| cascade_results={"source": "seed_curated"}, | |
| )) | |
| out = Path(output_dir) | |
| train_path = out / f"train_{suffix}.parquet" | |
| test_path = out / f"test_{suffix}.parquet" | |
| write_labeled_dataset(train_path, test_path, rows, test_split=test_split, seed=seed) | |
| return train_path, test_path | |
| def sample_and_label(config: BuildConfig) -> tuple[list[RawQuery], list[CapabilityLabel]]: | |
| queries = sample_mix(config.sources, config.target_total_queries, config.seed) | |
| cached = {} | |
| if config.capability_labels_cache and Path(config.capability_labels_cache).exists(): | |
| cached = read_capability_labels(config.capability_labels_cache) | |
| new_queries = [q for q in queries if q.id not in cached] | |
| new_labels = label_queries(new_queries, config.labeler) if new_queries else [] | |
| cached_labels: list[CapabilityLabel] = [] | |
| for q in queries: | |
| if q.id in cached: | |
| from greenrouting.data.schema import CapabilityVotes | |
| cached_labels.append(CapabilityLabel( | |
| query_id=q.id, | |
| capabilities=cached[q.id], | |
| votes=CapabilityVotes(), | |
| aggregation_method="cached", | |
| )) | |
| all_labels = new_labels + cached_labels | |
| by_id = {l.query_id: l for l in all_labels} | |
| aligned = [by_id[q.id] for q in queries if q.id in by_id] | |
| return queries, aligned | |