| # Understanding Residual Connections: A Visual Deep Dive | |
| ## Executive Summary | |
| This report presents a comprehensive comparison between a 20-layer PlainMLP and a 20-layer ResMLP on a synthetic "Distant Identity" task (Y = X). Through carefully controlled experiments and detailed visualizations, we demonstrate **why residual connections solve the vanishing gradient problem** and enable training of deep networks. | |
| **Key Finding**: With identical initialization and architecture, the only difference being the presence of `+ x` (residual connection), PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves 99.5% loss reduction. | |
| --- | |
| ## 1. Experimental Setup | |
| ### 1.1 Task: Distant Identity | |
| - **Input**: 1024 vectors of dimension 64, sampled from U(-1, 1) | |
| - **Target**: Y = X (identity mapping) | |
| - **Challenge**: Can a 20-layer network learn to simply pass input to output? | |
| ### 1.2 Architectures | |
| | Component | PlainMLP | ResMLP | | |
| |-----------|----------|--------| | |
| | Layer operation | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` | | |
| | Depth | 20 layers | 20 layers | | |
| | Hidden dimension | 64 | 64 | | |
| | Parameters | 83,200 | 83,200 | | |
| | Normalization | None | None | | |
| ### 1.3 Fair Initialization (Critical!) | |
| Both models use **identical initialization**: | |
| - **Weights**: Kaiming He × (1/√20) scaling | |
| - **Biases**: Zero | |
| - **No LayerNorm, no BatchNorm, no dropout** | |
| The **ONLY difference** is the `+ x` residual connection. | |
| ### 1.4 Training Configuration | |
| - **Optimizer**: Adam (lr=1e-3) | |
| - **Loss**: MSE | |
| - **Batch size**: 64 | |
| - **Steps**: 500 | |
| - **Seed**: 42 | |
| --- | |
| ## 2. Results Overview | |
| ### 2.1 Training Performance | |
| | Metric | PlainMLP | ResMLP | | |
| |--------|----------|--------| | |
| | Initial Loss | 0.333 | 13.826 | | |
| | Final Loss | 0.333 | 0.063 | | |
| | **Loss Reduction** | **0%** | **99.5%** | | |
| | Final Loss Ratio | - | **5.3× better** | | |
| ### 2.2 Gradient Health (After Training) | |
| | Layer | PlainMLP Gradient | ResMLP Gradient | | |
| |-------|-------------------|-----------------| | |
| | Layer 1 (earliest) | 8.65 × 10⁻¹⁹ | 3.78 × 10⁻³ | | |
| | Layer 10 (middle) | 1.07 × 10⁻⁹ | 2.52 × 10⁻³ | | |
| | Layer 20 (last) | 6.61 × 10⁻³ | 1.91 × 10⁻³ | | |
| ### 2.3 Activation Statistics | |
| | Model | Activation Std (Min) | Activation Std (Max) | | |
| |-------|---------------------|---------------------| | |
| | PlainMLP | 0.0000 | 0.1795 | | |
| | ResMLP | 0.1348 | 0.1767 | | |
| --- | |
| ## 3. The Micro-World: Visual Explanations | |
| ### 3.1 Signal Flow Through Layers (Forward Pass) | |
|  | |
| **What's happening:** | |
| - **PlainMLP (Red)**: Signal strength starts healthy (~0.58) but **collapses to near-zero** by layer 15-20 | |
| - **ResMLP (Blue)**: Signal stays **stable around 0.13-0.18** throughout all 20 layers | |
| **Why PlainMLP signal dies:** | |
| Each ReLU activation kills approximately 50% of values (all negatives become 0). After 20 layers: | |
| ``` | |
| Signal survival ≈ 0.5²⁰ ≈ 0.000001 (one millionth!) | |
| ``` | |
| **Why ResMLP signal survives:** | |
| The `+ x` ensures the original signal is always added back, preventing complete collapse. | |
| --- | |
| ### 3.2 Gradient Flow Through Layers (Backward Pass) | |
|  | |
| **What's happening:** | |
| - **PlainMLP**: Gradient at layer 20 is ~10⁻³, but by layer 1 it's **10⁻¹⁹** (essentially zero!) | |
| - **ResMLP**: Gradient stays healthy at ~10⁻³ across ALL layers | |
| **The vanishing gradient problem visualized:** | |
| - PlainMLP gradients decay by ~10¹⁶ across 20 layers | |
| - ResMLP gradients stay within the same order of magnitude | |
| **Consequence**: Early layers in PlainMLP receive NO learning signal. They're frozen! | |
| --- | |
| ### 3.3 The Gradient Highway Concept | |
|  | |
| **The intuition:** | |
| **PlainMLP (Top)**: | |
| - Gradient must pass through EVERY layer sequentially | |
| - Like a winding mountain road with tollbooths at each turn | |
| - Each layer "taxes" the gradient, shrinking it | |
| **ResMLP (Bottom)**: | |
| - The `+ x` creates a **direct highway** (green line) | |
| - Gradients can flow on the express lane, bypassing transformations | |
| - Even if individual layers block gradients, the highway ensures flow | |
| **This is why ResNets can be 100+ layers deep!** | |
| --- | |
| ### 3.4 The Mathematics: Chain Rule Multiplication | |
|  | |
| **Why gradients vanish - the math:** | |
| **PlainMLP gradient (chain rule):** | |
| ``` | |
| ∂L/∂x₁ = ∂L/∂x₂₀ × ∂x₂₀/∂x₁₉ × ∂x₁₉/∂x₁₈ × ... × ∂x₂/∂x₁ | |
| Each term ∂xᵢ₊₁/∂xᵢ ≈ 0.7 (due to ReLU killing half the gradients) | |
| Result: ∂L/∂x₁ = ∂L/∂x₂₀ × 0.7²⁰ = ∂L/∂x₂₀ × 0.0000008 | |
| ``` | |
| **ResMLP gradient (chain rule):** | |
| ``` | |
| Since xᵢ₊₁ = xᵢ + f(xᵢ), we have: | |
| ∂xᵢ₊₁/∂xᵢ = 1 + ∂f/∂xᵢ ≈ 1 + small_value | |
| Result: ∂L/∂x₁ = ∂L/∂x₂₀ × (1+ε)²⁰ ≈ ∂L/∂x₂₀ × 1.0 | |
| ``` | |
| **The key insight**: The `+ x` adds a **"1"** to each gradient term, preventing the product from shrinking to zero! | |
| --- | |
| ### 3.5 Layer-by-Layer Transformation | |
|  | |
| **Four views of what happens to data:** | |
| 1. **Top-left (Vector Magnitude)**: PlainMLP vector norm shrinks to near-zero; ResMLP stays stable | |
| 2. **Top-right (2D Trajectory)**: | |
| - PlainMLP path (red) collapses toward origin | |
| - ResMLP path (blue) maintains meaningful position | |
| 3. **Bottom-left (PlainMLP Heatmap)**: Activations go dark (dead) in later layers | |
| 4. **Bottom-right (ResMLP Heatmap)**: Activations stay colorful (alive) throughout | |
| --- | |
| ### 3.6 Learning Comparison Summary | |
|  | |
| **The complete picture:** | |
| | Aspect | PlainMLP | ResMLP | | |
| |--------|----------|--------| | |
| | Loss Reduction | 0% | 99.5% | | |
| | Learning Status | FAILED | SUCCESS | | |
| | Gradient at L1 | 10⁻¹⁹ (dead) | 10⁻³ (healthy) | | |
| | Trainable? | NO | YES | | |
| --- | |
| ## 4. The Core Insight | |
| The residual connection `x = x + f(x)` does ONE simple but profound thing: | |
| > **It ensures that the gradient of the output with respect to the input is always at least 1.** | |
| ### Without residual (`x = f(x)`): | |
| ``` | |
| ∂output/∂input = ∂f/∂x | |
| This can be < 1, and (small)²⁰ → 0 | |
| ``` | |
| ### With residual (`x = x + f(x)`): | |
| ``` | |
| ∂output/∂input = 1 + ∂f/∂x | |
| This is always ≥ 1, so (≥1)²⁰ ≥ 1 | |
| ``` | |
| **This single change enables:** | |
| - 20-layer networks (this experiment) | |
| - 100-layer networks (ResNet-101) | |
| - 1000-layer networks (demonstrated in research) | |
| --- | |
| ## 5. Why This Matters | |
| ### 5.1 Historical Context | |
| Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult. The vanishing gradient problem meant early layers couldn't learn. | |
| ### 5.2 The ResNet Revolution | |
| He et al.'s simple insight - add the input to the output - enabled: | |
| - **ImageNet SOTA** with 152 layers | |
| - **Foundation for modern architectures**: Transformers use residual connections in every attention block | |
| - **GPT, BERT, Vision Transformers** all rely on this principle | |
| ### 5.3 The Identity Mapping Perspective | |
| Another way to understand residuals: the network only needs to learn the **residual** (difference from identity), not the full transformation. Learning "do nothing" becomes trivially easy (just set weights to zero). | |
| --- | |
| ## 6. Reproducibility | |
| ### 6.1 Code | |
| All experiments can be reproduced using: | |
| ```bash | |
| cd projects/resmlp_comparison | |
| python experiment_fair.py # Run main experiment | |
| python visualize_micro_world.py # Generate visualizations | |
| ``` | |
| ### 6.2 Key Files | |
| - `experiment_fair.py`: Main experiment code | |
| - `visualize_micro_world.py`: Visualization generation | |
| - `results_fair.json`: Raw numerical results | |
| - `plots_fair/`: Primary result plots | |
| - `plots_micro/`: Micro-world explanation visualizations | |
| ### 6.3 Dependencies | |
| - PyTorch | |
| - NumPy | |
| - Matplotlib | |
| --- | |
| ## 7. Conclusion | |
| Through this controlled experiment, we've demonstrated: | |
| 1. **The Problem**: Deep networks without residual connections suffer catastrophic gradient vanishing (10⁻¹⁹ at layer 1) | |
| 2. **The Solution**: A simple `+ x` residual connection maintains healthy gradients (~10⁻³) throughout | |
| 3. **The Result**: 99.5% loss reduction with residuals vs. 0% without | |
| 4. **The Mechanism**: The residual adds a "1" to each gradient term in the chain rule, preventing multiplicative decay | |
| **The residual connection is perhaps the most important architectural innovation in deep learning history, enabling the training of arbitrarily deep networks.** | |
| --- | |
| ## Appendix: Numerical Results | |
| ### A.1 Loss History (Selected Steps) | |
| | Step | PlainMLP Loss | ResMLP Loss | | |
| |------|---------------|-------------| | |
| | 0 | 0.333 | 13.826 | | |
| | 100 | 0.333 | 0.328 | | |
| | 200 | 0.333 | 0.137 | | |
| | 300 | 0.333 | 0.091 | | |
| | 400 | 0.333 | 0.073 | | |
| | 500 | 0.333 | 0.063 | | |
| ### A.2 Gradient Norms by Layer (Final State) | |
| | Layer | PlainMLP | ResMLP | | |
| |-------|----------|--------| | |
| | 1 | 8.65e-19 | 3.78e-03 | | |
| | 5 | 2.15e-14 | 3.15e-03 | | |
| | 10 | 1.07e-09 | 2.52e-03 | | |
| | 15 | 5.29e-05 | 2.17e-03 | | |
| | 20 | 6.61e-03 | 1.91e-03 | | |
| ### A.3 Activation Statistics by Layer (Final State) | |
| | Layer | PlainMLP Std | ResMLP Std | | |
| |-------|--------------|------------| | |
| | 1 | 0.0000 | 0.1348 | | |
| | 5 | 0.0000 | 0.1456 | | |
| | 10 | 0.0001 | 0.1589 | | |
| | 15 | 0.0234 | 0.1678 | | |
| | 20 | 0.1795 | 0.1767 | | |