Spaces:
Running
Running
| """Plot training-curve PNGs from JSONL trajectories. | |
| Usage: | |
| python scripts/plot_curves.py train/data/eval_sweep.jsonl --output eval/results/training_curve.png | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("input", type=str, help="JSONL file with one row per episode/step.") | |
| parser.add_argument("--output", type=str, default="eval/results/curve.png") | |
| parser.add_argument("--x-field", default="step", help="Field on the x-axis (default: step).") | |
| parser.add_argument("--y-field", default="final_score", | |
| help="Field on the y-axis (default: final_score).") | |
| parser.add_argument("--group-by", default=None, help="Optional grouping field (e.g. policy).") | |
| parser.add_argument("--title", default=None) | |
| args = parser.parse_args() | |
| rows = [json.loads(line) for line in Path(args.input).read_text().splitlines() if line.strip()] | |
| if not rows: | |
| print("Empty input β nothing to plot.") | |
| return | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| if args.group_by: | |
| groups: dict[str, list[tuple[float, float]]] = {} | |
| for r in rows: | |
| key = str(r.get(args.group_by, "default")) | |
| groups.setdefault(key, []).append((r.get(args.x_field, 0), r.get(args.y_field, 0))) | |
| for key, pairs in groups.items(): | |
| xs, ys = zip(*sorted(pairs)) | |
| ax.plot(xs, ys, label=key, marker=".", linewidth=1.5, alpha=0.8) | |
| ax.legend() | |
| else: | |
| xs = [r.get(args.x_field, i) for i, r in enumerate(rows)] | |
| ys = [r.get(args.y_field, 0) for r in rows] | |
| ax.plot(xs, ys, marker=".", linewidth=1.5) | |
| ax.set_xlabel(args.x_field) | |
| ax.set_ylabel(args.y_field) | |
| ax.set_title(args.title or f"{args.y_field} over {args.x_field}") | |
| ax.grid(alpha=0.3) | |
| out_path = Path(args.output) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=140, bbox_inches="tight") | |
| print(f"Saved -> {out_path}") | |
| if __name__ == "__main__": | |
| main() | |