AbstractPhil's picture
Create probe.py
5624568 verified
"""
Comprehensive testing cell for BaselineViT (RoseFace-aware)
Run AFTER loading your model & checkpoint in Colab.
Assumes: model, get_cifar100_dataloaders are defined.
"""
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from pathlib import Path
import json
from tqdm import tqdm
# =========================
# RoseFace-aware utilities
# =========================
@torch.no_grad()
def principal_angle_overlap(class_pentachora):
"""
Measure subspace overlap between classes (lower = better decoupling).
Returns (mean_fro, std_fro) across all class pairs.
"""
device = class_pentachora[0].vertices.device
dtype = class_pentachora[0].vertices.dtype
C = len(class_pentachora)
U = []
for p in class_pentachora:
V = p.vertices.to(device=device, dtype=dtype) # [5,D]
c = V.mean(dim=0, keepdim=True)
A = V - c # [5,D]
# QR on D x 5 (A^T) → orthonormal basis in R^D
Q, _ = torch.linalg.qr(A.t(), mode='reduced') # [D, r]
U.append(Q)
overlaps = []
for a in range(C):
for b in range(a+1, C):
M = U[a].t() @ U[b] # [r_a, r_b]
overlaps.append(torch.linalg.norm(M, 'fro').item())
if not overlaps:
return 0.0, 0.0
return float(np.mean(overlaps)), float(np.std(overlaps))
@torch.no_grad()
def face_usage_heatmap(model, features_proj, targets, norm_type='l1'):
"""
Compute per-class face (triad) usage heatmap [C,10].
features_proj: [N,D] L1-normalized (from model forward outputs)
"""
device, dtype = features_proj.device, features_proj.dtype
C = model.num_classes
triplets = torch.tensor([
[0,1,2],[0,1,3],[0,1,4],
[0,2,3],[0,2,4],[0,3,4],
[1,2,3],[1,2,4],[1,3,4],
[2,3,4]
], device=device, dtype=torch.long)
counts = torch.zeros(C, 10, device=device, dtype=torch.long)
for cls in torch.unique(targets):
idx = (targets == cls)
if idx.sum() == 0:
continue
f = features_proj[idx] # [b,D]
p = model.class_pentachora[int(cls)]
Vn = p.vertices_norm if norm_type == 'l1' else F.normalize(p.vertices, dim=-1) # [5,D]
# Build 10 faces
faces = []
for t in triplets:
b = (Vn[t[0]] + Vn[t[1]] + Vn[t[2]]) / 3.0
if norm_type == 'l1':
b = b / (b.abs().sum() + 1e-8)
else:
b = F.normalize(b.unsqueeze(0), dim=-1).squeeze(0)
faces.append(b)
F10 = torch.stack(faces, dim=0) # [10,D]
sims = f @ F10.t() # [b,10]
winner = sims.argmax(dim=1) # [b]
binc = torch.bincount(winner, minlength=10) # [10]
counts[int(cls)] += binc
counts = counts.float()
counts = counts / (counts.sum(dim=1, keepdim=True) + 1e-9)
return counts # [C,10]
@torch.no_grad()
def margin_stats(cos_pre, targets):
"""
Compute margin Δ = pos - best_neg from PRE-margin cosines.
"""
pos = cos_pre.gather(1, targets.view(-1,1)).squeeze(1) # [B]
masked = cos_pre.masked_fill(F.one_hot(targets, cos_pre.size(1)).bool(), -1e9)
neg = masked.max(dim=1).values # [B]
delta = pos - neg
return float(delta.mean()), float(delta.std())
# ============================================
# FEATURE EXTRACTION AND ANALYSIS (upgraded)
# ============================================
class FeatureAnalyzer:
"""
Analyze feature capacity and geometric patterns.
Now aware of RoseFace:
- can run with or without margin at inference (margin_mode)
- stores pre-margin cosines and post-margin cosines
"""
def __init__(self, model, dataloader, device=None, margin_mode='none'):
"""
margin_mode:
'none' -> don't pass targets to model (no margin at eval)
'apply' -> pass targets (apply margin at eval)
'both' -> run both (twice); store *_nomargin and *_margin
"""
self.model = model
self.dataloader = dataloader
self.device = device or next(model.parameters()).device
self.model.eval()
assert margin_mode in ('none','apply','both')
self.margin_mode = margin_mode
def _forward_once(self, images, labels, apply_margin):
# forward; return dict of tensors on CPU
if apply_margin:
outputs = self.model(images, return_features=True, targets=labels)
else:
outputs = self.model(images, return_features=True)
out = {k: v.detach().cpu() for k, v in outputs.items() if isinstance(v, torch.Tensor)}
# Derive post-margin cosines (if RoseFace): cos_post = logits / s
if getattr(self.model, 'head_type', 'legacy') == 'roseface':
s = float(getattr(self.model, 'scale_s', 1.0))
if s > 0 and 'logits' in out:
out['cos_post'] = (out['logits'] / s)
return out
def extract_all_features(self, max_batches=None):
"""
Extract features, pre-margin cosines, post-margin cosines (if available).
Returns dict with keys:
- cls_features
- features_proj
- similarities (pre-margin cos)
- cos_post (post-margin cos; RoseFace only)
- logits
- labels
If margin_mode == 'both', suffix *_nomargin / *_margin are included.
"""
agg = {}
def append_batch(prefix, out_tensors, labels):
# initialize lists
for k, v in out_tensors.items():
agg.setdefault(f'{prefix}{k}', []).append(v)
agg.setdefault(f'{prefix}labels', []).append(labels.cpu())
with torch.no_grad():
for i, (images, labels) in enumerate(tqdm(self.dataloader, desc="Extracting features")):
if max_batches is not None and i >= max_batches:
break
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
if self.margin_mode in ('none', 'both'):
out0 = self._forward_once(images, labels, apply_margin=False)
append_batch('', out0, labels)
if self.margin_mode in ('apply', 'both'):
out1 = self._forward_once(images, labels, apply_margin=True)
append_batch('m_', out1, labels)
# concat
# concat everything we collected
for k, lst in agg.items():
agg[k] = torch.cat(lst, dim=0)
# helper: pick normal key, else 'm_' fallback
def pick(key: str):
return agg.get(key, agg.get(f"m_{key}", torch.empty(0)))
# unify view for downstream code
if self.margin_mode == 'both':
features = {
'cls_features': pick('features'),
'features_proj': pick('features_proj'),
'similarities': agg.get('similarities', torch.empty(0)),
'cos_post': agg.get('cos_post', torch.empty(0)),
'labels': agg.get('labels', torch.empty(0)),
'similarities_margin': agg.get('m_similarities', torch.empty(0)),
'cos_post_margin': agg.get('m_cos_post', torch.empty(0)),
'logits': agg.get('logits', torch.empty(0)),
'logits_margin': agg.get('m_logits', torch.empty(0)),
}
else:
# works for BOTH margin_mode='none' and margin_mode='apply'
features = {
'cls_features': pick('features'),
'features_proj': pick('features_proj'),
'similarities': pick('similarities'), # pre-margin cosines
'cos_post': pick('cos_post'), # post-margin cosines (RoseFace)
'labels': pick('labels'),
'logits': pick('logits'),
}
return features
def analyze_feature_collapse(self, features):
print("\n=== FEATURE COLLAPSE ANALYSIS ===")
cls_features = features['cls_features'].numpy()
unique_patterns = self._count_unique_patterns(cls_features)
print(f"Estimated unique patterns: {unique_patterns}/100 classes")
feature_std = cls_features.std(axis=0).mean()
print(f"Average feature std: {feature_std:.4f}")
labels = features['labels'].numpy()
sample_size = min(1000, len(labels))
indices = np.random.choice(len(labels), sample_size, replace=False)
silhouette = silhouette_score(cls_features[indices], labels[indices])
print(f"Silhouette score: {silhouette:.3f}")
# centroid proximity count
class_features = {}
for i in range(100):
mask = labels == i
if mask.sum() > 0:
class_features[i] = cls_features[mask].mean(axis=0)
if class_features:
centroids = np.stack(list(class_features.values()))
d = np.linalg.norm(centroids[:, None] - centroids[None, :], axis=2)
thr = np.percentile(d[d > 0], 20)
close_pairs = (d < thr) & (d > 0)
classes_with_close_neighbors = close_pairs.sum(axis=1)
print(f"Classes with very similar features: {(classes_with_close_neighbors > 2).sum()}/100")
return {'unique_patterns': unique_patterns, 'feature_std': feature_std, 'silhouette': silhouette}
def analyze_geometric_patterns(self, features):
print("\n=== GEOMETRIC PATTERN ANALYSIS ===")
sims = features['similarities'] # pre-margin cosines [N,C]
print(f"Average max cosine: {sims.max(dim=1)[0].mean():.3f}")
print(f"Average min cosine: {sims.min(dim=1)[0].mean():.3f}")
print(f"Cosine std: {sims.std():.3f}")
high_sim_threshold = sims.mean() + sims.std()
high_sim_count = (sims > high_sim_threshold).sum(dim=1).float().mean()
print(f"Avg classes above (mean+std): {high_sim_count:.1f}/100")
labels = features['labels']
correct = sims.gather(1, labels.view(-1,1)).squeeze(1).mean().item()
wrong = (sims.sum(dim=1) - sims.gather(1, labels.view(-1,1)).squeeze(1)) / (sims.size(1)-1)
margin = (correct - wrong.mean().item())
print(f"Avg cosine margin (correct - mean wrong): {margin:.3f}")
# RoseFace-specific: if post cosines are present, compare deltas
if 'cos_post' in features and features['cos_post'].numel() > 0:
cos_post = features['cos_post']
# shift on target column
pos_pre = sims.gather(1, labels.view(-1,1))
pos_post = cos_post.gather(1, labels.view(-1,1))
shift = (pos_post - pos_pre).mean().item()
print(f"Avg target cosine shift (post - pre): {shift:.3f}")
return {
'max_cos': sims.max(dim=1)[0].mean().item(),
'cos_std': sims.std().item(),
'high_sim_classes': high_sim_count.item(),
'discrimination_margin': margin
}
def test_linear_probe(self, features, num_epochs=50):
print("\n=== LINEAR PROBE TEST ===")
X = features['cls_features']
y = features['labels']
n_train = int(0.8 * len(y))
X_train, y_train = X[:n_train], y[:n_train]
X_test, y_test = X[n_train:], y[n_train:]
probe = torch.nn.Linear(X_train.shape[1], 100).to(self.device)
opt = torch.optim.Adam(probe.parameters(), lr=0.01)
X_train = X_train.to(self.device); y_train = y_train.to(self.device)
X_test = X_test.to(self.device); y_test = y_test.to(self.device)
best = 0.0
for epoch in range(num_epochs):
logits = probe(X_train)
loss = F.cross_entropy(logits, y_train)
opt.zero_grad(); loss.backward(); opt.step()
if epoch % 10 == 0:
with torch.no_grad():
acc = (probe(X_test).argmax(dim=1) == y_test).float().mean().item()
best = max(best, acc)
print(f" Epoch {epoch}: Test acc = {acc*100:.1f}%")
with torch.no_grad():
final = (probe(X_test).argmax(dim=1) == y_test).float().mean().item()
best = max(best, final)
print(f"Best linear probe accuracy: {best*100:.1f}%")
return best
def visualize_features(self, features, method='tsne', n_samples=2000):
print(f"\n=== FEATURE VISUALIZATION ({method.upper()}) ===")
cls_features = features['cls_features'].numpy()
labels = features['labels'].numpy()
n_samples = min(n_samples, len(labels))
idx = np.random.choice(len(labels), n_samples, replace=False)
X = cls_features[idx]; y = labels[idx]
print(f"Reducing {n_samples} samples to 2D...")
reducer = TSNE(n_components=2, random_state=42, perplexity=30) if method=='tsne' else PCA(n_components=2)
X2 = reducer.fit_transform(X)
plt.figure(figsize=(12,9))
scatter = plt.scatter(X2[:,0], X2[:,1], c=y, cmap='nipy_spectral', alpha=0.6, s=15)
plt.title(f'Feature Space Visualization ({method.upper()})'); plt.xlabel('Comp 1'); plt.ylabel('Comp 2')
print("Estimating visual clusters...")
silhouette_scores, K = [], list(range(30, 60, 5))
for k in K:
kmeans = KMeans(n_clusters=k, random_state=42, n_init=3)
cls_lbl = kmeans.fit_predict(X2)
silhouette_scores.append(silhouette_score(X2, cls_lbl))
best_k = K[int(np.argmax(silhouette_scores))]
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=5)
cluster_labels = kmeans.fit_predict(X2)
n_populated = len(np.unique(cluster_labels))
plt.text(0.02, 0.98, f'Estimated clusters: {n_populated}', transform=plt.gca().transAxes,
va='top', fontsize=12, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
cbar = plt.colorbar(scatter, ticks=np.arange(0,100,10)); cbar.set_label('Class', rotation=270, labelpad=15)
plt.tight_layout(); plt.show()
return X2, n_populated
def analyze_pentachora_usage(self):
print("\n=== PENTACHORA USAGE ANALYSIS ===")
print(f"Classes: {self.model.num_classes}")
print(f"Embed dim: {self.model.embed_dim} | Penta dim: {self.model.pentachora_dim}")
print(f"Head: {getattr(self.model,'head_type','legacy')} | Prototype: {getattr(self.model,'prototype_mode','n/a')} | Margin: {getattr(self.model,'margin_type','n/a')}")
if hasattr(self.model, 'to_pentachora_dim'):
if isinstance(self.model.to_pentachora_dim, torch.nn.Linear):
print(f"Projection: Linear {self.model.embed_dim}{self.model.pentachora_dim}")
else:
print("Projection: Identity")
# Inter-class centroid similarity (legacy view)
centroids = self.model.get_class_centroids()
sim = centroids @ centroids.t()
mask = ~torch.eye(self.model.num_classes, dtype=bool, device=sim.device)
off = sim[mask]
print(f"\nCentroid sims: mean={off.mean():.3f} max={off.max():.3f} min={off.min():.3f}")
# Principal-angle overlap
mean_fro, std_fro = principal_angle_overlap(self.model.class_pentachora)
print(f"Principal-angle Fro overlap: mean={mean_fro:.3f} ± {std_fro:.3f} (lower is better)")
return {'mean_similarity': off.mean().item(), 'max_similarity': off.max().item(), 'mean_fro_overlap': mean_fro}
def run_full_analysis(self):
print("="*60); print("COMPREHENSIVE FEATURE ANALYSIS"); print("="*60)
print("\nExtracting features (margin_mode =", self.margin_mode, ") ...")
feats = self.extract_all_features(max_batches=50)
print(f"✓ Extracted features from {len(feats['labels'])} samples")
res = {}
res['collapse'] = self.analyze_feature_collapse(feats)
res['geometric'] = self.analyze_geometric_patterns(feats)
# Margin stats from PRE-margin cosines
mu, sig = margin_stats(feats['similarities'], feats['labels'])
print(f"PRE-margin Δ (pos - bestneg): mean={mu:.3f}, std={sig:.3f}")
# Face-usage heatmap
if 'features_proj' in feats and feats['features_proj'].numel() > 0:
heat = face_usage_heatmap(self.model, feats['features_proj'].to(self.device), feats['labels'].to(self.device), norm_type=getattr(self.model,'norm_type','l1'))
print("Face-usage heatmap computed [C,10] (display top 3 classes by mass):")
class_mass = heat.sum(dim=1)
top3 = torch.topk(class_mass, k=min(3, heat.size(0))).indices.tolist()
for c in top3:
print(f" class {c}: {heat[c].cpu().numpy().round(3)}")
res['linear_probe'] = self.test_linear_probe(feats)
res['pentachora'] = self.analyze_pentachora_usage()
# Visualizations
_, n_tsne = self.visualize_features(feats, 'tsne')
_, n_pca = self.visualize_features(feats, 'pca')
res['visual_clusters'] = {'tsne': n_tsne, 'pca': n_pca}
# Summary
print("\n" + "="*60); print("DIAGNOSIS SUMMARY"); print("="*60)
up = res['collapse']['unique_patterns']; lp = res['linear_probe']
if up <= 45 and lp <= 0.42:
print(f"🔴 Compact regime: {up} unique patterns; linear probe {lp*100:.1f}%")
elif up > 60:
print(f"✅ Diverse regime: {up} unique patterns; linear probe {lp*100:.1f}%")
else:
print(f"⚡ Partial bottleneck: {up} unique patterns; linear probe {lp*100:.1f}%")
return res
# ------------------------------
# Helpers (unchanged interface)
# ------------------------------
def _count_unique_patterns(self, features, method='elbow'):
X = features[:min(3000, len(features))]
if method == 'elbow':
inertias, K = [], list(range(20, 80, 5))
for k in K:
km = KMeans(n_clusters=k, random_state=42, n_init=3)
km.fit(X); inertias.append(km.inertia_)
diffs = np.diff(inertias); diffs2 = np.diff(diffs)
if len(diffs2) > 0:
elbow_idx = int(np.argmax(np.abs(diffs2))) + 1
est = K[elbow_idx]
else:
est = 41
else:
scores, K = [], list(range(30, 60, 5))
for k in K:
km = KMeans(n_clusters=k, random_state=42, n_init=3)
lbl = km.fit_predict(X)
scores.append(silhouette_score(X, lbl))
est = K[int(np.argmax(scores))]
return est
# ============================================
# QUICK TEST (RoseFace-aware)
# ============================================
def quick_41_percent_test(model, test_loader, device=None, apply_margin_eval=False):
"""
If apply_margin_eval=True, pass targets to model at eval (margin applied).
Otherwise, evaluate without margin (classic).
"""
print("="*60); print("41% ACCURACY CAP HYPOTHESIS TEST"); print("="*60)
model.eval()
device = device or next(model.parameters()).device
# 1) Accuracy
print("\n1. Verifying model accuracy...")
correct, total = 0, 0
with torch.no_grad():
for images, labels in tqdm(test_loader, desc="Testing"):
images = images.to(device)
labels = labels.to(device)
outputs = model(images, targets=labels) if apply_margin_eval else model(images)
pred = outputs['logits'].argmax(dim=1)
correct += (pred == labels).sum().item()
total += labels.size(0)
acc = 100 * correct / total
policy = "WITH margin" if apply_margin_eval else "NO margin"
print(f" Test Accuracy ({policy}): {acc:.2f}%")
is_at_cap = abs(acc - 41.0) < 3.0
# 2) Focused analysis (small sample)
print("\n2. Analyzing feature patterns...")
margin_mode = 'apply' if apply_margin_eval else 'none'
analyzer = FeatureAnalyzer(model, test_loader, device=device, margin_mode=margin_mode)
feats = analyzer.extract_all_features(max_batches=20)
acc_rose5 = offline_head_eval_rose5(model, feats['features_proj'].to(device), feats['labels'])
print(f"Offline prototype eval (rose5, no margin): {acc_rose5*100:.2f}%")
collapse = analyzer.analyze_feature_collapse(feats)
pent = analyzer.analyze_pentachora_usage()
print("\n" + "="*60); print("VERDICT"); print("="*60)
if is_at_cap and collapse['unique_patterns'] <= 45:
print("🔴 41% CAP CONFIRMED")
print(f" Acc: {acc:.1f}% | Unique patterns: {collapse['unique_patterns']}")
print(" Likely geometric bottleneck.")
elif collapse['unique_patterns'] <= 45:
print("⚠️ FEATURE BOTTLENECK DETECTED")
print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%")
else:
print("✅ NO 41% BOTTLENECK")
print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%")
return {
'accuracy': acc,
'unique_patterns': collapse['unique_patterns'],
'is_bottlenecked': collapse['unique_patterns'] <= 45,
'pentachora_similarity': pent['mean_similarity']
}
@torch.no_grad()
def offline_head_eval_rose5(model, features_proj, labels):
# compute z_l2 (dual-norm bridge)
z = features_proj
z_l2 = z / (z.norm(p=2, dim=-1, keepdim=True) + 1e-12)
# build rose5 prototypes [C,D]
V = torch.stack([p.vertices for p in model.class_pentachora], dim=0).to(z.device, z.dtype) # [C,5,D]
V = V / (V.norm(p=2, dim=-1, keepdim=True) + 1e-12)
W = model.rose_face_weights.to(z.device, z.dtype) # [10,5]
faces = torch.einsum('tf,cfd->ctd', W, V)
proto = (V.mean(dim=1) + 0.5 * faces.mean(dim=1))
proto = proto / (proto.norm(p=2, dim=-1, keepdim=True) + 1e-12) # [C,D]
cos = z_l2 @ proto.t() # [N,C]
acc = (cos.argmax(dim=1) == labels.to(z.device)).float().mean().item()
return acc
# ==========================
# RUN ANALYSIS (example)
# ==========================
if __name__ == "__main__":
print("Starting RoseFace-aware feature analysis...")
print(f"Model device: {next(model.parameters()).device}")
# Dataloaders
train_loader, test_loader, train_transforms = get_cifar100_dataloaders(batch_size=128)
# Quick test in BOTH modes (optional): compare accuracy
print("\nRunning quick 41% hypothesis test (NO margin at eval)...")
res_nom = quick_41_percent_test(model, test_loader, apply_margin_eval=False)
print("\nRunning quick 41% hypothesis test (WITH margin at eval)...")
res_mar = quick_41_percent_test(model, test_loader, apply_margin_eval=True)
# Full analysis with richer diagnostics (no margin at eval is typical)
analyzer = FeatureAnalyzer(model, test_loader, margin_mode='none')
full_results = analyzer.run_full_analysis()