AmberLJC commited on
Commit
7dbd1cf
·
verified ·
1 Parent(s): 214f5ea

Upload experiment_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. experiment_v2.py +423 -0
experiment_v2.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlainMLP vs ResMLP Comparison on Distant Identity Task (V2)
3
+
4
+ This experiment demonstrates the vanishing gradient problem in deep networks
5
+ and how residual connections solve it.
6
+
7
+ Key insight: The identity task Y=X is trivially solvable by a residual network
8
+ if it can learn to zero out the residual branch, but a plain network must
9
+ learn a complex composition of transformations.
10
+
11
+ V2 Changes:
12
+ - Use proper residual scaling (1/sqrt(num_layers)) to prevent explosion
13
+ - Better initialization for residual blocks
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import numpy as np
19
+ import matplotlib.pyplot as plt
20
+ from typing import Dict, List, Tuple
21
+ import json
22
+
23
+ # Set random seeds for reproducibility
24
+ torch.manual_seed(42)
25
+ np.random.seed(42)
26
+
27
+ # Configuration
28
+ NUM_LAYERS = 20
29
+ HIDDEN_DIM = 64
30
+ NUM_SAMPLES = 1024
31
+ TRAINING_STEPS = 500
32
+ LEARNING_RATE = 1e-3
33
+ BATCH_SIZE = 64
34
+
35
+ print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}")
36
+ print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}")
37
+
38
+
39
+ class PlainMLP(nn.Module):
40
+ """Plain MLP: x = ReLU(Linear(x)) for each layer
41
+
42
+ This architecture suffers from vanishing gradients in deep networks because:
43
+ 1. Each ReLU zeros out negative values, losing information
44
+ 2. Gradients must flow through all layers multiplicatively
45
+ 3. The network must learn a complex function composition to approximate identity
46
+ """
47
+
48
+ def __init__(self, dim: int, num_layers: int):
49
+ super().__init__()
50
+ self.layers = nn.ModuleList()
51
+ for _ in range(num_layers):
52
+ layer = nn.Linear(dim, dim)
53
+ # Kaiming He initialization
54
+ nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
55
+ nn.init.zeros_(layer.bias)
56
+ self.layers.append(layer)
57
+ self.activation = nn.ReLU()
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ for layer in self.layers:
61
+ x = self.activation(layer(x))
62
+ return x
63
+
64
+
65
+ class ResMLP(nn.Module):
66
+ """Residual MLP: x = x + scale * ReLU(Linear(x)) for each layer
67
+
68
+ Key advantages for identity learning:
69
+ 1. Identity shortcut allows gradients to flow directly to early layers
70
+ 2. Network only needs to learn the residual (deviation from identity)
71
+ 3. For identity task, optimal solution is to zero the residual branch
72
+
73
+ Uses scaling factor 1/sqrt(num_layers) to prevent activation explosion.
74
+ """
75
+
76
+ def __init__(self, dim: int, num_layers: int):
77
+ super().__init__()
78
+ self.layers = nn.ModuleList()
79
+ self.scale = 1.0 / np.sqrt(num_layers) # Scaling to prevent explosion
80
+
81
+ for _ in range(num_layers):
82
+ layer = nn.Linear(dim, dim)
83
+ # Kaiming He initialization
84
+ nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
85
+ nn.init.zeros_(layer.bias)
86
+ self.layers.append(layer)
87
+ self.activation = nn.ReLU()
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ for layer in self.layers:
91
+ x = x + self.scale * self.activation(layer(x)) # Scaled residual
92
+ return x
93
+
94
+
95
+ def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Generate synthetic data where Y = X, with X ~ U(-1, 1)
97
+
98
+ This is the "Distant Identity" task - the network must learn to output
99
+ exactly what it received as input, which is trivial for a single layer
100
+ but challenging for deep networks without skip connections.
101
+ """
102
+ X = torch.empty(num_samples, dim).uniform_(-1, 1)
103
+ Y = X.clone() # Identity task: target equals input
104
+ return X, Y
105
+
106
+
107
+ def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor,
108
+ steps: int, lr: float, batch_size: int) -> List[float]:
109
+ """Train model and record loss at each step"""
110
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
111
+ criterion = nn.MSELoss()
112
+ losses = []
113
+
114
+ num_samples = X.shape[0]
115
+
116
+ for step in range(steps):
117
+ # Random batch sampling
118
+ indices = torch.randint(0, num_samples, (batch_size,))
119
+ batch_x = X[indices]
120
+ batch_y = Y[indices]
121
+
122
+ # Forward pass
123
+ optimizer.zero_grad()
124
+ output = model(batch_x)
125
+ loss = criterion(output, batch_y)
126
+
127
+ # Backward pass
128
+ loss.backward()
129
+ optimizer.step()
130
+
131
+ losses.append(loss.item())
132
+
133
+ if step % 100 == 0:
134
+ print(f" Step {step}/{steps}, Loss: {loss.item():.6f}")
135
+
136
+ return losses
137
+
138
+
139
+ class ActivationGradientHook:
140
+ """Hook to capture activations and gradients at each layer"""
141
+
142
+ def __init__(self):
143
+ self.activations: List[torch.Tensor] = []
144
+ self.gradients: List[torch.Tensor] = []
145
+ self.handles = []
146
+
147
+ def register_hooks(self, model: nn.Module):
148
+ """Register forward and backward hooks on each layer"""
149
+ for layer in model.layers:
150
+ # Forward hook to capture activations (output of linear layer)
151
+ handle_fwd = layer.register_forward_hook(self._forward_hook)
152
+ # Backward hook to capture gradients
153
+ handle_bwd = layer.register_full_backward_hook(self._backward_hook)
154
+ self.handles.extend([handle_fwd, handle_bwd])
155
+
156
+ def _forward_hook(self, module, input, output):
157
+ self.activations.append(output.detach().clone())
158
+
159
+ def _backward_hook(self, module, grad_input, grad_output):
160
+ # grad_output[0] is the gradient w.r.t. the layer's output
161
+ self.gradients.append(grad_output[0].detach().clone())
162
+
163
+ def clear(self):
164
+ self.activations = []
165
+ self.gradients = []
166
+
167
+ def remove_hooks(self):
168
+ for handle in self.handles:
169
+ handle.remove()
170
+ self.handles = []
171
+
172
+ def get_activation_stats(self) -> Tuple[List[float], List[float]]:
173
+ """Get mean and std of activations for each layer"""
174
+ means = [act.mean().item() for act in self.activations]
175
+ stds = [act.std().item() for act in self.activations]
176
+ return means, stds
177
+
178
+ def get_gradient_norms(self) -> List[float]:
179
+ """Get L2 norm of gradients for each layer"""
180
+ # Gradients are captured in reverse order (from output to input)
181
+ norms = [grad.norm(2).item() for grad in reversed(self.gradients)]
182
+ return norms
183
+
184
+
185
+ def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict:
186
+ """Perform forward/backward pass and capture activation/gradient stats"""
187
+ hook = ActivationGradientHook()
188
+ hook.register_hooks(model)
189
+
190
+ # Generate new random batch
191
+ X_test = torch.empty(batch_size, dim).uniform_(-1, 1)
192
+ Y_test = X_test.clone()
193
+
194
+ # Forward pass
195
+ model.zero_grad()
196
+ output = model(X_test)
197
+ loss = nn.MSELoss()(output, Y_test)
198
+
199
+ # Backward pass
200
+ loss.backward()
201
+
202
+ # Get statistics
203
+ act_means, act_stds = hook.get_activation_stats()
204
+ grad_norms = hook.get_gradient_norms()
205
+
206
+ hook.remove_hooks()
207
+
208
+ return {
209
+ 'activation_means': act_means,
210
+ 'activation_stds': act_stds,
211
+ 'gradient_norms': grad_norms,
212
+ 'final_loss': loss.item()
213
+ }
214
+
215
+
216
+ def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str):
217
+ """Plot training loss curves for both models"""
218
+ plt.figure(figsize=(10, 6))
219
+ steps = range(len(plain_losses))
220
+
221
+ plt.plot(steps, plain_losses, label='PlainMLP', color='#e74c3c', alpha=0.8, linewidth=2)
222
+ plt.plot(steps, res_losses, label='ResMLP', color='#3498db', alpha=0.8, linewidth=2)
223
+
224
+ plt.xlabel('Training Steps', fontsize=12)
225
+ plt.ylabel('MSE Loss', fontsize=12)
226
+ plt.title('Training Loss: PlainMLP vs ResMLP on Identity Task', fontsize=14)
227
+ plt.legend(fontsize=11)
228
+ plt.grid(True, alpha=0.3)
229
+ plt.yscale('log') # Log scale to see differences better
230
+
231
+ # Add annotation about final losses
232
+ final_plain = plain_losses[-1]
233
+ final_res = res_losses[-1]
234
+ plt.annotate(f'PlainMLP final: {final_plain:.4f}',
235
+ xy=(len(plain_losses)-1, final_plain),
236
+ xytext=(len(plain_losses)*0.7, final_plain*2),
237
+ fontsize=10, color='#e74c3c',
238
+ arrowprops=dict(arrowstyle='->', color='#e74c3c', alpha=0.7))
239
+ plt.annotate(f'ResMLP final: {final_res:.6f}',
240
+ xy=(len(res_losses)-1, final_res),
241
+ xytext=(len(res_losses)*0.7, final_res*0.1),
242
+ fontsize=10, color='#3498db',
243
+ arrowprops=dict(arrowstyle='->', color='#3498db', alpha=0.7))
244
+
245
+ plt.tight_layout()
246
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
247
+ plt.close()
248
+ print(f"[Plot] Saved training loss plot to {save_path}")
249
+
250
+
251
+ def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str):
252
+ """Plot gradient magnitude vs layer depth"""
253
+ plt.figure(figsize=(10, 6))
254
+ layers = range(1, len(plain_grads) + 1)
255
+
256
+ plt.plot(layers, plain_grads, 'o-', label='PlainMLP', color='#e74c3c',
257
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
258
+ plt.plot(layers, res_grads, 's-', label='ResMLP', color='#3498db',
259
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
260
+
261
+ plt.xlabel('Layer Depth', fontsize=12)
262
+ plt.ylabel('Gradient L2 Norm', fontsize=12)
263
+ plt.title('Gradient Magnitude vs Layer Depth (After Training)', fontsize=14)
264
+ plt.legend(fontsize=11)
265
+ plt.grid(True, alpha=0.3)
266
+ plt.yscale('log')
267
+
268
+ # Add shaded region to highlight gradient difference
269
+ plt.fill_between(layers, plain_grads, res_grads, alpha=0.2, color='gray')
270
+
271
+ plt.tight_layout()
272
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
273
+ plt.close()
274
+ print(f"[Plot] Saved gradient magnitude plot to {save_path}")
275
+
276
+
277
+ def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str):
278
+ """Plot activation mean vs layer depth"""
279
+ plt.figure(figsize=(10, 6))
280
+ layers = range(1, len(plain_means) + 1)
281
+
282
+ plt.plot(layers, plain_means, 'o-', label='PlainMLP', color='#e74c3c',
283
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
284
+ plt.plot(layers, res_means, 's-', label='ResMLP', color='#3498db',
285
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
286
+
287
+ plt.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label='Zero baseline')
288
+
289
+ plt.xlabel('Layer Depth', fontsize=12)
290
+ plt.ylabel('Activation Mean', fontsize=12)
291
+ plt.title('Activation Mean vs Layer Depth (After Training)', fontsize=14)
292
+ plt.legend(fontsize=11)
293
+ plt.grid(True, alpha=0.3)
294
+
295
+ plt.tight_layout()
296
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
297
+ plt.close()
298
+ print(f"[Plot] Saved activation mean plot to {save_path}")
299
+
300
+
301
+ def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str):
302
+ """Plot activation std vs layer depth"""
303
+ plt.figure(figsize=(10, 6))
304
+ layers = range(1, len(plain_stds) + 1)
305
+
306
+ plt.plot(layers, plain_stds, 'o-', label='PlainMLP', color='#e74c3c',
307
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
308
+ plt.plot(layers, res_stds, 's-', label='ResMLP', color='#3498db',
309
+ markersize=8, linewidth=2, markeredgecolor='white', markeredgewidth=1)
310
+
311
+ plt.xlabel('Layer Depth', fontsize=12)
312
+ plt.ylabel('Activation Std', fontsize=12)
313
+ plt.title('Activation Standard Deviation vs Layer Depth (After Training)', fontsize=14)
314
+ plt.legend(fontsize=11)
315
+ plt.grid(True, alpha=0.3)
316
+
317
+ plt.tight_layout()
318
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
319
+ plt.close()
320
+ print(f"[Plot] Saved activation std plot to {save_path}")
321
+
322
+
323
+ def main():
324
+ print("=" * 60)
325
+ print("PlainMLP vs ResMLP: Distant Identity Task Experiment (V2)")
326
+ print("=" * 60)
327
+
328
+ # Generate synthetic data
329
+ print("\n[1] Generating synthetic identity data...")
330
+ X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM)
331
+ print(f" Data shape: X={X.shape}, Y={Y.shape}")
332
+ print(f" X range: [{X.min():.3f}, {X.max():.3f}]")
333
+
334
+ # Initialize models
335
+ print("\n[2] Initializing models...")
336
+ plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS)
337
+ res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS)
338
+
339
+ plain_params = sum(p.numel() for p in plain_mlp.parameters())
340
+ res_params = sum(p.numel() for p in res_mlp.parameters())
341
+ print(f" PlainMLP parameters: {plain_params:,}")
342
+ print(f" ResMLP parameters: {res_params:,}")
343
+ print(f" ResMLP residual scale: {res_mlp.scale:.4f}")
344
+
345
+ # Train PlainMLP
346
+ print("\n[3] Training PlainMLP...")
347
+ plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
348
+ print(f" Final loss: {plain_losses[-1]:.6f}")
349
+
350
+ # Train ResMLP
351
+ print("\n[4] Training ResMLP...")
352
+ res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
353
+ print(f" Final loss: {res_losses[-1]:.6f}")
354
+
355
+ # Final state analysis
356
+ print("\n[5] Analyzing final state of trained models...")
357
+ print(" Analyzing PlainMLP...")
358
+ plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM)
359
+ print(" Analyzing ResMLP...")
360
+ res_stats = analyze_final_state(res_mlp, HIDDEN_DIM)
361
+
362
+ # Print analysis summary
363
+ print("\n[6] Analysis Summary:")
364
+ print(f" PlainMLP - Final Loss: {plain_stats['final_loss']:.6f}")
365
+ print(f" ResMLP - Final Loss: {res_stats['final_loss']:.6f}")
366
+ print(f" Loss Improvement: {plain_stats['final_loss'] / res_stats['final_loss']:.1f}x")
367
+ print(f"\n PlainMLP - Gradient norm range: [{min(plain_stats['gradient_norms']):.2e}, {max(plain_stats['gradient_norms']):.2e}]")
368
+ print(f" ResMLP - Gradient norm range: [{min(res_stats['gradient_norms']):.2e}, {max(res_stats['gradient_norms']):.2e}]")
369
+ print(f"\n PlainMLP - Activation std range: [{min(plain_stats['activation_stds']):.4f}, {max(plain_stats['activation_stds']):.4f}]")
370
+ print(f" ResMLP - Activation std range: [{min(res_stats['activation_stds']):.4f}, {max(res_stats['activation_stds']):.4f}]")
371
+
372
+ # Generate plots
373
+ print("\n[7] Generating plots...")
374
+ plot_training_loss(plain_losses, res_losses, 'plots/training_loss.png')
375
+ plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'],
376
+ 'plots/gradient_magnitude.png')
377
+ plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'],
378
+ 'plots/activation_mean.png')
379
+ plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'],
380
+ 'plots/activation_std.png')
381
+
382
+ # Save results to JSON for report
383
+ results = {
384
+ 'config': {
385
+ 'num_layers': NUM_LAYERS,
386
+ 'hidden_dim': HIDDEN_DIM,
387
+ 'num_samples': NUM_SAMPLES,
388
+ 'training_steps': TRAINING_STEPS,
389
+ 'learning_rate': LEARNING_RATE,
390
+ 'batch_size': BATCH_SIZE,
391
+ 'residual_scale': float(res_mlp.scale)
392
+ },
393
+ 'plain_mlp': {
394
+ 'final_loss': plain_losses[-1],
395
+ 'initial_loss': plain_losses[0],
396
+ 'loss_history': plain_losses,
397
+ 'gradient_norms': plain_stats['gradient_norms'],
398
+ 'activation_means': plain_stats['activation_means'],
399
+ 'activation_stds': plain_stats['activation_stds']
400
+ },
401
+ 'res_mlp': {
402
+ 'final_loss': res_losses[-1],
403
+ 'initial_loss': res_losses[0],
404
+ 'loss_history': res_losses,
405
+ 'gradient_norms': res_stats['gradient_norms'],
406
+ 'activation_means': res_stats['activation_means'],
407
+ 'activation_stds': res_stats['activation_stds']
408
+ }
409
+ }
410
+
411
+ with open('results.json', 'w') as f:
412
+ json.dump(results, f, indent=2)
413
+ print("\n[8] Results saved to results.json")
414
+
415
+ print("\n" + "=" * 60)
416
+ print("Experiment completed successfully!")
417
+ print("=" * 60)
418
+
419
+ return results
420
+
421
+
422
+ if __name__ == "__main__":
423
+ results = main()