PlainMLP vs ResMLP Comparison - Distant Identity Task
Objective
Compare a 20-layer PlainMLP and ResMLP on a synthetic "Distant Identity" task to demonstrate the vanishing gradient problem and how residual connections solve it.
Tasks
Phase 1: Implementation
- Implement PlainMLP (20 layers, hidden dim 64, ReLU, Kaiming He init)
- Implement ResMLP (20 layers, hidden dim 64, residual connections, Kaiming He init)
- Generate synthetic data (1024 vectors, dim 64, U(-1,1), Y=X)
Phase 2: Training
- Train both models for 500 steps with Adam (lr=1e-3)
- Record MSE loss at each step
Phase 3: Final State Analysis
- Implement PyTorch hooks for gradient and activation capture
- Perform forward/backward pass on new random batch
- Capture L2 norm of gradients at each layer
- Capture mean and std of activations at each layer
Phase 4: Visualization & Reporting
- Plot Training Loss vs Steps (both models)
- Plot Gradient Magnitude vs Layer Depth
- Plot Activation Mean vs Layer Depth
- Plot Activation Std vs Layer Depth
- Write summary report with analysis
Expected Outcomes
- PlainMLP: Vanishing gradients, poor learning of identity function
- ResMLP: Stable gradients, successful learning of identity function