Spaces:
Running
Running
| #!/usr/bin/env python | |
| """Generate the headline "training improves the model" plots. | |
| Produces: | |
| 08_training_progression.png | |
| LEFT: smoothed reward over training step (all runs) with 1.7B base | |
| reference line so judges see the policy gradient working. | |
| RIGHT: same-model eval before/after pairs with delta arrows. | |
| 09_training_diagnostics.png | |
| LEFT: reward std over training step (convergence signal). | |
| RIGHT: mean completion length over step (behaviour shift). | |
| Inputs: | |
| --log-history LABEL=PATH log_history.json per run (repeat) | |
| --summary PATH plots/runs_summary.json | |
| --out-dir PATH plots/ (default) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import statistics | |
| from pathlib import Path | |
| from typing import Any | |
| _LABEL_COLORS: dict[str, str] = { | |
| "0.6B base": "#ffb74d", | |
| "Probe (0.6B, β=0)": "#1f77b4", | |
| "1.7B base": "#66bb6a", | |
| "Drift (1.7B, β=0)": "#e53935", | |
| "Anchor (1.7B, β=0.2)": "#2e7d32", | |
| "Restrain (1.7B, β=1.0)": "#0d47a1", | |
| "Champion (1.7B, β=0.3)": "#ff6f00", | |
| "4B base": "#5e35b1", | |
| "4B-instruct": "#00838f", | |
| } | |
| _FALLBACK = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", | |
| "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"] | |
| def _color(label: str, i: int = 0) -> str: | |
| return _LABEL_COLORS.get(label, _FALLBACK[i % len(_FALLBACK)]) | |
| def _rolling(vals: list[float], w: int) -> list[float]: | |
| out: list[float] = [] | |
| for i in range(len(vals)): | |
| chunk = vals[max(0, i - w + 1):i + 1] | |
| out.append(statistics.mean(chunk) if chunk else 0.0) | |
| return out | |
| def _extract_series(hist: list[dict]) -> dict[str, list]: | |
| """Pull step, reward, reward_std, completion length, and kl from log_history.""" | |
| steps, rewards, reward_stds, comp_lens, kls = [], [], [], [], [] | |
| for row in hist: | |
| step = row.get("step") | |
| if step is None: | |
| continue | |
| rew = row.get("reward") or row.get("rewards/reward_func/mean") | |
| if rew is None: | |
| continue | |
| steps.append(int(step)) | |
| rewards.append(float(rew)) | |
| reward_stds.append(float(row.get("reward_std") or row.get("rewards/reward_func/std") or 0.0)) | |
| comp_lens.append(float(row.get("completions/mean_length", 0.0))) | |
| kls.append(float(row.get("kl", 0.0))) | |
| return {"steps": steps, "rewards": rewards, "reward_stds": reward_stds, | |
| "comp_lens": comp_lens, "kls": kls} | |
| _BEFORE_AFTER_PAIRS: list[tuple[str, str]] = [ | |
| ("0.6B base", "Probe (0.6B, β=0)"), | |
| ("1.7B base", "Drift (1.7B, β=0)"), | |
| ("1.7B base", "Anchor (1.7B, β=0.2)"), | |
| ("1.7B base", "Restrain (1.7B, β=1.0)"), | |
| ("1.7B base", "Champion (1.7B, β=0.3)"), | |
| ] | |
| def plot_08(runs: dict[str, dict], summary_rows: list[dict], out_path: Path) -> None: | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| fig, (ax_rew, ax_ba) = plt.subplots(1, 2, figsize=(16, 5.5), | |
| gridspec_kw={"width_ratios": [1.2, 1]}) | |
| # --- LEFT: Smoothed reward over step --- | |
| window = 30 | |
| for i, (label, data) in enumerate(runs.items()): | |
| s = data["series"] | |
| if not s["steps"]: | |
| continue | |
| raw = s["rewards"] | |
| smooth = _rolling(raw, window) | |
| col = _color(label, i) | |
| ax_rew.plot(s["steps"], raw, color=col, alpha=0.18, lw=0.8) | |
| ax_rew.plot(s["steps"], smooth, color=col, lw=2.5, | |
| label=f"{label} (rolling-{window})") | |
| if smooth: | |
| ax_rew.annotate(f"{smooth[-1]:.4f}", | |
| xy=(s["steps"][-1], smooth[-1]), | |
| fontsize=8, fontweight="bold", color=col, | |
| xytext=(5, 5), textcoords="offset points") | |
| base_17b_avg = 0.0 | |
| for row in summary_rows: | |
| if row.get("label") == "1.7B base": | |
| base_17b_avg = row.get("avg_score", 0.0) | |
| break | |
| if base_17b_avg > 0: | |
| ax_rew.axhline(base_17b_avg, color="#66bb6a", ls="--", lw=1.5, alpha=0.7) | |
| ax_rew.text(5, base_17b_avg + 0.002, f"1.7B base eval avg = {base_17b_avg:.3f}", | |
| fontsize=8, color="#66bb6a", fontstyle="italic") | |
| ax_rew.set_xlabel("Training step", fontsize=11) | |
| ax_rew.set_ylabel("Mean rubric reward", fontsize=11) | |
| ax_rew.set_title("Reward climbs over training\n(policy gradient is working)", fontsize=12) | |
| ax_rew.legend(fontsize=8, loc="upper left", framealpha=0.9) | |
| ax_rew.grid(alpha=0.3) | |
| ax_rew.set_ylim(bottom=-0.005) | |
| # --- RIGHT: Before/after eval bars --- | |
| by_label = {r["label"]: r for r in summary_rows} | |
| pairs = [(b, t) for b, t in _BEFORE_AFTER_PAIRS | |
| if b in by_label and t in by_label] | |
| n = len(pairs) | |
| x_pos = list(range(n)) | |
| bar_w = 0.35 | |
| for idx, (base_lbl, trained_lbl) in enumerate(pairs): | |
| base_score = by_label[base_lbl]["avg_score"] | |
| trained_score = by_label[trained_lbl]["avg_score"] | |
| delta = trained_score - base_score | |
| ax_ba.bar(idx - bar_w / 2, base_score, bar_w, | |
| color="#bdbdbd", edgecolor="white", linewidth=0.5) | |
| ax_ba.bar(idx + bar_w / 2, trained_score, bar_w, | |
| color=_color(trained_lbl, idx), edgecolor="white", linewidth=0.5) | |
| top = max(base_score, trained_score) | |
| sign = "+" if delta >= 0 else "" | |
| ax_ba.text(idx, top + 0.008, f"{sign}{delta:.3f}", | |
| ha="center", va="bottom", fontsize=9, fontweight="bold", | |
| color="#2e7d32" if delta >= 0 else "#c62828") | |
| ax_ba.text(idx - bar_w / 2, base_score + 0.002, f"{base_score:.3f}", | |
| ha="center", va="bottom", fontsize=7, color="#616161") | |
| ax_ba.text(idx + bar_w / 2, trained_score + 0.002, f"{trained_score:.3f}", | |
| ha="center", va="bottom", fontsize=7, color="#212121") | |
| pair_labels = [] | |
| for b, t in pairs: | |
| short_t = t.split("(")[-1].rstrip(")") | |
| size = b.split(" ")[0] | |
| pair_labels.append(f"{size}\n{short_t}") | |
| ax_ba.set_xticks(x_pos) | |
| ax_ba.set_xticklabels(pair_labels, fontsize=9) | |
| ax_ba.set_ylabel("Avg eval score (n=50)", fontsize=11) | |
| ax_ba.set_title("Eval score: base (grey) vs trained (color)\nDelta labeled above each pair", fontsize=12) | |
| ax_ba.grid(alpha=0.3, axis="y") | |
| vals = [by_label[b]["avg_score"] for b, _ in pairs] + [by_label[t]["avg_score"] for _, t in pairs] | |
| top_val = max(vals) if vals else 0.1 | |
| ax_ba.set_ylim(0, top_val * 1.4 + 0.02) | |
| grey_patch = mpatches.Patch(color="#bdbdbd", label="Base (untrained)") | |
| trained_patch = mpatches.Patch(color="#1f77b4", label="After GRPO") | |
| ax_ba.legend(handles=[grey_patch, trained_patch], fontsize=8, loc="upper right") | |
| fig.suptitle("ClarifyRL — Training progression and evaluation improvement", fontsize=14, fontweight="bold") | |
| fig.tight_layout(rect=[0, 0, 1, 0.94]) | |
| fig.savefig(out_path, dpi=160, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| def plot_09(runs: dict[str, dict], out_path: Path) -> None: | |
| import matplotlib.pyplot as plt | |
| fig, (ax_std, ax_len) = plt.subplots(1, 2, figsize=(14, 5)) | |
| window = 20 | |
| for i, (label, data) in enumerate(runs.items()): | |
| s = data["series"] | |
| if not s["steps"]: | |
| continue | |
| col = _color(label, i) | |
| smooth_std = _rolling(s["reward_stds"], window) | |
| ax_std.plot(s["steps"], smooth_std, color=col, lw=2, label=f"{label} (rolling-{window})") | |
| smooth_len = _rolling(s["comp_lens"], window) | |
| ax_len.plot(s["steps"], smooth_len, color=col, lw=2, label=label) | |
| ax_std.set_xlabel("Training step") | |
| ax_std.set_ylabel("Reward std (within batch)") | |
| ax_std.set_title("Reward variance over training\n(shrinking = policy converging)") | |
| ax_std.legend(fontsize=8) | |
| ax_std.grid(alpha=0.3) | |
| ax_len.set_xlabel("Training step") | |
| ax_len.set_ylabel("Mean completion length (tokens)") | |
| ax_len.set_title("Completion length over training\n(tracks output verbosity shift)") | |
| ax_len.legend(fontsize=8) | |
| ax_len.grid(alpha=0.3) | |
| fig.suptitle("ClarifyRL — Training diagnostics", fontsize=13, fontweight="bold") | |
| fig.tight_layout(rect=[0, 0, 1, 0.94]) | |
| fig.savefig(out_path, dpi=160, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| def main() -> None: | |
| p = argparse.ArgumentParser(description=__doc__, | |
| formatter_class=argparse.RawDescriptionHelpFormatter) | |
| p.add_argument("--log-history", action="append", default=[], | |
| help="LABEL=PATH (can repeat)") | |
| p.add_argument("--summary", default="plots/runs_summary.json", | |
| help="Path to runs_summary.json") | |
| p.add_argument("--out-dir", default="plots") | |
| args = p.parse_args() | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| runs: dict[str, dict] = {} | |
| for spec in args.log_history: | |
| label, _, path = spec.rpartition("=") | |
| if not label: | |
| label, path = path, label | |
| path = path.strip() | |
| label = label.strip() | |
| p_path = Path(path) | |
| if not p_path.exists(): | |
| print(f"[skip] {label}: {path} not found") | |
| continue | |
| hist = json.loads(p_path.read_text()) | |
| runs[label] = {"series": _extract_series(hist)} | |
| summary_rows: list[dict] = [] | |
| sp = Path(args.summary) | |
| if sp.exists(): | |
| summary_rows = json.loads(sp.read_text()).get("rows", []) | |
| else: | |
| print(f"[warn] {args.summary} not found — before/after panel will be empty") | |
| if runs: | |
| plot_08(runs, summary_rows, out_dir / "08_training_progression.png") | |
| plot_09(runs, out_dir / "09_training_diagnostics.png") | |
| else: | |
| print("[skip] no log_history files provided") | |
| if __name__ == "__main__": | |
| main() | |