AmberLJC commited on
Commit
85ee51e
·
verified ·
1 Parent(s): 538e428

Upload experiment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. experiment.py +362 -0
experiment.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlainMLP vs ResMLP Comparison on Distant Identity Task
3
+
4
+ This experiment demonstrates the vanishing gradient problem in deep networks
5
+ and how residual connections solve it.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ from typing import Dict, List, Tuple
13
+ import json
14
+
15
+ # Set random seeds for reproducibility
16
+ torch.manual_seed(42)
17
+ np.random.seed(42)
18
+
19
+ # Configuration
20
+ NUM_LAYERS = 20
21
+ HIDDEN_DIM = 64
22
+ NUM_SAMPLES = 1024
23
+ TRAINING_STEPS = 500
24
+ LEARNING_RATE = 1e-3
25
+ BATCH_SIZE = 64
26
+
27
+ print(f"[Config] Layers: {NUM_LAYERS}, Hidden Dim: {HIDDEN_DIM}")
28
+ print(f"[Config] Samples: {NUM_SAMPLES}, Steps: {TRAINING_STEPS}, LR: {LEARNING_RATE}")
29
+
30
+
31
+ class PlainMLP(nn.Module):
32
+ """Plain MLP: x = ReLU(Linear(x)) for each layer"""
33
+
34
+ def __init__(self, dim: int, num_layers: int):
35
+ super().__init__()
36
+ self.layers = nn.ModuleList()
37
+ for _ in range(num_layers):
38
+ layer = nn.Linear(dim, dim)
39
+ # Kaiming He initialization
40
+ nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
41
+ nn.init.zeros_(layer.bias)
42
+ self.layers.append(layer)
43
+ self.activation = nn.ReLU()
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ for layer in self.layers:
47
+ x = self.activation(layer(x))
48
+ return x
49
+
50
+
51
+ class ResMLP(nn.Module):
52
+ """Residual MLP: x = x + ReLU(Linear(x)) for each layer"""
53
+
54
+ def __init__(self, dim: int, num_layers: int):
55
+ super().__init__()
56
+ self.layers = nn.ModuleList()
57
+ for _ in range(num_layers):
58
+ layer = nn.Linear(dim, dim)
59
+ # Kaiming He initialization
60
+ nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
61
+ nn.init.zeros_(layer.bias)
62
+ self.layers.append(layer)
63
+ self.activation = nn.ReLU()
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ for layer in self.layers:
67
+ x = x + self.activation(layer(x)) # Residual connection
68
+ return x
69
+
70
+
71
+ def generate_identity_data(num_samples: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
72
+ """Generate synthetic data where Y = X, with X ~ U(-1, 1)"""
73
+ X = torch.empty(num_samples, dim).uniform_(-1, 1)
74
+ Y = X.clone() # Identity task: target equals input
75
+ return X, Y
76
+
77
+
78
+ def train_model(model: nn.Module, X: torch.Tensor, Y: torch.Tensor,
79
+ steps: int, lr: float, batch_size: int) -> List[float]:
80
+ """Train model and record loss at each step"""
81
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
82
+ criterion = nn.MSELoss()
83
+ losses = []
84
+
85
+ num_samples = X.shape[0]
86
+
87
+ for step in range(steps):
88
+ # Random batch sampling
89
+ indices = torch.randint(0, num_samples, (batch_size,))
90
+ batch_x = X[indices]
91
+ batch_y = Y[indices]
92
+
93
+ # Forward pass
94
+ optimizer.zero_grad()
95
+ output = model(batch_x)
96
+ loss = criterion(output, batch_y)
97
+
98
+ # Backward pass
99
+ loss.backward()
100
+ optimizer.step()
101
+
102
+ losses.append(loss.item())
103
+
104
+ if step % 100 == 0:
105
+ print(f" Step {step}/{steps}, Loss: {loss.item():.6f}")
106
+
107
+ return losses
108
+
109
+
110
+ class ActivationGradientHook:
111
+ """Hook to capture activations and gradients at each layer"""
112
+
113
+ def __init__(self):
114
+ self.activations: List[torch.Tensor] = []
115
+ self.gradients: List[torch.Tensor] = []
116
+ self.handles = []
117
+
118
+ def register_hooks(self, model: nn.Module):
119
+ """Register forward and backward hooks on each layer"""
120
+ for layer in model.layers:
121
+ # Forward hook to capture activations
122
+ handle_fwd = layer.register_forward_hook(self._forward_hook)
123
+ # Backward hook to capture gradients
124
+ handle_bwd = layer.register_full_backward_hook(self._backward_hook)
125
+ self.handles.extend([handle_fwd, handle_bwd])
126
+
127
+ def _forward_hook(self, module, input, output):
128
+ self.activations.append(output.detach().clone())
129
+
130
+ def _backward_hook(self, module, grad_input, grad_output):
131
+ # grad_output[0] is the gradient w.r.t. the layer's output
132
+ self.gradients.append(grad_output[0].detach().clone())
133
+
134
+ def clear(self):
135
+ self.activations = []
136
+ self.gradients = []
137
+
138
+ def remove_hooks(self):
139
+ for handle in self.handles:
140
+ handle.remove()
141
+ self.handles = []
142
+
143
+ def get_activation_stats(self) -> Tuple[List[float], List[float]]:
144
+ """Get mean and std of activations for each layer"""
145
+ means = [act.mean().item() for act in self.activations]
146
+ stds = [act.std().item() for act in self.activations]
147
+ return means, stds
148
+
149
+ def get_gradient_norms(self) -> List[float]:
150
+ """Get L2 norm of gradients for each layer"""
151
+ # Gradients are captured in reverse order (from output to input)
152
+ norms = [grad.norm(2).item() for grad in reversed(self.gradients)]
153
+ return norms
154
+
155
+
156
+ def analyze_final_state(model: nn.Module, dim: int, batch_size: int = 64) -> Dict:
157
+ """Perform forward/backward pass and capture activation/gradient stats"""
158
+ hook = ActivationGradientHook()
159
+ hook.register_hooks(model)
160
+
161
+ # Generate new random batch
162
+ X_test = torch.empty(batch_size, dim).uniform_(-1, 1)
163
+ Y_test = X_test.clone()
164
+
165
+ # Forward pass
166
+ model.zero_grad()
167
+ output = model(X_test)
168
+ loss = nn.MSELoss()(output, Y_test)
169
+
170
+ # Backward pass
171
+ loss.backward()
172
+
173
+ # Get statistics
174
+ act_means, act_stds = hook.get_activation_stats()
175
+ grad_norms = hook.get_gradient_norms()
176
+
177
+ hook.remove_hooks()
178
+
179
+ return {
180
+ 'activation_means': act_means,
181
+ 'activation_stds': act_stds,
182
+ 'gradient_norms': grad_norms,
183
+ 'final_loss': loss.item()
184
+ }
185
+
186
+
187
+ def plot_training_loss(plain_losses: List[float], res_losses: List[float], save_path: str):
188
+ """Plot training loss curves for both models"""
189
+ plt.figure(figsize=(10, 6))
190
+ steps = range(len(plain_losses))
191
+
192
+ plt.plot(steps, plain_losses, label='PlainMLP', color='red', alpha=0.8)
193
+ plt.plot(steps, res_losses, label='ResMLP', color='blue', alpha=0.8)
194
+
195
+ plt.xlabel('Training Steps', fontsize=12)
196
+ plt.ylabel('MSE Loss', fontsize=12)
197
+ plt.title('Training Loss: PlainMLP vs ResMLP on Identity Task', fontsize=14)
198
+ plt.legend(fontsize=11)
199
+ plt.grid(True, alpha=0.3)
200
+ plt.yscale('log') # Log scale to see differences better
201
+
202
+ plt.tight_layout()
203
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
204
+ plt.close()
205
+ print(f"[Plot] Saved training loss plot to {save_path}")
206
+
207
+
208
+ def plot_gradient_magnitudes(plain_grads: List[float], res_grads: List[float], save_path: str):
209
+ """Plot gradient magnitude vs layer depth"""
210
+ plt.figure(figsize=(10, 6))
211
+ layers = range(1, len(plain_grads) + 1)
212
+
213
+ plt.plot(layers, plain_grads, 'o-', label='PlainMLP', color='red', markersize=6)
214
+ plt.plot(layers, res_grads, 's-', label='ResMLP', color='blue', markersize=6)
215
+
216
+ plt.xlabel('Layer Depth', fontsize=12)
217
+ plt.ylabel('Gradient L2 Norm', fontsize=12)
218
+ plt.title('Gradient Magnitude vs Layer Depth (After Training)', fontsize=14)
219
+ plt.legend(fontsize=11)
220
+ plt.grid(True, alpha=0.3)
221
+ plt.yscale('log')
222
+
223
+ plt.tight_layout()
224
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
225
+ plt.close()
226
+ print(f"[Plot] Saved gradient magnitude plot to {save_path}")
227
+
228
+
229
+ def plot_activation_means(plain_means: List[float], res_means: List[float], save_path: str):
230
+ """Plot activation mean vs layer depth"""
231
+ plt.figure(figsize=(10, 6))
232
+ layers = range(1, len(plain_means) + 1)
233
+
234
+ plt.plot(layers, plain_means, 'o-', label='PlainMLP', color='red', markersize=6)
235
+ plt.plot(layers, res_means, 's-', label='ResMLP', color='blue', markersize=6)
236
+
237
+ plt.xlabel('Layer Depth', fontsize=12)
238
+ plt.ylabel('Activation Mean', fontsize=12)
239
+ plt.title('Activation Mean vs Layer Depth (After Training)', fontsize=14)
240
+ plt.legend(fontsize=11)
241
+ plt.grid(True, alpha=0.3)
242
+
243
+ plt.tight_layout()
244
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
245
+ plt.close()
246
+ print(f"[Plot] Saved activation mean plot to {save_path}")
247
+
248
+
249
+ def plot_activation_stds(plain_stds: List[float], res_stds: List[float], save_path: str):
250
+ """Plot activation std vs layer depth"""
251
+ plt.figure(figsize=(10, 6))
252
+ layers = range(1, len(plain_stds) + 1)
253
+
254
+ plt.plot(layers, plain_stds, 'o-', label='PlainMLP', color='red', markersize=6)
255
+ plt.plot(layers, res_stds, 's-', label='ResMLP', color='blue', markersize=6)
256
+
257
+ plt.xlabel('Layer Depth', fontsize=12)
258
+ plt.ylabel('Activation Std', fontsize=12)
259
+ plt.title('Activation Standard Deviation vs Layer Depth (After Training)', fontsize=14)
260
+ plt.legend(fontsize=11)
261
+ plt.grid(True, alpha=0.3)
262
+
263
+ plt.tight_layout()
264
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
265
+ plt.close()
266
+ print(f"[Plot] Saved activation std plot to {save_path}")
267
+
268
+
269
+ def main():
270
+ print("=" * 60)
271
+ print("PlainMLP vs ResMLP: Distant Identity Task Experiment")
272
+ print("=" * 60)
273
+
274
+ # Generate synthetic data
275
+ print("\n[1] Generating synthetic identity data...")
276
+ X, Y = generate_identity_data(NUM_SAMPLES, HIDDEN_DIM)
277
+ print(f" Data shape: X={X.shape}, Y={Y.shape}")
278
+ print(f" X range: [{X.min():.3f}, {X.max():.3f}]")
279
+
280
+ # Initialize models
281
+ print("\n[2] Initializing models...")
282
+ plain_mlp = PlainMLP(HIDDEN_DIM, NUM_LAYERS)
283
+ res_mlp = ResMLP(HIDDEN_DIM, NUM_LAYERS)
284
+
285
+ plain_params = sum(p.numel() for p in plain_mlp.parameters())
286
+ res_params = sum(p.numel() for p in res_mlp.parameters())
287
+ print(f" PlainMLP parameters: {plain_params:,}")
288
+ print(f" ResMLP parameters: {res_params:,}")
289
+
290
+ # Train PlainMLP
291
+ print("\n[3] Training PlainMLP...")
292
+ plain_losses = train_model(plain_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
293
+ print(f" Final loss: {plain_losses[-1]:.6f}")
294
+
295
+ # Train ResMLP
296
+ print("\n[4] Training ResMLP...")
297
+ res_losses = train_model(res_mlp, X, Y, TRAINING_STEPS, LEARNING_RATE, BATCH_SIZE)
298
+ print(f" Final loss: {res_losses[-1]:.6f}")
299
+
300
+ # Final state analysis
301
+ print("\n[5] Analyzing final state of trained models...")
302
+ print(" Analyzing PlainMLP...")
303
+ plain_stats = analyze_final_state(plain_mlp, HIDDEN_DIM)
304
+ print(" Analyzing ResMLP...")
305
+ res_stats = analyze_final_state(res_mlp, HIDDEN_DIM)
306
+
307
+ # Print analysis summary
308
+ print("\n[6] Analysis Summary:")
309
+ print(f" PlainMLP - Final Loss: {plain_stats['final_loss']:.6f}")
310
+ print(f" ResMLP - Final Loss: {res_stats['final_loss']:.6f}")
311
+ print(f" PlainMLP - Gradient norm range: [{min(plain_stats['gradient_norms']):.2e}, {max(plain_stats['gradient_norms']):.2e}]")
312
+ print(f" ResMLP - Gradient norm range: [{min(res_stats['gradient_norms']):.2e}, {max(res_stats['gradient_norms']):.2e}]")
313
+
314
+ # Generate plots
315
+ print("\n[7] Generating plots...")
316
+ plot_training_loss(plain_losses, res_losses, 'plots/training_loss.png')
317
+ plot_gradient_magnitudes(plain_stats['gradient_norms'], res_stats['gradient_norms'],
318
+ 'plots/gradient_magnitude.png')
319
+ plot_activation_means(plain_stats['activation_means'], res_stats['activation_means'],
320
+ 'plots/activation_mean.png')
321
+ plot_activation_stds(plain_stats['activation_stds'], res_stats['activation_stds'],
322
+ 'plots/activation_std.png')
323
+
324
+ # Save results to JSON for report
325
+ results = {
326
+ 'config': {
327
+ 'num_layers': NUM_LAYERS,
328
+ 'hidden_dim': HIDDEN_DIM,
329
+ 'num_samples': NUM_SAMPLES,
330
+ 'training_steps': TRAINING_STEPS,
331
+ 'learning_rate': LEARNING_RATE,
332
+ 'batch_size': BATCH_SIZE
333
+ },
334
+ 'plain_mlp': {
335
+ 'final_loss': plain_losses[-1],
336
+ 'initial_loss': plain_losses[0],
337
+ 'gradient_norms': plain_stats['gradient_norms'],
338
+ 'activation_means': plain_stats['activation_means'],
339
+ 'activation_stds': plain_stats['activation_stds']
340
+ },
341
+ 'res_mlp': {
342
+ 'final_loss': res_losses[-1],
343
+ 'initial_loss': res_losses[0],
344
+ 'gradient_norms': res_stats['gradient_norms'],
345
+ 'activation_means': res_stats['activation_means'],
346
+ 'activation_stds': res_stats['activation_stds']
347
+ }
348
+ }
349
+
350
+ with open('results.json', 'w') as f:
351
+ json.dump(results, f, indent=2)
352
+ print("\n[8] Results saved to results.json")
353
+
354
+ print("\n" + "=" * 60)
355
+ print("Experiment completed successfully!")
356
+ print("=" * 60)
357
+
358
+ return results
359
+
360
+
361
+ if __name__ == "__main__":
362
+ results = main()