svd-triton / svd_conv_cifar100_analysis.py
AbstractPhil's picture
Rename svd_conv_cifar100_test.py to svd_conv_cifar100_analysis.py
e2643f8 verified
# @title Analysis β€” SVD Test Model Geometric Structure
#
# Probes the trained model_svd_test for:
# 1. Embedding geometry: CV, cosine similarity, effective dimension
# 2. Per-class structure: inter/intra class cosine, class separation
# 3. SVD feature analysis: S spectrum, Vh structure per depth
# 4. Feature attribution: how much does SVD vs conv contribute?
# 5. kNN accuracy at different k values
"""
================================================================================
SVD MODEL GEOMETRIC ANALYSIS β€” 5000 samples
Val accuracy: 70.2%
================================================================================
────────────────────────────────────────────────────────────
1. EMBEDDING GEOMETRY
────────────────────────────────────────────────────────────
SVD (264-d):
CV: 0.1241
Cosine sim: 0.9989 Β± 0.0002
Eff dimension: 131.2 / 264
Energy top-5: 5.5% top-10: 10.4% top-20: 19.5%
Conv (384-d):
CV: 0.1320
Cosine sim: 0.5440 Β± 0.0678
Eff dimension: 201.0 / 384
Energy top-5: 7.6% top-10: 13.2% top-20: 22.2%
Combined (648-d):
CV: 0.4140
Cosine sim: 0.9510 Β± 0.0163
Eff dimension: 240.3 / 648
Energy top-5: 6.9% top-10: 11.9% top-20: 20.1%
────────────────────────────────────────────────────────────
2. PER-CLASS STRUCTURE (combined features)
────────────────────────────────────────────────────────────
Inter-class cosine: 0.9838 Β± 0.0067 (max: 0.9989)
Intra-class cosine: 0.9833 Β± 0.0023
Separation ratio: 1.00x (higher = better separated)
────────────────────────────────────────────────────────────
3. SVD SPECTRUM PER DEPTH
────────────────────────────────────────────────────────────
Depth 0 (32Γ—32, rank=32):
S mean: [24.44, 15.68, 12.95, 11.27, 9.76, 8.71...]
Energy: top-1=14.4% top-5=42.9% top-10=63.7%
Entropy: 3.132 / 3.466 (90% of max)
Vh diag: 3.1% energy on diagonal
Cond ratio: 26.5x
Depth 1 (16Γ—16, rank=32):
S mean: [16.59, 10.65, 9.05, 7.93, 7.07, 6.40...]
Energy: top-1=12.4% top-5=38.0% top-10=58.2%
Entropy: 3.233 / 3.466 (93% of max)
Vh diag: 3.1% energy on diagonal
Cond ratio: 14.8x
Depth 2 (8Γ—8, rank=32):
S mean: [6.35, 4.89, 4.21, 3.72, 3.35, 3.06...]
Energy: top-1=10.7% top-5=37.9% top-10=59.9%
Entropy: 3.214 / 3.466 (93% of max)
Vh diag: 3.0% energy on diagonal
Cond ratio: 18.2x
Depth 3 (4Γ—4, rank=32):
S mean: [7.20, 3.67, 2.70, 2.19, 1.82, 1.53...]
Energy: top-1=27.2% top-5=66.9% top-10=88.8%
Entropy: 2.372 / 3.466 (68% of max)
Vh diag: 3.1% energy on diagonal
Cond ratio: 7198608.5x
────────────────────────────────────────────────────────────
4. FEATURE ATTRIBUTION β€” SVD vs Conv
────────────────────────────────────────────────────────────
Full features: 70.2%
Zero SVD (264-d): 9.9% (drop: +60.3)
Zero conv (384-d): 0.8% (drop: +69.4)
SVD contribution: 60.3 points
Conv contribution: 69.4 points
Per-depth SVD ablation (zero one depth at a time):
Zero depth 0: 57.7% (drop: +12.6)
Zero depth 1: 58.4% (drop: +11.8)
Zero depth 2: 59.9% (drop: +10.4)
Zero depth 3: 61.5% (drop: +8.7)
────────────────────────────────────────────────────────────
5. kNN ACCURACY
────────────────────────────────────────────────────────────
SVD only kNN-1: 1.2%
SVD only kNN-5: 1.3%
SVD only kNN-10: 1.4%
Conv only kNN-1: 53.4%
Conv only kNN-5: 57.2%
Conv only kNN-10: 58.6%
Combined kNN-1: 48.5%
Combined kNN-5: 52.8%
Combined kNN-10: 55.7%
────────────────────────────────────────────────────────────
6. GEOMETRIC CONSTANTS
────────────────────────────────────────────────────────────
CV multi-scale: {3: 0.2597, 4: 0.3323, 5: 0.38, 6: 0.4283, 7: 0.4337, 8: 0.2682}
CV pentachoron: 0.3800 β€” outside band
Embedding norms: 65.7458 Β± 1.4738
================================================================================
ANALYSIS COMPLETE
================================================================================
"""
import numpy as np
from collections import defaultdict
@torch.no_grad()
def analyze_svd_model(model, val_loader, device, n_max=5000):
"""Comprehensive geometric analysis of the trained SVD test model."""
model.eval()
model = model.to(device)
# ── Collect all features ──
all_svd_feats = [] # per-depth SVD features
all_conv_feats = [] # pooled conv features
all_combined = [] # full classifier input
all_logits = []
all_labels = []
all_S = [[], [], [], []] # singular values per depth
all_Vh = [[], [], [], []] # rotation matrices per depth
n_collected = 0
for images, labels in val_loader:
if n_collected >= n_max:
break
images, labels = images.to(device), labels.to(device)
B = images.shape[0]
# Run through stages manually to collect intermediates
h = images
svd_feats_batch = []
for i, (stage, pool, proj) in enumerate(zip(model.stages, model.pools, model.to_svd)):
h = stage(h)
h_svd = proj(h)
H, W = h_svd.shape[2], h_svd.shape[3]
h_flat = h_svd.permute(0, 2, 3, 1).reshape(B, H * W, model.svd_rank)
with torch.amp.autocast('cuda', enabled=False):
with torch.no_grad():
_, S, Vh = gram_eigh_svd(h_flat.float())
S = S.clamp(min=1e-6)
all_S[i].append(S.cpu())
all_Vh[i].append(Vh.cpu())
svd_feats_batch.append(model._extract_svd_features(S, Vh))
h = pool(h)
conv_feat = model.final_pool(h).flatten(1)
combined = torch.cat(svd_feats_batch + [conv_feat], dim=-1)
logits = model.classifier(combined)
all_svd_feats.append(torch.cat(svd_feats_batch, dim=-1).cpu())
all_conv_feats.append(conv_feat.cpu())
all_combined.append(combined.cpu())
all_logits.append(logits.cpu())
all_labels.append(labels.cpu())
n_collected += B
svd_feats = torch.cat(all_svd_feats)[:n_max]
conv_feats = torch.cat(all_conv_feats)[:n_max]
combined = torch.cat(all_combined)[:n_max]
logits = torch.cat(all_logits)[:n_max]
labels = torch.cat(all_labels)[:n_max]
S_all = [torch.cat(s)[:n_max] for s in all_S]
Vh_all = [torch.cat(v)[:n_max] for v in all_Vh]
acc = (logits.argmax(-1) == labels).float().mean().item() * 100
n = svd_feats.shape[0]
print(f"\n{'='*80}")
print(f" SVD MODEL GEOMETRIC ANALYSIS β€” {n} samples")
print(f" Val accuracy: {acc:.1f}%")
print(f"{'='*80}")
# ════════════════════════════════════════════════════════════════════
# 1. EMBEDDING GEOMETRY
# ════════════════════════════════════════════════════════════════════
print(f"\n{'─'*60}")
print(f" 1. EMBEDDING GEOMETRY")
print(f"{'─'*60}")
for name, feats in [("SVD (264-d)", svd_feats), ("Conv (384-d)", conv_feats),
("Combined (648-d)", combined)]:
feats_n = F.normalize(feats.float(), dim=-1).to(device)
# CV
cv = cv_metric(feats_n, n_samples=200)
# Cosine similarity distribution
sub = feats_n[:min(2000, n)]
sim = sub @ sub.T
mask = ~torch.eye(sub.shape[0], dtype=torch.bool, device=device)
cos_mean = sim[mask].mean().item()
cos_std = sim[mask].std().item()
# Effective dimension via SVD
centered = feats.float()[:2000] - feats.float()[:2000].mean(0)
sv = torch.linalg.svdvals(centered.to(device))
sv_norm = sv / sv.sum()
eff_dim = (sv.sum() ** 2 / (sv ** 2).sum()).item()
top5_energy = sv_norm[:5].sum().item()
top10_energy = sv_norm[:10].sum().item()
top20_energy = sv_norm[:20].sum().item()
print(f"\n {name}:")
print(f" CV: {cv:.4f}")
print(f" Cosine sim: {cos_mean:.4f} Β± {cos_std:.4f}")
print(f" Eff dimension: {eff_dim:.1f} / {feats.shape[1]}")
print(f" Energy top-5: {top5_energy*100:.1f}% top-10: {top10_energy*100:.1f}% top-20: {top20_energy*100:.1f}%")
# ════════════════════════════════════════════════════════════════════
# 2. PER-CLASS STRUCTURE
# ════════════════════════════════════════════════════════════════════
print(f"\n{'─'*60}")
print(f" 2. PER-CLASS STRUCTURE (combined features)")
print(f"{'─'*60}")
combined_n = F.normalize(combined.float(), dim=-1).to(device)
classes = labels.unique().sort().values
# Class centroids
centroids = []
for c in classes:
mask_c = labels == c
if mask_c.sum() > 0:
centroids.append(F.normalize(combined_n[mask_c].mean(0, keepdim=True), dim=-1))
centroids = torch.cat(centroids) # (n_classes, D)
# Inter-class cosine (centroid-to-centroid)
inter_sim = centroids @ centroids.T
inter_mask = ~torch.eye(centroids.shape[0], dtype=torch.bool, device=device)
inter_mean = inter_sim[inter_mask].mean().item()
inter_std = inter_sim[inter_mask].std().item()
inter_max = inter_sim[inter_mask].max().item()
# Intra-class cosine (samples to their centroid)
intra_cos = []
for i, c in enumerate(classes):
mask_c = labels == c
if mask_c.sum() > 1:
sims = (combined_n[mask_c] @ centroids[i]).mean().item()
intra_cos.append(sims)
intra_mean = np.mean(intra_cos)
intra_std = np.std(intra_cos)
# Class separation ratio
sep_ratio = intra_mean / (inter_mean + 1e-8)
print(f" Inter-class cosine: {inter_mean:.4f} Β± {inter_std:.4f} (max: {inter_max:.4f})")
print(f" Intra-class cosine: {intra_mean:.4f} Β± {intra_std:.4f}")
print(f" Separation ratio: {sep_ratio:.2f}x (higher = better separated)")
# ════════════════════════════════════════════════════════════════════
# 3. SVD FEATURE ANALYSIS PER DEPTH
# ════════════════════════════════════════════════════════════════════
print(f"\n{'─'*60}")
print(f" 3. SVD SPECTRUM PER DEPTH")
print(f"{'─'*60}")
spatial_names = ["32Γ—32", "16Γ—16", "8Γ—8", "4Γ—4"]
for i in range(4):
S = S_all[i] # (n, k)
Vh = Vh_all[i] # (n, k, k)
k = S.shape[1]
# Singular value statistics
s_mean = S.mean(0)
s_std = S.std(0)
# Energy concentration
s_norm = S / S.sum(dim=-1, keepdim=True).clamp(min=1e-8)
top1_energy = s_norm[:, 0].mean().item()
top5_energy = s_norm[:, :5].sum(-1).mean().item()
top10_energy = s_norm[:, :min(10, k)].sum(-1).mean().item()
# Spectral entropy
s_ent = -(s_norm * (s_norm.clamp(min=1e-8)).log()).sum(-1).mean().item()
max_ent = math.log(k)
# Vh structure: how diagonal is the rotation?
vh_diag = Vh.diagonal(dim1=-2, dim2=-1) # (n, k)
diag_energy = vh_diag.pow(2).sum(-1).mean().item()
total_energy = Vh.pow(2).sum((-2, -1)).mean().item()
diag_ratio = diag_energy / (total_energy + 1e-8)
# Condition number proxy: ratio of largest to smallest S
cond = (S[:, 0] / S[:, -1].clamp(min=1e-8)).mean().item()
print(f"\n Depth {i} ({spatial_names[i]}, rank={k}):")
print(f" S mean: [{', '.join(f'{v:.2f}' for v in s_mean[:6].tolist())}{'...' if k > 6 else ''}]")
print(f" Energy: top-1={top1_energy*100:.1f}% top-5={top5_energy*100:.1f}% top-10={top10_energy*100:.1f}%")
print(f" Entropy: {s_ent:.3f} / {max_ent:.3f} ({s_ent/max_ent*100:.0f}% of max)")
print(f" Vh diag: {diag_ratio*100:.1f}% energy on diagonal")
print(f" Cond ratio: {cond:.1f}x")
# ════════════════════════════════════════════════════════════════════
# 4. FEATURE ATTRIBUTION β€” SVD vs Conv
# ════════════════════════════════════════════════════════════════════
print(f"\n{'─'*60}")
print(f" 4. FEATURE ATTRIBUTION β€” SVD vs Conv")
print(f"{'─'*60}")
# Test classifier with zeroed-out SVD or conv features
model_device = next(model.parameters()).device
# Full features
full_logits = model.classifier(combined.to(model_device))
full_acc = (full_logits.argmax(-1) == labels.to(model_device)).float().mean().item() * 100
# Zero SVD features (first 264 dims)
no_svd = combined.clone()
no_svd[:, :264] = 0.0
no_svd_logits = model.classifier(no_svd.to(model_device))
no_svd_acc = (no_svd_logits.argmax(-1) == labels.to(model_device)).float().mean().item() * 100
# Zero conv features (last 384 dims)
no_conv = combined.clone()
no_conv[:, 264:] = 0.0
no_conv_logits = model.classifier(no_conv.to(model_device))
no_conv_acc = (no_conv_logits.argmax(-1) == labels.to(model_device)).float().mean().item() * 100
# Per-depth SVD ablation
depth_accs = []
for d in range(4):
ablated = combined.clone()
start = d * 66
ablated[:, start:start+66] = 0.0
abl_logits = model.classifier(ablated.to(model_device))
abl_acc = (abl_logits.argmax(-1) == labels.to(model_device)).float().mean().item() * 100
depth_accs.append(abl_acc)
print(f" Full features: {full_acc:.1f}%")
print(f" Zero SVD (264-d): {no_svd_acc:.1f}% (drop: {full_acc - no_svd_acc:+.1f})")
print(f" Zero conv (384-d): {no_conv_acc:.1f}% (drop: {full_acc - no_conv_acc:+.1f})")
print(f" SVD contribution: {full_acc - no_svd_acc:.1f} points")
print(f" Conv contribution: {full_acc - no_conv_acc:.1f} points")
print(f"\n Per-depth SVD ablation (zero one depth at a time):")
for d in range(4):
drop = full_acc - depth_accs[d]
print(f" Zero depth {d}: {depth_accs[d]:.1f}% (drop: {drop:+.1f})")
# ════════════════════════════════════════════════════════════════════
# 5. kNN ACCURACY
# ════════════════════════════════════════════════════════════════════
print(f"\n{'─'*60}")
print(f" 5. kNN ACCURACY")
print(f"{'─'*60}")
for name, feats in [("SVD only", svd_feats), ("Conv only", conv_feats),
("Combined", combined)]:
feats_n = F.normalize(feats.float(), dim=-1).to(device)
sub = feats_n[:min(5000, n)]
sub_labels = labels[:min(5000, n)].to(device)
for k_val in [1, 5, 10]:
knn = knn_accuracy(sub, sub_labels, k=k_val) * 100
print(f" {name:>12} kNN-{k_val}: {knn:.1f}%")
print()
# ════════════════════════════════════════════════════════════════════
# 6. GEOMETRIC CONSTANTS CHECK
# ════════════════════════════════════════════════════════════════════
print(f"{'─'*60}")
print(f" 6. GEOMETRIC CONSTANTS")
print(f"{'─'*60}")
combined_n = F.normalize(combined.float(), dim=-1).to(device)
# Multi-scale CV
cv_scales = cv_multi_scale(combined_n[:2000], scales=(3, 4, 5, 6, 7, 8))
print(f" CV multi-scale: {cv_scales}")
# Check for 0.20-0.23 universal attractor band
cv5 = cv_scales.get(5, None)
if cv5 is not None:
in_band = 0.20 <= cv5 <= 0.23
print(f" CV pentachoron: {cv5:.4f} β€” {'IN BAND [0.20-0.23]' if in_band else 'outside band'}")
# Norm distribution
norms = combined.float().norm(dim=-1)
print(f" Embedding norms: {norms.mean():.4f} Β± {norms.std():.4f}")
print(f"\n{'='*80}")
print(f" ANALYSIS COMPLETE")
print(f"{'='*80}")
# ── Run ──────────────────────────────────────────────────────────────────────
analyze_svd_model(model_svd_test, val_loader, device)