Spaces:
Running
Running
File size: 2,196 Bytes
2733f3f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | """Plot training-curve PNGs from JSONL trajectories.
Usage:
python scripts/plot_curves.py train/data/eval_sweep.jsonl --output eval/results/training_curve.png
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import matplotlib.pyplot as plt
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", type=str, help="JSONL file with one row per episode/step.")
parser.add_argument("--output", type=str, default="eval/results/curve.png")
parser.add_argument("--x-field", default="step", help="Field on the x-axis (default: step).")
parser.add_argument("--y-field", default="final_score",
help="Field on the y-axis (default: final_score).")
parser.add_argument("--group-by", default=None, help="Optional grouping field (e.g. policy).")
parser.add_argument("--title", default=None)
args = parser.parse_args()
rows = [json.loads(line) for line in Path(args.input).read_text().splitlines() if line.strip()]
if not rows:
print("Empty input — nothing to plot.")
return
fig, ax = plt.subplots(figsize=(9, 5))
if args.group_by:
groups: dict[str, list[tuple[float, float]]] = {}
for r in rows:
key = str(r.get(args.group_by, "default"))
groups.setdefault(key, []).append((r.get(args.x_field, 0), r.get(args.y_field, 0)))
for key, pairs in groups.items():
xs, ys = zip(*sorted(pairs))
ax.plot(xs, ys, label=key, marker=".", linewidth=1.5, alpha=0.8)
ax.legend()
else:
xs = [r.get(args.x_field, i) for i, r in enumerate(rows)]
ys = [r.get(args.y_field, 0) for r in rows]
ax.plot(xs, ys, marker=".", linewidth=1.5)
ax.set_xlabel(args.x_field)
ax.set_ylabel(args.y_field)
ax.set_title(args.title or f"{args.y_field} over {args.x_field}")
ax.grid(alpha=0.3)
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.tight_layout()
plt.savefig(out_path, dpi=140, bbox_inches="tight")
print(f"Saved -> {out_path}")
if __name__ == "__main__":
main()
|