File size: 5,945 Bytes
02655d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Reward vs Batch Size Scaling Visualization
Visualizes how reward scales with batch size across different model sizes
"""
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Tuple
def plot_reward_vs_batch_size(batch_sizes: List[int],
rewards: List[float],
model_sizes: List[float],
output_file: str = 'reward_vs_batch_size.png'):
"""
Create scatter plot showing reward vs batch size colored by model size
Args:
batch_sizes: List of batch sizes used
rewards: List of corresponding rewards
model_sizes: List of model size proxies
output_file: Output filename for the plot
"""
fig, ax = plt.subplots(figsize=(12, 7))
scatter = ax.scatter(batch_sizes, rewards, c=model_sizes,
s=100, alpha=0.6, cmap='viridis', edgecolors='black')
# Add trend line
z = np.polyfit(batch_sizes, rewards, 2)
p = np.poly1d(z)
x_trend = np.linspace(min(batch_sizes), max(batch_sizes), 100)
ax.plot(x_trend, p(x_trend), "r--", alpha=0.8, linewidth=2, label='Trend')
ax.set_xlabel('Batch Size', fontsize=12, fontweight='bold')
ax.set_ylabel('Reward', fontsize=12, fontweight='bold')
ax.set_title('Reward vs Batch Size Scaling\n(Colored by Model Size)',
fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, linestyle='--')
ax.legend()
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Model Size Proxy', fontsize=11, fontweight='bold')
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"β Reward vs batch size saved to {output_file}")
plt.close()
def plot_scaling_law_validation(model_sizes: List[float],
optimal_batch_sizes: List[int],
output_file: str = 'scaling_law_validation.png'):
"""
Validate batch_size β β(model_size) scaling law
Args:
model_sizes: List of model size proxies
optimal_batch_sizes: List of computed optimal batch sizes
output_file: Output filename for the plot
"""
fig, ax = plt.subplots(figsize=(10, 6))
# Plot actual data
ax.scatter(model_sizes, optimal_batch_sizes, s=100, alpha=0.7,
label='Actual', color='#3498db', edgecolors='black')
# Plot theoretical scaling law
base_batch = optimal_batch_sizes[0] / np.sqrt(model_sizes[0])
theoretical = [base_batch * np.sqrt(m) for m in model_sizes]
ax.plot(model_sizes, theoretical, 'r--', linewidth=2,
label='Theoretical: batch β β(model_size)')
ax.set_xlabel('Model Size Proxy', fontsize=12, fontweight='bold')
ax.set_ylabel('Optimal Batch Size', fontsize=12, fontweight='bold')
ax.set_title('Scaling Law Validation\nbatch_size β β(model_size)',
fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"β Scaling law validation saved to {output_file}")
plt.close()
def plot_compute_efficiency_heatmap(batch_sizes: List[int],
model_sizes: List[float],
efficiencies: np.ndarray,
output_file: str = 'compute_efficiency_heatmap.png'):
"""
Create heatmap of compute efficiency across batch sizes and model sizes
Args:
batch_sizes: List of batch sizes
model_sizes: List of model sizes
efficiencies: 2D array of compute efficiencies
output_file: Output filename for the plot
"""
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(efficiencies, cmap='RdYlGn', aspect='auto',
interpolation='nearest')
ax.set_xticks(np.arange(len(batch_sizes)))
ax.set_yticks(np.arange(len(model_sizes)))
ax.set_xticklabels(batch_sizes)
ax.set_yticklabels([f'{m:.2f}' for m in model_sizes])
ax.set_xlabel('Batch Size', fontsize=12, fontweight='bold')
ax.set_ylabel('Model Size Proxy', fontsize=12, fontweight='bold')
ax.set_title('Compute Efficiency Heatmap\n(Reward per Second)',
fontsize=14, fontweight='bold')
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Efficiency (reward/sec)', fontsize=11, fontweight='bold')
# Add text annotations
for i in range(len(model_sizes)):
for j in range(len(batch_sizes)):
text = ax.text(j, i, f'{efficiencies[i, j]:.2f}',
ha="center", va="center", color="black", fontsize=8)
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"β Compute efficiency heatmap saved to {output_file}")
plt.close()
if __name__ == '__main__':
# Example usage
np.random.seed(42)
# Generate sample data
batch_sizes = [4, 6, 8, 10, 12, 14, 16, 18, 20]
model_sizes = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]
rewards = [0.70 + 0.05 * np.sqrt(b) + 0.02 * np.random.randn()
for b in batch_sizes]
plot_reward_vs_batch_size(batch_sizes, rewards, model_sizes)
# Scaling law validation
optimal_batch_sizes = [int(8 * np.sqrt(m)) for m in model_sizes]
plot_scaling_law_validation(model_sizes, optimal_batch_sizes)
# Compute efficiency heatmap
efficiencies = np.random.uniform(5, 12, (len(model_sizes), len(batch_sizes)))
plot_compute_efficiency_heatmap(batch_sizes, model_sizes, efficiencies)
|