|
|
""" |
|
|
Gradient Clipping Experiment |
|
|
|
|
|
This script demonstrates how gradient clipping stabilizes training by preventing |
|
|
sudden large weight updates caused by rare, high-loss data points. |
|
|
|
|
|
Experiment Setup: |
|
|
- Simple model: Embedding(4, 16) -> Linear(16, 4) |
|
|
- Vocabulary: ['A', 'B', 'C', 'D'] |
|
|
- Dataset: 1000 samples with imbalanced targets (990 'A', 10 'B') |
|
|
- Compare training with and without gradient clipping |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import random |
|
|
|
|
|
|
|
|
SEED = 42 |
|
|
|
|
|
|
|
|
def set_seeds(seed=SEED): |
|
|
"""Set all random seeds for reproducibility.""" |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleNextTokenModel(nn.Module): |
|
|
""" |
|
|
Simple model that takes a token index and predicts the next token. |
|
|
Architecture: Embedding -> Linear |
|
|
""" |
|
|
def __init__(self, vocab_size=4, embedding_dim=16): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
|
self.linear = nn.Linear(embedding_dim, vocab_size) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: Token indices of shape (batch_size,) |
|
|
Returns: |
|
|
Logits of shape (batch_size, vocab_size) |
|
|
""" |
|
|
embedded = self.embedding(x) |
|
|
logits = self.linear(embedded) |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED): |
|
|
""" |
|
|
Create a synthetic dataset with imbalanced targets. |
|
|
|
|
|
Args: |
|
|
n_samples: Total number of samples |
|
|
n_rare: Number of rare 'B' samples |
|
|
seed: Random seed for reproducibility |
|
|
|
|
|
Returns: |
|
|
inputs: Random token indices (0-3) |
|
|
targets: 990 'A' (0) and 10 'B' (1) |
|
|
rare_indices: Indices where target is 'B' |
|
|
""" |
|
|
|
|
|
set_seeds(seed) |
|
|
|
|
|
vocab = {'A': 0, 'B': 1, 'C': 2, 'D': 3} |
|
|
|
|
|
|
|
|
inputs = torch.randint(0, 4, (n_samples,)) |
|
|
|
|
|
|
|
|
targets = torch.zeros(n_samples, dtype=torch.long) |
|
|
|
|
|
|
|
|
rare_indices = random.sample(range(n_samples), n_rare) |
|
|
targets[rare_indices] = 1 |
|
|
|
|
|
return inputs, targets, sorted(rare_indices) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_weight_norm(model): |
|
|
"""Compute L2 norm of all model weights.""" |
|
|
total_norm = 0.0 |
|
|
for param in model.parameters(): |
|
|
total_norm += param.data.norm(2).item() ** 2 |
|
|
return total_norm ** 0.5 |
|
|
|
|
|
|
|
|
def get_initial_weights(seed=SEED): |
|
|
"""Get initial weights for reproducible model initialization.""" |
|
|
set_seeds(seed) |
|
|
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
return {name: param.clone() for name, param in model.state_dict().items()} |
|
|
|
|
|
|
|
|
def train_epoch(model, optimizer, criterion, inputs, targets, clip_grad=False, max_norm=1.0): |
|
|
""" |
|
|
Train for one epoch, recording metrics at each step. |
|
|
|
|
|
Args: |
|
|
model: The neural network |
|
|
optimizer: SGD optimizer |
|
|
criterion: CrossEntropyLoss |
|
|
inputs: Input token indices |
|
|
targets: Target token indices |
|
|
clip_grad: Whether to apply gradient clipping |
|
|
max_norm: Maximum gradient norm (if clipping) |
|
|
|
|
|
Returns: |
|
|
losses: List of losses per step |
|
|
grad_norms: List of gradient norms per step (before clipping) |
|
|
weight_norms: List of weight norms per step |
|
|
""" |
|
|
model.train() |
|
|
|
|
|
losses = [] |
|
|
grad_norms = [] |
|
|
weight_norms = [] |
|
|
|
|
|
|
|
|
for i in range(len(inputs)): |
|
|
x = inputs[i:i+1] |
|
|
y = targets[i:i+1] |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
logits = model(x) |
|
|
loss = criterion(logits, y) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) |
|
|
|
|
|
|
|
|
if clip_grad: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
losses.append(loss.item()) |
|
|
grad_norms.append(grad_norm.item()) |
|
|
weight_norms.append(compute_weight_norm(model)) |
|
|
|
|
|
return losses, grad_norms, weight_norms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_training(inputs, targets, rare_indices, clip_grad=False, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=None): |
|
|
""" |
|
|
Run complete training loop. |
|
|
|
|
|
Args: |
|
|
inputs: Input token indices |
|
|
targets: Target token indices |
|
|
rare_indices: Indices of rare 'B' samples |
|
|
clip_grad: Whether to apply gradient clipping |
|
|
max_norm: Maximum gradient norm threshold |
|
|
n_epochs: Number of training epochs |
|
|
lr: Learning rate |
|
|
init_weights: Initial model weights for reproducibility |
|
|
|
|
|
Returns: |
|
|
all_losses, all_grad_norms, all_weight_norms: Metrics across all steps |
|
|
""" |
|
|
|
|
|
set_seeds(SEED) |
|
|
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
if init_weights: |
|
|
model.load_state_dict(init_weights) |
|
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=lr) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
all_losses = [] |
|
|
all_grad_norms = [] |
|
|
all_weight_norms = [] |
|
|
|
|
|
mode = "WITH" if clip_grad else "WITHOUT" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Training {mode} gradient clipping (max_norm={max_norm})") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
losses, grad_norms, weight_norms = train_epoch( |
|
|
model, optimizer, criterion, inputs, targets, |
|
|
clip_grad=clip_grad, max_norm=max_norm |
|
|
) |
|
|
|
|
|
all_losses.extend(losses) |
|
|
all_grad_norms.extend(grad_norms) |
|
|
all_weight_norms.extend(weight_norms) |
|
|
|
|
|
avg_loss = np.mean(losses) |
|
|
max_grad = np.max(grad_norms) |
|
|
print(f"Epoch {epoch+1}/{n_epochs}: Avg Loss={avg_loss:.4f}, Max Grad Norm={max_grad:.4f}") |
|
|
|
|
|
return all_losses, all_grad_norms, all_weight_norms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_metrics(losses, grad_norms, weight_norms, title, filename, rare_indices=None, n_samples=1000): |
|
|
""" |
|
|
Plot training metrics: loss, gradient norm, and weight norm. |
|
|
|
|
|
Args: |
|
|
losses: List of losses per step |
|
|
grad_norms: List of gradient norms per step |
|
|
weight_norms: List of weight norms per step |
|
|
title: Plot title |
|
|
filename: Output filename |
|
|
rare_indices: Indices of rare 'B' samples (for highlighting) |
|
|
n_samples: Number of samples per epoch |
|
|
""" |
|
|
fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True) |
|
|
|
|
|
steps = range(len(losses)) |
|
|
n_epochs = len(losses) // n_samples |
|
|
|
|
|
|
|
|
axes[0].plot(steps, losses, 'b-', alpha=0.7, linewidth=0.5) |
|
|
axes[0].set_ylabel('Training Loss', fontsize=12) |
|
|
axes[0].set_title(title, fontsize=14, fontweight='bold') |
|
|
axes[0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if rare_indices: |
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
step = epoch * n_samples + idx |
|
|
if step < len(losses): |
|
|
axes[0].axvline(x=step, color='red', alpha=0.3, linewidth=0.5) |
|
|
|
|
|
|
|
|
axes[1].plot(steps, grad_norms, 'g-', alpha=0.7, linewidth=0.5) |
|
|
axes[1].set_ylabel('Gradient L2 Norm', fontsize=12) |
|
|
axes[1].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if "With" in title or "WITH" in title: |
|
|
axes[1].axhline(y=1.0, color='red', linestyle='--', label='Clip threshold (1.0)') |
|
|
axes[1].legend() |
|
|
|
|
|
if rare_indices: |
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
step = epoch * n_samples + idx |
|
|
if step < len(grad_norms): |
|
|
axes[1].axvline(x=step, color='red', alpha=0.3, linewidth=0.5) |
|
|
|
|
|
|
|
|
axes[2].plot(steps, weight_norms, 'm-', alpha=0.7, linewidth=0.5) |
|
|
axes[2].set_ylabel('Weight L2 Norm', fontsize=12) |
|
|
axes[2].set_xlabel('Training Step', fontsize=12) |
|
|
axes[2].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"Plot saved to: {filename}") |
|
|
|
|
|
|
|
|
def plot_comparison(metrics_no_clip, metrics_with_clip, rare_indices, filename, n_samples=1000): |
|
|
""" |
|
|
Create side-by-side comparison plot. |
|
|
|
|
|
Args: |
|
|
metrics_no_clip: (losses, grad_norms, weight_norms) without clipping |
|
|
metrics_with_clip: (losses, grad_norms, weight_norms) with clipping |
|
|
rare_indices: Indices of rare 'B' samples |
|
|
filename: Output filename |
|
|
n_samples: Number of samples per epoch |
|
|
""" |
|
|
fig, axes = plt.subplots(3, 2, figsize=(16, 12)) |
|
|
|
|
|
losses_no, grads_no, weights_no = metrics_no_clip |
|
|
losses_with, grads_with, weights_with = metrics_with_clip |
|
|
|
|
|
steps = range(len(losses_no)) |
|
|
n_epochs = len(losses_no) // n_samples |
|
|
|
|
|
|
|
|
axes[0, 0].plot(steps, losses_no, 'b-', alpha=0.7, linewidth=0.5) |
|
|
axes[0, 0].set_ylabel('Training Loss', fontsize=11) |
|
|
axes[0, 0].set_title('WITHOUT Gradient Clipping', fontsize=13, fontweight='bold', color='red') |
|
|
axes[0, 0].grid(True, alpha=0.3) |
|
|
|
|
|
axes[1, 0].plot(steps, grads_no, 'g-', alpha=0.7, linewidth=0.5) |
|
|
axes[1, 0].set_ylabel('Gradient L2 Norm', fontsize=11) |
|
|
axes[1, 0].grid(True, alpha=0.3) |
|
|
|
|
|
axes[2, 0].plot(steps, weights_no, 'm-', alpha=0.7, linewidth=0.5) |
|
|
axes[2, 0].set_ylabel('Weight L2 Norm', fontsize=11) |
|
|
axes[2, 0].set_xlabel('Training Step', fontsize=11) |
|
|
axes[2, 0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[0, 1].plot(steps, losses_with, 'b-', alpha=0.7, linewidth=0.5) |
|
|
axes[0, 1].set_title('WITH Gradient Clipping (max_norm=1.0)', fontsize=13, fontweight='bold', color='green') |
|
|
axes[0, 1].grid(True, alpha=0.3) |
|
|
|
|
|
axes[1, 1].plot(steps, grads_with, 'g-', alpha=0.7, linewidth=0.5) |
|
|
axes[1, 1].axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Clip threshold') |
|
|
axes[1, 1].legend(loc='upper right') |
|
|
axes[1, 1].grid(True, alpha=0.3) |
|
|
|
|
|
axes[2, 1].plot(steps, weights_with, 'm-', alpha=0.7, linewidth=0.5) |
|
|
axes[2, 1].set_xlabel('Training Step', fontsize=11) |
|
|
axes[2, 1].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
for col in range(2): |
|
|
for row in range(3): |
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
step = epoch * n_samples + idx |
|
|
if step < len(losses_no): |
|
|
axes[row, col].axvline(x=step, color='red', alpha=0.2, linewidth=0.5) |
|
|
|
|
|
|
|
|
axes[0, 0].axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples") |
|
|
axes[0, 0].legend(loc='upper right') |
|
|
|
|
|
|
|
|
fig.suptitle('Effect of Gradient Clipping on Training Stability\n(Red lines indicate rare "B" samples)', |
|
|
fontsize=14, fontweight='bold', y=1.02) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"Comparison plot saved to: {filename}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
print("="*60) |
|
|
print("GRADIENT CLIPPING EXPERIMENT") |
|
|
print("="*60) |
|
|
print("\nThis experiment demonstrates how gradient clipping stabilizes") |
|
|
print("training by preventing sudden large weight updates caused by") |
|
|
print("rare, high-loss data points.\n") |
|
|
|
|
|
|
|
|
inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED) |
|
|
|
|
|
print(f"Dataset created:") |
|
|
print(f" Total samples: {len(inputs)}") |
|
|
print(f" Target 'A' (0): {(targets == 0).sum().item()}") |
|
|
print(f" Target 'B' (1): {(targets == 1).sum().item()}") |
|
|
print(f" Rare 'B' indices: {rare_indices}") |
|
|
|
|
|
|
|
|
init_weights = get_initial_weights(seed=SEED) |
|
|
|
|
|
|
|
|
losses_no_clip, grads_no_clip, weights_no_clip = run_training( |
|
|
inputs, targets, rare_indices, |
|
|
clip_grad=False, n_epochs=3, lr=0.1, init_weights=init_weights |
|
|
) |
|
|
|
|
|
|
|
|
losses_with_clip, grads_with_clip, weights_with_clip = run_training( |
|
|
inputs, targets, rare_indices, |
|
|
clip_grad=True, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=init_weights |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("GENERATING PLOTS") |
|
|
print("="*60) |
|
|
|
|
|
plot_metrics( |
|
|
losses_no_clip, grads_no_clip, weights_no_clip, |
|
|
"Training WITHOUT Gradient Clipping", |
|
|
"no_clipping.png", |
|
|
rare_indices |
|
|
) |
|
|
|
|
|
plot_metrics( |
|
|
losses_with_clip, grads_with_clip, weights_with_clip, |
|
|
"Training WITH Gradient Clipping (max_norm=1.0)", |
|
|
"with_clipping.png", |
|
|
rare_indices |
|
|
) |
|
|
|
|
|
|
|
|
plot_comparison( |
|
|
(losses_no_clip, grads_no_clip, weights_no_clip), |
|
|
(losses_with_clip, grads_with_clip, weights_with_clip), |
|
|
rare_indices, |
|
|
"comparison.png" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("SUMMARY STATISTICS") |
|
|
print("="*60) |
|
|
|
|
|
print("\nWithout Gradient Clipping:") |
|
|
print(f" Max Gradient Norm: {max(grads_no_clip):.4f}") |
|
|
print(f" Mean Gradient Norm: {np.mean(grads_no_clip):.4f}") |
|
|
print(f" Std Gradient Norm: {np.std(grads_no_clip):.4f}") |
|
|
print(f" Final Weight Norm: {weights_no_clip[-1]:.4f}") |
|
|
print(f" Final Loss: {losses_no_clip[-1]:.4f}") |
|
|
|
|
|
print("\nWith Gradient Clipping (max_norm=1.0):") |
|
|
print(f" Max Gradient Norm: {max(grads_with_clip):.4f}") |
|
|
print(f" Mean Gradient Norm: {np.mean(grads_with_clip):.4f}") |
|
|
print(f" Std Gradient Norm: {np.std(grads_with_clip):.4f}") |
|
|
print(f" Final Weight Norm: {weights_with_clip[-1]:.4f}") |
|
|
print(f" Final Loss: {losses_with_clip[-1]:.4f}") |
|
|
|
|
|
|
|
|
return { |
|
|
'no_clip': { |
|
|
'max_grad': max(grads_no_clip), |
|
|
'mean_grad': np.mean(grads_no_clip), |
|
|
'std_grad': np.std(grads_no_clip), |
|
|
'final_weight': weights_no_clip[-1], |
|
|
'final_loss': losses_no_clip[-1] |
|
|
}, |
|
|
'with_clip': { |
|
|
'max_grad': max(grads_with_clip), |
|
|
'mean_grad': np.mean(grads_with_clip), |
|
|
'std_grad': np.std(grads_with_clip), |
|
|
'final_weight': weights_with_clip[-1], |
|
|
'final_loss': losses_with_clip[-1] |
|
|
}, |
|
|
'rare_indices': rare_indices |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
stats = main() |
|
|
print("\n" + "="*60) |
|
|
print("EXPERIMENT COMPLETE!") |
|
|
print("="*60) |
|
|
|