# 🧠 Understanding Residual Connections: PlainMLP vs ResMLP [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-ee4c2c.svg)](https://pytorch.org/) A comprehensive visual deep dive into **why residual connections solve the vanishing gradient problem** and enable training of deep neural networks. ## 🎯 Key Finding > With identical initialization and architecture, the **only difference being `+ x` (residual connection)**, PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves **99.5% loss reduction**. | Model | Initial Loss | Final Loss | Loss Reduction | |-------|-------------|------------|----------------| | PlainMLP (20 layers) | 0.333 | 0.333 | **0%** ❌ | | ResMLP (20 layers) | 13.826 | 0.063 | **99.5%** βœ… | ## πŸ“Š Visual Results ### Training Loss Comparison ![Training Loss](plots_fair/training_loss.png) ### Gradient Flow Analysis | Layer | PlainMLP Gradient | ResMLP Gradient | |-------|-------------------|-----------------| | Layer 1 (earliest) | 8.65 Γ— 10⁻¹⁹ πŸ’€ | 3.78 Γ— 10⁻³ βœ… | | Layer 10 (middle) | 1.07 Γ— 10⁻⁹ | 2.52 Γ— 10⁻³ βœ… | | Layer 20 (last) | 6.61 Γ— 10⁻³ | 1.91 Γ— 10⁻³ βœ… | ## πŸ”¬ Experimental Setup ### Task: Distant Identity (Y = X) - **Input**: 1024 vectors of dimension 64, sampled from U(-1, 1) - **Target**: Y = X (identity mapping) - **Challenge**: Can a 20-layer network learn to simply pass input to output? ### Architecture Comparison | Component | PlainMLP | ResMLP | |-----------|----------|--------| | Layer operation | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` | | Depth | 20 layers | 20 layers | | Hidden dimension | 64 | 64 | | Parameters | 83,200 | 83,200 | ### Fair Initialization (Critical!) Both models use **identical initialization**: - **Weights**: Kaiming He Γ— (1/√20) scaling - **Biases**: Zero - **No LayerNorm, no BatchNorm, no dropout** ## πŸŽ“ The Core Insight The residual connection `x = x + f(x)` does ONE simple but profound thing: > **It ensures that the gradient of the output with respect to the input is always at least 1.** ### Without residual (`x = f(x)`): ``` βˆ‚output/βˆ‚input = βˆ‚f/βˆ‚x This can be < 1, and (small)²⁰ β†’ 0 ``` ### With residual (`x = x + f(x)`): ``` βˆ‚output/βˆ‚input = 1 + βˆ‚f/βˆ‚x This is always β‰₯ 1, so (β‰₯1)²⁰ β‰₯ 1 ``` ## πŸ“ Repository Structure ``` resmlp_comparison/ β”œβ”€β”€ README.md # This file β”œβ”€β”€ experiment_final.py # Main experiment code β”œβ”€β”€ experiment_fair.py # Fair comparison experiment β”œβ”€β”€ visualize_micro_world.py # Visualization generation β”œβ”€β”€ results_fair.json # Raw numerical results β”œβ”€β”€ report_final.md # Detailed analysis report β”‚ β”œβ”€β”€ plots_fair/ # Primary result plots β”‚ β”œβ”€β”€ training_loss.png β”‚ β”œβ”€β”€ gradient_magnitude.png β”‚ β”œβ”€β”€ activation_mean.png β”‚ └── activation_std.png β”‚ └── plots_micro/ # Educational visualizations β”œβ”€β”€ 1_signal_flow.png β”œβ”€β”€ 2_gradient_flow.png β”œβ”€β”€ 3_highway_concept.png β”œβ”€β”€ 4_chain_rule.png β”œβ”€β”€ 5_layer_transformation.png └── 6_learning_comparison.png ``` ## πŸš€ Quick Start ### Installation ```bash pip install torch numpy matplotlib ``` ### Run Experiment ```bash # Run the main fair comparison experiment python experiment_fair.py # Generate micro-world visualizations python visualize_micro_world.py ``` ## πŸ“ˆ Detailed Visualizations ### 1. Signal Flow Through Layers ![Signal Flow](plots_micro/1_signal_flow.png) - PlainMLP signal collapses to near-zero by layer 15-20 - ResMLP signal stays stable throughout all layers ### 2. Gradient Flow (Backward Pass) ![Gradient Flow](plots_micro/2_gradient_flow.png) - PlainMLP: Gradient decays from 10⁻³ to 10⁻¹⁹ (essentially zero!) - ResMLP: Gradient stays healthy at ~10⁻³ across ALL layers ### 3. The Highway Concept ![Highway](plots_micro/3_highway_concept.png) - The `+ x` creates a direct "gradient highway" for information flow ### 4. Chain Rule Mathematics ![Chain Rule](plots_micro/4_chain_rule.png) - Visual explanation of why gradients vanish mathematically ## πŸ”‘ Why This Matters ### Historical Context Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult due to vanishing gradients. ### The ResNet Revolution He et al.'s simple insight enabled: - **ImageNet SOTA** with 152 layers - **Foundation for modern architectures**: Transformers use residual connections in every attention block - **GPT, BERT, Vision Transformers** all rely on this principle ## πŸ“š Reports - [`report_final.md`](report_final.md) - Comprehensive analysis with all visualizations - [`report_fair.md`](report_fair.md) - Fair comparison methodology - [`report.md`](report.md) - Initial experiment report ## πŸ“– Citation If you find this educational resource helpful, please consider citing: ```bibtex @misc{resmlp_comparison, title={Understanding Residual Connections: A Visual Deep Dive}, author={AmberLJC}, year={2024}, url={https://huggingface.co/AmberLJC/resmlp_comparison} } ``` ## πŸ“„ License MIT License - feel free to use for educational purposes! ## πŸ™ Acknowledgments Inspired by the seminal work: - He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. CVPR.