router-api / greenrouting /data /cascade.py
spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""Difficulty cascade: runs each query against an ascending ladder of models,
grades the response, and derives a continuous `min_capable_log_params` label.
Memory strategy: load one rung at a time, run all queries, dump checkpoint, free
the weights, then advance to the next rung. Resumes from per-rung JSONL files.
"""
from __future__ import annotations
import json
import math
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Iterable, Optional
from greenrouting.data.graders import grade
from greenrouting.data.schema import LabeledQuery, RawQuery
@dataclass
class RungResult:
rung_id: str
params_b: float
query_id: str
sample_index: int
response: str
score: float
response_tokens: int
def _read_raw_manifest(path: str | Path) -> list[RawQuery]:
queries: list[RawQuery] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
queries.append(RawQuery(**data))
return queries
def _read_rung_checkpoint(path: Path) -> dict[str, list[RungResult]]:
if not path.exists():
return {}
out: dict[str, list[RungResult]] = {}
with open(path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
row = json.loads(line)
r = RungResult(**row)
out.setdefault(r.query_id, []).append(r)
return out
def _append_rung_checkpoint(path: Path, result: RungResult) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(asdict(result)) + "\n")
def _load_model_and_tokenizer(hf_model: str):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tok = AutoTokenizer.from_pretrained(hf_model)
if tok.pad_token_id is None:
tok.pad_token_id = tok.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
hf_model, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None
)
model.eval()
return tok, model
def _free_model(model) -> None:
import gc
del model
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
def _format_prompt(tok, query: str) -> str:
if hasattr(tok, "apply_chat_template") and tok.chat_template:
messages = [{"role": "user", "content": query}]
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return f"### Instruction:\n{query}\n\n### Response:\n"
def _generate(tok, model, prompt: str, max_new_tokens: int, temperature: float) -> tuple[str, int]:
import torch
inputs = tok(prompt, return_tensors="pt").to(model.device)
do_sample = temperature > 0
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
pad_token_id=tok.pad_token_id,
)
new_tokens = out[0][inputs["input_ids"].shape[1]:]
response = tok.decode(new_tokens, skip_special_tokens=True)
return response.strip(), int(new_tokens.shape[0])
def run_rung(
rung,
queries: list[RawQuery],
k_samples: int,
max_new_tokens: int,
temperature_first: float,
temperature_resample: float,
checkpoint_path: Path,
progress: bool = True,
) -> list[RungResult]:
existing = _read_rung_checkpoint(checkpoint_path)
pending = [q for q in queries if len(existing.get(q.id, [])) < k_samples]
results: list[RungResult] = [r for rs in existing.values() for r in rs]
if not pending:
return results
tok, model = _load_model_and_tokenizer(rung.hf_model)
try:
for i, q in enumerate(pending):
done = len(existing.get(q.id, []))
for s in range(done, k_samples):
temp = temperature_first if s == 0 else temperature_resample
prompt = _format_prompt(tok, q.text)
start = time.time()
response, n_tokens = _generate(tok, model, prompt, max_new_tokens, temp)
score = grade(response, q.grader_metadata, max_new_tokens=max_new_tokens)
rr = RungResult(
rung_id=rung.id,
params_b=rung.params_b,
query_id=q.id,
sample_index=s,
response=response,
score=score,
response_tokens=n_tokens,
)
_append_rung_checkpoint(checkpoint_path, rr)
results.append(rr)
if progress:
print(
f" [{rung.id}] {i+1}/{len(pending)} sample={s} "
f"score={score:.2f} tok={n_tokens} t={time.time()-start:.1f}s"
)
finally:
_free_model(model)
return results
def derive_difficulty(
per_rung: dict[str, list[float]],
rung_params_b: dict[str, float],
pass_threshold: float,
) -> float:
"""Continuous min_capable_log_params from per-rung mean scores.
Logic:
- sort rungs by parameter count
- for each rung, mean score across samples is the "rung pass rate"
- the smallest rung whose pass rate >= threshold defines the floor
- linear interpolation in log(params) space between the failing and passing rung
- if no rung passes, return log(largest_rung_params * 2) as out-of-pool
- if smallest rung already passes, return log(smallest_rung_params)
"""
sorted_rungs = sorted(rung_params_b.items(), key=lambda kv: kv[1])
if not sorted_rungs:
return math.log(8e9)
means: list[tuple[str, float, float]] = []
for rung_id, params_b in sorted_rungs:
scores = per_rung.get(rung_id, [])
if not scores:
continue
means.append((rung_id, params_b, sum(scores) / len(scores)))
if not means:
return math.log(sorted_rungs[-1][1] * 1e9 * 2)
if means[0][2] >= pass_threshold:
return math.log(means[0][1] * 1e9)
for i in range(1, len(means)):
prev_id, prev_params, prev_score = means[i - 1]
cur_id, cur_params, cur_score = means[i]
if cur_score >= pass_threshold:
denom = max(cur_score - prev_score, 1e-6)
t = max(0.0, min(1.0, (pass_threshold - prev_score) / denom))
log_lo = math.log(prev_params * 1e9)
log_hi = math.log(cur_params * 1e9)
return log_lo + t * (log_hi - log_lo)
return math.log(means[-1][1] * 1e9 * 2)
def derive_length_bucket(response_token_counts: list[int]) -> str:
if not response_token_counts:
return "medium"
avg = sum(response_token_counts) / len(response_token_counts)
if avg < 100:
return "short"
if avg < 400:
return "medium"
return "long"
def run_cascade(
config,
raw_manifest_path: str | Path,
capability_labels_path: str | Path,
train_path: str | Path,
test_path: str | Path,
pass_threshold: float = 0.7,
) -> None:
from greenrouting.data.builder import (
read_capability_labels,
write_labeled_dataset,
)
queries = _read_raw_manifest(raw_manifest_path)
cap_labels = read_capability_labels(capability_labels_path)
per_rung_results: dict[str, dict[str, list[RungResult]]] = {}
out_dir = Path(config.output_dir)
for rung in config.cascade.rungs:
if not rung.runs_locally:
print(f"[skip] {rung.id} marked as not runs_locally; configure remote backend.")
continue
ckpt = out_dir / f"cascade_{config.profile_name}_{rung.id}.jsonl"
results = run_rung(
rung,
queries,
k_samples=config.cascade.k_samples,
max_new_tokens=config.cascade.max_new_tokens,
temperature_first=config.cascade.temperature_first,
temperature_resample=config.cascade.temperature_resample,
checkpoint_path=ckpt,
)
by_query: dict[str, list[RungResult]] = {}
for r in results:
by_query.setdefault(r.query_id, []).append(r)
per_rung_results[rung.id] = by_query
print(f"[done] rung {rung.id}: {sum(len(v) for v in by_query.values())} samples")
rung_params: dict[str, float] = {r.id: r.params_b for r in config.cascade.rungs}
labeled: list[LabeledQuery] = []
for q in queries:
per_rung_scores: dict[str, list[float]] = {}
token_counts: list[int] = []
for rung_id, by_query in per_rung_results.items():
for rr in by_query.get(q.id, []):
per_rung_scores.setdefault(rung_id, []).append(rr.score)
token_counts.append(rr.response_tokens)
if not per_rung_scores:
continue
difficulty = derive_difficulty(per_rung_scores, rung_params, pass_threshold)
length_bucket = derive_length_bucket(token_counts)
caps = cap_labels.get(q.id, {})
labeled.append(LabeledQuery(
raw=q,
capabilities=caps,
difficulty_log_params=difficulty,
length_bucket=length_bucket,
cascade_results={
"per_rung_mean_scores": {
k: sum(v) / len(v) for k, v in per_rung_scores.items()
},
},
))
write_labeled_dataset(
train_path=train_path,
test_path=test_path,
rows=labeled,
test_split=config.test_split,
seed=config.seed,
)
print(f"[done] wrote {len(labeled)} labeled rows -> {train_path}, {test_path}")