Spaces:
Sleeping
Sleeping
| """Difficulty cascade: runs each query against an ascending ladder of models, | |
| grades the response, and derives a continuous `min_capable_log_params` label. | |
| Memory strategy: load one rung at a time, run all queries, dump checkpoint, free | |
| the weights, then advance to the next rung. Resumes from per-rung JSONL files. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import time | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| from greenrouting.data.graders import grade | |
| from greenrouting.data.schema import LabeledQuery, RawQuery | |
| class RungResult: | |
| rung_id: str | |
| params_b: float | |
| query_id: str | |
| sample_index: int | |
| response: str | |
| score: float | |
| response_tokens: int | |
| def _read_raw_manifest(path: str | Path) -> list[RawQuery]: | |
| queries: list[RawQuery] = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| data = json.loads(line) | |
| queries.append(RawQuery(**data)) | |
| return queries | |
| def _read_rung_checkpoint(path: Path) -> dict[str, list[RungResult]]: | |
| if not path.exists(): | |
| return {} | |
| out: dict[str, list[RungResult]] = {} | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| row = json.loads(line) | |
| r = RungResult(**row) | |
| out.setdefault(r.query_id, []).append(r) | |
| return out | |
| def _append_rung_checkpoint(path: Path, result: RungResult) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(asdict(result)) + "\n") | |
| def _load_model_and_tokenizer(hf_model: str): | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| tok = AutoTokenizer.from_pretrained(hf_model) | |
| if tok.pad_token_id is None: | |
| tok.pad_token_id = tok.eos_token_id | |
| model = AutoModelForCausalLM.from_pretrained( | |
| hf_model, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| model.eval() | |
| return tok, model | |
| def _free_model(model) -> None: | |
| import gc | |
| del model | |
| gc.collect() | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| def _format_prompt(tok, query: str) -> str: | |
| if hasattr(tok, "apply_chat_template") and tok.chat_template: | |
| messages = [{"role": "user", "content": query}] | |
| return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| return f"### Instruction:\n{query}\n\n### Response:\n" | |
| def _generate(tok, model, prompt: str, max_new_tokens: int, temperature: float) -> tuple[str, int]: | |
| import torch | |
| inputs = tok(prompt, return_tensors="pt").to(model.device) | |
| do_sample = temperature > 0 | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature if do_sample else 1.0, | |
| pad_token_id=tok.pad_token_id, | |
| ) | |
| new_tokens = out[0][inputs["input_ids"].shape[1]:] | |
| response = tok.decode(new_tokens, skip_special_tokens=True) | |
| return response.strip(), int(new_tokens.shape[0]) | |
| def run_rung( | |
| rung, | |
| queries: list[RawQuery], | |
| k_samples: int, | |
| max_new_tokens: int, | |
| temperature_first: float, | |
| temperature_resample: float, | |
| checkpoint_path: Path, | |
| progress: bool = True, | |
| ) -> list[RungResult]: | |
| existing = _read_rung_checkpoint(checkpoint_path) | |
| pending = [q for q in queries if len(existing.get(q.id, [])) < k_samples] | |
| results: list[RungResult] = [r for rs in existing.values() for r in rs] | |
| if not pending: | |
| return results | |
| tok, model = _load_model_and_tokenizer(rung.hf_model) | |
| try: | |
| for i, q in enumerate(pending): | |
| done = len(existing.get(q.id, [])) | |
| for s in range(done, k_samples): | |
| temp = temperature_first if s == 0 else temperature_resample | |
| prompt = _format_prompt(tok, q.text) | |
| start = time.time() | |
| response, n_tokens = _generate(tok, model, prompt, max_new_tokens, temp) | |
| score = grade(response, q.grader_metadata, max_new_tokens=max_new_tokens) | |
| rr = RungResult( | |
| rung_id=rung.id, | |
| params_b=rung.params_b, | |
| query_id=q.id, | |
| sample_index=s, | |
| response=response, | |
| score=score, | |
| response_tokens=n_tokens, | |
| ) | |
| _append_rung_checkpoint(checkpoint_path, rr) | |
| results.append(rr) | |
| if progress: | |
| print( | |
| f" [{rung.id}] {i+1}/{len(pending)} sample={s} " | |
| f"score={score:.2f} tok={n_tokens} t={time.time()-start:.1f}s" | |
| ) | |
| finally: | |
| _free_model(model) | |
| return results | |
| def derive_difficulty( | |
| per_rung: dict[str, list[float]], | |
| rung_params_b: dict[str, float], | |
| pass_threshold: float, | |
| ) -> float: | |
| """Continuous min_capable_log_params from per-rung mean scores. | |
| Logic: | |
| - sort rungs by parameter count | |
| - for each rung, mean score across samples is the "rung pass rate" | |
| - the smallest rung whose pass rate >= threshold defines the floor | |
| - linear interpolation in log(params) space between the failing and passing rung | |
| - if no rung passes, return log(largest_rung_params * 2) as out-of-pool | |
| - if smallest rung already passes, return log(smallest_rung_params) | |
| """ | |
| sorted_rungs = sorted(rung_params_b.items(), key=lambda kv: kv[1]) | |
| if not sorted_rungs: | |
| return math.log(8e9) | |
| means: list[tuple[str, float, float]] = [] | |
| for rung_id, params_b in sorted_rungs: | |
| scores = per_rung.get(rung_id, []) | |
| if not scores: | |
| continue | |
| means.append((rung_id, params_b, sum(scores) / len(scores))) | |
| if not means: | |
| return math.log(sorted_rungs[-1][1] * 1e9 * 2) | |
| if means[0][2] >= pass_threshold: | |
| return math.log(means[0][1] * 1e9) | |
| for i in range(1, len(means)): | |
| prev_id, prev_params, prev_score = means[i - 1] | |
| cur_id, cur_params, cur_score = means[i] | |
| if cur_score >= pass_threshold: | |
| denom = max(cur_score - prev_score, 1e-6) | |
| t = max(0.0, min(1.0, (pass_threshold - prev_score) / denom)) | |
| log_lo = math.log(prev_params * 1e9) | |
| log_hi = math.log(cur_params * 1e9) | |
| return log_lo + t * (log_hi - log_lo) | |
| return math.log(means[-1][1] * 1e9 * 2) | |
| def derive_length_bucket(response_token_counts: list[int]) -> str: | |
| if not response_token_counts: | |
| return "medium" | |
| avg = sum(response_token_counts) / len(response_token_counts) | |
| if avg < 100: | |
| return "short" | |
| if avg < 400: | |
| return "medium" | |
| return "long" | |
| def run_cascade( | |
| config, | |
| raw_manifest_path: str | Path, | |
| capability_labels_path: str | Path, | |
| train_path: str | Path, | |
| test_path: str | Path, | |
| pass_threshold: float = 0.7, | |
| ) -> None: | |
| from greenrouting.data.builder import ( | |
| read_capability_labels, | |
| write_labeled_dataset, | |
| ) | |
| queries = _read_raw_manifest(raw_manifest_path) | |
| cap_labels = read_capability_labels(capability_labels_path) | |
| per_rung_results: dict[str, dict[str, list[RungResult]]] = {} | |
| out_dir = Path(config.output_dir) | |
| for rung in config.cascade.rungs: | |
| if not rung.runs_locally: | |
| print(f"[skip] {rung.id} marked as not runs_locally; configure remote backend.") | |
| continue | |
| ckpt = out_dir / f"cascade_{config.profile_name}_{rung.id}.jsonl" | |
| results = run_rung( | |
| rung, | |
| queries, | |
| k_samples=config.cascade.k_samples, | |
| max_new_tokens=config.cascade.max_new_tokens, | |
| temperature_first=config.cascade.temperature_first, | |
| temperature_resample=config.cascade.temperature_resample, | |
| checkpoint_path=ckpt, | |
| ) | |
| by_query: dict[str, list[RungResult]] = {} | |
| for r in results: | |
| by_query.setdefault(r.query_id, []).append(r) | |
| per_rung_results[rung.id] = by_query | |
| print(f"[done] rung {rung.id}: {sum(len(v) for v in by_query.values())} samples") | |
| rung_params: dict[str, float] = {r.id: r.params_b for r in config.cascade.rungs} | |
| labeled: list[LabeledQuery] = [] | |
| for q in queries: | |
| per_rung_scores: dict[str, list[float]] = {} | |
| token_counts: list[int] = [] | |
| for rung_id, by_query in per_rung_results.items(): | |
| for rr in by_query.get(q.id, []): | |
| per_rung_scores.setdefault(rung_id, []).append(rr.score) | |
| token_counts.append(rr.response_tokens) | |
| if not per_rung_scores: | |
| continue | |
| difficulty = derive_difficulty(per_rung_scores, rung_params, pass_threshold) | |
| length_bucket = derive_length_bucket(token_counts) | |
| caps = cap_labels.get(q.id, {}) | |
| labeled.append(LabeledQuery( | |
| raw=q, | |
| capabilities=caps, | |
| difficulty_log_params=difficulty, | |
| length_bucket=length_bucket, | |
| cascade_results={ | |
| "per_rung_mean_scores": { | |
| k: sum(v) / len(v) for k, v in per_rung_scores.items() | |
| }, | |
| }, | |
| )) | |
| write_labeled_dataset( | |
| train_path=train_path, | |
| test_path=test_path, | |
| rows=labeled, | |
| test_split=config.test_split, | |
| seed=config.seed, | |
| ) | |
| print(f"[done] wrote {len(labeled)} labeled rows -> {train_path}, {test_path}") | |