"""Generate training plots from TRL trainer_state.json or JSONL logs + eval results. Plots produced -------------- 1. reward_curve.png — mean reward ± std band + smoothed trend 2. kl_loss_curve.png — KL divergence & policy loss on twin axes 3. completion_stats.png — mean completion length + clipped-ratio line 4. bypass_bars.png — RL-trained vs handcrafted baseline (eval results) 5. per_category.png — per-scenario-category breakdown (eval results) Usage ----- python scripts/make_plots.py \\ --trainer-state /path/to/trainer_state.json \\ --eval docs/eval_results.json \\ --out docs/plots/ """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Any, Dict, List, Optional, Tuple # --------------------------------------------------------------------------- # Argument parsing # --------------------------------------------------------------------------- def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--logs", type=str, default="logs/", help="Directory of JSONL log files (legacy fallback)") p.add_argument("--trainer-state", type=str, default=None, help="Path to TRL trainer_state.json (preferred)") p.add_argument("--out", type=str, default="docs/plots/") p.add_argument("--eval", type=str, default="docs/eval_results.json") return p.parse_args() # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def _load_jsonl(path: Path) -> List[Dict[str, Any]]: rows = [] with open(path) as f: for line in f: line = line.strip() if line: try: rows.append(json.loads(line)) except json.JSONDecodeError: pass return rows def _load_all_logs(logs_dir: Path) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] for p in sorted(logs_dir.glob("*.jsonl")): rows.extend(_load_jsonl(p)) rows.sort(key=lambda r: r.get("step", r.get("global_step", 0))) return rows def _load_trainer_state(state_path: Path) -> List[Dict[str, Any]]: """Parse TRL trainer_state.json log_history into a normalised row list.""" if not state_path.exists(): print(f"trainer_state.json not found at {state_path}") return [] with open(state_path) as f: data = json.load(f) # TRL GRPO key → normalised key mapping KEY_MAP = { "reward": "reward/mean", "rewards/mean": "reward/mean", "reward/mean": "reward/mean", "reward_std": "reward/std", "rewards/std": "reward/std", "reward/std": "reward/std", "kl": "kl", "loss": "loss", "train/loss": "loss", "learning_rate": "lr", "completions/mean_length": "completion/mean_length", "completions/clipped_ratio": "completion/clipped_ratio", } rows: List[Dict[str, Any]] = [] for entry in data.get("log_history", []): step = entry.get("step") if step is None: continue row: Dict[str, Any] = {"step": step} for src, dst in KEY_MAP.items(): if src in entry: row[dst] = entry[src] if len(row) > 1: # has at least one metric besides step rows.append(row) rows.sort(key=lambda r: r["step"]) print(f"Loaded {len(rows)} log entries from {state_path}") return rows def _extract(rows: List[Dict[str, Any]], key: str) -> Tuple[List[int], List[float]]: steps, vals = [], [] for r in rows: v = r.get(key) if v is not None: steps.append(r["step"]) vals.append(float(v)) return steps, vals # --------------------------------------------------------------------------- # Plot helpers # --------------------------------------------------------------------------- def _smooth(vals: List[float], window: int) -> List[float]: import numpy as np if window <= 1 or len(vals) < window: return vals return list(np.convolve(vals, np.ones(window) / window, mode="valid")) BLUE = "#3b82f6" DBLUE = "#1d4ed8" RED = "#ef4444" GREEN = "#22c55e" ORANGE = "#f97316" PURPLE = "#a855f7" GRAY = "#94a3b8" # --------------------------------------------------------------------------- # Plot 1: Reward curve with ±std band # --------------------------------------------------------------------------- def _plot_reward_curve(rows: List[Dict[str, Any]], out_dir: Path) -> None: import matplotlib.pyplot as plt import numpy as np steps_r, rewards = _extract(rows, "reward/mean") steps_s, stds = _extract(rows, "reward/std") if not steps_r: print("No reward data — skipping reward_curve.png") return window = max(1, len(rewards) // 15) smoothed = _smooth(rewards, window) smooth_steps = steps_r[window - 1:] if window > 1 else steps_r fig, ax = plt.subplots(figsize=(10, 5)) # std band if steps_s and len(steps_s) == len(steps_r): r_arr = np.array(rewards) s_arr = np.array(stds) ax.fill_between(steps_r, r_arr - s_arr, r_arr + s_arr, alpha=0.15, color=BLUE, label="±1 std") ax.plot(steps_r, rewards, alpha=0.35, color=BLUE, linewidth=0.9, label="raw reward") ax.plot(smooth_steps, smoothed, color=DBLUE, linewidth=2.2, label=f"smoothed (w={window})") ax.axhline(0, color="gray", linestyle="--", linewidth=0.6) ax.set_xlabel("Training Step") ax.set_ylabel("Mean Reward") ax.set_title("InjectArena — GRPO Reward Curve (300 steps)") ax.legend(loc="lower right") ax.set_ylim(bottom=0) ax.grid(alpha=0.25) out_path = out_dir / "reward_curve.png" plt.tight_layout() plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved {out_path}") # --------------------------------------------------------------------------- # Plot 2: KL divergence + policy loss on twin axes # --------------------------------------------------------------------------- def _plot_kl_loss(rows: List[Dict[str, Any]], out_dir: Path) -> None: import matplotlib.pyplot as plt steps_kl, kls = _extract(rows, "kl") steps_l, losses = _extract(rows, "loss") if not steps_kl and not steps_l: print("No KL/loss data — skipping kl_loss_curve.png") return fig, ax1 = plt.subplots(figsize=(10, 4)) if steps_kl: ax1.plot(steps_kl, kls, color=PURPLE, linewidth=1.8, label="KL divergence") ax1.set_ylabel("KL Divergence", color=PURPLE) ax1.tick_params(axis="y", labelcolor=PURPLE) if steps_l: ax2 = ax1.twinx() ax2.plot(steps_l, losses, color=RED, linewidth=1.8, linestyle="--", label="Policy loss") ax2.set_ylabel("Policy Loss", color=RED) ax2.tick_params(axis="y", labelcolor=RED) ax1.set_xlabel("Training Step") ax1.set_title("InjectArena — KL Divergence & Policy Loss") ax1.grid(alpha=0.2) # Combined legend lines = [] if steps_kl: lines += ax1.get_lines() if steps_l: lines += ax2.get_lines() if lines: ax1.legend(lines, [l.get_label() for l in lines], loc="upper right") out_path = out_dir / "kl_loss_curve.png" plt.tight_layout() plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved {out_path}") # --------------------------------------------------------------------------- # Plot 3: Completion length + clipped ratio # --------------------------------------------------------------------------- def _plot_completion_stats(rows: List[Dict[str, Any]], out_dir: Path) -> None: import matplotlib.pyplot as plt steps_l, lengths = _extract(rows, "completion/mean_length") steps_c, clipped = _extract(rows, "completion/clipped_ratio") if not steps_l and not steps_c: print("No completion stats — skipping completion_stats.png") return fig, ax1 = plt.subplots(figsize=(10, 4)) if steps_l: ax1.plot(steps_l, lengths, color=ORANGE, linewidth=1.8, label="Mean completion length (tokens)") ax1.set_ylabel("Mean Length (tokens)", color=ORANGE) ax1.tick_params(axis="y", labelcolor=ORANGE) if steps_c: ax2 = ax1.twinx() ax2.plot(steps_c, clipped, color=RED, linewidth=1.8, linestyle="--", label="Clipped ratio (hit max_len)") ax2.set_ylabel("Clipped Ratio", color=RED) ax2.set_ylim(0, 1.05) ax2.tick_params(axis="y", labelcolor=RED) ax2.axhline(1.0, color=RED, linestyle=":", linewidth=0.7, alpha=0.5) ax1.set_xlabel("Training Step") ax1.set_title("InjectArena — Completion Length & Clipping") ax1.grid(alpha=0.2) lines = [] if steps_l: lines += ax1.get_lines() if steps_c: lines += ax2.get_lines() if lines: ax1.legend(lines, [l.get_label() for l in lines], loc="upper right") out_path = out_dir / "completion_stats.png" plt.tight_layout() plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved {out_path}") # --------------------------------------------------------------------------- # Plot 4: Bypass bars (eval results vs baseline) # --------------------------------------------------------------------------- def _plot_bypass_bars(eval_path: Path, out_dir: Path) -> None: import matplotlib.pyplot as plt import numpy as np if not eval_path.exists(): print(f"Eval results not found at {eval_path} — skipping bypass_bars.png") return with open(eval_path) as f: data = json.load(f) metrics = { "PG2 Bypass": data.get("pg2_bypass_rate", 0), "FW Bypass": data.get("fw_bypass_rate", 0), "Task Success": data.get("task_success_rate", 0), "Composed Bypass": data.get("composed_bypass_rate", 0), } baselines = { "PG2 Bypass": 0.15, "FW Bypass": 0.20, "Task Success": 0.05, "Composed Bypass": 0.02, } x = np.arange(len(metrics)) width = 0.35 fig, ax = plt.subplots(figsize=(9, 5)) bars1 = ax.bar(x - width / 2, [baselines[k] for k in metrics], width, label="Handcrafted Baseline", color=GRAY, edgecolor="white") bars2 = ax.bar(x + width / 2, [metrics[k] for k in metrics], width, label="InjectArena (RL-trained)", color=BLUE, edgecolor="white") ax.set_ylabel("Rate") ax.set_title("InjectArena — Attacker Performance vs Baseline") ax.set_xticks(x) ax.set_xticklabels(list(metrics.keys())) ax.set_ylim(0, 1.05) ax.legend() ax.grid(axis="y", alpha=0.3) for bar in bars1: h = bar.get_height() if h > 0.01: ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01, f"{h:.0%}", ha="center", va="bottom", fontsize=9, color="#475569") for bar in bars2: h = bar.get_height() if h > 0.01: ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01, f"{h:.0%}", ha="center", va="bottom", fontsize=9, color=DBLUE) out_path = out_dir / "bypass_bars.png" plt.tight_layout() plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved {out_path}") # --------------------------------------------------------------------------- # Plot 5: Per-category breakdown # --------------------------------------------------------------------------- def _plot_per_category(eval_path: Path, out_dir: Path) -> None: import matplotlib.pyplot as plt import numpy as np if not eval_path.exists(): return with open(eval_path) as f: data = json.load(f) per_cat = data.get("per_category", {}) if not per_cat: print("No per_category data in eval results — skipping per_category.png") return cats = list(per_cat.keys()) task_rates = [per_cat[c]["task_success"] for c in cats] bypass_rates = [per_cat[c]["composed_bypass"] for c in cats] x = np.arange(len(cats)) width = 0.35 fig, ax = plt.subplots(figsize=(8, 5)) ax.bar(x - width / 2, task_rates, width, label="Task Success", color=GREEN, edgecolor="white") ax.bar(x + width / 2, bypass_rates, width, label="Composed Bypass", color=BLUE, edgecolor="white") ax.set_ylabel("Rate") ax.set_title("InjectArena — Per-Category Breakdown") ax.set_xticks(x) ax.set_xticklabels(cats, rotation=15, ha="right") ax.set_ylim(0, 1.05) ax.legend() ax.grid(axis="y", alpha=0.3) out_path = out_dir / "per_category.png" plt.tight_layout() plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved {out_path}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: args = _parse_args() out_dir = Path(args.out) eval_path = Path(args.eval) out_dir.mkdir(parents=True, exist_ok=True) try: import matplotlib matplotlib.use("Agg") except ImportError: print("matplotlib not installed — pip install matplotlib") return # Load training log rows (trainer_state preferred, JSONL fallback) rows: List[Dict[str, Any]] = [] if args.trainer_state: rows = _load_trainer_state(Path(args.trainer_state)) if not rows: logs_dir = Path(args.logs) if logs_dir.exists(): rows = _load_all_logs(logs_dir) if rows: print(f"Loaded {len(rows)} log rows from {logs_dir}") # Training plots (require rows) if rows: _plot_reward_curve(rows, out_dir) _plot_kl_loss(rows, out_dir) _plot_completion_stats(rows, out_dir) else: print("No training log data found — skipping reward/KL/completion plots.") # Eval plots (require eval results JSON) _plot_bypass_bars(eval_path, out_dir) _plot_per_category(eval_path, out_dir) print("\nAll plots done.") for p in sorted(out_dir.glob("*.png")): print(f" {p}") if __name__ == "__main__": main()