# PlainMLP vs ResMLP: Fair Comparison on Distant Identity Task ## Executive Summary This experiment compares a 20-layer PlainMLP against a 20-layer ResMLP on a synthetic "Distant Identity" task (Y = X), using **identical initialization** for both models to ensure a fair comparison. **Key Finding**: With fair initialization, the PlainMLP shows **complete gradient vanishing** (gradients at layer 1 are ~10⁻¹⁹), making it essentially untrainable. The ResMLP achieves **5.3x lower loss** and maintains healthy gradient flow throughout all layers. --- ## Experimental Setup ### Models (IDENTICAL Initialization) | Property | PlainMLP | ResMLP | |----------|----------|--------| | Architecture | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` | | Layers | 20 | 20 | | Hidden Dimension | 64 | 64 | | Parameters | 83,200 | 83,200 | | Weight Init | Kaiming He × 1/√20 | Kaiming He × 1/√20 | | Bias Init | Zero | Zero | **Critical**: Both models use **identical** weight initialization (Kaiming He scaled by 1/√num_layers). The ONLY difference is the residual connection. ### Training Configuration - **Task**: Learn identity mapping Y = X - **Data**: 1024 vectors, dimension 64, sampled from U(-1, 1) - **Optimizer**: Adam (lr=1e-3) - **Batch Size**: 64 - **Training Steps**: 500 - **Loss**: MSE --- ## Results ### 1. Training Loss Comparison ![Training Loss](training_loss.png) | Metric | PlainMLP | ResMLP | |--------|----------|--------| | Initial Loss | 0.333 | 13.826 | | Final Loss | 0.333 | 0.063 | | Loss Reduction | **0%** | **99.5%** | | Improvement | - | **5.3x better** | **Key Observation**: PlainMLP shows **zero learning** - the loss stays flat at ~0.33 throughout training. ResMLP starts with higher loss (due to accumulated residuals) but rapidly converges to 0.063. ### 2. Gradient Flow Analysis ![Gradient Magnitude](gradient_magnitude.png) | Layer | PlainMLP Gradient | ResMLP Gradient | |-------|-------------------|-----------------| | Layer 1 (earliest) | **8.65 × 10⁻¹⁹** | 3.78 × 10⁻³ | | Layer 10 (middle) | ~10⁻¹⁰ | ~2.5 × 10⁻³ | | Layer 20 (last) | 6.61 × 10⁻³ | 1.91 × 10⁻³ | **Critical Finding**: PlainMLP gradients at layer 1 are essentially **zero** (10⁻¹⁹ is numerical noise). This is the **vanishing gradient problem** in its most extreme form. The network cannot learn because gradients don't reach early layers. ResMLP maintains gradients in the 10⁻³ range across all layers - healthy for learning. ### 3. Activation Statistics ![Activation Std](activation_std.png) | Metric | PlainMLP | ResMLP | |--------|----------|--------| | Std Range | [0.0000, 0.1795] | [0.1348, 0.1767] | | Layer 20 Std | ~0 | 0.135 | **Key Observation**: PlainMLP activations **collapse to zero** in later layers. The signal is completely lost by the time it reaches the output. ResMLP maintains stable activation statistics throughout. ![Activation Mean](activation_mean.png) --- ## Why This Happens ### PlainMLP: Multiplicative Gradient Path In PlainMLP, gradients must flow through **all 20 layers multiplicatively**: ``` ∂L/∂x₁ = ∂L/∂x₂₀ × ∂x₂₀/∂x₁₉ × ... × ∂x₂/∂x₁ ``` With small weights (scaled by 1/√20 ≈ 0.224), each multiplication shrinks the gradient. After 20 layers: - Gradient scale ≈ (0.224)²⁰ ≈ 10⁻¹³ (theoretical) - Actual: 10⁻¹⁹ (even worse due to ReLU zeros) ### ResMLP: Additive Gradient Path In ResMLP, the identity shortcut provides a **direct gradient path**: ``` ∂L/∂x₁ = ∂L/∂x₂₀ × (1 + ∂f₂₀/∂x₁₉) × ... × (1 + ∂f₂/∂x₁) ``` The "1 +" terms ensure gradients never vanish completely. Even if the residual branch gradients are small, the identity path preserves gradient flow. --- ## Conclusions 1. **Residual connections are essential for deep networks**: With identical initialization, PlainMLP is completely untrainable (0% loss reduction) while ResMLP achieves 99.5% loss reduction. 2. **Vanishing gradients are catastrophic**: PlainMLP gradients at layer 1 are 10⁻¹⁹ - effectively zero. No amount of training can fix this. 3. **The identity shortcut is the key**: The only architectural difference is `x = f(x)` vs `x = x + f(x)`, yet this makes the difference between a dead network and a functional one. 4. **Fair comparison matters**: The previous experiment gave PlainMLP standard Kaiming init while ResMLP had scaled init. This fair comparison shows the true power of residual connections. --- ## Reproducibility ```bash cd projects/resmlp_comparison python experiment_fair.py ``` All results are saved to `results_fair.json` and plots to `plots_fair/`. --- *Experiment conducted with PyTorch, random seed 42 for reproducibility.*