resmlp_comparison / report_final.md
AmberLJC's picture
Upload report_final.md with huggingface_hub
61dd467 verified
# Understanding Residual Connections: A Visual Deep Dive
## Executive Summary
This report presents a comprehensive comparison between a 20-layer PlainMLP and a 20-layer ResMLP on a synthetic "Distant Identity" task (Y = X). Through carefully controlled experiments and detailed visualizations, we demonstrate **why residual connections solve the vanishing gradient problem** and enable training of deep networks.
**Key Finding**: With identical initialization and architecture, the only difference being the presence of `+ x` (residual connection), PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves 99.5% loss reduction.
---
## 1. Experimental Setup
### 1.1 Task: Distant Identity
- **Input**: 1024 vectors of dimension 64, sampled from U(-1, 1)
- **Target**: Y = X (identity mapping)
- **Challenge**: Can a 20-layer network learn to simply pass input to output?
### 1.2 Architectures
| Component | PlainMLP | ResMLP |
|-----------|----------|--------|
| Layer operation | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` |
| Depth | 20 layers | 20 layers |
| Hidden dimension | 64 | 64 |
| Parameters | 83,200 | 83,200 |
| Normalization | None | None |
### 1.3 Fair Initialization (Critical!)
Both models use **identical initialization**:
- **Weights**: Kaiming He × (1/√20) scaling
- **Biases**: Zero
- **No LayerNorm, no BatchNorm, no dropout**
The **ONLY difference** is the `+ x` residual connection.
### 1.4 Training Configuration
- **Optimizer**: Adam (lr=1e-3)
- **Loss**: MSE
- **Batch size**: 64
- **Steps**: 500
- **Seed**: 42
---
## 2. Results Overview
### 2.1 Training Performance
| Metric | PlainMLP | ResMLP |
|--------|----------|--------|
| Initial Loss | 0.333 | 13.826 |
| Final Loss | 0.333 | 0.063 |
| **Loss Reduction** | **0%** | **99.5%** |
| Final Loss Ratio | - | **5.3× better** |
### 2.2 Gradient Health (After Training)
| Layer | PlainMLP Gradient | ResMLP Gradient |
|-------|-------------------|-----------------|
| Layer 1 (earliest) | 8.65 × 10⁻¹⁹ | 3.78 × 10⁻³ |
| Layer 10 (middle) | 1.07 × 10⁻⁹ | 2.52 × 10⁻³ |
| Layer 20 (last) | 6.61 × 10⁻³ | 1.91 × 10⁻³ |
### 2.3 Activation Statistics
| Model | Activation Std (Min) | Activation Std (Max) |
|-------|---------------------|---------------------|
| PlainMLP | 0.0000 | 0.1795 |
| ResMLP | 0.1348 | 0.1767 |
---
## 3. The Micro-World: Visual Explanations
### 3.1 Signal Flow Through Layers (Forward Pass)
![Signal Flow](1_signal_flow.png)
**What's happening:**
- **PlainMLP (Red)**: Signal strength starts healthy (~0.58) but **collapses to near-zero** by layer 15-20
- **ResMLP (Blue)**: Signal stays **stable around 0.13-0.18** throughout all 20 layers
**Why PlainMLP signal dies:**
Each ReLU activation kills approximately 50% of values (all negatives become 0). After 20 layers:
```
Signal survival ≈ 0.5²⁰ ≈ 0.000001 (one millionth!)
```
**Why ResMLP signal survives:**
The `+ x` ensures the original signal is always added back, preventing complete collapse.
---
### 3.2 Gradient Flow Through Layers (Backward Pass)
![Gradient Flow](2_gradient_flow.png)
**What's happening:**
- **PlainMLP**: Gradient at layer 20 is ~10⁻³, but by layer 1 it's **10⁻¹⁹** (essentially zero!)
- **ResMLP**: Gradient stays healthy at ~10⁻³ across ALL layers
**The vanishing gradient problem visualized:**
- PlainMLP gradients decay by ~10¹⁶ across 20 layers
- ResMLP gradients stay within the same order of magnitude
**Consequence**: Early layers in PlainMLP receive NO learning signal. They're frozen!
---
### 3.3 The Gradient Highway Concept
![Highway Concept](3_highway_concept.png)
**The intuition:**
**PlainMLP (Top)**:
- Gradient must pass through EVERY layer sequentially
- Like a winding mountain road with tollbooths at each turn
- Each layer "taxes" the gradient, shrinking it
**ResMLP (Bottom)**:
- The `+ x` creates a **direct highway** (green line)
- Gradients can flow on the express lane, bypassing transformations
- Even if individual layers block gradients, the highway ensures flow
**This is why ResNets can be 100+ layers deep!**
---
### 3.4 The Mathematics: Chain Rule Multiplication
![Chain Rule](4_chain_rule.png)
**Why gradients vanish - the math:**
**PlainMLP gradient (chain rule):**
```
∂L/∂x₁ = ∂L/∂x₂₀ × ∂x₂₀/∂x₁₉ × ∂x₁₉/∂x₁₈ × ... × ∂x₂/∂x₁
Each term ∂xᵢ₊₁/∂xᵢ ≈ 0.7 (due to ReLU killing half the gradients)
Result: ∂L/∂x₁ = ∂L/∂x₂₀ × 0.7²⁰ = ∂L/∂x₂₀ × 0.0000008
```
**ResMLP gradient (chain rule):**
```
Since xᵢ₊₁ = xᵢ + f(xᵢ), we have:
∂xᵢ₊₁/∂xᵢ = 1 + ∂f/∂xᵢ ≈ 1 + small_value
Result: ∂L/∂x₁ = ∂L/∂x₂₀ × (1+ε)²⁰ ≈ ∂L/∂x₂₀ × 1.0
```
**The key insight**: The `+ x` adds a **"1"** to each gradient term, preventing the product from shrinking to zero!
---
### 3.5 Layer-by-Layer Transformation
![Layer Transformation](5_layer_transformation.png)
**Four views of what happens to data:**
1. **Top-left (Vector Magnitude)**: PlainMLP vector norm shrinks to near-zero; ResMLP stays stable
2. **Top-right (2D Trajectory)**:
- PlainMLP path (red) collapses toward origin
- ResMLP path (blue) maintains meaningful position
3. **Bottom-left (PlainMLP Heatmap)**: Activations go dark (dead) in later layers
4. **Bottom-right (ResMLP Heatmap)**: Activations stay colorful (alive) throughout
---
### 3.6 Learning Comparison Summary
![Learning Comparison](6_learning_comparison.png)
**The complete picture:**
| Aspect | PlainMLP | ResMLP |
|--------|----------|--------|
| Loss Reduction | 0% | 99.5% |
| Learning Status | FAILED | SUCCESS |
| Gradient at L1 | 10⁻¹⁹ (dead) | 10⁻³ (healthy) |
| Trainable? | NO | YES |
---
## 4. The Core Insight
The residual connection `x = x + f(x)` does ONE simple but profound thing:
> **It ensures that the gradient of the output with respect to the input is always at least 1.**
### Without residual (`x = f(x)`):
```
∂output/∂input = ∂f/∂x
This can be < 1, and (small)²⁰ → 0
```
### With residual (`x = x + f(x)`):
```
∂output/∂input = 1 + ∂f/∂x
This is always ≥ 1, so (≥1)²⁰ ≥ 1
```
**This single change enables:**
- 20-layer networks (this experiment)
- 100-layer networks (ResNet-101)
- 1000-layer networks (demonstrated in research)
---
## 5. Why This Matters
### 5.1 Historical Context
Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult. The vanishing gradient problem meant early layers couldn't learn.
### 5.2 The ResNet Revolution
He et al.'s simple insight - add the input to the output - enabled:
- **ImageNet SOTA** with 152 layers
- **Foundation for modern architectures**: Transformers use residual connections in every attention block
- **GPT, BERT, Vision Transformers** all rely on this principle
### 5.3 The Identity Mapping Perspective
Another way to understand residuals: the network only needs to learn the **residual** (difference from identity), not the full transformation. Learning "do nothing" becomes trivially easy (just set weights to zero).
---
## 6. Reproducibility
### 6.1 Code
All experiments can be reproduced using:
```bash
cd projects/resmlp_comparison
python experiment_fair.py # Run main experiment
python visualize_micro_world.py # Generate visualizations
```
### 6.2 Key Files
- `experiment_fair.py`: Main experiment code
- `visualize_micro_world.py`: Visualization generation
- `results_fair.json`: Raw numerical results
- `plots_fair/`: Primary result plots
- `plots_micro/`: Micro-world explanation visualizations
### 6.3 Dependencies
- PyTorch
- NumPy
- Matplotlib
---
## 7. Conclusion
Through this controlled experiment, we've demonstrated:
1. **The Problem**: Deep networks without residual connections suffer catastrophic gradient vanishing (10⁻¹⁹ at layer 1)
2. **The Solution**: A simple `+ x` residual connection maintains healthy gradients (~10⁻³) throughout
3. **The Result**: 99.5% loss reduction with residuals vs. 0% without
4. **The Mechanism**: The residual adds a "1" to each gradient term in the chain rule, preventing multiplicative decay
**The residual connection is perhaps the most important architectural innovation in deep learning history, enabling the training of arbitrarily deep networks.**
---
## Appendix: Numerical Results
### A.1 Loss History (Selected Steps)
| Step | PlainMLP Loss | ResMLP Loss |
|------|---------------|-------------|
| 0 | 0.333 | 13.826 |
| 100 | 0.333 | 0.328 |
| 200 | 0.333 | 0.137 |
| 300 | 0.333 | 0.091 |
| 400 | 0.333 | 0.073 |
| 500 | 0.333 | 0.063 |
### A.2 Gradient Norms by Layer (Final State)
| Layer | PlainMLP | ResMLP |
|-------|----------|--------|
| 1 | 8.65e-19 | 3.78e-03 |
| 5 | 2.15e-14 | 3.15e-03 |
| 10 | 1.07e-09 | 2.52e-03 |
| 15 | 5.29e-05 | 2.17e-03 |
| 20 | 6.61e-03 | 1.91e-03 |
### A.3 Activation Statistics by Layer (Final State)
| Layer | PlainMLP Std | ResMLP Std |
|-------|--------------|------------|
| 1 | 0.0000 | 0.1348 |
| 5 | 0.0000 | 0.1456 |
| 10 | 0.0001 | 0.1589 |
| 15 | 0.0234 | 0.1678 |
| 20 | 0.1795 | 0.1767 |