""" Final Gradient Clipping Experiment: Testing Physics-of-AI Predictions Key insights from previous experiments: 1. With extreme imbalance (99:1), neither model learns rare class 2. Gradient clipping's benefit is in STABILITY, not learning rare classes per se 3. The key effect is on WEIGHT NORM STABILITY and GRADIENT SPIKE HANDLING This experiment tests: 1. Prediction 2: Representation Collapse - effective dim variance without clipping 2. Prediction 4: Rare Sample Learning - using moderate imbalance (80:20) 3. NEW: Weight norm stability analysis 4. NEW: Gradient spike analysis at rare sample positions """ import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt import random from typing import Dict, List SEED = 42 def set_seeds(seed=SEED): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) class SimpleNextTokenModel(nn.Module): def __init__(self, vocab_size=4, embedding_dim=16): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.linear = nn.Linear(embedding_dim, vocab_size) def forward(self, x): embedded = self.embedding(x) logits = self.linear(embedded) return logits def get_embeddings(self): return self.embedding.weight.data.clone() def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float: """PCA-based effective dimensionality.""" centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True) cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1) eigenvalues = torch.linalg.eigvalsh(cov) eigenvalues = torch.clamp(eigenvalues, min=1e-10) eigenvalues = eigenvalues / eigenvalues.sum() entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) return torch.exp(entropy).item() def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor, targets: torch.Tensor) -> Dict[int, float]: model.eval() with torch.no_grad(): logits = model(inputs) predictions = logits.argmax(dim=1) accuracies = {} for class_idx in range(4): mask = targets == class_idx if mask.sum() > 0: correct = (predictions[mask] == targets[mask]).float().mean().item() accuracies[class_idx] = correct else: accuracies[class_idx] = None return accuracies def create_dataset_moderate_imbalance(n_samples=1000, rare_ratio=0.2, seed=SEED): """Create dataset with moderate imbalance (e.g., 80:20).""" set_seeds(seed) n_rare = int(n_samples * rare_ratio) n_common = n_samples - n_rare inputs = torch.randint(0, 4, (n_samples,)) targets = torch.zeros(n_samples, dtype=torch.long) rare_indices = random.sample(range(n_samples), n_rare) targets[rare_indices] = 1 return inputs, targets, sorted(rare_indices) def create_dataset_extreme_imbalance(n_samples=1000, n_rare=10, seed=SEED): """Create dataset with extreme imbalance (99:1).""" set_seeds(seed) inputs = torch.randint(0, 4, (n_samples,)) targets = torch.zeros(n_samples, dtype=torch.long) rare_indices = random.sample(range(n_samples), n_rare) targets[rare_indices] = 1 return inputs, targets, sorted(rare_indices) def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor, rare_indices: List[int], clip_grad: bool = False, max_norm: float = 1.0, n_epochs: int = 10, lr: float = 0.1, init_weights=None, track_every: int = 50) -> Dict: """Training with comprehensive tracking.""" set_seeds(SEED) model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) if init_weights: model.load_state_dict({k: v.clone() for k, v in init_weights.items()}) optimizer = optim.SGD(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() metrics = { 'losses': [], 'grad_norms': [], 'weight_norms': [], 'effective_dims': [], 'effective_dim_steps': [], 'class_accuracies': {0: [], 1: [], 2: [], 3: []}, 'accuracy_steps': [], 'weight_norm_changes': [], # Track sudden changes } step = 0 n_samples = len(inputs) prev_weight_norm = None for epoch in range(n_epochs): model.train() for i in range(n_samples): x = inputs[i:i+1] y = targets[i:i+1] optimizer.zero_grad() logits = model(x) loss = criterion(logits, y) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) if clip_grad: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() metrics['losses'].append(loss.item()) metrics['grad_norms'].append(grad_norm.item()) current_weight_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 metrics['weight_norms'].append(current_weight_norm) # Track weight norm change if prev_weight_norm is not None: metrics['weight_norm_changes'].append(abs(current_weight_norm - prev_weight_norm)) else: metrics['weight_norm_changes'].append(0) prev_weight_norm = current_weight_norm # Track periodically if step % track_every == 0: emb_matrix = model.get_embeddings() eff_dim = compute_effective_dimension(emb_matrix) metrics['effective_dims'].append(eff_dim) metrics['effective_dim_steps'].append(step) class_acc = compute_per_class_accuracy(model, inputs, targets) for cls_idx in range(4): if class_acc[cls_idx] is not None: metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx]) else: metrics['class_accuracies'][cls_idx].append(0.0) metrics['accuracy_steps'].append(step) step += 1 return metrics def run_experiment_suite(): """Run complete experiment suite with both imbalance levels.""" print("="*70) print("FINAL GRADIENT CLIPPING EXPERIMENT") print("Testing Physics-of-AI Predictions") print("="*70) # Get initial weights set_seeds(SEED) init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) init_weights = {name: param.clone() for name, param in init_model.state_dict().items()} results = {} # ========================================================================= # EXPERIMENT 1: Extreme Imbalance (99:1) - Original Setup # ========================================================================= print("\n" + "="*70) print("EXPERIMENT 1: EXTREME IMBALANCE (99:1)") print("="*70) inputs_extreme, targets_extreme, rare_extreme = create_dataset_extreme_imbalance( n_samples=1000, n_rare=10, seed=SEED ) print(f"Dataset: {(targets_extreme == 0).sum().item()} common, {(targets_extreme == 1).sum().item()} rare") print("\nTraining WITHOUT clipping...") metrics_extreme_no_clip = train_with_tracking( inputs_extreme, targets_extreme, rare_extreme, clip_grad=False, n_epochs=5, lr=0.1, init_weights=init_weights, track_every=100 ) print("Training WITH clipping...") metrics_extreme_with_clip = train_with_tracking( inputs_extreme, targets_extreme, rare_extreme, clip_grad=True, max_norm=1.0, n_epochs=5, lr=0.1, init_weights=init_weights, track_every=100 ) results['extreme'] = { 'no_clip': metrics_extreme_no_clip, 'with_clip': metrics_extreme_with_clip, 'rare_indices': rare_extreme } # ========================================================================= # EXPERIMENT 2: Moderate Imbalance (80:20) # ========================================================================= print("\n" + "="*70) print("EXPERIMENT 2: MODERATE IMBALANCE (80:20)") print("="*70) inputs_moderate, targets_moderate, rare_moderate = create_dataset_moderate_imbalance( n_samples=1000, rare_ratio=0.2, seed=SEED ) print(f"Dataset: {(targets_moderate == 0).sum().item()} common, {(targets_moderate == 1).sum().item()} rare") print("\nTraining WITHOUT clipping...") metrics_moderate_no_clip = train_with_tracking( inputs_moderate, targets_moderate, rare_moderate, clip_grad=False, n_epochs=10, lr=0.1, init_weights=init_weights, track_every=100 ) print("Training WITH clipping...") metrics_moderate_with_clip = train_with_tracking( inputs_moderate, targets_moderate, rare_moderate, clip_grad=True, max_norm=1.0, n_epochs=10, lr=0.1, init_weights=init_weights, track_every=100 ) results['moderate'] = { 'no_clip': metrics_moderate_no_clip, 'with_clip': metrics_moderate_with_clip, 'rare_indices': rare_moderate } return results def plot_final_comparison(results: Dict, filename: str): """Create final comparison plot.""" fig = plt.figure(figsize=(20, 20)) gs = fig.add_gridspec(5, 2, hspace=0.35, wspace=0.25) # ========================================================================= # Row 1: Weight Norm Stability (Key Physics-of-AI Insight) # ========================================================================= ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) # Extreme imbalance steps = range(len(results['extreme']['no_clip']['weight_norms'])) ax1.plot(steps, results['extreme']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') ax1.plot(steps, results['extreme']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') ax1.set_ylabel('Weight Norm', fontsize=11) ax1.set_title('EXTREME (99:1) - Weight Norm Evolution', fontsize=12, fontweight='bold') ax1.legend() ax1.grid(True, alpha=0.3) # Moderate imbalance steps = range(len(results['moderate']['no_clip']['weight_norms'])) ax2.plot(steps, results['moderate']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') ax2.plot(steps, results['moderate']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') ax2.set_title('MODERATE (80:20) - Weight Norm Evolution', fontsize=12, fontweight='bold') ax2.legend() ax2.grid(True, alpha=0.3) # ========================================================================= # Row 2: Weight Norm Changes (Stability Metric) # ========================================================================= ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) # Extreme steps = range(len(results['extreme']['no_clip']['weight_norm_changes'])) ax3.plot(steps, results['extreme']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip') ax3.plot(steps, results['extreme']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip') ax3.set_ylabel('|Weight Norm Change|', fontsize=11) ax3.set_title('EXTREME - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold') ax3.legend() ax3.grid(True, alpha=0.3) # Moderate steps = range(len(results['moderate']['no_clip']['weight_norm_changes'])) ax4.plot(steps, results['moderate']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip') ax4.plot(steps, results['moderate']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip') ax4.set_title('MODERATE - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold') ax4.legend() ax4.grid(True, alpha=0.3) # ========================================================================= # Row 3: Gradient Norms # ========================================================================= ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) # Extreme steps = range(len(results['extreme']['no_clip']['grad_norms'])) ax5.plot(steps, results['extreme']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') ax5.plot(steps, results['extreme']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') ax5.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') ax5.set_ylabel('Gradient Norm', fontsize=11) ax5.set_title('EXTREME - Gradient Norms', fontsize=12, fontweight='bold') ax5.legend() ax5.grid(True, alpha=0.3) # Moderate steps = range(len(results['moderate']['no_clip']['grad_norms'])) ax6.plot(steps, results['moderate']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') ax6.plot(steps, results['moderate']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') ax6.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') ax6.set_title('MODERATE - Gradient Norms', fontsize=12, fontweight='bold') ax6.legend() ax6.grid(True, alpha=0.3) # ========================================================================= # Row 4: Effective Dimension # ========================================================================= ax7 = fig.add_subplot(gs[3, 0]) ax8 = fig.add_subplot(gs[3, 1]) # Extreme ax7.plot(results['extreme']['no_clip']['effective_dim_steps'], results['extreme']['no_clip']['effective_dims'], 'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip') ax7.plot(results['extreme']['with_clip']['effective_dim_steps'], results['extreme']['with_clip']['effective_dims'], 'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip') ax7.set_ylabel('Effective Dimension', fontsize=11) ax7.set_title('EXTREME - Effective Dimensionality', fontsize=12, fontweight='bold') ax7.legend() ax7.grid(True, alpha=0.3) # Moderate ax8.plot(results['moderate']['no_clip']['effective_dim_steps'], results['moderate']['no_clip']['effective_dims'], 'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip') ax8.plot(results['moderate']['with_clip']['effective_dim_steps'], results['moderate']['with_clip']['effective_dims'], 'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip') ax8.set_title('MODERATE - Effective Dimensionality', fontsize=12, fontweight='bold') ax8.legend() ax8.grid(True, alpha=0.3) # ========================================================================= # Row 5: Class Accuracies # ========================================================================= ax9 = fig.add_subplot(gs[4, 0]) ax10 = fig.add_subplot(gs[4, 1]) # Extreme - Rare class B ax9.plot(results['extreme']['no_clip']['accuracy_steps'], results['extreme']['no_clip']['class_accuracies'][1], 'r-', alpha=0.7, linewidth=2, label='Without Clip') ax9.plot(results['extreme']['with_clip']['accuracy_steps'], results['extreme']['with_clip']['class_accuracies'][1], 'g-', alpha=0.7, linewidth=2, label='With Clip') ax9.set_ylabel('Rare Class B Accuracy', fontsize=11) ax9.set_xlabel('Training Step', fontsize=11) ax9.set_title('EXTREME - Rare Class Accuracy', fontsize=12, fontweight='bold') ax9.legend() ax9.grid(True, alpha=0.3) ax9.set_ylim([0, 1.05]) # Moderate - Rare class B ax10.plot(results['moderate']['no_clip']['accuracy_steps'], results['moderate']['no_clip']['class_accuracies'][1], 'r-', alpha=0.7, linewidth=2, label='Without Clip') ax10.plot(results['moderate']['with_clip']['accuracy_steps'], results['moderate']['with_clip']['class_accuracies'][1], 'g-', alpha=0.7, linewidth=2, label='With Clip') ax10.set_xlabel('Training Step', fontsize=11) ax10.set_title('MODERATE - Rare Class Accuracy', fontsize=12, fontweight='bold') ax10.legend() ax10.grid(True, alpha=0.3) ax10.set_ylim([0, 1.05]) fig.suptitle('Gradient Clipping Analysis: Physics-of-AI Predictions\n' 'Comparing Extreme (99:1) vs Moderate (80:20) Class Imbalance', fontsize=14, fontweight='bold', y=1.01) plt.savefig(filename, dpi=150, bbox_inches='tight') plt.close() print(f"Final comparison plot saved to: {filename}") def compute_statistics(results: Dict) -> Dict: """Compute summary statistics for all experiments.""" stats = {} for imbalance in ['extreme', 'moderate']: no_clip = results[imbalance]['no_clip'] with_clip = results[imbalance]['with_clip'] stats[imbalance] = { 'weight_norm_std': { 'no_clip': np.std(no_clip['weight_norms']), 'with_clip': np.std(with_clip['weight_norms']), }, 'weight_change_mean': { 'no_clip': np.mean(no_clip['weight_norm_changes']), 'with_clip': np.mean(with_clip['weight_norm_changes']), }, 'weight_change_max': { 'no_clip': np.max(no_clip['weight_norm_changes']), 'with_clip': np.max(with_clip['weight_norm_changes']), }, 'grad_norm_max': { 'no_clip': np.max(no_clip['grad_norms']), 'with_clip': np.max(with_clip['grad_norms']), }, 'effective_dim_std': { 'no_clip': np.std(no_clip['effective_dims']), 'with_clip': np.std(with_clip['effective_dims']), }, 'final_rare_acc': { 'no_clip': no_clip['class_accuracies'][1][-1] if no_clip['class_accuracies'][1] else 0, 'with_clip': with_clip['class_accuracies'][1][-1] if with_clip['class_accuracies'][1] else 0, }, } return stats def print_summary(stats: Dict): """Print formatted summary.""" print("\n" + "="*70) print("EXPERIMENT SUMMARY") print("="*70) for imbalance in ['extreme', 'moderate']: s = stats[imbalance] label = "EXTREME (99:1)" if imbalance == 'extreme' else "MODERATE (80:20)" print(f"\n{label}") print("-" * 50) print(f"\n[PREDICTION 2] Representation Collapse (Effective Dim Variance):") print(f" WITHOUT Clipping: {s['effective_dim_std']['no_clip']:.6f}") print(f" WITH Clipping: {s['effective_dim_std']['with_clip']:.6f}") supported = s['effective_dim_std']['no_clip'] > s['effective_dim_std']['with_clip'] print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}") print(f"\n[PREDICTION 4] Rare Sample Learning:") print(f" Final Rare Accuracy (WITHOUT): {s['final_rare_acc']['no_clip']:.1%}") print(f" Final Rare Accuracy (WITH): {s['final_rare_acc']['with_clip']:.1%}") supported = s['final_rare_acc']['with_clip'] >= s['final_rare_acc']['no_clip'] print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}") print(f"\n[STABILITY] Weight Norm Analysis:") print(f" Weight Norm Std (WITHOUT): {s['weight_norm_std']['no_clip']:.4f}") print(f" Weight Norm Std (WITH): {s['weight_norm_std']['with_clip']:.4f}") print(f" Max Weight Change (WITHOUT): {s['weight_change_max']['no_clip']:.4f}") print(f" Max Weight Change (WITH): {s['weight_change_max']['with_clip']:.4f}") print(f"\n[GRADIENT] Analysis:") print(f" Max Gradient Norm (WITHOUT): {s['grad_norm_max']['no_clip']:.4f}") print(f" Max Gradient Norm (WITH): {s['grad_norm_max']['with_clip']:.4f}") print(f" Clipping Ratio: {s['grad_norm_max']['no_clip'] / 1.0:.1f}x threshold") def main(): # Run experiments results = run_experiment_suite() # Generate plots print("\n" + "="*70) print("GENERATING PLOTS") print("="*70) plot_final_comparison(results, "final_comparison.png") # Compute and print statistics stats = compute_statistics(results) print_summary(stats) return results, stats if __name__ == "__main__": results, stats = main() print("\n" + "="*70) print("EXPERIMENT COMPLETE!") print("="*70)