YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

🧠 Understanding Residual Connections: PlainMLP vs ResMLP

License: MIT Python 3.8+ PyTorch

A comprehensive visual deep dive into why residual connections solve the vanishing gradient problem and enable training of deep neural networks.

🎯 Key Finding

With identical initialization and architecture, the only difference being + x (residual connection), PlainMLP completely fails to learn (0% loss reduction) while ResMLP achieves 99.5% loss reduction.

Model Initial Loss Final Loss Loss Reduction
PlainMLP (20 layers) 0.333 0.333 0% ❌
ResMLP (20 layers) 13.826 0.063 99.5% βœ…

πŸ“Š Visual Results

Training Loss Comparison

Training Loss

Gradient Flow Analysis

Layer PlainMLP Gradient ResMLP Gradient
Layer 1 (earliest) 8.65 Γ— 10⁻¹⁹ πŸ’€ 3.78 Γ— 10⁻³ βœ…
Layer 10 (middle) 1.07 Γ— 10⁻⁹ 2.52 Γ— 10⁻³ βœ…
Layer 20 (last) 6.61 Γ— 10⁻³ 1.91 Γ— 10⁻³ βœ…

πŸ”¬ Experimental Setup

Task: Distant Identity (Y = X)

  • Input: 1024 vectors of dimension 64, sampled from U(-1, 1)
  • Target: Y = X (identity mapping)
  • Challenge: Can a 20-layer network learn to simply pass input to output?

Architecture Comparison

Component PlainMLP ResMLP
Layer operation x = ReLU(Linear(x)) x = x + ReLU(Linear(x))
Depth 20 layers 20 layers
Hidden dimension 64 64
Parameters 83,200 83,200

Fair Initialization (Critical!)

Both models use identical initialization:

  • Weights: Kaiming He Γ— (1/√20) scaling
  • Biases: Zero
  • No LayerNorm, no BatchNorm, no dropout

πŸŽ“ The Core Insight

The residual connection x = x + f(x) does ONE simple but profound thing:

It ensures that the gradient of the output with respect to the input is always at least 1.

Without residual (x = f(x)):

βˆ‚output/βˆ‚input = βˆ‚f/βˆ‚x 
This can be < 1, and (small)²⁰ β†’ 0

With residual (x = x + f(x)):

βˆ‚output/βˆ‚input = 1 + βˆ‚f/βˆ‚x
This is always β‰₯ 1, so (β‰₯1)²⁰ β‰₯ 1

πŸ“ Repository Structure

resmlp_comparison/
β”œβ”€β”€ README.md                    # This file
β”œβ”€β”€ experiment_final.py          # Main experiment code
β”œβ”€β”€ experiment_fair.py           # Fair comparison experiment
β”œβ”€β”€ visualize_micro_world.py     # Visualization generation
β”œβ”€β”€ results_fair.json            # Raw numerical results
β”œβ”€β”€ report_final.md              # Detailed analysis report
β”‚
β”œβ”€β”€ plots_fair/                  # Primary result plots
β”‚   β”œβ”€β”€ training_loss.png
β”‚   β”œβ”€β”€ gradient_magnitude.png
β”‚   β”œβ”€β”€ activation_mean.png
β”‚   └── activation_std.png
β”‚
└── plots_micro/                 # Educational visualizations
    β”œβ”€β”€ 1_signal_flow.png
    β”œβ”€β”€ 2_gradient_flow.png
    β”œβ”€β”€ 3_highway_concept.png
    β”œβ”€β”€ 4_chain_rule.png
    β”œβ”€β”€ 5_layer_transformation.png
    └── 6_learning_comparison.png

πŸš€ Quick Start

Installation

pip install torch numpy matplotlib

Run Experiment

# Run the main fair comparison experiment
python experiment_fair.py

# Generate micro-world visualizations
python visualize_micro_world.py

πŸ“ˆ Detailed Visualizations

1. Signal Flow Through Layers

Signal Flow

  • PlainMLP signal collapses to near-zero by layer 15-20
  • ResMLP signal stays stable throughout all layers

2. Gradient Flow (Backward Pass)

Gradient Flow

  • PlainMLP: Gradient decays from 10⁻³ to 10⁻¹⁹ (essentially zero!)
  • ResMLP: Gradient stays healthy at ~10⁻³ across ALL layers

3. The Highway Concept

Highway

  • The + x creates a direct "gradient highway" for information flow

4. Chain Rule Mathematics

Chain Rule

  • Visual explanation of why gradients vanish mathematically

πŸ”‘ Why This Matters

Historical Context

Before ResNets (2015), training networks deeper than ~20 layers was extremely difficult due to vanishing gradients.

The ResNet Revolution

He et al.'s simple insight enabled:

  • ImageNet SOTA with 152 layers
  • Foundation for modern architectures: Transformers use residual connections in every attention block
  • GPT, BERT, Vision Transformers all rely on this principle

πŸ“š Reports

πŸ“– Citation

If you find this educational resource helpful, please consider citing:

@misc{resmlp_comparison,
  title={Understanding Residual Connections: A Visual Deep Dive},
  author={AmberLJC},
  year={2024},
  url={https://huggingface.co/AmberLJC/resmlp_comparison}
}

πŸ“„ License

MIT License - feel free to use for educational purposes!

πŸ™ Acknowledgments

Inspired by the seminal work:

  • He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. CVPR.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support