""" Comprehensive testing cell for BaselineViT (RoseFace-aware) Run AFTER loading your model & checkpoint in Colab. Assumes: model, get_cifar100_dataloaders are defined. """ import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt from sklearn.manifold import TSNE from sklearn.decomposition import PCA from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from pathlib import Path import json from tqdm import tqdm # ========================= # RoseFace-aware utilities # ========================= @torch.no_grad() def principal_angle_overlap(class_pentachora): """ Measure subspace overlap between classes (lower = better decoupling). Returns (mean_fro, std_fro) across all class pairs. """ device = class_pentachora[0].vertices.device dtype = class_pentachora[0].vertices.dtype C = len(class_pentachora) U = [] for p in class_pentachora: V = p.vertices.to(device=device, dtype=dtype) # [5,D] c = V.mean(dim=0, keepdim=True) A = V - c # [5,D] # QR on D x 5 (A^T) → orthonormal basis in R^D Q, _ = torch.linalg.qr(A.t(), mode='reduced') # [D, r] U.append(Q) overlaps = [] for a in range(C): for b in range(a+1, C): M = U[a].t() @ U[b] # [r_a, r_b] overlaps.append(torch.linalg.norm(M, 'fro').item()) if not overlaps: return 0.0, 0.0 return float(np.mean(overlaps)), float(np.std(overlaps)) @torch.no_grad() def face_usage_heatmap(model, features_proj, targets, norm_type='l1'): """ Compute per-class face (triad) usage heatmap [C,10]. features_proj: [N,D] L1-normalized (from model forward outputs) """ device, dtype = features_proj.device, features_proj.dtype C = model.num_classes triplets = torch.tensor([ [0,1,2],[0,1,3],[0,1,4], [0,2,3],[0,2,4],[0,3,4], [1,2,3],[1,2,4],[1,3,4], [2,3,4] ], device=device, dtype=torch.long) counts = torch.zeros(C, 10, device=device, dtype=torch.long) for cls in torch.unique(targets): idx = (targets == cls) if idx.sum() == 0: continue f = features_proj[idx] # [b,D] p = model.class_pentachora[int(cls)] Vn = p.vertices_norm if norm_type == 'l1' else F.normalize(p.vertices, dim=-1) # [5,D] # Build 10 faces faces = [] for t in triplets: b = (Vn[t[0]] + Vn[t[1]] + Vn[t[2]]) / 3.0 if norm_type == 'l1': b = b / (b.abs().sum() + 1e-8) else: b = F.normalize(b.unsqueeze(0), dim=-1).squeeze(0) faces.append(b) F10 = torch.stack(faces, dim=0) # [10,D] sims = f @ F10.t() # [b,10] winner = sims.argmax(dim=1) # [b] binc = torch.bincount(winner, minlength=10) # [10] counts[int(cls)] += binc counts = counts.float() counts = counts / (counts.sum(dim=1, keepdim=True) + 1e-9) return counts # [C,10] @torch.no_grad() def margin_stats(cos_pre, targets): """ Compute margin Δ = pos - best_neg from PRE-margin cosines. """ pos = cos_pre.gather(1, targets.view(-1,1)).squeeze(1) # [B] masked = cos_pre.masked_fill(F.one_hot(targets, cos_pre.size(1)).bool(), -1e9) neg = masked.max(dim=1).values # [B] delta = pos - neg return float(delta.mean()), float(delta.std()) # ============================================ # FEATURE EXTRACTION AND ANALYSIS (upgraded) # ============================================ class FeatureAnalyzer: """ Analyze feature capacity and geometric patterns. Now aware of RoseFace: - can run with or without margin at inference (margin_mode) - stores pre-margin cosines and post-margin cosines """ def __init__(self, model, dataloader, device=None, margin_mode='none'): """ margin_mode: 'none' -> don't pass targets to model (no margin at eval) 'apply' -> pass targets (apply margin at eval) 'both' -> run both (twice); store *_nomargin and *_margin """ self.model = model self.dataloader = dataloader self.device = device or next(model.parameters()).device self.model.eval() assert margin_mode in ('none','apply','both') self.margin_mode = margin_mode def _forward_once(self, images, labels, apply_margin): # forward; return dict of tensors on CPU if apply_margin: outputs = self.model(images, return_features=True, targets=labels) else: outputs = self.model(images, return_features=True) out = {k: v.detach().cpu() for k, v in outputs.items() if isinstance(v, torch.Tensor)} # Derive post-margin cosines (if RoseFace): cos_post = logits / s if getattr(self.model, 'head_type', 'legacy') == 'roseface': s = float(getattr(self.model, 'scale_s', 1.0)) if s > 0 and 'logits' in out: out['cos_post'] = (out['logits'] / s) return out def extract_all_features(self, max_batches=None): """ Extract features, pre-margin cosines, post-margin cosines (if available). Returns dict with keys: - cls_features - features_proj - similarities (pre-margin cos) - cos_post (post-margin cos; RoseFace only) - logits - labels If margin_mode == 'both', suffix *_nomargin / *_margin are included. """ agg = {} def append_batch(prefix, out_tensors, labels): # initialize lists for k, v in out_tensors.items(): agg.setdefault(f'{prefix}{k}', []).append(v) agg.setdefault(f'{prefix}labels', []).append(labels.cpu()) with torch.no_grad(): for i, (images, labels) in enumerate(tqdm(self.dataloader, desc="Extracting features")): if max_batches is not None and i >= max_batches: break images = images.to(self.device, non_blocking=True) labels = labels.to(self.device, non_blocking=True) if self.margin_mode in ('none', 'both'): out0 = self._forward_once(images, labels, apply_margin=False) append_batch('', out0, labels) if self.margin_mode in ('apply', 'both'): out1 = self._forward_once(images, labels, apply_margin=True) append_batch('m_', out1, labels) # concat # concat everything we collected for k, lst in agg.items(): agg[k] = torch.cat(lst, dim=0) # helper: pick normal key, else 'm_' fallback def pick(key: str): return agg.get(key, agg.get(f"m_{key}", torch.empty(0))) # unify view for downstream code if self.margin_mode == 'both': features = { 'cls_features': pick('features'), 'features_proj': pick('features_proj'), 'similarities': agg.get('similarities', torch.empty(0)), 'cos_post': agg.get('cos_post', torch.empty(0)), 'labels': agg.get('labels', torch.empty(0)), 'similarities_margin': agg.get('m_similarities', torch.empty(0)), 'cos_post_margin': agg.get('m_cos_post', torch.empty(0)), 'logits': agg.get('logits', torch.empty(0)), 'logits_margin': agg.get('m_logits', torch.empty(0)), } else: # works for BOTH margin_mode='none' and margin_mode='apply' features = { 'cls_features': pick('features'), 'features_proj': pick('features_proj'), 'similarities': pick('similarities'), # pre-margin cosines 'cos_post': pick('cos_post'), # post-margin cosines (RoseFace) 'labels': pick('labels'), 'logits': pick('logits'), } return features def analyze_feature_collapse(self, features): print("\n=== FEATURE COLLAPSE ANALYSIS ===") cls_features = features['cls_features'].numpy() unique_patterns = self._count_unique_patterns(cls_features) print(f"Estimated unique patterns: {unique_patterns}/100 classes") feature_std = cls_features.std(axis=0).mean() print(f"Average feature std: {feature_std:.4f}") labels = features['labels'].numpy() sample_size = min(1000, len(labels)) indices = np.random.choice(len(labels), sample_size, replace=False) silhouette = silhouette_score(cls_features[indices], labels[indices]) print(f"Silhouette score: {silhouette:.3f}") # centroid proximity count class_features = {} for i in range(100): mask = labels == i if mask.sum() > 0: class_features[i] = cls_features[mask].mean(axis=0) if class_features: centroids = np.stack(list(class_features.values())) d = np.linalg.norm(centroids[:, None] - centroids[None, :], axis=2) thr = np.percentile(d[d > 0], 20) close_pairs = (d < thr) & (d > 0) classes_with_close_neighbors = close_pairs.sum(axis=1) print(f"Classes with very similar features: {(classes_with_close_neighbors > 2).sum()}/100") return {'unique_patterns': unique_patterns, 'feature_std': feature_std, 'silhouette': silhouette} def analyze_geometric_patterns(self, features): print("\n=== GEOMETRIC PATTERN ANALYSIS ===") sims = features['similarities'] # pre-margin cosines [N,C] print(f"Average max cosine: {sims.max(dim=1)[0].mean():.3f}") print(f"Average min cosine: {sims.min(dim=1)[0].mean():.3f}") print(f"Cosine std: {sims.std():.3f}") high_sim_threshold = sims.mean() + sims.std() high_sim_count = (sims > high_sim_threshold).sum(dim=1).float().mean() print(f"Avg classes above (mean+std): {high_sim_count:.1f}/100") labels = features['labels'] correct = sims.gather(1, labels.view(-1,1)).squeeze(1).mean().item() wrong = (sims.sum(dim=1) - sims.gather(1, labels.view(-1,1)).squeeze(1)) / (sims.size(1)-1) margin = (correct - wrong.mean().item()) print(f"Avg cosine margin (correct - mean wrong): {margin:.3f}") # RoseFace-specific: if post cosines are present, compare deltas if 'cos_post' in features and features['cos_post'].numel() > 0: cos_post = features['cos_post'] # shift on target column pos_pre = sims.gather(1, labels.view(-1,1)) pos_post = cos_post.gather(1, labels.view(-1,1)) shift = (pos_post - pos_pre).mean().item() print(f"Avg target cosine shift (post - pre): {shift:.3f}") return { 'max_cos': sims.max(dim=1)[0].mean().item(), 'cos_std': sims.std().item(), 'high_sim_classes': high_sim_count.item(), 'discrimination_margin': margin } def test_linear_probe(self, features, num_epochs=50): print("\n=== LINEAR PROBE TEST ===") X = features['cls_features'] y = features['labels'] n_train = int(0.8 * len(y)) X_train, y_train = X[:n_train], y[:n_train] X_test, y_test = X[n_train:], y[n_train:] probe = torch.nn.Linear(X_train.shape[1], 100).to(self.device) opt = torch.optim.Adam(probe.parameters(), lr=0.01) X_train = X_train.to(self.device); y_train = y_train.to(self.device) X_test = X_test.to(self.device); y_test = y_test.to(self.device) best = 0.0 for epoch in range(num_epochs): logits = probe(X_train) loss = F.cross_entropy(logits, y_train) opt.zero_grad(); loss.backward(); opt.step() if epoch % 10 == 0: with torch.no_grad(): acc = (probe(X_test).argmax(dim=1) == y_test).float().mean().item() best = max(best, acc) print(f" Epoch {epoch}: Test acc = {acc*100:.1f}%") with torch.no_grad(): final = (probe(X_test).argmax(dim=1) == y_test).float().mean().item() best = max(best, final) print(f"Best linear probe accuracy: {best*100:.1f}%") return best def visualize_features(self, features, method='tsne', n_samples=2000): print(f"\n=== FEATURE VISUALIZATION ({method.upper()}) ===") cls_features = features['cls_features'].numpy() labels = features['labels'].numpy() n_samples = min(n_samples, len(labels)) idx = np.random.choice(len(labels), n_samples, replace=False) X = cls_features[idx]; y = labels[idx] print(f"Reducing {n_samples} samples to 2D...") reducer = TSNE(n_components=2, random_state=42, perplexity=30) if method=='tsne' else PCA(n_components=2) X2 = reducer.fit_transform(X) plt.figure(figsize=(12,9)) scatter = plt.scatter(X2[:,0], X2[:,1], c=y, cmap='nipy_spectral', alpha=0.6, s=15) plt.title(f'Feature Space Visualization ({method.upper()})'); plt.xlabel('Comp 1'); plt.ylabel('Comp 2') print("Estimating visual clusters...") silhouette_scores, K = [], list(range(30, 60, 5)) for k in K: kmeans = KMeans(n_clusters=k, random_state=42, n_init=3) cls_lbl = kmeans.fit_predict(X2) silhouette_scores.append(silhouette_score(X2, cls_lbl)) best_k = K[int(np.argmax(silhouette_scores))] kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=5) cluster_labels = kmeans.fit_predict(X2) n_populated = len(np.unique(cluster_labels)) plt.text(0.02, 0.98, f'Estimated clusters: {n_populated}', transform=plt.gca().transAxes, va='top', fontsize=12, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) cbar = plt.colorbar(scatter, ticks=np.arange(0,100,10)); cbar.set_label('Class', rotation=270, labelpad=15) plt.tight_layout(); plt.show() return X2, n_populated def analyze_pentachora_usage(self): print("\n=== PENTACHORA USAGE ANALYSIS ===") print(f"Classes: {self.model.num_classes}") print(f"Embed dim: {self.model.embed_dim} | Penta dim: {self.model.pentachora_dim}") print(f"Head: {getattr(self.model,'head_type','legacy')} | Prototype: {getattr(self.model,'prototype_mode','n/a')} | Margin: {getattr(self.model,'margin_type','n/a')}") if hasattr(self.model, 'to_pentachora_dim'): if isinstance(self.model.to_pentachora_dim, torch.nn.Linear): print(f"Projection: Linear {self.model.embed_dim}→{self.model.pentachora_dim}") else: print("Projection: Identity") # Inter-class centroid similarity (legacy view) centroids = self.model.get_class_centroids() sim = centroids @ centroids.t() mask = ~torch.eye(self.model.num_classes, dtype=bool, device=sim.device) off = sim[mask] print(f"\nCentroid sims: mean={off.mean():.3f} max={off.max():.3f} min={off.min():.3f}") # Principal-angle overlap mean_fro, std_fro = principal_angle_overlap(self.model.class_pentachora) print(f"Principal-angle Fro overlap: mean={mean_fro:.3f} ± {std_fro:.3f} (lower is better)") return {'mean_similarity': off.mean().item(), 'max_similarity': off.max().item(), 'mean_fro_overlap': mean_fro} def run_full_analysis(self): print("="*60); print("COMPREHENSIVE FEATURE ANALYSIS"); print("="*60) print("\nExtracting features (margin_mode =", self.margin_mode, ") ...") feats = self.extract_all_features(max_batches=50) print(f"✓ Extracted features from {len(feats['labels'])} samples") res = {} res['collapse'] = self.analyze_feature_collapse(feats) res['geometric'] = self.analyze_geometric_patterns(feats) # Margin stats from PRE-margin cosines mu, sig = margin_stats(feats['similarities'], feats['labels']) print(f"PRE-margin Δ (pos - bestneg): mean={mu:.3f}, std={sig:.3f}") # Face-usage heatmap if 'features_proj' in feats and feats['features_proj'].numel() > 0: heat = face_usage_heatmap(self.model, feats['features_proj'].to(self.device), feats['labels'].to(self.device), norm_type=getattr(self.model,'norm_type','l1')) print("Face-usage heatmap computed [C,10] (display top 3 classes by mass):") class_mass = heat.sum(dim=1) top3 = torch.topk(class_mass, k=min(3, heat.size(0))).indices.tolist() for c in top3: print(f" class {c}: {heat[c].cpu().numpy().round(3)}") res['linear_probe'] = self.test_linear_probe(feats) res['pentachora'] = self.analyze_pentachora_usage() # Visualizations _, n_tsne = self.visualize_features(feats, 'tsne') _, n_pca = self.visualize_features(feats, 'pca') res['visual_clusters'] = {'tsne': n_tsne, 'pca': n_pca} # Summary print("\n" + "="*60); print("DIAGNOSIS SUMMARY"); print("="*60) up = res['collapse']['unique_patterns']; lp = res['linear_probe'] if up <= 45 and lp <= 0.42: print(f"🔴 Compact regime: {up} unique patterns; linear probe {lp*100:.1f}%") elif up > 60: print(f"✅ Diverse regime: {up} unique patterns; linear probe {lp*100:.1f}%") else: print(f"⚡ Partial bottleneck: {up} unique patterns; linear probe {lp*100:.1f}%") return res # ------------------------------ # Helpers (unchanged interface) # ------------------------------ def _count_unique_patterns(self, features, method='elbow'): X = features[:min(3000, len(features))] if method == 'elbow': inertias, K = [], list(range(20, 80, 5)) for k in K: km = KMeans(n_clusters=k, random_state=42, n_init=3) km.fit(X); inertias.append(km.inertia_) diffs = np.diff(inertias); diffs2 = np.diff(diffs) if len(diffs2) > 0: elbow_idx = int(np.argmax(np.abs(diffs2))) + 1 est = K[elbow_idx] else: est = 41 else: scores, K = [], list(range(30, 60, 5)) for k in K: km = KMeans(n_clusters=k, random_state=42, n_init=3) lbl = km.fit_predict(X) scores.append(silhouette_score(X, lbl)) est = K[int(np.argmax(scores))] return est # ============================================ # QUICK TEST (RoseFace-aware) # ============================================ def quick_41_percent_test(model, test_loader, device=None, apply_margin_eval=False): """ If apply_margin_eval=True, pass targets to model at eval (margin applied). Otherwise, evaluate without margin (classic). """ print("="*60); print("41% ACCURACY CAP HYPOTHESIS TEST"); print("="*60) model.eval() device = device or next(model.parameters()).device # 1) Accuracy print("\n1. Verifying model accuracy...") correct, total = 0, 0 with torch.no_grad(): for images, labels in tqdm(test_loader, desc="Testing"): images = images.to(device) labels = labels.to(device) outputs = model(images, targets=labels) if apply_margin_eval else model(images) pred = outputs['logits'].argmax(dim=1) correct += (pred == labels).sum().item() total += labels.size(0) acc = 100 * correct / total policy = "WITH margin" if apply_margin_eval else "NO margin" print(f" Test Accuracy ({policy}): {acc:.2f}%") is_at_cap = abs(acc - 41.0) < 3.0 # 2) Focused analysis (small sample) print("\n2. Analyzing feature patterns...") margin_mode = 'apply' if apply_margin_eval else 'none' analyzer = FeatureAnalyzer(model, test_loader, device=device, margin_mode=margin_mode) feats = analyzer.extract_all_features(max_batches=20) acc_rose5 = offline_head_eval_rose5(model, feats['features_proj'].to(device), feats['labels']) print(f"Offline prototype eval (rose5, no margin): {acc_rose5*100:.2f}%") collapse = analyzer.analyze_feature_collapse(feats) pent = analyzer.analyze_pentachora_usage() print("\n" + "="*60); print("VERDICT"); print("="*60) if is_at_cap and collapse['unique_patterns'] <= 45: print("🔴 41% CAP CONFIRMED") print(f" Acc: {acc:.1f}% | Unique patterns: {collapse['unique_patterns']}") print(" Likely geometric bottleneck.") elif collapse['unique_patterns'] <= 45: print("⚠️ FEATURE BOTTLENECK DETECTED") print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%") else: print("✅ NO 41% BOTTLENECK") print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%") return { 'accuracy': acc, 'unique_patterns': collapse['unique_patterns'], 'is_bottlenecked': collapse['unique_patterns'] <= 45, 'pentachora_similarity': pent['mean_similarity'] } @torch.no_grad() def offline_head_eval_rose5(model, features_proj, labels): # compute z_l2 (dual-norm bridge) z = features_proj z_l2 = z / (z.norm(p=2, dim=-1, keepdim=True) + 1e-12) # build rose5 prototypes [C,D] V = torch.stack([p.vertices for p in model.class_pentachora], dim=0).to(z.device, z.dtype) # [C,5,D] V = V / (V.norm(p=2, dim=-1, keepdim=True) + 1e-12) W = model.rose_face_weights.to(z.device, z.dtype) # [10,5] faces = torch.einsum('tf,cfd->ctd', W, V) proto = (V.mean(dim=1) + 0.5 * faces.mean(dim=1)) proto = proto / (proto.norm(p=2, dim=-1, keepdim=True) + 1e-12) # [C,D] cos = z_l2 @ proto.t() # [N,C] acc = (cos.argmax(dim=1) == labels.to(z.device)).float().mean().item() return acc # ========================== # RUN ANALYSIS (example) # ========================== if __name__ == "__main__": print("Starting RoseFace-aware feature analysis...") print(f"Model device: {next(model.parameters()).device}") # Dataloaders train_loader, test_loader, train_transforms = get_cifar100_dataloaders(batch_size=128) # Quick test in BOTH modes (optional): compare accuracy print("\nRunning quick 41% hypothesis test (NO margin at eval)...") res_nom = quick_41_percent_test(model, test_loader, apply_margin_eval=False) print("\nRunning quick 41% hypothesis test (WITH margin at eval)...") res_mar = quick_41_percent_test(model, test_loader, apply_margin_eval=True) # Full analysis with richer diagnostics (no margin at eval is typical) analyzer = FeatureAnalyzer(model, test_loader, margin_mode='none') full_results = analyzer.run_full_analysis()