File size: 5,631 Bytes
538e428 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# π§ Understanding Residual Connections: PlainMLP vs ResMLP
[](https://opensource.org/licenses/MIT)
[](https://www.python.org/downloads/)
[](https://pytorch.org/)
A comprehensive visual deep dive into **why residual connections solve the vanishing gradient problem** and enable training of deep neural networks.
## π― Key Finding
> With identical initialization and architecture, the **only difference being `+ x` (residual connection)**, PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves **99.5% loss reduction**.
| Model | Initial Loss | Final Loss | Loss Reduction |
|-------|-------------|------------|----------------|
| PlainMLP (20 layers) | 0.333 | 0.333 | **0%** β |
| ResMLP (20 layers) | 13.826 | 0.063 | **99.5%** β
|
## π Visual Results
### Training Loss Comparison

### Gradient Flow Analysis
| Layer | PlainMLP Gradient | ResMLP Gradient |
|-------|-------------------|-----------------|
| Layer 1 (earliest) | 8.65 Γ 10β»ΒΉβΉ π | 3.78 Γ 10β»Β³ β
|
| Layer 10 (middle) | 1.07 Γ 10β»βΉ | 2.52 Γ 10β»Β³ β
|
| Layer 20 (last) | 6.61 Γ 10β»Β³ | 1.91 Γ 10β»Β³ β
|
## π¬ Experimental Setup
### Task: Distant Identity (Y = X)
- **Input**: 1024 vectors of dimension 64, sampled from U(-1, 1)
- **Target**: Y = X (identity mapping)
- **Challenge**: Can a 20-layer network learn to simply pass input to output?
### Architecture Comparison
| Component | PlainMLP | ResMLP |
|-----------|----------|--------|
| Layer operation | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` |
| Depth | 20 layers | 20 layers |
| Hidden dimension | 64 | 64 |
| Parameters | 83,200 | 83,200 |
### Fair Initialization (Critical!)
Both models use **identical initialization**:
- **Weights**: Kaiming He Γ (1/β20) scaling
- **Biases**: Zero
- **No LayerNorm, no BatchNorm, no dropout**
## π The Core Insight
The residual connection `x = x + f(x)` does ONE simple but profound thing:
> **It ensures that the gradient of the output with respect to the input is always at least 1.**
### Without residual (`x = f(x)`):
```
βoutput/βinput = βf/βx
This can be < 1, and (small)Β²β° β 0
```
### With residual (`x = x + f(x)`):
```
βoutput/βinput = 1 + βf/βx
This is always β₯ 1, so (β₯1)Β²β° β₯ 1
```
## π Repository Structure
```
resmlp_comparison/
βββ README.md # This file
βββ experiment_final.py # Main experiment code
βββ experiment_fair.py # Fair comparison experiment
βββ visualize_micro_world.py # Visualization generation
βββ results_fair.json # Raw numerical results
βββ report_final.md # Detailed analysis report
β
βββ plots_fair/ # Primary result plots
β βββ training_loss.png
β βββ gradient_magnitude.png
β βββ activation_mean.png
β βββ activation_std.png
β
βββ plots_micro/ # Educational visualizations
βββ 1_signal_flow.png
βββ 2_gradient_flow.png
βββ 3_highway_concept.png
βββ 4_chain_rule.png
βββ 5_layer_transformation.png
βββ 6_learning_comparison.png
```
## π Quick Start
### Installation
```bash
pip install torch numpy matplotlib
```
### Run Experiment
```bash
# Run the main fair comparison experiment
python experiment_fair.py
# Generate micro-world visualizations
python visualize_micro_world.py
```
## π Detailed Visualizations
### 1. Signal Flow Through Layers

- PlainMLP signal collapses to near-zero by layer 15-20
- ResMLP signal stays stable throughout all layers
### 2. Gradient Flow (Backward Pass)

- PlainMLP: Gradient decays from 10β»Β³ to 10β»ΒΉβΉ (essentially zero!)
- ResMLP: Gradient stays healthy at ~10β»Β³ across ALL layers
### 3. The Highway Concept

- The `+ x` creates a direct "gradient highway" for information flow
### 4. Chain Rule Mathematics

- Visual explanation of why gradients vanish mathematically
## π Why This Matters
### Historical Context
Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult due to vanishing gradients.
### The ResNet Revolution
He et al.'s simple insight enabled:
- **ImageNet SOTA** with 152 layers
- **Foundation for modern architectures**: Transformers use residual connections in every attention block
- **GPT, BERT, Vision Transformers** all rely on this principle
## π Reports
- [`report_final.md`](report_final.md) - Comprehensive analysis with all visualizations
- [`report_fair.md`](report_fair.md) - Fair comparison methodology
- [`report.md`](report.md) - Initial experiment report
## π Citation
If you find this educational resource helpful, please consider citing:
```bibtex
@misc{resmlp_comparison,
title={Understanding Residual Connections: A Visual Deep Dive},
author={AmberLJC},
year={2024},
url={https://huggingface.co/AmberLJC/resmlp_comparison}
}
```
## π License
MIT License - feel free to use for educational purposes!
## π Acknowledgments
Inspired by the seminal work:
- He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. CVPR.
|