Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| GridMind-RL Training Curve Plotter | |
| ---------------------------------- | |
| Reads the training CSV generated by train_unsloth.py and creates a | |
| beautiful PNG plot of the reward components to prove learning. | |
| Also overlays baseline reference lines. | |
| """ | |
| import argparse | |
| import os | |
| import json | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| def load_heuristic_scores(): | |
| """Load heuristic baseline scores.""" | |
| path = "results/baseline_scores_heuristic.json" | |
| if os.path.exists(path): | |
| with open(path) as f: | |
| return json.load(f) | |
| return None | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Plot training learning curves") | |
| parser.add_argument("--csv", type=str, default="results/training_log.csv", help="Path to training CSV") | |
| parser.add_argument("--output", type=str, default="results/training_curve.png", help="Path to save PNG") | |
| args = parser.parse_args() | |
| os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) | |
| heuristic_data = load_heuristic_scores() | |
| if not os.path.exists(args.csv): | |
| print("No CSV found.") | |
| return | |
| print(f"Reading training logs from {args.csv}") | |
| df = pd.read_csv(args.csv) | |
| if "step" not in df.columns: | |
| print("No 'step' column found.") | |
| return | |
| # Get baseline scores from our real runs | |
| h_avg = 0.514 # overall heuristic average from real runs | |
| if heuristic_data: | |
| h_avg = heuristic_data.get("overall_average", 0.514) | |
| plt.style.use("dark_background") | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Left: Episode score (from /grade) | |
| ax = axes[0] | |
| episode_col = "rewards/reward_env_interaction/mean" | |
| if episode_col in df.columns: | |
| raw = df[episode_col] | |
| smooth = raw.rolling(window=5, min_periods=1).mean() | |
| ax.plot(df["step"], raw, alpha=0.25, color="#4ECDC4", label="Raw") | |
| ax.plot(df["step"], smooth, color="#4ECDC4", linewidth=2.5, label="Trained LLM (smoothed)") | |
| ax.axhline(y=h_avg, color="#FF6B6B", linestyle="--", linewidth=2, | |
| label=f"Heuristic baseline ({h_avg:.3f})") | |
| ax.set_xlabel("Training Step", fontsize=11, color="#e6edf3") | |
| ax.set_ylabel("Episode Score (0.0-1.0)", fontsize=11, color="#e6edf3") | |
| ax.set_title("Episode Score from /grade Endpoint\n(Higher = Better Energy Management)", | |
| fontsize=12, color="#e6edf3") | |
| ax.legend(fontsize=10) | |
| ax.grid(True, linestyle="--", alpha=0.3, color="#8b949e") | |
| ax.set_ylim(0.35, 0.75) | |
| print(f"Episode score: {raw.iloc[0]:.3f} -> {smooth.dropna().iloc[-1]:.3f}") | |
| # Right: JSON validity | |
| ax2 = axes[1] | |
| json_col = "rewards/reward_json_valid/mean" | |
| if json_col in df.columns: | |
| raw = df[json_col] | |
| smooth = raw.rolling(window=5, min_periods=1).mean() | |
| ax2.plot(df["step"], raw, alpha=0.25, color="#FFE66D", label="Raw") | |
| ax2.plot(df["step"], smooth, color="#FFE66D", linewidth=2.5, label="JSON Validity (smoothed)") | |
| ax2.set_xlabel("Training Step", fontsize=11, color="#e6edf3") | |
| ax2.set_ylabel("JSON Format Reward (0.0-0.2)", fontsize=11, color="#e6edf3") | |
| ax2.set_title("Action Format Compliance\n(Higher = Better JSON Output)", | |
| fontsize=12, color="#e6edf3") | |
| ax2.legend(fontsize=10) | |
| ax2.grid(True, linestyle="--", alpha=0.3, color="#8b949e") | |
| ax2.set_ylim(0, 0.22) | |
| plt.tight_layout() | |
| plt.savefig(args.output, dpi=150, bbox_inches="tight", facecolor="#0d1117") | |
| print(f"Training curve saved to {args.output}") | |
| if __name__ == "__main__": | |
| main() |