|
|
""" |
|
|
Extended Gradient Clipping Experiment: Testing Physics-of-AI Predictions |
|
|
|
|
|
This script tests two predictions from our Physics-of-AI analysis: |
|
|
|
|
|
Prediction 2: Representation Collapse |
|
|
- Hypothesis: Without clipping, the effective dimensionality of embeddings |
|
|
should show sudden drops at rare sample positions. |
|
|
- Test: Track PCA-based effective dimension throughout training. |
|
|
|
|
|
Prediction 4: Rare Sample Learning |
|
|
- Hypothesis: With clipping, the model should achieve better accuracy on rare samples. |
|
|
- Test: Track per-class accuracy throughout training. |
|
|
|
|
|
Based on Ziming Liu's Physics-of-AI framework and the unigram toy model analysis. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import random |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
|
|
|
SEED = 42 |
|
|
|
|
|
|
|
|
def set_seeds(seed=SEED): |
|
|
"""Set all random seeds for reproducibility.""" |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleNextTokenModel(nn.Module): |
|
|
""" |
|
|
Simple model that takes a token index and predicts the next token. |
|
|
Architecture: Embedding -> Linear |
|
|
""" |
|
|
def __init__(self, vocab_size=4, embedding_dim=16): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
|
self.linear = nn.Linear(embedding_dim, vocab_size) |
|
|
|
|
|
def forward(self, x): |
|
|
embedded = self.embedding(x) |
|
|
logits = self.linear(embedded) |
|
|
return logits |
|
|
|
|
|
def get_embeddings(self): |
|
|
"""Return the embedding matrix for analysis.""" |
|
|
return self.embedding.weight.data.clone() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float: |
|
|
""" |
|
|
Compute effective dimensionality using PCA entropy. |
|
|
|
|
|
Following Ziming Liu's approach from the Unigram toy model analysis: |
|
|
"We define effective dimensionality via PCA entropy" |
|
|
|
|
|
Effective dimension = exp(entropy of normalized eigenvalues) |
|
|
|
|
|
Args: |
|
|
embedding_matrix: (vocab_size, embedding_dim) tensor |
|
|
|
|
|
Returns: |
|
|
Effective dimension (float between 1 and embedding_dim) |
|
|
""" |
|
|
|
|
|
centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1) |
|
|
|
|
|
|
|
|
eigenvalues = torch.linalg.eigvalsh(cov) |
|
|
eigenvalues = torch.clamp(eigenvalues, min=1e-10) |
|
|
|
|
|
|
|
|
eigenvalues = eigenvalues / eigenvalues.sum() |
|
|
|
|
|
|
|
|
entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) |
|
|
|
|
|
|
|
|
effective_dim = torch.exp(entropy).item() |
|
|
|
|
|
return effective_dim |
|
|
|
|
|
|
|
|
def compute_embedding_stats(embedding_matrix: torch.Tensor) -> Dict[str, float]: |
|
|
""" |
|
|
Compute various statistics about the embedding matrix. |
|
|
|
|
|
Returns: |
|
|
Dictionary with embedding statistics |
|
|
""" |
|
|
|
|
|
eff_dim = compute_effective_dimension(embedding_matrix) |
|
|
|
|
|
|
|
|
norms = torch.norm(embedding_matrix, dim=1) |
|
|
|
|
|
|
|
|
normalized = embedding_matrix / (norms.unsqueeze(1) + 1e-10) |
|
|
cosine_sim = torch.mm(normalized, normalized.T) |
|
|
|
|
|
mask = ~torch.eye(cosine_sim.shape[0], dtype=bool) |
|
|
off_diag = cosine_sim[mask] |
|
|
|
|
|
return { |
|
|
'effective_dim': eff_dim, |
|
|
'mean_norm': norms.mean().item(), |
|
|
'std_norm': norms.std().item(), |
|
|
'mean_cosine_sim': off_diag.mean().item(), |
|
|
'max_cosine_sim': off_diag.max().item(), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor, |
|
|
targets: torch.Tensor) -> Dict[int, float]: |
|
|
""" |
|
|
Compute accuracy for each target class. |
|
|
|
|
|
Args: |
|
|
model: The neural network |
|
|
inputs: Input token indices |
|
|
targets: Target token indices |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping class index to accuracy |
|
|
""" |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
logits = model(inputs) |
|
|
predictions = logits.argmax(dim=1) |
|
|
|
|
|
accuracies = {} |
|
|
for class_idx in range(4): |
|
|
mask = targets == class_idx |
|
|
if mask.sum() > 0: |
|
|
correct = (predictions[mask] == targets[mask]).float().mean().item() |
|
|
accuracies[class_idx] = correct |
|
|
else: |
|
|
accuracies[class_idx] = None |
|
|
|
|
|
return accuracies |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED): |
|
|
""" |
|
|
Create a synthetic dataset with imbalanced targets. |
|
|
""" |
|
|
set_seeds(seed) |
|
|
|
|
|
inputs = torch.randint(0, 4, (n_samples,)) |
|
|
targets = torch.zeros(n_samples, dtype=torch.long) |
|
|
|
|
|
rare_indices = random.sample(range(n_samples), n_rare) |
|
|
targets[rare_indices] = 1 |
|
|
|
|
|
return inputs, targets, sorted(rare_indices) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor, |
|
|
rare_indices: List[int], clip_grad: bool = False, |
|
|
max_norm: float = 1.0, n_epochs: int = 3, |
|
|
lr: float = 0.1, init_weights=None, |
|
|
track_every: int = 10) -> Dict: |
|
|
""" |
|
|
Train with extended tracking of: |
|
|
- Loss, gradient norm, weight norm (as before) |
|
|
- Effective dimensionality of embeddings |
|
|
- Per-class accuracy |
|
|
|
|
|
Args: |
|
|
inputs, targets: Training data |
|
|
rare_indices: Indices of rare 'B' samples |
|
|
clip_grad: Whether to apply gradient clipping |
|
|
max_norm: Clipping threshold |
|
|
n_epochs: Number of epochs |
|
|
lr: Learning rate |
|
|
init_weights: Initial model weights |
|
|
track_every: Track embedding stats every N steps |
|
|
|
|
|
Returns: |
|
|
Dictionary with all tracked metrics |
|
|
""" |
|
|
set_seeds(SEED) |
|
|
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
if init_weights: |
|
|
model.load_state_dict({k: v.clone() for k, v in init_weights.items()}) |
|
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=lr) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
metrics = { |
|
|
'losses': [], |
|
|
'grad_norms': [], |
|
|
'weight_norms': [], |
|
|
'effective_dims': [], |
|
|
'effective_dim_steps': [], |
|
|
'class_accuracies': {0: [], 1: [], 2: [], 3: []}, |
|
|
'accuracy_steps': [], |
|
|
'embedding_stats': [], |
|
|
} |
|
|
|
|
|
mode = "WITH" if clip_grad else "WITHOUT" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Training {mode} gradient clipping (max_norm={max_norm})") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
step = 0 |
|
|
n_samples = len(inputs) |
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
model.train() |
|
|
epoch_losses = [] |
|
|
|
|
|
for i in range(n_samples): |
|
|
x = inputs[i:i+1] |
|
|
y = targets[i:i+1] |
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits = model(x) |
|
|
loss = criterion(logits, y) |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) |
|
|
|
|
|
|
|
|
if clip_grad: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
metrics['losses'].append(loss.item()) |
|
|
metrics['grad_norms'].append(grad_norm.item()) |
|
|
|
|
|
|
|
|
total_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
|
|
metrics['weight_norms'].append(total_norm) |
|
|
|
|
|
epoch_losses.append(loss.item()) |
|
|
|
|
|
|
|
|
is_rare_position = i in rare_indices |
|
|
should_track = (step % track_every == 0) or is_rare_position |
|
|
|
|
|
if should_track: |
|
|
emb_matrix = model.get_embeddings() |
|
|
emb_stats = compute_embedding_stats(emb_matrix) |
|
|
|
|
|
metrics['effective_dims'].append(emb_stats['effective_dim']) |
|
|
metrics['effective_dim_steps'].append(step) |
|
|
metrics['embedding_stats'].append(emb_stats) |
|
|
|
|
|
|
|
|
class_acc = compute_per_class_accuracy(model, inputs, targets) |
|
|
for cls_idx in range(4): |
|
|
if class_acc[cls_idx] is not None: |
|
|
metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx]) |
|
|
else: |
|
|
metrics['class_accuracies'][cls_idx].append(0.0) |
|
|
metrics['accuracy_steps'].append(step) |
|
|
|
|
|
step += 1 |
|
|
|
|
|
avg_loss = np.mean(epoch_losses) |
|
|
|
|
|
|
|
|
class_acc = compute_per_class_accuracy(model, inputs, targets) |
|
|
print(f"Epoch {epoch+1}/{n_epochs}: Avg Loss={avg_loss:.4f}") |
|
|
b_acc = f"{class_acc[1]:.3f}" if class_acc[1] is not None else "N/A" |
|
|
print(f" Class Accuracies: A={class_acc[0]:.3f}, B={b_acc}") |
|
|
|
|
|
eff_dim = compute_effective_dimension(model.get_embeddings()) |
|
|
print(f" Effective Dimension: {eff_dim:.3f}") |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_effective_dimension_comparison(metrics_no_clip: Dict, metrics_with_clip: Dict, |
|
|
rare_indices: List[int], filename: str, |
|
|
n_samples: int = 1000): |
|
|
""" |
|
|
Plot effective dimensionality comparison. |
|
|
|
|
|
This tests Prediction 2: Without clipping, effective dimensionality |
|
|
should show sudden drops at rare sample positions. |
|
|
""" |
|
|
fig, axes = plt.subplots(2, 1, figsize=(14, 10)) |
|
|
|
|
|
|
|
|
ax1 = axes[0] |
|
|
steps_no = metrics_no_clip['effective_dim_steps'] |
|
|
dims_no = metrics_no_clip['effective_dims'] |
|
|
|
|
|
ax1.plot(steps_no, dims_no, 'b-', linewidth=1.5, marker='o', markersize=3, alpha=0.7) |
|
|
ax1.set_ylabel('Effective Dimension', fontsize=12) |
|
|
ax1.set_title('WITHOUT Gradient Clipping - Embedding Effective Dimensionality', |
|
|
fontsize=13, fontweight='bold', color='red') |
|
|
ax1.grid(True, alpha=0.3) |
|
|
ax1.set_ylim([0, 16]) |
|
|
|
|
|
|
|
|
n_epochs = len(metrics_no_clip['losses']) // n_samples |
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
step = epoch * n_samples + idx |
|
|
ax1.axvline(x=step, color='red', alpha=0.3, linewidth=1) |
|
|
|
|
|
|
|
|
ax1.axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples") |
|
|
ax1.legend(loc='upper right') |
|
|
|
|
|
|
|
|
ax2 = axes[1] |
|
|
steps_with = metrics_with_clip['effective_dim_steps'] |
|
|
dims_with = metrics_with_clip['effective_dims'] |
|
|
|
|
|
ax2.plot(steps_with, dims_with, 'g-', linewidth=1.5, marker='o', markersize=3, alpha=0.7) |
|
|
ax2.set_ylabel('Effective Dimension', fontsize=12) |
|
|
ax2.set_xlabel('Training Step', fontsize=12) |
|
|
ax2.set_title('WITH Gradient Clipping - Embedding Effective Dimensionality', |
|
|
fontsize=13, fontweight='bold', color='green') |
|
|
ax2.grid(True, alpha=0.3) |
|
|
ax2.set_ylim([0, 16]) |
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
step = epoch * n_samples + idx |
|
|
ax2.axvline(x=step, color='red', alpha=0.3, linewidth=1) |
|
|
|
|
|
ax2.axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples") |
|
|
ax2.legend(loc='upper right') |
|
|
|
|
|
fig.suptitle('Prediction 2: Representation Collapse Test\n' |
|
|
'(Hypothesis: Without clipping, effective dim drops at rare samples)', |
|
|
fontsize=14, fontweight='bold', y=1.02) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"Effective dimension plot saved to: {filename}") |
|
|
|
|
|
|
|
|
def plot_class_accuracy_comparison(metrics_no_clip: Dict, metrics_with_clip: Dict, |
|
|
filename: str): |
|
|
""" |
|
|
Plot per-class accuracy comparison. |
|
|
|
|
|
This tests Prediction 4: With clipping, the model should achieve |
|
|
better accuracy on rare samples (class 'B'). |
|
|
""" |
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
|
|
|
|
|
|
|
|
ax_a = axes[0, 0] |
|
|
steps_no = metrics_no_clip['accuracy_steps'] |
|
|
steps_with = metrics_with_clip['accuracy_steps'] |
|
|
|
|
|
ax_a.plot(steps_no, metrics_no_clip['class_accuracies'][0], 'r-', |
|
|
linewidth=1.5, alpha=0.7, label='Without Clipping') |
|
|
ax_a.plot(steps_with, metrics_with_clip['class_accuracies'][0], 'g-', |
|
|
linewidth=1.5, alpha=0.7, label='With Clipping') |
|
|
ax_a.set_ylabel('Accuracy', fontsize=11) |
|
|
ax_a.set_title("Class 'A' (Common - 990 samples)", fontsize=12, fontweight='bold') |
|
|
ax_a.legend() |
|
|
ax_a.grid(True, alpha=0.3) |
|
|
ax_a.set_ylim([0, 1.05]) |
|
|
|
|
|
|
|
|
ax_b = axes[0, 1] |
|
|
ax_b.plot(steps_no, metrics_no_clip['class_accuracies'][1], 'r-', |
|
|
linewidth=1.5, alpha=0.7, label='Without Clipping') |
|
|
ax_b.plot(steps_with, metrics_with_clip['class_accuracies'][1], 'g-', |
|
|
linewidth=1.5, alpha=0.7, label='With Clipping') |
|
|
ax_b.set_ylabel('Accuracy', fontsize=11) |
|
|
ax_b.set_title("Class 'B' (Rare - 10 samples) ⭐ KEY PREDICTION", |
|
|
fontsize=12, fontweight='bold', color='purple') |
|
|
ax_b.legend() |
|
|
ax_b.grid(True, alpha=0.3) |
|
|
ax_b.set_ylim([0, 1.05]) |
|
|
|
|
|
|
|
|
ax_diff = axes[1, 0] |
|
|
acc_b_no = np.array(metrics_no_clip['class_accuracies'][1]) |
|
|
acc_b_with = np.array(metrics_with_clip['class_accuracies'][1]) |
|
|
min_len = min(len(acc_b_no), len(acc_b_with)) |
|
|
diff = acc_b_with[:min_len] - acc_b_no[:min_len] |
|
|
|
|
|
colors = ['green' if d >= 0 else 'red' for d in diff] |
|
|
ax_diff.bar(steps_no[:min_len], diff, color=colors, alpha=0.7, width=8) |
|
|
ax_diff.axhline(y=0, color='black', linestyle='-', linewidth=1) |
|
|
ax_diff.set_ylabel('Accuracy Difference\n(With Clip - Without Clip)', fontsize=11) |
|
|
ax_diff.set_xlabel('Training Step', fontsize=11) |
|
|
ax_diff.set_title("Rare Class 'B': Clipping Benefit", fontsize=12, fontweight='bold') |
|
|
ax_diff.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax_summary = axes[1, 1] |
|
|
ax_summary.axis('off') |
|
|
|
|
|
|
|
|
final_acc_a_no = metrics_no_clip['class_accuracies'][0][-1] |
|
|
final_acc_a_with = metrics_with_clip['class_accuracies'][0][-1] |
|
|
final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1] |
|
|
final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1] |
|
|
|
|
|
summary_text = f""" |
|
|
PREDICTION 4 TEST RESULTS |
|
|
═══════════════════════════════════════ |
|
|
|
|
|
Hypothesis: With clipping, the model should |
|
|
achieve better accuracy on rare samples. |
|
|
|
|
|
FINAL ACCURACIES: |
|
|
───────────────────────────────────────── |
|
|
Class 'A' (Common): |
|
|
Without Clipping: {final_acc_a_no:.1%} |
|
|
With Clipping: {final_acc_a_with:.1%} |
|
|
Difference: {final_acc_a_with - final_acc_a_no:+.1%} |
|
|
|
|
|
Class 'B' (Rare): |
|
|
Without Clipping: {final_acc_b_no:.1%} |
|
|
With Clipping: {final_acc_b_with:.1%} |
|
|
Difference: {final_acc_b_with - final_acc_b_no:+.1%} |
|
|
|
|
|
───────────────────────────────────────── |
|
|
VERDICT: {'✅ PREDICTION SUPPORTED' if final_acc_b_with >= final_acc_b_no else '❌ PREDICTION NOT SUPPORTED'} |
|
|
(Clipping {'improves' if final_acc_b_with > final_acc_b_no else 'does not improve'} rare class accuracy) |
|
|
""" |
|
|
|
|
|
ax_summary.text(0.1, 0.5, summary_text, transform=ax_summary.transAxes, |
|
|
fontsize=11, verticalalignment='center', fontfamily='monospace', |
|
|
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)) |
|
|
|
|
|
fig.suptitle('Prediction 4: Rare Sample Learning Test\n' |
|
|
'(Hypothesis: Clipping improves accuracy on rare samples)', |
|
|
fontsize=14, fontweight='bold', y=1.02) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"Class accuracy plot saved to: {filename}") |
|
|
|
|
|
|
|
|
def plot_combined_analysis(metrics_no_clip: Dict, metrics_with_clip: Dict, |
|
|
rare_indices: List[int], filename: str, |
|
|
n_samples: int = 1000): |
|
|
""" |
|
|
Create a comprehensive 6-panel analysis plot. |
|
|
""" |
|
|
fig = plt.figure(figsize=(18, 14)) |
|
|
|
|
|
|
|
|
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.25) |
|
|
|
|
|
n_epochs = len(metrics_no_clip['losses']) // n_samples |
|
|
|
|
|
|
|
|
ax1 = fig.add_subplot(gs[0, 0]) |
|
|
ax2 = fig.add_subplot(gs[0, 1]) |
|
|
|
|
|
|
|
|
ax1.plot(metrics_no_clip['effective_dim_steps'], metrics_no_clip['effective_dims'], |
|
|
'b-', linewidth=1.5, marker='o', markersize=2, alpha=0.7) |
|
|
ax1.set_ylabel('Effective Dimension', fontsize=11) |
|
|
ax1.set_title('Effective Dim - WITHOUT Clipping', fontsize=12, fontweight='bold', color='red') |
|
|
ax1.grid(True, alpha=0.3) |
|
|
ax1.set_ylim([0, 16]) |
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
ax1.axvline(x=epoch * n_samples + idx, color='red', alpha=0.2, linewidth=1) |
|
|
|
|
|
|
|
|
ax2.plot(metrics_with_clip['effective_dim_steps'], metrics_with_clip['effective_dims'], |
|
|
'g-', linewidth=1.5, marker='o', markersize=2, alpha=0.7) |
|
|
ax2.set_title('Effective Dim - WITH Clipping', fontsize=12, fontweight='bold', color='green') |
|
|
ax2.grid(True, alpha=0.3) |
|
|
ax2.set_ylim([0, 16]) |
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
ax2.axvline(x=epoch * n_samples + idx, color='red', alpha=0.2, linewidth=1) |
|
|
|
|
|
|
|
|
ax3 = fig.add_subplot(gs[1, 0]) |
|
|
ax4 = fig.add_subplot(gs[1, 1]) |
|
|
|
|
|
|
|
|
ax3.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][0], |
|
|
'r-', linewidth=1.5, alpha=0.7, label='Without Clip') |
|
|
ax3.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][0], |
|
|
'g-', linewidth=1.5, alpha=0.7, label='With Clip') |
|
|
ax3.set_ylabel('Accuracy', fontsize=11) |
|
|
ax3.set_title("Common Class 'A' Accuracy", fontsize=12, fontweight='bold') |
|
|
ax3.legend() |
|
|
ax3.grid(True, alpha=0.3) |
|
|
ax3.set_ylim([0, 1.05]) |
|
|
|
|
|
|
|
|
ax4.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][1], |
|
|
'r-', linewidth=1.5, alpha=0.7, label='Without Clip') |
|
|
ax4.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][1], |
|
|
'g-', linewidth=1.5, alpha=0.7, label='With Clip') |
|
|
ax4.set_title("Rare Class 'B' Accuracy ⭐", fontsize=12, fontweight='bold', color='purple') |
|
|
ax4.legend() |
|
|
ax4.grid(True, alpha=0.3) |
|
|
ax4.set_ylim([0, 1.05]) |
|
|
|
|
|
|
|
|
ax5 = fig.add_subplot(gs[2, 0]) |
|
|
ax6 = fig.add_subplot(gs[2, 1]) |
|
|
|
|
|
steps = range(len(metrics_no_clip['grad_norms'])) |
|
|
|
|
|
|
|
|
ax5.plot(steps, metrics_no_clip['grad_norms'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip') |
|
|
ax5.plot(steps, metrics_with_clip['grad_norms'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip') |
|
|
ax5.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
|
|
ax5.set_ylabel('Gradient Norm', fontsize=11) |
|
|
ax5.set_xlabel('Training Step', fontsize=11) |
|
|
ax5.set_title('Gradient Norms Comparison', fontsize=12, fontweight='bold') |
|
|
ax5.legend() |
|
|
ax5.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax6.plot(steps, metrics_no_clip['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') |
|
|
ax6.plot(steps, metrics_with_clip['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') |
|
|
ax6.set_xlabel('Training Step', fontsize=11) |
|
|
ax6.set_title('Weight Norms Comparison', fontsize=12, fontweight='bold') |
|
|
ax6.legend() |
|
|
ax6.grid(True, alpha=0.3) |
|
|
|
|
|
fig.suptitle('Extended Gradient Clipping Analysis: Testing Physics-of-AI Predictions\n' |
|
|
'(Red vertical lines = rare sample positions)', |
|
|
fontsize=14, fontweight='bold', y=1.01) |
|
|
|
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"Combined analysis plot saved to: {filename}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
print("="*70) |
|
|
print("EXTENDED GRADIENT CLIPPING EXPERIMENT") |
|
|
print("Testing Physics-of-AI Predictions") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED) |
|
|
|
|
|
print(f"\nDataset created:") |
|
|
print(f" Total samples: {len(inputs)}") |
|
|
print(f" Target 'A' (0): {(targets == 0).sum().item()}") |
|
|
print(f" Target 'B' (1): {(targets == 1).sum().item()}") |
|
|
print(f" Rare 'B' indices: {rare_indices}") |
|
|
|
|
|
|
|
|
set_seeds(SEED) |
|
|
init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
init_weights = {name: param.clone() for name, param in init_model.state_dict().items()} |
|
|
|
|
|
|
|
|
init_eff_dim = compute_effective_dimension(init_model.get_embeddings()) |
|
|
print(f"\nInitial embedding effective dimension: {init_eff_dim:.3f}") |
|
|
|
|
|
|
|
|
metrics_no_clip = train_with_tracking( |
|
|
inputs, targets, rare_indices, |
|
|
clip_grad=False, n_epochs=3, lr=0.1, |
|
|
init_weights=init_weights, track_every=5 |
|
|
) |
|
|
|
|
|
|
|
|
metrics_with_clip = train_with_tracking( |
|
|
inputs, targets, rare_indices, |
|
|
clip_grad=True, max_norm=1.0, n_epochs=3, lr=0.1, |
|
|
init_weights=init_weights, track_every=5 |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("GENERATING ANALYSIS PLOTS") |
|
|
print("="*70) |
|
|
|
|
|
plot_effective_dimension_comparison( |
|
|
metrics_no_clip, metrics_with_clip, rare_indices, |
|
|
"effective_dimension_comparison.png" |
|
|
) |
|
|
|
|
|
plot_class_accuracy_comparison( |
|
|
metrics_no_clip, metrics_with_clip, |
|
|
"class_accuracy_comparison.png" |
|
|
) |
|
|
|
|
|
plot_combined_analysis( |
|
|
metrics_no_clip, metrics_with_clip, rare_indices, |
|
|
"combined_analysis.png" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("PREDICTION TEST RESULTS") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
print("\n📊 PREDICTION 2: Representation Collapse") |
|
|
print("-" * 50) |
|
|
|
|
|
dims_no = metrics_no_clip['effective_dims'] |
|
|
dims_with = metrics_with_clip['effective_dims'] |
|
|
|
|
|
print(f"Effective Dimension Statistics:") |
|
|
print(f" WITHOUT Clipping:") |
|
|
print(f" Initial: {dims_no[0]:.3f}") |
|
|
print(f" Final: {dims_no[-1]:.3f}") |
|
|
print(f" Min: {min(dims_no):.3f}") |
|
|
print(f" Max: {max(dims_no):.3f}") |
|
|
print(f" Std: {np.std(dims_no):.3f}") |
|
|
|
|
|
print(f" WITH Clipping:") |
|
|
print(f" Initial: {dims_with[0]:.3f}") |
|
|
print(f" Final: {dims_with[-1]:.3f}") |
|
|
print(f" Min: {min(dims_with):.3f}") |
|
|
print(f" Max: {max(dims_with):.3f}") |
|
|
print(f" Std: {np.std(dims_with):.3f}") |
|
|
|
|
|
|
|
|
collapse_supported = np.std(dims_no) > np.std(dims_with) |
|
|
print(f"\n Verdict: {'✅ SUPPORTED' if collapse_supported else '❌ NOT SUPPORTED'}") |
|
|
print(f" (Without clipping has {'higher' if collapse_supported else 'lower'} variance in effective dim)") |
|
|
|
|
|
|
|
|
print("\n📊 PREDICTION 4: Rare Sample Learning") |
|
|
print("-" * 50) |
|
|
|
|
|
final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1] |
|
|
final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1] |
|
|
|
|
|
print(f"Final Rare Class 'B' Accuracy:") |
|
|
print(f" WITHOUT Clipping: {final_acc_b_no:.1%}") |
|
|
print(f" WITH Clipping: {final_acc_b_with:.1%}") |
|
|
print(f" Difference: {final_acc_b_with - final_acc_b_no:+.1%}") |
|
|
|
|
|
rare_learning_supported = final_acc_b_with >= final_acc_b_no |
|
|
print(f"\n Verdict: {'✅ SUPPORTED' if rare_learning_supported else '❌ NOT SUPPORTED'}") |
|
|
|
|
|
|
|
|
return { |
|
|
'metrics_no_clip': metrics_no_clip, |
|
|
'metrics_with_clip': metrics_with_clip, |
|
|
'rare_indices': rare_indices, |
|
|
'prediction_2_supported': collapse_supported, |
|
|
'prediction_4_supported': rare_learning_supported, |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
results = main() |
|
|
print("\n" + "="*70) |
|
|
print("EXPERIMENT COMPLETE!") |
|
|
print("="*70) |
|
|
|