resmlp_comparison / README.md
AmberLJC's picture
Upload README.md with huggingface_hub
538e428 verified
# 🧠 Understanding Residual Connections: PlainMLP vs ResMLP
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-ee4c2c.svg)](https://pytorch.org/)
A comprehensive visual deep dive into **why residual connections solve the vanishing gradient problem** and enable training of deep neural networks.
## 🎯 Key Finding
> With identical initialization and architecture, the **only difference being `+ x` (residual connection)**, PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves **99.5% loss reduction**.
| Model | Initial Loss | Final Loss | Loss Reduction |
|-------|-------------|------------|----------------|
| PlainMLP (20 layers) | 0.333 | 0.333 | **0%** ❌ |
| ResMLP (20 layers) | 13.826 | 0.063 | **99.5%** βœ… |
## πŸ“Š Visual Results
### Training Loss Comparison
![Training Loss](plots_fair/training_loss.png)
### Gradient Flow Analysis
| Layer | PlainMLP Gradient | ResMLP Gradient |
|-------|-------------------|-----------------|
| Layer 1 (earliest) | 8.65 Γ— 10⁻¹⁹ πŸ’€ | 3.78 Γ— 10⁻³ βœ… |
| Layer 10 (middle) | 1.07 Γ— 10⁻⁹ | 2.52 Γ— 10⁻³ βœ… |
| Layer 20 (last) | 6.61 Γ— 10⁻³ | 1.91 Γ— 10⁻³ βœ… |
## πŸ”¬ Experimental Setup
### Task: Distant Identity (Y = X)
- **Input**: 1024 vectors of dimension 64, sampled from U(-1, 1)
- **Target**: Y = X (identity mapping)
- **Challenge**: Can a 20-layer network learn to simply pass input to output?
### Architecture Comparison
| Component | PlainMLP | ResMLP |
|-----------|----------|--------|
| Layer operation | `x = ReLU(Linear(x))` | `x = x + ReLU(Linear(x))` |
| Depth | 20 layers | 20 layers |
| Hidden dimension | 64 | 64 |
| Parameters | 83,200 | 83,200 |
### Fair Initialization (Critical!)
Both models use **identical initialization**:
- **Weights**: Kaiming He Γ— (1/√20) scaling
- **Biases**: Zero
- **No LayerNorm, no BatchNorm, no dropout**
## πŸŽ“ The Core Insight
The residual connection `x = x + f(x)` does ONE simple but profound thing:
> **It ensures that the gradient of the output with respect to the input is always at least 1.**
### Without residual (`x = f(x)`):
```
βˆ‚output/βˆ‚input = βˆ‚f/βˆ‚x
This can be < 1, and (small)²⁰ β†’ 0
```
### With residual (`x = x + f(x)`):
```
βˆ‚output/βˆ‚input = 1 + βˆ‚f/βˆ‚x
This is always β‰₯ 1, so (β‰₯1)²⁰ β‰₯ 1
```
## πŸ“ Repository Structure
```
resmlp_comparison/
β”œβ”€β”€ README.md # This file
β”œβ”€β”€ experiment_final.py # Main experiment code
β”œβ”€β”€ experiment_fair.py # Fair comparison experiment
β”œβ”€β”€ visualize_micro_world.py # Visualization generation
β”œβ”€β”€ results_fair.json # Raw numerical results
β”œβ”€β”€ report_final.md # Detailed analysis report
β”‚
β”œβ”€β”€ plots_fair/ # Primary result plots
β”‚ β”œβ”€β”€ training_loss.png
β”‚ β”œβ”€β”€ gradient_magnitude.png
β”‚ β”œβ”€β”€ activation_mean.png
β”‚ └── activation_std.png
β”‚
└── plots_micro/ # Educational visualizations
β”œβ”€β”€ 1_signal_flow.png
β”œβ”€β”€ 2_gradient_flow.png
β”œβ”€β”€ 3_highway_concept.png
β”œβ”€β”€ 4_chain_rule.png
β”œβ”€β”€ 5_layer_transformation.png
└── 6_learning_comparison.png
```
## πŸš€ Quick Start
### Installation
```bash
pip install torch numpy matplotlib
```
### Run Experiment
```bash
# Run the main fair comparison experiment
python experiment_fair.py
# Generate micro-world visualizations
python visualize_micro_world.py
```
## πŸ“ˆ Detailed Visualizations
### 1. Signal Flow Through Layers
![Signal Flow](plots_micro/1_signal_flow.png)
- PlainMLP signal collapses to near-zero by layer 15-20
- ResMLP signal stays stable throughout all layers
### 2. Gradient Flow (Backward Pass)
![Gradient Flow](plots_micro/2_gradient_flow.png)
- PlainMLP: Gradient decays from 10⁻³ to 10⁻¹⁹ (essentially zero!)
- ResMLP: Gradient stays healthy at ~10⁻³ across ALL layers
### 3. The Highway Concept
![Highway](plots_micro/3_highway_concept.png)
- The `+ x` creates a direct "gradient highway" for information flow
### 4. Chain Rule Mathematics
![Chain Rule](plots_micro/4_chain_rule.png)
- Visual explanation of why gradients vanish mathematically
## πŸ”‘ Why This Matters
### Historical Context
Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult due to vanishing gradients.
### The ResNet Revolution
He et al.'s simple insight enabled:
- **ImageNet SOTA** with 152 layers
- **Foundation for modern architectures**: Transformers use residual connections in every attention block
- **GPT, BERT, Vision Transformers** all rely on this principle
## πŸ“š Reports
- [`report_final.md`](report_final.md) - Comprehensive analysis with all visualizations
- [`report_fair.md`](report_fair.md) - Fair comparison methodology
- [`report.md`](report.md) - Initial experiment report
## πŸ“– Citation
If you find this educational resource helpful, please consider citing:
```bibtex
@misc{resmlp_comparison,
title={Understanding Residual Connections: A Visual Deep Dive},
author={AmberLJC},
year={2024},
url={https://huggingface.co/AmberLJC/resmlp_comparison}
}
```
## πŸ“„ License
MIT License - feel free to use for educational purposes!
## πŸ™ Acknowledgments
Inspired by the seminal work:
- He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. CVPR.