| # PlainMLP vs ResMLP: Fair Comparison on Distant Identity Task | |
| ## Executive Summary | |
| This experiment compares a 20-layer PlainMLP against a 20-layer ResMLP on a synthetic "Distant Identity" task (Y = X), using **identical initialization** for both models to ensure a fair comparison. | |
| **Key Finding**: With fair initialization, the PlainMLP shows **complete gradient vanishing** (gradients at layer 1 are ~10⁻¹⁹), making it essentially untrainable. The ResMLP achieves **5.3x lower loss** and maintains healthy gradient flow throughout all layers. | |
| --- | |
| ## Experimental Setup | |
| ### Models (IDENTICAL Initialization) | |
| | Property | PlainMLP | ResMLP | | |
| |----------|----------|--------| | |
| | Architecture | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` | | |
| | Layers | 20 | 20 | | |
| | Hidden Dimension | 64 | 64 | | |
| | Parameters | 83,200 | 83,200 | | |
| | Weight Init | Kaiming He × 1/√20 | Kaiming He × 1/√20 | | |
| | Bias Init | Zero | Zero | | |
| **Critical**: Both models use **identical** weight initialization (Kaiming He scaled by 1/√num_layers). The ONLY difference is the residual connection. | |
| ### Training Configuration | |
| - **Task**: Learn identity mapping Y = X | |
| - **Data**: 1024 vectors, dimension 64, sampled from U(-1, 1) | |
| - **Optimizer**: Adam (lr=1e-3) | |
| - **Batch Size**: 64 | |
| - **Training Steps**: 500 | |
| - **Loss**: MSE | |
| --- | |
| ## Results | |
| ### 1. Training Loss Comparison | |
|  | |
| | Metric | PlainMLP | ResMLP | | |
| |--------|----------|--------| | |
| | Initial Loss | 0.333 | 13.826 | | |
| | Final Loss | 0.333 | 0.063 | | |
| | Loss Reduction | **0%** | **99.5%** | | |
| | Improvement | - | **5.3x better** | | |
| **Key Observation**: PlainMLP shows **zero learning** - the loss stays flat at ~0.33 throughout training. ResMLP starts with higher loss (due to accumulated residuals) but rapidly converges to 0.063. | |
| ### 2. Gradient Flow Analysis | |
|  | |
| | Layer | PlainMLP Gradient | ResMLP Gradient | | |
| |-------|-------------------|-----------------| | |
| | Layer 1 (earliest) | **8.65 × 10⁻¹⁹** | 3.78 × 10⁻³ | | |
| | Layer 10 (middle) | ~10⁻¹⁰ | ~2.5 × 10⁻³ | | |
| | Layer 20 (last) | 6.61 × 10⁻³ | 1.91 × 10⁻³ | | |
| **Critical Finding**: PlainMLP gradients at layer 1 are essentially **zero** (10⁻¹⁹ is numerical noise). This is the **vanishing gradient problem** in its most extreme form. The network cannot learn because gradients don't reach early layers. | |
| ResMLP maintains gradients in the 10⁻³ range across all layers - healthy for learning. | |
| ### 3. Activation Statistics | |
|  | |
| | Metric | PlainMLP | ResMLP | | |
| |--------|----------|--------| | |
| | Std Range | [0.0000, 0.1795] | [0.1348, 0.1767] | | |
| | Layer 20 Std | ~0 | 0.135 | | |
| **Key Observation**: PlainMLP activations **collapse to zero** in later layers. The signal is completely lost by the time it reaches the output. ResMLP maintains stable activation statistics throughout. | |
|  | |
| --- | |
| ## Why This Happens | |
| ### PlainMLP: Multiplicative Gradient Path | |
| In PlainMLP, gradients must flow through **all 20 layers multiplicatively**: | |
| ``` | |
| ∂L/∂x₁ = ∂L/∂x₂₀ × ∂x₂₀/∂x₁₉ × ... × ∂x₂/∂x₁ | |
| ``` | |
| With small weights (scaled by 1/√20 ≈ 0.224), each multiplication shrinks the gradient. After 20 layers: | |
| - Gradient scale ≈ (0.224)²⁰ ≈ 10⁻¹³ (theoretical) | |
| - Actual: 10⁻¹⁹ (even worse due to ReLU zeros) | |
| ### ResMLP: Additive Gradient Path | |
| In ResMLP, the identity shortcut provides a **direct gradient path**: | |
| ``` | |
| ∂L/∂x₁ = ∂L/∂x₂₀ × (1 + ∂f₂₀/∂x₁₉) × ... × (1 + ∂f₂/∂x₁) | |
| ``` | |
| The "1 +" terms ensure gradients never vanish completely. Even if the residual branch gradients are small, the identity path preserves gradient flow. | |
| --- | |
| ## Conclusions | |
| 1. **Residual connections are essential for deep networks**: With identical initialization, PlainMLP is completely untrainable (0% loss reduction) while ResMLP achieves 99.5% loss reduction. | |
| 2. **Vanishing gradients are catastrophic**: PlainMLP gradients at layer 1 are 10⁻¹⁹ - effectively zero. No amount of training can fix this. | |
| 3. **The identity shortcut is the key**: The only architectural difference is `x = f(x)` vs `x = x + f(x)`, yet this makes the difference between a dead network and a functional one. | |
| 4. **Fair comparison matters**: The previous experiment gave PlainMLP standard Kaiming init while ResMLP had scaled init. This fair comparison shows the true power of residual connections. | |
| --- | |
| ## Reproducibility | |
| ```bash | |
| cd projects/resmlp_comparison | |
| python experiment_fair.py | |
| ``` | |
| All results are saved to `results_fair.json` and plots to `plots_fair/`. | |
| --- | |
| *Experiment conducted with PyTorch, random seed 42 for reproducibility.* | |