from __future__ import annotations import csv import json import os import re from pathlib import Path import torch.distributed as dist import hackable # noqa: F401 from hackable.utils import resolve_repo_path from eval_sweep_models import ( _init_distributed, _load_yaml, _resolve_local_model_dir, evaluate_one_model, ) def _parse_checkpoint_step(dirname: str) -> int | None: m = re.match(r"^checkpoint-(\d+)$", dirname) if m: return int(m.group(1)) m = re.search(r"-step-(\d+)$", dirname) if m: return int(m.group(1)) return None def _discover_checkpoint_jobs( base_cfg: dict, permanent_root: Path, run_label: str ) -> list[tuple[str, int, str, Path, str]]: """(run_label, step, resolved_model_dir_str, resolved_path, dir_name)""" root = permanent_root.resolve() if not root.is_dir(): raise FileNotFoundError(f"Not a directory: {root}") jobs: list[tuple[str, int, str, Path, str]] = [] for p in sorted(root.iterdir()): if not p.is_dir(): continue step = _parse_checkpoint_step(p.name) if step is None: continue resolved = _resolve_local_model_dir(base_cfg, str(p)) jobs.append((run_label, step, str(resolved), resolved, p.name)) jobs.sort(key=lambda x: (x[1], x[4])) return jobs def _line_chart_svg( series: list[tuple[str, list[tuple[int, float]], str]], title: str, y_label: str, y_max: float, path: Path, ) -> None: width = 900 height = 420 lm, rm, tm, bm = 70, 40, 50, 55 pw = width - lm - rm ph = height - tm - bm yb = tm + ph all_steps: list[int] = [] for _, pts, _ in series: all_steps.extend(s for s, _ in pts) if not all_steps: path.write_text( f'' f'{title} (no data)', encoding="utf-8", ) return x_min, x_max = min(all_steps), max(all_steps) if x_max == x_min: x_max = x_min + 1 def sx(x: int) -> int: return lm + int((x - x_min) / (x_max - x_min) * pw) def sy(y: float) -> int: y = max(0.0, min(y_max, y)) return yb - int((y / y_max) * ph) if y_max > 0 else yb parts: list[str] = [ f'', '', f'{title}', f'{y_label}', f'', f'', f'Training step', ] for i in range(5): val = (i / 4) * y_max yy = sy(val) parts.append( f'' ) parts.append( f'{val:.2f}' ) legend_x = lm + pw - 200 legend_y = tm + 8 for idx, (name, pts, color) in enumerate(series): if len(pts) < 2: pts_sorted = sorted(pts, key=lambda z: z[0]) if not pts_sorted: continue cx, cy = sx(pts_sorted[0][0]), sy(pts_sorted[0][1]) parts.append( f'' ) else: pts_sorted = sorted(pts, key=lambda z: z[0]) d = "M " + " L ".join(f"{sx(s)} {sy(v)}" for s, v in pts_sorted) parts.append( f'' ) parts.append( f'' ) parts.append( f'{name}' ) parts.append("") path.write_text("\n".join(parts), encoding="utf-8") def _scatter_accuracy_vs_cot_svg(rows: list[dict], path: Path, title: str) -> None: """Scatter: x = avg_cot_words, y = accuracy. One color per ``run_label``; optional path by training step.""" width = 640 height = 520 lm, rm, tm, bm = 72, 160, 52, 64 pw = width - lm - rm ph = height - tm - bm yb = tm + ph if not rows: path.write_text( f'' f'{title} (no data)', encoding="utf-8", ) return labels: list[str] = [] seen: set[str] = set() for r in rows: lab = str(r.get("run_label", "run")) if lab not in seen: seen.add(lab) labels.append(lab) colors = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"] color_map = {lab: colors[i % len(colors)] for i, lab in enumerate(labels)} xs = [float(r["avg_cot_words"]) for r in rows] ys = [float(r["accuracy"]) for r in rows] x_min, x_max = min(xs), max(xs) y_min, y_max = 0.0, 1.0 if x_max <= x_min: x_max = x_min + 1.0 pad = (x_max - x_min) * 0.06 + 1.0 x_min = max(0.0, x_min - pad) x_max = x_max + pad def sx(x: float) -> float: return lm + (x - x_min) / (x_max - x_min) * pw def sy(y: float) -> float: y = max(y_min, min(y_max, y)) return yb - (y - y_min) / (y_max - y_min) * ph parts: list[str] = [ f'', '', f'{title}', f'Avg CoT length (words)', f'Accuracy', f'', f'', ] for i in range(5): val = y_min + (i / 4) * (y_max - y_min) yy = sy(val) parts.append(f'') parts.append( f'{val:.2f}' ) for i in range(5): frac = i / 4 xv = x_min + frac * (x_max - x_min) xx = sx(xv) parts.append(f'') parts.append( f'{xv:.0f}' ) for lab in labels: sub = [r for r in rows if str(r.get("run_label", "run")) == lab] sub.sort(key=lambda r: int(r["checkpoint_step"])) color = color_map[lab] if len(sub) >= 2: d = "M " + " L ".join(f'{sx(float(r["avg_cot_words"])):.1f} {sy(float(r["accuracy"])):.1f}' for r in sub) parts.append( f'' ) for r in rows: lab = str(r.get("run_label", "run")) color = color_map[lab] cx = sx(float(r["avg_cot_words"])) cy = sy(float(r["accuracy"])) step = int(r["checkpoint_step"]) name = str(r.get("checkpoint_dir", f"step-{step}")) tip = f"{name}: accuracy={float(r['accuracy']):.4f}, avg_cot_words={float(r['avg_cot_words']):.2f}" parts.append( f'' f"{tip}" f'{step}' ) legend_x = lm + pw + 14 legend_y = tm + 4 parts.append( f'Series' ) for idx, lab in enumerate(labels): cy = legend_y + 18 + idx * 20 parts.append( f'' ) parts.append( f'{lab}' ) parts.append("") path.write_text("\n".join(parts), encoding="utf-8") def _resolve_out_root(default: Path) -> Path: raw = os.environ.get("OUT_ROOT") if raw is None or not str(raw).strip(): return resolve_repo_path(str(default)) return resolve_repo_path(raw) def main() -> None: rank, _, _ = _init_distributed() base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"]))) eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200")) eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4")) rollout_n = int(os.environ.get("ROLLOUT_SAMPLES", "8")) permanent_root = os.environ.get("PERMANENT_ROOT", "").strip() if permanent_root: pr = resolve_repo_path(permanent_root) run_label_single = os.environ.get("RUN_LABEL", "permanent") out_default = pr / "eval_permanent" out_root = _resolve_out_root(out_default) jobs_single = _discover_checkpoint_jobs(base_cfg, pr, run_label_single) all_jobs = jobs_single jobs_cw1: list = [] jobs_cw5: list = [] else: cw1_root = resolve_repo_path(os.environ["PERMANENT_CW1"]) cw5_root = resolve_repo_path(os.environ["PERMANENT_CW5"]) out_default = cw1_root.parent / "eval_permanent" out_root = _resolve_out_root(out_default) jobs_cw1 = _discover_checkpoint_jobs(base_cfg, cw1_root, "correctness_weight_1") jobs_cw5 = _discover_checkpoint_jobs(base_cfg, cw5_root, "correctness_weight_5") all_jobs = jobs_cw1 + jobs_cw5 if rank == 0: out_root.mkdir(parents=True, exist_ok=True) (out_root / "rollouts").mkdir(parents=True, exist_ok=True) (out_root / "full_outputs").mkdir(parents=True, exist_ok=True) if permanent_root: print(f"PERMANENT_ROOT: {resolve_repo_path(permanent_root)} ({len(all_jobs)} checkpoints)") for run_label, step, _, _, name in all_jobs: print(f" {run_label} step={step} ({name})") else: print(f"Found {len(jobs_cw1)} checkpoints (cw=1), {len(jobs_cw5)} checkpoints (cw=5)") for jl in (jobs_cw1, jobs_cw5): for run_label, step, _, _, name in jl: print(f" {run_label} step={step} ({name})") if dist.is_initialized(): dist.barrier() rows: list[dict] = [] for run_label, step, _resolved_str, resolved_path, dir_name in all_jobs: records = evaluate_one_model( model_dir=resolved_path, base_cfg=base_cfg, eval_max_samples=eval_max_samples, batch_size=eval_batch_size, ) if rank == 0: acc = sum(float(r["correctness"]) for r in records) / len(records) if records else 0.0 avg_cot = sum(float(r["cot_words"]) for r in records) / len(records) if records else 0.0 row = { "run_label": run_label, "checkpoint_step": step, "checkpoint_dir": dir_name, "model_dir": str(resolved_path), "num_examples": len(records), "accuracy": acc, "avg_cot_words": avg_cot, } rows.append(row) rollout_dir = out_root / "rollouts" / run_label rollout_dir.mkdir(parents=True, exist_ok=True) rollout_path = rollout_dir / f"{dir_name}_rollouts.jsonl" with rollout_path.open("w", encoding="utf-8") as handle: for rec in records[:rollout_n]: handle.write(json.dumps(rec, ensure_ascii=True) + "\n") full_path = out_root / "full_outputs" / run_label / f"{dir_name}_outputs.jsonl" full_path.parent.mkdir(parents=True, exist_ok=True) with full_path.open("w", encoding="utf-8") as handle: for rec in records: handle.write(json.dumps(rec, ensure_ascii=True) + "\n") print( f"Eval {run_label} {dir_name}: acc={acc:.4f} avg_cot_words={avg_cot:.2f} n={len(records)}" ) if dist.is_initialized(): dist.barrier() if rank != 0: return rows.sort(key=lambda r: (r["run_label"], r["checkpoint_step"], r["checkpoint_dir"])) summary_json = out_root / "permanent_checkpoints_eval.json" summary_csv = out_root / "permanent_checkpoints_eval.csv" summary_json.write_text(json.dumps(rows, indent=2), encoding="utf-8") with summary_csv.open("w", encoding="utf-8", newline="") as handle: w = csv.DictWriter( handle, fieldnames=[ "run_label", "checkpoint_step", "checkpoint_dir", "model_dir", "num_examples", "accuracy", "avg_cot_words", ], ) w.writeheader() for row in rows: w.writerow(row) def series_for(label: str, ykey: str) -> list[tuple[int, float]]: return [ (int(r["checkpoint_step"]), float(r[ykey])) for r in rows if r["run_label"] == label ] palette = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"] uniq_labels = sorted({str(r["run_label"]) for r in rows}) acc_series = [ (lab, series_for(lab, "accuracy"), palette[i % len(palette)]) for i, lab in enumerate(uniq_labels) if series_for(lab, "accuracy") ] cot_series = [ (lab, series_for(lab, "avg_cot_words"), palette[i % len(palette)]) for i, lab in enumerate(uniq_labels) if series_for(lab, "avg_cot_words") ] cot_max = 1.0 for r in rows: cot_max = max(cot_max, float(r["avg_cot_words"])) if acc_series: _line_chart_svg( acc_series, "GSM8K accuracy vs checkpoint step", "Accuracy", 1.0, out_root / "accuracy_vs_step.svg", ) if cot_series: _line_chart_svg( cot_series, "Average CoT length (words) vs checkpoint step", "Avg CoT words", cot_max, out_root / "avg_cot_vs_step.svg", ) _scatter_accuracy_vs_cot_svg( rows, out_root / "accuracy_vs_avg_cot_words.svg", "GSM8K accuracy vs average CoT length (words)", ) print(f"Saved: {summary_json}") print(f"Saved: {summary_csv}") if acc_series: print(f"Saved: {out_root / 'accuracy_vs_step.svg'}") if cot_series: print(f"Saved: {out_root / 'avg_cot_vs_step.svg'}") print(f"Saved: {out_root / 'accuracy_vs_avg_cot_words.svg'}") print(f"Rollouts: {out_root / 'rollouts'}//") print(f"Full outputs: {out_root / 'full_outputs'}//") if __name__ == "__main__": main()