Spaces:
Running
Running
| #!/usr/bin/env python | |
| """Generate the 5 submission plots from training logs + eval JSONs. | |
| Inputs (all optional — script emits whatever it can): | |
| --log-history <output_dir>/log_history.json (from train_grpo.py) | |
| --eval-policy outputs/eval_policy.json | |
| --eval-base outputs/eval_<model>_base.json | |
| --eval-trained outputs/eval_<model>_trained.json | |
| --out-dir plots/ (default) | |
| Output PNGs in --out-dir: | |
| 01_reward_loss_curves.png reward + loss vs training step | |
| 02_per_family_bars.png avg score per task family (policy / base / trained) | |
| 03_component_breakdown.png per-rubric-component bars (FieldMatch / InfoGain / ...) | |
| 04_before_after.png scatter + grouped bar of policy vs base vs trained | |
| 05_question_efficiency.png hist of questions asked per scenario (policy / base / trained) | |
| If multiple sources are present, each plot overlays the comparable series. | |
| If a source is missing, the plot quietly omits that series and continues. | |
| Usage: | |
| python scripts/make_plots.py \\ | |
| --log-history clarify-rl-grpo-qwen3-0.6b/log_history.json \\ | |
| --eval-policy outputs/eval_policy.json \\ | |
| --eval-base outputs/eval_qwen3-0.6b_base.json \\ | |
| --eval-trained outputs/eval_qwen3-0.6b_trained.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import statistics | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Any | |
| # Lazy-import matplotlib only at runtime — keeps `--help` snappy and avoids | |
| # import errors when only a subset of inputs is present. | |
| def _load_json(path: str | None) -> Any: | |
| if not path: | |
| return None | |
| p = Path(path) | |
| if not p.exists(): | |
| print(f"[skip] {path} not found") | |
| return None | |
| return json.loads(p.read_text()) | |
| def _safe_mean(xs: list[float]) -> float: | |
| return statistics.mean(xs) if xs else 0.0 | |
| # Stable label -> color map. Same label = same color in every plot, no | |
| # collisions across the 8 series we render in the submission deck. Values | |
| # are picked from matplotlib's tab20 + a couple of overrides to keep the | |
| # headline trios (Run 1 / Run 2 / Run 4) clearly distinct. | |
| _LABEL_COLORS: dict[str, str] = { | |
| "policy (deterministic)": "#9e9e9e", # neutral grey | |
| "policy (baseline)": "#9e9e9e", | |
| "0.6B base": "#ffb74d", # warm orange | |
| "Probe (0.6B, β=0)": "#1f77b4", # strong blue — exploratory | |
| "1.7B base": "#66bb6a", # mid green | |
| "Drift (1.7B, β=0)": "#e53935", # red — collapsed | |
| "Anchor (1.7B, β=0.2)": "#2e7d32", # deep green — KL anchor recovers | |
| "Restrain (1.7B, β=1.0)": "#0d47a1", # dark blue — too anchored | |
| "Champion (1.7B, β=0.3)": "#ff6f00", # orange — winner, beats base | |
| "4B base": "#5e35b1", # purple — ceiling marker | |
| "4B-instruct": "#00838f", # teal | |
| "4B GRPO (Run 3)": "#ff6f00", # amber | |
| "untrained": "#ffb74d", | |
| "trained": "#1f77b4", | |
| } | |
| # Fallback palette when a label isn't pre-mapped. tab20-style colors | |
| # already chosen to contrast with the named ones above. | |
| _FALLBACK_COLORS: list[str] = [ | |
| "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", | |
| "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", | |
| ] | |
| def _color_for(label: str, fallback_index: int = 0) -> str: | |
| """Return a stable, distinct color for `label`.""" | |
| if label in _LABEL_COLORS: | |
| return _LABEL_COLORS[label] | |
| return _FALLBACK_COLORS[fallback_index % len(_FALLBACK_COLORS)] | |
| def _autoscale_top(values: list[float], floor: float = 0.10, headroom: float = 1.20) -> float: | |
| """Pick a y-axis top so the data fills the visible range. | |
| `floor` is the minimum top we want even when values are tiny (so a single | |
| very low bar doesn't end up at 100% of the axis and look misleading). | |
| """ | |
| if not values: | |
| return 1.0 | |
| top = max(values) * headroom | |
| return max(top, floor) | |
| # --------------------------------------------------------------------------- | |
| # Plot 1 — Reward + Loss curves | |
| # --------------------------------------------------------------------------- | |
| def plot_reward_loss_curves(log_history: list[dict], out_path: Path) -> None: | |
| if not log_history: | |
| print("[skip] reward/loss curves — no log_history") | |
| return | |
| import matplotlib.pyplot as plt | |
| steps_loss: list[int] = [] | |
| losses: list[float] = [] | |
| steps_rew: list[int] = [] | |
| rewards: list[float] = [] | |
| for row in log_history: | |
| step = row.get("step") | |
| if step is None: | |
| continue | |
| if "loss" in row and isinstance(row["loss"], (int, float)): | |
| steps_loss.append(step) | |
| losses.append(float(row["loss"])) | |
| if "reward" in row and isinstance(row["reward"], (int, float)): | |
| steps_rew.append(step) | |
| rewards.append(float(row["reward"])) | |
| elif "rewards/reward_func/mean" in row: | |
| steps_rew.append(step) | |
| rewards.append(float(row["rewards/reward_func/mean"])) | |
| if not steps_loss and not steps_rew: | |
| print("[skip] reward/loss curves — log_history has no loss or reward") | |
| return | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| if steps_loss: | |
| axes[0].plot(steps_loss, losses, color="tab:red", lw=1.5) | |
| axes[0].set_title("Training loss (advantage-weighted)") | |
| axes[0].set_xlabel("Step") | |
| axes[0].set_ylabel("Loss") | |
| axes[0].grid(alpha=0.3) | |
| if steps_rew: | |
| axes[1].plot(steps_rew, rewards, color="tab:blue", lw=1.5) | |
| # Rolling mean for trend | |
| window = max(1, len(rewards) // 20) | |
| if window >= 3 and len(rewards) >= window: | |
| roll = [ | |
| _safe_mean(rewards[max(0, i - window): i + 1]) | |
| for i in range(len(rewards)) | |
| ] | |
| axes[1].plot(steps_rew, roll, color="tab:blue", lw=2.5, alpha=0.4, label=f"rolling-mean ({window})") | |
| axes[1].legend() | |
| axes[1].set_title("Reward per training step") | |
| axes[1].set_xlabel("Step") | |
| axes[1].set_ylabel("Reward (rubric score)") | |
| axes[1].grid(alpha=0.3) | |
| fig.suptitle("ClarifyRL GRPO training") | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| def plot_reward_loss_curves_multi(runs: dict[str, list[dict]], out_path: Path) -> None: | |
| """Overlay multiple training runs: LEFT = KL divergence, RIGHT = reward. | |
| The loss panel is replaced by KL divergence because GRPO loss is | |
| intrinsically noisy and un-interpretable, while KL directly shows | |
| the anchor working (Run 4 stays ~0.005-0.010 throughout). | |
| """ | |
| runs = {label: hist for label, hist in runs.items() if hist} | |
| if not runs: | |
| print("[skip] reward/loss curves \u2014 no runs") | |
| return | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| has_kl = False | |
| window = 30 | |
| for i, (label, hist) in enumerate(runs.items()): | |
| steps_kl: list[int] = [] | |
| kls: list[float] = [] | |
| steps_rew: list[int] = [] | |
| rewards: list[float] = [] | |
| for row in hist: | |
| step = row.get("step") | |
| if step is None: | |
| continue | |
| kl = row.get("kl") | |
| if kl is not None and isinstance(kl, (int, float)): | |
| steps_kl.append(step) | |
| kls.append(float(kl)) | |
| if "reward" in row and isinstance(row["reward"], (int, float)): | |
| steps_rew.append(step) | |
| rewards.append(float(row["reward"])) | |
| elif "rewards/reward_func/mean" in row: | |
| steps_rew.append(step) | |
| rewards.append(float(row["rewards/reward_func/mean"])) | |
| color = _color_for(label, i) | |
| if steps_kl and any(v > 0 for v in kls): | |
| has_kl = True | |
| roll_kl = [_safe_mean(kls[max(0, j - window): j + 1]) for j in range(len(kls))] | |
| axes[0].plot(steps_kl, kls, color=color, lw=0.8, alpha=0.25) | |
| axes[0].plot(steps_kl, roll_kl, color=color, lw=2.5, label=f"{label}") | |
| if steps_rew: | |
| axes[1].plot(steps_rew, rewards, color=color, lw=0.8, alpha=0.2) | |
| roll = [_safe_mean(rewards[max(0, j - window): j + 1]) for j in range(len(rewards))] | |
| axes[1].plot(steps_rew, roll, color=color, lw=2.5, | |
| label=f"{label} (rolling-{window})") | |
| if roll: | |
| axes[1].annotate(f"{roll[-1]:.4f}", | |
| xy=(steps_rew[-1], roll[-1]), | |
| fontsize=8, fontweight="bold", color=color, | |
| xytext=(5, 3), textcoords="offset points") | |
| if has_kl: | |
| axes[0].set_title("KL divergence from reference policy") | |
| axes[0].set_xlabel("Step") | |
| axes[0].set_ylabel("KL (nats)") | |
| axes[0].legend(fontsize=8) | |
| axes[0].grid(alpha=0.3) | |
| else: | |
| axes[0].set_title("KL divergence (no KL data — runs used beta=0)") | |
| axes[0].text(0.5, 0.5, "No KL anchor active\n(beta=0 for all runs shown)", | |
| transform=axes[0].transAxes, ha="center", va="center", | |
| fontsize=12, color="gray") | |
| axes[0].grid(alpha=0.3) | |
| axes[1].set_title("Reward per training step (rolling-30)") | |
| axes[1].set_xlabel("Step") | |
| axes[1].set_ylabel("Reward (rubric mean)") | |
| axes[1].legend(fontsize=8) | |
| axes[1].grid(alpha=0.3) | |
| axes[1].set_ylim(bottom=-0.005) | |
| fig.suptitle("ClarifyRL GRPO training curves", fontsize=13, fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| # --------------------------------------------------------------------------- | |
| # Plot 2 — Per-family bars | |
| # --------------------------------------------------------------------------- | |
| def plot_per_family_bars(evals: dict[str, dict | None], out_path: Path) -> None: | |
| series = {label: ev for label, ev in evals.items() if ev} | |
| if not series: | |
| print("[skip] per-family — no eval JSONs") | |
| return | |
| import matplotlib.pyplot as plt | |
| families = sorted({ | |
| r.get("family", "?") | |
| for ev in series.values() | |
| for r in ev.get("results", []) | |
| }) | |
| matrix: dict[str, dict[str, float]] = {} | |
| for label, ev in series.items(): | |
| per_family = defaultdict(list) | |
| for r in ev.get("results", []): | |
| per_family[r.get("family", "?")].append(r.get("final_score", 0.0)) | |
| matrix[label] = {fam: _safe_mean(per_family.get(fam, [])) for fam in families} | |
| n_groups = len(families) | |
| n_series = len(series) | |
| width = 0.8 / max(1, n_series) | |
| x = list(range(n_groups)) | |
| fig, ax = plt.subplots(figsize=(max(8, n_groups * 1.4), 5)) | |
| all_values: list[float] = [] | |
| for i, (label, scores) in enumerate(matrix.items()): | |
| vals = [scores[fam] for fam in families] | |
| all_values.extend(vals) | |
| ax.bar( | |
| [xi + (i - (n_series - 1) / 2) * width for xi in x], | |
| vals, | |
| width=width, | |
| label=label, | |
| color=_color_for(label, i), | |
| edgecolor="white", | |
| linewidth=0.5, | |
| ) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(families, rotation=15, ha="right") | |
| ax.set_ylabel("Avg final score (n=50, eval v4)") | |
| ax.set_title("Avg score per task family — base vs trained, all model sizes") | |
| ax.set_ylim(0.0, _autoscale_top(all_values, floor=0.40)) | |
| ax.grid(alpha=0.3, axis="y") | |
| ax.legend(loc="upper right", fontsize=8, framealpha=0.95) | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| # --------------------------------------------------------------------------- | |
| # Plot 3 — Per-component breakdown | |
| # --------------------------------------------------------------------------- | |
| def plot_component_breakdown(evals: dict[str, dict | None], out_path: Path) -> None: | |
| series = {label: ev for label, ev in evals.items() if ev} | |
| if not series: | |
| print("[skip] component breakdown — no eval JSONs") | |
| return | |
| import matplotlib.pyplot as plt | |
| # Real keys in the eval JSON are e.g. `FieldMatchRubric` (with suffix). | |
| # We display the short name on the axis but look up `<short>Rubric` first. | |
| component_names = ["FormatCheck", "FieldMatch", "InfoGain", "QuestionEfficiency", "HallucinationCheck"] | |
| def _bd_get(bd: dict, c: str) -> float | None: | |
| for key in (f"{c}Rubric", c, c.lower(), c.replace("Check", "")): | |
| v = bd.get(key) | |
| if v is not None: | |
| try: | |
| return float(v) | |
| except (TypeError, ValueError): | |
| return None | |
| return None | |
| # Average ONLY across scenarios where a score breakdown was actually | |
| # produced (i.e. the rubric ran). Format-failed scenarios with `{}` get | |
| # filtered out so the bars represent "given the rubric ran, here is each | |
| # component's contribution" — which is what the README claims. | |
| matrix: dict[str, dict[str, float]] = {} | |
| coverage: dict[str, int] = {} | |
| for label, ev in series.items(): | |
| sums: dict[str, list[float]] = {c: [] for c in component_names} | |
| n_scored = 0 | |
| for r in ev.get("results", []): | |
| bd = r.get("score_breakdown") or {} | |
| if not bd: | |
| continue | |
| n_scored += 1 | |
| for c in component_names: | |
| v = _bd_get(bd, c) | |
| if v is not None: | |
| sums[c].append(v) | |
| matrix[label] = {c: _safe_mean(sums[c]) for c in component_names} | |
| coverage[label] = n_scored | |
| n_comps = len(component_names) | |
| n_series = len(series) | |
| width = 0.8 / max(1, n_series) | |
| x = list(range(n_comps)) | |
| fig, ax = plt.subplots(figsize=(11, 5.5)) | |
| all_values: list[float] = [] | |
| for i, (label, scores) in enumerate(matrix.items()): | |
| vals = [scores[c] for c in component_names] | |
| all_values.extend(vals) | |
| n = coverage.get(label, 0) | |
| legend_label = f"{label} (n_scored={n}/{len(series.get(label, {}).get('results', []))})" | |
| ax.bar( | |
| [xi + (i - (n_series - 1) / 2) * width for xi in x], | |
| vals, | |
| width=width, | |
| label=legend_label, | |
| color=_color_for(label, i), | |
| edgecolor="white", | |
| linewidth=0.5, | |
| ) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(component_names, rotation=10, ha="right") | |
| ax.set_ylabel("Avg component score (when rubric ran)") | |
| ax.set_title("Reward breakdown by rubric component (conditional on score being computed)") | |
| ax.set_ylim(0.0, _autoscale_top(all_values, floor=1.0)) | |
| ax.grid(alpha=0.3, axis="y") | |
| ax.legend(loc="upper right", fontsize=7, framealpha=0.95) | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| # --------------------------------------------------------------------------- | |
| # Plot 4 — Before / after summary | |
| # --------------------------------------------------------------------------- | |
| def plot_before_after(evals: dict[str, dict | None], out_path: Path) -> None: | |
| series = {label: ev for label, ev in evals.items() if ev} | |
| if not series: | |
| print("[skip] before/after — no eval JSONs") | |
| return | |
| import matplotlib.pyplot as plt | |
| metrics = { | |
| "avg_score": "Avg final score", | |
| "completion_rate": "Completion rate", | |
| } | |
| n_series = len(series) | |
| n_metrics = len(metrics) | |
| width = 0.8 / max(1, n_series) | |
| x = list(range(n_metrics)) | |
| fig, ax = plt.subplots(figsize=(11, 5.5)) | |
| all_values: list[float] = [] | |
| for i, (label, ev) in enumerate(series.items()): | |
| s = ev.get("summary", {}) or {} | |
| vals = [float(s.get(k, 0.0)) for k in metrics] | |
| all_values.extend(vals) | |
| bars = ax.bar( | |
| [xi + (i - (n_series - 1) / 2) * width for xi in x], | |
| vals, | |
| width=width, | |
| label=label, | |
| color=_color_for(label, i), | |
| edgecolor="white", | |
| linewidth=0.5, | |
| ) | |
| for b, v in zip(bars, vals): | |
| if v > 0: | |
| ax.text(b.get_x() + b.get_width() / 2, v + 0.005, | |
| f"{v:.3f}", ha="center", va="bottom", fontsize=7) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(list(metrics.values()), rotation=0, ha="center") | |
| ax.set_ylim(0.0, _autoscale_top(all_values, floor=0.30)) | |
| ax.set_ylabel("Score / rate (n=50, eval v4)") | |
| ax.set_title("Aggregate eval metrics — base vs trained, all model sizes") | |
| ax.grid(alpha=0.3, axis="y") | |
| ax.legend(loc="upper left", fontsize=8, framealpha=0.95) | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| # --------------------------------------------------------------------------- | |
| # Plot 5 — Question efficiency histogram | |
| # --------------------------------------------------------------------------- | |
| def plot_question_efficiency(evals: dict[str, dict | None], out_path: Path) -> None: | |
| series = {label: ev for label, ev in evals.items() if ev} | |
| if not series: | |
| print("[skip] question hist — no eval JSONs") | |
| return | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots(figsize=(10, 5.5)) | |
| bins = list(range(0, 8)) # 0..7 questions | |
| for i, (label, ev) in enumerate(series.items()): | |
| qs = [r.get("questions_asked", 0) for r in ev.get("results", [])] | |
| ax.hist( | |
| qs, | |
| bins=bins, | |
| alpha=0.55, | |
| label=f"{label} (mean={_safe_mean(qs):.2f})", | |
| color=_color_for(label, i), | |
| edgecolor="black", | |
| align="left", | |
| ) | |
| ax.set_xticks(list(range(0, 7))) | |
| ax.set_xlabel("Questions asked per scenario") | |
| ax.set_ylabel("Count") | |
| ax.set_title("Question efficiency (lower is better, given same score)") | |
| ax.legend() | |
| ax.grid(alpha=0.3, axis="y") | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=160) | |
| plt.close(fig) | |
| print(f"[ok] {out_path}") | |
| # --------------------------------------------------------------------------- | |
| # main | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) | |
| parser.add_argument("--log-history", action="append", default=[], | |
| help="LABEL=PATH log_history.json (can repeat for multi-run)") | |
| parser.add_argument("--eval-policy", default=None, help="Path to outputs/eval_policy.json") | |
| parser.add_argument("--eval-base", default=None, help="Path to outputs/eval_<model>_base.json") | |
| parser.add_argument("--eval-trained", default=None, help="Path to outputs/eval_<model>_trained.json") | |
| parser.add_argument("--eval", action="append", default=[], | |
| help="LABEL=PATH eval JSON (can repeat to add extra series)") | |
| parser.add_argument("--out-dir", default="plots", help="Output directory for PNGs") | |
| args = parser.parse_args() | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| evals: dict[str, dict | None] = {} | |
| if args.eval_policy: | |
| evals["policy (baseline)"] = _load_json(args.eval_policy) | |
| if args.eval_base: | |
| evals["untrained"] = _load_json(args.eval_base) | |
| if args.eval_trained: | |
| evals["trained"] = _load_json(args.eval_trained) | |
| for spec in args.eval: | |
| if "=" not in spec: | |
| print(f"[warn] --eval expects LABEL=PATH, got {spec!r}; skipping") | |
| continue | |
| # rpartition: labels may legitimately contain '=' (e.g. "Run X, beta=0") | |
| label, _, path = spec.rpartition("=") | |
| label = label.strip() | |
| path = path.strip() | |
| loaded = _load_json(path) | |
| if loaded is not None: | |
| evals[label] = loaded | |
| if args.log_history and len(args.log_history) == 1 and "=" not in args.log_history[0]: | |
| log_history = _load_json(args.log_history[0]) or [] | |
| plot_reward_loss_curves(log_history, out_dir / "01_reward_loss_curves.png") | |
| elif args.log_history: | |
| labelled: dict[str, list] = {} | |
| for spec in args.log_history: | |
| # Use rpartition so labels can contain '=' (e.g. "Run X, beta=0"). | |
| # Path is always the trailing token after the LAST '='. | |
| label, sep, path = spec.rpartition("=") | |
| if not sep: | |
| # Fallback: no '=' at all → treat the whole thing as a path | |
| label, path = path, label | |
| data = _load_json(path.strip()) | |
| if data is not None: | |
| labelled[label.strip()] = data | |
| plot_reward_loss_curves_multi(labelled, out_dir / "01_reward_loss_curves.png") | |
| else: | |
| print("[skip] reward/loss curves \u2014 no log_history") | |
| plot_per_family_bars(evals, out_dir / "02_per_family_bars.png") | |
| plot_component_breakdown(evals, out_dir / "03_component_breakdown.png") | |
| plot_before_after(evals, out_dir / "04_before_after.png") | |
| plot_question_efficiency(evals, out_dir / "05_question_efficiency.png") | |
| print() | |
| print(f"All plots written to: {out_dir.resolve()}") | |
| if __name__ == "__main__": | |
| main() | |