vergil-training / scripts /generate_plots.py
Laksh718
feat(submission): OpenEnv shim + plot pipeline + demo Space deploy + docs
ce44f4b
#!/usr/bin/env python3
"""
generate_plots.py — Build the two PNGs needed for hackathon submission
=======================================================================
After a successful GRPO run on the Hugging Face Space we have:
/tmp/vergil_grpo_output/trainer_state.json (written by TRL)
/tmp/vergil_grpo_output/validation_log.json (written by us)
This script renders:
docs/plots/training_curve.png
- reward / loss / KL across optimisation steps
docs/plots/comparison.png
- bar chart: naive vs trained-VERGIL on the eval suite
(sourced from training_results/rl_training_results.json
or, if absent, from the validation_log.json baseline_*)
Usage
-----
python scripts/generate_plots.py # use defaults
python scripts/generate_plots.py \
--trainer-state /tmp/vergil_grpo_output/trainer_state.json \
--validation-log /tmp/vergil_grpo_output/validation_log.json \
--out-dir docs/plots
The script never raises on missing inputs — it just prints what it
*could* render and exits 0, so it's safe to call from CI or the Space's
post-train hook.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# Use a non-interactive backend BEFORE importing pyplot so this works
# headlessly inside the HF Space container.
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
# ── Style: dark theme that matches the demo UI ────────────────────────────
plt.rcParams.update({
"figure.facecolor": "#0a0e1a",
"axes.facecolor": "#0f1421",
"axes.edgecolor": "#2d3f58",
"axes.labelcolor": "#cbd5e1",
"axes.titlecolor": "#e2e8f0",
"xtick.color": "#94a3b8",
"ytick.color": "#94a3b8",
"grid.color": "#1e293b",
"font.family": "sans-serif",
"font.size": 11,
"axes.titlesize": 14,
"axes.titleweight": "bold",
"axes.grid": True,
"grid.alpha": 0.4,
"savefig.dpi": 150,
"savefig.bbox": "tight",
})
BRAND = "#8b5cf6" # purple (vergil)
BRAND2 = "#22d3ee" # cyan (accent)
GOOD = "#34d399" # green (rewards / completed)
BAD = "#fb7185" # rose (failures)
NEUTRAL = "#64748b" # slate (loss / aux)
# ──────────────────────────────────────────────────────────────────────────
def _load_json(path: Path) -> Optional[Any]:
if not path.exists():
return None
try:
return json.loads(path.read_text())
except Exception as e:
print(f"[plots] failed to read {path}: {e}", file=sys.stderr)
return None
# ──────────────────────────────────────────────────────────────────────────
def plot_training_curve(trainer_state_path: Path,
validation_log_path: Path,
out_path: Path) -> bool:
"""
Render the optimisation-step curves: reward, loss, kl.
Falls back to validation_log.json (mean_reward / mean_fulfillment)
when trainer_state.json isn't present (e.g. baseline-only runs).
"""
state = _load_json(trainer_state_path)
val = _load_json(validation_log_path)
if state and isinstance(state.get("log_history"), list) and state["log_history"]:
log = state["log_history"]
steps = [e.get("step") for e in log if "step" in e]
rewards = [e.get("reward") for e in log if "reward" in e]
losses = [e.get("loss") for e in log if "loss" in e]
kls = [e.get("kl") for e in log if "kl" in e]
# If reward isn't logged separately, fall back to the validation log
if not rewards and val:
rewards = [v.get("mean_reward") for v in val if "mean_reward" in v]
steps_r = [v.get("step") for v in val if "step" in v]
else:
steps_r = [e["step"] for e in log if "reward" in e]
steps_l = [e["step"] for e in log if "loss" in e]
steps_k = [e["step"] for e in log if "kl" in e]
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
ax = axes[0]
if rewards:
ax.plot(steps_r, rewards, color=GOOD, lw=2.2, marker="o",
markersize=4, markerfacecolor=GOOD, markeredgecolor="#0a0e1a")
ax.fill_between(steps_r, rewards, min(rewards), color=GOOD, alpha=0.12)
ax.set_title("Reward (group-relative mean)")
ax.set_xlabel("optimization step")
ax.set_ylabel("reward")
ax = axes[1]
if losses:
ax.plot(steps_l, losses, color=NEUTRAL, lw=2, marker="s",
markersize=3.5)
ax.set_title("Policy loss")
ax.set_xlabel("optimization step")
ax.set_ylabel("loss")
ax = axes[2]
if kls:
ax.plot(steps_k, kls, color=BRAND2, lw=2, marker="^",
markersize=3.5)
ax.set_title("KL(π‖π_ref)")
ax.set_xlabel("optimization step")
ax.set_ylabel("KL")
fig.suptitle("VERGIL — GRPO training curves",
color="#e2e8f0", fontweight="bold", fontsize=15, y=1.02)
fig.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path)
plt.close(fig)
print(f"[plots] wrote {out_path}")
return True
# Fallback: validation log only
if val:
steps = [v.get("step", i) for i, v in enumerate(val)]
rewards = [v.get("mean_reward", 0.0) for v in val]
fulfill = [v.get("mean_fulfillment", 0.0) for v in val]
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(steps, rewards, color=GOOD, lw=2.2, marker="o", label="mean reward")
ax.set_xlabel("validation step")
ax.set_ylabel("mean reward", color=GOOD)
ax.tick_params(axis="y", labelcolor=GOOD)
ax2 = ax.twinx()
ax2.plot(steps, fulfill, color=BRAND2, lw=2.2, marker="s", label="fulfillment")
ax2.set_ylabel("fulfillment rate", color=BRAND2)
ax2.tick_params(axis="y", labelcolor=BRAND2)
ax.set_title("VERGIL — validation reward + fulfillment over training")
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path)
plt.close(fig)
print(f"[plots] wrote {out_path} (fallback)")
return True
print(f"[plots] skipped {out_path} — no trainer_state.json or validation_log.json")
return False
# ──────────────────────────────────────────────────────────────────────────
def _pick_metrics(d: Dict[str, Any]) -> Tuple[float, float, int, float]:
"""Return (reward, fulfillment, n_failed, avg_trust) from a metrics dict."""
return (
float(d.get("total_reward", d.get("mean_reward", 0.0))),
float(d.get("fulfillment_rate", d.get("mean_fulfillment", 0.0))),
int(d.get("n_failed", 0)),
float(d.get("avg_trust", d.get("trust_avg", 0.0))),
)
def plot_comparison(results_path: Path,
out_path: Path) -> bool:
"""
Bar chart: baseline (naive) vs trained-VERGIL across our 4 KPIs.
"""
data = _load_json(results_path)
if not isinstance(data, dict):
print(f"[plots] skipped {out_path}{results_path} missing/invalid")
return False
naive = data.get("naive") or data.get("baseline")
trained = data.get("vergil") or data.get("trained") or data.get("rl")
if not (isinstance(naive, dict) and isinstance(trained, dict)):
print(f"[plots] skipped {out_path} — need 'naive' + 'vergil' keys in results")
return False
n_r, n_f, n_fl, n_t = _pick_metrics(naive)
v_r, v_f, v_fl, v_t = _pick_metrics(trained)
metrics = [
("Reward", n_r, v_r, f"{n_r:+.2f}", f"{v_r:+.2f}"),
("Fulfillment %", n_f * 100, v_f * 100, f"{n_f*100:.0f}%", f"{v_f*100:.0f}%"),
("Failures", n_fl, v_fl, f"{n_fl}", f"{v_fl}"),
("Trust %", n_t * 100, v_t * 100, f"{n_t*100:.0f}%", f"{v_t*100:.0f}%"),
]
fig, axes = plt.subplots(1, 4, figsize=(15, 4.5))
for ax, (label, nv, vv, ntxt, vtxt) in zip(axes, metrics):
bars = ax.bar(["Naive", "VERGIL"], [nv, vv],
color=[NEUTRAL, BRAND], width=0.55, edgecolor="#0a0e1a", linewidth=1.5)
ax.set_title(label)
ax.bar_label(bars, labels=[ntxt, vtxt], padding=4,
color="#e2e8f0", fontweight="bold", fontsize=10)
ymax = max(nv, vv, 0)
ymin = min(nv, vv, 0)
pad = 0.18 * max(abs(ymax - ymin), 1)
ax.set_ylim(ymin - pad, ymax + pad)
fig.suptitle("Naive baseline vs VERGIL-trained agent",
color="#e2e8f0", fontweight="bold", fontsize=15, y=1.02)
fig.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path)
plt.close(fig)
print(f"[plots] wrote {out_path}")
return True
# ──────────────────────────────────────────────────────────────────────────
def main() -> int:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--trainer-state",
default="/tmp/vergil_grpo_output/trainer_state.json")
p.add_argument("--validation-log",
default="/tmp/vergil_grpo_output/validation_log.json")
p.add_argument("--results",
default="training_results/rl_training_results.json",
help="Naive vs trained comparison JSON (any of: rl_training_results.json, training_results.json)")
p.add_argument("--out-dir", default="docs/plots")
args = p.parse_args()
out = Path(args.out_dir)
plot_training_curve(
trainer_state_path = Path(args.trainer_state),
validation_log_path = Path(args.validation_log),
out_path = out / "training_curve.png",
)
# Try the user-specified results, then fall back through alternate names
candidates = [
Path(args.results),
Path("training_results/rl_training_results.json"),
Path("training_results/training_results.json"),
]
for c in candidates:
if c.exists():
plot_comparison(c, out / "comparison.png")
break
else:
print(f"[plots] skipped comparison.png — no results JSON found in {candidates}")
return 0
if __name__ == "__main__":
sys.exit(main())