""" Gradient Clipping Experiment This script demonstrates how gradient clipping stabilizes training by preventing sudden large weight updates caused by rare, high-loss data points. Experiment Setup: - Simple model: Embedding(4, 16) -> Linear(16, 4) - Vocabulary: ['A', 'B', 'C', 'D'] - Dataset: 1000 samples with imbalanced targets (990 'A', 10 'B') - Compare training with and without gradient clipping """ import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt import random # Set seeds for reproducibility SEED = 42 def set_seeds(seed=SEED): """Set all random seeds for reproducibility.""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # ============================================================================= # 1. MODEL DEFINITION # ============================================================================= class SimpleNextTokenModel(nn.Module): """ Simple model that takes a token index and predicts the next token. Architecture: Embedding -> Linear """ def __init__(self, vocab_size=4, embedding_dim=16): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.linear = nn.Linear(embedding_dim, vocab_size) def forward(self, x): """ Args: x: Token indices of shape (batch_size,) Returns: Logits of shape (batch_size, vocab_size) """ embedded = self.embedding(x) # (batch_size, embedding_dim) logits = self.linear(embedded) # (batch_size, vocab_size) return logits # ============================================================================= # 2. DATASET CREATION # ============================================================================= def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED): """ Create a synthetic dataset with imbalanced targets. Args: n_samples: Total number of samples n_rare: Number of rare 'B' samples seed: Random seed for reproducibility Returns: inputs: Random token indices (0-3) targets: 990 'A' (0) and 10 'B' (1) rare_indices: Indices where target is 'B' """ # Set seed for reproducibility set_seeds(seed) vocab = {'A': 0, 'B': 1, 'C': 2, 'D': 3} # Random input tokens inputs = torch.randint(0, 4, (n_samples,)) # Create imbalanced targets: mostly 'A' (0), few 'B' (1) targets = torch.zeros(n_samples, dtype=torch.long) # All 'A' initially # Randomly select indices for rare 'B' samples rare_indices = random.sample(range(n_samples), n_rare) targets[rare_indices] = 1 # Set to 'B' return inputs, targets, sorted(rare_indices) # ============================================================================= # 3. UTILITY FUNCTIONS # ============================================================================= def compute_weight_norm(model): """Compute L2 norm of all model weights.""" total_norm = 0.0 for param in model.parameters(): total_norm += param.data.norm(2).item() ** 2 return total_norm ** 0.5 def get_initial_weights(seed=SEED): """Get initial weights for reproducible model initialization.""" set_seeds(seed) model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) return {name: param.clone() for name, param in model.state_dict().items()} def train_epoch(model, optimizer, criterion, inputs, targets, clip_grad=False, max_norm=1.0): """ Train for one epoch, recording metrics at each step. Args: model: The neural network optimizer: SGD optimizer criterion: CrossEntropyLoss inputs: Input token indices targets: Target token indices clip_grad: Whether to apply gradient clipping max_norm: Maximum gradient norm (if clipping) Returns: losses: List of losses per step grad_norms: List of gradient norms per step (before clipping) weight_norms: List of weight norms per step """ model.train() losses = [] grad_norms = [] weight_norms = [] # Train on each sample individually to see the effect of rare samples for i in range(len(inputs)): x = inputs[i:i+1] # Single sample y = targets[i:i+1] optimizer.zero_grad() # Forward pass logits = model(x) loss = criterion(logits, y) # Backward pass loss.backward() # Compute gradient norm BEFORE clipping # Use a large value to just compute the norm without clipping grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) # Apply gradient clipping if requested if clip_grad: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # Update weights optimizer.step() # Record metrics losses.append(loss.item()) grad_norms.append(grad_norm.item()) weight_norms.append(compute_weight_norm(model)) return losses, grad_norms, weight_norms # ============================================================================= # 4. TRAINING FUNCTIONS # ============================================================================= def run_training(inputs, targets, rare_indices, clip_grad=False, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=None): """ Run complete training loop. Args: inputs: Input token indices targets: Target token indices rare_indices: Indices of rare 'B' samples clip_grad: Whether to apply gradient clipping max_norm: Maximum gradient norm threshold n_epochs: Number of training epochs lr: Learning rate init_weights: Initial model weights for reproducibility Returns: all_losses, all_grad_norms, all_weight_norms: Metrics across all steps """ # Create fresh model with same initial weights set_seeds(SEED) model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) if init_weights: model.load_state_dict(init_weights) optimizer = optim.SGD(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() all_losses = [] all_grad_norms = [] all_weight_norms = [] mode = "WITH" if clip_grad else "WITHOUT" print(f"\n{'='*60}") print(f"Training {mode} gradient clipping (max_norm={max_norm})") print(f"{'='*60}") for epoch in range(n_epochs): losses, grad_norms, weight_norms = train_epoch( model, optimizer, criterion, inputs, targets, clip_grad=clip_grad, max_norm=max_norm ) all_losses.extend(losses) all_grad_norms.extend(grad_norms) all_weight_norms.extend(weight_norms) avg_loss = np.mean(losses) max_grad = np.max(grad_norms) print(f"Epoch {epoch+1}/{n_epochs}: Avg Loss={avg_loss:.4f}, Max Grad Norm={max_grad:.4f}") return all_losses, all_grad_norms, all_weight_norms # ============================================================================= # 5. PLOTTING FUNCTIONS # ============================================================================= def plot_metrics(losses, grad_norms, weight_norms, title, filename, rare_indices=None, n_samples=1000): """ Plot training metrics: loss, gradient norm, and weight norm. Args: losses: List of losses per step grad_norms: List of gradient norms per step weight_norms: List of weight norms per step title: Plot title filename: Output filename rare_indices: Indices of rare 'B' samples (for highlighting) n_samples: Number of samples per epoch """ fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True) steps = range(len(losses)) n_epochs = len(losses) // n_samples # Plot 1: Training Loss axes[0].plot(steps, losses, 'b-', alpha=0.7, linewidth=0.5) axes[0].set_ylabel('Training Loss', fontsize=12) axes[0].set_title(title, fontsize=14, fontweight='bold') axes[0].grid(True, alpha=0.3) # Highlight rare sample positions if rare_indices: for epoch in range(n_epochs): for idx in rare_indices: step = epoch * n_samples + idx if step < len(losses): axes[0].axvline(x=step, color='red', alpha=0.3, linewidth=0.5) # Plot 2: Gradient Norm axes[1].plot(steps, grad_norms, 'g-', alpha=0.7, linewidth=0.5) axes[1].set_ylabel('Gradient L2 Norm', fontsize=12) axes[1].grid(True, alpha=0.3) # Add horizontal line at clipping threshold if "With" in title or "WITH" in title: axes[1].axhline(y=1.0, color='red', linestyle='--', label='Clip threshold (1.0)') axes[1].legend() if rare_indices: for epoch in range(n_epochs): for idx in rare_indices: step = epoch * n_samples + idx if step < len(grad_norms): axes[1].axvline(x=step, color='red', alpha=0.3, linewidth=0.5) # Plot 3: Weight Norm axes[2].plot(steps, weight_norms, 'm-', alpha=0.7, linewidth=0.5) axes[2].set_ylabel('Weight L2 Norm', fontsize=12) axes[2].set_xlabel('Training Step', fontsize=12) axes[2].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(filename, dpi=150, bbox_inches='tight') plt.close() print(f"Plot saved to: {filename}") def plot_comparison(metrics_no_clip, metrics_with_clip, rare_indices, filename, n_samples=1000): """ Create side-by-side comparison plot. Args: metrics_no_clip: (losses, grad_norms, weight_norms) without clipping metrics_with_clip: (losses, grad_norms, weight_norms) with clipping rare_indices: Indices of rare 'B' samples filename: Output filename n_samples: Number of samples per epoch """ fig, axes = plt.subplots(3, 2, figsize=(16, 12)) losses_no, grads_no, weights_no = metrics_no_clip losses_with, grads_with, weights_with = metrics_with_clip steps = range(len(losses_no)) n_epochs = len(losses_no) // n_samples # Column 1: Without Clipping axes[0, 0].plot(steps, losses_no, 'b-', alpha=0.7, linewidth=0.5) axes[0, 0].set_ylabel('Training Loss', fontsize=11) axes[0, 0].set_title('WITHOUT Gradient Clipping', fontsize=13, fontweight='bold', color='red') axes[0, 0].grid(True, alpha=0.3) axes[1, 0].plot(steps, grads_no, 'g-', alpha=0.7, linewidth=0.5) axes[1, 0].set_ylabel('Gradient L2 Norm', fontsize=11) axes[1, 0].grid(True, alpha=0.3) axes[2, 0].plot(steps, weights_no, 'm-', alpha=0.7, linewidth=0.5) axes[2, 0].set_ylabel('Weight L2 Norm', fontsize=11) axes[2, 0].set_xlabel('Training Step', fontsize=11) axes[2, 0].grid(True, alpha=0.3) # Column 2: With Clipping axes[0, 1].plot(steps, losses_with, 'b-', alpha=0.7, linewidth=0.5) axes[0, 1].set_title('WITH Gradient Clipping (max_norm=1.0)', fontsize=13, fontweight='bold', color='green') axes[0, 1].grid(True, alpha=0.3) axes[1, 1].plot(steps, grads_with, 'g-', alpha=0.7, linewidth=0.5) axes[1, 1].axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Clip threshold') axes[1, 1].legend(loc='upper right') axes[1, 1].grid(True, alpha=0.3) axes[2, 1].plot(steps, weights_with, 'm-', alpha=0.7, linewidth=0.5) axes[2, 1].set_xlabel('Training Step', fontsize=11) axes[2, 1].grid(True, alpha=0.3) # Highlight rare sample positions in all plots for col in range(2): for row in range(3): for epoch in range(n_epochs): for idx in rare_indices: step = epoch * n_samples + idx if step < len(losses_no): axes[row, col].axvline(x=step, color='red', alpha=0.2, linewidth=0.5) # Add legend for rare samples axes[0, 0].axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples") axes[0, 0].legend(loc='upper right') # Add overall title fig.suptitle('Effect of Gradient Clipping on Training Stability\n(Red lines indicate rare "B" samples)', fontsize=14, fontweight='bold', y=1.02) plt.tight_layout() plt.savefig(filename, dpi=150, bbox_inches='tight') plt.close() print(f"Comparison plot saved to: {filename}") # ============================================================================= # 6. MAIN EXECUTION # ============================================================================= def main(): print("="*60) print("GRADIENT CLIPPING EXPERIMENT") print("="*60) print("\nThis experiment demonstrates how gradient clipping stabilizes") print("training by preventing sudden large weight updates caused by") print("rare, high-loss data points.\n") # Create dataset ONCE (used for both runs) inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED) print(f"Dataset created:") print(f" Total samples: {len(inputs)}") print(f" Target 'A' (0): {(targets == 0).sum().item()}") print(f" Target 'B' (1): {(targets == 1).sum().item()}") print(f" Rare 'B' indices: {rare_indices}") # Get initial weights (same for both runs) init_weights = get_initial_weights(seed=SEED) # Run training WITHOUT gradient clipping losses_no_clip, grads_no_clip, weights_no_clip = run_training( inputs, targets, rare_indices, clip_grad=False, n_epochs=3, lr=0.1, init_weights=init_weights ) # Run training WITH gradient clipping losses_with_clip, grads_with_clip, weights_with_clip = run_training( inputs, targets, rare_indices, clip_grad=True, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=init_weights ) # Generate individual plots print("\n" + "="*60) print("GENERATING PLOTS") print("="*60) plot_metrics( losses_no_clip, grads_no_clip, weights_no_clip, "Training WITHOUT Gradient Clipping", "no_clipping.png", rare_indices ) plot_metrics( losses_with_clip, grads_with_clip, weights_with_clip, "Training WITH Gradient Clipping (max_norm=1.0)", "with_clipping.png", rare_indices ) # Generate comparison plot plot_comparison( (losses_no_clip, grads_no_clip, weights_no_clip), (losses_with_clip, grads_with_clip, weights_with_clip), rare_indices, "comparison.png" ) # Print summary statistics print("\n" + "="*60) print("SUMMARY STATISTICS") print("="*60) print("\nWithout Gradient Clipping:") print(f" Max Gradient Norm: {max(grads_no_clip):.4f}") print(f" Mean Gradient Norm: {np.mean(grads_no_clip):.4f}") print(f" Std Gradient Norm: {np.std(grads_no_clip):.4f}") print(f" Final Weight Norm: {weights_no_clip[-1]:.4f}") print(f" Final Loss: {losses_no_clip[-1]:.4f}") print("\nWith Gradient Clipping (max_norm=1.0):") print(f" Max Gradient Norm: {max(grads_with_clip):.4f}") print(f" Mean Gradient Norm: {np.mean(grads_with_clip):.4f}") print(f" Std Gradient Norm: {np.std(grads_with_clip):.4f}") print(f" Final Weight Norm: {weights_with_clip[-1]:.4f}") print(f" Final Loss: {losses_with_clip[-1]:.4f}") # Return statistics for report return { 'no_clip': { 'max_grad': max(grads_no_clip), 'mean_grad': np.mean(grads_no_clip), 'std_grad': np.std(grads_no_clip), 'final_weight': weights_no_clip[-1], 'final_loss': losses_no_clip[-1] }, 'with_clip': { 'max_grad': max(grads_with_clip), 'mean_grad': np.mean(grads_with_clip), 'std_grad': np.std(grads_with_clip), 'final_weight': weights_with_clip[-1], 'final_loss': losses_with_clip[-1] }, 'rare_indices': rare_indices } if __name__ == "__main__": stats = main() print("\n" + "="*60) print("EXPERIMENT COMPLETE!") print("="*60)