| # π§ Understanding Residual Connections: PlainMLP vs ResMLP | |
| [](https://opensource.org/licenses/MIT) | |
| [](https://www.python.org/downloads/) | |
| [](https://pytorch.org/) | |
| A comprehensive visual deep dive into **why residual connections solve the vanishing gradient problem** and enable training of deep neural networks. | |
| ## π― Key Finding | |
| > With identical initialization and architecture, the **only difference being `+ x` (residual connection)**, PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves **99.5% loss reduction**. | |
| | Model | Initial Loss | Final Loss | Loss Reduction | | |
| |-------|-------------|------------|----------------| | |
| | PlainMLP (20 layers) | 0.333 | 0.333 | **0%** β | | |
| | ResMLP (20 layers) | 13.826 | 0.063 | **99.5%** β | | |
| ## π Visual Results | |
| ### Training Loss Comparison | |
|  | |
| ### Gradient Flow Analysis | |
| | Layer | PlainMLP Gradient | ResMLP Gradient | | |
| |-------|-------------------|-----------------| | |
| | Layer 1 (earliest) | 8.65 Γ 10β»ΒΉβΉ π | 3.78 Γ 10β»Β³ β | | |
| | Layer 10 (middle) | 1.07 Γ 10β»βΉ | 2.52 Γ 10β»Β³ β | | |
| | Layer 20 (last) | 6.61 Γ 10β»Β³ | 1.91 Γ 10β»Β³ β | | |
| ## π¬ Experimental Setup | |
| ### Task: Distant Identity (Y = X) | |
| - **Input**: 1024 vectors of dimension 64, sampled from U(-1, 1) | |
| - **Target**: Y = X (identity mapping) | |
| - **Challenge**: Can a 20-layer network learn to simply pass input to output? | |
| ### Architecture Comparison | |
| | Component | PlainMLP | ResMLP | | |
| |-----------|----------|--------| | |
| | Layer operation | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` | | |
| | Depth | 20 layers | 20 layers | | |
| | Hidden dimension | 64 | 64 | | |
| | Parameters | 83,200 | 83,200 | | |
| ### Fair Initialization (Critical!) | |
| Both models use **identical initialization**: | |
| - **Weights**: Kaiming He Γ (1/β20) scaling | |
| - **Biases**: Zero | |
| - **No LayerNorm, no BatchNorm, no dropout** | |
| ## π The Core Insight | |
| The residual connection `x = x + f(x)` does ONE simple but profound thing: | |
| > **It ensures that the gradient of the output with respect to the input is always at least 1.** | |
| ### Without residual (`x = f(x)`): | |
| ``` | |
| βoutput/βinput = βf/βx | |
| This can be < 1, and (small)Β²β° β 0 | |
| ``` | |
| ### With residual (`x = x + f(x)`): | |
| ``` | |
| βoutput/βinput = 1 + βf/βx | |
| This is always β₯ 1, so (β₯1)Β²β° β₯ 1 | |
| ``` | |
| ## π Repository Structure | |
| ``` | |
| resmlp_comparison/ | |
| βββ README.md # This file | |
| βββ experiment_final.py # Main experiment code | |
| βββ experiment_fair.py # Fair comparison experiment | |
| βββ visualize_micro_world.py # Visualization generation | |
| βββ results_fair.json # Raw numerical results | |
| βββ report_final.md # Detailed analysis report | |
| β | |
| βββ plots_fair/ # Primary result plots | |
| β βββ training_loss.png | |
| β βββ gradient_magnitude.png | |
| β βββ activation_mean.png | |
| β βββ activation_std.png | |
| β | |
| βββ plots_micro/ # Educational visualizations | |
| βββ 1_signal_flow.png | |
| βββ 2_gradient_flow.png | |
| βββ 3_highway_concept.png | |
| βββ 4_chain_rule.png | |
| βββ 5_layer_transformation.png | |
| βββ 6_learning_comparison.png | |
| ``` | |
| ## π Quick Start | |
| ### Installation | |
| ```bash | |
| pip install torch numpy matplotlib | |
| ``` | |
| ### Run Experiment | |
| ```bash | |
| # Run the main fair comparison experiment | |
| python experiment_fair.py | |
| # Generate micro-world visualizations | |
| python visualize_micro_world.py | |
| ``` | |
| ## π Detailed Visualizations | |
| ### 1. Signal Flow Through Layers | |
|  | |
| - PlainMLP signal collapses to near-zero by layer 15-20 | |
| - ResMLP signal stays stable throughout all layers | |
| ### 2. Gradient Flow (Backward Pass) | |
|  | |
| - PlainMLP: Gradient decays from 10β»Β³ to 10β»ΒΉβΉ (essentially zero!) | |
| - ResMLP: Gradient stays healthy at ~10β»Β³ across ALL layers | |
| ### 3. The Highway Concept | |
|  | |
| - The `+ x` creates a direct "gradient highway" for information flow | |
| ### 4. Chain Rule Mathematics | |
|  | |
| - Visual explanation of why gradients vanish mathematically | |
| ## π Why This Matters | |
| ### Historical Context | |
| Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult due to vanishing gradients. | |
| ### The ResNet Revolution | |
| He et al.'s simple insight enabled: | |
| - **ImageNet SOTA** with 152 layers | |
| - **Foundation for modern architectures**: Transformers use residual connections in every attention block | |
| - **GPT, BERT, Vision Transformers** all rely on this principle | |
| ## π Reports | |
| - [`report_final.md`](report_final.md) - Comprehensive analysis with all visualizations | |
| - [`report_fair.md`](report_fair.md) - Fair comparison methodology | |
| - [`report.md`](report.md) - Initial experiment report | |
| ## π Citation | |
| If you find this educational resource helpful, please consider citing: | |
| ```bibtex | |
| @misc{resmlp_comparison, | |
| title={Understanding Residual Connections: A Visual Deep Dive}, | |
| author={AmberLJC}, | |
| year={2024}, | |
| url={https://huggingface.co/AmberLJC/resmlp_comparison} | |
| } | |
| ``` | |
| ## π License | |
| MIT License - feel free to use for educational purposes! | |
| ## π Acknowledgments | |
| Inspired by the seminal work: | |
| - He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. CVPR. | |