import json import matplotlib.pyplot as plt from pathlib import Path state_path = "/teamspace/studios/this_studio/wedding-model/checkpoint-300/trainer_state.json" with open(state_path) as f: state = json.load(f) history = state.get("log_history", []) steps = [] rewards = [] for entry in history: if "reward" in entry and "step" in entry: steps.append(entry["step"]) rewards.append(entry["reward"]) plt.figure(figsize=(10,4)) plt.plot(steps, rewards, marker='o', color='b') plt.xlabel("Training step") plt.ylabel("Episode reward") plt.title("Wedding Planner Agent — reward over training") plt.grid(True) out_dir = Path("/teamspace/studios/this_studio/wedding-planner-env/assets") out_dir.mkdir(exist_ok=True) out_path = out_dir / "reward_curve.png" plt.savefig(out_path, dpi=150) print(f"Plot saved to {out_path}")