File size: 2,588 Bytes
a6f0611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Plot training curves from runs/metrics.jsonl.

Outputs:
  runs/reward_curve.png  - reward per episode + rolling mean
  runs/loss_curve.png    - 1.0 - grader_score (proxy "loss") per episode

Usage:
  python plot_metrics.py [--input runs/metrics.jsonl] [--out runs/]
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


def _rolling(xs, w):
    out, acc = [], []
    for x in xs:
        acc.append(x)
        if len(acc) > w:
            acc.pop(0)
        out.append(sum(acc) / len(acc))
    return out


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--input", default="runs/metrics.jsonl")
    p.add_argument("--out", default="runs")
    p.add_argument("--window", type=int, default=10)
    args = p.parse_args()

    src = Path(args.input)
    out = Path(args.out)
    out.mkdir(parents=True, exist_ok=True)

    rows = [json.loads(l) for l in src.read_text().splitlines() if l.strip()]
    if not rows:
        raise SystemExit(f"No rows in {src}")

    eps = [r.get("episode", i + 1) for i, r in enumerate(rows)]
    rewards = [float(r.get("reward", 0.0)) for r in rows]

    plt.figure(figsize=(8, 4.5))
    plt.plot(eps, rewards, alpha=0.4, label="reward")
    plt.plot(eps, _rolling(rewards, args.window), linewidth=2, label=f"rolling mean (w={args.window})")
    plt.axhline(0, color="gray", linewidth=0.5)
    plt.xlabel("episode"); plt.ylabel("reward"); plt.title("Training reward")
    plt.legend()
    plt.savefig(out / "reward_curve.png", dpi=120)
    plt.close()

    # Proxy loss = 1 - normalized grader score; if missing, derive from recall+precision
    losses = []
    for r in rows:
        g = r.get("grader_score")
        if g is None:
            recall = float(r.get("recall", 0.0))
            precision = float(r.get("precision", 0.0))
            g = 0.55 + 0.20 * recall + 0.15 * precision if (recall >= 0.8 and precision >= 0.7) else 0.30 * recall + 0.10 * precision
        losses.append(max(0.0, 1.0 - float(g)))

    plt.figure(figsize=(8, 4.5))
    plt.plot(eps, losses, alpha=0.4, label="loss (1 - grader)")
    plt.plot(eps, _rolling(losses, args.window), linewidth=2, label=f"rolling mean (w={args.window})")
    plt.xlabel("episode"); plt.ylabel("loss"); plt.title("Training loss proxy")
    plt.legend()
    plt.savefig(out / "loss_curve.png", dpi=120)
    plt.close()

    print(f"wrote {out / 'reward_curve.png'}")
    print(f"wrote {out / 'loss_curve.png'}")


if __name__ == "__main__":
    main()