|
|
""" |
|
|
Comprehensive testing cell for BaselineViT (RoseFace-aware) |
|
|
Run AFTER loading your model & checkpoint in Colab. |
|
|
Assumes: model, get_cifar100_dataloaders are defined. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from sklearn.manifold import TSNE |
|
|
from sklearn.decomposition import PCA |
|
|
from sklearn.cluster import KMeans |
|
|
from sklearn.metrics import silhouette_score |
|
|
from pathlib import Path |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def principal_angle_overlap(class_pentachora): |
|
|
""" |
|
|
Measure subspace overlap between classes (lower = better decoupling). |
|
|
Returns (mean_fro, std_fro) across all class pairs. |
|
|
""" |
|
|
device = class_pentachora[0].vertices.device |
|
|
dtype = class_pentachora[0].vertices.dtype |
|
|
C = len(class_pentachora) |
|
|
U = [] |
|
|
for p in class_pentachora: |
|
|
V = p.vertices.to(device=device, dtype=dtype) |
|
|
c = V.mean(dim=0, keepdim=True) |
|
|
A = V - c |
|
|
|
|
|
Q, _ = torch.linalg.qr(A.t(), mode='reduced') |
|
|
U.append(Q) |
|
|
overlaps = [] |
|
|
for a in range(C): |
|
|
for b in range(a+1, C): |
|
|
M = U[a].t() @ U[b] |
|
|
overlaps.append(torch.linalg.norm(M, 'fro').item()) |
|
|
if not overlaps: |
|
|
return 0.0, 0.0 |
|
|
return float(np.mean(overlaps)), float(np.std(overlaps)) |
|
|
|
|
|
@torch.no_grad() |
|
|
def face_usage_heatmap(model, features_proj, targets, norm_type='l1'): |
|
|
""" |
|
|
Compute per-class face (triad) usage heatmap [C,10]. |
|
|
features_proj: [N,D] L1-normalized (from model forward outputs) |
|
|
""" |
|
|
device, dtype = features_proj.device, features_proj.dtype |
|
|
C = model.num_classes |
|
|
triplets = torch.tensor([ |
|
|
[0,1,2],[0,1,3],[0,1,4], |
|
|
[0,2,3],[0,2,4],[0,3,4], |
|
|
[1,2,3],[1,2,4],[1,3,4], |
|
|
[2,3,4] |
|
|
], device=device, dtype=torch.long) |
|
|
counts = torch.zeros(C, 10, device=device, dtype=torch.long) |
|
|
|
|
|
for cls in torch.unique(targets): |
|
|
idx = (targets == cls) |
|
|
if idx.sum() == 0: |
|
|
continue |
|
|
f = features_proj[idx] |
|
|
p = model.class_pentachora[int(cls)] |
|
|
Vn = p.vertices_norm if norm_type == 'l1' else F.normalize(p.vertices, dim=-1) |
|
|
|
|
|
|
|
|
faces = [] |
|
|
for t in triplets: |
|
|
b = (Vn[t[0]] + Vn[t[1]] + Vn[t[2]]) / 3.0 |
|
|
if norm_type == 'l1': |
|
|
b = b / (b.abs().sum() + 1e-8) |
|
|
else: |
|
|
b = F.normalize(b.unsqueeze(0), dim=-1).squeeze(0) |
|
|
faces.append(b) |
|
|
F10 = torch.stack(faces, dim=0) |
|
|
sims = f @ F10.t() |
|
|
winner = sims.argmax(dim=1) |
|
|
binc = torch.bincount(winner, minlength=10) |
|
|
counts[int(cls)] += binc |
|
|
|
|
|
counts = counts.float() |
|
|
counts = counts / (counts.sum(dim=1, keepdim=True) + 1e-9) |
|
|
return counts |
|
|
|
|
|
@torch.no_grad() |
|
|
def margin_stats(cos_pre, targets): |
|
|
""" |
|
|
Compute margin Δ = pos - best_neg from PRE-margin cosines. |
|
|
""" |
|
|
pos = cos_pre.gather(1, targets.view(-1,1)).squeeze(1) |
|
|
masked = cos_pre.masked_fill(F.one_hot(targets, cos_pre.size(1)).bool(), -1e9) |
|
|
neg = masked.max(dim=1).values |
|
|
delta = pos - neg |
|
|
return float(delta.mean()), float(delta.std()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureAnalyzer: |
|
|
""" |
|
|
Analyze feature capacity and geometric patterns. |
|
|
Now aware of RoseFace: |
|
|
- can run with or without margin at inference (margin_mode) |
|
|
- stores pre-margin cosines and post-margin cosines |
|
|
""" |
|
|
def __init__(self, model, dataloader, device=None, margin_mode='none'): |
|
|
""" |
|
|
margin_mode: |
|
|
'none' -> don't pass targets to model (no margin at eval) |
|
|
'apply' -> pass targets (apply margin at eval) |
|
|
'both' -> run both (twice); store *_nomargin and *_margin |
|
|
""" |
|
|
self.model = model |
|
|
self.dataloader = dataloader |
|
|
self.device = device or next(model.parameters()).device |
|
|
self.model.eval() |
|
|
assert margin_mode in ('none','apply','both') |
|
|
self.margin_mode = margin_mode |
|
|
|
|
|
def _forward_once(self, images, labels, apply_margin): |
|
|
|
|
|
if apply_margin: |
|
|
outputs = self.model(images, return_features=True, targets=labels) |
|
|
else: |
|
|
outputs = self.model(images, return_features=True) |
|
|
|
|
|
out = {k: v.detach().cpu() for k, v in outputs.items() if isinstance(v, torch.Tensor)} |
|
|
|
|
|
if getattr(self.model, 'head_type', 'legacy') == 'roseface': |
|
|
s = float(getattr(self.model, 'scale_s', 1.0)) |
|
|
if s > 0 and 'logits' in out: |
|
|
out['cos_post'] = (out['logits'] / s) |
|
|
return out |
|
|
|
|
|
def extract_all_features(self, max_batches=None): |
|
|
""" |
|
|
Extract features, pre-margin cosines, post-margin cosines (if available). |
|
|
Returns dict with keys: |
|
|
- cls_features |
|
|
- features_proj |
|
|
- similarities (pre-margin cos) |
|
|
- cos_post (post-margin cos; RoseFace only) |
|
|
- logits |
|
|
- labels |
|
|
If margin_mode == 'both', suffix *_nomargin / *_margin are included. |
|
|
""" |
|
|
agg = {} |
|
|
|
|
|
def append_batch(prefix, out_tensors, labels): |
|
|
|
|
|
for k, v in out_tensors.items(): |
|
|
agg.setdefault(f'{prefix}{k}', []).append(v) |
|
|
agg.setdefault(f'{prefix}labels', []).append(labels.cpu()) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i, (images, labels) in enumerate(tqdm(self.dataloader, desc="Extracting features")): |
|
|
if max_batches is not None and i >= max_batches: |
|
|
break |
|
|
images = images.to(self.device, non_blocking=True) |
|
|
labels = labels.to(self.device, non_blocking=True) |
|
|
|
|
|
if self.margin_mode in ('none', 'both'): |
|
|
out0 = self._forward_once(images, labels, apply_margin=False) |
|
|
append_batch('', out0, labels) |
|
|
|
|
|
if self.margin_mode in ('apply', 'both'): |
|
|
out1 = self._forward_once(images, labels, apply_margin=True) |
|
|
append_batch('m_', out1, labels) |
|
|
|
|
|
|
|
|
|
|
|
for k, lst in agg.items(): |
|
|
agg[k] = torch.cat(lst, dim=0) |
|
|
|
|
|
|
|
|
def pick(key: str): |
|
|
return agg.get(key, agg.get(f"m_{key}", torch.empty(0))) |
|
|
|
|
|
|
|
|
if self.margin_mode == 'both': |
|
|
features = { |
|
|
'cls_features': pick('features'), |
|
|
'features_proj': pick('features_proj'), |
|
|
'similarities': agg.get('similarities', torch.empty(0)), |
|
|
'cos_post': agg.get('cos_post', torch.empty(0)), |
|
|
'labels': agg.get('labels', torch.empty(0)), |
|
|
'similarities_margin': agg.get('m_similarities', torch.empty(0)), |
|
|
'cos_post_margin': agg.get('m_cos_post', torch.empty(0)), |
|
|
'logits': agg.get('logits', torch.empty(0)), |
|
|
'logits_margin': agg.get('m_logits', torch.empty(0)), |
|
|
} |
|
|
else: |
|
|
|
|
|
features = { |
|
|
'cls_features': pick('features'), |
|
|
'features_proj': pick('features_proj'), |
|
|
'similarities': pick('similarities'), |
|
|
'cos_post': pick('cos_post'), |
|
|
'labels': pick('labels'), |
|
|
'logits': pick('logits'), |
|
|
} |
|
|
return features |
|
|
|
|
|
|
|
|
def analyze_feature_collapse(self, features): |
|
|
print("\n=== FEATURE COLLAPSE ANALYSIS ===") |
|
|
cls_features = features['cls_features'].numpy() |
|
|
unique_patterns = self._count_unique_patterns(cls_features) |
|
|
print(f"Estimated unique patterns: {unique_patterns}/100 classes") |
|
|
|
|
|
feature_std = cls_features.std(axis=0).mean() |
|
|
print(f"Average feature std: {feature_std:.4f}") |
|
|
|
|
|
labels = features['labels'].numpy() |
|
|
sample_size = min(1000, len(labels)) |
|
|
indices = np.random.choice(len(labels), sample_size, replace=False) |
|
|
silhouette = silhouette_score(cls_features[indices], labels[indices]) |
|
|
print(f"Silhouette score: {silhouette:.3f}") |
|
|
|
|
|
|
|
|
class_features = {} |
|
|
for i in range(100): |
|
|
mask = labels == i |
|
|
if mask.sum() > 0: |
|
|
class_features[i] = cls_features[mask].mean(axis=0) |
|
|
if class_features: |
|
|
centroids = np.stack(list(class_features.values())) |
|
|
d = np.linalg.norm(centroids[:, None] - centroids[None, :], axis=2) |
|
|
thr = np.percentile(d[d > 0], 20) |
|
|
close_pairs = (d < thr) & (d > 0) |
|
|
classes_with_close_neighbors = close_pairs.sum(axis=1) |
|
|
print(f"Classes with very similar features: {(classes_with_close_neighbors > 2).sum()}/100") |
|
|
|
|
|
return {'unique_patterns': unique_patterns, 'feature_std': feature_std, 'silhouette': silhouette} |
|
|
|
|
|
def analyze_geometric_patterns(self, features): |
|
|
print("\n=== GEOMETRIC PATTERN ANALYSIS ===") |
|
|
sims = features['similarities'] |
|
|
print(f"Average max cosine: {sims.max(dim=1)[0].mean():.3f}") |
|
|
print(f"Average min cosine: {sims.min(dim=1)[0].mean():.3f}") |
|
|
print(f"Cosine std: {sims.std():.3f}") |
|
|
|
|
|
high_sim_threshold = sims.mean() + sims.std() |
|
|
high_sim_count = (sims > high_sim_threshold).sum(dim=1).float().mean() |
|
|
print(f"Avg classes above (mean+std): {high_sim_count:.1f}/100") |
|
|
|
|
|
labels = features['labels'] |
|
|
correct = sims.gather(1, labels.view(-1,1)).squeeze(1).mean().item() |
|
|
wrong = (sims.sum(dim=1) - sims.gather(1, labels.view(-1,1)).squeeze(1)) / (sims.size(1)-1) |
|
|
margin = (correct - wrong.mean().item()) |
|
|
print(f"Avg cosine margin (correct - mean wrong): {margin:.3f}") |
|
|
|
|
|
|
|
|
if 'cos_post' in features and features['cos_post'].numel() > 0: |
|
|
cos_post = features['cos_post'] |
|
|
|
|
|
pos_pre = sims.gather(1, labels.view(-1,1)) |
|
|
pos_post = cos_post.gather(1, labels.view(-1,1)) |
|
|
shift = (pos_post - pos_pre).mean().item() |
|
|
print(f"Avg target cosine shift (post - pre): {shift:.3f}") |
|
|
|
|
|
return { |
|
|
'max_cos': sims.max(dim=1)[0].mean().item(), |
|
|
'cos_std': sims.std().item(), |
|
|
'high_sim_classes': high_sim_count.item(), |
|
|
'discrimination_margin': margin |
|
|
} |
|
|
|
|
|
def test_linear_probe(self, features, num_epochs=50): |
|
|
print("\n=== LINEAR PROBE TEST ===") |
|
|
X = features['cls_features'] |
|
|
y = features['labels'] |
|
|
n_train = int(0.8 * len(y)) |
|
|
X_train, y_train = X[:n_train], y[:n_train] |
|
|
X_test, y_test = X[n_train:], y[n_train:] |
|
|
|
|
|
probe = torch.nn.Linear(X_train.shape[1], 100).to(self.device) |
|
|
opt = torch.optim.Adam(probe.parameters(), lr=0.01) |
|
|
|
|
|
X_train = X_train.to(self.device); y_train = y_train.to(self.device) |
|
|
X_test = X_test.to(self.device); y_test = y_test.to(self.device) |
|
|
|
|
|
best = 0.0 |
|
|
for epoch in range(num_epochs): |
|
|
logits = probe(X_train) |
|
|
loss = F.cross_entropy(logits, y_train) |
|
|
opt.zero_grad(); loss.backward(); opt.step() |
|
|
if epoch % 10 == 0: |
|
|
with torch.no_grad(): |
|
|
acc = (probe(X_test).argmax(dim=1) == y_test).float().mean().item() |
|
|
best = max(best, acc) |
|
|
print(f" Epoch {epoch}: Test acc = {acc*100:.1f}%") |
|
|
with torch.no_grad(): |
|
|
final = (probe(X_test).argmax(dim=1) == y_test).float().mean().item() |
|
|
best = max(best, final) |
|
|
print(f"Best linear probe accuracy: {best*100:.1f}%") |
|
|
return best |
|
|
|
|
|
def visualize_features(self, features, method='tsne', n_samples=2000): |
|
|
print(f"\n=== FEATURE VISUALIZATION ({method.upper()}) ===") |
|
|
cls_features = features['cls_features'].numpy() |
|
|
labels = features['labels'].numpy() |
|
|
n_samples = min(n_samples, len(labels)) |
|
|
idx = np.random.choice(len(labels), n_samples, replace=False) |
|
|
X = cls_features[idx]; y = labels[idx] |
|
|
|
|
|
print(f"Reducing {n_samples} samples to 2D...") |
|
|
reducer = TSNE(n_components=2, random_state=42, perplexity=30) if method=='tsne' else PCA(n_components=2) |
|
|
X2 = reducer.fit_transform(X) |
|
|
|
|
|
plt.figure(figsize=(12,9)) |
|
|
scatter = plt.scatter(X2[:,0], X2[:,1], c=y, cmap='nipy_spectral', alpha=0.6, s=15) |
|
|
plt.title(f'Feature Space Visualization ({method.upper()})'); plt.xlabel('Comp 1'); plt.ylabel('Comp 2') |
|
|
|
|
|
print("Estimating visual clusters...") |
|
|
silhouette_scores, K = [], list(range(30, 60, 5)) |
|
|
for k in K: |
|
|
kmeans = KMeans(n_clusters=k, random_state=42, n_init=3) |
|
|
cls_lbl = kmeans.fit_predict(X2) |
|
|
silhouette_scores.append(silhouette_score(X2, cls_lbl)) |
|
|
best_k = K[int(np.argmax(silhouette_scores))] |
|
|
kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=5) |
|
|
cluster_labels = kmeans.fit_predict(X2) |
|
|
n_populated = len(np.unique(cluster_labels)) |
|
|
plt.text(0.02, 0.98, f'Estimated clusters: {n_populated}', transform=plt.gca().transAxes, |
|
|
va='top', fontsize=12, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) |
|
|
cbar = plt.colorbar(scatter, ticks=np.arange(0,100,10)); cbar.set_label('Class', rotation=270, labelpad=15) |
|
|
plt.tight_layout(); plt.show() |
|
|
return X2, n_populated |
|
|
|
|
|
def analyze_pentachora_usage(self): |
|
|
print("\n=== PENTACHORA USAGE ANALYSIS ===") |
|
|
print(f"Classes: {self.model.num_classes}") |
|
|
print(f"Embed dim: {self.model.embed_dim} | Penta dim: {self.model.pentachora_dim}") |
|
|
print(f"Head: {getattr(self.model,'head_type','legacy')} | Prototype: {getattr(self.model,'prototype_mode','n/a')} | Margin: {getattr(self.model,'margin_type','n/a')}") |
|
|
if hasattr(self.model, 'to_pentachora_dim'): |
|
|
if isinstance(self.model.to_pentachora_dim, torch.nn.Linear): |
|
|
print(f"Projection: Linear {self.model.embed_dim}→{self.model.pentachora_dim}") |
|
|
else: |
|
|
print("Projection: Identity") |
|
|
|
|
|
|
|
|
centroids = self.model.get_class_centroids() |
|
|
sim = centroids @ centroids.t() |
|
|
mask = ~torch.eye(self.model.num_classes, dtype=bool, device=sim.device) |
|
|
off = sim[mask] |
|
|
print(f"\nCentroid sims: mean={off.mean():.3f} max={off.max():.3f} min={off.min():.3f}") |
|
|
|
|
|
|
|
|
mean_fro, std_fro = principal_angle_overlap(self.model.class_pentachora) |
|
|
print(f"Principal-angle Fro overlap: mean={mean_fro:.3f} ± {std_fro:.3f} (lower is better)") |
|
|
|
|
|
return {'mean_similarity': off.mean().item(), 'max_similarity': off.max().item(), 'mean_fro_overlap': mean_fro} |
|
|
|
|
|
def run_full_analysis(self): |
|
|
print("="*60); print("COMPREHENSIVE FEATURE ANALYSIS"); print("="*60) |
|
|
print("\nExtracting features (margin_mode =", self.margin_mode, ") ...") |
|
|
feats = self.extract_all_features(max_batches=50) |
|
|
print(f"✓ Extracted features from {len(feats['labels'])} samples") |
|
|
|
|
|
res = {} |
|
|
res['collapse'] = self.analyze_feature_collapse(feats) |
|
|
res['geometric'] = self.analyze_geometric_patterns(feats) |
|
|
|
|
|
|
|
|
mu, sig = margin_stats(feats['similarities'], feats['labels']) |
|
|
print(f"PRE-margin Δ (pos - bestneg): mean={mu:.3f}, std={sig:.3f}") |
|
|
|
|
|
|
|
|
if 'features_proj' in feats and feats['features_proj'].numel() > 0: |
|
|
heat = face_usage_heatmap(self.model, feats['features_proj'].to(self.device), feats['labels'].to(self.device), norm_type=getattr(self.model,'norm_type','l1')) |
|
|
print("Face-usage heatmap computed [C,10] (display top 3 classes by mass):") |
|
|
class_mass = heat.sum(dim=1) |
|
|
top3 = torch.topk(class_mass, k=min(3, heat.size(0))).indices.tolist() |
|
|
for c in top3: |
|
|
print(f" class {c}: {heat[c].cpu().numpy().round(3)}") |
|
|
|
|
|
res['linear_probe'] = self.test_linear_probe(feats) |
|
|
res['pentachora'] = self.analyze_pentachora_usage() |
|
|
|
|
|
|
|
|
_, n_tsne = self.visualize_features(feats, 'tsne') |
|
|
_, n_pca = self.visualize_features(feats, 'pca') |
|
|
res['visual_clusters'] = {'tsne': n_tsne, 'pca': n_pca} |
|
|
|
|
|
|
|
|
print("\n" + "="*60); print("DIAGNOSIS SUMMARY"); print("="*60) |
|
|
up = res['collapse']['unique_patterns']; lp = res['linear_probe'] |
|
|
if up <= 45 and lp <= 0.42: |
|
|
print(f"🔴 Compact regime: {up} unique patterns; linear probe {lp*100:.1f}%") |
|
|
elif up > 60: |
|
|
print(f"✅ Diverse regime: {up} unique patterns; linear probe {lp*100:.1f}%") |
|
|
else: |
|
|
print(f"⚡ Partial bottleneck: {up} unique patterns; linear probe {lp*100:.1f}%") |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _count_unique_patterns(self, features, method='elbow'): |
|
|
X = features[:min(3000, len(features))] |
|
|
if method == 'elbow': |
|
|
inertias, K = [], list(range(20, 80, 5)) |
|
|
for k in K: |
|
|
km = KMeans(n_clusters=k, random_state=42, n_init=3) |
|
|
km.fit(X); inertias.append(km.inertia_) |
|
|
diffs = np.diff(inertias); diffs2 = np.diff(diffs) |
|
|
if len(diffs2) > 0: |
|
|
elbow_idx = int(np.argmax(np.abs(diffs2))) + 1 |
|
|
est = K[elbow_idx] |
|
|
else: |
|
|
est = 41 |
|
|
else: |
|
|
scores, K = [], list(range(30, 60, 5)) |
|
|
for k in K: |
|
|
km = KMeans(n_clusters=k, random_state=42, n_init=3) |
|
|
lbl = km.fit_predict(X) |
|
|
scores.append(silhouette_score(X, lbl)) |
|
|
est = K[int(np.argmax(scores))] |
|
|
return est |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def quick_41_percent_test(model, test_loader, device=None, apply_margin_eval=False): |
|
|
""" |
|
|
If apply_margin_eval=True, pass targets to model at eval (margin applied). |
|
|
Otherwise, evaluate without margin (classic). |
|
|
""" |
|
|
print("="*60); print("41% ACCURACY CAP HYPOTHESIS TEST"); print("="*60) |
|
|
model.eval() |
|
|
device = device or next(model.parameters()).device |
|
|
|
|
|
|
|
|
print("\n1. Verifying model accuracy...") |
|
|
correct, total = 0, 0 |
|
|
with torch.no_grad(): |
|
|
for images, labels in tqdm(test_loader, desc="Testing"): |
|
|
images = images.to(device) |
|
|
labels = labels.to(device) |
|
|
outputs = model(images, targets=labels) if apply_margin_eval else model(images) |
|
|
pred = outputs['logits'].argmax(dim=1) |
|
|
correct += (pred == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
acc = 100 * correct / total |
|
|
policy = "WITH margin" if apply_margin_eval else "NO margin" |
|
|
print(f" Test Accuracy ({policy}): {acc:.2f}%") |
|
|
|
|
|
is_at_cap = abs(acc - 41.0) < 3.0 |
|
|
|
|
|
print("\n2. Analyzing feature patterns...") |
|
|
margin_mode = 'apply' if apply_margin_eval else 'none' |
|
|
analyzer = FeatureAnalyzer(model, test_loader, device=device, margin_mode=margin_mode) |
|
|
feats = analyzer.extract_all_features(max_batches=20) |
|
|
acc_rose5 = offline_head_eval_rose5(model, feats['features_proj'].to(device), feats['labels']) |
|
|
print(f"Offline prototype eval (rose5, no margin): {acc_rose5*100:.2f}%") |
|
|
collapse = analyzer.analyze_feature_collapse(feats) |
|
|
pent = analyzer.analyze_pentachora_usage() |
|
|
|
|
|
print("\n" + "="*60); print("VERDICT"); print("="*60) |
|
|
if is_at_cap and collapse['unique_patterns'] <= 45: |
|
|
print("🔴 41% CAP CONFIRMED") |
|
|
print(f" Acc: {acc:.1f}% | Unique patterns: {collapse['unique_patterns']}") |
|
|
print(" Likely geometric bottleneck.") |
|
|
elif collapse['unique_patterns'] <= 45: |
|
|
print("⚠️ FEATURE BOTTLENECK DETECTED") |
|
|
print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%") |
|
|
else: |
|
|
print("✅ NO 41% BOTTLENECK") |
|
|
print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%") |
|
|
|
|
|
return { |
|
|
'accuracy': acc, |
|
|
'unique_patterns': collapse['unique_patterns'], |
|
|
'is_bottlenecked': collapse['unique_patterns'] <= 45, |
|
|
'pentachora_similarity': pent['mean_similarity'] |
|
|
} |
|
|
|
|
|
@torch.no_grad() |
|
|
def offline_head_eval_rose5(model, features_proj, labels): |
|
|
|
|
|
z = features_proj |
|
|
z_l2 = z / (z.norm(p=2, dim=-1, keepdim=True) + 1e-12) |
|
|
|
|
|
V = torch.stack([p.vertices for p in model.class_pentachora], dim=0).to(z.device, z.dtype) |
|
|
V = V / (V.norm(p=2, dim=-1, keepdim=True) + 1e-12) |
|
|
W = model.rose_face_weights.to(z.device, z.dtype) |
|
|
faces = torch.einsum('tf,cfd->ctd', W, V) |
|
|
proto = (V.mean(dim=1) + 0.5 * faces.mean(dim=1)) |
|
|
proto = proto / (proto.norm(p=2, dim=-1, keepdim=True) + 1e-12) |
|
|
cos = z_l2 @ proto.t() |
|
|
acc = (cos.argmax(dim=1) == labels.to(z.device)).float().mean().item() |
|
|
return acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Starting RoseFace-aware feature analysis...") |
|
|
print(f"Model device: {next(model.parameters()).device}") |
|
|
|
|
|
train_loader, test_loader, train_transforms = get_cifar100_dataloaders(batch_size=128) |
|
|
|
|
|
|
|
|
print("\nRunning quick 41% hypothesis test (NO margin at eval)...") |
|
|
res_nom = quick_41_percent_test(model, test_loader, apply_margin_eval=False) |
|
|
|
|
|
print("\nRunning quick 41% hypothesis test (WITH margin at eval)...") |
|
|
res_mar = quick_41_percent_test(model, test_loader, apply_margin_eval=True) |
|
|
|
|
|
|
|
|
analyzer = FeatureAnalyzer(model, test_loader, margin_mode='none') |
|
|
full_results = analyzer.run_full_analysis() |
|
|
|