|
|
|
|
|
|
|
|
"""
|
|
|
Reward vs Batch Size Scaling Visualization
|
|
|
Visualizes how reward scales with batch size across different model sizes
|
|
|
"""
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
|
|
|
def plot_reward_vs_batch_size(batch_sizes: List[int],
|
|
|
rewards: List[float],
|
|
|
model_sizes: List[float],
|
|
|
output_file: str = 'reward_vs_batch_size.png'):
|
|
|
"""
|
|
|
Create scatter plot showing reward vs batch size colored by model size
|
|
|
|
|
|
Args:
|
|
|
batch_sizes: List of batch sizes used
|
|
|
rewards: List of corresponding rewards
|
|
|
model_sizes: List of model size proxies
|
|
|
output_file: Output filename for the plot
|
|
|
"""
|
|
|
fig, ax = plt.subplots(figsize=(12, 7))
|
|
|
|
|
|
scatter = ax.scatter(batch_sizes, rewards, c=model_sizes,
|
|
|
s=100, alpha=0.6, cmap='viridis', edgecolors='black')
|
|
|
|
|
|
|
|
|
z = np.polyfit(batch_sizes, rewards, 2)
|
|
|
p = np.poly1d(z)
|
|
|
x_trend = np.linspace(min(batch_sizes), max(batch_sizes), 100)
|
|
|
ax.plot(x_trend, p(x_trend), "r--", alpha=0.8, linewidth=2, label='Trend')
|
|
|
|
|
|
ax.set_xlabel('Batch Size', fontsize=12, fontweight='bold')
|
|
|
ax.set_ylabel('Reward', fontsize=12, fontweight='bold')
|
|
|
ax.set_title('Reward vs Batch Size Scaling\n(Colored by Model Size)',
|
|
|
fontsize=14, fontweight='bold')
|
|
|
ax.grid(True, alpha=0.3, linestyle='--')
|
|
|
ax.legend()
|
|
|
|
|
|
|
|
|
cbar = plt.colorbar(scatter, ax=ax)
|
|
|
cbar.set_label('Model Size Proxy', fontsize=11, fontweight='bold')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
|
|
print(f"β Reward vs batch size saved to {output_file}")
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
def plot_scaling_law_validation(model_sizes: List[float],
|
|
|
optimal_batch_sizes: List[int],
|
|
|
output_file: str = 'scaling_law_validation.png'):
|
|
|
"""
|
|
|
Validate batch_size β β(model_size) scaling law
|
|
|
|
|
|
Args:
|
|
|
model_sizes: List of model size proxies
|
|
|
optimal_batch_sizes: List of computed optimal batch sizes
|
|
|
output_file: Output filename for the plot
|
|
|
"""
|
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
|
|
|
|
|
|
|
ax.scatter(model_sizes, optimal_batch_sizes, s=100, alpha=0.7,
|
|
|
label='Actual', color='#3498db', edgecolors='black')
|
|
|
|
|
|
|
|
|
base_batch = optimal_batch_sizes[0] / np.sqrt(model_sizes[0])
|
|
|
theoretical = [base_batch * np.sqrt(m) for m in model_sizes]
|
|
|
ax.plot(model_sizes, theoretical, 'r--', linewidth=2,
|
|
|
label='Theoretical: batch β β(model_size)')
|
|
|
|
|
|
ax.set_xlabel('Model Size Proxy', fontsize=12, fontweight='bold')
|
|
|
ax.set_ylabel('Optimal Batch Size', fontsize=12, fontweight='bold')
|
|
|
ax.set_title('Scaling Law Validation\nbatch_size β β(model_size)',
|
|
|
fontsize=14, fontweight='bold')
|
|
|
ax.legend(fontsize=11)
|
|
|
ax.grid(True, alpha=0.3, linestyle='--')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
|
|
print(f"β Scaling law validation saved to {output_file}")
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
def plot_compute_efficiency_heatmap(batch_sizes: List[int],
|
|
|
model_sizes: List[float],
|
|
|
efficiencies: np.ndarray,
|
|
|
output_file: str = 'compute_efficiency_heatmap.png'):
|
|
|
"""
|
|
|
Create heatmap of compute efficiency across batch sizes and model sizes
|
|
|
|
|
|
Args:
|
|
|
batch_sizes: List of batch sizes
|
|
|
model_sizes: List of model sizes
|
|
|
efficiencies: 2D array of compute efficiencies
|
|
|
output_file: Output filename for the plot
|
|
|
"""
|
|
|
fig, ax = plt.subplots(figsize=(10, 8))
|
|
|
|
|
|
im = ax.imshow(efficiencies, cmap='RdYlGn', aspect='auto',
|
|
|
interpolation='nearest')
|
|
|
|
|
|
ax.set_xticks(np.arange(len(batch_sizes)))
|
|
|
ax.set_yticks(np.arange(len(model_sizes)))
|
|
|
ax.set_xticklabels(batch_sizes)
|
|
|
ax.set_yticklabels([f'{m:.2f}' for m in model_sizes])
|
|
|
|
|
|
ax.set_xlabel('Batch Size', fontsize=12, fontweight='bold')
|
|
|
ax.set_ylabel('Model Size Proxy', fontsize=12, fontweight='bold')
|
|
|
ax.set_title('Compute Efficiency Heatmap\n(Reward per Second)',
|
|
|
fontsize=14, fontweight='bold')
|
|
|
|
|
|
|
|
|
cbar = plt.colorbar(im, ax=ax)
|
|
|
cbar.set_label('Efficiency (reward/sec)', fontsize=11, fontweight='bold')
|
|
|
|
|
|
|
|
|
for i in range(len(model_sizes)):
|
|
|
for j in range(len(batch_sizes)):
|
|
|
text = ax.text(j, i, f'{efficiencies[i, j]:.2f}',
|
|
|
ha="center", va="center", color="black", fontsize=8)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
|
|
print(f"β Compute efficiency heatmap saved to {output_file}")
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
np.random.seed(42)
|
|
|
|
|
|
|
|
|
batch_sizes = [4, 6, 8, 10, 12, 14, 16, 18, 20]
|
|
|
model_sizes = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]
|
|
|
rewards = [0.70 + 0.05 * np.sqrt(b) + 0.02 * np.random.randn()
|
|
|
for b in batch_sizes]
|
|
|
|
|
|
plot_reward_vs_batch_size(batch_sizes, rewards, model_sizes)
|
|
|
|
|
|
|
|
|
optimal_batch_sizes = [int(8 * np.sqrt(m)) for m in model_sizes]
|
|
|
plot_scaling_law_validation(model_sizes, optimal_batch_sizes)
|
|
|
|
|
|
|
|
|
efficiencies = np.random.uniform(5, 12, (len(model_sizes), len(batch_sizes)))
|
|
|
plot_compute_efficiency_heatmap(batch_sizes, model_sizes, efficiencies)
|
|
|
|