"""Render the GRPO training-curve PNGs that the README embeds. Reads ``checkpoints/defender_grpo//training_log.jsonl`` files written by the `_JsonLogger` callback in `train.train_grpo` and produces: * ``eval/results/training_curves.png`` — reward vs global step, one line per curriculum stage. * ``eval/results/format_compliance.png`` — `kl` and `loss` vs step (whichever fields the trainer produced) as a sanity proxy. If no JSONL logs exist (because training hasn't been run yet on this machine), the script generates *placeholder* curves from a deterministic synthetic process so the README never has a broken image link before the real GPU run finishes. The placeholder file is clearly labelled. """ from __future__ import annotations import argparse import json import math import os import random import sys from typing import Any, Dict, List _HERE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.dirname(_HERE)) STAGE_ORDER = [ "stage1_basic", "stage2_multi", "stage3_mixed", "stage4_adversarial", ] STAGE_COLORS = { "stage1_basic": "#1f77b4", "stage2_multi": "#2ca02c", "stage3_mixed": "#ff7f0e", "stage4_adversarial": "#d62728", } def _read_stage_logs(grpo_root: str) -> Dict[str, List[Dict[str, Any]]]: """Read training_log.jsonl from each stage subdirectory.""" out: Dict[str, List[Dict[str, Any]]] = {} for stage in STAGE_ORDER: path = os.path.join(grpo_root, stage, "training_log.jsonl") if not os.path.exists(path): continue rows: List[Dict[str, Any]] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: rows.append(json.loads(line)) except json.JSONDecodeError: continue if rows: out[stage] = rows return out def _placeholder_logs() -> Dict[str, List[Dict[str, Any]]]: """Make synthetic-but-believable curves so the README has a plot. Each stage's reward starts low and asymptotes; later stages start lower because they're harder. Designed to look like a noisy sigmoid: this is illustrative only and is overwritten the moment real logs land in checkpoints/defender_grpo//training_log.jsonl. """ rng = random.Random(42) out: Dict[str, List[Dict[str, Any]]] = {} starts = {"stage1_basic": -0.4, "stage2_multi": -0.6, "stage3_mixed": -0.8, "stage4_adversarial": -0.9} asymptotes = { "stage1_basic": 0.95, "stage2_multi": 0.85, "stage3_mixed": 0.70, "stage4_adversarial": 0.55, } for stage in STAGE_ORDER: rows = [] n_steps = 200 a, b = starts[stage], asymptotes[stage] for step in range(0, n_steps, 5): t = step / n_steps mean = a + (b - a) * (1 - math.exp(-3.5 * t)) noise = rng.gauss(0, 0.07) rows.append({ "stage": stage, "step": step, "reward": max(-1.5, min(1.1, mean + noise)), "kl": 0.02 + 0.01 * t + max(0.0, rng.gauss(0, 0.005)), "loss": 0.7 - 0.3 * t + rng.gauss(0, 0.04), }) out[stage] = rows return out def _key(rows: List[Dict[str, Any]], names: List[str]) -> List[float] | None: """Return values for the first matching key, else None.""" for name in names: if any(name in r for r in rows): return [r.get(name, math.nan) for r in rows] return None def _plot_curves(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool): import matplotlib # type: ignore[import-not-found] matplotlib.use("Agg") import matplotlib.pyplot as plt # type: ignore[import-not-found] fig, ax = plt.subplots(figsize=(8, 4.5)) cumulative = 0 for stage in STAGE_ORDER: rows = stage_logs.get(stage, []) if not rows: continue rows = sorted(rows, key=lambda r: r.get("step", 0)) steps = [cumulative + r.get("step", 0) for r in rows] rewards = _key(rows, ["reward", "rewards/mean", "train/reward", "reward_mean"]) or [ math.nan ] * len(rows) ax.plot(steps, rewards, label=stage, color=STAGE_COLORS[stage], linewidth=1.6) if rows: cumulative += max(r.get("step", 0) for r in rows) + 5 ax.axhline(0.0, color="#888", linewidth=0.6, linestyle="--") ax.set_xlabel("Global step (concatenated across stages)") ax.set_ylabel("Mean reward") title = "OpenSOC GRPO defender — reward across curriculum stages" if placeholder: title += " [placeholder — re-run after real training]" ax.set_title(title) ax.legend(loc="lower right", fontsize=9) ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(out_path, dpi=150) plt.close(fig) def _plot_aux(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool): import matplotlib # type: ignore[import-not-found] matplotlib.use("Agg") import matplotlib.pyplot as plt # type: ignore[import-not-found] fig, axes = plt.subplots(1, 2, figsize=(10, 3.8)) for stage in STAGE_ORDER: rows = stage_logs.get(stage, []) if not rows: continue rows = sorted(rows, key=lambda r: r.get("step", 0)) steps = [r.get("step", 0) for r in rows] kl = _key(rows, ["kl", "kl_div", "objective/kl", "train/kl"]) loss = _key(rows, ["loss", "train/loss"]) if kl is not None: axes[0].plot(steps, kl, label=stage, color=STAGE_COLORS[stage], linewidth=1.4) if loss is not None: axes[1].plot(steps, loss, label=stage, color=STAGE_COLORS[stage], linewidth=1.4) axes[0].set_title("KL(policy ‖ ref)") axes[0].set_xlabel("Step (within stage)") axes[0].grid(True, alpha=0.3) axes[0].legend(fontsize=8, loc="upper right") axes[1].set_title("Training loss") axes[1].set_xlabel("Step (within stage)") axes[1].grid(True, alpha=0.3) axes[1].legend(fontsize=8, loc="upper right") suffix = " [placeholder]" if placeholder else "" fig.suptitle(f"OpenSOC GRPO — KL and loss diagnostics{suffix}") fig.tight_layout() fig.savefig(out_path, dpi=150) plt.close(fig) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--grpo-root", default="checkpoints/defender_grpo", help="Directory containing /training_log.jsonl files.", ) parser.add_argument("--out-dir", default="eval/results") parser.add_argument( "--allow-placeholder", action="store_true", help="Generate fake curves if real logs are missing (default off).", ) args = parser.parse_args() grpo_root = os.path.join(os.path.dirname(_HERE), args.grpo_root) out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir) os.makedirs(out_dir, exist_ok=True) stage_logs = _read_stage_logs(grpo_root) placeholder = False if not stage_logs: if not args.allow_placeholder: print( f"No training logs found under {grpo_root}.\n" " - re-run after `python -m train.train_grpo ...` produces " "training_log.jsonl, or pass `--allow-placeholder` to render " "synthetic curves for the README scaffold.", file=sys.stderr, ) sys.exit(2) stage_logs = _placeholder_logs() placeholder = True curves_path = os.path.join(out_dir, "training_curves.png") aux_path = os.path.join(out_dir, "training_kl_loss.png") _plot_curves(stage_logs, curves_path, placeholder) _plot_aux(stage_logs, aux_path, placeholder) print(f"Wrote {curves_path} and {aux_path}" + (" [placeholder]" if placeholder else "")) if __name__ == "__main__": main()