opensoc-env / eval /plot_training.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""Render the GRPO training-curve PNGs that the README embeds.
Reads ``checkpoints/defender_grpo/<stage>/training_log.jsonl`` files
written by the `_JsonLogger` callback in `train.train_grpo` and produces:
* ``eval/results/training_curves.png`` — reward vs global step,
one line per curriculum stage.
* ``eval/results/format_compliance.png`` — `kl` and `loss` vs step
(whichever fields the trainer
produced) as a sanity proxy.
If no JSONL logs exist (because training hasn't been run yet on this
machine), the script generates *placeholder* curves from a deterministic
synthetic process so the README never has a broken image link before the
real GPU run finishes. The placeholder file is clearly labelled.
"""
from __future__ import annotations
import argparse
import json
import math
import os
import random
import sys
from typing import Any, Dict, List
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))
STAGE_ORDER = [
"stage1_basic",
"stage2_multi",
"stage3_mixed",
"stage4_adversarial",
]
STAGE_COLORS = {
"stage1_basic": "#1f77b4",
"stage2_multi": "#2ca02c",
"stage3_mixed": "#ff7f0e",
"stage4_adversarial": "#d62728",
}
def _read_stage_logs(grpo_root: str) -> Dict[str, List[Dict[str, Any]]]:
"""Read training_log.jsonl from each stage subdirectory."""
out: Dict[str, List[Dict[str, Any]]] = {}
for stage in STAGE_ORDER:
path = os.path.join(grpo_root, stage, "training_log.jsonl")
if not os.path.exists(path):
continue
rows: List[Dict[str, Any]] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
rows.append(json.loads(line))
except json.JSONDecodeError:
continue
if rows:
out[stage] = rows
return out
def _placeholder_logs() -> Dict[str, List[Dict[str, Any]]]:
"""Make synthetic-but-believable curves so the README has a plot.
Each stage's reward starts low and asymptotes; later stages start
lower because they're harder. Designed to look like a noisy
sigmoid: this is illustrative only and is overwritten the moment
real logs land in checkpoints/defender_grpo/<stage>/training_log.jsonl.
"""
rng = random.Random(42)
out: Dict[str, List[Dict[str, Any]]] = {}
starts = {"stage1_basic": -0.4, "stage2_multi": -0.6, "stage3_mixed": -0.8, "stage4_adversarial": -0.9}
asymptotes = {
"stage1_basic": 0.95,
"stage2_multi": 0.85,
"stage3_mixed": 0.70,
"stage4_adversarial": 0.55,
}
for stage in STAGE_ORDER:
rows = []
n_steps = 200
a, b = starts[stage], asymptotes[stage]
for step in range(0, n_steps, 5):
t = step / n_steps
mean = a + (b - a) * (1 - math.exp(-3.5 * t))
noise = rng.gauss(0, 0.07)
rows.append({
"stage": stage,
"step": step,
"reward": max(-1.5, min(1.1, mean + noise)),
"kl": 0.02 + 0.01 * t + max(0.0, rng.gauss(0, 0.005)),
"loss": 0.7 - 0.3 * t + rng.gauss(0, 0.04),
})
out[stage] = rows
return out
def _key(rows: List[Dict[str, Any]], names: List[str]) -> List[float] | None:
"""Return values for the first matching key, else None."""
for name in names:
if any(name in r for r in rows):
return [r.get(name, math.nan) for r in rows]
return None
def _plot_curves(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool):
import matplotlib # type: ignore[import-not-found]
matplotlib.use("Agg")
import matplotlib.pyplot as plt # type: ignore[import-not-found]
fig, ax = plt.subplots(figsize=(8, 4.5))
cumulative = 0
for stage in STAGE_ORDER:
rows = stage_logs.get(stage, [])
if not rows:
continue
rows = sorted(rows, key=lambda r: r.get("step", 0))
steps = [cumulative + r.get("step", 0) for r in rows]
rewards = _key(rows, ["reward", "rewards/mean", "train/reward", "reward_mean"]) or [
math.nan
] * len(rows)
ax.plot(steps, rewards, label=stage, color=STAGE_COLORS[stage], linewidth=1.6)
if rows:
cumulative += max(r.get("step", 0) for r in rows) + 5
ax.axhline(0.0, color="#888", linewidth=0.6, linestyle="--")
ax.set_xlabel("Global step (concatenated across stages)")
ax.set_ylabel("Mean reward")
title = "OpenSOC GRPO defender — reward across curriculum stages"
if placeholder:
title += " [placeholder — re-run after real training]"
ax.set_title(title)
ax.legend(loc="lower right", fontsize=9)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(out_path, dpi=150)
plt.close(fig)
def _plot_aux(stage_logs: Dict[str, List[Dict[str, Any]]], out_path: str, placeholder: bool):
import matplotlib # type: ignore[import-not-found]
matplotlib.use("Agg")
import matplotlib.pyplot as plt # type: ignore[import-not-found]
fig, axes = plt.subplots(1, 2, figsize=(10, 3.8))
for stage in STAGE_ORDER:
rows = stage_logs.get(stage, [])
if not rows:
continue
rows = sorted(rows, key=lambda r: r.get("step", 0))
steps = [r.get("step", 0) for r in rows]
kl = _key(rows, ["kl", "kl_div", "objective/kl", "train/kl"])
loss = _key(rows, ["loss", "train/loss"])
if kl is not None:
axes[0].plot(steps, kl, label=stage, color=STAGE_COLORS[stage], linewidth=1.4)
if loss is not None:
axes[1].plot(steps, loss, label=stage, color=STAGE_COLORS[stage], linewidth=1.4)
axes[0].set_title("KL(policy ‖ ref)")
axes[0].set_xlabel("Step (within stage)")
axes[0].grid(True, alpha=0.3)
axes[0].legend(fontsize=8, loc="upper right")
axes[1].set_title("Training loss")
axes[1].set_xlabel("Step (within stage)")
axes[1].grid(True, alpha=0.3)
axes[1].legend(fontsize=8, loc="upper right")
suffix = " [placeholder]" if placeholder else ""
fig.suptitle(f"OpenSOC GRPO — KL and loss diagnostics{suffix}")
fig.tight_layout()
fig.savefig(out_path, dpi=150)
plt.close(fig)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--grpo-root", default="checkpoints/defender_grpo",
help="Directory containing <stage>/training_log.jsonl files.",
)
parser.add_argument("--out-dir", default="eval/results")
parser.add_argument(
"--allow-placeholder", action="store_true",
help="Generate fake curves if real logs are missing (default off).",
)
args = parser.parse_args()
grpo_root = os.path.join(os.path.dirname(_HERE), args.grpo_root)
out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir)
os.makedirs(out_dir, exist_ok=True)
stage_logs = _read_stage_logs(grpo_root)
placeholder = False
if not stage_logs:
if not args.allow_placeholder:
print(
f"No training logs found under {grpo_root}.\n"
" - re-run after `python -m train.train_grpo ...` produces "
"training_log.jsonl, or pass `--allow-placeholder` to render "
"synthetic curves for the README scaffold.",
file=sys.stderr,
)
sys.exit(2)
stage_logs = _placeholder_logs()
placeholder = True
curves_path = os.path.join(out_dir, "training_curves.png")
aux_path = os.path.join(out_dir, "training_kl_loss.png")
_plot_curves(stage_logs, curves_path, placeholder)
_plot_aux(stage_logs, aux_path, placeholder)
print(f"Wrote {curves_path} and {aux_path}" + (" [placeholder]" if placeholder else ""))
if __name__ == "__main__":
main()