|
|
""" |
|
|
Final Gradient Clipping Experiment: Testing Physics-of-AI Predictions |
|
|
|
|
|
Key insights from previous experiments: |
|
|
1. With extreme imbalance (99:1), neither model learns rare class |
|
|
2. Gradient clipping's benefit is in STABILITY, not learning rare classes per se |
|
|
3. The key effect is on WEIGHT NORM STABILITY and GRADIENT SPIKE HANDLING |
|
|
|
|
|
This experiment tests: |
|
|
1. Prediction 2: Representation Collapse - effective dim variance without clipping |
|
|
2. Prediction 4: Rare Sample Learning - using moderate imbalance (80:20) |
|
|
3. NEW: Weight norm stability analysis |
|
|
4. NEW: Gradient spike analysis at rare sample positions |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import random |
|
|
from typing import Dict, List |
|
|
|
|
|
SEED = 42 |
|
|
|
|
|
|
|
|
def set_seeds(seed=SEED): |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
class SimpleNextTokenModel(nn.Module): |
|
|
def __init__(self, vocab_size=4, embedding_dim=16): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
|
self.linear = nn.Linear(embedding_dim, vocab_size) |
|
|
|
|
|
def forward(self, x): |
|
|
embedded = self.embedding(x) |
|
|
logits = self.linear(embedded) |
|
|
return logits |
|
|
|
|
|
def get_embeddings(self): |
|
|
return self.embedding.weight.data.clone() |
|
|
|
|
|
|
|
|
def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float: |
|
|
"""PCA-based effective dimensionality.""" |
|
|
centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True) |
|
|
cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1) |
|
|
eigenvalues = torch.linalg.eigvalsh(cov) |
|
|
eigenvalues = torch.clamp(eigenvalues, min=1e-10) |
|
|
eigenvalues = eigenvalues / eigenvalues.sum() |
|
|
entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) |
|
|
return torch.exp(entropy).item() |
|
|
|
|
|
|
|
|
def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor, |
|
|
targets: torch.Tensor) -> Dict[int, float]: |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
logits = model(inputs) |
|
|
predictions = logits.argmax(dim=1) |
|
|
|
|
|
accuracies = {} |
|
|
for class_idx in range(4): |
|
|
mask = targets == class_idx |
|
|
if mask.sum() > 0: |
|
|
correct = (predictions[mask] == targets[mask]).float().mean().item() |
|
|
accuracies[class_idx] = correct |
|
|
else: |
|
|
accuracies[class_idx] = None |
|
|
|
|
|
return accuracies |
|
|
|
|
|
|
|
|
def create_dataset_moderate_imbalance(n_samples=1000, rare_ratio=0.2, seed=SEED): |
|
|
"""Create dataset with moderate imbalance (e.g., 80:20).""" |
|
|
set_seeds(seed) |
|
|
|
|
|
n_rare = int(n_samples * rare_ratio) |
|
|
n_common = n_samples - n_rare |
|
|
|
|
|
inputs = torch.randint(0, 4, (n_samples,)) |
|
|
targets = torch.zeros(n_samples, dtype=torch.long) |
|
|
|
|
|
rare_indices = random.sample(range(n_samples), n_rare) |
|
|
targets[rare_indices] = 1 |
|
|
|
|
|
return inputs, targets, sorted(rare_indices) |
|
|
|
|
|
|
|
|
def create_dataset_extreme_imbalance(n_samples=1000, n_rare=10, seed=SEED): |
|
|
"""Create dataset with extreme imbalance (99:1).""" |
|
|
set_seeds(seed) |
|
|
|
|
|
inputs = torch.randint(0, 4, (n_samples,)) |
|
|
targets = torch.zeros(n_samples, dtype=torch.long) |
|
|
|
|
|
rare_indices = random.sample(range(n_samples), n_rare) |
|
|
targets[rare_indices] = 1 |
|
|
|
|
|
return inputs, targets, sorted(rare_indices) |
|
|
|
|
|
|
|
|
def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor, |
|
|
rare_indices: List[int], clip_grad: bool = False, |
|
|
max_norm: float = 1.0, n_epochs: int = 10, |
|
|
lr: float = 0.1, init_weights=None, |
|
|
track_every: int = 50) -> Dict: |
|
|
"""Training with comprehensive tracking.""" |
|
|
set_seeds(SEED) |
|
|
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
if init_weights: |
|
|
model.load_state_dict({k: v.clone() for k, v in init_weights.items()}) |
|
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=lr) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
metrics = { |
|
|
'losses': [], |
|
|
'grad_norms': [], |
|
|
'weight_norms': [], |
|
|
'effective_dims': [], |
|
|
'effective_dim_steps': [], |
|
|
'class_accuracies': {0: [], 1: [], 2: [], 3: []}, |
|
|
'accuracy_steps': [], |
|
|
'weight_norm_changes': [], |
|
|
} |
|
|
|
|
|
step = 0 |
|
|
n_samples = len(inputs) |
|
|
prev_weight_norm = None |
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
model.train() |
|
|
|
|
|
for i in range(n_samples): |
|
|
x = inputs[i:i+1] |
|
|
y = targets[i:i+1] |
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits = model(x) |
|
|
loss = criterion(logits, y) |
|
|
loss.backward() |
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) |
|
|
|
|
|
if clip_grad: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
metrics['losses'].append(loss.item()) |
|
|
metrics['grad_norms'].append(grad_norm.item()) |
|
|
|
|
|
current_weight_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
|
|
metrics['weight_norms'].append(current_weight_norm) |
|
|
|
|
|
|
|
|
if prev_weight_norm is not None: |
|
|
metrics['weight_norm_changes'].append(abs(current_weight_norm - prev_weight_norm)) |
|
|
else: |
|
|
metrics['weight_norm_changes'].append(0) |
|
|
prev_weight_norm = current_weight_norm |
|
|
|
|
|
|
|
|
if step % track_every == 0: |
|
|
emb_matrix = model.get_embeddings() |
|
|
eff_dim = compute_effective_dimension(emb_matrix) |
|
|
metrics['effective_dims'].append(eff_dim) |
|
|
metrics['effective_dim_steps'].append(step) |
|
|
|
|
|
class_acc = compute_per_class_accuracy(model, inputs, targets) |
|
|
for cls_idx in range(4): |
|
|
if class_acc[cls_idx] is not None: |
|
|
metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx]) |
|
|
else: |
|
|
metrics['class_accuracies'][cls_idx].append(0.0) |
|
|
metrics['accuracy_steps'].append(step) |
|
|
|
|
|
step += 1 |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def run_experiment_suite(): |
|
|
"""Run complete experiment suite with both imbalance levels.""" |
|
|
print("="*70) |
|
|
print("FINAL GRADIENT CLIPPING EXPERIMENT") |
|
|
print("Testing Physics-of-AI Predictions") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
set_seeds(SEED) |
|
|
init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
init_weights = {name: param.clone() for name, param in init_model.state_dict().items()} |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("EXPERIMENT 1: EXTREME IMBALANCE (99:1)") |
|
|
print("="*70) |
|
|
|
|
|
inputs_extreme, targets_extreme, rare_extreme = create_dataset_extreme_imbalance( |
|
|
n_samples=1000, n_rare=10, seed=SEED |
|
|
) |
|
|
print(f"Dataset: {(targets_extreme == 0).sum().item()} common, {(targets_extreme == 1).sum().item()} rare") |
|
|
|
|
|
print("\nTraining WITHOUT clipping...") |
|
|
metrics_extreme_no_clip = train_with_tracking( |
|
|
inputs_extreme, targets_extreme, rare_extreme, |
|
|
clip_grad=False, n_epochs=5, lr=0.1, |
|
|
init_weights=init_weights, track_every=100 |
|
|
) |
|
|
|
|
|
print("Training WITH clipping...") |
|
|
metrics_extreme_with_clip = train_with_tracking( |
|
|
inputs_extreme, targets_extreme, rare_extreme, |
|
|
clip_grad=True, max_norm=1.0, n_epochs=5, lr=0.1, |
|
|
init_weights=init_weights, track_every=100 |
|
|
) |
|
|
|
|
|
results['extreme'] = { |
|
|
'no_clip': metrics_extreme_no_clip, |
|
|
'with_clip': metrics_extreme_with_clip, |
|
|
'rare_indices': rare_extreme |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("EXPERIMENT 2: MODERATE IMBALANCE (80:20)") |
|
|
print("="*70) |
|
|
|
|
|
inputs_moderate, targets_moderate, rare_moderate = create_dataset_moderate_imbalance( |
|
|
n_samples=1000, rare_ratio=0.2, seed=SEED |
|
|
) |
|
|
print(f"Dataset: {(targets_moderate == 0).sum().item()} common, {(targets_moderate == 1).sum().item()} rare") |
|
|
|
|
|
print("\nTraining WITHOUT clipping...") |
|
|
metrics_moderate_no_clip = train_with_tracking( |
|
|
inputs_moderate, targets_moderate, rare_moderate, |
|
|
clip_grad=False, n_epochs=10, lr=0.1, |
|
|
init_weights=init_weights, track_every=100 |
|
|
) |
|
|
|
|
|
print("Training WITH clipping...") |
|
|
metrics_moderate_with_clip = train_with_tracking( |
|
|
inputs_moderate, targets_moderate, rare_moderate, |
|
|
clip_grad=True, max_norm=1.0, n_epochs=10, lr=0.1, |
|
|
init_weights=init_weights, track_every=100 |
|
|
) |
|
|
|
|
|
results['moderate'] = { |
|
|
'no_clip': metrics_moderate_no_clip, |
|
|
'with_clip': metrics_moderate_with_clip, |
|
|
'rare_indices': rare_moderate |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def plot_final_comparison(results: Dict, filename: str): |
|
|
"""Create final comparison plot.""" |
|
|
fig = plt.figure(figsize=(20, 20)) |
|
|
gs = fig.add_gridspec(5, 2, hspace=0.35, wspace=0.25) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax1 = fig.add_subplot(gs[0, 0]) |
|
|
ax2 = fig.add_subplot(gs[0, 1]) |
|
|
|
|
|
|
|
|
steps = range(len(results['extreme']['no_clip']['weight_norms'])) |
|
|
ax1.plot(steps, results['extreme']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') |
|
|
ax1.plot(steps, results['extreme']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') |
|
|
ax1.set_ylabel('Weight Norm', fontsize=11) |
|
|
ax1.set_title('EXTREME (99:1) - Weight Norm Evolution', fontsize=12, fontweight='bold') |
|
|
ax1.legend() |
|
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
steps = range(len(results['moderate']['no_clip']['weight_norms'])) |
|
|
ax2.plot(steps, results['moderate']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') |
|
|
ax2.plot(steps, results['moderate']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') |
|
|
ax2.set_title('MODERATE (80:20) - Weight Norm Evolution', fontsize=12, fontweight='bold') |
|
|
ax2.legend() |
|
|
ax2.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax3 = fig.add_subplot(gs[1, 0]) |
|
|
ax4 = fig.add_subplot(gs[1, 1]) |
|
|
|
|
|
|
|
|
steps = range(len(results['extreme']['no_clip']['weight_norm_changes'])) |
|
|
ax3.plot(steps, results['extreme']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip') |
|
|
ax3.plot(steps, results['extreme']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip') |
|
|
ax3.set_ylabel('|Weight Norm Change|', fontsize=11) |
|
|
ax3.set_title('EXTREME - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold') |
|
|
ax3.legend() |
|
|
ax3.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
steps = range(len(results['moderate']['no_clip']['weight_norm_changes'])) |
|
|
ax4.plot(steps, results['moderate']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip') |
|
|
ax4.plot(steps, results['moderate']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip') |
|
|
ax4.set_title('MODERATE - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold') |
|
|
ax4.legend() |
|
|
ax4.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax5 = fig.add_subplot(gs[2, 0]) |
|
|
ax6 = fig.add_subplot(gs[2, 1]) |
|
|
|
|
|
|
|
|
steps = range(len(results['extreme']['no_clip']['grad_norms'])) |
|
|
ax5.plot(steps, results['extreme']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') |
|
|
ax5.plot(steps, results['extreme']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') |
|
|
ax5.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
|
|
ax5.set_ylabel('Gradient Norm', fontsize=11) |
|
|
ax5.set_title('EXTREME - Gradient Norms', fontsize=12, fontweight='bold') |
|
|
ax5.legend() |
|
|
ax5.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
steps = range(len(results['moderate']['no_clip']['grad_norms'])) |
|
|
ax6.plot(steps, results['moderate']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') |
|
|
ax6.plot(steps, results['moderate']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') |
|
|
ax6.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
|
|
ax6.set_title('MODERATE - Gradient Norms', fontsize=12, fontweight='bold') |
|
|
ax6.legend() |
|
|
ax6.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax7 = fig.add_subplot(gs[3, 0]) |
|
|
ax8 = fig.add_subplot(gs[3, 1]) |
|
|
|
|
|
|
|
|
ax7.plot(results['extreme']['no_clip']['effective_dim_steps'], |
|
|
results['extreme']['no_clip']['effective_dims'], |
|
|
'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip') |
|
|
ax7.plot(results['extreme']['with_clip']['effective_dim_steps'], |
|
|
results['extreme']['with_clip']['effective_dims'], |
|
|
'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip') |
|
|
ax7.set_ylabel('Effective Dimension', fontsize=11) |
|
|
ax7.set_title('EXTREME - Effective Dimensionality', fontsize=12, fontweight='bold') |
|
|
ax7.legend() |
|
|
ax7.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax8.plot(results['moderate']['no_clip']['effective_dim_steps'], |
|
|
results['moderate']['no_clip']['effective_dims'], |
|
|
'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip') |
|
|
ax8.plot(results['moderate']['with_clip']['effective_dim_steps'], |
|
|
results['moderate']['with_clip']['effective_dims'], |
|
|
'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip') |
|
|
ax8.set_title('MODERATE - Effective Dimensionality', fontsize=12, fontweight='bold') |
|
|
ax8.legend() |
|
|
ax8.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax9 = fig.add_subplot(gs[4, 0]) |
|
|
ax10 = fig.add_subplot(gs[4, 1]) |
|
|
|
|
|
|
|
|
ax9.plot(results['extreme']['no_clip']['accuracy_steps'], |
|
|
results['extreme']['no_clip']['class_accuracies'][1], |
|
|
'r-', alpha=0.7, linewidth=2, label='Without Clip') |
|
|
ax9.plot(results['extreme']['with_clip']['accuracy_steps'], |
|
|
results['extreme']['with_clip']['class_accuracies'][1], |
|
|
'g-', alpha=0.7, linewidth=2, label='With Clip') |
|
|
ax9.set_ylabel('Rare Class B Accuracy', fontsize=11) |
|
|
ax9.set_xlabel('Training Step', fontsize=11) |
|
|
ax9.set_title('EXTREME - Rare Class Accuracy', fontsize=12, fontweight='bold') |
|
|
ax9.legend() |
|
|
ax9.grid(True, alpha=0.3) |
|
|
ax9.set_ylim([0, 1.05]) |
|
|
|
|
|
|
|
|
ax10.plot(results['moderate']['no_clip']['accuracy_steps'], |
|
|
results['moderate']['no_clip']['class_accuracies'][1], |
|
|
'r-', alpha=0.7, linewidth=2, label='Without Clip') |
|
|
ax10.plot(results['moderate']['with_clip']['accuracy_steps'], |
|
|
results['moderate']['with_clip']['class_accuracies'][1], |
|
|
'g-', alpha=0.7, linewidth=2, label='With Clip') |
|
|
ax10.set_xlabel('Training Step', fontsize=11) |
|
|
ax10.set_title('MODERATE - Rare Class Accuracy', fontsize=12, fontweight='bold') |
|
|
ax10.legend() |
|
|
ax10.grid(True, alpha=0.3) |
|
|
ax10.set_ylim([0, 1.05]) |
|
|
|
|
|
fig.suptitle('Gradient Clipping Analysis: Physics-of-AI Predictions\n' |
|
|
'Comparing Extreme (99:1) vs Moderate (80:20) Class Imbalance', |
|
|
fontsize=14, fontweight='bold', y=1.01) |
|
|
|
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"Final comparison plot saved to: {filename}") |
|
|
|
|
|
|
|
|
def compute_statistics(results: Dict) -> Dict: |
|
|
"""Compute summary statistics for all experiments.""" |
|
|
stats = {} |
|
|
|
|
|
for imbalance in ['extreme', 'moderate']: |
|
|
no_clip = results[imbalance]['no_clip'] |
|
|
with_clip = results[imbalance]['with_clip'] |
|
|
|
|
|
stats[imbalance] = { |
|
|
'weight_norm_std': { |
|
|
'no_clip': np.std(no_clip['weight_norms']), |
|
|
'with_clip': np.std(with_clip['weight_norms']), |
|
|
}, |
|
|
'weight_change_mean': { |
|
|
'no_clip': np.mean(no_clip['weight_norm_changes']), |
|
|
'with_clip': np.mean(with_clip['weight_norm_changes']), |
|
|
}, |
|
|
'weight_change_max': { |
|
|
'no_clip': np.max(no_clip['weight_norm_changes']), |
|
|
'with_clip': np.max(with_clip['weight_norm_changes']), |
|
|
}, |
|
|
'grad_norm_max': { |
|
|
'no_clip': np.max(no_clip['grad_norms']), |
|
|
'with_clip': np.max(with_clip['grad_norms']), |
|
|
}, |
|
|
'effective_dim_std': { |
|
|
'no_clip': np.std(no_clip['effective_dims']), |
|
|
'with_clip': np.std(with_clip['effective_dims']), |
|
|
}, |
|
|
'final_rare_acc': { |
|
|
'no_clip': no_clip['class_accuracies'][1][-1] if no_clip['class_accuracies'][1] else 0, |
|
|
'with_clip': with_clip['class_accuracies'][1][-1] if with_clip['class_accuracies'][1] else 0, |
|
|
}, |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
def print_summary(stats: Dict): |
|
|
"""Print formatted summary.""" |
|
|
print("\n" + "="*70) |
|
|
print("EXPERIMENT SUMMARY") |
|
|
print("="*70) |
|
|
|
|
|
for imbalance in ['extreme', 'moderate']: |
|
|
s = stats[imbalance] |
|
|
label = "EXTREME (99:1)" if imbalance == 'extreme' else "MODERATE (80:20)" |
|
|
|
|
|
print(f"\n{label}") |
|
|
print("-" * 50) |
|
|
|
|
|
print(f"\n[PREDICTION 2] Representation Collapse (Effective Dim Variance):") |
|
|
print(f" WITHOUT Clipping: {s['effective_dim_std']['no_clip']:.6f}") |
|
|
print(f" WITH Clipping: {s['effective_dim_std']['with_clip']:.6f}") |
|
|
supported = s['effective_dim_std']['no_clip'] > s['effective_dim_std']['with_clip'] |
|
|
print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}") |
|
|
|
|
|
print(f"\n[PREDICTION 4] Rare Sample Learning:") |
|
|
print(f" Final Rare Accuracy (WITHOUT): {s['final_rare_acc']['no_clip']:.1%}") |
|
|
print(f" Final Rare Accuracy (WITH): {s['final_rare_acc']['with_clip']:.1%}") |
|
|
supported = s['final_rare_acc']['with_clip'] >= s['final_rare_acc']['no_clip'] |
|
|
print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}") |
|
|
|
|
|
print(f"\n[STABILITY] Weight Norm Analysis:") |
|
|
print(f" Weight Norm Std (WITHOUT): {s['weight_norm_std']['no_clip']:.4f}") |
|
|
print(f" Weight Norm Std (WITH): {s['weight_norm_std']['with_clip']:.4f}") |
|
|
print(f" Max Weight Change (WITHOUT): {s['weight_change_max']['no_clip']:.4f}") |
|
|
print(f" Max Weight Change (WITH): {s['weight_change_max']['with_clip']:.4f}") |
|
|
|
|
|
print(f"\n[GRADIENT] Analysis:") |
|
|
print(f" Max Gradient Norm (WITHOUT): {s['grad_norm_max']['no_clip']:.4f}") |
|
|
print(f" Max Gradient Norm (WITH): {s['grad_norm_max']['with_clip']:.4f}") |
|
|
print(f" Clipping Ratio: {s['grad_norm_max']['no_clip'] / 1.0:.1f}x threshold") |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
results = run_experiment_suite() |
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("GENERATING PLOTS") |
|
|
print("="*70) |
|
|
|
|
|
plot_final_comparison(results, "final_comparison.png") |
|
|
|
|
|
|
|
|
stats = compute_statistics(results) |
|
|
print_summary(stats) |
|
|
|
|
|
return results, stats |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
results, stats = main() |
|
|
print("\n" + "="*70) |
|
|
print("EXPERIMENT COMPLETE!") |
|
|
print("="*70) |
|
|
|