| from __future__ import annotations |
|
|
| import csv |
| import json |
| import os |
| import re |
| from pathlib import Path |
|
|
| import torch.distributed as dist |
|
|
| import hackable |
| from hackable.utils import resolve_repo_path |
| from eval_sweep_models import ( |
| _init_distributed, |
| _load_yaml, |
| _resolve_local_model_dir, |
| evaluate_one_model, |
| ) |
|
|
|
|
| def _parse_checkpoint_step(dirname: str) -> int | None: |
| m = re.match(r"^checkpoint-(\d+)$", dirname) |
| if m: |
| return int(m.group(1)) |
| m = re.search(r"-step-(\d+)$", dirname) |
| if m: |
| return int(m.group(1)) |
| return None |
|
|
|
|
| def _discover_checkpoint_jobs( |
| base_cfg: dict, permanent_root: Path, run_label: str |
| ) -> list[tuple[str, int, str, Path, str]]: |
| """(run_label, step, resolved_model_dir_str, resolved_path, dir_name)""" |
| root = permanent_root.resolve() |
| if not root.is_dir(): |
| raise FileNotFoundError(f"Not a directory: {root}") |
| jobs: list[tuple[str, int, str, Path, str]] = [] |
| for p in sorted(root.iterdir()): |
| if not p.is_dir(): |
| continue |
| step = _parse_checkpoint_step(p.name) |
| if step is None: |
| continue |
| resolved = _resolve_local_model_dir(base_cfg, str(p)) |
| jobs.append((run_label, step, str(resolved), resolved, p.name)) |
| jobs.sort(key=lambda x: (x[1], x[4])) |
| return jobs |
|
|
|
|
| def _line_chart_svg( |
| series: list[tuple[str, list[tuple[int, float]], str]], |
| title: str, |
| y_label: str, |
| y_max: float, |
| path: Path, |
| ) -> None: |
| width = 900 |
| height = 420 |
| lm, rm, tm, bm = 70, 40, 50, 55 |
| pw = width - lm - rm |
| ph = height - tm - bm |
| yb = tm + ph |
|
|
| all_steps: list[int] = [] |
| for _, pts, _ in series: |
| all_steps.extend(s for s, _ in pts) |
| if not all_steps: |
| path.write_text( |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">' |
| f'<text x="40" y="40">{title} (no data)</text></svg>', |
| encoding="utf-8", |
| ) |
| return |
| x_min, x_max = min(all_steps), max(all_steps) |
| if x_max == x_min: |
| x_max = x_min + 1 |
|
|
| def sx(x: int) -> int: |
| return lm + int((x - x_min) / (x_max - x_min) * pw) |
|
|
| def sy(y: float) -> int: |
| y = max(0.0, min(y_max, y)) |
| return yb - int((y / y_max) * ph) if y_max > 0 else yb |
|
|
| parts: list[str] = [ |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">', |
| '<rect width="100%" height="100%" fill="#ffffff"/>', |
| f'<text x="{lm}" y="28" font-size="16" font-family="sans-serif">{title}</text>', |
| f'<text x="20" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" ' |
| f'transform="rotate(-90 20 {tm + ph // 2})">{y_label}</text>', |
| f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>', |
| f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>', |
| f'<text x="{lm + pw // 2}" y="{height - 12}" text-anchor="middle" ' |
| f'font-size="12" font-family="sans-serif">Training step</text>', |
| ] |
|
|
| for i in range(5): |
| val = (i / 4) * y_max |
| yy = sy(val) |
| parts.append( |
| f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#999"/>' |
| ) |
| parts.append( |
| f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" ' |
| f'font-family="sans-serif">{val:.2f}</text>' |
| ) |
|
|
| legend_x = lm + pw - 200 |
| legend_y = tm + 8 |
| for idx, (name, pts, color) in enumerate(series): |
| if len(pts) < 2: |
| pts_sorted = sorted(pts, key=lambda z: z[0]) |
| if not pts_sorted: |
| continue |
| cx, cy = sx(pts_sorted[0][0]), sy(pts_sorted[0][1]) |
| parts.append( |
| f'<circle cx="{cx}" cy="{cy}" r="4" fill="{color}" stroke="#111"/>' |
| ) |
| else: |
| pts_sorted = sorted(pts, key=lambda z: z[0]) |
| d = "M " + " L ".join(f"{sx(s)} {sy(v)}" for s, v in pts_sorted) |
| parts.append( |
| f'<path d="{d}" fill="none" stroke="{color}" stroke-width="2.5"/>' |
| ) |
| parts.append( |
| f'<rect x="{legend_x}" y="{legend_y + idx * 18}" width="10" height="10" fill="{color}"/>' |
| ) |
| parts.append( |
| f'<text x="{legend_x + 16}" y="{legend_y + idx * 18 + 9}" font-size="11" ' |
| f'font-family="sans-serif">{name}</text>' |
| ) |
|
|
| parts.append("</svg>") |
| path.write_text("\n".join(parts), encoding="utf-8") |
|
|
|
|
| def _scatter_accuracy_vs_cot_svg(rows: list[dict], path: Path, title: str) -> None: |
| """Scatter: x = avg_cot_words, y = accuracy. One color per ``run_label``; optional path by training step.""" |
| width = 640 |
| height = 520 |
| lm, rm, tm, bm = 72, 160, 52, 64 |
| pw = width - lm - rm |
| ph = height - tm - bm |
| yb = tm + ph |
|
|
| if not rows: |
| path.write_text( |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">' |
| f'<text x="40" y="40">{title} (no data)</text></svg>', |
| encoding="utf-8", |
| ) |
| return |
|
|
| labels: list[str] = [] |
| seen: set[str] = set() |
| for r in rows: |
| lab = str(r.get("run_label", "run")) |
| if lab not in seen: |
| seen.add(lab) |
| labels.append(lab) |
|
|
| colors = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"] |
| color_map = {lab: colors[i % len(colors)] for i, lab in enumerate(labels)} |
|
|
| xs = [float(r["avg_cot_words"]) for r in rows] |
| ys = [float(r["accuracy"]) for r in rows] |
| x_min, x_max = min(xs), max(xs) |
| y_min, y_max = 0.0, 1.0 |
| if x_max <= x_min: |
| x_max = x_min + 1.0 |
| pad = (x_max - x_min) * 0.06 + 1.0 |
| x_min = max(0.0, x_min - pad) |
| x_max = x_max + pad |
|
|
| def sx(x: float) -> float: |
| return lm + (x - x_min) / (x_max - x_min) * pw |
|
|
| def sy(y: float) -> float: |
| y = max(y_min, min(y_max, y)) |
| return yb - (y - y_min) / (y_max - y_min) * ph |
|
|
| parts: list[str] = [ |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">', |
| '<rect width="100%" height="100%" fill="#fafafa"/>', |
| f'<text x="{lm}" y="30" font-size="15" font-family="sans-serif">{title}</text>', |
| f'<text x="{width // 2}" y="{height - 18}" text-anchor="middle" font-size="12" ' |
| f'font-family="sans-serif">Avg CoT length (words)</text>', |
| f'<text x="18" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" ' |
| f'transform="rotate(-90 18 {tm + ph // 2})">Accuracy</text>', |
| f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>', |
| f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>', |
| ] |
|
|
| for i in range(5): |
| val = y_min + (i / 4) * (y_max - y_min) |
| yy = sy(val) |
| parts.append(f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#bbb"/>') |
| parts.append( |
| f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" ' |
| f'font-family="sans-serif">{val:.2f}</text>' |
| ) |
|
|
| for i in range(5): |
| frac = i / 4 |
| xv = x_min + frac * (x_max - x_min) |
| xx = sx(xv) |
| parts.append(f'<line x1="{xx}" y1="{yb}" x2="{xx}" y2="{yb + 4}" stroke="#bbb"/>') |
| parts.append( |
| f'<text x="{xx}" y="{yb + 18}" text-anchor="middle" font-size="10" ' |
| f'font-family="sans-serif">{xv:.0f}</text>' |
| ) |
|
|
| for lab in labels: |
| sub = [r for r in rows if str(r.get("run_label", "run")) == lab] |
| sub.sort(key=lambda r: int(r["checkpoint_step"])) |
| color = color_map[lab] |
| if len(sub) >= 2: |
| d = "M " + " L ".join(f'{sx(float(r["avg_cot_words"])):.1f} {sy(float(r["accuracy"])):.1f}' for r in sub) |
| parts.append( |
| f'<path d="{d}" fill="none" stroke="{color}" stroke-width="1.5" stroke-opacity="0.35"/>' |
| ) |
|
|
| for r in rows: |
| lab = str(r.get("run_label", "run")) |
| color = color_map[lab] |
| cx = sx(float(r["avg_cot_words"])) |
| cy = sy(float(r["accuracy"])) |
| step = int(r["checkpoint_step"]) |
| name = str(r.get("checkpoint_dir", f"step-{step}")) |
| tip = f"{name}: accuracy={float(r['accuracy']):.4f}, avg_cot_words={float(r['avg_cot_words']):.2f}" |
| parts.append( |
| f'<g><circle cx="{cx:.1f}" cy="{cy:.1f}" r="5" fill="{color}" stroke="#111" stroke-width="1">' |
| f"<title>{tip}</title></circle>" |
| f'<text x="{cx + 8:.1f}" y="{cy - 6:.1f}" font-size="9" font-family="sans-serif" fill="#333">{step}</text></g>' |
| ) |
|
|
| legend_x = lm + pw + 14 |
| legend_y = tm + 4 |
| parts.append( |
| f'<text x="{legend_x}" y="{legend_y}" font-size="11" font-family="sans-serif" font-weight="bold">Series</text>' |
| ) |
| for idx, lab in enumerate(labels): |
| cy = legend_y + 18 + idx * 20 |
| parts.append( |
| f'<rect x="{legend_x}" y="{cy - 8}" width="10" height="10" fill="{color_map[lab]}"/>' |
| ) |
| parts.append( |
| f'<text x="{legend_x + 16}" y="{cy}" font-size="11" font-family="sans-serif">{lab}</text>' |
| ) |
|
|
| parts.append("</svg>") |
| path.write_text("\n".join(parts), encoding="utf-8") |
|
|
|
|
| def _resolve_out_root(default: Path) -> Path: |
| raw = os.environ.get("OUT_ROOT") |
| if raw is None or not str(raw).strip(): |
| return resolve_repo_path(str(default)) |
| return resolve_repo_path(raw) |
|
|
|
|
| def main() -> None: |
| rank, _, _ = _init_distributed() |
| base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"]))) |
|
|
| eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200")) |
| eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4")) |
| rollout_n = int(os.environ.get("ROLLOUT_SAMPLES", "8")) |
|
|
| permanent_root = os.environ.get("PERMANENT_ROOT", "").strip() |
| if permanent_root: |
| pr = resolve_repo_path(permanent_root) |
| run_label_single = os.environ.get("RUN_LABEL", "permanent") |
| out_default = pr / "eval_permanent" |
| out_root = _resolve_out_root(out_default) |
| jobs_single = _discover_checkpoint_jobs(base_cfg, pr, run_label_single) |
| all_jobs = jobs_single |
| jobs_cw1: list = [] |
| jobs_cw5: list = [] |
| else: |
| cw1_root = resolve_repo_path(os.environ["PERMANENT_CW1"]) |
| cw5_root = resolve_repo_path(os.environ["PERMANENT_CW5"]) |
| out_default = cw1_root.parent / "eval_permanent" |
| out_root = _resolve_out_root(out_default) |
| jobs_cw1 = _discover_checkpoint_jobs(base_cfg, cw1_root, "correctness_weight_1") |
| jobs_cw5 = _discover_checkpoint_jobs(base_cfg, cw5_root, "correctness_weight_5") |
| all_jobs = jobs_cw1 + jobs_cw5 |
|
|
| if rank == 0: |
| out_root.mkdir(parents=True, exist_ok=True) |
| (out_root / "rollouts").mkdir(parents=True, exist_ok=True) |
| (out_root / "full_outputs").mkdir(parents=True, exist_ok=True) |
| if permanent_root: |
| print(f"PERMANENT_ROOT: {resolve_repo_path(permanent_root)} ({len(all_jobs)} checkpoints)") |
| for run_label, step, _, _, name in all_jobs: |
| print(f" {run_label} step={step} ({name})") |
| else: |
| print(f"Found {len(jobs_cw1)} checkpoints (cw=1), {len(jobs_cw5)} checkpoints (cw=5)") |
| for jl in (jobs_cw1, jobs_cw5): |
| for run_label, step, _, _, name in jl: |
| print(f" {run_label} step={step} ({name})") |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| rows: list[dict] = [] |
| for run_label, step, _resolved_str, resolved_path, dir_name in all_jobs: |
| records = evaluate_one_model( |
| model_dir=resolved_path, |
| base_cfg=base_cfg, |
| eval_max_samples=eval_max_samples, |
| batch_size=eval_batch_size, |
| ) |
| if rank == 0: |
| acc = sum(float(r["correctness"]) for r in records) / len(records) if records else 0.0 |
| avg_cot = sum(float(r["cot_words"]) for r in records) / len(records) if records else 0.0 |
| row = { |
| "run_label": run_label, |
| "checkpoint_step": step, |
| "checkpoint_dir": dir_name, |
| "model_dir": str(resolved_path), |
| "num_examples": len(records), |
| "accuracy": acc, |
| "avg_cot_words": avg_cot, |
| } |
| rows.append(row) |
|
|
| rollout_dir = out_root / "rollouts" / run_label |
| rollout_dir.mkdir(parents=True, exist_ok=True) |
| rollout_path = rollout_dir / f"{dir_name}_rollouts.jsonl" |
| with rollout_path.open("w", encoding="utf-8") as handle: |
| for rec in records[:rollout_n]: |
| handle.write(json.dumps(rec, ensure_ascii=True) + "\n") |
|
|
| full_path = out_root / "full_outputs" / run_label / f"{dir_name}_outputs.jsonl" |
| full_path.parent.mkdir(parents=True, exist_ok=True) |
| with full_path.open("w", encoding="utf-8") as handle: |
| for rec in records: |
| handle.write(json.dumps(rec, ensure_ascii=True) + "\n") |
|
|
| print( |
| f"Eval {run_label} {dir_name}: acc={acc:.4f} avg_cot_words={avg_cot:.2f} n={len(records)}" |
| ) |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| if rank != 0: |
| return |
|
|
| rows.sort(key=lambda r: (r["run_label"], r["checkpoint_step"], r["checkpoint_dir"])) |
|
|
| summary_json = out_root / "permanent_checkpoints_eval.json" |
| summary_csv = out_root / "permanent_checkpoints_eval.csv" |
| summary_json.write_text(json.dumps(rows, indent=2), encoding="utf-8") |
| with summary_csv.open("w", encoding="utf-8", newline="") as handle: |
| w = csv.DictWriter( |
| handle, |
| fieldnames=[ |
| "run_label", |
| "checkpoint_step", |
| "checkpoint_dir", |
| "model_dir", |
| "num_examples", |
| "accuracy", |
| "avg_cot_words", |
| ], |
| ) |
| w.writeheader() |
| for row in rows: |
| w.writerow(row) |
|
|
| def series_for(label: str, ykey: str) -> list[tuple[int, float]]: |
| return [ |
| (int(r["checkpoint_step"]), float(r[ykey])) |
| for r in rows |
| if r["run_label"] == label |
| ] |
|
|
| palette = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"] |
| uniq_labels = sorted({str(r["run_label"]) for r in rows}) |
| acc_series = [ |
| (lab, series_for(lab, "accuracy"), palette[i % len(palette)]) |
| for i, lab in enumerate(uniq_labels) |
| if series_for(lab, "accuracy") |
| ] |
| cot_series = [ |
| (lab, series_for(lab, "avg_cot_words"), palette[i % len(palette)]) |
| for i, lab in enumerate(uniq_labels) |
| if series_for(lab, "avg_cot_words") |
| ] |
|
|
| cot_max = 1.0 |
| for r in rows: |
| cot_max = max(cot_max, float(r["avg_cot_words"])) |
|
|
| if acc_series: |
| _line_chart_svg( |
| acc_series, |
| "GSM8K accuracy vs checkpoint step", |
| "Accuracy", |
| 1.0, |
| out_root / "accuracy_vs_step.svg", |
| ) |
| if cot_series: |
| _line_chart_svg( |
| cot_series, |
| "Average CoT length (words) vs checkpoint step", |
| "Avg CoT words", |
| cot_max, |
| out_root / "avg_cot_vs_step.svg", |
| ) |
|
|
| _scatter_accuracy_vs_cot_svg( |
| rows, |
| out_root / "accuracy_vs_avg_cot_words.svg", |
| "GSM8K accuracy vs average CoT length (words)", |
| ) |
|
|
| print(f"Saved: {summary_json}") |
| print(f"Saved: {summary_csv}") |
| if acc_series: |
| print(f"Saved: {out_root / 'accuracy_vs_step.svg'}") |
| if cot_series: |
| print(f"Saved: {out_root / 'avg_cot_vs_step.svg'}") |
| print(f"Saved: {out_root / 'accuracy_vs_avg_cot_words.svg'}") |
| print(f"Rollouts: {out_root / 'rollouts'}/<run_label>/") |
| print(f"Full outputs: {out_root / 'full_outputs'}/<run_label>/") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|