resmlp_comparison / experiment_fair.py
AmberLJC's picture
Upload experiment_fair.py with huggingface_hub
3f891b2 verified
"""
PlainMLP vs ResMLP Comparison - FAIR INITIALIZATION VERSION
This version uses IDENTICAL initialization for both models to ensure
a fair comparison. Both use:
- Kaiming He initialization
- Weight scaling by 1/sqrt(num_layers)
- Zero bias initialization
The ONLY difference is the residual connection: x = x + f(x) vs x = f(x)
"""
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
import json
import os
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Configuration
NUM_LAYERS = 20
HIDDEN_DIM = 64
NUM_SAMPLES = 1024
TRAINING_STEPS = 500
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}")
print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}")
print(f"[Config] FAIR COMPARISON: Both models use identical initialization")
class PlainMLP(nn.Module):
"""Plain MLP: x = ReLU(Linear(x)) for each layer
NOW WITH SAME INITIALIZATION AS ResMLP:
- Kaiming He initialization
- Weight scaling by 1/sqrt(num_layers)
- Zero bias
"""
def __init__(self, dim: int, num_layers: int):
super().__init__()
self.layers = nn.ModuleList()
self.num_layers = num_layers
for _ in range(num_layers):
layer = nn.Linear(dim, dim)
# SAME initialization as ResMLP
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
layer.weight.data *= 1.0 / np.sqrt(num_layers) # Same scaling!
nn.init.zeros_(layer.bias)
self.layers.append(layer)
self.activation = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = self.activation(layer(x)) # NO residual connection
return x
class ResMLP(nn.Module):
"""Residual MLP: x = x + ReLU(Linear(x)) for each layer
Uses same initialization as PlainMLP:
- Kaiming He initialization
- Weight scaling by 1/sqrt(num_layers)
- Zero bias
"""
def __init__(self, dim: int, num_layers: int):
super().__init__()
self.layers = nn.ModuleList()
self.num_layers = num_layers
for _ in range(num_layers):
layer = nn.Linear(dim, dim)
# Same initialization as PlainMLP
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
layer.weight.data *= 1.0 / np.sqrt(num_layers) # Same scaling
nn.init.zeros_(layer.bias)
self.layers.append(layer)
self.activation = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = x + self.activation(layer(x)) # WITH residual connection
return x
def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate synthetic data where Y = X, with X ~ U(-1, 1)"""
X = torch.empty(num_samples, dim).uniform_(-1, 1)
Y = X.clone()
return X, Y
def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor,
steps: int, lr: float, batch_size: int) -> List[float]:
"""Train model and record loss at each step"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
losses = []
num_samples = X.shape[0]
for step in range(steps):
# Random batch sampling
indices = torch.randint(0, num_samples, (batch_size,))
batch_x = X[indices]
batch_y = Y[indices]
# Forward pass
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
# Backward pass
loss.backward()
optimizer.step()
losses.append(loss.item())
if step % 100 == 0:
print(f" Step {step}/{steps}, Loss: {loss.item():.6f}")
return losses
class ActivationGradientHook:
"""Hook to capture activations and gradients at each layer"""
def __init__(self):
self.activations: List[torch.Tensor] = []
self.gradients: List[torch.Tensor] = []
self.handles = []
def register_hooks(self, model: nn.Module):
"""Register forward and backward hooks on each layer"""
for layer in model.layers:
handle_fwd = layer.register_forward_hook(self._forward_hook)
handle_bwd = layer.register_full_backward_hook(self._backward_hook)
self.handles.extend([handle_fwd, handle_bwd])
def _forward_hook(self, module, input, output):
self.activations.append(output.detach().clone())
def _backward_hook(self, module, grad_input, grad_output):
self.gradients.append(grad_output[0].detach().clone())
def clear(self):
self.activations = []
self.gradients = []
def remove_hooks(self):
for handle in self.handles:
handle.remove()
self.handles = []
def get_activation_stats(self) -> Tuple[List[float], List[float]]:
"""Get mean and std of activations for each layer"""
means = [act.mean().item() for act in self.activations]
stds = [act.std().item() for act in self.activations]
return means, stds
def get_gradient_norms(self) -> List[float]:
"""Get L2 norm of gradients for each layer (in forward order)"""
norms = [grad.norm(2).item() for grad in reversed(self.gradients)]
return norms
def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict:
"""Perform forward/backward pass and capture activation/gradient stats"""
hook = ActivationGradientHook()
hook.register_hooks(model)
# Generate new random batch
X_test = torch.empty(batch_size, dim).uniform_(-1, 1)
Y_test = X_test.clone()
# Forward pass
model.zero_grad()
output = model(X_test)
loss = nn.MSELoss()(output, Y_test)
# Backward pass
loss.backward()
# Get statistics
act_means, act_stds = hook.get_activation_stats()
grad_norms = hook.get_gradient_norms()
hook.remove_hooks()
return {
'activation_means': act_means,
'activation_stds': act_stds,
'gradient_norms': grad_norms,
'final_loss': loss.item()
}
def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str):
"""Plot training loss curves for both models"""
fig, ax = plt.subplots(figsize=(10, 6))
steps = range(len(plain_losses))
ax.plot(steps, plain_losses, label='PlainMLP (20 layers)', color='#e74c3c',
alpha=0.8, linewidth=2)
ax.plot(steps, res_losses, label='ResMLP (20 layers)', color='#3498db',
alpha=0.8, linewidth=2)
ax.set_xlabel('Training Steps', fontsize=12)
ax.set_ylabel('MSE Loss', fontsize=12)
ax.set_title('Training Loss: PlainMLP vs ResMLP (FAIR Initialization)\nIdentity Task (Y = X)', fontsize=14)
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_yscale('log')
# Add final loss annotations
final_plain = plain_losses[-1]
final_res = res_losses[-1]
# Text box with final results
textstr = f'Final Loss:\n PlainMLP: {final_plain:.4f}\n ResMLP: {final_res:.4f}\n Improvement: {final_plain/final_res:.1f}x'
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
ax.text(0.02, 0.02, textstr, transform=ax.transAxes, fontsize=10,
verticalalignment='bottom', bbox=props)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"[Plot] Saved training loss plot to {save_path}")
def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str):
"""Plot gradient magnitude vs layer depth"""
fig, ax = plt.subplots(figsize=(10, 6))
layers = range(1, len(plain_grads) + 1)
ax.plot(layers, plain_grads, 'o-', label='PlainMLP', color='#e74c3c',
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
ax.plot(layers, res_grads, 's-', label='ResMLP', color='#3498db',
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
ax.set_xlabel('Layer Depth (1 = first layer, 20 = last layer)', fontsize=12)
ax.set_ylabel('Gradient L2 Norm (log scale)', fontsize=12)
ax.set_title('Gradient Magnitude vs Layer Depth (Fair Initialization)', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_yscale('log')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"[Plot] Saved gradient magnitude plot to {save_path}")
def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str):
"""Plot activation mean vs layer depth"""
fig, ax = plt.subplots(figsize=(10, 6))
layers = range(1, len(plain_means) + 1)
ax.plot(layers, plain_means, 'o-', label='PlainMLP', color='#e74c3c',
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
ax.plot(layers, res_means, 's-', label='ResMLP', color='#3498db',
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax.set_xlabel('Layer Depth', fontsize=12)
ax.set_ylabel('Activation Mean', fontsize=12)
ax.set_title('Activation Mean vs Layer Depth (Fair Initialization)', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"[Plot] Saved activation mean plot to {save_path}")
def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str):
"""Plot activation std vs layer depth"""
fig, ax = plt.subplots(figsize=(10, 6))
layers = range(1, len(plain_stds) + 1)
ax.plot(layers, plain_stds, 'o-', label='PlainMLP', color='#e74c3c',
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
ax.plot(layers, res_stds, 's-', label='ResMLP', color='#3498db',
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
ax.set_xlabel('Layer Depth', fontsize=12)
ax.set_ylabel('Activation Standard Deviation', fontsize=12)
ax.set_title('Activation Std vs Layer Depth (Fair Initialization)', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"[Plot] Saved activation std plot to {save_path}")
def main():
print("=" * 70)
print("PlainMLP vs ResMLP: FAIR COMPARISON (Identical Initialization)")
print("=" * 70)
# Ensure plots directory exists
os.makedirs('plots_fair', exist_ok=True)
# Generate synthetic data
print("\n[1] Generating synthetic identity data...")
X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM)
print(f" Data shape: X={X.shape}, Y={Y.shape}")
print(f" X range: [{X.min():.3f}, {X.max():.3f}]")
print(f" Task: Learn Y = X (identity mapping)")
# Initialize models
print("\n[2] Initializing models with IDENTICAL initialization...")
print(" Both use: Kaiming He + 1/sqrt(num_layers) scaling + zero bias")
plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS)
res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS)
plain_params = sum(p.numel() for p in plain_mlp.parameters())
res_params = sum(p.numel() for p in res_mlp.parameters())
print(f" PlainMLP parameters: {plain_params:,}")
print(f" ResMLP parameters: {res_params:,}")
# Verify initialization is the same
print("\n Verifying initialization parity...")
plain_w_norm = sum(p.norm().item() for p in plain_mlp.parameters())
res_w_norm = sum(p.norm().item() for p in res_mlp.parameters())
print(f" PlainMLP total weight norm: {plain_w_norm:.4f}")
print(f" ResMLP total weight norm: {res_w_norm:.4f}")
# Train PlainMLP
print("\n[3] Training PlainMLP...")
plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
print(f" Final loss: {plain_losses[-1]:.6f}")
# Train ResMLP
print("\n[4] Training ResMLP...")
res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
print(f" Final loss: {res_losses[-1]:.6f}")
# Calculate improvement
improvement = plain_losses[-1] / res_losses[-1]
print(f"\n >>> ResMLP achieves {improvement:.1f}x lower loss than PlainMLP <<<")
# Final state analysis
print("\n[5] Analyzing final state of trained models...")
print(" Running forward/backward pass on new random batch...")
print(" Analyzing PlainMLP...")
plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM)
print(" Analyzing ResMLP...")
res_stats = analyze_final_state(res_mlp, HIDDEN_DIM)
# Print detailed analysis
print("\n[6] Detailed Analysis:")
print("\n === Loss Comparison ===")
print(f" PlainMLP - Initial: {plain_losses[0]:.4f}, Final: {plain_losses[-1]:.4f}")
print(f" ResMLP - Initial: {res_losses[0]:.4f}, Final: {res_losses[-1]:.4f}")
print("\n === Gradient Flow (L2 norms) ===")
print(f" PlainMLP - Layer 1: {plain_stats['gradient_norms'][0]:.2e}, Layer 20: {plain_stats['gradient_norms'][-1]:.2e}")
print(f" ResMLP - Layer 1: {res_stats['gradient_norms'][0]:.2e}, Layer 20: {res_stats['gradient_norms'][-1]:.2e}")
print("\n === Activation Statistics ===")
print(f" PlainMLP - Std range: [{min(plain_stats['activation_stds']):.4f}, {max(plain_stats['activation_stds']):.4f}]")
print(f" ResMLP - Std range: [{min(res_stats['activation_stds']):.4f}, {max(res_stats['activation_stds']):.4f}]")
# Generate plots
print("\n[7] Generating plots...")
plot_training_loss(plain_losses, res_losses, 'plots_fair/training_loss.png')
plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'],
'plots_fair/gradient_magnitude.png')
plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'],
'plots_fair/activation_mean.png')
plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'],
'plots_fair/activation_std.png')
# Save results to JSON
results = {
'config': {
'num_layers': NUM_LAYERS,
'hidden_dim': HIDDEN_DIM,
'num_samples': NUM_SAMPLES,
'training_steps': TRAINING_STEPS,
'learning_rate': LEARNING_RATE,
'batch_size': BATCH_SIZE,
'initialization': 'Kaiming He + 1/sqrt(num_layers) scaling (IDENTICAL for both)'
},
'plain_mlp': {
'final_loss': plain_losses[-1],
'initial_loss': plain_losses[0],
'loss_history': plain_losses,
'gradient_norms': plain_stats['gradient_norms'],
'activation_means': plain_stats['activation_means'],
'activation_stds': plain_stats['activation_stds']
},
'res_mlp': {
'final_loss': res_losses[-1],
'initial_loss': res_losses[0],
'loss_history': res_losses,
'gradient_norms': res_stats['gradient_norms'],
'activation_means': res_stats['activation_means'],
'activation_stds': res_stats['activation_stds']
},
'summary': {
'loss_improvement': improvement,
'plain_grad_range': [min(plain_stats['gradient_norms']), max(plain_stats['gradient_norms'])],
'res_grad_range': [min(res_stats['gradient_norms']), max(res_stats['gradient_norms'])],
'plain_std_range': [min(plain_stats['activation_stds']), max(plain_stats['activation_stds'])],
'res_std_range': [min(res_stats['activation_stds']), max(res_stats['activation_stds'])]
}
}
with open('results_fair.json', 'w') as f:
json.dump(results, f, indent=2)
print("\n[8] Results saved to results_fair.json")
print("\n" + "=" * 70)
print("FAIR COMPARISON Experiment completed successfully!")
print("=" * 70)
return results
if __name__ == "__main__":
results = main()