"""Difficulty cascade: runs each query against an ascending ladder of models, grades the response, and derives a continuous `min_capable_log_params` label. Memory strategy: load one rung at a time, run all queries, dump checkpoint, free the weights, then advance to the next rung. Resumes from per-rung JSONL files. """ from __future__ import annotations import json import math import time from dataclasses import asdict, dataclass from pathlib import Path from typing import Iterable, Optional from greenrouting.data.graders import grade from greenrouting.data.schema import LabeledQuery, RawQuery @dataclass class RungResult: rung_id: str params_b: float query_id: str sample_index: int response: str score: float response_tokens: int def _read_raw_manifest(path: str | Path) -> list[RawQuery]: queries: list[RawQuery] = [] with open(path, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue data = json.loads(line) queries.append(RawQuery(**data)) return queries def _read_rung_checkpoint(path: Path) -> dict[str, list[RungResult]]: if not path.exists(): return {} out: dict[str, list[RungResult]] = {} with open(path, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue row = json.loads(line) r = RungResult(**row) out.setdefault(r.query_id, []).append(r) return out def _append_rung_checkpoint(path: Path, result: RungResult) -> None: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "a", encoding="utf-8") as f: f.write(json.dumps(asdict(result)) + "\n") def _load_model_and_tokenizer(hf_model: str): import torch from transformers import AutoModelForCausalLM, AutoTokenizer dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 tok = AutoTokenizer.from_pretrained(hf_model) if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id model = AutoModelForCausalLM.from_pretrained( hf_model, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None ) model.eval() return tok, model def _free_model(model) -> None: import gc del model gc.collect() try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass def _format_prompt(tok, query: str) -> str: if hasattr(tok, "apply_chat_template") and tok.chat_template: messages = [{"role": "user", "content": query}] return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return f"### Instruction:\n{query}\n\n### Response:\n" def _generate(tok, model, prompt: str, max_new_tokens: int, temperature: float) -> tuple[str, int]: import torch inputs = tok(prompt, return_tensors="pt").to(model.device) do_sample = temperature > 0 with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature if do_sample else 1.0, pad_token_id=tok.pad_token_id, ) new_tokens = out[0][inputs["input_ids"].shape[1]:] response = tok.decode(new_tokens, skip_special_tokens=True) return response.strip(), int(new_tokens.shape[0]) def run_rung( rung, queries: list[RawQuery], k_samples: int, max_new_tokens: int, temperature_first: float, temperature_resample: float, checkpoint_path: Path, progress: bool = True, ) -> list[RungResult]: existing = _read_rung_checkpoint(checkpoint_path) pending = [q for q in queries if len(existing.get(q.id, [])) < k_samples] results: list[RungResult] = [r for rs in existing.values() for r in rs] if not pending: return results tok, model = _load_model_and_tokenizer(rung.hf_model) try: for i, q in enumerate(pending): done = len(existing.get(q.id, [])) for s in range(done, k_samples): temp = temperature_first if s == 0 else temperature_resample prompt = _format_prompt(tok, q.text) start = time.time() response, n_tokens = _generate(tok, model, prompt, max_new_tokens, temp) score = grade(response, q.grader_metadata, max_new_tokens=max_new_tokens) rr = RungResult( rung_id=rung.id, params_b=rung.params_b, query_id=q.id, sample_index=s, response=response, score=score, response_tokens=n_tokens, ) _append_rung_checkpoint(checkpoint_path, rr) results.append(rr) if progress: print( f" [{rung.id}] {i+1}/{len(pending)} sample={s} " f"score={score:.2f} tok={n_tokens} t={time.time()-start:.1f}s" ) finally: _free_model(model) return results def derive_difficulty( per_rung: dict[str, list[float]], rung_params_b: dict[str, float], pass_threshold: float, ) -> float: """Continuous min_capable_log_params from per-rung mean scores. Logic: - sort rungs by parameter count - for each rung, mean score across samples is the "rung pass rate" - the smallest rung whose pass rate >= threshold defines the floor - linear interpolation in log(params) space between the failing and passing rung - if no rung passes, return log(largest_rung_params * 2) as out-of-pool - if smallest rung already passes, return log(smallest_rung_params) """ sorted_rungs = sorted(rung_params_b.items(), key=lambda kv: kv[1]) if not sorted_rungs: return math.log(8e9) means: list[tuple[str, float, float]] = [] for rung_id, params_b in sorted_rungs: scores = per_rung.get(rung_id, []) if not scores: continue means.append((rung_id, params_b, sum(scores) / len(scores))) if not means: return math.log(sorted_rungs[-1][1] * 1e9 * 2) if means[0][2] >= pass_threshold: return math.log(means[0][1] * 1e9) for i in range(1, len(means)): prev_id, prev_params, prev_score = means[i - 1] cur_id, cur_params, cur_score = means[i] if cur_score >= pass_threshold: denom = max(cur_score - prev_score, 1e-6) t = max(0.0, min(1.0, (pass_threshold - prev_score) / denom)) log_lo = math.log(prev_params * 1e9) log_hi = math.log(cur_params * 1e9) return log_lo + t * (log_hi - log_lo) return math.log(means[-1][1] * 1e9 * 2) def derive_length_bucket(response_token_counts: list[int]) -> str: if not response_token_counts: return "medium" avg = sum(response_token_counts) / len(response_token_counts) if avg < 100: return "short" if avg < 400: return "medium" return "long" def run_cascade( config, raw_manifest_path: str | Path, capability_labels_path: str | Path, train_path: str | Path, test_path: str | Path, pass_threshold: float = 0.7, ) -> None: from greenrouting.data.builder import ( read_capability_labels, write_labeled_dataset, ) queries = _read_raw_manifest(raw_manifest_path) cap_labels = read_capability_labels(capability_labels_path) per_rung_results: dict[str, dict[str, list[RungResult]]] = {} out_dir = Path(config.output_dir) for rung in config.cascade.rungs: if not rung.runs_locally: print(f"[skip] {rung.id} marked as not runs_locally; configure remote backend.") continue ckpt = out_dir / f"cascade_{config.profile_name}_{rung.id}.jsonl" results = run_rung( rung, queries, k_samples=config.cascade.k_samples, max_new_tokens=config.cascade.max_new_tokens, temperature_first=config.cascade.temperature_first, temperature_resample=config.cascade.temperature_resample, checkpoint_path=ckpt, ) by_query: dict[str, list[RungResult]] = {} for r in results: by_query.setdefault(r.query_id, []).append(r) per_rung_results[rung.id] = by_query print(f"[done] rung {rung.id}: {sum(len(v) for v in by_query.values())} samples") rung_params: dict[str, float] = {r.id: r.params_b for r in config.cascade.rungs} labeled: list[LabeledQuery] = [] for q in queries: per_rung_scores: dict[str, list[float]] = {} token_counts: list[int] = [] for rung_id, by_query in per_rung_results.items(): for rr in by_query.get(q.id, []): per_rung_scores.setdefault(rung_id, []).append(rr.score) token_counts.append(rr.response_tokens) if not per_rung_scores: continue difficulty = derive_difficulty(per_rung_scores, rung_params, pass_threshold) length_bucket = derive_length_bucket(token_counts) caps = cap_labels.get(q.id, {}) labeled.append(LabeledQuery( raw=q, capabilities=caps, difficulty_log_params=difficulty, length_bucket=length_bucket, cascade_results={ "per_rung_mean_scores": { k: sum(v) / len(v) for k, v in per_rung_scores.items() }, }, )) write_labeled_dataset( train_path=train_path, test_path=test_path, rows=labeled, test_split=config.test_split, seed=config.seed, ) print(f"[done] wrote {len(labeled)} labeled rows -> {train_path}, {test_path}")