Spaces:
Sleeping
Sleeping
File size: 9,532 Bytes
6f0ff99 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 | """Dataset orchestrator: source sampling -> capability labeling -> cascade plan -> parquet."""
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from greenrouting.data.capability_labeler import LabelerConfig, label_queries
from greenrouting.data.schema import CapabilityLabel, LabeledQuery, RawQuery
from greenrouting.data.sources import SOURCE_REGISTRY, sample_mix
from greenrouting.routing.registry import CAPABILITY_KEYS
@dataclass
class CascadeRungConfig:
id: str
hf_model: str
params_b: float
decode_tokens_per_second_estimate: float
runs_locally: bool = True
@dataclass
class CascadeConfig:
rungs: list[CascadeRungConfig]
k_samples: int = 1
max_new_tokens: int = 200
temperature_first: float = 0.0
temperature_resample: float = 0.7
def projected_seconds(self, n_queries: int) -> float:
total = 0.0
for r in self.rungs:
inferences = n_queries * self.k_samples
total += inferences * self.max_new_tokens / max(r.decode_tokens_per_second_estimate, 1.0)
total += inferences * 0.4
return total
@dataclass
class BuildConfig:
profile_name: str
target_total_queries: int
test_split: float
seed: int
sources: dict[str, float]
cascade: CascadeConfig
labeler: LabelerConfig
budget_minutes: float = 60.0
output_dir: str = "data"
capability_labels_cache: Optional[str] = None
@classmethod
def from_yaml(cls, path: str | Path) -> "BuildConfig":
import yaml
with open(path, "r", encoding="utf-8") as f:
raw = yaml.safe_load(f)
rungs = [CascadeRungConfig(**r) for r in raw["cascade"]["rungs"]]
cascade = CascadeConfig(
rungs=rungs,
k_samples=raw["cascade"].get("k_samples", 1),
max_new_tokens=raw["cascade"].get("max_new_tokens", 200),
temperature_first=raw["cascade"].get("temperature_first", 0.0),
temperature_resample=raw["cascade"].get("temperature_resample", 0.7),
)
labeler_raw = raw.get("labeler", {})
labeler = LabelerConfig(
use_heuristic=labeler_raw.get("use_heuristic", True),
use_gpt=labeler_raw.get("use_gpt", False),
use_claude=labeler_raw.get("use_claude", False),
use_gemini=labeler_raw.get("use_gemini", False),
source_prior_weight=labeler_raw.get("source_prior_weight", 0.5),
sleep_between_calls_s=labeler_raw.get("sleep_between_calls_s", 0.0),
)
return cls(
profile_name=raw["profile_name"],
target_total_queries=raw["target_total_queries"],
test_split=raw["test_split"],
seed=raw["seed"],
sources=raw["sources"],
cascade=cascade,
labeler=labeler,
budget_minutes=raw.get("budget_minutes", 60.0),
output_dir=raw.get("output_dir", "data"),
capability_labels_cache=raw.get("capability_labels_cache"),
)
@dataclass
class BuildPlan:
config: BuildConfig
n_queries: int
cascade_seconds: float
cascade_minutes: float
over_budget: bool
notes: list[str] = field(default_factory=list)
def report(self) -> str:
lines = [
f"Profile: {self.config.profile_name}",
f"Target queries: {self.config.target_total_queries}",
f"Test split: {int(self.config.test_split * 100)}%",
f"Sources: {', '.join(f'{k}={v}' for k, v in self.config.sources.items())}",
f"Cascade rungs: {', '.join(r.id for r in self.config.cascade.rungs)}",
f"k_samples per rung: {self.config.cascade.k_samples}",
f"Max new tokens: {self.config.cascade.max_new_tokens}",
f"Estimated cascade wall time: {self.cascade_minutes:.1f} min",
f"Configured budget: {self.config.budget_minutes:.1f} min",
f"Over budget: {self.over_budget}",
]
if self.notes:
lines.append("Notes:")
for note in self.notes:
lines.append(f" - {note}")
return "\n".join(lines)
def plan(config: BuildConfig) -> BuildPlan:
notes: list[str] = []
cascade_s = config.cascade.projected_seconds(config.target_total_queries)
cascade_m = cascade_s / 60.0
over_budget = cascade_m > config.budget_minutes
if over_budget:
notes.append(
f"cascade projected {cascade_m:.1f} min exceeds budget {config.budget_minutes:.1f} min"
)
if config.labeler.use_gpt and not os.environ.get("OPENAI_API_KEY"):
notes.append("OPENAI_API_KEY missing; gpt vote will be skipped")
if config.labeler.use_claude and not os.environ.get("ANTHROPIC_API_KEY"):
notes.append("ANTHROPIC_API_KEY missing; claude vote will be skipped")
if config.labeler.use_gemini and not os.environ.get("GOOGLE_API_KEY"):
notes.append("GOOGLE_API_KEY missing; gemini vote will be skipped")
for src in config.sources:
if src not in SOURCE_REGISTRY:
notes.append(f"unknown source '{src}' in mix")
return BuildPlan(
config=config,
n_queries=config.target_total_queries,
cascade_seconds=cascade_s,
cascade_minutes=cascade_m,
over_budget=over_budget,
notes=notes,
)
def write_capability_labels(path: str | Path, labels: list[CapabilityLabel]) -> None:
import pandas as pd
df = pd.DataFrame([lbl.to_record() for lbl in labels])
Path(path).parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False)
def read_capability_labels(path: str | Path) -> dict[str, dict[str, float]]:
import pandas as pd
df = pd.read_parquet(path)
out: dict[str, dict[str, float]] = {}
cap_cols = [c for c in df.columns if c.startswith("cap_")]
for _, row in df.iterrows():
out[row["query_id"]] = {c[4:]: float(row[c]) for c in cap_cols}
return out
def write_raw_manifest(path: str | Path, queries: list[RawQuery]) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for q in queries:
f.write(json.dumps(q.to_dict()) + "\n")
def write_labeled_dataset(
train_path: str | Path,
test_path: str | Path,
rows: list[LabeledQuery],
test_split: float,
seed: int,
) -> None:
import pandas as pd
import random as _random
rng = _random.Random(seed)
indices = list(range(len(rows)))
rng.shuffle(indices)
n_test = max(1, int(len(rows) * test_split))
test_idx = set(indices[:n_test])
train_records = [rows[i].to_record() for i in range(len(rows)) if i not in test_idx]
test_records = [rows[i].to_record() for i in test_idx]
Path(train_path).parent.mkdir(parents=True, exist_ok=True)
pd.DataFrame(train_records).to_parquet(train_path, index=False)
pd.DataFrame(test_records).to_parquet(test_path, index=False)
def build_seed_dataset(
output_dir: str | Path,
test_split: float = 0.15,
seed: int = 42,
suffix: str = "seed",
) -> tuple[Path, Path]:
"""Materialize the curated seed entries into train/test parquet files.
Skips the cascade and the labeler: the seed entries already carry gold
capability multi-labels, difficulty (in log_params), and length buckets.
"""
from greenrouting.data.seed_dataset import (
SEED_QUERIES,
difficulty_log_params_from_b,
seed_capability_dict,
)
rows: list[LabeledQuery] = []
for i, entry in enumerate(SEED_QUERIES):
raw = RawQuery(
id=f"seed-{i:04d}",
text=entry.text,
source="seed",
source_category=entry.primary_category,
has_grader=False,
grader_metadata={},
)
rows.append(LabeledQuery(
raw=raw,
capabilities=seed_capability_dict(entry, CAPABILITY_KEYS),
difficulty_log_params=difficulty_log_params_from_b(entry.difficulty_b),
length_bucket=entry.length,
cascade_results={"source": "seed_curated"},
))
out = Path(output_dir)
train_path = out / f"train_{suffix}.parquet"
test_path = out / f"test_{suffix}.parquet"
write_labeled_dataset(train_path, test_path, rows, test_split=test_split, seed=seed)
return train_path, test_path
def sample_and_label(config: BuildConfig) -> tuple[list[RawQuery], list[CapabilityLabel]]:
queries = sample_mix(config.sources, config.target_total_queries, config.seed)
cached = {}
if config.capability_labels_cache and Path(config.capability_labels_cache).exists():
cached = read_capability_labels(config.capability_labels_cache)
new_queries = [q for q in queries if q.id not in cached]
new_labels = label_queries(new_queries, config.labeler) if new_queries else []
cached_labels: list[CapabilityLabel] = []
for q in queries:
if q.id in cached:
from greenrouting.data.schema import CapabilityVotes
cached_labels.append(CapabilityLabel(
query_id=q.id,
capabilities=cached[q.id],
votes=CapabilityVotes(),
aggregation_method="cached",
))
all_labels = new_labels + cached_labels
by_id = {l.query_id: l for l in all_labels}
aligned = [by_id[q.id] for q in queries if q.id in by_id]
return queries, aligned
|