Spaces:
Sleeping
Sleeping
| """Plot training curves from runs/metrics.jsonl. | |
| Outputs: | |
| runs/reward_curve.png - reward per episode + rolling mean | |
| runs/loss_curve.png - 1.0 - grader_score (proxy "loss") per episode | |
| Usage: | |
| python plot_metrics.py [--input runs/metrics.jsonl] [--out runs/] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| def _rolling(xs, w): | |
| out, acc = [], [] | |
| for x in xs: | |
| acc.append(x) | |
| if len(acc) > w: | |
| acc.pop(0) | |
| out.append(sum(acc) / len(acc)) | |
| return out | |
| def main() -> None: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--input", default="runs/metrics.jsonl") | |
| p.add_argument("--out", default="runs") | |
| p.add_argument("--window", type=int, default=10) | |
| args = p.parse_args() | |
| src = Path(args.input) | |
| out = Path(args.out) | |
| out.mkdir(parents=True, exist_ok=True) | |
| rows = [json.loads(l) for l in src.read_text().splitlines() if l.strip()] | |
| if not rows: | |
| raise SystemExit(f"No rows in {src}") | |
| eps = [r.get("episode", i + 1) for i, r in enumerate(rows)] | |
| rewards = [float(r.get("reward", 0.0)) for r in rows] | |
| plt.figure(figsize=(8, 4.5)) | |
| plt.plot(eps, rewards, alpha=0.4, label="reward") | |
| plt.plot(eps, _rolling(rewards, args.window), linewidth=2, label=f"rolling mean (w={args.window})") | |
| plt.axhline(0, color="gray", linewidth=0.5) | |
| plt.xlabel("episode"); plt.ylabel("reward"); plt.title("Training reward") | |
| plt.legend() | |
| plt.savefig(out / "reward_curve.png", dpi=120) | |
| plt.close() | |
| # Proxy loss = 1 - normalized grader score; if missing, derive from recall+precision | |
| losses = [] | |
| for r in rows: | |
| g = r.get("grader_score") | |
| if g is None: | |
| recall = float(r.get("recall", 0.0)) | |
| precision = float(r.get("precision", 0.0)) | |
| g = 0.55 + 0.20 * recall + 0.15 * precision if (recall >= 0.8 and precision >= 0.7) else 0.30 * recall + 0.10 * precision | |
| losses.append(max(0.0, 1.0 - float(g))) | |
| plt.figure(figsize=(8, 4.5)) | |
| plt.plot(eps, losses, alpha=0.4, label="loss (1 - grader)") | |
| plt.plot(eps, _rolling(losses, args.window), linewidth=2, label=f"rolling mean (w={args.window})") | |
| plt.xlabel("episode"); plt.ylabel("loss"); plt.title("Training loss proxy") | |
| plt.legend() | |
| plt.savefig(out / "loss_curve.png", dpi=120) | |
| plt.close() | |
| print(f"wrote {out / 'reward_curve.png'}") | |
| print(f"wrote {out / 'loss_curve.png'}") | |
| if __name__ == "__main__": | |
| main() | |