|
|
""" |
|
|
Extended Gradient Clipping Experiment V2: Testing Physics-of-AI Predictions |
|
|
|
|
|
Key changes from V1: |
|
|
1. More epochs (10 instead of 3) to allow rare class learning |
|
|
2. Smaller learning rate (0.01) for more stable training |
|
|
3. More frequent tracking to catch dynamics |
|
|
4. Added loss tracking per class to understand learning dynamics |
|
|
|
|
|
Predictions being tested: |
|
|
- Prediction 2: Representation Collapse (effective dimensionality drops without clipping) |
|
|
- Prediction 4: Rare Sample Learning (clipping improves rare class accuracy) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import random |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
SEED = 42 |
|
|
|
|
|
|
|
|
def set_seeds(seed=SEED): |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
class SimpleNextTokenModel(nn.Module): |
|
|
def __init__(self, vocab_size=4, embedding_dim=16): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
|
self.linear = nn.Linear(embedding_dim, vocab_size) |
|
|
|
|
|
def forward(self, x): |
|
|
embedded = self.embedding(x) |
|
|
logits = self.linear(embedded) |
|
|
return logits |
|
|
|
|
|
def get_embeddings(self): |
|
|
return self.embedding.weight.data.clone() |
|
|
|
|
|
|
|
|
def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float: |
|
|
"""PCA-based effective dimensionality using entropy.""" |
|
|
centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True) |
|
|
cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1) |
|
|
eigenvalues = torch.linalg.eigvalsh(cov) |
|
|
eigenvalues = torch.clamp(eigenvalues, min=1e-10) |
|
|
eigenvalues = eigenvalues / eigenvalues.sum() |
|
|
entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) |
|
|
return torch.exp(entropy).item() |
|
|
|
|
|
|
|
|
def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor, |
|
|
targets: torch.Tensor) -> Dict[int, float]: |
|
|
"""Compute accuracy for each target class.""" |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
logits = model(inputs) |
|
|
predictions = logits.argmax(dim=1) |
|
|
|
|
|
accuracies = {} |
|
|
for class_idx in range(4): |
|
|
mask = targets == class_idx |
|
|
if mask.sum() > 0: |
|
|
correct = (predictions[mask] == targets[mask]).float().mean().item() |
|
|
accuracies[class_idx] = correct |
|
|
else: |
|
|
accuracies[class_idx] = None |
|
|
|
|
|
return accuracies |
|
|
|
|
|
|
|
|
def compute_per_class_loss(model: nn.Module, inputs: torch.Tensor, |
|
|
targets: torch.Tensor, criterion: nn.Module) -> Dict[int, float]: |
|
|
"""Compute average loss for each target class.""" |
|
|
model.eval() |
|
|
losses = {} |
|
|
with torch.no_grad(): |
|
|
logits = model(inputs) |
|
|
for class_idx in range(4): |
|
|
mask = targets == class_idx |
|
|
if mask.sum() > 0: |
|
|
class_loss = criterion(logits[mask], targets[mask]).item() |
|
|
losses[class_idx] = class_loss |
|
|
else: |
|
|
losses[class_idx] = None |
|
|
return losses |
|
|
|
|
|
|
|
|
def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED): |
|
|
set_seeds(seed) |
|
|
inputs = torch.randint(0, 4, (n_samples,)) |
|
|
targets = torch.zeros(n_samples, dtype=torch.long) |
|
|
rare_indices = random.sample(range(n_samples), n_rare) |
|
|
targets[rare_indices] = 1 |
|
|
return inputs, targets, sorted(rare_indices) |
|
|
|
|
|
|
|
|
def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor, |
|
|
rare_indices: List[int], clip_grad: bool = False, |
|
|
max_norm: float = 1.0, n_epochs: int = 10, |
|
|
lr: float = 0.01, init_weights=None, |
|
|
track_every: int = 50) -> Dict: |
|
|
""" |
|
|
Extended training with comprehensive tracking. |
|
|
""" |
|
|
set_seeds(SEED) |
|
|
model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
if init_weights: |
|
|
model.load_state_dict({k: v.clone() for k, v in init_weights.items()}) |
|
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=lr) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
metrics = { |
|
|
'losses': [], |
|
|
'grad_norms': [], |
|
|
'weight_norms': [], |
|
|
'effective_dims': [], |
|
|
'effective_dim_steps': [], |
|
|
'class_accuracies': {0: [], 1: [], 2: [], 3: []}, |
|
|
'class_losses': {0: [], 1: [], 2: [], 3: []}, |
|
|
'accuracy_steps': [], |
|
|
'rare_sample_losses': [], |
|
|
'rare_sample_steps': [], |
|
|
} |
|
|
|
|
|
mode = "WITH" if clip_grad else "WITHOUT" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Training {mode} gradient clipping (max_norm={max_norm})") |
|
|
print(f"Learning rate: {lr}, Epochs: {n_epochs}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
step = 0 |
|
|
n_samples = len(inputs) |
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
model.train() |
|
|
epoch_losses = [] |
|
|
|
|
|
for i in range(n_samples): |
|
|
x = inputs[i:i+1] |
|
|
y = targets[i:i+1] |
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits = model(x) |
|
|
loss = criterion(logits, y) |
|
|
loss.backward() |
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) |
|
|
|
|
|
if clip_grad: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
metrics['losses'].append(loss.item()) |
|
|
metrics['grad_norms'].append(grad_norm.item()) |
|
|
|
|
|
total_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
|
|
metrics['weight_norms'].append(total_norm) |
|
|
|
|
|
epoch_losses.append(loss.item()) |
|
|
|
|
|
|
|
|
if i in rare_indices: |
|
|
metrics['rare_sample_losses'].append(loss.item()) |
|
|
metrics['rare_sample_steps'].append(step) |
|
|
|
|
|
|
|
|
if step % track_every == 0: |
|
|
emb_matrix = model.get_embeddings() |
|
|
eff_dim = compute_effective_dimension(emb_matrix) |
|
|
|
|
|
metrics['effective_dims'].append(eff_dim) |
|
|
metrics['effective_dim_steps'].append(step) |
|
|
|
|
|
class_acc = compute_per_class_accuracy(model, inputs, targets) |
|
|
class_loss = compute_per_class_loss(model, inputs, targets, criterion) |
|
|
|
|
|
for cls_idx in range(4): |
|
|
if class_acc[cls_idx] is not None: |
|
|
metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx]) |
|
|
else: |
|
|
metrics['class_accuracies'][cls_idx].append(0.0) |
|
|
|
|
|
if class_loss[cls_idx] is not None: |
|
|
metrics['class_losses'][cls_idx].append(class_loss[cls_idx]) |
|
|
else: |
|
|
metrics['class_losses'][cls_idx].append(0.0) |
|
|
|
|
|
metrics['accuracy_steps'].append(step) |
|
|
|
|
|
step += 1 |
|
|
|
|
|
avg_loss = np.mean(epoch_losses) |
|
|
class_acc = compute_per_class_accuracy(model, inputs, targets) |
|
|
class_loss = compute_per_class_loss(model, inputs, targets, criterion) |
|
|
eff_dim = compute_effective_dimension(model.get_embeddings()) |
|
|
|
|
|
b_acc = f"{class_acc[1]:.3f}" if class_acc[1] is not None else "N/A" |
|
|
b_loss = f"{class_loss[1]:.3f}" if class_loss[1] is not None else "N/A" |
|
|
|
|
|
print(f"Epoch {epoch+1:2d}/{n_epochs}: Loss={avg_loss:.4f} | " |
|
|
f"Acc A={class_acc[0]:.3f} B={b_acc} | " |
|
|
f"Loss A={class_loss[0]:.3f} B={b_loss} | " |
|
|
f"EffDim={eff_dim:.3f}") |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def plot_comprehensive_analysis(metrics_no_clip: Dict, metrics_with_clip: Dict, |
|
|
rare_indices: List[int], filename: str, |
|
|
n_samples: int = 1000): |
|
|
"""Create comprehensive 8-panel analysis.""" |
|
|
fig = plt.figure(figsize=(20, 16)) |
|
|
gs = fig.add_gridspec(4, 2, hspace=0.35, wspace=0.25) |
|
|
|
|
|
n_epochs = len(metrics_no_clip['losses']) // n_samples |
|
|
|
|
|
|
|
|
ax1 = fig.add_subplot(gs[0, 0]) |
|
|
ax2 = fig.add_subplot(gs[0, 1]) |
|
|
|
|
|
ax1.plot(metrics_no_clip['effective_dim_steps'], metrics_no_clip['effective_dims'], |
|
|
'b-', linewidth=2, marker='o', markersize=3) |
|
|
ax1.set_ylabel('Effective Dimension', fontsize=11) |
|
|
ax1.set_title('Effective Dim - WITHOUT Clipping', fontsize=12, fontweight='bold', color='red') |
|
|
ax1.grid(True, alpha=0.3) |
|
|
ax1.set_ylim([2.0, 3.5]) |
|
|
|
|
|
ax2.plot(metrics_with_clip['effective_dim_steps'], metrics_with_clip['effective_dims'], |
|
|
'g-', linewidth=2, marker='o', markersize=3) |
|
|
ax2.set_title('Effective Dim - WITH Clipping', fontsize=12, fontweight='bold', color='green') |
|
|
ax2.grid(True, alpha=0.3) |
|
|
ax2.set_ylim([2.0, 3.5]) |
|
|
|
|
|
|
|
|
ax3 = fig.add_subplot(gs[1, 0]) |
|
|
ax4 = fig.add_subplot(gs[1, 1]) |
|
|
|
|
|
ax3.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][0], |
|
|
'r-', linewidth=2, alpha=0.7, label='Without Clip') |
|
|
ax3.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][0], |
|
|
'g-', linewidth=2, alpha=0.7, label='With Clip') |
|
|
ax3.set_ylabel('Accuracy', fontsize=11) |
|
|
ax3.set_title("Common Class 'A' Accuracy", fontsize=12, fontweight='bold') |
|
|
ax3.legend() |
|
|
ax3.grid(True, alpha=0.3) |
|
|
ax3.set_ylim([0, 1.05]) |
|
|
|
|
|
ax4.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][1], |
|
|
'r-', linewidth=2, alpha=0.7, label='Without Clip') |
|
|
ax4.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][1], |
|
|
'g-', linewidth=2, alpha=0.7, label='With Clip') |
|
|
ax4.set_title("Rare Class 'B' Accuracy [KEY PREDICTION]", fontsize=12, fontweight='bold', color='purple') |
|
|
ax4.legend() |
|
|
ax4.grid(True, alpha=0.3) |
|
|
ax4.set_ylim([0, 1.05]) |
|
|
|
|
|
|
|
|
ax5 = fig.add_subplot(gs[2, 0]) |
|
|
ax6 = fig.add_subplot(gs[2, 1]) |
|
|
|
|
|
ax5.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_losses'][0], |
|
|
'r-', linewidth=2, alpha=0.7, label='Without Clip') |
|
|
ax5.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_losses'][0], |
|
|
'g-', linewidth=2, alpha=0.7, label='With Clip') |
|
|
ax5.set_ylabel('Loss', fontsize=11) |
|
|
ax5.set_title("Common Class 'A' Loss", fontsize=12, fontweight='bold') |
|
|
ax5.legend() |
|
|
ax5.grid(True, alpha=0.3) |
|
|
|
|
|
ax6.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_losses'][1], |
|
|
'r-', linewidth=2, alpha=0.7, label='Without Clip') |
|
|
ax6.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_losses'][1], |
|
|
'g-', linewidth=2, alpha=0.7, label='With Clip') |
|
|
ax6.set_title("Rare Class 'B' Loss", fontsize=12, fontweight='bold') |
|
|
ax6.legend() |
|
|
ax6.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax7 = fig.add_subplot(gs[3, 0]) |
|
|
ax8 = fig.add_subplot(gs[3, 1]) |
|
|
|
|
|
steps = range(len(metrics_no_clip['grad_norms'])) |
|
|
|
|
|
ax7.plot(steps, metrics_no_clip['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') |
|
|
ax7.plot(steps, metrics_with_clip['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') |
|
|
ax7.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
|
|
ax7.set_ylabel('Gradient Norm', fontsize=11) |
|
|
ax7.set_xlabel('Training Step', fontsize=11) |
|
|
ax7.set_title('Gradient Norms', fontsize=12, fontweight='bold') |
|
|
ax7.legend() |
|
|
ax7.grid(True, alpha=0.3) |
|
|
|
|
|
ax8.plot(steps, metrics_no_clip['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') |
|
|
ax8.plot(steps, metrics_with_clip['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') |
|
|
ax8.set_xlabel('Training Step', fontsize=11) |
|
|
ax8.set_title('Weight Norms', fontsize=12, fontweight='bold') |
|
|
ax8.legend() |
|
|
ax8.grid(True, alpha=0.3) |
|
|
|
|
|
fig.suptitle('Extended Gradient Clipping Analysis: Testing Physics-of-AI Predictions\n' |
|
|
f'(10 epochs, lr=0.01, 990 common / 10 rare samples)', |
|
|
fontsize=14, fontweight='bold', y=1.01) |
|
|
|
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"Comprehensive analysis saved to: {filename}") |
|
|
|
|
|
|
|
|
def plot_rare_sample_dynamics(metrics_no_clip: Dict, metrics_with_clip: Dict, |
|
|
filename: str): |
|
|
"""Plot dynamics specifically at rare sample positions.""" |
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
|
|
|
|
|
|
|
|
ax1 = axes[0, 0] |
|
|
ax1.plot(metrics_no_clip['rare_sample_steps'], metrics_no_clip['rare_sample_losses'], |
|
|
'ro-', alpha=0.7, markersize=3, linewidth=0.5, label='Without Clip') |
|
|
ax1.plot(metrics_with_clip['rare_sample_steps'], metrics_with_clip['rare_sample_losses'], |
|
|
'go-', alpha=0.7, markersize=3, linewidth=0.5, label='With Clip') |
|
|
ax1.set_ylabel('Loss at Rare Sample', fontsize=11) |
|
|
ax1.set_title('Loss When Encountering Rare Samples', fontsize=12, fontweight='bold') |
|
|
ax1.legend() |
|
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax2 = axes[0, 1] |
|
|
ax2.hist(metrics_no_clip['rare_sample_losses'], bins=30, alpha=0.5, color='red', |
|
|
label=f"Without Clip (mean={np.mean(metrics_no_clip['rare_sample_losses']):.3f})") |
|
|
ax2.hist(metrics_with_clip['rare_sample_losses'], bins=30, alpha=0.5, color='green', |
|
|
label=f"With Clip (mean={np.mean(metrics_with_clip['rare_sample_losses']):.3f})") |
|
|
ax2.set_xlabel('Loss', fontsize=11) |
|
|
ax2.set_ylabel('Count', fontsize=11) |
|
|
ax2.set_title('Distribution of Rare Sample Losses', fontsize=12, fontweight='bold') |
|
|
ax2.legend() |
|
|
ax2.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax3 = axes[1, 0] |
|
|
|
|
|
|
|
|
n_samples = 1000 |
|
|
n_epochs = len(metrics_no_clip['losses']) // n_samples |
|
|
rare_indices = [25, 104, 114, 142, 228, 250, 281, 654, 754, 759] |
|
|
|
|
|
rare_grad_norms_no = [] |
|
|
rare_grad_norms_with = [] |
|
|
rare_steps = [] |
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
for idx in rare_indices: |
|
|
step = epoch * n_samples + idx |
|
|
if step < len(metrics_no_clip['grad_norms']): |
|
|
rare_grad_norms_no.append(metrics_no_clip['grad_norms'][step]) |
|
|
rare_grad_norms_with.append(metrics_with_clip['grad_norms'][step]) |
|
|
rare_steps.append(step) |
|
|
|
|
|
ax3.scatter(rare_steps, rare_grad_norms_no, c='red', alpha=0.6, s=20, label='Without Clip') |
|
|
ax3.scatter(rare_steps, rare_grad_norms_with, c='green', alpha=0.6, s=20, label='With Clip') |
|
|
ax3.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
|
|
ax3.set_xlabel('Training Step', fontsize=11) |
|
|
ax3.set_ylabel('Gradient Norm', fontsize=11) |
|
|
ax3.set_title('Gradient Norms at Rare Sample Positions', fontsize=12, fontweight='bold') |
|
|
ax3.legend() |
|
|
ax3.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax4 = axes[1, 1] |
|
|
ax4.axis('off') |
|
|
|
|
|
mean_rare_loss_no = np.mean(metrics_no_clip['rare_sample_losses']) |
|
|
mean_rare_loss_with = np.mean(metrics_with_clip['rare_sample_losses']) |
|
|
mean_rare_grad_no = np.mean(rare_grad_norms_no) |
|
|
mean_rare_grad_with = np.mean(rare_grad_norms_with) |
|
|
|
|
|
|
|
|
final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1] if metrics_no_clip['class_accuracies'][1] else 0 |
|
|
final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1] if metrics_with_clip['class_accuracies'][1] else 0 |
|
|
|
|
|
summary_text = f""" |
|
|
RARE SAMPLE DYNAMICS SUMMARY |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
At Rare Sample Positions: |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
Mean Loss (WITHOUT Clipping): {mean_rare_loss_no:.4f} |
|
|
Mean Loss (WITH Clipping): {mean_rare_loss_with:.4f} |
|
|
Loss Reduction: {(mean_rare_loss_no - mean_rare_loss_with) / mean_rare_loss_no * 100:+.1f}% |
|
|
|
|
|
Mean Gradient Norm (WITHOUT): {mean_rare_grad_no:.4f} |
|
|
Mean Gradient Norm (WITH): {mean_rare_grad_with:.4f} |
|
|
Gradient Reduction: {(mean_rare_grad_no - mean_rare_grad_with) / mean_rare_grad_no * 100:+.1f}% |
|
|
|
|
|
Final Rare Class Accuracy: |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
WITHOUT Clipping: {final_acc_b_no:.1%} |
|
|
WITH Clipping: {final_acc_b_with:.1%} |
|
|
|
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
PHYSICS-OF-AI INTERPRETATION: |
|
|
|
|
|
Gradient clipping acts as a "velocity limiter" in |
|
|
weight space, preventing the model from making |
|
|
sudden large updates when encountering rare samples. |
|
|
|
|
|
This allows the model to gradually learn the rare |
|
|
class pattern rather than overshooting and forgetting. |
|
|
""" |
|
|
|
|
|
ax4.text(0.05, 0.5, summary_text, transform=ax4.transAxes, |
|
|
fontsize=10, verticalalignment='center', fontfamily='monospace', |
|
|
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.9)) |
|
|
|
|
|
fig.suptitle('Rare Sample Dynamics Analysis\n' |
|
|
'(How the model behaves when encountering rare class B samples)', |
|
|
fontsize=14, fontweight='bold', y=1.01) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(filename, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print(f"Rare sample dynamics plot saved to: {filename}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("="*70) |
|
|
print("EXTENDED GRADIENT CLIPPING EXPERIMENT V2") |
|
|
print("Testing Physics-of-AI Predictions with Extended Training") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED) |
|
|
|
|
|
print(f"\nDataset: {len(inputs)} samples ({(targets == 0).sum().item()} common, {(targets == 1).sum().item()} rare)") |
|
|
print(f"Rare indices: {rare_indices}") |
|
|
|
|
|
|
|
|
set_seeds(SEED) |
|
|
init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
|
|
init_weights = {name: param.clone() for name, param in init_model.state_dict().items()} |
|
|
|
|
|
init_eff_dim = compute_effective_dimension(init_model.get_embeddings()) |
|
|
print(f"Initial effective dimension: {init_eff_dim:.3f}") |
|
|
|
|
|
|
|
|
n_epochs = 10 |
|
|
lr = 0.01 |
|
|
|
|
|
|
|
|
metrics_no_clip = train_with_tracking( |
|
|
inputs, targets, rare_indices, |
|
|
clip_grad=False, n_epochs=n_epochs, lr=lr, |
|
|
init_weights=init_weights, track_every=100 |
|
|
) |
|
|
|
|
|
|
|
|
metrics_with_clip = train_with_tracking( |
|
|
inputs, targets, rare_indices, |
|
|
clip_grad=True, max_norm=1.0, n_epochs=n_epochs, lr=lr, |
|
|
init_weights=init_weights, track_every=100 |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("GENERATING ANALYSIS PLOTS") |
|
|
print("="*70) |
|
|
|
|
|
plot_comprehensive_analysis( |
|
|
metrics_no_clip, metrics_with_clip, rare_indices, |
|
|
"extended_analysis_v2.png" |
|
|
) |
|
|
|
|
|
plot_rare_sample_dynamics( |
|
|
metrics_no_clip, metrics_with_clip, |
|
|
"rare_sample_dynamics.png" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("FINAL PREDICTION TEST RESULTS") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
dims_no = metrics_no_clip['effective_dims'] |
|
|
dims_with = metrics_with_clip['effective_dims'] |
|
|
|
|
|
print("\n[PREDICTION 2] Representation Collapse:") |
|
|
print(f" Effective Dim Variance (WITHOUT): {np.std(dims_no):.6f}") |
|
|
print(f" Effective Dim Variance (WITH): {np.std(dims_with):.6f}") |
|
|
print(f" Verdict: {'SUPPORTED' if np.std(dims_no) > np.std(dims_with) else 'NOT SUPPORTED'}") |
|
|
|
|
|
|
|
|
final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1] |
|
|
final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1] |
|
|
|
|
|
print("\n[PREDICTION 4] Rare Sample Learning:") |
|
|
print(f" Final Rare Class Accuracy (WITHOUT): {final_acc_b_no:.1%}") |
|
|
print(f" Final Rare Class Accuracy (WITH): {final_acc_b_with:.1%}") |
|
|
print(f" Verdict: {'SUPPORTED' if final_acc_b_with >= final_acc_b_no else 'NOT SUPPORTED'}") |
|
|
|
|
|
return { |
|
|
'metrics_no_clip': metrics_no_clip, |
|
|
'metrics_with_clip': metrics_with_clip, |
|
|
'rare_indices': rare_indices, |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
results = main() |
|
|
print("\n" + "="*70) |
|
|
print("EXPERIMENT COMPLETE!") |
|
|
print("="*70) |
|
|
|