AmberLJC commited on
Commit
11c4193
·
verified ·
1 Parent(s): 76666a7

Upload extended_experiment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. extended_experiment.py +698 -0
extended_experiment.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Extended Gradient Clipping Experiment: Testing Physics-of-AI Predictions
3
+
4
+ This script tests two predictions from our Physics-of-AI analysis:
5
+
6
+ Prediction 2: Representation Collapse
7
+ - Hypothesis: Without clipping, the effective dimensionality of embeddings
8
+ should show sudden drops at rare sample positions.
9
+ - Test: Track PCA-based effective dimension throughout training.
10
+
11
+ Prediction 4: Rare Sample Learning
12
+ - Hypothesis: With clipping, the model should achieve better accuracy on rare samples.
13
+ - Test: Track per-class accuracy throughout training.
14
+
15
+ Based on Ziming Liu's Physics-of-AI framework and the unigram toy model analysis.
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ import numpy as np
22
+ import matplotlib.pyplot as plt
23
+ import random
24
+ from typing import Dict, List, Tuple
25
+
26
+ # Set seeds for reproducibility
27
+ SEED = 42
28
+
29
+
30
+ def set_seeds(seed=SEED):
31
+ """Set all random seeds for reproducibility."""
32
+ torch.manual_seed(seed)
33
+ np.random.seed(seed)
34
+ random.seed(seed)
35
+
36
+
37
+ # =============================================================================
38
+ # 1. MODEL DEFINITION
39
+ # =============================================================================
40
+
41
+ class SimpleNextTokenModel(nn.Module):
42
+ """
43
+ Simple model that takes a token index and predicts the next token.
44
+ Architecture: Embedding -> Linear
45
+ """
46
+ def __init__(self, vocab_size=4, embedding_dim=16):
47
+ super().__init__()
48
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
49
+ self.linear = nn.Linear(embedding_dim, vocab_size)
50
+
51
+ def forward(self, x):
52
+ embedded = self.embedding(x)
53
+ logits = self.linear(embedded)
54
+ return logits
55
+
56
+ def get_embeddings(self):
57
+ """Return the embedding matrix for analysis."""
58
+ return self.embedding.weight.data.clone()
59
+
60
+
61
+ # =============================================================================
62
+ # 2. EFFECTIVE DIMENSIONALITY (PCA-based)
63
+ # =============================================================================
64
+
65
+ def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float:
66
+ """
67
+ Compute effective dimensionality using PCA entropy.
68
+
69
+ Following Ziming Liu's approach from the Unigram toy model analysis:
70
+ "We define effective dimensionality via PCA entropy"
71
+
72
+ Effective dimension = exp(entropy of normalized eigenvalues)
73
+
74
+ Args:
75
+ embedding_matrix: (vocab_size, embedding_dim) tensor
76
+
77
+ Returns:
78
+ Effective dimension (float between 1 and embedding_dim)
79
+ """
80
+ # Center the embeddings
81
+ centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True)
82
+
83
+ # Compute covariance matrix
84
+ cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1)
85
+
86
+ # Get eigenvalues
87
+ eigenvalues = torch.linalg.eigvalsh(cov)
88
+ eigenvalues = torch.clamp(eigenvalues, min=1e-10) # Avoid log(0)
89
+
90
+ # Normalize to get probability distribution
91
+ eigenvalues = eigenvalues / eigenvalues.sum()
92
+
93
+ # Compute entropy
94
+ entropy = -torch.sum(eigenvalues * torch.log(eigenvalues))
95
+
96
+ # Effective dimension = exp(entropy)
97
+ effective_dim = torch.exp(entropy).item()
98
+
99
+ return effective_dim
100
+
101
+
102
+ def compute_embedding_stats(embedding_matrix: torch.Tensor) -> Dict[str, float]:
103
+ """
104
+ Compute various statistics about the embedding matrix.
105
+
106
+ Returns:
107
+ Dictionary with embedding statistics
108
+ """
109
+ # Effective dimension
110
+ eff_dim = compute_effective_dimension(embedding_matrix)
111
+
112
+ # Embedding norms per token
113
+ norms = torch.norm(embedding_matrix, dim=1)
114
+
115
+ # Pairwise cosine similarities
116
+ normalized = embedding_matrix / (norms.unsqueeze(1) + 1e-10)
117
+ cosine_sim = torch.mm(normalized, normalized.T)
118
+ # Get off-diagonal elements (exclude self-similarity)
119
+ mask = ~torch.eye(cosine_sim.shape[0], dtype=bool)
120
+ off_diag = cosine_sim[mask]
121
+
122
+ return {
123
+ 'effective_dim': eff_dim,
124
+ 'mean_norm': norms.mean().item(),
125
+ 'std_norm': norms.std().item(),
126
+ 'mean_cosine_sim': off_diag.mean().item(),
127
+ 'max_cosine_sim': off_diag.max().item(),
128
+ }
129
+
130
+
131
+ # =============================================================================
132
+ # 3. PER-CLASS ACCURACY
133
+ # =============================================================================
134
+
135
+ def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor,
136
+ targets: torch.Tensor) -> Dict[int, float]:
137
+ """
138
+ Compute accuracy for each target class.
139
+
140
+ Args:
141
+ model: The neural network
142
+ inputs: Input token indices
143
+ targets: Target token indices
144
+
145
+ Returns:
146
+ Dictionary mapping class index to accuracy
147
+ """
148
+ model.eval()
149
+ with torch.no_grad():
150
+ logits = model(inputs)
151
+ predictions = logits.argmax(dim=1)
152
+
153
+ accuracies = {}
154
+ for class_idx in range(4): # Vocab size = 4
155
+ mask = targets == class_idx
156
+ if mask.sum() > 0:
157
+ correct = (predictions[mask] == targets[mask]).float().mean().item()
158
+ accuracies[class_idx] = correct
159
+ else:
160
+ accuracies[class_idx] = None # No samples of this class
161
+
162
+ return accuracies
163
+
164
+
165
+ # =============================================================================
166
+ # 4. DATASET CREATION
167
+ # =============================================================================
168
+
169
+ def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED):
170
+ """
171
+ Create a synthetic dataset with imbalanced targets.
172
+ """
173
+ set_seeds(seed)
174
+
175
+ inputs = torch.randint(0, 4, (n_samples,))
176
+ targets = torch.zeros(n_samples, dtype=torch.long)
177
+
178
+ rare_indices = random.sample(range(n_samples), n_rare)
179
+ targets[rare_indices] = 1 # Set to 'B'
180
+
181
+ return inputs, targets, sorted(rare_indices)
182
+
183
+
184
+ # =============================================================================
185
+ # 5. EXTENDED TRAINING LOOP
186
+ # =============================================================================
187
+
188
+ def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor,
189
+ rare_indices: List[int], clip_grad: bool = False,
190
+ max_norm: float = 1.0, n_epochs: int = 3,
191
+ lr: float = 0.1, init_weights=None,
192
+ track_every: int = 10) -> Dict:
193
+ """
194
+ Train with extended tracking of:
195
+ - Loss, gradient norm, weight norm (as before)
196
+ - Effective dimensionality of embeddings
197
+ - Per-class accuracy
198
+
199
+ Args:
200
+ inputs, targets: Training data
201
+ rare_indices: Indices of rare 'B' samples
202
+ clip_grad: Whether to apply gradient clipping
203
+ max_norm: Clipping threshold
204
+ n_epochs: Number of epochs
205
+ lr: Learning rate
206
+ init_weights: Initial model weights
207
+ track_every: Track embedding stats every N steps
208
+
209
+ Returns:
210
+ Dictionary with all tracked metrics
211
+ """
212
+ set_seeds(SEED)
213
+ model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
214
+ if init_weights:
215
+ model.load_state_dict({k: v.clone() for k, v in init_weights.items()})
216
+
217
+ optimizer = optim.SGD(model.parameters(), lr=lr)
218
+ criterion = nn.CrossEntropyLoss()
219
+
220
+ # Tracking arrays
221
+ metrics = {
222
+ 'losses': [],
223
+ 'grad_norms': [],
224
+ 'weight_norms': [],
225
+ 'effective_dims': [],
226
+ 'effective_dim_steps': [],
227
+ 'class_accuracies': {0: [], 1: [], 2: [], 3: []}, # A, B, C, D
228
+ 'accuracy_steps': [],
229
+ 'embedding_stats': [],
230
+ }
231
+
232
+ mode = "WITH" if clip_grad else "WITHOUT"
233
+ print(f"\n{'='*60}")
234
+ print(f"Training {mode} gradient clipping (max_norm={max_norm})")
235
+ print(f"{'='*60}")
236
+
237
+ step = 0
238
+ n_samples = len(inputs)
239
+
240
+ for epoch in range(n_epochs):
241
+ model.train()
242
+ epoch_losses = []
243
+
244
+ for i in range(n_samples):
245
+ x = inputs[i:i+1]
246
+ y = targets[i:i+1]
247
+
248
+ optimizer.zero_grad()
249
+ logits = model(x)
250
+ loss = criterion(logits, y)
251
+ loss.backward()
252
+
253
+ # Compute gradient norm BEFORE clipping
254
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
255
+
256
+ # Apply clipping if requested
257
+ if clip_grad:
258
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
259
+
260
+ optimizer.step()
261
+
262
+ # Record basic metrics
263
+ metrics['losses'].append(loss.item())
264
+ metrics['grad_norms'].append(grad_norm.item())
265
+
266
+ # Weight norm
267
+ total_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
268
+ metrics['weight_norms'].append(total_norm)
269
+
270
+ epoch_losses.append(loss.item())
271
+
272
+ # Track embedding stats periodically OR at rare sample positions
273
+ is_rare_position = i in rare_indices
274
+ should_track = (step % track_every == 0) or is_rare_position
275
+
276
+ if should_track:
277
+ emb_matrix = model.get_embeddings()
278
+ emb_stats = compute_embedding_stats(emb_matrix)
279
+
280
+ metrics['effective_dims'].append(emb_stats['effective_dim'])
281
+ metrics['effective_dim_steps'].append(step)
282
+ metrics['embedding_stats'].append(emb_stats)
283
+
284
+ # Per-class accuracy
285
+ class_acc = compute_per_class_accuracy(model, inputs, targets)
286
+ for cls_idx in range(4):
287
+ if class_acc[cls_idx] is not None:
288
+ metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx])
289
+ else:
290
+ metrics['class_accuracies'][cls_idx].append(0.0)
291
+ metrics['accuracy_steps'].append(step)
292
+
293
+ step += 1
294
+
295
+ avg_loss = np.mean(epoch_losses)
296
+
297
+ # End of epoch: compute full accuracy
298
+ class_acc = compute_per_class_accuracy(model, inputs, targets)
299
+ print(f"Epoch {epoch+1}/{n_epochs}: Avg Loss={avg_loss:.4f}")
300
+ b_acc = f"{class_acc[1]:.3f}" if class_acc[1] is not None else "N/A"
301
+ print(f" Class Accuracies: A={class_acc[0]:.3f}, B={b_acc}")
302
+
303
+ eff_dim = compute_effective_dimension(model.get_embeddings())
304
+ print(f" Effective Dimension: {eff_dim:.3f}")
305
+
306
+ return metrics
307
+
308
+
309
+ # =============================================================================
310
+ # 6. PLOTTING FUNCTIONS
311
+ # =============================================================================
312
+
313
+ def plot_effective_dimension_comparison(metrics_no_clip: Dict, metrics_with_clip: Dict,
314
+ rare_indices: List[int], filename: str,
315
+ n_samples: int = 1000):
316
+ """
317
+ Plot effective dimensionality comparison.
318
+
319
+ This tests Prediction 2: Without clipping, effective dimensionality
320
+ should show sudden drops at rare sample positions.
321
+ """
322
+ fig, axes = plt.subplots(2, 1, figsize=(14, 10))
323
+
324
+ # Plot 1: Without Clipping
325
+ ax1 = axes[0]
326
+ steps_no = metrics_no_clip['effective_dim_steps']
327
+ dims_no = metrics_no_clip['effective_dims']
328
+
329
+ ax1.plot(steps_no, dims_no, 'b-', linewidth=1.5, marker='o', markersize=3, alpha=0.7)
330
+ ax1.set_ylabel('Effective Dimension', fontsize=12)
331
+ ax1.set_title('WITHOUT Gradient Clipping - Embedding Effective Dimensionality',
332
+ fontsize=13, fontweight='bold', color='red')
333
+ ax1.grid(True, alpha=0.3)
334
+ ax1.set_ylim([0, 16]) # Max is embedding_dim=16
335
+
336
+ # Mark rare sample positions
337
+ n_epochs = len(metrics_no_clip['losses']) // n_samples
338
+ for epoch in range(n_epochs):
339
+ for idx in rare_indices:
340
+ step = epoch * n_samples + idx
341
+ ax1.axvline(x=step, color='red', alpha=0.3, linewidth=1)
342
+
343
+ # Add annotation
344
+ ax1.axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples")
345
+ ax1.legend(loc='upper right')
346
+
347
+ # Plot 2: With Clipping
348
+ ax2 = axes[1]
349
+ steps_with = metrics_with_clip['effective_dim_steps']
350
+ dims_with = metrics_with_clip['effective_dims']
351
+
352
+ ax2.plot(steps_with, dims_with, 'g-', linewidth=1.5, marker='o', markersize=3, alpha=0.7)
353
+ ax2.set_ylabel('Effective Dimension', fontsize=12)
354
+ ax2.set_xlabel('Training Step', fontsize=12)
355
+ ax2.set_title('WITH Gradient Clipping - Embedding Effective Dimensionality',
356
+ fontsize=13, fontweight='bold', color='green')
357
+ ax2.grid(True, alpha=0.3)
358
+ ax2.set_ylim([0, 16])
359
+
360
+ for epoch in range(n_epochs):
361
+ for idx in rare_indices:
362
+ step = epoch * n_samples + idx
363
+ ax2.axvline(x=step, color='red', alpha=0.3, linewidth=1)
364
+
365
+ ax2.axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples")
366
+ ax2.legend(loc='upper right')
367
+
368
+ fig.suptitle('Prediction 2: Representation Collapse Test\n'
369
+ '(Hypothesis: Without clipping, effective dim drops at rare samples)',
370
+ fontsize=14, fontweight='bold', y=1.02)
371
+
372
+ plt.tight_layout()
373
+ plt.savefig(filename, dpi=150, bbox_inches='tight')
374
+ plt.close()
375
+ print(f"Effective dimension plot saved to: {filename}")
376
+
377
+
378
+ def plot_class_accuracy_comparison(metrics_no_clip: Dict, metrics_with_clip: Dict,
379
+ filename: str):
380
+ """
381
+ Plot per-class accuracy comparison.
382
+
383
+ This tests Prediction 4: With clipping, the model should achieve
384
+ better accuracy on rare samples (class 'B').
385
+ """
386
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
387
+
388
+ # Class A (common) - Without vs With
389
+ ax_a = axes[0, 0]
390
+ steps_no = metrics_no_clip['accuracy_steps']
391
+ steps_with = metrics_with_clip['accuracy_steps']
392
+
393
+ ax_a.plot(steps_no, metrics_no_clip['class_accuracies'][0], 'r-',
394
+ linewidth=1.5, alpha=0.7, label='Without Clipping')
395
+ ax_a.plot(steps_with, metrics_with_clip['class_accuracies'][0], 'g-',
396
+ linewidth=1.5, alpha=0.7, label='With Clipping')
397
+ ax_a.set_ylabel('Accuracy', fontsize=11)
398
+ ax_a.set_title("Class 'A' (Common - 990 samples)", fontsize=12, fontweight='bold')
399
+ ax_a.legend()
400
+ ax_a.grid(True, alpha=0.3)
401
+ ax_a.set_ylim([0, 1.05])
402
+
403
+ # Class B (rare) - Without vs With
404
+ ax_b = axes[0, 1]
405
+ ax_b.plot(steps_no, metrics_no_clip['class_accuracies'][1], 'r-',
406
+ linewidth=1.5, alpha=0.7, label='Without Clipping')
407
+ ax_b.plot(steps_with, metrics_with_clip['class_accuracies'][1], 'g-',
408
+ linewidth=1.5, alpha=0.7, label='With Clipping')
409
+ ax_b.set_ylabel('Accuracy', fontsize=11)
410
+ ax_b.set_title("Class 'B' (Rare - 10 samples) ⭐ KEY PREDICTION",
411
+ fontsize=12, fontweight='bold', color='purple')
412
+ ax_b.legend()
413
+ ax_b.grid(True, alpha=0.3)
414
+ ax_b.set_ylim([0, 1.05])
415
+
416
+ # Accuracy difference (With - Without) for rare class
417
+ ax_diff = axes[1, 0]
418
+ acc_b_no = np.array(metrics_no_clip['class_accuracies'][1])
419
+ acc_b_with = np.array(metrics_with_clip['class_accuracies'][1])
420
+ min_len = min(len(acc_b_no), len(acc_b_with))
421
+ diff = acc_b_with[:min_len] - acc_b_no[:min_len]
422
+
423
+ colors = ['green' if d >= 0 else 'red' for d in diff]
424
+ ax_diff.bar(steps_no[:min_len], diff, color=colors, alpha=0.7, width=8)
425
+ ax_diff.axhline(y=0, color='black', linestyle='-', linewidth=1)
426
+ ax_diff.set_ylabel('Accuracy Difference\n(With Clip - Without Clip)', fontsize=11)
427
+ ax_diff.set_xlabel('Training Step', fontsize=11)
428
+ ax_diff.set_title("Rare Class 'B': Clipping Benefit", fontsize=12, fontweight='bold')
429
+ ax_diff.grid(True, alpha=0.3)
430
+
431
+ # Summary statistics
432
+ ax_summary = axes[1, 1]
433
+ ax_summary.axis('off')
434
+
435
+ # Compute final accuracies
436
+ final_acc_a_no = metrics_no_clip['class_accuracies'][0][-1]
437
+ final_acc_a_with = metrics_with_clip['class_accuracies'][0][-1]
438
+ final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1]
439
+ final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1]
440
+
441
+ summary_text = f"""
442
+ PREDICTION 4 TEST RESULTS
443
+ ═══════════════════════════════════════
444
+
445
+ Hypothesis: With clipping, the model should
446
+ achieve better accuracy on rare samples.
447
+
448
+ FINAL ACCURACIES:
449
+ ─────────────────────────────────────────
450
+ Class 'A' (Common):
451
+ Without Clipping: {final_acc_a_no:.1%}
452
+ With Clipping: {final_acc_a_with:.1%}
453
+ Difference: {final_acc_a_with - final_acc_a_no:+.1%}
454
+
455
+ Class 'B' (Rare):
456
+ Without Clipping: {final_acc_b_no:.1%}
457
+ With Clipping: {final_acc_b_with:.1%}
458
+ Difference: {final_acc_b_with - final_acc_b_no:+.1%}
459
+
460
+ ─────────────────────────────────────────
461
+ VERDICT: {'✅ PREDICTION SUPPORTED' if final_acc_b_with >= final_acc_b_no else '❌ PREDICTION NOT SUPPORTED'}
462
+ (Clipping {'improves' if final_acc_b_with > final_acc_b_no else 'does not improve'} rare class accuracy)
463
+ """
464
+
465
+ ax_summary.text(0.1, 0.5, summary_text, transform=ax_summary.transAxes,
466
+ fontsize=11, verticalalignment='center', fontfamily='monospace',
467
+ bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
468
+
469
+ fig.suptitle('Prediction 4: Rare Sample Learning Test\n'
470
+ '(Hypothesis: Clipping improves accuracy on rare samples)',
471
+ fontsize=14, fontweight='bold', y=1.02)
472
+
473
+ plt.tight_layout()
474
+ plt.savefig(filename, dpi=150, bbox_inches='tight')
475
+ plt.close()
476
+ print(f"Class accuracy plot saved to: {filename}")
477
+
478
+
479
+ def plot_combined_analysis(metrics_no_clip: Dict, metrics_with_clip: Dict,
480
+ rare_indices: List[int], filename: str,
481
+ n_samples: int = 1000):
482
+ """
483
+ Create a comprehensive 6-panel analysis plot.
484
+ """
485
+ fig = plt.figure(figsize=(18, 14))
486
+
487
+ # Create grid
488
+ gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.25)
489
+
490
+ n_epochs = len(metrics_no_clip['losses']) // n_samples
491
+
492
+ # Row 1: Effective Dimension
493
+ ax1 = fig.add_subplot(gs[0, 0])
494
+ ax2 = fig.add_subplot(gs[0, 1])
495
+
496
+ # Without clipping
497
+ ax1.plot(metrics_no_clip['effective_dim_steps'], metrics_no_clip['effective_dims'],
498
+ 'b-', linewidth=1.5, marker='o', markersize=2, alpha=0.7)
499
+ ax1.set_ylabel('Effective Dimension', fontsize=11)
500
+ ax1.set_title('Effective Dim - WITHOUT Clipping', fontsize=12, fontweight='bold', color='red')
501
+ ax1.grid(True, alpha=0.3)
502
+ ax1.set_ylim([0, 16])
503
+ for epoch in range(n_epochs):
504
+ for idx in rare_indices:
505
+ ax1.axvline(x=epoch * n_samples + idx, color='red', alpha=0.2, linewidth=1)
506
+
507
+ # With clipping
508
+ ax2.plot(metrics_with_clip['effective_dim_steps'], metrics_with_clip['effective_dims'],
509
+ 'g-', linewidth=1.5, marker='o', markersize=2, alpha=0.7)
510
+ ax2.set_title('Effective Dim - WITH Clipping', fontsize=12, fontweight='bold', color='green')
511
+ ax2.grid(True, alpha=0.3)
512
+ ax2.set_ylim([0, 16])
513
+ for epoch in range(n_epochs):
514
+ for idx in rare_indices:
515
+ ax2.axvline(x=epoch * n_samples + idx, color='red', alpha=0.2, linewidth=1)
516
+
517
+ # Row 2: Class Accuracies
518
+ ax3 = fig.add_subplot(gs[1, 0])
519
+ ax4 = fig.add_subplot(gs[1, 1])
520
+
521
+ # Common class A
522
+ ax3.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][0],
523
+ 'r-', linewidth=1.5, alpha=0.7, label='Without Clip')
524
+ ax3.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][0],
525
+ 'g-', linewidth=1.5, alpha=0.7, label='With Clip')
526
+ ax3.set_ylabel('Accuracy', fontsize=11)
527
+ ax3.set_title("Common Class 'A' Accuracy", fontsize=12, fontweight='bold')
528
+ ax3.legend()
529
+ ax3.grid(True, alpha=0.3)
530
+ ax3.set_ylim([0, 1.05])
531
+
532
+ # Rare class B
533
+ ax4.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][1],
534
+ 'r-', linewidth=1.5, alpha=0.7, label='Without Clip')
535
+ ax4.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][1],
536
+ 'g-', linewidth=1.5, alpha=0.7, label='With Clip')
537
+ ax4.set_title("Rare Class 'B' Accuracy ⭐", fontsize=12, fontweight='bold', color='purple')
538
+ ax4.legend()
539
+ ax4.grid(True, alpha=0.3)
540
+ ax4.set_ylim([0, 1.05])
541
+
542
+ # Row 3: Gradient Norms and Weight Norms
543
+ ax5 = fig.add_subplot(gs[2, 0])
544
+ ax6 = fig.add_subplot(gs[2, 1])
545
+
546
+ steps = range(len(metrics_no_clip['grad_norms']))
547
+
548
+ # Gradient norms
549
+ ax5.plot(steps, metrics_no_clip['grad_norms'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip')
550
+ ax5.plot(steps, metrics_with_clip['grad_norms'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip')
551
+ ax5.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold')
552
+ ax5.set_ylabel('Gradient Norm', fontsize=11)
553
+ ax5.set_xlabel('Training Step', fontsize=11)
554
+ ax5.set_title('Gradient Norms Comparison', fontsize=12, fontweight='bold')
555
+ ax5.legend()
556
+ ax5.grid(True, alpha=0.3)
557
+
558
+ # Weight norms
559
+ ax6.plot(steps, metrics_no_clip['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip')
560
+ ax6.plot(steps, metrics_with_clip['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip')
561
+ ax6.set_xlabel('Training Step', fontsize=11)
562
+ ax6.set_title('Weight Norms Comparison', fontsize=12, fontweight='bold')
563
+ ax6.legend()
564
+ ax6.grid(True, alpha=0.3)
565
+
566
+ fig.suptitle('Extended Gradient Clipping Analysis: Testing Physics-of-AI Predictions\n'
567
+ '(Red vertical lines = rare sample positions)',
568
+ fontsize=14, fontweight='bold', y=1.01)
569
+
570
+ plt.savefig(filename, dpi=150, bbox_inches='tight')
571
+ plt.close()
572
+ print(f"Combined analysis plot saved to: {filename}")
573
+
574
+
575
+ # =============================================================================
576
+ # 7. MAIN EXECUTION
577
+ # =============================================================================
578
+
579
+ def main():
580
+ print("="*70)
581
+ print("EXTENDED GRADIENT CLIPPING EXPERIMENT")
582
+ print("Testing Physics-of-AI Predictions")
583
+ print("="*70)
584
+
585
+ # Create dataset
586
+ inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED)
587
+
588
+ print(f"\nDataset created:")
589
+ print(f" Total samples: {len(inputs)}")
590
+ print(f" Target 'A' (0): {(targets == 0).sum().item()}")
591
+ print(f" Target 'B' (1): {(targets == 1).sum().item()}")
592
+ print(f" Rare 'B' indices: {rare_indices}")
593
+
594
+ # Get initial weights
595
+ set_seeds(SEED)
596
+ init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
597
+ init_weights = {name: param.clone() for name, param in init_model.state_dict().items()}
598
+
599
+ # Initial effective dimension
600
+ init_eff_dim = compute_effective_dimension(init_model.get_embeddings())
601
+ print(f"\nInitial embedding effective dimension: {init_eff_dim:.3f}")
602
+
603
+ # Run training WITHOUT gradient clipping
604
+ metrics_no_clip = train_with_tracking(
605
+ inputs, targets, rare_indices,
606
+ clip_grad=False, n_epochs=3, lr=0.1,
607
+ init_weights=init_weights, track_every=5
608
+ )
609
+
610
+ # Run training WITH gradient clipping
611
+ metrics_with_clip = train_with_tracking(
612
+ inputs, targets, rare_indices,
613
+ clip_grad=True, max_norm=1.0, n_epochs=3, lr=0.1,
614
+ init_weights=init_weights, track_every=5
615
+ )
616
+
617
+ # Generate plots
618
+ print("\n" + "="*70)
619
+ print("GENERATING ANALYSIS PLOTS")
620
+ print("="*70)
621
+
622
+ plot_effective_dimension_comparison(
623
+ metrics_no_clip, metrics_with_clip, rare_indices,
624
+ "effective_dimension_comparison.png"
625
+ )
626
+
627
+ plot_class_accuracy_comparison(
628
+ metrics_no_clip, metrics_with_clip,
629
+ "class_accuracy_comparison.png"
630
+ )
631
+
632
+ plot_combined_analysis(
633
+ metrics_no_clip, metrics_with_clip, rare_indices,
634
+ "combined_analysis.png"
635
+ )
636
+
637
+ # Print summary
638
+ print("\n" + "="*70)
639
+ print("PREDICTION TEST RESULTS")
640
+ print("="*70)
641
+
642
+ # Prediction 2: Representation Collapse
643
+ print("\n📊 PREDICTION 2: Representation Collapse")
644
+ print("-" * 50)
645
+
646
+ dims_no = metrics_no_clip['effective_dims']
647
+ dims_with = metrics_with_clip['effective_dims']
648
+
649
+ print(f"Effective Dimension Statistics:")
650
+ print(f" WITHOUT Clipping:")
651
+ print(f" Initial: {dims_no[0]:.3f}")
652
+ print(f" Final: {dims_no[-1]:.3f}")
653
+ print(f" Min: {min(dims_no):.3f}")
654
+ print(f" Max: {max(dims_no):.3f}")
655
+ print(f" Std: {np.std(dims_no):.3f}")
656
+
657
+ print(f" WITH Clipping:")
658
+ print(f" Initial: {dims_with[0]:.3f}")
659
+ print(f" Final: {dims_with[-1]:.3f}")
660
+ print(f" Min: {min(dims_with):.3f}")
661
+ print(f" Max: {max(dims_with):.3f}")
662
+ print(f" Std: {np.std(dims_with):.3f}")
663
+
664
+ # Check if without clipping has more variance (indicating sudden drops)
665
+ collapse_supported = np.std(dims_no) > np.std(dims_with)
666
+ print(f"\n Verdict: {'✅ SUPPORTED' if collapse_supported else '❌ NOT SUPPORTED'}")
667
+ print(f" (Without clipping has {'higher' if collapse_supported else 'lower'} variance in effective dim)")
668
+
669
+ # Prediction 4: Rare Sample Learning
670
+ print("\n📊 PREDICTION 4: Rare Sample Learning")
671
+ print("-" * 50)
672
+
673
+ final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1]
674
+ final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1]
675
+
676
+ print(f"Final Rare Class 'B' Accuracy:")
677
+ print(f" WITHOUT Clipping: {final_acc_b_no:.1%}")
678
+ print(f" WITH Clipping: {final_acc_b_with:.1%}")
679
+ print(f" Difference: {final_acc_b_with - final_acc_b_no:+.1%}")
680
+
681
+ rare_learning_supported = final_acc_b_with >= final_acc_b_no
682
+ print(f"\n Verdict: {'✅ SUPPORTED' if rare_learning_supported else '❌ NOT SUPPORTED'}")
683
+
684
+ # Return results for further analysis
685
+ return {
686
+ 'metrics_no_clip': metrics_no_clip,
687
+ 'metrics_with_clip': metrics_with_clip,
688
+ 'rare_indices': rare_indices,
689
+ 'prediction_2_supported': collapse_supported,
690
+ 'prediction_4_supported': rare_learning_supported,
691
+ }
692
+
693
+
694
+ if __name__ == "__main__":
695
+ results = main()
696
+ print("\n" + "="*70)
697
+ print("EXPERIMENT COMPLETE!")
698
+ print("="*70)