AmberLJC commited on
Commit
86f312f
·
verified ·
1 Parent(s): 11c4193

Upload experiment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. experiment.py +470 -0
experiment.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient Clipping Experiment
3
+
4
+ This script demonstrates how gradient clipping stabilizes training by preventing
5
+ sudden large weight updates caused by rare, high-loss data points.
6
+
7
+ Experiment Setup:
8
+ - Simple model: Embedding(4, 16) -> Linear(16, 4)
9
+ - Vocabulary: ['A', 'B', 'C', 'D']
10
+ - Dataset: 1000 samples with imbalanced targets (990 'A', 10 'B')
11
+ - Compare training with and without gradient clipping
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.optim as optim
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ import random
20
+
21
+ # Set seeds for reproducibility
22
+ SEED = 42
23
+
24
+
25
+ def set_seeds(seed=SEED):
26
+ """Set all random seeds for reproducibility."""
27
+ torch.manual_seed(seed)
28
+ np.random.seed(seed)
29
+ random.seed(seed)
30
+
31
+
32
+ # =============================================================================
33
+ # 1. MODEL DEFINITION
34
+ # =============================================================================
35
+
36
+ class SimpleNextTokenModel(nn.Module):
37
+ """
38
+ Simple model that takes a token index and predicts the next token.
39
+ Architecture: Embedding -> Linear
40
+ """
41
+ def __init__(self, vocab_size=4, embedding_dim=16):
42
+ super().__init__()
43
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
44
+ self.linear = nn.Linear(embedding_dim, vocab_size)
45
+
46
+ def forward(self, x):
47
+ """
48
+ Args:
49
+ x: Token indices of shape (batch_size,)
50
+ Returns:
51
+ Logits of shape (batch_size, vocab_size)
52
+ """
53
+ embedded = self.embedding(x) # (batch_size, embedding_dim)
54
+ logits = self.linear(embedded) # (batch_size, vocab_size)
55
+ return logits
56
+
57
+
58
+ # =============================================================================
59
+ # 2. DATASET CREATION
60
+ # =============================================================================
61
+
62
+ def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED):
63
+ """
64
+ Create a synthetic dataset with imbalanced targets.
65
+
66
+ Args:
67
+ n_samples: Total number of samples
68
+ n_rare: Number of rare 'B' samples
69
+ seed: Random seed for reproducibility
70
+
71
+ Returns:
72
+ inputs: Random token indices (0-3)
73
+ targets: 990 'A' (0) and 10 'B' (1)
74
+ rare_indices: Indices where target is 'B'
75
+ """
76
+ # Set seed for reproducibility
77
+ set_seeds(seed)
78
+
79
+ vocab = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
80
+
81
+ # Random input tokens
82
+ inputs = torch.randint(0, 4, (n_samples,))
83
+
84
+ # Create imbalanced targets: mostly 'A' (0), few 'B' (1)
85
+ targets = torch.zeros(n_samples, dtype=torch.long) # All 'A' initially
86
+
87
+ # Randomly select indices for rare 'B' samples
88
+ rare_indices = random.sample(range(n_samples), n_rare)
89
+ targets[rare_indices] = 1 # Set to 'B'
90
+
91
+ return inputs, targets, sorted(rare_indices)
92
+
93
+
94
+ # =============================================================================
95
+ # 3. UTILITY FUNCTIONS
96
+ # =============================================================================
97
+
98
+ def compute_weight_norm(model):
99
+ """Compute L2 norm of all model weights."""
100
+ total_norm = 0.0
101
+ for param in model.parameters():
102
+ total_norm += param.data.norm(2).item() ** 2
103
+ return total_norm ** 0.5
104
+
105
+
106
+ def get_initial_weights(seed=SEED):
107
+ """Get initial weights for reproducible model initialization."""
108
+ set_seeds(seed)
109
+ model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
110
+ return {name: param.clone() for name, param in model.state_dict().items()}
111
+
112
+
113
+ def train_epoch(model, optimizer, criterion, inputs, targets, clip_grad=False, max_norm=1.0):
114
+ """
115
+ Train for one epoch, recording metrics at each step.
116
+
117
+ Args:
118
+ model: The neural network
119
+ optimizer: SGD optimizer
120
+ criterion: CrossEntropyLoss
121
+ inputs: Input token indices
122
+ targets: Target token indices
123
+ clip_grad: Whether to apply gradient clipping
124
+ max_norm: Maximum gradient norm (if clipping)
125
+
126
+ Returns:
127
+ losses: List of losses per step
128
+ grad_norms: List of gradient norms per step (before clipping)
129
+ weight_norms: List of weight norms per step
130
+ """
131
+ model.train()
132
+
133
+ losses = []
134
+ grad_norms = []
135
+ weight_norms = []
136
+
137
+ # Train on each sample individually to see the effect of rare samples
138
+ for i in range(len(inputs)):
139
+ x = inputs[i:i+1] # Single sample
140
+ y = targets[i:i+1]
141
+
142
+ optimizer.zero_grad()
143
+
144
+ # Forward pass
145
+ logits = model(x)
146
+ loss = criterion(logits, y)
147
+
148
+ # Backward pass
149
+ loss.backward()
150
+
151
+ # Compute gradient norm BEFORE clipping
152
+ # Use a large value to just compute the norm without clipping
153
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
154
+
155
+ # Apply gradient clipping if requested
156
+ if clip_grad:
157
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
158
+
159
+ # Update weights
160
+ optimizer.step()
161
+
162
+ # Record metrics
163
+ losses.append(loss.item())
164
+ grad_norms.append(grad_norm.item())
165
+ weight_norms.append(compute_weight_norm(model))
166
+
167
+ return losses, grad_norms, weight_norms
168
+
169
+
170
+ # =============================================================================
171
+ # 4. TRAINING FUNCTIONS
172
+ # =============================================================================
173
+
174
+ def run_training(inputs, targets, rare_indices, clip_grad=False, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=None):
175
+ """
176
+ Run complete training loop.
177
+
178
+ Args:
179
+ inputs: Input token indices
180
+ targets: Target token indices
181
+ rare_indices: Indices of rare 'B' samples
182
+ clip_grad: Whether to apply gradient clipping
183
+ max_norm: Maximum gradient norm threshold
184
+ n_epochs: Number of training epochs
185
+ lr: Learning rate
186
+ init_weights: Initial model weights for reproducibility
187
+
188
+ Returns:
189
+ all_losses, all_grad_norms, all_weight_norms: Metrics across all steps
190
+ """
191
+ # Create fresh model with same initial weights
192
+ set_seeds(SEED)
193
+ model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
194
+ if init_weights:
195
+ model.load_state_dict(init_weights)
196
+
197
+ optimizer = optim.SGD(model.parameters(), lr=lr)
198
+ criterion = nn.CrossEntropyLoss()
199
+
200
+ all_losses = []
201
+ all_grad_norms = []
202
+ all_weight_norms = []
203
+
204
+ mode = "WITH" if clip_grad else "WITHOUT"
205
+ print(f"\n{'='*60}")
206
+ print(f"Training {mode} gradient clipping (max_norm={max_norm})")
207
+ print(f"{'='*60}")
208
+
209
+ for epoch in range(n_epochs):
210
+ losses, grad_norms, weight_norms = train_epoch(
211
+ model, optimizer, criterion, inputs, targets,
212
+ clip_grad=clip_grad, max_norm=max_norm
213
+ )
214
+
215
+ all_losses.extend(losses)
216
+ all_grad_norms.extend(grad_norms)
217
+ all_weight_norms.extend(weight_norms)
218
+
219
+ avg_loss = np.mean(losses)
220
+ max_grad = np.max(grad_norms)
221
+ print(f"Epoch {epoch+1}/{n_epochs}: Avg Loss={avg_loss:.4f}, Max Grad Norm={max_grad:.4f}")
222
+
223
+ return all_losses, all_grad_norms, all_weight_norms
224
+
225
+
226
+ # =============================================================================
227
+ # 5. PLOTTING FUNCTIONS
228
+ # =============================================================================
229
+
230
+ def plot_metrics(losses, grad_norms, weight_norms, title, filename, rare_indices=None, n_samples=1000):
231
+ """
232
+ Plot training metrics: loss, gradient norm, and weight norm.
233
+
234
+ Args:
235
+ losses: List of losses per step
236
+ grad_norms: List of gradient norms per step
237
+ weight_norms: List of weight norms per step
238
+ title: Plot title
239
+ filename: Output filename
240
+ rare_indices: Indices of rare 'B' samples (for highlighting)
241
+ n_samples: Number of samples per epoch
242
+ """
243
+ fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)
244
+
245
+ steps = range(len(losses))
246
+ n_epochs = len(losses) // n_samples
247
+
248
+ # Plot 1: Training Loss
249
+ axes[0].plot(steps, losses, 'b-', alpha=0.7, linewidth=0.5)
250
+ axes[0].set_ylabel('Training Loss', fontsize=12)
251
+ axes[0].set_title(title, fontsize=14, fontweight='bold')
252
+ axes[0].grid(True, alpha=0.3)
253
+
254
+ # Highlight rare sample positions
255
+ if rare_indices:
256
+ for epoch in range(n_epochs):
257
+ for idx in rare_indices:
258
+ step = epoch * n_samples + idx
259
+ if step < len(losses):
260
+ axes[0].axvline(x=step, color='red', alpha=0.3, linewidth=0.5)
261
+
262
+ # Plot 2: Gradient Norm
263
+ axes[1].plot(steps, grad_norms, 'g-', alpha=0.7, linewidth=0.5)
264
+ axes[1].set_ylabel('Gradient L2 Norm', fontsize=12)
265
+ axes[1].grid(True, alpha=0.3)
266
+
267
+ # Add horizontal line at clipping threshold
268
+ if "With" in title or "WITH" in title:
269
+ axes[1].axhline(y=1.0, color='red', linestyle='--', label='Clip threshold (1.0)')
270
+ axes[1].legend()
271
+
272
+ if rare_indices:
273
+ for epoch in range(n_epochs):
274
+ for idx in rare_indices:
275
+ step = epoch * n_samples + idx
276
+ if step < len(grad_norms):
277
+ axes[1].axvline(x=step, color='red', alpha=0.3, linewidth=0.5)
278
+
279
+ # Plot 3: Weight Norm
280
+ axes[2].plot(steps, weight_norms, 'm-', alpha=0.7, linewidth=0.5)
281
+ axes[2].set_ylabel('Weight L2 Norm', fontsize=12)
282
+ axes[2].set_xlabel('Training Step', fontsize=12)
283
+ axes[2].grid(True, alpha=0.3)
284
+
285
+ plt.tight_layout()
286
+ plt.savefig(filename, dpi=150, bbox_inches='tight')
287
+ plt.close()
288
+
289
+ print(f"Plot saved to: {filename}")
290
+
291
+
292
+ def plot_comparison(metrics_no_clip, metrics_with_clip, rare_indices, filename, n_samples=1000):
293
+ """
294
+ Create side-by-side comparison plot.
295
+
296
+ Args:
297
+ metrics_no_clip: (losses, grad_norms, weight_norms) without clipping
298
+ metrics_with_clip: (losses, grad_norms, weight_norms) with clipping
299
+ rare_indices: Indices of rare 'B' samples
300
+ filename: Output filename
301
+ n_samples: Number of samples per epoch
302
+ """
303
+ fig, axes = plt.subplots(3, 2, figsize=(16, 12))
304
+
305
+ losses_no, grads_no, weights_no = metrics_no_clip
306
+ losses_with, grads_with, weights_with = metrics_with_clip
307
+
308
+ steps = range(len(losses_no))
309
+ n_epochs = len(losses_no) // n_samples
310
+
311
+ # Column 1: Without Clipping
312
+ axes[0, 0].plot(steps, losses_no, 'b-', alpha=0.7, linewidth=0.5)
313
+ axes[0, 0].set_ylabel('Training Loss', fontsize=11)
314
+ axes[0, 0].set_title('WITHOUT Gradient Clipping', fontsize=13, fontweight='bold', color='red')
315
+ axes[0, 0].grid(True, alpha=0.3)
316
+
317
+ axes[1, 0].plot(steps, grads_no, 'g-', alpha=0.7, linewidth=0.5)
318
+ axes[1, 0].set_ylabel('Gradient L2 Norm', fontsize=11)
319
+ axes[1, 0].grid(True, alpha=0.3)
320
+
321
+ axes[2, 0].plot(steps, weights_no, 'm-', alpha=0.7, linewidth=0.5)
322
+ axes[2, 0].set_ylabel('Weight L2 Norm', fontsize=11)
323
+ axes[2, 0].set_xlabel('Training Step', fontsize=11)
324
+ axes[2, 0].grid(True, alpha=0.3)
325
+
326
+ # Column 2: With Clipping
327
+ axes[0, 1].plot(steps, losses_with, 'b-', alpha=0.7, linewidth=0.5)
328
+ axes[0, 1].set_title('WITH Gradient Clipping (max_norm=1.0)', fontsize=13, fontweight='bold', color='green')
329
+ axes[0, 1].grid(True, alpha=0.3)
330
+
331
+ axes[1, 1].plot(steps, grads_with, 'g-', alpha=0.7, linewidth=0.5)
332
+ axes[1, 1].axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Clip threshold')
333
+ axes[1, 1].legend(loc='upper right')
334
+ axes[1, 1].grid(True, alpha=0.3)
335
+
336
+ axes[2, 1].plot(steps, weights_with, 'm-', alpha=0.7, linewidth=0.5)
337
+ axes[2, 1].set_xlabel('Training Step', fontsize=11)
338
+ axes[2, 1].grid(True, alpha=0.3)
339
+
340
+ # Highlight rare sample positions in all plots
341
+ for col in range(2):
342
+ for row in range(3):
343
+ for epoch in range(n_epochs):
344
+ for idx in rare_indices:
345
+ step = epoch * n_samples + idx
346
+ if step < len(losses_no):
347
+ axes[row, col].axvline(x=step, color='red', alpha=0.2, linewidth=0.5)
348
+
349
+ # Add legend for rare samples
350
+ axes[0, 0].axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples")
351
+ axes[0, 0].legend(loc='upper right')
352
+
353
+ # Add overall title
354
+ fig.suptitle('Effect of Gradient Clipping on Training Stability\n(Red lines indicate rare "B" samples)',
355
+ fontsize=14, fontweight='bold', y=1.02)
356
+
357
+ plt.tight_layout()
358
+ plt.savefig(filename, dpi=150, bbox_inches='tight')
359
+ plt.close()
360
+
361
+ print(f"Comparison plot saved to: {filename}")
362
+
363
+
364
+ # =============================================================================
365
+ # 6. MAIN EXECUTION
366
+ # =============================================================================
367
+
368
+ def main():
369
+ print("="*60)
370
+ print("GRADIENT CLIPPING EXPERIMENT")
371
+ print("="*60)
372
+ print("\nThis experiment demonstrates how gradient clipping stabilizes")
373
+ print("training by preventing sudden large weight updates caused by")
374
+ print("rare, high-loss data points.\n")
375
+
376
+ # Create dataset ONCE (used for both runs)
377
+ inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED)
378
+
379
+ print(f"Dataset created:")
380
+ print(f" Total samples: {len(inputs)}")
381
+ print(f" Target 'A' (0): {(targets == 0).sum().item()}")
382
+ print(f" Target 'B' (1): {(targets == 1).sum().item()}")
383
+ print(f" Rare 'B' indices: {rare_indices}")
384
+
385
+ # Get initial weights (same for both runs)
386
+ init_weights = get_initial_weights(seed=SEED)
387
+
388
+ # Run training WITHOUT gradient clipping
389
+ losses_no_clip, grads_no_clip, weights_no_clip = run_training(
390
+ inputs, targets, rare_indices,
391
+ clip_grad=False, n_epochs=3, lr=0.1, init_weights=init_weights
392
+ )
393
+
394
+ # Run training WITH gradient clipping
395
+ losses_with_clip, grads_with_clip, weights_with_clip = run_training(
396
+ inputs, targets, rare_indices,
397
+ clip_grad=True, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=init_weights
398
+ )
399
+
400
+ # Generate individual plots
401
+ print("\n" + "="*60)
402
+ print("GENERATING PLOTS")
403
+ print("="*60)
404
+
405
+ plot_metrics(
406
+ losses_no_clip, grads_no_clip, weights_no_clip,
407
+ "Training WITHOUT Gradient Clipping",
408
+ "no_clipping.png",
409
+ rare_indices
410
+ )
411
+
412
+ plot_metrics(
413
+ losses_with_clip, grads_with_clip, weights_with_clip,
414
+ "Training WITH Gradient Clipping (max_norm=1.0)",
415
+ "with_clipping.png",
416
+ rare_indices
417
+ )
418
+
419
+ # Generate comparison plot
420
+ plot_comparison(
421
+ (losses_no_clip, grads_no_clip, weights_no_clip),
422
+ (losses_with_clip, grads_with_clip, weights_with_clip),
423
+ rare_indices,
424
+ "comparison.png"
425
+ )
426
+
427
+ # Print summary statistics
428
+ print("\n" + "="*60)
429
+ print("SUMMARY STATISTICS")
430
+ print("="*60)
431
+
432
+ print("\nWithout Gradient Clipping:")
433
+ print(f" Max Gradient Norm: {max(grads_no_clip):.4f}")
434
+ print(f" Mean Gradient Norm: {np.mean(grads_no_clip):.4f}")
435
+ print(f" Std Gradient Norm: {np.std(grads_no_clip):.4f}")
436
+ print(f" Final Weight Norm: {weights_no_clip[-1]:.4f}")
437
+ print(f" Final Loss: {losses_no_clip[-1]:.4f}")
438
+
439
+ print("\nWith Gradient Clipping (max_norm=1.0):")
440
+ print(f" Max Gradient Norm: {max(grads_with_clip):.4f}")
441
+ print(f" Mean Gradient Norm: {np.mean(grads_with_clip):.4f}")
442
+ print(f" Std Gradient Norm: {np.std(grads_with_clip):.4f}")
443
+ print(f" Final Weight Norm: {weights_with_clip[-1]:.4f}")
444
+ print(f" Final Loss: {losses_with_clip[-1]:.4f}")
445
+
446
+ # Return statistics for report
447
+ return {
448
+ 'no_clip': {
449
+ 'max_grad': max(grads_no_clip),
450
+ 'mean_grad': np.mean(grads_no_clip),
451
+ 'std_grad': np.std(grads_no_clip),
452
+ 'final_weight': weights_no_clip[-1],
453
+ 'final_loss': losses_no_clip[-1]
454
+ },
455
+ 'with_clip': {
456
+ 'max_grad': max(grads_with_clip),
457
+ 'mean_grad': np.mean(grads_with_clip),
458
+ 'std_grad': np.std(grads_with_clip),
459
+ 'final_weight': weights_with_clip[-1],
460
+ 'final_loss': losses_with_clip[-1]
461
+ },
462
+ 'rare_indices': rare_indices
463
+ }
464
+
465
+
466
+ if __name__ == "__main__":
467
+ stats = main()
468
+ print("\n" + "="*60)
469
+ print("EXPERIMENT COMPLETE!")
470
+ print("="*60)