resmlp_comparison / report_fair.md
AmberLJC's picture
Upload report_fair.md with huggingface_hub
214f5ea verified

PlainMLP vs ResMLP: Fair Comparison on Distant Identity Task

Executive Summary

This experiment compares a 20-layer PlainMLP against a 20-layer ResMLP on a synthetic "Distant Identity" task (Y = X), using identical initialization for both models to ensure a fair comparison.

Key Finding: With fair initialization, the PlainMLP shows complete gradient vanishing (gradients at layer 1 are ~10⁻¹⁹), making it essentially untrainable. The ResMLP achieves 5.3x lower loss and maintains healthy gradient flow throughout all layers.


Experimental Setup

Models (IDENTICAL Initialization)

Property PlainMLP ResMLP
Architecture x = ReLU(Linear(x)) x = x + ReLU(Linear(x))
Layers 20 20
Hidden Dimension 64 64
Parameters 83,200 83,200
Weight Init Kaiming He × 1/√20 Kaiming He × 1/√20
Bias Init Zero Zero

Critical: Both models use identical weight initialization (Kaiming He scaled by 1/√num_layers). The ONLY difference is the residual connection.

Training Configuration

  • Task: Learn identity mapping Y = X
  • Data: 1024 vectors, dimension 64, sampled from U(-1, 1)
  • Optimizer: Adam (lr=1e-3)
  • Batch Size: 64
  • Training Steps: 500
  • Loss: MSE

Results

1. Training Loss Comparison

Training Loss

Metric PlainMLP ResMLP
Initial Loss 0.333 13.826
Final Loss 0.333 0.063
Loss Reduction 0% 99.5%
Improvement - 5.3x better

Key Observation: PlainMLP shows zero learning - the loss stays flat at ~0.33 throughout training. ResMLP starts with higher loss (due to accumulated residuals) but rapidly converges to 0.063.

2. Gradient Flow Analysis

Gradient Magnitude

Layer PlainMLP Gradient ResMLP Gradient
Layer 1 (earliest) 8.65 × 10⁻¹⁹ 3.78 × 10⁻³
Layer 10 (middle) ~10⁻¹⁰ ~2.5 × 10⁻³
Layer 20 (last) 6.61 × 10⁻³ 1.91 × 10⁻³

Critical Finding: PlainMLP gradients at layer 1 are essentially zero (10⁻¹⁹ is numerical noise). This is the vanishing gradient problem in its most extreme form. The network cannot learn because gradients don't reach early layers.

ResMLP maintains gradients in the 10⁻³ range across all layers - healthy for learning.

3. Activation Statistics

Activation Std

Metric PlainMLP ResMLP
Std Range [0.0000, 0.1795] [0.1348, 0.1767]
Layer 20 Std ~0 0.135

Key Observation: PlainMLP activations collapse to zero in later layers. The signal is completely lost by the time it reaches the output. ResMLP maintains stable activation statistics throughout.

Activation Mean


Why This Happens

PlainMLP: Multiplicative Gradient Path

In PlainMLP, gradients must flow through all 20 layers multiplicatively:

∂L/∂x₁ = ∂L/∂x₂₀ × ∂x₂₀/∂x₁₉ × ... × ∂x₂/∂x₁

With small weights (scaled by 1/√20 ≈ 0.224), each multiplication shrinks the gradient. After 20 layers:

  • Gradient scale ≈ (0.224)²⁰ ≈ 10⁻¹³ (theoretical)
  • Actual: 10⁻¹⁹ (even worse due to ReLU zeros)

ResMLP: Additive Gradient Path

In ResMLP, the identity shortcut provides a direct gradient path:

∂L/∂x₁ = ∂L/∂x₂₀ × (1 + ∂f₂₀/∂x₁₉) × ... × (1 + ∂f₂/∂x₁)

The "1 +" terms ensure gradients never vanish completely. Even if the residual branch gradients are small, the identity path preserves gradient flow.


Conclusions

  1. Residual connections are essential for deep networks: With identical initialization, PlainMLP is completely untrainable (0% loss reduction) while ResMLP achieves 99.5% loss reduction.

  2. Vanishing gradients are catastrophic: PlainMLP gradients at layer 1 are 10⁻¹⁹ - effectively zero. No amount of training can fix this.

  3. The identity shortcut is the key: The only architectural difference is x = f(x) vs x = x + f(x), yet this makes the difference between a dead network and a functional one.

  4. Fair comparison matters: The previous experiment gave PlainMLP standard Kaiming init while ResMLP had scaled init. This fair comparison shows the true power of residual connections.


Reproducibility

cd projects/resmlp_comparison
python experiment_fair.py

All results are saved to results_fair.json and plots to plots_fair/.


Experiment conducted with PyTorch, random seed 42 for reproducibility.