""" PlainMLP vs ResMLP Comparison - FAIR INITIALIZATION VERSION This version uses IDENTICAL initialization for both models to ensure a fair comparison. Both use: - Kaiming He initialization - Weight scaling by 1/sqrt(num_layers) - Zero bias initialization The ONLY difference is the residual connection: x = x + f(x) vs x = f(x) """ import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt from typing import Dict, List, Tuple import json import os # Set random seeds for reproducibility torch.manual_seed(42) np.random.seed(42) # Configuration NUM_LAYERS = 20 HIDDEN_DIM = 64 NUM_SAMPLES = 1024 TRAINING_STEPS = 500 LEARNING_RATE = 1e-3 BATCH_SIZE = 64 print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}") print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}") print(f"[Config] FAIR COMPARISON: Both models use identical initialization") class PlainMLP(nn.Module): """Plain MLP: x = ReLU(Linear(x)) for each layer NOW WITH SAME INITIALIZATION AS ResMLP: - Kaiming He initialization - Weight scaling by 1/sqrt(num_layers) - Zero bias """ def __init__(self, dim: int, num_layers: int): super().__init__() self.layers = nn.ModuleList() self.num_layers = num_layers for _ in range(num_layers): layer = nn.Linear(dim, dim) # SAME initialization as ResMLP nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') layer.weight.data *= 1.0 / np.sqrt(num_layers) # Same scaling! nn.init.zeros_(layer.bias) self.layers.append(layer) self.activation = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = self.activation(layer(x)) # NO residual connection return x class ResMLP(nn.Module): """Residual MLP: x = x + ReLU(Linear(x)) for each layer Uses same initialization as PlainMLP: - Kaiming He initialization - Weight scaling by 1/sqrt(num_layers) - Zero bias """ def __init__(self, dim: int, num_layers: int): super().__init__() self.layers = nn.ModuleList() self.num_layers = num_layers for _ in range(num_layers): layer = nn.Linear(dim, dim) # Same initialization as PlainMLP nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') layer.weight.data *= 1.0 / np.sqrt(num_layers) # Same scaling nn.init.zeros_(layer.bias) self.layers.append(layer) self.activation = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = x + self.activation(layer(x)) # WITH residual connection return x def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: """Generate synthetic data where Y = X, with X ~ U(-1, 1)""" X = torch.empty(num_samples, dim).uniform_(-1, 1) Y = X.clone() return X, Y def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor, steps: int, lr: float, batch_size: int) -> List[float]: """Train model and record loss at each step""" optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.MSELoss() losses = [] num_samples = X.shape[0] for step in range(steps): # Random batch sampling indices = torch.randint(0, num_samples, (batch_size,)) batch_x = X[indices] batch_y = Y[indices] # Forward pass optimizer.zero_grad() output = model(batch_x) loss = criterion(output, batch_y) # Backward pass loss.backward() optimizer.step() losses.append(loss.item()) if step % 100 == 0: print(f" Step {step}/{steps}, Loss: {loss.item():.6f}") return losses class ActivationGradientHook: """Hook to capture activations and gradients at each layer""" def __init__(self): self.activations: List[torch.Tensor] = [] self.gradients: List[torch.Tensor] = [] self.handles = [] def register_hooks(self, model: nn.Module): """Register forward and backward hooks on each layer""" for layer in model.layers: handle_fwd = layer.register_forward_hook(self._forward_hook) handle_bwd = layer.register_full_backward_hook(self._backward_hook) self.handles.extend([handle_fwd, handle_bwd]) def _forward_hook(self, module, input, output): self.activations.append(output.detach().clone()) def _backward_hook(self, module, grad_input, grad_output): self.gradients.append(grad_output[0].detach().clone()) def clear(self): self.activations = [] self.gradients = [] def remove_hooks(self): for handle in self.handles: handle.remove() self.handles = [] def get_activation_stats(self) -> Tuple[List[float], List[float]]: """Get mean and std of activations for each layer""" means = [act.mean().item() for act in self.activations] stds = [act.std().item() for act in self.activations] return means, stds def get_gradient_norms(self) -> List[float]: """Get L2 norm of gradients for each layer (in forward order)""" norms = [grad.norm(2).item() for grad in reversed(self.gradients)] return norms def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict: """Perform forward/backward pass and capture activation/gradient stats""" hook = ActivationGradientHook() hook.register_hooks(model) # Generate new random batch X_test = torch.empty(batch_size, dim).uniform_(-1, 1) Y_test = X_test.clone() # Forward pass model.zero_grad() output = model(X_test) loss = nn.MSELoss()(output, Y_test) # Backward pass loss.backward() # Get statistics act_means, act_stds = hook.get_activation_stats() grad_norms = hook.get_gradient_norms() hook.remove_hooks() return { 'activation_means': act_means, 'activation_stds': act_stds, 'gradient_norms': grad_norms, 'final_loss': loss.item() } def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str): """Plot training loss curves for both models""" fig, ax = plt.subplots(figsize=(10, 6)) steps = range(len(plain_losses)) ax.plot(steps, plain_losses, label='PlainMLP (20 layers)', color='#e74c3c', alpha=0.8, linewidth=2) ax.plot(steps, res_losses, label='ResMLP (20 layers)', color='#3498db', alpha=0.8, linewidth=2) ax.set_xlabel('Training Steps', fontsize=12) ax.set_ylabel('MSE Loss', fontsize=12) ax.set_title('Training Loss: PlainMLP vs ResMLP (FAIR Initialization)\nIdentity Task (Y = X)', fontsize=14) ax.legend(fontsize=11, loc='upper right') ax.grid(True, alpha=0.3) ax.set_yscale('log') # Add final loss annotations final_plain = plain_losses[-1] final_res = res_losses[-1] # Text box with final results textstr = f'Final Loss:\n PlainMLP: {final_plain:.4f}\n ResMLP: {final_res:.4f}\n Improvement: {final_plain/final_res:.1f}x' props = dict(boxstyle='round', facecolor='wheat', alpha=0.8) ax.text(0.02, 0.02, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='bottom', bbox=props) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"[Plot] Saved training loss plot to {save_path}") def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str): """Plot gradient magnitude vs layer depth""" fig, ax = plt.subplots(figsize=(10, 6)) layers = range(1, len(plain_grads) + 1) ax.plot(layers, plain_grads, 'o-', label='PlainMLP', color='#e74c3c', markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) ax.plot(layers, res_grads, 's-', label='ResMLP', color='#3498db', markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) ax.set_xlabel('Layer Depth (1 = first layer, 20 = last layer)', fontsize=12) ax.set_ylabel('Gradient L2 Norm (log scale)', fontsize=12) ax.set_title('Gradient Magnitude vs Layer Depth (Fair Initialization)', fontsize=14) ax.legend(fontsize=11) ax.grid(True, alpha=0.3) ax.set_yscale('log') plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"[Plot] Saved gradient magnitude plot to {save_path}") def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str): """Plot activation mean vs layer depth""" fig, ax = plt.subplots(figsize=(10, 6)) layers = range(1, len(plain_means) + 1) ax.plot(layers, plain_means, 'o-', label='PlainMLP', color='#e74c3c', markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) ax.plot(layers, res_means, 's-', label='ResMLP', color='#3498db', markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, linewidth=1) ax.set_xlabel('Layer Depth', fontsize=12) ax.set_ylabel('Activation Mean', fontsize=12) ax.set_title('Activation Mean vs Layer Depth (Fair Initialization)', fontsize=14) ax.legend(fontsize=11) ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"[Plot] Saved activation mean plot to {save_path}") def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str): """Plot activation std vs layer depth""" fig, ax = plt.subplots(figsize=(10, 6)) layers = range(1, len(plain_stds) + 1) ax.plot(layers, plain_stds, 'o-', label='PlainMLP', color='#e74c3c', markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) ax.plot(layers, res_stds, 's-', label='ResMLP', color='#3498db', markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) ax.set_xlabel('Layer Depth', fontsize=12) ax.set_ylabel('Activation Standard Deviation', fontsize=12) ax.set_title('Activation Std vs Layer Depth (Fair Initialization)', fontsize=14) ax.legend(fontsize=11) ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"[Plot] Saved activation std plot to {save_path}") def main(): print("=" * 70) print("PlainMLP vs ResMLP: FAIR COMPARISON (Identical Initialization)") print("=" * 70) # Ensure plots directory exists os.makedirs('plots_fair', exist_ok=True) # Generate synthetic data print("\n[1] Generating synthetic identity data...") X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM) print(f" Data shape: X={X.shape}, Y={Y.shape}") print(f" X range: [{X.min():.3f}, {X.max():.3f}]") print(f" Task: Learn Y = X (identity mapping)") # Initialize models print("\n[2] Initializing models with IDENTICAL initialization...") print(" Both use: Kaiming He + 1/sqrt(num_layers) scaling + zero bias") plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS) res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS) plain_params = sum(p.numel() for p in plain_mlp.parameters()) res_params = sum(p.numel() for p in res_mlp.parameters()) print(f" PlainMLP parameters: {plain_params:,}") print(f" ResMLP parameters: {res_params:,}") # Verify initialization is the same print("\n Verifying initialization parity...") plain_w_norm = sum(p.norm().item() for p in plain_mlp.parameters()) res_w_norm = sum(p.norm().item() for p in res_mlp.parameters()) print(f" PlainMLP total weight norm: {plain_w_norm:.4f}") print(f" ResMLP total weight norm: {res_w_norm:.4f}") # Train PlainMLP print("\n[3] Training PlainMLP...") plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) print(f" Final loss: {plain_losses[-1]:.6f}") # Train ResMLP print("\n[4] Training ResMLP...") res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) print(f" Final loss: {res_losses[-1]:.6f}") # Calculate improvement improvement = plain_losses[-1] / res_losses[-1] print(f"\n >>> ResMLP achieves {improvement:.1f}x lower loss than PlainMLP <<<") # Final state analysis print("\n[5] Analyzing final state of trained models...") print(" Running forward/backward pass on new random batch...") print(" Analyzing PlainMLP...") plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM) print(" Analyzing ResMLP...") res_stats = analyze_final_state(res_mlp, HIDDEN_DIM) # Print detailed analysis print("\n[6] Detailed Analysis:") print("\n === Loss Comparison ===") print(f" PlainMLP - Initial: {plain_losses[0]:.4f}, Final: {plain_losses[-1]:.4f}") print(f" ResMLP - Initial: {res_losses[0]:.4f}, Final: {res_losses[-1]:.4f}") print("\n === Gradient Flow (L2 norms) ===") print(f" PlainMLP - Layer 1: {plain_stats['gradient_norms'][0]:.2e}, Layer 20: {plain_stats['gradient_norms'][-1]:.2e}") print(f" ResMLP - Layer 1: {res_stats['gradient_norms'][0]:.2e}, Layer 20: {res_stats['gradient_norms'][-1]:.2e}") print("\n === Activation Statistics ===") print(f" PlainMLP - Std range: [{min(plain_stats['activation_stds']):.4f}, {max(plain_stats['activation_stds']):.4f}]") print(f" ResMLP - Std range: [{min(res_stats['activation_stds']):.4f}, {max(res_stats['activation_stds']):.4f}]") # Generate plots print("\n[7] Generating plots...") plot_training_loss(plain_losses, res_losses, 'plots_fair/training_loss.png') plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'], 'plots_fair/gradient_magnitude.png') plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'], 'plots_fair/activation_mean.png') plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'], 'plots_fair/activation_std.png') # Save results to JSON results = { 'config': { 'num_layers': NUM_LAYERS, 'hidden_dim': HIDDEN_DIM, 'num_samples': NUM_SAMPLES, 'training_steps': TRAINING_STEPS, 'learning_rate': LEARNING_RATE, 'batch_size': BATCH_SIZE, 'initialization': 'Kaiming He + 1/sqrt(num_layers) scaling (IDENTICAL for both)' }, 'plain_mlp': { 'final_loss': plain_losses[-1], 'initial_loss': plain_losses[0], 'loss_history': plain_losses, 'gradient_norms': plain_stats['gradient_norms'], 'activation_means': plain_stats['activation_means'], 'activation_stds': plain_stats['activation_stds'] }, 'res_mlp': { 'final_loss': res_losses[-1], 'initial_loss': res_losses[0], 'loss_history': res_losses, 'gradient_norms': res_stats['gradient_norms'], 'activation_means': res_stats['activation_means'], 'activation_stds': res_stats['activation_stds'] }, 'summary': { 'loss_improvement': improvement, 'plain_grad_range': [min(plain_stats['gradient_norms']), max(plain_stats['gradient_norms'])], 'res_grad_range': [min(res_stats['gradient_norms']), max(res_stats['gradient_norms'])], 'plain_std_range': [min(plain_stats['activation_stds']), max(plain_stats['activation_stds'])], 'res_std_range': [min(res_stats['activation_stds']), max(res_stats['activation_stds'])] } } with open('results_fair.json', 'w') as f: json.dump(results, f, indent=2) print("\n[8] Results saved to results_fair.json") print("\n" + "=" * 70) print("FAIR COMPARISON Experiment completed successfully!") print("=" * 70) return results if __name__ == "__main__": results = main()