Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| generate_plots.py — Build the two PNGs needed for hackathon submission | |
| ======================================================================= | |
| After a successful GRPO run on the Hugging Face Space we have: | |
| /tmp/vergil_grpo_output/trainer_state.json (written by TRL) | |
| /tmp/vergil_grpo_output/validation_log.json (written by us) | |
| This script renders: | |
| docs/plots/training_curve.png | |
| - reward / loss / KL across optimisation steps | |
| docs/plots/comparison.png | |
| - bar chart: naive vs trained-VERGIL on the eval suite | |
| (sourced from training_results/rl_training_results.json | |
| or, if absent, from the validation_log.json baseline_*) | |
| Usage | |
| ----- | |
| python scripts/generate_plots.py # use defaults | |
| python scripts/generate_plots.py \ | |
| --trainer-state /tmp/vergil_grpo_output/trainer_state.json \ | |
| --validation-log /tmp/vergil_grpo_output/validation_log.json \ | |
| --out-dir docs/plots | |
| The script never raises on missing inputs — it just prints what it | |
| *could* render and exits 0, so it's safe to call from CI or the Space's | |
| post-train hook. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| # Use a non-interactive backend BEFORE importing pyplot so this works | |
| # headlessly inside the HF Space container. | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| # ── Style: dark theme that matches the demo UI ──────────────────────────── | |
| plt.rcParams.update({ | |
| "figure.facecolor": "#0a0e1a", | |
| "axes.facecolor": "#0f1421", | |
| "axes.edgecolor": "#2d3f58", | |
| "axes.labelcolor": "#cbd5e1", | |
| "axes.titlecolor": "#e2e8f0", | |
| "xtick.color": "#94a3b8", | |
| "ytick.color": "#94a3b8", | |
| "grid.color": "#1e293b", | |
| "font.family": "sans-serif", | |
| "font.size": 11, | |
| "axes.titlesize": 14, | |
| "axes.titleweight": "bold", | |
| "axes.grid": True, | |
| "grid.alpha": 0.4, | |
| "savefig.dpi": 150, | |
| "savefig.bbox": "tight", | |
| }) | |
| BRAND = "#8b5cf6" # purple (vergil) | |
| BRAND2 = "#22d3ee" # cyan (accent) | |
| GOOD = "#34d399" # green (rewards / completed) | |
| BAD = "#fb7185" # rose (failures) | |
| NEUTRAL = "#64748b" # slate (loss / aux) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def _load_json(path: Path) -> Optional[Any]: | |
| if not path.exists(): | |
| return None | |
| try: | |
| return json.loads(path.read_text()) | |
| except Exception as e: | |
| print(f"[plots] failed to read {path}: {e}", file=sys.stderr) | |
| return None | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def plot_training_curve(trainer_state_path: Path, | |
| validation_log_path: Path, | |
| out_path: Path) -> bool: | |
| """ | |
| Render the optimisation-step curves: reward, loss, kl. | |
| Falls back to validation_log.json (mean_reward / mean_fulfillment) | |
| when trainer_state.json isn't present (e.g. baseline-only runs). | |
| """ | |
| state = _load_json(trainer_state_path) | |
| val = _load_json(validation_log_path) | |
| if state and isinstance(state.get("log_history"), list) and state["log_history"]: | |
| log = state["log_history"] | |
| steps = [e.get("step") for e in log if "step" in e] | |
| rewards = [e.get("reward") for e in log if "reward" in e] | |
| losses = [e.get("loss") for e in log if "loss" in e] | |
| kls = [e.get("kl") for e in log if "kl" in e] | |
| # If reward isn't logged separately, fall back to the validation log | |
| if not rewards and val: | |
| rewards = [v.get("mean_reward") for v in val if "mean_reward" in v] | |
| steps_r = [v.get("step") for v in val if "step" in v] | |
| else: | |
| steps_r = [e["step"] for e in log if "reward" in e] | |
| steps_l = [e["step"] for e in log if "loss" in e] | |
| steps_k = [e["step"] for e in log if "kl" in e] | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) | |
| ax = axes[0] | |
| if rewards: | |
| ax.plot(steps_r, rewards, color=GOOD, lw=2.2, marker="o", | |
| markersize=4, markerfacecolor=GOOD, markeredgecolor="#0a0e1a") | |
| ax.fill_between(steps_r, rewards, min(rewards), color=GOOD, alpha=0.12) | |
| ax.set_title("Reward (group-relative mean)") | |
| ax.set_xlabel("optimization step") | |
| ax.set_ylabel("reward") | |
| ax = axes[1] | |
| if losses: | |
| ax.plot(steps_l, losses, color=NEUTRAL, lw=2, marker="s", | |
| markersize=3.5) | |
| ax.set_title("Policy loss") | |
| ax.set_xlabel("optimization step") | |
| ax.set_ylabel("loss") | |
| ax = axes[2] | |
| if kls: | |
| ax.plot(steps_k, kls, color=BRAND2, lw=2, marker="^", | |
| markersize=3.5) | |
| ax.set_title("KL(π‖π_ref)") | |
| ax.set_xlabel("optimization step") | |
| ax.set_ylabel("KL") | |
| fig.suptitle("VERGIL — GRPO training curves", | |
| color="#e2e8f0", fontweight="bold", fontsize=15, y=1.02) | |
| fig.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out_path) | |
| plt.close(fig) | |
| print(f"[plots] wrote {out_path}") | |
| return True | |
| # Fallback: validation log only | |
| if val: | |
| steps = [v.get("step", i) for i, v in enumerate(val)] | |
| rewards = [v.get("mean_reward", 0.0) for v in val] | |
| fulfill = [v.get("mean_fulfillment", 0.0) for v in val] | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(steps, rewards, color=GOOD, lw=2.2, marker="o", label="mean reward") | |
| ax.set_xlabel("validation step") | |
| ax.set_ylabel("mean reward", color=GOOD) | |
| ax.tick_params(axis="y", labelcolor=GOOD) | |
| ax2 = ax.twinx() | |
| ax2.plot(steps, fulfill, color=BRAND2, lw=2.2, marker="s", label="fulfillment") | |
| ax2.set_ylabel("fulfillment rate", color=BRAND2) | |
| ax2.tick_params(axis="y", labelcolor=BRAND2) | |
| ax.set_title("VERGIL — validation reward + fulfillment over training") | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out_path) | |
| plt.close(fig) | |
| print(f"[plots] wrote {out_path} (fallback)") | |
| return True | |
| print(f"[plots] skipped {out_path} — no trainer_state.json or validation_log.json") | |
| return False | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def _pick_metrics(d: Dict[str, Any]) -> Tuple[float, float, int, float]: | |
| """Return (reward, fulfillment, n_failed, avg_trust) from a metrics dict.""" | |
| return ( | |
| float(d.get("total_reward", d.get("mean_reward", 0.0))), | |
| float(d.get("fulfillment_rate", d.get("mean_fulfillment", 0.0))), | |
| int(d.get("n_failed", 0)), | |
| float(d.get("avg_trust", d.get("trust_avg", 0.0))), | |
| ) | |
| def plot_comparison(results_path: Path, | |
| out_path: Path) -> bool: | |
| """ | |
| Bar chart: baseline (naive) vs trained-VERGIL across our 4 KPIs. | |
| """ | |
| data = _load_json(results_path) | |
| if not isinstance(data, dict): | |
| print(f"[plots] skipped {out_path} — {results_path} missing/invalid") | |
| return False | |
| naive = data.get("naive") or data.get("baseline") | |
| trained = data.get("vergil") or data.get("trained") or data.get("rl") | |
| if not (isinstance(naive, dict) and isinstance(trained, dict)): | |
| print(f"[plots] skipped {out_path} — need 'naive' + 'vergil' keys in results") | |
| return False | |
| n_r, n_f, n_fl, n_t = _pick_metrics(naive) | |
| v_r, v_f, v_fl, v_t = _pick_metrics(trained) | |
| metrics = [ | |
| ("Reward", n_r, v_r, f"{n_r:+.2f}", f"{v_r:+.2f}"), | |
| ("Fulfillment %", n_f * 100, v_f * 100, f"{n_f*100:.0f}%", f"{v_f*100:.0f}%"), | |
| ("Failures", n_fl, v_fl, f"{n_fl}", f"{v_fl}"), | |
| ("Trust %", n_t * 100, v_t * 100, f"{n_t*100:.0f}%", f"{v_t*100:.0f}%"), | |
| ] | |
| fig, axes = plt.subplots(1, 4, figsize=(15, 4.5)) | |
| for ax, (label, nv, vv, ntxt, vtxt) in zip(axes, metrics): | |
| bars = ax.bar(["Naive", "VERGIL"], [nv, vv], | |
| color=[NEUTRAL, BRAND], width=0.55, edgecolor="#0a0e1a", linewidth=1.5) | |
| ax.set_title(label) | |
| ax.bar_label(bars, labels=[ntxt, vtxt], padding=4, | |
| color="#e2e8f0", fontweight="bold", fontsize=10) | |
| ymax = max(nv, vv, 0) | |
| ymin = min(nv, vv, 0) | |
| pad = 0.18 * max(abs(ymax - ymin), 1) | |
| ax.set_ylim(ymin - pad, ymax + pad) | |
| fig.suptitle("Naive baseline vs VERGIL-trained agent", | |
| color="#e2e8f0", fontweight="bold", fontsize=15, y=1.02) | |
| fig.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out_path) | |
| plt.close(fig) | |
| print(f"[plots] wrote {out_path}") | |
| return True | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def main() -> int: | |
| p = argparse.ArgumentParser(description=__doc__, | |
| formatter_class=argparse.RawDescriptionHelpFormatter) | |
| p.add_argument("--trainer-state", | |
| default="/tmp/vergil_grpo_output/trainer_state.json") | |
| p.add_argument("--validation-log", | |
| default="/tmp/vergil_grpo_output/validation_log.json") | |
| p.add_argument("--results", | |
| default="training_results/rl_training_results.json", | |
| help="Naive vs trained comparison JSON (any of: rl_training_results.json, training_results.json)") | |
| p.add_argument("--out-dir", default="docs/plots") | |
| args = p.parse_args() | |
| out = Path(args.out_dir) | |
| plot_training_curve( | |
| trainer_state_path = Path(args.trainer_state), | |
| validation_log_path = Path(args.validation_log), | |
| out_path = out / "training_curve.png", | |
| ) | |
| # Try the user-specified results, then fall back through alternate names | |
| candidates = [ | |
| Path(args.results), | |
| Path("training_results/rl_training_results.json"), | |
| Path("training_results/training_results.json"), | |
| ] | |
| for c in candidates: | |
| if c.exists(): | |
| plot_comparison(c, out / "comparison.png") | |
| break | |
| else: | |
| print(f"[plots] skipped comparison.png — no results JSON found in {candidates}") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |