#!/usr/bin/env python3 """ generate_plots.py — Build the two PNGs needed for hackathon submission ======================================================================= After a successful GRPO run on the Hugging Face Space we have: /tmp/vergil_grpo_output/trainer_state.json (written by TRL) /tmp/vergil_grpo_output/validation_log.json (written by us) This script renders: docs/plots/training_curve.png - reward / loss / KL across optimisation steps docs/plots/comparison.png - bar chart: naive vs trained-VERGIL on the eval suite (sourced from training_results/rl_training_results.json or, if absent, from the validation_log.json baseline_*) Usage ----- python scripts/generate_plots.py # use defaults python scripts/generate_plots.py \ --trainer-state /tmp/vergil_grpo_output/trainer_state.json \ --validation-log /tmp/vergil_grpo_output/validation_log.json \ --out-dir docs/plots The script never raises on missing inputs — it just prints what it *could* render and exits 0, so it's safe to call from CI or the Space's post-train hook. """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple # Use a non-interactive backend BEFORE importing pyplot so this works # headlessly inside the HF Space container. import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 # ── Style: dark theme that matches the demo UI ──────────────────────────── plt.rcParams.update({ "figure.facecolor": "#0a0e1a", "axes.facecolor": "#0f1421", "axes.edgecolor": "#2d3f58", "axes.labelcolor": "#cbd5e1", "axes.titlecolor": "#e2e8f0", "xtick.color": "#94a3b8", "ytick.color": "#94a3b8", "grid.color": "#1e293b", "font.family": "sans-serif", "font.size": 11, "axes.titlesize": 14, "axes.titleweight": "bold", "axes.grid": True, "grid.alpha": 0.4, "savefig.dpi": 150, "savefig.bbox": "tight", }) BRAND = "#8b5cf6" # purple (vergil) BRAND2 = "#22d3ee" # cyan (accent) GOOD = "#34d399" # green (rewards / completed) BAD = "#fb7185" # rose (failures) NEUTRAL = "#64748b" # slate (loss / aux) # ────────────────────────────────────────────────────────────────────────── def _load_json(path: Path) -> Optional[Any]: if not path.exists(): return None try: return json.loads(path.read_text()) except Exception as e: print(f"[plots] failed to read {path}: {e}", file=sys.stderr) return None # ────────────────────────────────────────────────────────────────────────── def plot_training_curve(trainer_state_path: Path, validation_log_path: Path, out_path: Path) -> bool: """ Render the optimisation-step curves: reward, loss, kl. Falls back to validation_log.json (mean_reward / mean_fulfillment) when trainer_state.json isn't present (e.g. baseline-only runs). """ state = _load_json(trainer_state_path) val = _load_json(validation_log_path) if state and isinstance(state.get("log_history"), list) and state["log_history"]: log = state["log_history"] steps = [e.get("step") for e in log if "step" in e] rewards = [e.get("reward") for e in log if "reward" in e] losses = [e.get("loss") for e in log if "loss" in e] kls = [e.get("kl") for e in log if "kl" in e] # If reward isn't logged separately, fall back to the validation log if not rewards and val: rewards = [v.get("mean_reward") for v in val if "mean_reward" in v] steps_r = [v.get("step") for v in val if "step" in v] else: steps_r = [e["step"] for e in log if "reward" in e] steps_l = [e["step"] for e in log if "loss" in e] steps_k = [e["step"] for e in log if "kl" in e] fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) ax = axes[0] if rewards: ax.plot(steps_r, rewards, color=GOOD, lw=2.2, marker="o", markersize=4, markerfacecolor=GOOD, markeredgecolor="#0a0e1a") ax.fill_between(steps_r, rewards, min(rewards), color=GOOD, alpha=0.12) ax.set_title("Reward (group-relative mean)") ax.set_xlabel("optimization step") ax.set_ylabel("reward") ax = axes[1] if losses: ax.plot(steps_l, losses, color=NEUTRAL, lw=2, marker="s", markersize=3.5) ax.set_title("Policy loss") ax.set_xlabel("optimization step") ax.set_ylabel("loss") ax = axes[2] if kls: ax.plot(steps_k, kls, color=BRAND2, lw=2, marker="^", markersize=3.5) ax.set_title("KL(π‖π_ref)") ax.set_xlabel("optimization step") ax.set_ylabel("KL") fig.suptitle("VERGIL — GRPO training curves", color="#e2e8f0", fontweight="bold", fontsize=15, y=1.02) fig.tight_layout() out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path) plt.close(fig) print(f"[plots] wrote {out_path}") return True # Fallback: validation log only if val: steps = [v.get("step", i) for i, v in enumerate(val)] rewards = [v.get("mean_reward", 0.0) for v in val] fulfill = [v.get("mean_fulfillment", 0.0) for v in val] fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(steps, rewards, color=GOOD, lw=2.2, marker="o", label="mean reward") ax.set_xlabel("validation step") ax.set_ylabel("mean reward", color=GOOD) ax.tick_params(axis="y", labelcolor=GOOD) ax2 = ax.twinx() ax2.plot(steps, fulfill, color=BRAND2, lw=2.2, marker="s", label="fulfillment") ax2.set_ylabel("fulfillment rate", color=BRAND2) ax2.tick_params(axis="y", labelcolor=BRAND2) ax.set_title("VERGIL — validation reward + fulfillment over training") out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path) plt.close(fig) print(f"[plots] wrote {out_path} (fallback)") return True print(f"[plots] skipped {out_path} — no trainer_state.json or validation_log.json") return False # ────────────────────────────────────────────────────────────────────────── def _pick_metrics(d: Dict[str, Any]) -> Tuple[float, float, int, float]: """Return (reward, fulfillment, n_failed, avg_trust) from a metrics dict.""" return ( float(d.get("total_reward", d.get("mean_reward", 0.0))), float(d.get("fulfillment_rate", d.get("mean_fulfillment", 0.0))), int(d.get("n_failed", 0)), float(d.get("avg_trust", d.get("trust_avg", 0.0))), ) def plot_comparison(results_path: Path, out_path: Path) -> bool: """ Bar chart: baseline (naive) vs trained-VERGIL across our 4 KPIs. """ data = _load_json(results_path) if not isinstance(data, dict): print(f"[plots] skipped {out_path} — {results_path} missing/invalid") return False naive = data.get("naive") or data.get("baseline") trained = data.get("vergil") or data.get("trained") or data.get("rl") if not (isinstance(naive, dict) and isinstance(trained, dict)): print(f"[plots] skipped {out_path} — need 'naive' + 'vergil' keys in results") return False n_r, n_f, n_fl, n_t = _pick_metrics(naive) v_r, v_f, v_fl, v_t = _pick_metrics(trained) metrics = [ ("Reward", n_r, v_r, f"{n_r:+.2f}", f"{v_r:+.2f}"), ("Fulfillment %", n_f * 100, v_f * 100, f"{n_f*100:.0f}%", f"{v_f*100:.0f}%"), ("Failures", n_fl, v_fl, f"{n_fl}", f"{v_fl}"), ("Trust %", n_t * 100, v_t * 100, f"{n_t*100:.0f}%", f"{v_t*100:.0f}%"), ] fig, axes = plt.subplots(1, 4, figsize=(15, 4.5)) for ax, (label, nv, vv, ntxt, vtxt) in zip(axes, metrics): bars = ax.bar(["Naive", "VERGIL"], [nv, vv], color=[NEUTRAL, BRAND], width=0.55, edgecolor="#0a0e1a", linewidth=1.5) ax.set_title(label) ax.bar_label(bars, labels=[ntxt, vtxt], padding=4, color="#e2e8f0", fontweight="bold", fontsize=10) ymax = max(nv, vv, 0) ymin = min(nv, vv, 0) pad = 0.18 * max(abs(ymax - ymin), 1) ax.set_ylim(ymin - pad, ymax + pad) fig.suptitle("Naive baseline vs VERGIL-trained agent", color="#e2e8f0", fontweight="bold", fontsize=15, y=1.02) fig.tight_layout() out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_path) plt.close(fig) print(f"[plots] wrote {out_path}") return True # ────────────────────────────────────────────────────────────────────────── def main() -> int: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("--trainer-state", default="/tmp/vergil_grpo_output/trainer_state.json") p.add_argument("--validation-log", default="/tmp/vergil_grpo_output/validation_log.json") p.add_argument("--results", default="training_results/rl_training_results.json", help="Naive vs trained comparison JSON (any of: rl_training_results.json, training_results.json)") p.add_argument("--out-dir", default="docs/plots") args = p.parse_args() out = Path(args.out_dir) plot_training_curve( trainer_state_path = Path(args.trainer_state), validation_log_path = Path(args.validation_log), out_path = out / "training_curve.png", ) # Try the user-specified results, then fall back through alternate names candidates = [ Path(args.results), Path("training_results/rl_training_results.json"), Path("training_results/training_results.json"), ] for c in candidates: if c.exists(): plot_comparison(c, out / "comparison.png") break else: print(f"[plots] skipped comparison.png — no results JSON found in {candidates}") return 0 if __name__ == "__main__": sys.exit(main())