AmberLJC's picture
Upload experiment.py with huggingface_hub
86f312f verified
"""
Gradient Clipping Experiment
This script demonstrates how gradient clipping stabilizes training by preventing
sudden large weight updates caused by rare, high-loss data points.
Experiment Setup:
- Simple model: Embedding(4, 16) -> Linear(16, 4)
- Vocabulary: ['A', 'B', 'C', 'D']
- Dataset: 1000 samples with imbalanced targets (990 'A', 10 'B')
- Compare training with and without gradient clipping
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
# Set seeds for reproducibility
SEED = 42
def set_seeds(seed=SEED):
"""Set all random seeds for reproducibility."""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# =============================================================================
# 1. MODEL DEFINITION
# =============================================================================
class SimpleNextTokenModel(nn.Module):
"""
Simple model that takes a token index and predicts the next token.
Architecture: Embedding -> Linear
"""
def __init__(self, vocab_size=4, embedding_dim=16):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, x):
"""
Args:
x: Token indices of shape (batch_size,)
Returns:
Logits of shape (batch_size, vocab_size)
"""
embedded = self.embedding(x) # (batch_size, embedding_dim)
logits = self.linear(embedded) # (batch_size, vocab_size)
return logits
# =============================================================================
# 2. DATASET CREATION
# =============================================================================
def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED):
"""
Create a synthetic dataset with imbalanced targets.
Args:
n_samples: Total number of samples
n_rare: Number of rare 'B' samples
seed: Random seed for reproducibility
Returns:
inputs: Random token indices (0-3)
targets: 990 'A' (0) and 10 'B' (1)
rare_indices: Indices where target is 'B'
"""
# Set seed for reproducibility
set_seeds(seed)
vocab = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
# Random input tokens
inputs = torch.randint(0, 4, (n_samples,))
# Create imbalanced targets: mostly 'A' (0), few 'B' (1)
targets = torch.zeros(n_samples, dtype=torch.long) # All 'A' initially
# Randomly select indices for rare 'B' samples
rare_indices = random.sample(range(n_samples), n_rare)
targets[rare_indices] = 1 # Set to 'B'
return inputs, targets, sorted(rare_indices)
# =============================================================================
# 3. UTILITY FUNCTIONS
# =============================================================================
def compute_weight_norm(model):
"""Compute L2 norm of all model weights."""
total_norm = 0.0
for param in model.parameters():
total_norm += param.data.norm(2).item() ** 2
return total_norm ** 0.5
def get_initial_weights(seed=SEED):
"""Get initial weights for reproducible model initialization."""
set_seeds(seed)
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
return {name: param.clone() for name, param in model.state_dict().items()}
def train_epoch(model, optimizer, criterion, inputs, targets, clip_grad=False, max_norm=1.0):
"""
Train for one epoch, recording metrics at each step.
Args:
model: The neural network
optimizer: SGD optimizer
criterion: CrossEntropyLoss
inputs: Input token indices
targets: Target token indices
clip_grad: Whether to apply gradient clipping
max_norm: Maximum gradient norm (if clipping)
Returns:
losses: List of losses per step
grad_norms: List of gradient norms per step (before clipping)
weight_norms: List of weight norms per step
"""
model.train()
losses = []
grad_norms = []
weight_norms = []
# Train on each sample individually to see the effect of rare samples
for i in range(len(inputs)):
x = inputs[i:i+1] # Single sample
y = targets[i:i+1]
optimizer.zero_grad()
# Forward pass
logits = model(x)
loss = criterion(logits, y)
# Backward pass
loss.backward()
# Compute gradient norm BEFORE clipping
# Use a large value to just compute the norm without clipping
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
# Apply gradient clipping if requested
if clip_grad:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# Update weights
optimizer.step()
# Record metrics
losses.append(loss.item())
grad_norms.append(grad_norm.item())
weight_norms.append(compute_weight_norm(model))
return losses, grad_norms, weight_norms
# =============================================================================
# 4. TRAINING FUNCTIONS
# =============================================================================
def run_training(inputs, targets, rare_indices, clip_grad=False, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=None):
"""
Run complete training loop.
Args:
inputs: Input token indices
targets: Target token indices
rare_indices: Indices of rare 'B' samples
clip_grad: Whether to apply gradient clipping
max_norm: Maximum gradient norm threshold
n_epochs: Number of training epochs
lr: Learning rate
init_weights: Initial model weights for reproducibility
Returns:
all_losses, all_grad_norms, all_weight_norms: Metrics across all steps
"""
# Create fresh model with same initial weights
set_seeds(SEED)
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
if init_weights:
model.load_state_dict(init_weights)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
all_losses = []
all_grad_norms = []
all_weight_norms = []
mode = "WITH" if clip_grad else "WITHOUT"
print(f"\n{'='*60}")
print(f"Training {mode} gradient clipping (max_norm={max_norm})")
print(f"{'='*60}")
for epoch in range(n_epochs):
losses, grad_norms, weight_norms = train_epoch(
model, optimizer, criterion, inputs, targets,
clip_grad=clip_grad, max_norm=max_norm
)
all_losses.extend(losses)
all_grad_norms.extend(grad_norms)
all_weight_norms.extend(weight_norms)
avg_loss = np.mean(losses)
max_grad = np.max(grad_norms)
print(f"Epoch {epoch+1}/{n_epochs}: Avg Loss={avg_loss:.4f}, Max Grad Norm={max_grad:.4f}")
return all_losses, all_grad_norms, all_weight_norms
# =============================================================================
# 5. PLOTTING FUNCTIONS
# =============================================================================
def plot_metrics(losses, grad_norms, weight_norms, title, filename, rare_indices=None, n_samples=1000):
"""
Plot training metrics: loss, gradient norm, and weight norm.
Args:
losses: List of losses per step
grad_norms: List of gradient norms per step
weight_norms: List of weight norms per step
title: Plot title
filename: Output filename
rare_indices: Indices of rare 'B' samples (for highlighting)
n_samples: Number of samples per epoch
"""
fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)
steps = range(len(losses))
n_epochs = len(losses) // n_samples
# Plot 1: Training Loss
axes[0].plot(steps, losses, 'b-', alpha=0.7, linewidth=0.5)
axes[0].set_ylabel('Training Loss', fontsize=12)
axes[0].set_title(title, fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
# Highlight rare sample positions
if rare_indices:
for epoch in range(n_epochs):
for idx in rare_indices:
step = epoch * n_samples + idx
if step < len(losses):
axes[0].axvline(x=step, color='red', alpha=0.3, linewidth=0.5)
# Plot 2: Gradient Norm
axes[1].plot(steps, grad_norms, 'g-', alpha=0.7, linewidth=0.5)
axes[1].set_ylabel('Gradient L2 Norm', fontsize=12)
axes[1].grid(True, alpha=0.3)
# Add horizontal line at clipping threshold
if "With" in title or "WITH" in title:
axes[1].axhline(y=1.0, color='red', linestyle='--', label='Clip threshold (1.0)')
axes[1].legend()
if rare_indices:
for epoch in range(n_epochs):
for idx in rare_indices:
step = epoch * n_samples + idx
if step < len(grad_norms):
axes[1].axvline(x=step, color='red', alpha=0.3, linewidth=0.5)
# Plot 3: Weight Norm
axes[2].plot(steps, weight_norms, 'm-', alpha=0.7, linewidth=0.5)
axes[2].set_ylabel('Weight L2 Norm', fontsize=12)
axes[2].set_xlabel('Training Step', fontsize=12)
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(filename, dpi=150, bbox_inches='tight')
plt.close()
print(f"Plot saved to: {filename}")
def plot_comparison(metrics_no_clip, metrics_with_clip, rare_indices, filename, n_samples=1000):
"""
Create side-by-side comparison plot.
Args:
metrics_no_clip: (losses, grad_norms, weight_norms) without clipping
metrics_with_clip: (losses, grad_norms, weight_norms) with clipping
rare_indices: Indices of rare 'B' samples
filename: Output filename
n_samples: Number of samples per epoch
"""
fig, axes = plt.subplots(3, 2, figsize=(16, 12))
losses_no, grads_no, weights_no = metrics_no_clip
losses_with, grads_with, weights_with = metrics_with_clip
steps = range(len(losses_no))
n_epochs = len(losses_no) // n_samples
# Column 1: Without Clipping
axes[0, 0].plot(steps, losses_no, 'b-', alpha=0.7, linewidth=0.5)
axes[0, 0].set_ylabel('Training Loss', fontsize=11)
axes[0, 0].set_title('WITHOUT Gradient Clipping', fontsize=13, fontweight='bold', color='red')
axes[0, 0].grid(True, alpha=0.3)
axes[1, 0].plot(steps, grads_no, 'g-', alpha=0.7, linewidth=0.5)
axes[1, 0].set_ylabel('Gradient L2 Norm', fontsize=11)
axes[1, 0].grid(True, alpha=0.3)
axes[2, 0].plot(steps, weights_no, 'm-', alpha=0.7, linewidth=0.5)
axes[2, 0].set_ylabel('Weight L2 Norm', fontsize=11)
axes[2, 0].set_xlabel('Training Step', fontsize=11)
axes[2, 0].grid(True, alpha=0.3)
# Column 2: With Clipping
axes[0, 1].plot(steps, losses_with, 'b-', alpha=0.7, linewidth=0.5)
axes[0, 1].set_title('WITH Gradient Clipping (max_norm=1.0)', fontsize=13, fontweight='bold', color='green')
axes[0, 1].grid(True, alpha=0.3)
axes[1, 1].plot(steps, grads_with, 'g-', alpha=0.7, linewidth=0.5)
axes[1, 1].axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Clip threshold')
axes[1, 1].legend(loc='upper right')
axes[1, 1].grid(True, alpha=0.3)
axes[2, 1].plot(steps, weights_with, 'm-', alpha=0.7, linewidth=0.5)
axes[2, 1].set_xlabel('Training Step', fontsize=11)
axes[2, 1].grid(True, alpha=0.3)
# Highlight rare sample positions in all plots
for col in range(2):
for row in range(3):
for epoch in range(n_epochs):
for idx in rare_indices:
step = epoch * n_samples + idx
if step < len(losses_no):
axes[row, col].axvline(x=step, color='red', alpha=0.2, linewidth=0.5)
# Add legend for rare samples
axes[0, 0].axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples")
axes[0, 0].legend(loc='upper right')
# Add overall title
fig.suptitle('Effect of Gradient Clipping on Training Stability\n(Red lines indicate rare "B" samples)',
fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(filename, dpi=150, bbox_inches='tight')
plt.close()
print(f"Comparison plot saved to: {filename}")
# =============================================================================
# 6. MAIN EXECUTION
# =============================================================================
def main():
print("="*60)
print("GRADIENT CLIPPING EXPERIMENT")
print("="*60)
print("\nThis experiment demonstrates how gradient clipping stabilizes")
print("training by preventing sudden large weight updates caused by")
print("rare, high-loss data points.\n")
# Create dataset ONCE (used for both runs)
inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED)
print(f"Dataset created:")
print(f" Total samples: {len(inputs)}")
print(f" Target 'A' (0): {(targets == 0).sum().item()}")
print(f" Target 'B' (1): {(targets == 1).sum().item()}")
print(f" Rare 'B' indices: {rare_indices}")
# Get initial weights (same for both runs)
init_weights = get_initial_weights(seed=SEED)
# Run training WITHOUT gradient clipping
losses_no_clip, grads_no_clip, weights_no_clip = run_training(
inputs, targets, rare_indices,
clip_grad=False, n_epochs=3, lr=0.1, init_weights=init_weights
)
# Run training WITH gradient clipping
losses_with_clip, grads_with_clip, weights_with_clip = run_training(
inputs, targets, rare_indices,
clip_grad=True, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=init_weights
)
# Generate individual plots
print("\n" + "="*60)
print("GENERATING PLOTS")
print("="*60)
plot_metrics(
losses_no_clip, grads_no_clip, weights_no_clip,
"Training WITHOUT Gradient Clipping",
"no_clipping.png",
rare_indices
)
plot_metrics(
losses_with_clip, grads_with_clip, weights_with_clip,
"Training WITH Gradient Clipping (max_norm=1.0)",
"with_clipping.png",
rare_indices
)
# Generate comparison plot
plot_comparison(
(losses_no_clip, grads_no_clip, weights_no_clip),
(losses_with_clip, grads_with_clip, weights_with_clip),
rare_indices,
"comparison.png"
)
# Print summary statistics
print("\n" + "="*60)
print("SUMMARY STATISTICS")
print("="*60)
print("\nWithout Gradient Clipping:")
print(f" Max Gradient Norm: {max(grads_no_clip):.4f}")
print(f" Mean Gradient Norm: {np.mean(grads_no_clip):.4f}")
print(f" Std Gradient Norm: {np.std(grads_no_clip):.4f}")
print(f" Final Weight Norm: {weights_no_clip[-1]:.4f}")
print(f" Final Loss: {losses_no_clip[-1]:.4f}")
print("\nWith Gradient Clipping (max_norm=1.0):")
print(f" Max Gradient Norm: {max(grads_with_clip):.4f}")
print(f" Mean Gradient Norm: {np.mean(grads_with_clip):.4f}")
print(f" Std Gradient Norm: {np.std(grads_with_clip):.4f}")
print(f" Final Weight Norm: {weights_with_clip[-1]:.4f}")
print(f" Final Loss: {losses_with_clip[-1]:.4f}")
# Return statistics for report
return {
'no_clip': {
'max_grad': max(grads_no_clip),
'mean_grad': np.mean(grads_no_clip),
'std_grad': np.std(grads_no_clip),
'final_weight': weights_no_clip[-1],
'final_loss': losses_no_clip[-1]
},
'with_clip': {
'max_grad': max(grads_with_clip),
'mean_grad': np.mean(grads_with_clip),
'std_grad': np.std(grads_with_clip),
'final_weight': weights_with_clip[-1],
'final_loss': losses_with_clip[-1]
},
'rare_indices': rare_indices
}
if __name__ == "__main__":
stats = main()
print("\n" + "="*60)
print("EXPERIMENT COMPLETE!")
print("="*60)