Spaces:
Running
Running
| """Render the GRPO training-curve PNGs that the README embeds. | |
| Reads ``checkpoints/defender_grpo/<stage>/training_log.jsonl`` files | |
| written by the `_JsonLogger` callback in `train.train_grpo` and produces: | |
| * ``eval/results/training_curves.png`` — reward vs global step, | |
| one line per curriculum stage. | |
| * ``eval/results/format_compliance.png`` — `kl` and `loss` vs step | |
| (whichever fields the trainer | |
| produced) as a sanity proxy. | |
| If no JSONL logs exist (because training hasn't been run yet on this | |
| machine), the script generates *placeholder* curves from a deterministic | |
| synthetic process so the README never has a broken image link before the | |
| real GPU run finishes. The placeholder file is clearly labelled. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import os | |
| import random | |
| import sys | |
| from typing import Any, Dict, List | |
| _HERE = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, os.path.dirname(_HERE)) | |
| STAGE_ORDER = [ | |
| "stage1_basic", | |
| "stage2_multi", | |
| "stage3_mixed", | |
| "stage4_adversarial", | |
| ] | |
| STAGE_COLORS = { | |
| "stage1_basic": "#1f77b4", | |
| "stage2_multi": "#2ca02c", | |
| "stage3_mixed": "#ff7f0e", | |
| "stage4_adversarial": "#d62728", | |
| } | |
| def _read_stage_logs(grpo_root: str) -> Dict[str, List[Dict[str, Any]]]: | |
| """Read training_log.jsonl from each stage subdirectory.""" | |
| out: Dict[str, List[Dict[str, Any]]] = {} | |
| for stage in STAGE_ORDER: | |
| path = os.path.join(grpo_root, stage, "training_log.jsonl") | |
| if not os.path.exists(path): | |
| continue | |
| rows: List[Dict[str, Any]] = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| rows.append(json.loads(line)) | |
| except json.JSONDecodeError: | |
| continue | |
| if rows: | |
| out[stage] = rows | |
| return out | |
| def _placeholder_logs() -> Dict[str, List[Dict[str, Any]]]: | |
| """Make synthetic-but-believable curves so the README has a plot. | |
| Each stage's reward starts low and asymptotes; later stages start | |
| lower because they're harder. Designed to look like a noisy | |
| sigmoid: this is illustrative only and is overwritten the moment | |
| real logs land in checkpoints/defender_grpo/<stage>/training_log.jsonl. | |
| """ | |
| rng = random.Random(42) | |
| out: Dict[str, List[Dict[str, Any]]] = {} | |
| starts = {"stage1_basic": -0.4, "stage2_multi": -0.6, "stage3_mixed": -0.8, "stage4_adversarial": -0.9} | |
| asymptotes = { | |
| "stage1_basic": 0.95, | |
| "stage2_multi": 0.85, | |
| "stage3_mixed": 0.70, | |
| "stage4_adversarial": 0.55, | |
| } | |
| for stage in STAGE_ORDER: | |
| rows = [] | |
| n_steps = 200 | |
| a, b = starts[stage], asymptotes[stage] | |
| for step in range(0, n_steps, 5): | |
| t = step / n_steps | |
| mean = a + (b - a) * (1 - math.exp(-3.5 * t)) | |
| noise = rng.gauss(0, 0.07) | |
| rows.append({ | |
| "stage": stage, | |
| "step": step, | |
| "reward": max(-1.5, min(1.1, mean + noise)), | |
| "kl": 0.02 + 0.01 * t + max(0.0, rng.gauss(0, 0.005)), | |
| "loss": 0.7 - 0.3 * t + rng.gauss(0, 0.04), | |
| }) | |
| out[stage] = rows | |
| return out | |
| def _key(rows: List[Dict[str, Any]], names: List[str]) -> List[float] | None: | |
| """Return values for the first matching key, else None.""" | |
| for name in names: | |
| if any(name in r for r in rows): | |
| return [r.get(name, math.nan) for r in rows] | |
| return None | |
| def _plot_curves(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool): | |
| import matplotlib # type: ignore[import-not-found] | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # type: ignore[import-not-found] | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| cumulative = 0 | |
| for stage in STAGE_ORDER: | |
| rows = stage_logs.get(stage, []) | |
| if not rows: | |
| continue | |
| rows = sorted(rows, key=lambda r: r.get("step", 0)) | |
| steps = [cumulative + r.get("step", 0) for r in rows] | |
| rewards = _key(rows, ["reward", "rewards/mean", "train/reward", "reward_mean"]) or [ | |
| math.nan | |
| ] * len(rows) | |
| ax.plot(steps, rewards, label=stage, color=STAGE_COLORS[stage], linewidth=1.6) | |
| if rows: | |
| cumulative += max(r.get("step", 0) for r in rows) + 5 | |
| ax.axhline(0.0, color="#888", linewidth=0.6, linestyle="--") | |
| ax.set_xlabel("Global step (concatenated across stages)") | |
| ax.set_ylabel("Mean reward") | |
| title = "OpenSOC GRPO defender — reward across curriculum stages" | |
| if placeholder: | |
| title += " [placeholder — re-run after real training]" | |
| ax.set_title(title) | |
| ax.legend(loc="lower right", fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=150) | |
| plt.close(fig) | |
| def _plot_aux(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool): | |
| import matplotlib # type: ignore[import-not-found] | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # type: ignore[import-not-found] | |
| fig, axes = plt.subplots(1, 2, figsize=(10, 3.8)) | |
| for stage in STAGE_ORDER: | |
| rows = stage_logs.get(stage, []) | |
| if not rows: | |
| continue | |
| rows = sorted(rows, key=lambda r: r.get("step", 0)) | |
| steps = [r.get("step", 0) for r in rows] | |
| kl = _key(rows, ["kl", "kl_div", "objective/kl", "train/kl"]) | |
| loss = _key(rows, ["loss", "train/loss"]) | |
| if kl is not None: | |
| axes[0].plot(steps, kl, label=stage, color=STAGE_COLORS[stage], linewidth=1.4) | |
| if loss is not None: | |
| axes[1].plot(steps, loss, label=stage, color=STAGE_COLORS[stage], linewidth=1.4) | |
| axes[0].set_title("KL(policy ‖ ref)") | |
| axes[0].set_xlabel("Step (within stage)") | |
| axes[0].grid(True, alpha=0.3) | |
| axes[0].legend(fontsize=8, loc="upper right") | |
| axes[1].set_title("Training loss") | |
| axes[1].set_xlabel("Step (within stage)") | |
| axes[1].grid(True, alpha=0.3) | |
| axes[1].legend(fontsize=8, loc="upper right") | |
| suffix = " [placeholder]" if placeholder else "" | |
| fig.suptitle(f"OpenSOC GRPO — KL and loss diagnostics{suffix}") | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=150) | |
| plt.close(fig) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--grpo-root", default="checkpoints/defender_grpo", | |
| help="Directory containing <stage>/training_log.jsonl files.", | |
| ) | |
| parser.add_argument("--out-dir", default="eval/results") | |
| parser.add_argument( | |
| "--allow-placeholder", action="store_true", | |
| help="Generate fake curves if real logs are missing (default off).", | |
| ) | |
| args = parser.parse_args() | |
| grpo_root = os.path.join(os.path.dirname(_HERE), args.grpo_root) | |
| out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir) | |
| os.makedirs(out_dir, exist_ok=True) | |
| stage_logs = _read_stage_logs(grpo_root) | |
| placeholder = False | |
| if not stage_logs: | |
| if not args.allow_placeholder: | |
| print( | |
| f"No training logs found under {grpo_root}.\n" | |
| " - re-run after `python -m train.train_grpo ...` produces " | |
| "training_log.jsonl, or pass `--allow-placeholder` to render " | |
| "synthetic curves for the README scaffold.", | |
| file=sys.stderr, | |
| ) | |
| sys.exit(2) | |
| stage_logs = _placeholder_logs() | |
| placeholder = True | |
| curves_path = os.path.join(out_dir, "training_curves.png") | |
| aux_path = os.path.join(out_dir, "training_kl_loss.png") | |
| _plot_curves(stage_logs, curves_path, placeholder) | |
| _plot_aux(stage_logs, aux_path, placeholder) | |
| print(f"Wrote {curves_path} and {aux_path}" + (" [placeholder]" if placeholder else "")) | |
| if __name__ == "__main__": | |
| main() | |