vision-coder-openenv / scripts /generate_training_curve.py
amaljoe88's picture
deploy: sync 247eae2b from GitHub Actions
58191e8 verified
"""Generate training_curve.png from assets/train.jsonl."""
import json
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ROOT = Path(__file__).parent.parent
DATA = ROOT / "assets" / "train.jsonl"
OUT = ROOT / "assets" / "training_curve.png"
def smooth(vals, w=3):
out = []
for i in range(len(vals)):
sl = vals[max(0, i - w):i + w + 1]
out.append(sum(sl) / len(sl))
return out
def main():
rows = [json.loads(l) for l in DATA.read_text().splitlines() if l.strip()]
iters = [r["iter"] for r in rows]
easy = [r["easy"] for r in rows]
medium = [r["medium"] for r in rows]
hard = [r["hard"] for r in rows]
mean = [r["mean"] for r in rows]
loss = [r.get("loss") for r in rows]
colors = {"easy": "#3b82f6", "medium": "#22c55e", "hard": "#ef4444", "mean": "#facc15"}
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), facecolor="#0d1117")
fig.subplots_adjust(hspace=0.35)
for ax in (ax1, ax2):
ax.set_facecolor("#161b22")
for spine in ax.spines.values():
spine.set_edgecolor("#30363d")
ax.tick_params(colors="#8b949e")
ax.xaxis.label.set_color("#8b949e")
ax.yaxis.label.set_color("#8b949e")
ax.title.set_color("#e6edf3")
ax.grid(True, alpha=0.15, color="#30363d")
ax1.plot(iters, easy, color=colors["easy"], linewidth=1.4, label="Easy", alpha=0.85)
ax1.plot(iters, medium, color=colors["medium"], linewidth=1.4, label="Medium", alpha=0.85)
ax1.plot(iters, hard, color=colors["hard"], linewidth=1.4, label="Hard", alpha=0.85)
ax1.plot(iters, smooth(mean), color=colors["mean"], linewidth=2.2, linestyle="--", label="Mean (smoothed)")
ax1.set_xlabel("Episode")
ax1.set_ylabel("Reward")
ax1.set_title("Training Reward Progression")
ax1.legend(framealpha=0.2, labelcolor="white", facecolor="#161b22", edgecolor="#30363d")
ax1.set_xlim(0, max(iters))
loss_iters = [iters[i] for i, v in enumerate(loss) if v is not None]
loss_vals = [v for v in loss if v is not None]
ax2.plot(loss_iters, loss_vals, color="#a78bfa", linewidth=1.4, marker="o", markersize=3, label="GRPO loss")
ax2.set_xlabel("Episode")
ax2.set_ylabel("Loss")
ax2.set_title("Training Loss (GRPO)")
ax2.legend(framealpha=0.2, labelcolor="white", facecolor="#161b22", edgecolor="#30363d")
ax2.set_xlim(0, max(iters))
plt.savefig(OUT, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
print(f"Saved {OUT}")
if __name__ == "__main__":
main()