gradient_clipping_experiment / final_experiment.py
AmberLJC's picture
Upload final_experiment.py with huggingface_hub
6050e11 verified
"""
Final Gradient Clipping Experiment: Testing Physics-of-AI Predictions
Key insights from previous experiments:
1. With extreme imbalance (99:1), neither model learns rare class
2. Gradient clipping's benefit is in STABILITY, not learning rare classes per se
3. The key effect is on WEIGHT NORM STABILITY and GRADIENT SPIKE HANDLING
This experiment tests:
1. Prediction 2: Representation Collapse - effective dim variance without clipping
2. Prediction 4: Rare Sample Learning - using moderate imbalance (80:20)
3. NEW: Weight norm stability analysis
4. NEW: Gradient spike analysis at rare sample positions
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from typing import Dict, List
SEED = 42
def set_seeds(seed=SEED):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
class SimpleNextTokenModel(nn.Module):
def __init__(self, vocab_size=4, embedding_dim=16):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, x):
embedded = self.embedding(x)
logits = self.linear(embedded)
return logits
def get_embeddings(self):
return self.embedding.weight.data.clone()
def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float:
"""PCA-based effective dimensionality."""
centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True)
cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1)
eigenvalues = torch.linalg.eigvalsh(cov)
eigenvalues = torch.clamp(eigenvalues, min=1e-10)
eigenvalues = eigenvalues / eigenvalues.sum()
entropy = -torch.sum(eigenvalues * torch.log(eigenvalues))
return torch.exp(entropy).item()
def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor,
targets: torch.Tensor) -> Dict[int, float]:
model.eval()
with torch.no_grad():
logits = model(inputs)
predictions = logits.argmax(dim=1)
accuracies = {}
for class_idx in range(4):
mask = targets == class_idx
if mask.sum() > 0:
correct = (predictions[mask] == targets[mask]).float().mean().item()
accuracies[class_idx] = correct
else:
accuracies[class_idx] = None
return accuracies
def create_dataset_moderate_imbalance(n_samples=1000, rare_ratio=0.2, seed=SEED):
"""Create dataset with moderate imbalance (e.g., 80:20)."""
set_seeds(seed)
n_rare = int(n_samples * rare_ratio)
n_common = n_samples - n_rare
inputs = torch.randint(0, 4, (n_samples,))
targets = torch.zeros(n_samples, dtype=torch.long)
rare_indices = random.sample(range(n_samples), n_rare)
targets[rare_indices] = 1
return inputs, targets, sorted(rare_indices)
def create_dataset_extreme_imbalance(n_samples=1000, n_rare=10, seed=SEED):
"""Create dataset with extreme imbalance (99:1)."""
set_seeds(seed)
inputs = torch.randint(0, 4, (n_samples,))
targets = torch.zeros(n_samples, dtype=torch.long)
rare_indices = random.sample(range(n_samples), n_rare)
targets[rare_indices] = 1
return inputs, targets, sorted(rare_indices)
def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor,
rare_indices: List[int], clip_grad: bool = False,
max_norm: float = 1.0, n_epochs: int = 10,
lr: float = 0.1, init_weights=None,
track_every: int = 50) -> Dict:
"""Training with comprehensive tracking."""
set_seeds(SEED)
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
if init_weights:
model.load_state_dict({k: v.clone() for k, v in init_weights.items()})
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
metrics = {
'losses': [],
'grad_norms': [],
'weight_norms': [],
'effective_dims': [],
'effective_dim_steps': [],
'class_accuracies': {0: [], 1: [], 2: [], 3: []},
'accuracy_steps': [],
'weight_norm_changes': [], # Track sudden changes
}
step = 0
n_samples = len(inputs)
prev_weight_norm = None
for epoch in range(n_epochs):
model.train()
for i in range(n_samples):
x = inputs[i:i+1]
y = targets[i:i+1]
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf'))
if clip_grad:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
metrics['losses'].append(loss.item())
metrics['grad_norms'].append(grad_norm.item())
current_weight_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
metrics['weight_norms'].append(current_weight_norm)
# Track weight norm change
if prev_weight_norm is not None:
metrics['weight_norm_changes'].append(abs(current_weight_norm - prev_weight_norm))
else:
metrics['weight_norm_changes'].append(0)
prev_weight_norm = current_weight_norm
# Track periodically
if step % track_every == 0:
emb_matrix = model.get_embeddings()
eff_dim = compute_effective_dimension(emb_matrix)
metrics['effective_dims'].append(eff_dim)
metrics['effective_dim_steps'].append(step)
class_acc = compute_per_class_accuracy(model, inputs, targets)
for cls_idx in range(4):
if class_acc[cls_idx] is not None:
metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx])
else:
metrics['class_accuracies'][cls_idx].append(0.0)
metrics['accuracy_steps'].append(step)
step += 1
return metrics
def run_experiment_suite():
"""Run complete experiment suite with both imbalance levels."""
print("="*70)
print("FINAL GRADIENT CLIPPING EXPERIMENT")
print("Testing Physics-of-AI Predictions")
print("="*70)
# Get initial weights
set_seeds(SEED)
init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16)
init_weights = {name: param.clone() for name, param in init_model.state_dict().items()}
results = {}
# =========================================================================
# EXPERIMENT 1: Extreme Imbalance (99:1) - Original Setup
# =========================================================================
print("\n" + "="*70)
print("EXPERIMENT 1: EXTREME IMBALANCE (99:1)")
print("="*70)
inputs_extreme, targets_extreme, rare_extreme = create_dataset_extreme_imbalance(
n_samples=1000, n_rare=10, seed=SEED
)
print(f"Dataset: {(targets_extreme == 0).sum().item()} common, {(targets_extreme == 1).sum().item()} rare")
print("\nTraining WITHOUT clipping...")
metrics_extreme_no_clip = train_with_tracking(
inputs_extreme, targets_extreme, rare_extreme,
clip_grad=False, n_epochs=5, lr=0.1,
init_weights=init_weights, track_every=100
)
print("Training WITH clipping...")
metrics_extreme_with_clip = train_with_tracking(
inputs_extreme, targets_extreme, rare_extreme,
clip_grad=True, max_norm=1.0, n_epochs=5, lr=0.1,
init_weights=init_weights, track_every=100
)
results['extreme'] = {
'no_clip': metrics_extreme_no_clip,
'with_clip': metrics_extreme_with_clip,
'rare_indices': rare_extreme
}
# =========================================================================
# EXPERIMENT 2: Moderate Imbalance (80:20)
# =========================================================================
print("\n" + "="*70)
print("EXPERIMENT 2: MODERATE IMBALANCE (80:20)")
print("="*70)
inputs_moderate, targets_moderate, rare_moderate = create_dataset_moderate_imbalance(
n_samples=1000, rare_ratio=0.2, seed=SEED
)
print(f"Dataset: {(targets_moderate == 0).sum().item()} common, {(targets_moderate == 1).sum().item()} rare")
print("\nTraining WITHOUT clipping...")
metrics_moderate_no_clip = train_with_tracking(
inputs_moderate, targets_moderate, rare_moderate,
clip_grad=False, n_epochs=10, lr=0.1,
init_weights=init_weights, track_every=100
)
print("Training WITH clipping...")
metrics_moderate_with_clip = train_with_tracking(
inputs_moderate, targets_moderate, rare_moderate,
clip_grad=True, max_norm=1.0, n_epochs=10, lr=0.1,
init_weights=init_weights, track_every=100
)
results['moderate'] = {
'no_clip': metrics_moderate_no_clip,
'with_clip': metrics_moderate_with_clip,
'rare_indices': rare_moderate
}
return results
def plot_final_comparison(results: Dict, filename: str):
"""Create final comparison plot."""
fig = plt.figure(figsize=(20, 20))
gs = fig.add_gridspec(5, 2, hspace=0.35, wspace=0.25)
# =========================================================================
# Row 1: Weight Norm Stability (Key Physics-of-AI Insight)
# =========================================================================
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
# Extreme imbalance
steps = range(len(results['extreme']['no_clip']['weight_norms']))
ax1.plot(steps, results['extreme']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip')
ax1.plot(steps, results['extreme']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip')
ax1.set_ylabel('Weight Norm', fontsize=11)
ax1.set_title('EXTREME (99:1) - Weight Norm Evolution', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Moderate imbalance
steps = range(len(results['moderate']['no_clip']['weight_norms']))
ax2.plot(steps, results['moderate']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip')
ax2.plot(steps, results['moderate']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip')
ax2.set_title('MODERATE (80:20) - Weight Norm Evolution', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
# =========================================================================
# Row 2: Weight Norm Changes (Stability Metric)
# =========================================================================
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])
# Extreme
steps = range(len(results['extreme']['no_clip']['weight_norm_changes']))
ax3.plot(steps, results['extreme']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip')
ax3.plot(steps, results['extreme']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip')
ax3.set_ylabel('|Weight Norm Change|', fontsize=11)
ax3.set_title('EXTREME - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)
# Moderate
steps = range(len(results['moderate']['no_clip']['weight_norm_changes']))
ax4.plot(steps, results['moderate']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip')
ax4.plot(steps, results['moderate']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip')
ax4.set_title('MODERATE - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold')
ax4.legend()
ax4.grid(True, alpha=0.3)
# =========================================================================
# Row 3: Gradient Norms
# =========================================================================
ax5 = fig.add_subplot(gs[2, 0])
ax6 = fig.add_subplot(gs[2, 1])
# Extreme
steps = range(len(results['extreme']['no_clip']['grad_norms']))
ax5.plot(steps, results['extreme']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip')
ax5.plot(steps, results['extreme']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip')
ax5.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold')
ax5.set_ylabel('Gradient Norm', fontsize=11)
ax5.set_title('EXTREME - Gradient Norms', fontsize=12, fontweight='bold')
ax5.legend()
ax5.grid(True, alpha=0.3)
# Moderate
steps = range(len(results['moderate']['no_clip']['grad_norms']))
ax6.plot(steps, results['moderate']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip')
ax6.plot(steps, results['moderate']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip')
ax6.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold')
ax6.set_title('MODERATE - Gradient Norms', fontsize=12, fontweight='bold')
ax6.legend()
ax6.grid(True, alpha=0.3)
# =========================================================================
# Row 4: Effective Dimension
# =========================================================================
ax7 = fig.add_subplot(gs[3, 0])
ax8 = fig.add_subplot(gs[3, 1])
# Extreme
ax7.plot(results['extreme']['no_clip']['effective_dim_steps'],
results['extreme']['no_clip']['effective_dims'],
'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip')
ax7.plot(results['extreme']['with_clip']['effective_dim_steps'],
results['extreme']['with_clip']['effective_dims'],
'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip')
ax7.set_ylabel('Effective Dimension', fontsize=11)
ax7.set_title('EXTREME - Effective Dimensionality', fontsize=12, fontweight='bold')
ax7.legend()
ax7.grid(True, alpha=0.3)
# Moderate
ax8.plot(results['moderate']['no_clip']['effective_dim_steps'],
results['moderate']['no_clip']['effective_dims'],
'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip')
ax8.plot(results['moderate']['with_clip']['effective_dim_steps'],
results['moderate']['with_clip']['effective_dims'],
'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip')
ax8.set_title('MODERATE - Effective Dimensionality', fontsize=12, fontweight='bold')
ax8.legend()
ax8.grid(True, alpha=0.3)
# =========================================================================
# Row 5: Class Accuracies
# =========================================================================
ax9 = fig.add_subplot(gs[4, 0])
ax10 = fig.add_subplot(gs[4, 1])
# Extreme - Rare class B
ax9.plot(results['extreme']['no_clip']['accuracy_steps'],
results['extreme']['no_clip']['class_accuracies'][1],
'r-', alpha=0.7, linewidth=2, label='Without Clip')
ax9.plot(results['extreme']['with_clip']['accuracy_steps'],
results['extreme']['with_clip']['class_accuracies'][1],
'g-', alpha=0.7, linewidth=2, label='With Clip')
ax9.set_ylabel('Rare Class B Accuracy', fontsize=11)
ax9.set_xlabel('Training Step', fontsize=11)
ax9.set_title('EXTREME - Rare Class Accuracy', fontsize=12, fontweight='bold')
ax9.legend()
ax9.grid(True, alpha=0.3)
ax9.set_ylim([0, 1.05])
# Moderate - Rare class B
ax10.plot(results['moderate']['no_clip']['accuracy_steps'],
results['moderate']['no_clip']['class_accuracies'][1],
'r-', alpha=0.7, linewidth=2, label='Without Clip')
ax10.plot(results['moderate']['with_clip']['accuracy_steps'],
results['moderate']['with_clip']['class_accuracies'][1],
'g-', alpha=0.7, linewidth=2, label='With Clip')
ax10.set_xlabel('Training Step', fontsize=11)
ax10.set_title('MODERATE - Rare Class Accuracy', fontsize=12, fontweight='bold')
ax10.legend()
ax10.grid(True, alpha=0.3)
ax10.set_ylim([0, 1.05])
fig.suptitle('Gradient Clipping Analysis: Physics-of-AI Predictions\n'
'Comparing Extreme (99:1) vs Moderate (80:20) Class Imbalance',
fontsize=14, fontweight='bold', y=1.01)
plt.savefig(filename, dpi=150, bbox_inches='tight')
plt.close()
print(f"Final comparison plot saved to: {filename}")
def compute_statistics(results: Dict) -> Dict:
"""Compute summary statistics for all experiments."""
stats = {}
for imbalance in ['extreme', 'moderate']:
no_clip = results[imbalance]['no_clip']
with_clip = results[imbalance]['with_clip']
stats[imbalance] = {
'weight_norm_std': {
'no_clip': np.std(no_clip['weight_norms']),
'with_clip': np.std(with_clip['weight_norms']),
},
'weight_change_mean': {
'no_clip': np.mean(no_clip['weight_norm_changes']),
'with_clip': np.mean(with_clip['weight_norm_changes']),
},
'weight_change_max': {
'no_clip': np.max(no_clip['weight_norm_changes']),
'with_clip': np.max(with_clip['weight_norm_changes']),
},
'grad_norm_max': {
'no_clip': np.max(no_clip['grad_norms']),
'with_clip': np.max(with_clip['grad_norms']),
},
'effective_dim_std': {
'no_clip': np.std(no_clip['effective_dims']),
'with_clip': np.std(with_clip['effective_dims']),
},
'final_rare_acc': {
'no_clip': no_clip['class_accuracies'][1][-1] if no_clip['class_accuracies'][1] else 0,
'with_clip': with_clip['class_accuracies'][1][-1] if with_clip['class_accuracies'][1] else 0,
},
}
return stats
def print_summary(stats: Dict):
"""Print formatted summary."""
print("\n" + "="*70)
print("EXPERIMENT SUMMARY")
print("="*70)
for imbalance in ['extreme', 'moderate']:
s = stats[imbalance]
label = "EXTREME (99:1)" if imbalance == 'extreme' else "MODERATE (80:20)"
print(f"\n{label}")
print("-" * 50)
print(f"\n[PREDICTION 2] Representation Collapse (Effective Dim Variance):")
print(f" WITHOUT Clipping: {s['effective_dim_std']['no_clip']:.6f}")
print(f" WITH Clipping: {s['effective_dim_std']['with_clip']:.6f}")
supported = s['effective_dim_std']['no_clip'] > s['effective_dim_std']['with_clip']
print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}")
print(f"\n[PREDICTION 4] Rare Sample Learning:")
print(f" Final Rare Accuracy (WITHOUT): {s['final_rare_acc']['no_clip']:.1%}")
print(f" Final Rare Accuracy (WITH): {s['final_rare_acc']['with_clip']:.1%}")
supported = s['final_rare_acc']['with_clip'] >= s['final_rare_acc']['no_clip']
print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}")
print(f"\n[STABILITY] Weight Norm Analysis:")
print(f" Weight Norm Std (WITHOUT): {s['weight_norm_std']['no_clip']:.4f}")
print(f" Weight Norm Std (WITH): {s['weight_norm_std']['with_clip']:.4f}")
print(f" Max Weight Change (WITHOUT): {s['weight_change_max']['no_clip']:.4f}")
print(f" Max Weight Change (WITH): {s['weight_change_max']['with_clip']:.4f}")
print(f"\n[GRADIENT] Analysis:")
print(f" Max Gradient Norm (WITHOUT): {s['grad_norm_max']['no_clip']:.4f}")
print(f" Max Gradient Norm (WITH): {s['grad_norm_max']['with_clip']:.4f}")
print(f" Clipping Ratio: {s['grad_norm_max']['no_clip'] / 1.0:.1f}x threshold")
def main():
# Run experiments
results = run_experiment_suite()
# Generate plots
print("\n" + "="*70)
print("GENERATING PLOTS")
print("="*70)
plot_final_comparison(results, "final_comparison.png")
# Compute and print statistics
stats = compute_statistics(results)
print_summary(stats)
return results, stats
if __name__ == "__main__":
results, stats = main()
print("\n" + "="*70)
print("EXPERIMENT COMPLETE!")
print("="*70)