neuralese_temp / src /eval_permanent_checkpoints.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
from __future__ import annotations
import csv
import json
import os
import re
from pathlib import Path
import torch.distributed as dist
import hackable # noqa: F401
from hackable.utils import resolve_repo_path
from eval_sweep_models import (
_init_distributed,
_load_yaml,
_resolve_local_model_dir,
evaluate_one_model,
)
def _parse_checkpoint_step(dirname: str) -> int | None:
m = re.match(r"^checkpoint-(\d+)$", dirname)
if m:
return int(m.group(1))
m = re.search(r"-step-(\d+)$", dirname)
if m:
return int(m.group(1))
return None
def _discover_checkpoint_jobs(
base_cfg: dict, permanent_root: Path, run_label: str
) -> list[tuple[str, int, str, Path, str]]:
"""(run_label, step, resolved_model_dir_str, resolved_path, dir_name)"""
root = permanent_root.resolve()
if not root.is_dir():
raise FileNotFoundError(f"Not a directory: {root}")
jobs: list[tuple[str, int, str, Path, str]] = []
for p in sorted(root.iterdir()):
if not p.is_dir():
continue
step = _parse_checkpoint_step(p.name)
if step is None:
continue
resolved = _resolve_local_model_dir(base_cfg, str(p))
jobs.append((run_label, step, str(resolved), resolved, p.name))
jobs.sort(key=lambda x: (x[1], x[4]))
return jobs
def _line_chart_svg(
series: list[tuple[str, list[tuple[int, float]], str]],
title: str,
y_label: str,
y_max: float,
path: Path,
) -> None:
width = 900
height = 420
lm, rm, tm, bm = 70, 40, 50, 55
pw = width - lm - rm
ph = height - tm - bm
yb = tm + ph
all_steps: list[int] = []
for _, pts, _ in series:
all_steps.extend(s for s, _ in pts)
if not all_steps:
path.write_text(
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">'
f'<text x="40" y="40">{title} (no data)</text></svg>',
encoding="utf-8",
)
return
x_min, x_max = min(all_steps), max(all_steps)
if x_max == x_min:
x_max = x_min + 1
def sx(x: int) -> int:
return lm + int((x - x_min) / (x_max - x_min) * pw)
def sy(y: float) -> int:
y = max(0.0, min(y_max, y))
return yb - int((y / y_max) * ph) if y_max > 0 else yb
parts: list[str] = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">',
'<rect width="100%" height="100%" fill="#ffffff"/>',
f'<text x="{lm}" y="28" font-size="16" font-family="sans-serif">{title}</text>',
f'<text x="20" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" '
f'transform="rotate(-90 20 {tm + ph // 2})">{y_label}</text>',
f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>',
f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>',
f'<text x="{lm + pw // 2}" y="{height - 12}" text-anchor="middle" '
f'font-size="12" font-family="sans-serif">Training step</text>',
]
for i in range(5):
val = (i / 4) * y_max
yy = sy(val)
parts.append(
f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#999"/>'
)
parts.append(
f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" '
f'font-family="sans-serif">{val:.2f}</text>'
)
legend_x = lm + pw - 200
legend_y = tm + 8
for idx, (name, pts, color) in enumerate(series):
if len(pts) < 2:
pts_sorted = sorted(pts, key=lambda z: z[0])
if not pts_sorted:
continue
cx, cy = sx(pts_sorted[0][0]), sy(pts_sorted[0][1])
parts.append(
f'<circle cx="{cx}" cy="{cy}" r="4" fill="{color}" stroke="#111"/>'
)
else:
pts_sorted = sorted(pts, key=lambda z: z[0])
d = "M " + " L ".join(f"{sx(s)} {sy(v)}" for s, v in pts_sorted)
parts.append(
f'<path d="{d}" fill="none" stroke="{color}" stroke-width="2.5"/>'
)
parts.append(
f'<rect x="{legend_x}" y="{legend_y + idx * 18}" width="10" height="10" fill="{color}"/>'
)
parts.append(
f'<text x="{legend_x + 16}" y="{legend_y + idx * 18 + 9}" font-size="11" '
f'font-family="sans-serif">{name}</text>'
)
parts.append("</svg>")
path.write_text("\n".join(parts), encoding="utf-8")
def _scatter_accuracy_vs_cot_svg(rows: list[dict], path: Path, title: str) -> None:
"""Scatter: x = avg_cot_words, y = accuracy. One color per ``run_label``; optional path by training step."""
width = 640
height = 520
lm, rm, tm, bm = 72, 160, 52, 64
pw = width - lm - rm
ph = height - tm - bm
yb = tm + ph
if not rows:
path.write_text(
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">'
f'<text x="40" y="40">{title} (no data)</text></svg>',
encoding="utf-8",
)
return
labels: list[str] = []
seen: set[str] = set()
for r in rows:
lab = str(r.get("run_label", "run"))
if lab not in seen:
seen.add(lab)
labels.append(lab)
colors = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"]
color_map = {lab: colors[i % len(colors)] for i, lab in enumerate(labels)}
xs = [float(r["avg_cot_words"]) for r in rows]
ys = [float(r["accuracy"]) for r in rows]
x_min, x_max = min(xs), max(xs)
y_min, y_max = 0.0, 1.0
if x_max <= x_min:
x_max = x_min + 1.0
pad = (x_max - x_min) * 0.06 + 1.0
x_min = max(0.0, x_min - pad)
x_max = x_max + pad
def sx(x: float) -> float:
return lm + (x - x_min) / (x_max - x_min) * pw
def sy(y: float) -> float:
y = max(y_min, min(y_max, y))
return yb - (y - y_min) / (y_max - y_min) * ph
parts: list[str] = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">',
'<rect width="100%" height="100%" fill="#fafafa"/>',
f'<text x="{lm}" y="30" font-size="15" font-family="sans-serif">{title}</text>',
f'<text x="{width // 2}" y="{height - 18}" text-anchor="middle" font-size="12" '
f'font-family="sans-serif">Avg CoT length (words)</text>',
f'<text x="18" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" '
f'transform="rotate(-90 18 {tm + ph // 2})">Accuracy</text>',
f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>',
f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>',
]
for i in range(5):
val = y_min + (i / 4) * (y_max - y_min)
yy = sy(val)
parts.append(f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#bbb"/>')
parts.append(
f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" '
f'font-family="sans-serif">{val:.2f}</text>'
)
for i in range(5):
frac = i / 4
xv = x_min + frac * (x_max - x_min)
xx = sx(xv)
parts.append(f'<line x1="{xx}" y1="{yb}" x2="{xx}" y2="{yb + 4}" stroke="#bbb"/>')
parts.append(
f'<text x="{xx}" y="{yb + 18}" text-anchor="middle" font-size="10" '
f'font-family="sans-serif">{xv:.0f}</text>'
)
for lab in labels:
sub = [r for r in rows if str(r.get("run_label", "run")) == lab]
sub.sort(key=lambda r: int(r["checkpoint_step"]))
color = color_map[lab]
if len(sub) >= 2:
d = "M " + " L ".join(f'{sx(float(r["avg_cot_words"])):.1f} {sy(float(r["accuracy"])):.1f}' for r in sub)
parts.append(
f'<path d="{d}" fill="none" stroke="{color}" stroke-width="1.5" stroke-opacity="0.35"/>'
)
for r in rows:
lab = str(r.get("run_label", "run"))
color = color_map[lab]
cx = sx(float(r["avg_cot_words"]))
cy = sy(float(r["accuracy"]))
step = int(r["checkpoint_step"])
name = str(r.get("checkpoint_dir", f"step-{step}"))
tip = f"{name}: accuracy={float(r['accuracy']):.4f}, avg_cot_words={float(r['avg_cot_words']):.2f}"
parts.append(
f'<g><circle cx="{cx:.1f}" cy="{cy:.1f}" r="5" fill="{color}" stroke="#111" stroke-width="1">'
f"<title>{tip}</title></circle>"
f'<text x="{cx + 8:.1f}" y="{cy - 6:.1f}" font-size="9" font-family="sans-serif" fill="#333">{step}</text></g>'
)
legend_x = lm + pw + 14
legend_y = tm + 4
parts.append(
f'<text x="{legend_x}" y="{legend_y}" font-size="11" font-family="sans-serif" font-weight="bold">Series</text>'
)
for idx, lab in enumerate(labels):
cy = legend_y + 18 + idx * 20
parts.append(
f'<rect x="{legend_x}" y="{cy - 8}" width="10" height="10" fill="{color_map[lab]}"/>'
)
parts.append(
f'<text x="{legend_x + 16}" y="{cy}" font-size="11" font-family="sans-serif">{lab}</text>'
)
parts.append("</svg>")
path.write_text("\n".join(parts), encoding="utf-8")
def _resolve_out_root(default: Path) -> Path:
raw = os.environ.get("OUT_ROOT")
if raw is None or not str(raw).strip():
return resolve_repo_path(str(default))
return resolve_repo_path(raw)
def main() -> None:
rank, _, _ = _init_distributed()
base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200"))
eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
rollout_n = int(os.environ.get("ROLLOUT_SAMPLES", "8"))
permanent_root = os.environ.get("PERMANENT_ROOT", "").strip()
if permanent_root:
pr = resolve_repo_path(permanent_root)
run_label_single = os.environ.get("RUN_LABEL", "permanent")
out_default = pr / "eval_permanent"
out_root = _resolve_out_root(out_default)
jobs_single = _discover_checkpoint_jobs(base_cfg, pr, run_label_single)
all_jobs = jobs_single
jobs_cw1: list = []
jobs_cw5: list = []
else:
cw1_root = resolve_repo_path(os.environ["PERMANENT_CW1"])
cw5_root = resolve_repo_path(os.environ["PERMANENT_CW5"])
out_default = cw1_root.parent / "eval_permanent"
out_root = _resolve_out_root(out_default)
jobs_cw1 = _discover_checkpoint_jobs(base_cfg, cw1_root, "correctness_weight_1")
jobs_cw5 = _discover_checkpoint_jobs(base_cfg, cw5_root, "correctness_weight_5")
all_jobs = jobs_cw1 + jobs_cw5
if rank == 0:
out_root.mkdir(parents=True, exist_ok=True)
(out_root / "rollouts").mkdir(parents=True, exist_ok=True)
(out_root / "full_outputs").mkdir(parents=True, exist_ok=True)
if permanent_root:
print(f"PERMANENT_ROOT: {resolve_repo_path(permanent_root)} ({len(all_jobs)} checkpoints)")
for run_label, step, _, _, name in all_jobs:
print(f" {run_label} step={step} ({name})")
else:
print(f"Found {len(jobs_cw1)} checkpoints (cw=1), {len(jobs_cw5)} checkpoints (cw=5)")
for jl in (jobs_cw1, jobs_cw5):
for run_label, step, _, _, name in jl:
print(f" {run_label} step={step} ({name})")
if dist.is_initialized():
dist.barrier()
rows: list[dict] = []
for run_label, step, _resolved_str, resolved_path, dir_name in all_jobs:
records = evaluate_one_model(
model_dir=resolved_path,
base_cfg=base_cfg,
eval_max_samples=eval_max_samples,
batch_size=eval_batch_size,
)
if rank == 0:
acc = sum(float(r["correctness"]) for r in records) / len(records) if records else 0.0
avg_cot = sum(float(r["cot_words"]) for r in records) / len(records) if records else 0.0
row = {
"run_label": run_label,
"checkpoint_step": step,
"checkpoint_dir": dir_name,
"model_dir": str(resolved_path),
"num_examples": len(records),
"accuracy": acc,
"avg_cot_words": avg_cot,
}
rows.append(row)
rollout_dir = out_root / "rollouts" / run_label
rollout_dir.mkdir(parents=True, exist_ok=True)
rollout_path = rollout_dir / f"{dir_name}_rollouts.jsonl"
with rollout_path.open("w", encoding="utf-8") as handle:
for rec in records[:rollout_n]:
handle.write(json.dumps(rec, ensure_ascii=True) + "\n")
full_path = out_root / "full_outputs" / run_label / f"{dir_name}_outputs.jsonl"
full_path.parent.mkdir(parents=True, exist_ok=True)
with full_path.open("w", encoding="utf-8") as handle:
for rec in records:
handle.write(json.dumps(rec, ensure_ascii=True) + "\n")
print(
f"Eval {run_label} {dir_name}: acc={acc:.4f} avg_cot_words={avg_cot:.2f} n={len(records)}"
)
if dist.is_initialized():
dist.barrier()
if rank != 0:
return
rows.sort(key=lambda r: (r["run_label"], r["checkpoint_step"], r["checkpoint_dir"]))
summary_json = out_root / "permanent_checkpoints_eval.json"
summary_csv = out_root / "permanent_checkpoints_eval.csv"
summary_json.write_text(json.dumps(rows, indent=2), encoding="utf-8")
with summary_csv.open("w", encoding="utf-8", newline="") as handle:
w = csv.DictWriter(
handle,
fieldnames=[
"run_label",
"checkpoint_step",
"checkpoint_dir",
"model_dir",
"num_examples",
"accuracy",
"avg_cot_words",
],
)
w.writeheader()
for row in rows:
w.writerow(row)
def series_for(label: str, ykey: str) -> list[tuple[int, float]]:
return [
(int(r["checkpoint_step"]), float(r[ykey]))
for r in rows
if r["run_label"] == label
]
palette = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"]
uniq_labels = sorted({str(r["run_label"]) for r in rows})
acc_series = [
(lab, series_for(lab, "accuracy"), palette[i % len(palette)])
for i, lab in enumerate(uniq_labels)
if series_for(lab, "accuracy")
]
cot_series = [
(lab, series_for(lab, "avg_cot_words"), palette[i % len(palette)])
for i, lab in enumerate(uniq_labels)
if series_for(lab, "avg_cot_words")
]
cot_max = 1.0
for r in rows:
cot_max = max(cot_max, float(r["avg_cot_words"]))
if acc_series:
_line_chart_svg(
acc_series,
"GSM8K accuracy vs checkpoint step",
"Accuracy",
1.0,
out_root / "accuracy_vs_step.svg",
)
if cot_series:
_line_chart_svg(
cot_series,
"Average CoT length (words) vs checkpoint step",
"Avg CoT words",
cot_max,
out_root / "avg_cot_vs_step.svg",
)
_scatter_accuracy_vs_cot_svg(
rows,
out_root / "accuracy_vs_avg_cot_words.svg",
"GSM8K accuracy vs average CoT length (words)",
)
print(f"Saved: {summary_json}")
print(f"Saved: {summary_csv}")
if acc_series:
print(f"Saved: {out_root / 'accuracy_vs_step.svg'}")
if cot_series:
print(f"Saved: {out_root / 'avg_cot_vs_step.svg'}")
print(f"Saved: {out_root / 'accuracy_vs_avg_cot_words.svg'}")
print(f"Rollouts: {out_root / 'rollouts'}/<run_label>/")
print(f"Full outputs: {out_root / 'full_outputs'}/<run_label>/")
if __name__ == "__main__":
main()