Gridmind / scripts /plot_results.py
adityss's picture
fix: update training script with seed variation, fix reward normalization, regenerate training curves showing 0.52->0.67 improvement
bdc9954
#!/usr/bin/env python3
"""
GridMind-RL Training Curve Plotter
----------------------------------
Reads the training CSV generated by train_unsloth.py and creates a
beautiful PNG plot of the reward components to prove learning.
Also overlays baseline reference lines.
"""
import argparse
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
def load_heuristic_scores():
"""Load heuristic baseline scores."""
path = "results/baseline_scores_heuristic.json"
if os.path.exists(path):
with open(path) as f:
return json.load(f)
return None
def main():
parser = argparse.ArgumentParser(description="Plot training learning curves")
parser.add_argument("--csv", type=str, default="results/training_log.csv", help="Path to training CSV")
parser.add_argument("--output", type=str, default="results/training_curve.png", help="Path to save PNG")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
heuristic_data = load_heuristic_scores()
if not os.path.exists(args.csv):
print("No CSV found.")
return
print(f"Reading training logs from {args.csv}")
df = pd.read_csv(args.csv)
if "step" not in df.columns:
print("No 'step' column found.")
return
# Get baseline scores from our real runs
h_avg = 0.514 # overall heuristic average from real runs
if heuristic_data:
h_avg = heuristic_data.get("overall_average", 0.514)
plt.style.use("dark_background")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left: Episode score (from /grade)
ax = axes[0]
episode_col = "rewards/reward_env_interaction/mean"
if episode_col in df.columns:
raw = df[episode_col]
smooth = raw.rolling(window=5, min_periods=1).mean()
ax.plot(df["step"], raw, alpha=0.25, color="#4ECDC4", label="Raw")
ax.plot(df["step"], smooth, color="#4ECDC4", linewidth=2.5, label="Trained LLM (smoothed)")
ax.axhline(y=h_avg, color="#FF6B6B", linestyle="--", linewidth=2,
label=f"Heuristic baseline ({h_avg:.3f})")
ax.set_xlabel("Training Step", fontsize=11, color="#e6edf3")
ax.set_ylabel("Episode Score (0.0-1.0)", fontsize=11, color="#e6edf3")
ax.set_title("Episode Score from /grade Endpoint\n(Higher = Better Energy Management)",
fontsize=12, color="#e6edf3")
ax.legend(fontsize=10)
ax.grid(True, linestyle="--", alpha=0.3, color="#8b949e")
ax.set_ylim(0.35, 0.75)
print(f"Episode score: {raw.iloc[0]:.3f} -> {smooth.dropna().iloc[-1]:.3f}")
# Right: JSON validity
ax2 = axes[1]
json_col = "rewards/reward_json_valid/mean"
if json_col in df.columns:
raw = df[json_col]
smooth = raw.rolling(window=5, min_periods=1).mean()
ax2.plot(df["step"], raw, alpha=0.25, color="#FFE66D", label="Raw")
ax2.plot(df["step"], smooth, color="#FFE66D", linewidth=2.5, label="JSON Validity (smoothed)")
ax2.set_xlabel("Training Step", fontsize=11, color="#e6edf3")
ax2.set_ylabel("JSON Format Reward (0.0-0.2)", fontsize=11, color="#e6edf3")
ax2.set_title("Action Format Compliance\n(Higher = Better JSON Output)",
fontsize=12, color="#e6edf3")
ax2.legend(fontsize=10)
ax2.grid(True, linestyle="--", alpha=0.3, color="#8b949e")
ax2.set_ylim(0, 0.22)
plt.tight_layout()
plt.savefig(args.output, dpi=150, bbox_inches="tight", facecolor="#0d1117")
print(f"Training curve saved to {args.output}")
if __name__ == "__main__":
main()