File size: 5,945 Bytes
02655d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

Reward vs Batch Size Scaling Visualization

Visualizes how reward scales with batch size across different model sizes

"""
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Tuple


def plot_reward_vs_batch_size(batch_sizes: List[int], 

                              rewards: List[float],

                              model_sizes: List[float],

                              output_file: str = 'reward_vs_batch_size.png'):
    """

    Create scatter plot showing reward vs batch size colored by model size

    

    Args:

        batch_sizes: List of batch sizes used

        rewards: List of corresponding rewards

        model_sizes: List of model size proxies

        output_file: Output filename for the plot

    """
    fig, ax = plt.subplots(figsize=(12, 7))
    
    scatter = ax.scatter(batch_sizes, rewards, c=model_sizes, 
                        s=100, alpha=0.6, cmap='viridis', edgecolors='black')
    
    # Add trend line
    z = np.polyfit(batch_sizes, rewards, 2)
    p = np.poly1d(z)
    x_trend = np.linspace(min(batch_sizes), max(batch_sizes), 100)
    ax.plot(x_trend, p(x_trend), "r--", alpha=0.8, linewidth=2, label='Trend')
    
    ax.set_xlabel('Batch Size', fontsize=12, fontweight='bold')
    ax.set_ylabel('Reward', fontsize=12, fontweight='bold')
    ax.set_title('Reward vs Batch Size Scaling\n(Colored by Model Size)', 
                 fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend()
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Model Size Proxy', fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"βœ“ Reward vs batch size saved to {output_file}")
    plt.close()


def plot_scaling_law_validation(model_sizes: List[float],

                                optimal_batch_sizes: List[int],

                                output_file: str = 'scaling_law_validation.png'):
    """

    Validate batch_size ∝ √(model_size) scaling law

    

    Args:

        model_sizes: List of model size proxies

        optimal_batch_sizes: List of computed optimal batch sizes

        output_file: Output filename for the plot

    """
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot actual data
    ax.scatter(model_sizes, optimal_batch_sizes, s=100, alpha=0.7, 
              label='Actual', color='#3498db', edgecolors='black')
    
    # Plot theoretical scaling law
    base_batch = optimal_batch_sizes[0] / np.sqrt(model_sizes[0])
    theoretical = [base_batch * np.sqrt(m) for m in model_sizes]
    ax.plot(model_sizes, theoretical, 'r--', linewidth=2, 
           label='Theoretical: batch ∝ √(model_size)')
    
    ax.set_xlabel('Model Size Proxy', fontsize=12, fontweight='bold')
    ax.set_ylabel('Optimal Batch Size', fontsize=12, fontweight='bold')
    ax.set_title('Scaling Law Validation\nbatch_size ∝ √(model_size)', 
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, linestyle='--')
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"βœ“ Scaling law validation saved to {output_file}")
    plt.close()


def plot_compute_efficiency_heatmap(batch_sizes: List[int],

                                   model_sizes: List[float],

                                   efficiencies: np.ndarray,

                                   output_file: str = 'compute_efficiency_heatmap.png'):
    """

    Create heatmap of compute efficiency across batch sizes and model sizes

    

    Args:

        batch_sizes: List of batch sizes

        model_sizes: List of model sizes

        efficiencies: 2D array of compute efficiencies

        output_file: Output filename for the plot

    """
    fig, ax = plt.subplots(figsize=(10, 8))
    
    im = ax.imshow(efficiencies, cmap='RdYlGn', aspect='auto', 
                   interpolation='nearest')
    
    ax.set_xticks(np.arange(len(batch_sizes)))
    ax.set_yticks(np.arange(len(model_sizes)))
    ax.set_xticklabels(batch_sizes)
    ax.set_yticklabels([f'{m:.2f}' for m in model_sizes])
    
    ax.set_xlabel('Batch Size', fontsize=12, fontweight='bold')
    ax.set_ylabel('Model Size Proxy', fontsize=12, fontweight='bold')
    ax.set_title('Compute Efficiency Heatmap\n(Reward per Second)', 
                 fontsize=14, fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Efficiency (reward/sec)', fontsize=11, fontweight='bold')
    
    # Add text annotations
    for i in range(len(model_sizes)):
        for j in range(len(batch_sizes)):
            text = ax.text(j, i, f'{efficiencies[i, j]:.2f}',
                          ha="center", va="center", color="black", fontsize=8)
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"βœ“ Compute efficiency heatmap saved to {output_file}")
    plt.close()


if __name__ == '__main__':
    # Example usage
    np.random.seed(42)
    
    # Generate sample data
    batch_sizes = [4, 6, 8, 10, 12, 14, 16, 18, 20]
    model_sizes = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]
    rewards = [0.70 + 0.05 * np.sqrt(b) + 0.02 * np.random.randn() 
               for b in batch_sizes]
    
    plot_reward_vs_batch_size(batch_sizes, rewards, model_sizes)
    
    # Scaling law validation
    optimal_batch_sizes = [int(8 * np.sqrt(m)) for m in model_sizes]
    plot_scaling_law_validation(model_sizes, optimal_batch_sizes)
    
    # Compute efficiency heatmap
    efficiencies = np.random.uniform(5, 12, (len(model_sizes), len(batch_sizes)))
    plot_compute_efficiency_heatmap(batch_sizes, model_sizes, efficiencies)