resmlp_comparison / report_fair.md
AmberLJC's picture
Upload report_fair.md with huggingface_hub
214f5ea verified
# PlainMLP vs ResMLP: Fair Comparison on Distant Identity Task
## Executive Summary
This experiment compares a 20-layer PlainMLP against a 20-layer ResMLP on a synthetic "Distant Identity" task (Y = X), using **identical initialization** for both models to ensure a fair comparison.
**Key Finding**: With fair initialization, the PlainMLP shows **complete gradient vanishing** (gradients at layer 1 are ~10⁻¹⁹), making it essentially untrainable. The ResMLP achieves **5.3x lower loss** and maintains healthy gradient flow throughout all layers.
---
## Experimental Setup
### Models (IDENTICAL Initialization)
| Property | PlainMLP | ResMLP |
|----------|----------|--------|
| Architecture | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` |
| Layers | 20 | 20 |
| Hidden Dimension | 64 | 64 |
| Parameters | 83,200 | 83,200 |
| Weight Init | Kaiming He × 1/√20 | Kaiming He × 1/√20 |
| Bias Init | Zero | Zero |
**Critical**: Both models use **identical** weight initialization (Kaiming He scaled by 1/√num_layers). The ONLY difference is the residual connection.
### Training Configuration
- **Task**: Learn identity mapping Y = X
- **Data**: 1024 vectors, dimension 64, sampled from U(-1, 1)
- **Optimizer**: Adam (lr=1e-3)
- **Batch Size**: 64
- **Training Steps**: 500
- **Loss**: MSE
---
## Results
### 1. Training Loss Comparison
![Training Loss](training_loss.png)
| Metric | PlainMLP | ResMLP |
|--------|----------|--------|
| Initial Loss | 0.333 | 13.826 |
| Final Loss | 0.333 | 0.063 |
| Loss Reduction | **0%** | **99.5%** |
| Improvement | - | **5.3x better** |
**Key Observation**: PlainMLP shows **zero learning** - the loss stays flat at ~0.33 throughout training. ResMLP starts with higher loss (due to accumulated residuals) but rapidly converges to 0.063.
### 2. Gradient Flow Analysis
![Gradient Magnitude](gradient_magnitude.png)
| Layer | PlainMLP Gradient | ResMLP Gradient |
|-------|-------------------|-----------------|
| Layer 1 (earliest) | **8.65 × 10⁻¹⁹** | 3.78 × 10⁻³ |
| Layer 10 (middle) | ~10⁻¹⁰ | ~2.5 × 10⁻³ |
| Layer 20 (last) | 6.61 × 10⁻³ | 1.91 × 10⁻³ |
**Critical Finding**: PlainMLP gradients at layer 1 are essentially **zero** (10⁻¹⁹ is numerical noise). This is the **vanishing gradient problem** in its most extreme form. The network cannot learn because gradients don't reach early layers.
ResMLP maintains gradients in the 10⁻³ range across all layers - healthy for learning.
### 3. Activation Statistics
![Activation Std](activation_std.png)
| Metric | PlainMLP | ResMLP |
|--------|----------|--------|
| Std Range | [0.0000, 0.1795] | [0.1348, 0.1767] |
| Layer 20 Std | ~0 | 0.135 |
**Key Observation**: PlainMLP activations **collapse to zero** in later layers. The signal is completely lost by the time it reaches the output. ResMLP maintains stable activation statistics throughout.
![Activation Mean](activation_mean.png)
---
## Why This Happens
### PlainMLP: Multiplicative Gradient Path
In PlainMLP, gradients must flow through **all 20 layers multiplicatively**:
```
∂L/∂x₁ = ∂L/∂x₂₀ × ∂x₂₀/∂x₁₉ × ... × ∂x₂/∂x₁
```
With small weights (scaled by 1/√20 ≈ 0.224), each multiplication shrinks the gradient. After 20 layers:
- Gradient scale ≈ (0.224)²⁰ ≈ 10⁻¹³ (theoretical)
- Actual: 10⁻¹⁹ (even worse due to ReLU zeros)
### ResMLP: Additive Gradient Path
In ResMLP, the identity shortcut provides a **direct gradient path**:
```
∂L/∂x₁ = ∂L/∂x₂₀ × (1 + ∂f₂₀/∂x₁₉) × ... × (1 + ∂f₂/∂x₁)
```
The "1 +" terms ensure gradients never vanish completely. Even if the residual branch gradients are small, the identity path preserves gradient flow.
---
## Conclusions
1. **Residual connections are essential for deep networks**: With identical initialization, PlainMLP is completely untrainable (0% loss reduction) while ResMLP achieves 99.5% loss reduction.
2. **Vanishing gradients are catastrophic**: PlainMLP gradients at layer 1 are 10⁻¹⁹ - effectively zero. No amount of training can fix this.
3. **The identity shortcut is the key**: The only architectural difference is `x = f(x)` vs `x = x + f(x)`, yet this makes the difference between a dead network and a functional one.
4. **Fair comparison matters**: The previous experiment gave PlainMLP standard Kaiming init while ResMLP had scaled init. This fair comparison shows the true power of residual connections.
---
## Reproducibility
```bash
cd projects/resmlp_comparison
python experiment_fair.py
```
All results are saved to `results_fair.json` and plots to `plots_fair/`.
---
*Experiment conducted with PyTorch, random seed 42 for reproducibility.*