wedding-planner-7b / plot_reward.py
shivanandh033's picture
fine
318eb16
import json
import matplotlib.pyplot as plt
from pathlib import Path
state_path = "/teamspace/studios/this_studio/wedding-model/checkpoint-300/trainer_state.json"
with open(state_path) as f:
state = json.load(f)
history = state.get("log_history", [])
steps = []
rewards = []
for entry in history:
if "reward" in entry and "step" in entry:
steps.append(entry["step"])
rewards.append(entry["reward"])
plt.figure(figsize=(10,4))
plt.plot(steps, rewards, marker='o', color='b')
plt.xlabel("Training step")
plt.ylabel("Episode reward")
plt.title("Wedding Planner Agent — reward over training")
plt.grid(True)
out_dir = Path("/teamspace/studios/this_studio/wedding-planner-env/assets")
out_dir.mkdir(exist_ok=True)
out_path = out_dir / "reward_curve.png"
plt.savefig(out_path, dpi=150)
print(f"Plot saved to {out_path}")