#!/usr/bin/env python """Generate the headline "training improves the model" plots. Produces: 08_training_progression.png LEFT: smoothed reward over training step (all runs) with 1.7B base reference line so judges see the policy gradient working. RIGHT: same-model eval before/after pairs with delta arrows. 09_training_diagnostics.png LEFT: reward std over training step (convergence signal). RIGHT: mean completion length over step (behaviour shift). Inputs: --log-history LABEL=PATH log_history.json per run (repeat) --summary PATH plots/runs_summary.json --out-dir PATH plots/ (default) """ from __future__ import annotations import argparse import json import statistics from pathlib import Path from typing import Any _LABEL_COLORS: dict[str, str] = { "0.6B base": "#ffb74d", "Probe (0.6B, β=0)": "#1f77b4", "1.7B base": "#66bb6a", "Drift (1.7B, β=0)": "#e53935", "Anchor (1.7B, β=0.2)": "#2e7d32", "Restrain (1.7B, β=1.0)": "#0d47a1", "Champion (1.7B, β=0.3)": "#ff6f00", "4B base": "#5e35b1", "4B-instruct": "#00838f", } _FALLBACK = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"] def _color(label: str, i: int = 0) -> str: return _LABEL_COLORS.get(label, _FALLBACK[i % len(_FALLBACK)]) def _rolling(vals: list[float], w: int) -> list[float]: out: list[float] = [] for i in range(len(vals)): chunk = vals[max(0, i - w + 1):i + 1] out.append(statistics.mean(chunk) if chunk else 0.0) return out def _extract_series(hist: list[dict]) -> dict[str, list]: """Pull step, reward, reward_std, completion length, and kl from log_history.""" steps, rewards, reward_stds, comp_lens, kls = [], [], [], [], [] for row in hist: step = row.get("step") if step is None: continue rew = row.get("reward") or row.get("rewards/reward_func/mean") if rew is None: continue steps.append(int(step)) rewards.append(float(rew)) reward_stds.append(float(row.get("reward_std") or row.get("rewards/reward_func/std") or 0.0)) comp_lens.append(float(row.get("completions/mean_length", 0.0))) kls.append(float(row.get("kl", 0.0))) return {"steps": steps, "rewards": rewards, "reward_stds": reward_stds, "comp_lens": comp_lens, "kls": kls} _BEFORE_AFTER_PAIRS: list[tuple[str, str]] = [ ("0.6B base", "Probe (0.6B, β=0)"), ("1.7B base", "Drift (1.7B, β=0)"), ("1.7B base", "Anchor (1.7B, β=0.2)"), ("1.7B base", "Restrain (1.7B, β=1.0)"), ("1.7B base", "Champion (1.7B, β=0.3)"), ] def plot_08(runs: dict[str, dict], summary_rows: list[dict], out_path: Path) -> None: import matplotlib.pyplot as plt import matplotlib.patches as mpatches fig, (ax_rew, ax_ba) = plt.subplots(1, 2, figsize=(16, 5.5), gridspec_kw={"width_ratios": [1.2, 1]}) # --- LEFT: Smoothed reward over step --- window = 30 for i, (label, data) in enumerate(runs.items()): s = data["series"] if not s["steps"]: continue raw = s["rewards"] smooth = _rolling(raw, window) col = _color(label, i) ax_rew.plot(s["steps"], raw, color=col, alpha=0.18, lw=0.8) ax_rew.plot(s["steps"], smooth, color=col, lw=2.5, label=f"{label} (rolling-{window})") if smooth: ax_rew.annotate(f"{smooth[-1]:.4f}", xy=(s["steps"][-1], smooth[-1]), fontsize=8, fontweight="bold", color=col, xytext=(5, 5), textcoords="offset points") base_17b_avg = 0.0 for row in summary_rows: if row.get("label") == "1.7B base": base_17b_avg = row.get("avg_score", 0.0) break if base_17b_avg > 0: ax_rew.axhline(base_17b_avg, color="#66bb6a", ls="--", lw=1.5, alpha=0.7) ax_rew.text(5, base_17b_avg + 0.002, f"1.7B base eval avg = {base_17b_avg:.3f}", fontsize=8, color="#66bb6a", fontstyle="italic") ax_rew.set_xlabel("Training step", fontsize=11) ax_rew.set_ylabel("Mean rubric reward", fontsize=11) ax_rew.set_title("Reward climbs over training\n(policy gradient is working)", fontsize=12) ax_rew.legend(fontsize=8, loc="upper left", framealpha=0.9) ax_rew.grid(alpha=0.3) ax_rew.set_ylim(bottom=-0.005) # --- RIGHT: Before/after eval bars --- by_label = {r["label"]: r for r in summary_rows} pairs = [(b, t) for b, t in _BEFORE_AFTER_PAIRS if b in by_label and t in by_label] n = len(pairs) x_pos = list(range(n)) bar_w = 0.35 for idx, (base_lbl, trained_lbl) in enumerate(pairs): base_score = by_label[base_lbl]["avg_score"] trained_score = by_label[trained_lbl]["avg_score"] delta = trained_score - base_score ax_ba.bar(idx - bar_w / 2, base_score, bar_w, color="#bdbdbd", edgecolor="white", linewidth=0.5) ax_ba.bar(idx + bar_w / 2, trained_score, bar_w, color=_color(trained_lbl, idx), edgecolor="white", linewidth=0.5) top = max(base_score, trained_score) sign = "+" if delta >= 0 else "" ax_ba.text(idx, top + 0.008, f"{sign}{delta:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold", color="#2e7d32" if delta >= 0 else "#c62828") ax_ba.text(idx - bar_w / 2, base_score + 0.002, f"{base_score:.3f}", ha="center", va="bottom", fontsize=7, color="#616161") ax_ba.text(idx + bar_w / 2, trained_score + 0.002, f"{trained_score:.3f}", ha="center", va="bottom", fontsize=7, color="#212121") pair_labels = [] for b, t in pairs: short_t = t.split("(")[-1].rstrip(")") size = b.split(" ")[0] pair_labels.append(f"{size}\n{short_t}") ax_ba.set_xticks(x_pos) ax_ba.set_xticklabels(pair_labels, fontsize=9) ax_ba.set_ylabel("Avg eval score (n=50)", fontsize=11) ax_ba.set_title("Eval score: base (grey) vs trained (color)\nDelta labeled above each pair", fontsize=12) ax_ba.grid(alpha=0.3, axis="y") vals = [by_label[b]["avg_score"] for b, _ in pairs] + [by_label[t]["avg_score"] for _, t in pairs] top_val = max(vals) if vals else 0.1 ax_ba.set_ylim(0, top_val * 1.4 + 0.02) grey_patch = mpatches.Patch(color="#bdbdbd", label="Base (untrained)") trained_patch = mpatches.Patch(color="#1f77b4", label="After GRPO") ax_ba.legend(handles=[grey_patch, trained_patch], fontsize=8, loc="upper right") fig.suptitle("ClarifyRL — Training progression and evaluation improvement", fontsize=14, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.94]) fig.savefig(out_path, dpi=160, bbox_inches="tight") plt.close(fig) print(f"[ok] {out_path}") def plot_09(runs: dict[str, dict], out_path: Path) -> None: import matplotlib.pyplot as plt fig, (ax_std, ax_len) = plt.subplots(1, 2, figsize=(14, 5)) window = 20 for i, (label, data) in enumerate(runs.items()): s = data["series"] if not s["steps"]: continue col = _color(label, i) smooth_std = _rolling(s["reward_stds"], window) ax_std.plot(s["steps"], smooth_std, color=col, lw=2, label=f"{label} (rolling-{window})") smooth_len = _rolling(s["comp_lens"], window) ax_len.plot(s["steps"], smooth_len, color=col, lw=2, label=label) ax_std.set_xlabel("Training step") ax_std.set_ylabel("Reward std (within batch)") ax_std.set_title("Reward variance over training\n(shrinking = policy converging)") ax_std.legend(fontsize=8) ax_std.grid(alpha=0.3) ax_len.set_xlabel("Training step") ax_len.set_ylabel("Mean completion length (tokens)") ax_len.set_title("Completion length over training\n(tracks output verbosity shift)") ax_len.legend(fontsize=8) ax_len.grid(alpha=0.3) fig.suptitle("ClarifyRL — Training diagnostics", fontsize=13, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.94]) fig.savefig(out_path, dpi=160, bbox_inches="tight") plt.close(fig) print(f"[ok] {out_path}") def main() -> None: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("--log-history", action="append", default=[], help="LABEL=PATH (can repeat)") p.add_argument("--summary", default="plots/runs_summary.json", help="Path to runs_summary.json") p.add_argument("--out-dir", default="plots") args = p.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) runs: dict[str, dict] = {} for spec in args.log_history: label, _, path = spec.rpartition("=") if not label: label, path = path, label path = path.strip() label = label.strip() p_path = Path(path) if not p_path.exists(): print(f"[skip] {label}: {path} not found") continue hist = json.loads(p_path.read_text()) runs[label] = {"series": _extract_series(hist)} summary_rows: list[dict] = [] sp = Path(args.summary) if sp.exists(): summary_rows = json.loads(sp.read_text()).get("rows", []) else: print(f"[warn] {args.summary} not found — before/after panel will be empty") if runs: plot_08(runs, summary_rows, out_dir / "08_training_progression.png") plot_09(runs, out_dir / "09_training_diagnostics.png") else: print("[skip] no log_history files provided") if __name__ == "__main__": main()