File size: 4,139 Bytes
12263fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc5996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Cloud Arena Visualization — Mathematical Model

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np


REF_BG = '#0e1117'
REF_CYAN = '#00d4ff'
REF_AMBER = '#ffa500'
REF_NEON = '#39ff14'
TEXT_COLOR = '#e6e6e6'


def smooth(y, box_pts=50):
    if len(y) < box_pts:
        return y
    box = np.ones(box_pts) / box_pts
    return np.convolve(y, box, mode='valid')


def generate_dashboard(callback, output_path="outputs/training_dashboard.png"):
    rewards = np.array(callback.episode_rewards)
    savings = np.array(callback.episode_savings)
    security = np.array(callback.episode_security)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(22, 6), facecolor=REF_BG)

    for ax in [ax1, ax2, ax3]:
        ax.set_facecolor(REF_BG)
        ax.grid(True, alpha=0.05, color='white')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_color('#333333')
        ax.spines['bottom'].set_color('#333333')
        ax.tick_params(colors=TEXT_COLOR, labelsize=10)

    ax1.plot(rewards, color=REF_CYAN, alpha=0.15)
    ax1.plot(smooth(rewards), color=REF_CYAN, lw=3)
    ax1.set_title("Learning Curve", color=TEXT_COLOR, fontsize=14, fontweight='bold')

    ax2.plot(savings, color=REF_AMBER, alpha=0.15)
    ax2.plot(smooth(savings), color=REF_AMBER, lw=3)
    ax2.set_title("Cost Optimization %", color=TEXT_COLOR, fontsize=14, fontweight='bold')
    ax2.set_ylim(0, 100)

    ax3.plot(security, color=REF_NEON, alpha=0.15)
    ax3.plot(smooth(security), color=REF_NEON, lw=3)
    ax3.set_title("Security Score", color=TEXT_COLOR, fontsize=14, fontweight='bold')
    ax3.set_ylim(0, 1)

    plt.tight_layout()
    plt.savefig(output_path, dpi=200, bbox_inches='tight', facecolor=REF_BG)
    plt.close()
    return output_path


def generate_grpo_dashboard(all_results, all_stats, output_path="outputs/grpo_dashboard.png"):
    fig, axs = plt.subplots(2, 2, figsize=(16, 10), facecolor=REF_BG)
    ax1, ax2, ax3, ax4 = axs.flatten()
    for ax in [ax1, ax2, ax3, ax4]:
        ax.set_facecolor(REF_BG)
        ax.grid(True, alpha=0.08, color="white")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_color("#333333")
        ax.spines["bottom"].set_color("#333333")
        ax.tick_params(colors=TEXT_COLOR, labelsize=9)

    palette = ["#00d4ff", "#ffa500", "#39ff14", "#ff6b6b", "#b47eff"]
    model_names = list(all_results.keys())
    for i, name in enumerate(model_names):
        c = palette[i % len(palette)]
        rewards = all_results[name]
        ax1.plot(smooth(np.array(rewards), box_pts=min(20, max(3, len(rewards) // 5))), color=c, lw=2, label=name)

        kl_curve = [s.get("kl", 0.0) for s in all_stats.get(name, [])]
        ent_curve = [s.get("entropy", 0.0) for s in all_stats.get(name, [])]
        veto_curve = [s.get("veto_rate", 0.0) for s in all_stats.get(name, [])]

        ax2.plot(kl_curve, color=c, lw=1.8, label=name)
        ax3.plot(ent_curve, color=c, lw=1.8, label=name)
        ax4.plot(veto_curve, color=c, lw=1.8, label=name)

    ax1.set_title("GRPO Reward (Smoothed)", color=TEXT_COLOR, fontsize=12, fontweight="bold")
    ax1.set_xlabel("Episode", color=TEXT_COLOR)
    ax1.set_ylabel("Reward", color=TEXT_COLOR)
    ax1.legend(facecolor="#1a1a2e", edgecolor="#333", labelcolor=TEXT_COLOR, fontsize=8)

    ax2.set_title("KL Trend", color=TEXT_COLOR, fontsize=12, fontweight="bold")
    ax2.set_xlabel("Episode", color=TEXT_COLOR)
    ax2.set_ylabel("KL", color=TEXT_COLOR)

    ax3.set_title("Entropy Trend", color=TEXT_COLOR, fontsize=12, fontweight="bold")
    ax3.set_xlabel("Episode", color=TEXT_COLOR)
    ax3.set_ylabel("Entropy", color=TEXT_COLOR)

    ax4.set_title("Safety Violation / Veto Rate", color=TEXT_COLOR, fontsize=12, fontweight="bold")
    ax4.set_xlabel("Episode", color=TEXT_COLOR)
    ax4.set_ylabel("Rate", color=TEXT_COLOR)
    ax4.set_ylim(0, 1)

    plt.tight_layout()
    plt.savefig(output_path, dpi=200, bbox_inches="tight", facecolor=REF_BG)
    plt.close()
    return output_path