File size: 3,652 Bytes
c395f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc9954
 
 
 
 
c395f6a
 
 
 
 
 
 
 
bdc9954
c395f6a
bdc9954
c395f6a
 
bdc9954
c395f6a
 
bdc9954
c395f6a
bdc9954
 
c395f6a
 
bdc9954
 
 
 
c395f6a
bdc9954
 
c395f6a
bdc9954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c395f6a
bdc9954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c395f6a
 
bdc9954
 
c395f6a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
"""
GridMind-RL Training Curve Plotter
----------------------------------
Reads the training CSV generated by train_unsloth.py and creates a 
beautiful PNG plot of the reward components to prove learning.
Also overlays baseline reference lines.
"""

import argparse
import os
import json
import pandas as pd
import matplotlib.pyplot as plt

def load_heuristic_scores():
    """Load heuristic baseline scores."""
    path = "results/baseline_scores_heuristic.json"
    if os.path.exists(path):
        with open(path) as f:
            return json.load(f)
    return None

def main():
    parser = argparse.ArgumentParser(description="Plot training learning curves")
    parser.add_argument("--csv", type=str, default="results/training_log.csv", help="Path to training CSV")
    parser.add_argument("--output", type=str, default="results/training_curve.png", help="Path to save PNG")
    args = parser.parse_args()
    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)

    heuristic_data = load_heuristic_scores()

    if not os.path.exists(args.csv):
        print("No CSV found.")
        return

    print(f"Reading training logs from {args.csv}")
    df = pd.read_csv(args.csv)
    if "step" not in df.columns:
        print("No 'step' column found.")
        return

    # Get baseline scores from our real runs
    h_avg = 0.514  # overall heuristic average from real runs
    if heuristic_data:
        h_avg = heuristic_data.get("overall_average", 0.514)

    plt.style.use("dark_background")
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: Episode score (from /grade)
    ax = axes[0]
    episode_col = "rewards/reward_env_interaction/mean"
    if episode_col in df.columns:
        raw = df[episode_col]
        smooth = raw.rolling(window=5, min_periods=1).mean()
        ax.plot(df["step"], raw, alpha=0.25, color="#4ECDC4", label="Raw")
        ax.plot(df["step"], smooth, color="#4ECDC4", linewidth=2.5, label="Trained LLM (smoothed)")
        ax.axhline(y=h_avg, color="#FF6B6B", linestyle="--", linewidth=2,
                   label=f"Heuristic baseline ({h_avg:.3f})")
        ax.set_xlabel("Training Step", fontsize=11, color="#e6edf3")
        ax.set_ylabel("Episode Score (0.0-1.0)", fontsize=11, color="#e6edf3")
        ax.set_title("Episode Score from /grade Endpoint\n(Higher = Better Energy Management)",
                     fontsize=12, color="#e6edf3")
        ax.legend(fontsize=10)
        ax.grid(True, linestyle="--", alpha=0.3, color="#8b949e")
        ax.set_ylim(0.35, 0.75)
        print(f"Episode score: {raw.iloc[0]:.3f} -> {smooth.dropna().iloc[-1]:.3f}")

    # Right: JSON validity
    ax2 = axes[1]
    json_col = "rewards/reward_json_valid/mean"
    if json_col in df.columns:
        raw = df[json_col]
        smooth = raw.rolling(window=5, min_periods=1).mean()
        ax2.plot(df["step"], raw, alpha=0.25, color="#FFE66D", label="Raw")
        ax2.plot(df["step"], smooth, color="#FFE66D", linewidth=2.5, label="JSON Validity (smoothed)")
        ax2.set_xlabel("Training Step", fontsize=11, color="#e6edf3")
        ax2.set_ylabel("JSON Format Reward (0.0-0.2)", fontsize=11, color="#e6edf3")
        ax2.set_title("Action Format Compliance\n(Higher = Better JSON Output)",
                      fontsize=12, color="#e6edf3")
        ax2.legend(fontsize=10)
        ax2.grid(True, linestyle="--", alpha=0.3, color="#8b949e")
        ax2.set_ylim(0, 0.22)

    plt.tight_layout()
    plt.savefig(args.output, dpi=150, bbox_inches="tight", facecolor="#0d1117")
    print(f"Training curve saved to {args.output}")

if __name__ == "__main__":
    main()