File size: 5,631 Bytes
538e428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# 🧠 Understanding Residual Connections: PlainMLP vs ResMLP

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-ee4c2c.svg)](https://pytorch.org/)

A comprehensive visual deep dive into **why residual connections solve the vanishing gradient problem** and enable training of deep neural networks.

## 🎯 Key Finding

> With identical initialization and architecture, the **only difference being `+ x` (residual connection)**, PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves **99.5% loss reduction**.

| Model | Initial Loss | Final Loss | Loss Reduction |
|-------|-------------|------------|----------------|
| PlainMLP (20 layers) | 0.333 | 0.333 | **0%** ❌ |
| ResMLP (20 layers) | 13.826 | 0.063 | **99.5%** βœ… |

## πŸ“Š Visual Results

### Training Loss Comparison
![Training Loss](plots_fair/training_loss.png)

### Gradient Flow Analysis
| Layer | PlainMLP Gradient | ResMLP Gradient |
|-------|-------------------|-----------------|
| Layer 1 (earliest) | 8.65 Γ— 10⁻¹⁹ πŸ’€ | 3.78 Γ— 10⁻³ βœ… |
| Layer 10 (middle) | 1.07 Γ— 10⁻⁹ | 2.52 Γ— 10⁻³ βœ… |
| Layer 20 (last) | 6.61 Γ— 10⁻³ | 1.91 Γ— 10⁻³ βœ… |

## πŸ”¬ Experimental Setup

### Task: Distant Identity (Y = X)
- **Input**: 1024 vectors of dimension 64, sampled from U(-1, 1)
- **Target**: Y = X (identity mapping)
- **Challenge**: Can a 20-layer network learn to simply pass input to output?

### Architecture Comparison

| Component | PlainMLP | ResMLP |
|-----------|----------|--------|
| Layer operation | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` |
| Depth | 20 layers | 20 layers |
| Hidden dimension | 64 | 64 |
| Parameters | 83,200 | 83,200 |

### Fair Initialization (Critical!)
Both models use **identical initialization**:
- **Weights**: Kaiming He Γ— (1/√20) scaling
- **Biases**: Zero
- **No LayerNorm, no BatchNorm, no dropout**

## πŸŽ“ The Core Insight

The residual connection `x = x + f(x)` does ONE simple but profound thing:

> **It ensures that the gradient of the output with respect to the input is always at least 1.**

### Without residual (`x = f(x)`):
```
βˆ‚output/βˆ‚input = βˆ‚f/βˆ‚x 
This can be < 1, and (small)²⁰ β†’ 0
```

### With residual (`x = x + f(x)`):
```
βˆ‚output/βˆ‚input = 1 + βˆ‚f/βˆ‚x
This is always β‰₯ 1, so (β‰₯1)²⁰ β‰₯ 1
```

## πŸ“ Repository Structure

```
resmlp_comparison/
β”œβ”€β”€ README.md                    # This file
β”œβ”€β”€ experiment_final.py          # Main experiment code
β”œβ”€β”€ experiment_fair.py           # Fair comparison experiment
β”œβ”€β”€ visualize_micro_world.py     # Visualization generation
β”œβ”€β”€ results_fair.json            # Raw numerical results
β”œβ”€β”€ report_final.md              # Detailed analysis report
β”‚
β”œβ”€β”€ plots_fair/                  # Primary result plots
β”‚   β”œβ”€β”€ training_loss.png
β”‚   β”œβ”€β”€ gradient_magnitude.png
β”‚   β”œβ”€β”€ activation_mean.png
β”‚   └── activation_std.png
β”‚
└── plots_micro/                 # Educational visualizations
    β”œβ”€β”€ 1_signal_flow.png
    β”œβ”€β”€ 2_gradient_flow.png
    β”œβ”€β”€ 3_highway_concept.png
    β”œβ”€β”€ 4_chain_rule.png
    β”œβ”€β”€ 5_layer_transformation.png
    └── 6_learning_comparison.png
```

## πŸš€ Quick Start

### Installation
```bash
pip install torch numpy matplotlib
```

### Run Experiment
```bash
# Run the main fair comparison experiment
python experiment_fair.py

# Generate micro-world visualizations
python visualize_micro_world.py
```

## πŸ“ˆ Detailed Visualizations

### 1. Signal Flow Through Layers
![Signal Flow](plots_micro/1_signal_flow.png)
- PlainMLP signal collapses to near-zero by layer 15-20
- ResMLP signal stays stable throughout all layers

### 2. Gradient Flow (Backward Pass)
![Gradient Flow](plots_micro/2_gradient_flow.png)
- PlainMLP: Gradient decays from 10⁻³ to 10⁻¹⁹ (essentially zero!)
- ResMLP: Gradient stays healthy at ~10⁻³ across ALL layers

### 3. The Highway Concept
![Highway](plots_micro/3_highway_concept.png)
- The `+ x` creates a direct "gradient highway" for information flow

### 4. Chain Rule Mathematics
![Chain Rule](plots_micro/4_chain_rule.png)
- Visual explanation of why gradients vanish mathematically

## πŸ”‘ Why This Matters

### Historical Context
Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult due to vanishing gradients.

### The ResNet Revolution
He et al.'s simple insight enabled:
- **ImageNet SOTA** with 152 layers
- **Foundation for modern architectures**: Transformers use residual connections in every attention block
- **GPT, BERT, Vision Transformers** all rely on this principle

## πŸ“š Reports

- [`report_final.md`](report_final.md) - Comprehensive analysis with all visualizations
- [`report_fair.md`](report_fair.md) - Fair comparison methodology
- [`report.md`](report.md) - Initial experiment report

## πŸ“– Citation

If you find this educational resource helpful, please consider citing:

```bibtex
@misc{resmlp_comparison,
  title={Understanding Residual Connections: A Visual Deep Dive},
  author={AmberLJC},
  year={2024},
  url={https://huggingface.co/AmberLJC/resmlp_comparison}
}
```

## πŸ“„ License

MIT License - feel free to use for educational purposes!

## πŸ™ Acknowledgments

Inspired by the seminal work:
- He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. CVPR.