| """Render the GRPO training-curve PNGs that the README embeds. |
| |
| Reads ``checkpoints/defender_grpo/<stage>/training_log.jsonl`` files |
| written by the `_JsonLogger` callback in `train.train_grpo` and produces: |
| |
| * ``eval/results/training_curves.png`` — reward vs global step, |
| one line per curriculum stage. |
| * ``eval/results/format_compliance.png`` — `kl` and `loss` vs step |
| (whichever fields the trainer |
| produced) as a sanity proxy. |
| |
| If no JSONL logs exist (because training hasn't been run yet on this |
| machine), the script generates *placeholder* curves from a deterministic |
| synthetic process so the README never has a broken image link before the |
| real GPU run finishes. The placeholder file is clearly labelled. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import random |
| import sys |
| from typing import Any, Dict, List |
|
|
| _HERE = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, os.path.dirname(_HERE)) |
|
|
|
|
| STAGE_ORDER = [ |
| "stage1_basic", |
| "stage2_multi", |
| "stage3_mixed", |
| "stage4_adversarial", |
| ] |
| STAGE_COLORS = { |
| "stage1_basic": "#1f77b4", |
| "stage2_multi": "#2ca02c", |
| "stage3_mixed": "#ff7f0e", |
| "stage4_adversarial": "#d62728", |
| } |
|
|
|
|
| def _read_stage_logs(grpo_root: str) -> Dict[str, List[Dict[str, Any]]]: |
| """Read training_log.jsonl from each stage subdirectory.""" |
| out: Dict[str, List[Dict[str, Any]]] = {} |
| for stage in STAGE_ORDER: |
| path = os.path.join(grpo_root, stage, "training_log.jsonl") |
| if not os.path.exists(path): |
| continue |
| rows: List[Dict[str, Any]] = [] |
| with open(path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| rows.append(json.loads(line)) |
| except json.JSONDecodeError: |
| continue |
| if rows: |
| out[stage] = rows |
| return out |
|
|
|
|
| def _placeholder_logs() -> Dict[str, List[Dict[str, Any]]]: |
| """Make synthetic-but-believable curves so the README has a plot. |
| |
| Each stage's reward starts low and asymptotes; later stages start |
| lower because they're harder. Designed to look like a noisy |
| sigmoid: this is illustrative only and is overwritten the moment |
| real logs land in checkpoints/defender_grpo/<stage>/training_log.jsonl. |
| """ |
| rng = random.Random(42) |
| out: Dict[str, List[Dict[str, Any]]] = {} |
| starts = {"stage1_basic": -0.4, "stage2_multi": -0.6, "stage3_mixed": -0.8, "stage4_adversarial": -0.9} |
| asymptotes = { |
| "stage1_basic": 0.95, |
| "stage2_multi": 0.85, |
| "stage3_mixed": 0.70, |
| "stage4_adversarial": 0.55, |
| } |
| for stage in STAGE_ORDER: |
| rows = [] |
| n_steps = 200 |
| a, b = starts[stage], asymptotes[stage] |
| for step in range(0, n_steps, 5): |
| t = step / n_steps |
| mean = a + (b - a) * (1 - math.exp(-3.5 * t)) |
| noise = rng.gauss(0, 0.07) |
| rows.append({ |
| "stage": stage, |
| "step": step, |
| "reward": max(-1.5, min(1.1, mean + noise)), |
| "kl": 0.02 + 0.01 * t + max(0.0, rng.gauss(0, 0.005)), |
| "loss": 0.7 - 0.3 * t + rng.gauss(0, 0.04), |
| }) |
| out[stage] = rows |
| return out |
|
|
|
|
| def _key(rows: List[Dict[str, Any]], names: List[str]) -> List[float] | None: |
| """Return values for the first matching key, else None.""" |
| for name in names: |
| if any(name in r for r in rows): |
| return [r.get(name, math.nan) for r in rows] |
| return None |
|
|
|
|
| def _plot_curves(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| fig, ax = plt.subplots(figsize=(8, 4.5)) |
| cumulative = 0 |
| for stage in STAGE_ORDER: |
| rows = stage_logs.get(stage, []) |
| if not rows: |
| continue |
| rows = sorted(rows, key=lambda r: r.get("step", 0)) |
| steps = [cumulative + r.get("step", 0) for r in rows] |
| rewards = _key(rows, ["reward", "rewards/mean", "train/reward", "reward_mean"]) or [ |
| math.nan |
| ] * len(rows) |
| ax.plot(steps, rewards, label=stage, color=STAGE_COLORS[stage], linewidth=1.6) |
| if rows: |
| cumulative += max(r.get("step", 0) for r in rows) + 5 |
|
|
| ax.axhline(0.0, color="#888", linewidth=0.6, linestyle="--") |
| ax.set_xlabel("Global step (concatenated across stages)") |
| ax.set_ylabel("Mean reward") |
| title = "OpenSOC GRPO defender — reward across curriculum stages" |
| if placeholder: |
| title += " [placeholder — re-run after real training]" |
| ax.set_title(title) |
| ax.legend(loc="lower right", fontsize=9) |
| ax.grid(True, alpha=0.3) |
| fig.tight_layout() |
| fig.savefig(out_path, dpi=150) |
| plt.close(fig) |
|
|
|
|
| def _plot_aux(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(10, 3.8)) |
| for stage in STAGE_ORDER: |
| rows = stage_logs.get(stage, []) |
| if not rows: |
| continue |
| rows = sorted(rows, key=lambda r: r.get("step", 0)) |
| steps = [r.get("step", 0) for r in rows] |
| kl = _key(rows, ["kl", "kl_div", "objective/kl", "train/kl"]) |
| loss = _key(rows, ["loss", "train/loss"]) |
| if kl is not None: |
| axes[0].plot(steps, kl, label=stage, color=STAGE_COLORS[stage], linewidth=1.4) |
| if loss is not None: |
| axes[1].plot(steps, loss, label=stage, color=STAGE_COLORS[stage], linewidth=1.4) |
| axes[0].set_title("KL(policy ‖ ref)") |
| axes[0].set_xlabel("Step (within stage)") |
| axes[0].grid(True, alpha=0.3) |
| axes[0].legend(fontsize=8, loc="upper right") |
| axes[1].set_title("Training loss") |
| axes[1].set_xlabel("Step (within stage)") |
| axes[1].grid(True, alpha=0.3) |
| axes[1].legend(fontsize=8, loc="upper right") |
| suffix = " [placeholder]" if placeholder else "" |
| fig.suptitle(f"OpenSOC GRPO — KL and loss diagnostics{suffix}") |
| fig.tight_layout() |
| fig.savefig(out_path, dpi=150) |
| plt.close(fig) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--grpo-root", default="checkpoints/defender_grpo", |
| help="Directory containing <stage>/training_log.jsonl files.", |
| ) |
| parser.add_argument("--out-dir", default="eval/results") |
| parser.add_argument( |
| "--allow-placeholder", action="store_true", |
| help="Generate fake curves if real logs are missing (default off).", |
| ) |
| args = parser.parse_args() |
|
|
| grpo_root = os.path.join(os.path.dirname(_HERE), args.grpo_root) |
| out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| stage_logs = _read_stage_logs(grpo_root) |
| placeholder = False |
| if not stage_logs: |
| if not args.allow_placeholder: |
| print( |
| f"No training logs found under {grpo_root}.\n" |
| " - re-run after `python -m train.train_grpo ...` produces " |
| "training_log.jsonl, or pass `--allow-placeholder` to render " |
| "synthetic curves for the README scaffold.", |
| file=sys.stderr, |
| ) |
| sys.exit(2) |
| stage_logs = _placeholder_logs() |
| placeholder = True |
|
|
| curves_path = os.path.join(out_dir, "training_curves.png") |
| aux_path = os.path.join(out_dir, "training_kl_loss.png") |
| _plot_curves(stage_logs, curves_path, placeholder) |
| _plot_aux(stage_logs, aux_path, placeholder) |
|
|
| print(f"Wrote {curves_path} and {aux_path}" + (" [placeholder]" if placeholder else "")) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|