|
|
""" |
|
|
PlainMLP vs ResMLP Comparison on Distant Identity Task (Final Version) |
|
|
|
|
|
This experiment demonstrates the vanishing gradient problem in deep networks |
|
|
and how residual connections solve it. |
|
|
|
|
|
Key Design Choices: |
|
|
1. PlainMLP: Standard x = ReLU(Linear(x)) - suffers from vanishing gradients |
|
|
2. ResMLP: x = x + ReLU(Linear(x)) with zero-initialized bias and small weight scale |
|
|
- This allows the network to start as near-identity and learn deviations |
|
|
- Gradients can flow through the skip connection even when residual branch is small |
|
|
|
|
|
The "Distant Identity" task (Y=X) is particularly revealing because: |
|
|
- ResMLP can trivially solve it by zeroing the residual branch (identity shortcut) |
|
|
- PlainMLP must learn a complex function composition to approximate identity |
|
|
- With ReLU, PlainMLP can never perfectly learn identity (negative values are zeroed) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import Dict, List, Tuple |
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
torch.manual_seed(42) |
|
|
np.random.seed(42) |
|
|
|
|
|
|
|
|
NUM_LAYERS = 20 |
|
|
HIDDEN_DIM = 64 |
|
|
NUM_SAMPLES = 1024 |
|
|
TRAINING_STEPS = 500 |
|
|
LEARNING_RATE = 1e-3 |
|
|
BATCH_SIZE = 64 |
|
|
|
|
|
print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}") |
|
|
print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}") |
|
|
|
|
|
|
|
|
class PlainMLP(nn.Module): |
|
|
"""Plain MLP: x = ReLU(Linear(x)) for each layer |
|
|
|
|
|
This architecture suffers from: |
|
|
1. Vanishing gradients - gradients must flow through all layers multiplicatively |
|
|
2. Information loss - ReLU zeros negative values at each layer |
|
|
3. Complex optimization - must learn exact function composition for identity |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, num_layers: int): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList() |
|
|
for _ in range(num_layers): |
|
|
layer = nn.Linear(dim, dim) |
|
|
|
|
|
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') |
|
|
nn.init.zeros_(layer.bias) |
|
|
self.layers.append(layer) |
|
|
self.activation = nn.ReLU() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
for layer in self.layers: |
|
|
x = self.activation(layer(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class ResMLP(nn.Module): |
|
|
"""Residual MLP: x = x + ReLU(Linear(x)) for each layer |
|
|
|
|
|
Key advantages: |
|
|
1. Identity shortcut - gradients flow directly to early layers via skip connection |
|
|
2. Residual learning - network learns deviation from identity, not full mapping |
|
|
3. For identity task - optimal solution is to zero the residual branch |
|
|
|
|
|
Uses small weight initialization (scaled by 1/sqrt(num_layers)) to: |
|
|
- Start near-identity behavior |
|
|
- Prevent activation explosion |
|
|
- Allow gradual learning of residuals |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, num_layers: int): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList() |
|
|
self.num_layers = num_layers |
|
|
|
|
|
for _ in range(num_layers): |
|
|
layer = nn.Linear(dim, dim) |
|
|
|
|
|
|
|
|
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') |
|
|
layer.weight.data *= 1.0 / np.sqrt(num_layers) |
|
|
nn.init.zeros_(layer.bias) |
|
|
self.layers.append(layer) |
|
|
self.activation = nn.ReLU() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
for layer in self.layers: |
|
|
x = x + self.activation(layer(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Generate synthetic data where Y = X, with X ~ U(-1, 1)""" |
|
|
X = torch.empty(num_samples, dim).uniform_(-1, 1) |
|
|
Y = X.clone() |
|
|
return X, Y |
|
|
|
|
|
|
|
|
def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor, |
|
|
steps: int, lr: float, batch_size: int) -> List[float]: |
|
|
"""Train model and record loss at each step""" |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
|
|
criterion = nn.MSELoss() |
|
|
losses = [] |
|
|
|
|
|
num_samples = X.shape[0] |
|
|
|
|
|
for step in range(steps): |
|
|
|
|
|
indices = torch.randint(0, num_samples, (batch_size,)) |
|
|
batch_x = X[indices] |
|
|
batch_y = Y[indices] |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
output = model(batch_x) |
|
|
loss = criterion(output, batch_y) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
losses.append(loss.item()) |
|
|
|
|
|
if step % 100 == 0: |
|
|
print(f" Step {step}/{steps}, Loss: {loss.item():.6f}") |
|
|
|
|
|
return losses |
|
|
|
|
|
|
|
|
class ActivationGradientHook: |
|
|
"""Hook to capture activations and gradients at each layer""" |
|
|
|
|
|
def __init__(self): |
|
|
self.activations: List[torch.Tensor] = [] |
|
|
self.gradients: List[torch.Tensor] = [] |
|
|
self.handles = [] |
|
|
|
|
|
def register_hooks(self, model: nn.Module): |
|
|
"""Register forward and backward hooks on each layer""" |
|
|
for layer in model.layers: |
|
|
handle_fwd = layer.register_forward_hook(self._forward_hook) |
|
|
handle_bwd = layer.register_full_backward_hook(self._backward_hook) |
|
|
self.handles.extend([handle_fwd, handle_bwd]) |
|
|
|
|
|
def _forward_hook(self, module, input, output): |
|
|
self.activations.append(output.detach().clone()) |
|
|
|
|
|
def _backward_hook(self, module, grad_input, grad_output): |
|
|
self.gradients.append(grad_output[0].detach().clone()) |
|
|
|
|
|
def clear(self): |
|
|
self.activations = [] |
|
|
self.gradients = [] |
|
|
|
|
|
def remove_hooks(self): |
|
|
for handle in self.handles: |
|
|
handle.remove() |
|
|
self.handles = [] |
|
|
|
|
|
def get_activation_stats(self) -> Tuple[List[float], List[float]]: |
|
|
"""Get mean and std of activations for each layer""" |
|
|
means = [act.mean().item() for act in self.activations] |
|
|
stds = [act.std().item() for act in self.activations] |
|
|
return means, stds |
|
|
|
|
|
def get_gradient_norms(self) -> List[float]: |
|
|
"""Get L2 norm of gradients for each layer (in forward order)""" |
|
|
norms = [grad.norm(2).item() for grad in reversed(self.gradients)] |
|
|
return norms |
|
|
|
|
|
|
|
|
def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict: |
|
|
"""Perform forward/backward pass and capture activation/gradient stats""" |
|
|
hook = ActivationGradientHook() |
|
|
hook.register_hooks(model) |
|
|
|
|
|
|
|
|
X_test = torch.empty(batch_size, dim).uniform_(-1, 1) |
|
|
Y_test = X_test.clone() |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
output = model(X_test) |
|
|
loss = nn.MSELoss()(output, Y_test) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
act_means, act_stds = hook.get_activation_stats() |
|
|
grad_norms = hook.get_gradient_norms() |
|
|
|
|
|
hook.remove_hooks() |
|
|
|
|
|
return { |
|
|
'activation_means': act_means, |
|
|
'activation_stds': act_stds, |
|
|
'gradient_norms': grad_norms, |
|
|
'final_loss': loss.item() |
|
|
} |
|
|
|
|
|
|
|
|
def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str): |
|
|
"""Plot training loss curves for both models""" |
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
steps = range(len(plain_losses)) |
|
|
|
|
|
ax.plot(steps, plain_losses, label='PlainMLP (20 layers)', color='#e74c3c', |
|
|
alpha=0.8, linewidth=2) |
|
|
ax.plot(steps, res_losses, label='ResMLP (20 layers)', color='#3498db', |
|
|
alpha=0.8, linewidth=2) |
|
|
|
|
|
ax.set_xlabel('Training Steps', fontsize=12) |
|
|
ax.set_ylabel('MSE Loss', fontsize=12) |
|
|
ax.set_title('Training Loss: PlainMLP vs ResMLP on Identity Task (Y = X)', fontsize=14) |
|
|
ax.legend(fontsize=11, loc='upper right') |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.set_yscale('log') |
|
|
|
|
|
|
|
|
final_plain = plain_losses[-1] |
|
|
final_res = res_losses[-1] |
|
|
|
|
|
|
|
|
textstr = f'Final Loss:\n PlainMLP: {final_plain:.4f}\n ResMLP: {final_res:.4f}\n Improvement: {final_plain/final_res:.1f}x' |
|
|
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8) |
|
|
ax.text(0.02, 0.02, textstr, transform=ax.transAxes, fontsize=10, |
|
|
verticalalignment='bottom', bbox=props) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"[Plot] Saved training loss plot to {save_path}") |
|
|
|
|
|
|
|
|
def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str): |
|
|
"""Plot gradient magnitude vs layer depth""" |
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
layers = range(1, len(plain_grads) + 1) |
|
|
|
|
|
ax.plot(layers, plain_grads, 'o-', label='PlainMLP', color='#e74c3c', |
|
|
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) |
|
|
ax.plot(layers, res_grads, 's-', label='ResMLP', color='#3498db', |
|
|
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) |
|
|
|
|
|
ax.set_xlabel('Layer Depth (1 = first layer, 20 = last layer)', fontsize=12) |
|
|
ax.set_ylabel('Gradient L2 Norm (log scale)', fontsize=12) |
|
|
ax.set_title('Gradient Magnitude vs Layer Depth (After 500 Training Steps)', fontsize=14) |
|
|
ax.legend(fontsize=11) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.set_yscale('log') |
|
|
|
|
|
|
|
|
ax.fill_between(layers, plain_grads, res_grads, alpha=0.15, color='gray') |
|
|
|
|
|
|
|
|
ax.annotate('Gradients flow more\nuniformly in ResMLP', |
|
|
xy=(10, res_grads[9]), xytext=(5, res_grads[9]*5), |
|
|
fontsize=10, color='#3498db', |
|
|
arrowprops=dict(arrowstyle='->', color='#3498db', alpha=0.7)) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"[Plot] Saved gradient magnitude plot to {save_path}") |
|
|
|
|
|
|
|
|
def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str): |
|
|
"""Plot activation mean vs layer depth""" |
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
layers = range(1, len(plain_means) + 1) |
|
|
|
|
|
ax.plot(layers, plain_means, 'o-', label='PlainMLP', color='#e74c3c', |
|
|
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) |
|
|
ax.plot(layers, res_means, 's-', label='ResMLP', color='#3498db', |
|
|
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) |
|
|
|
|
|
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, linewidth=1) |
|
|
|
|
|
ax.set_xlabel('Layer Depth', fontsize=12) |
|
|
ax.set_ylabel('Activation Mean', fontsize=12) |
|
|
ax.set_title('Activation Mean vs Layer Depth (After Training)', fontsize=14) |
|
|
ax.legend(fontsize=11) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"[Plot] Saved activation mean plot to {save_path}") |
|
|
|
|
|
|
|
|
def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str): |
|
|
"""Plot activation std vs layer depth""" |
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
layers = range(1, len(plain_stds) + 1) |
|
|
|
|
|
ax.plot(layers, plain_stds, 'o-', label='PlainMLP', color='#e74c3c', |
|
|
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) |
|
|
ax.plot(layers, res_stds, 's-', label='ResMLP', color='#3498db', |
|
|
markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1) |
|
|
|
|
|
ax.set_xlabel('Layer Depth', fontsize=12) |
|
|
ax.set_ylabel('Activation Standard Deviation', fontsize=12) |
|
|
ax.set_title('Activation Std vs Layer Depth (After Training)', fontsize=14) |
|
|
ax.legend(fontsize=11) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax.annotate('ResMLP maintains\nstable activations', |
|
|
xy=(15, res_stds[14]), xytext=(10, res_stds[14]*1.3), |
|
|
fontsize=10, color='#3498db', |
|
|
arrowprops=dict(arrowstyle='->', color='#3498db', alpha=0.7)) |
|
|
|
|
|
ax.annotate('PlainMLP activations\ndegrade through layers', |
|
|
xy=(18, plain_stds[17]), xytext=(12, plain_stds[17]*0.5), |
|
|
fontsize=10, color='#e74c3c', |
|
|
arrowprops=dict(arrowstyle='->', color='#e74c3c', alpha=0.7)) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"[Plot] Saved activation std plot to {save_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("=" * 60) |
|
|
print("PlainMLP vs ResMLP: Distant Identity Task Experiment") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
os.makedirs('plots', exist_ok=True) |
|
|
|
|
|
|
|
|
print("\n[1] Generating synthetic identity data...") |
|
|
X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM) |
|
|
print(f" Data shape: X={X.shape}, Y={Y.shape}") |
|
|
print(f" X range: [{X.min():.3f}, {X.max():.3f}]") |
|
|
print(f" Task: Learn Y = X (identity mapping)") |
|
|
|
|
|
|
|
|
print("\n[2] Initializing models...") |
|
|
plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS) |
|
|
res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS) |
|
|
|
|
|
plain_params = sum(p.numel() for p in plain_mlp.parameters()) |
|
|
res_params = sum(p.numel() for p in res_mlp.parameters()) |
|
|
print(f" PlainMLP parameters: {plain_params:,}") |
|
|
print(f" ResMLP parameters: {res_params:,}") |
|
|
|
|
|
|
|
|
print("\n[3] Training PlainMLP...") |
|
|
plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) |
|
|
print(f" Final loss: {plain_losses[-1]:.6f}") |
|
|
|
|
|
|
|
|
print("\n[4] Training ResMLP...") |
|
|
res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE) |
|
|
print(f" Final loss: {res_losses[-1]:.6f}") |
|
|
|
|
|
|
|
|
improvement = plain_losses[-1] / res_losses[-1] |
|
|
print(f"\n >>> ResMLP achieves {improvement:.1f}x lower loss than PlainMLP <<<") |
|
|
|
|
|
|
|
|
print("\n[5] Analyzing final state of trained models...") |
|
|
print(" Running forward/backward pass on new random batch...") |
|
|
print(" Analyzing PlainMLP...") |
|
|
plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM) |
|
|
print(" Analyzing ResMLP...") |
|
|
res_stats = analyze_final_state(res_mlp, HIDDEN_DIM) |
|
|
|
|
|
|
|
|
print("\n[6] Detailed Analysis:") |
|
|
print("\n === Loss Comparison ===") |
|
|
print(f" PlainMLP - Initial: {plain_losses[0]:.4f}, Final: {plain_losses[-1]:.4f}") |
|
|
print(f" ResMLP - Initial: {res_losses[0]:.4f}, Final: {res_losses[-1]:.4f}") |
|
|
|
|
|
print("\n === Gradient Flow (L2 norms) ===") |
|
|
print(f" PlainMLP - Layer 1: {plain_stats['gradient_norms'][0]:.2e}, Layer 20: {plain_stats['gradient_norms'][-1]:.2e}") |
|
|
print(f" ResMLP - Layer 1: {res_stats['gradient_norms'][0]:.2e}, Layer 20: {res_stats['gradient_norms'][-1]:.2e}") |
|
|
|
|
|
print("\n === Activation Statistics ===") |
|
|
print(f" PlainMLP - Std range: [{min(plain_stats['activation_stds']):.4f}, {max(plain_stats['activation_stds']):.4f}]") |
|
|
print(f" ResMLP - Std range: [{min(res_stats['activation_stds']):.4f}, {max(res_stats['activation_stds']):.4f}]") |
|
|
|
|
|
|
|
|
print("\n[7] Generating plots...") |
|
|
plot_training_loss(plain_losses, res_losses, 'plots/training_loss.png') |
|
|
plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'], |
|
|
'plots/gradient_magnitude.png') |
|
|
plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'], |
|
|
'plots/activation_mean.png') |
|
|
plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'], |
|
|
'plots/activation_std.png') |
|
|
|
|
|
|
|
|
results = { |
|
|
'config': { |
|
|
'num_layers': NUM_LAYERS, |
|
|
'hidden_dim': HIDDEN_DIM, |
|
|
'num_samples': NUM_SAMPLES, |
|
|
'training_steps': TRAINING_STEPS, |
|
|
'learning_rate': LEARNING_RATE, |
|
|
'batch_size': BATCH_SIZE |
|
|
}, |
|
|
'plain_mlp': { |
|
|
'final_loss': plain_losses[-1], |
|
|
'initial_loss': plain_losses[0], |
|
|
'loss_history': plain_losses, |
|
|
'gradient_norms': plain_stats['gradient_norms'], |
|
|
'activation_means': plain_stats['activation_means'], |
|
|
'activation_stds': plain_stats['activation_stds'] |
|
|
}, |
|
|
'res_mlp': { |
|
|
'final_loss': res_losses[-1], |
|
|
'initial_loss': res_losses[0], |
|
|
'loss_history': res_losses, |
|
|
'gradient_norms': res_stats['gradient_norms'], |
|
|
'activation_means': res_stats['activation_means'], |
|
|
'activation_stds': res_stats['activation_stds'] |
|
|
}, |
|
|
'summary': { |
|
|
'loss_improvement': improvement, |
|
|
'plain_grad_range': [min(plain_stats['gradient_norms']), max(plain_stats['gradient_norms'])], |
|
|
'res_grad_range': [min(res_stats['gradient_norms']), max(res_stats['gradient_norms'])], |
|
|
'plain_std_range': [min(plain_stats['activation_stds']), max(plain_stats['activation_stds'])], |
|
|
'res_std_range': [min(res_stats['activation_stds']), max(res_stats['activation_stds'])] |
|
|
} |
|
|
} |
|
|
|
|
|
with open('results.json', 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
print("\n[8] Results saved to results.json") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Experiment completed successfully!") |
|
|
print("=" * 60) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
results = main() |
|
|
|