AmberLJC commited on
Commit
eb31f6a
·
verified ·
1 Parent(s): 85ee51e

Upload experiment_final.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. experiment_final.py +456 -0
experiment_final.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlainMLP vs ResMLP Comparison on Distant Identity Task (Final Version)
3
+
4
+ This experiment demonstrates the vanishing gradient problem in deep networks
5
+ and how residual connections solve it.
6
+
7
+ Key Design Choices:
8
+ 1. PlainMLP: Standard x = ReLU(Linear(x)) - suffers from vanishing gradients
9
+ 2. ResMLP: x = x + ReLU(Linear(x)) with zero-initialized bias and small weight scale
10
+ - This allows the network to start as near-identity and learn deviations
11
+ - Gradients can flow through the skip connection even when residual branch is small
12
+
13
+ The "Distant Identity" task (Y=X) is particularly revealing because:
14
+ - ResMLP can trivially solve it by zeroing the residual branch (identity shortcut)
15
+ - PlainMLP must learn a complex function composition to approximate identity
16
+ - With ReLU, PlainMLP can never perfectly learn identity (negative values are zeroed)
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import numpy as np
22
+ import matplotlib.pyplot as plt
23
+ from typing import Dict, List, Tuple
24
+ import json
25
+ import os
26
+
27
+ # Set random seeds for reproducibility
28
+ torch.manual_seed(42)
29
+ np.random.seed(42)
30
+
31
+ # Configuration
32
+ NUM_LAYERS = 20
33
+ HIDDEN_DIM = 64
34
+ NUM_SAMPLES = 1024
35
+ TRAINING_STEPS = 500
36
+ LEARNING_RATE = 1e-3
37
+ BATCH_SIZE = 64
38
+
39
+ print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}")
40
+ print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}")
41
+
42
+
43
+ class PlainMLP(nn.Module):
44
+ """Plain MLP: x = ReLU(Linear(x)) for each layer
45
+
46
+ This architecture suffers from:
47
+ 1. Vanishing gradients - gradients must flow through all layers multiplicatively
48
+ 2. Information loss - ReLU zeros negative values at each layer
49
+ 3. Complex optimization - must learn exact function composition for identity
50
+ """
51
+
52
+ def __init__(self, dim: int, num_layers: int):
53
+ super().__init__()
54
+ self.layers = nn.ModuleList()
55
+ for _ in range(num_layers):
56
+ layer = nn.Linear(dim, dim)
57
+ # Kaiming He initialization
58
+ nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
59
+ nn.init.zeros_(layer.bias)
60
+ self.layers.append(layer)
61
+ self.activation = nn.ReLU()
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ for layer in self.layers:
65
+ x = self.activation(layer(x))
66
+ return x
67
+
68
+
69
+ class ResMLP(nn.Module):
70
+ """Residual MLP: x = x + ReLU(Linear(x)) for each layer
71
+
72
+ Key advantages:
73
+ 1. Identity shortcut - gradients flow directly to early layers via skip connection
74
+ 2. Residual learning - network learns deviation from identity, not full mapping
75
+ 3. For identity task - optimal solution is to zero the residual branch
76
+
77
+ Uses small weight initialization (scaled by 1/sqrt(num_layers)) to:
78
+ - Start near-identity behavior
79
+ - Prevent activation explosion
80
+ - Allow gradual learning of residuals
81
+ """
82
+
83
+ def __init__(self, dim: int, num_layers: int):
84
+ super().__init__()
85
+ self.layers = nn.ModuleList()
86
+ self.num_layers = num_layers
87
+
88
+ for _ in range(num_layers):
89
+ layer = nn.Linear(dim, dim)
90
+ # Small initialization for residual branch
91
+ # This ensures the network starts close to identity
92
+ nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
93
+ layer.weight.data *= 1.0 / np.sqrt(num_layers) # Scale down weights
94
+ nn.init.zeros_(layer.bias)
95
+ self.layers.append(layer)
96
+ self.activation = nn.ReLU()
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ for layer in self.layers:
100
+ x = x + self.activation(layer(x)) # Residual connection
101
+ return x
102
+
103
+
104
+ def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
105
+ """Generate synthetic data where Y = X, with X ~ U(-1, 1)"""
106
+ X = torch.empty(num_samples, dim).uniform_(-1, 1)
107
+ Y = X.clone()
108
+ return X, Y
109
+
110
+
111
+ def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor,
112
+ steps: int, lr: float, batch_size: int) -> List[float]:
113
+ """Train model and record loss at each step"""
114
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
115
+ criterion = nn.MSELoss()
116
+ losses = []
117
+
118
+ num_samples = X.shape[0]
119
+
120
+ for step in range(steps):
121
+ # Random batch sampling
122
+ indices = torch.randint(0, num_samples, (batch_size,))
123
+ batch_x = X[indices]
124
+ batch_y = Y[indices]
125
+
126
+ # Forward pass
127
+ optimizer.zero_grad()
128
+ output = model(batch_x)
129
+ loss = criterion(output, batch_y)
130
+
131
+ # Backward pass
132
+ loss.backward()
133
+ optimizer.step()
134
+
135
+ losses.append(loss.item())
136
+
137
+ if step % 100 == 0:
138
+ print(f" Step {step}/{steps}, Loss: {loss.item():.6f}")
139
+
140
+ return losses
141
+
142
+
143
+ class ActivationGradientHook:
144
+ """Hook to capture activations and gradients at each layer"""
145
+
146
+ def __init__(self):
147
+ self.activations: List[torch.Tensor] = []
148
+ self.gradients: List[torch.Tensor] = []
149
+ self.handles = []
150
+
151
+ def register_hooks(self, model: nn.Module):
152
+ """Register forward and backward hooks on each layer"""
153
+ for layer in model.layers:
154
+ handle_fwd = layer.register_forward_hook(self._forward_hook)
155
+ handle_bwd = layer.register_full_backward_hook(self._backward_hook)
156
+ self.handles.extend([handle_fwd, handle_bwd])
157
+
158
+ def _forward_hook(self, module, input, output):
159
+ self.activations.append(output.detach().clone())
160
+
161
+ def _backward_hook(self, module, grad_input, grad_output):
162
+ self.gradients.append(grad_output[0].detach().clone())
163
+
164
+ def clear(self):
165
+ self.activations = []
166
+ self.gradients = []
167
+
168
+ def remove_hooks(self):
169
+ for handle in self.handles:
170
+ handle.remove()
171
+ self.handles = []
172
+
173
+ def get_activation_stats(self) -> Tuple[List[float], List[float]]:
174
+ """Get mean and std of activations for each layer"""
175
+ means = [act.mean().item() for act in self.activations]
176
+ stds = [act.std().item() for act in self.activations]
177
+ return means, stds
178
+
179
+ def get_gradient_norms(self) -> List[float]:
180
+ """Get L2 norm of gradients for each layer (in forward order)"""
181
+ norms = [grad.norm(2).item() for grad in reversed(self.gradients)]
182
+ return norms
183
+
184
+
185
+ def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict:
186
+ """Perform forward/backward pass and capture activation/gradient stats"""
187
+ hook = ActivationGradientHook()
188
+ hook.register_hooks(model)
189
+
190
+ # Generate new random batch
191
+ X_test = torch.empty(batch_size, dim).uniform_(-1, 1)
192
+ Y_test = X_test.clone()
193
+
194
+ # Forward pass
195
+ model.zero_grad()
196
+ output = model(X_test)
197
+ loss = nn.MSELoss()(output, Y_test)
198
+
199
+ # Backward pass
200
+ loss.backward()
201
+
202
+ # Get statistics
203
+ act_means, act_stds = hook.get_activation_stats()
204
+ grad_norms = hook.get_gradient_norms()
205
+
206
+ hook.remove_hooks()
207
+
208
+ return {
209
+ 'activation_means': act_means,
210
+ 'activation_stds': act_stds,
211
+ 'gradient_norms': grad_norms,
212
+ 'final_loss': loss.item()
213
+ }
214
+
215
+
216
+ def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str):
217
+ """Plot training loss curves for both models"""
218
+ fig, ax = plt.subplots(figsize=(10, 6))
219
+ steps = range(len(plain_losses))
220
+
221
+ ax.plot(steps, plain_losses, label='PlainMLP (20 layers)', color='#e74c3c',
222
+ alpha=0.8, linewidth=2)
223
+ ax.plot(steps, res_losses, label='ResMLP (20 layers)', color='#3498db',
224
+ alpha=0.8, linewidth=2)
225
+
226
+ ax.set_xlabel('Training Steps', fontsize=12)
227
+ ax.set_ylabel('MSE Loss', fontsize=12)
228
+ ax.set_title('Training Loss: PlainMLP vs ResMLP on Identity Task (Y = X)', fontsize=14)
229
+ ax.legend(fontsize=11, loc='upper right')
230
+ ax.grid(True, alpha=0.3)
231
+ ax.set_yscale('log')
232
+
233
+ # Add final loss annotations
234
+ final_plain = plain_losses[-1]
235
+ final_res = res_losses[-1]
236
+
237
+ # Text box with final results
238
+ textstr = f'Final Loss:\n PlainMLP: {final_plain:.4f}\n ResMLP: {final_res:.4f}\n Improvement: {final_plain/final_res:.1f}x'
239
+ props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
240
+ ax.text(0.02, 0.02, textstr, transform=ax.transAxes, fontsize=10,
241
+ verticalalignment='bottom', bbox=props)
242
+
243
+ plt.tight_layout()
244
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
245
+ plt.close()
246
+ print(f"[Plot] Saved training loss plot to {save_path}")
247
+
248
+
249
+ def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str):
250
+ """Plot gradient magnitude vs layer depth"""
251
+ fig, ax = plt.subplots(figsize=(10, 6))
252
+ layers = range(1, len(plain_grads) + 1)
253
+
254
+ ax.plot(layers, plain_grads, 'o-', label='PlainMLP', color='#e74c3c',
255
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
256
+ ax.plot(layers, res_grads, 's-', label='ResMLP', color='#3498db',
257
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
258
+
259
+ ax.set_xlabel('Layer Depth (1 = first layer, 20 = last layer)', fontsize=12)
260
+ ax.set_ylabel('Gradient L2 Norm (log scale)', fontsize=12)
261
+ ax.set_title('Gradient Magnitude vs Layer Depth (After 500 Training Steps)', fontsize=14)
262
+ ax.legend(fontsize=11)
263
+ ax.grid(True, alpha=0.3)
264
+ ax.set_yscale('log')
265
+
266
+ # Highlight the gradient difference
267
+ ax.fill_between(layers, plain_grads, res_grads, alpha=0.15, color='gray')
268
+
269
+ # Add annotation about gradient flow
270
+ ax.annotate('Gradients flow more\nuniformly in ResMLP',
271
+ xy=(10, res_grads[9]), xytext=(5, res_grads[9]*5),
272
+ fontsize=10, color='#3498db',
273
+ arrowprops=dict(arrowstyle='->', color='#3498db', alpha=0.7))
274
+
275
+ plt.tight_layout()
276
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
277
+ plt.close()
278
+ print(f"[Plot] Saved gradient magnitude plot to {save_path}")
279
+
280
+
281
+ def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str):
282
+ """Plot activation mean vs layer depth"""
283
+ fig, ax = plt.subplots(figsize=(10, 6))
284
+ layers = range(1, len(plain_means) + 1)
285
+
286
+ ax.plot(layers, plain_means, 'o-', label='PlainMLP', color='#e74c3c',
287
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
288
+ ax.plot(layers, res_means, 's-', label='ResMLP', color='#3498db',
289
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
290
+
291
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
292
+
293
+ ax.set_xlabel('Layer Depth', fontsize=12)
294
+ ax.set_ylabel('Activation Mean', fontsize=12)
295
+ ax.set_title('Activation Mean vs Layer Depth (After Training)', fontsize=14)
296
+ ax.legend(fontsize=11)
297
+ ax.grid(True, alpha=0.3)
298
+
299
+ plt.tight_layout()
300
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
301
+ plt.close()
302
+ print(f"[Plot] Saved activation mean plot to {save_path}")
303
+
304
+
305
+ def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str):
306
+ """Plot activation std vs layer depth"""
307
+ fig, ax = plt.subplots(figsize=(10, 6))
308
+ layers = range(1, len(plain_stds) + 1)
309
+
310
+ ax.plot(layers, plain_stds, 'o-', label='PlainMLP', color='#e74c3c',
311
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
312
+ ax.plot(layers, res_stds, 's-', label='ResMLP', color='#3498db',
313
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
314
+
315
+ ax.set_xlabel('Layer Depth', fontsize=12)
316
+ ax.set_ylabel('Activation Standard Deviation', fontsize=12)
317
+ ax.set_title('Activation Std vs Layer Depth (After Training)', fontsize=14)
318
+ ax.legend(fontsize=11)
319
+ ax.grid(True, alpha=0.3)
320
+
321
+ # Add annotation about signal preservation
322
+ ax.annotate('ResMLP maintains\nstable activations',
323
+ xy=(15, res_stds[14]), xytext=(10, res_stds[14]*1.3),
324
+ fontsize=10, color='#3498db',
325
+ arrowprops=dict(arrowstyle='->', color='#3498db', alpha=0.7))
326
+
327
+ ax.annotate('PlainMLP activations\ndegrade through layers',
328
+ xy=(18, plain_stds[17]), xytext=(12, plain_stds[17]*0.5),
329
+ fontsize=10, color='#e74c3c',
330
+ arrowprops=dict(arrowstyle='->', color='#e74c3c', alpha=0.7))
331
+
332
+ plt.tight_layout()
333
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
334
+ plt.close()
335
+ print(f"[Plot] Saved activation std plot to {save_path}")
336
+
337
+
338
+ def main():
339
+ print("=" * 60)
340
+ print("PlainMLP vs ResMLP: Distant Identity Task Experiment")
341
+ print("=" * 60)
342
+
343
+ # Ensure plots directory exists
344
+ os.makedirs('plots', exist_ok=True)
345
+
346
+ # Generate synthetic data
347
+ print("\n[1] Generating synthetic identity data...")
348
+ X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM)
349
+ print(f" Data shape: X={X.shape}, Y={Y.shape}")
350
+ print(f" X range: [{X.min():.3f}, {X.max():.3f}]")
351
+ print(f" Task: Learn Y = X (identity mapping)")
352
+
353
+ # Initialize models
354
+ print("\n[2] Initializing models...")
355
+ plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS)
356
+ res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS)
357
+
358
+ plain_params = sum(p.numel() for p in plain_mlp.parameters())
359
+ res_params = sum(p.numel() for p in res_mlp.parameters())
360
+ print(f" PlainMLP parameters: {plain_params:,}")
361
+ print(f" ResMLP parameters: {res_params:,}")
362
+
363
+ # Train PlainMLP
364
+ print("\n[3] Training PlainMLP...")
365
+ plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
366
+ print(f" Final loss: {plain_losses[-1]:.6f}")
367
+
368
+ # Train ResMLP
369
+ print("\n[4] Training ResMLP...")
370
+ res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
371
+ print(f" Final loss: {res_losses[-1]:.6f}")
372
+
373
+ # Calculate improvement
374
+ improvement = plain_losses[-1] / res_losses[-1]
375
+ print(f"\n >>> ResMLP achieves {improvement:.1f}x lower loss than PlainMLP <<<")
376
+
377
+ # Final state analysis
378
+ print("\n[5] Analyzing final state of trained models...")
379
+ print(" Running forward/backward pass on new random batch...")
380
+ print(" Analyzing PlainMLP...")
381
+ plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM)
382
+ print(" Analyzing ResMLP...")
383
+ res_stats = analyze_final_state(res_mlp, HIDDEN_DIM)
384
+
385
+ # Print detailed analysis
386
+ print("\n[6] Detailed Analysis:")
387
+ print("\n === Loss Comparison ===")
388
+ print(f" PlainMLP - Initial: {plain_losses[0]:.4f}, Final: {plain_losses[-1]:.4f}")
389
+ print(f" ResMLP - Initial: {res_losses[0]:.4f}, Final: {res_losses[-1]:.4f}")
390
+
391
+ print("\n === Gradient Flow (L2 norms) ===")
392
+ print(f" PlainMLP - Layer 1: {plain_stats['gradient_norms'][0]:.2e}, Layer 20: {plain_stats['gradient_norms'][-1]:.2e}")
393
+ print(f" ResMLP - Layer 1: {res_stats['gradient_norms'][0]:.2e}, Layer 20: {res_stats['gradient_norms'][-1]:.2e}")
394
+
395
+ print("\n === Activation Statistics ===")
396
+ print(f" PlainMLP - Std range: [{min(plain_stats['activation_stds']):.4f}, {max(plain_stats['activation_stds']):.4f}]")
397
+ print(f" ResMLP - Std range: [{min(res_stats['activation_stds']):.4f}, {max(res_stats['activation_stds']):.4f}]")
398
+
399
+ # Generate plots
400
+ print("\n[7] Generating plots...")
401
+ plot_training_loss(plain_losses, res_losses, 'plots/training_loss.png')
402
+ plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'],
403
+ 'plots/gradient_magnitude.png')
404
+ plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'],
405
+ 'plots/activation_mean.png')
406
+ plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'],
407
+ 'plots/activation_std.png')
408
+
409
+ # Save results to JSON
410
+ results = {
411
+ 'config': {
412
+ 'num_layers': NUM_LAYERS,
413
+ 'hidden_dim': HIDDEN_DIM,
414
+ 'num_samples': NUM_SAMPLES,
415
+ 'training_steps': TRAINING_STEPS,
416
+ 'learning_rate': LEARNING_RATE,
417
+ 'batch_size': BATCH_SIZE
418
+ },
419
+ 'plain_mlp': {
420
+ 'final_loss': plain_losses[-1],
421
+ 'initial_loss': plain_losses[0],
422
+ 'loss_history': plain_losses,
423
+ 'gradient_norms': plain_stats['gradient_norms'],
424
+ 'activation_means': plain_stats['activation_means'],
425
+ 'activation_stds': plain_stats['activation_stds']
426
+ },
427
+ 'res_mlp': {
428
+ 'final_loss': res_losses[-1],
429
+ 'initial_loss': res_losses[0],
430
+ 'loss_history': res_losses,
431
+ 'gradient_norms': res_stats['gradient_norms'],
432
+ 'activation_means': res_stats['activation_means'],
433
+ 'activation_stds': res_stats['activation_stds']
434
+ },
435
+ 'summary': {
436
+ 'loss_improvement': improvement,
437
+ 'plain_grad_range': [min(plain_stats['gradient_norms']), max(plain_stats['gradient_norms'])],
438
+ 'res_grad_range': [min(res_stats['gradient_norms']), max(res_stats['gradient_norms'])],
439
+ 'plain_std_range': [min(plain_stats['activation_stds']), max(plain_stats['activation_stds'])],
440
+ 'res_std_range': [min(res_stats['activation_stds']), max(res_stats['activation_stds'])]
441
+ }
442
+ }
443
+
444
+ with open('results.json', 'w') as f:
445
+ json.dump(results, f, indent=2)
446
+ print("\n[8] Results saved to results.json")
447
+
448
+ print("\n" + "=" * 60)
449
+ print("Experiment completed successfully!")
450
+ print("=" * 60)
451
+
452
+ return results
453
+
454
+
455
+ if __name__ == "__main__":
456
+ results = main()