| | """ |
| | Comprehensive testing cell for BaselineViT (RoseFace-aware) |
| | Run AFTER loading your model & checkpoint in Colab. |
| | Assumes: model, get_cifar100_dataloaders are defined. |
| | """ |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from sklearn.manifold import TSNE |
| | from sklearn.decomposition import PCA |
| | from sklearn.cluster import KMeans |
| | from sklearn.metrics import silhouette_score |
| | from pathlib import Path |
| | import json |
| | from tqdm import tqdm |
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def principal_angle_overlap(class_pentachora): |
| | """ |
| | Measure subspace overlap between classes (lower = better decoupling). |
| | Returns (mean_fro, std_fro) across all class pairs. |
| | """ |
| | device = class_pentachora[0].vertices.device |
| | dtype = class_pentachora[0].vertices.dtype |
| | C = len(class_pentachora) |
| | U = [] |
| | for p in class_pentachora: |
| | V = p.vertices.to(device=device, dtype=dtype) |
| | c = V.mean(dim=0, keepdim=True) |
| | A = V - c |
| | |
| | Q, _ = torch.linalg.qr(A.t(), mode='reduced') |
| | U.append(Q) |
| | overlaps = [] |
| | for a in range(C): |
| | for b in range(a+1, C): |
| | M = U[a].t() @ U[b] |
| | overlaps.append(torch.linalg.norm(M, 'fro').item()) |
| | if not overlaps: |
| | return 0.0, 0.0 |
| | return float(np.mean(overlaps)), float(np.std(overlaps)) |
| |
|
| | @torch.no_grad() |
| | def face_usage_heatmap(model, features_proj, targets, norm_type='l1'): |
| | """ |
| | Compute per-class face (triad) usage heatmap [C,10]. |
| | features_proj: [N,D] L1-normalized (from model forward outputs) |
| | """ |
| | device, dtype = features_proj.device, features_proj.dtype |
| | C = model.num_classes |
| | triplets = torch.tensor([ |
| | [0,1,2],[0,1,3],[0,1,4], |
| | [0,2,3],[0,2,4],[0,3,4], |
| | [1,2,3],[1,2,4],[1,3,4], |
| | [2,3,4] |
| | ], device=device, dtype=torch.long) |
| | counts = torch.zeros(C, 10, device=device, dtype=torch.long) |
| |
|
| | for cls in torch.unique(targets): |
| | idx = (targets == cls) |
| | if idx.sum() == 0: |
| | continue |
| | f = features_proj[idx] |
| | p = model.class_pentachora[int(cls)] |
| | Vn = p.vertices_norm if norm_type == 'l1' else F.normalize(p.vertices, dim=-1) |
| |
|
| | |
| | faces = [] |
| | for t in triplets: |
| | b = (Vn[t[0]] + Vn[t[1]] + Vn[t[2]]) / 3.0 |
| | if norm_type == 'l1': |
| | b = b / (b.abs().sum() + 1e-8) |
| | else: |
| | b = F.normalize(b.unsqueeze(0), dim=-1).squeeze(0) |
| | faces.append(b) |
| | F10 = torch.stack(faces, dim=0) |
| | sims = f @ F10.t() |
| | winner = sims.argmax(dim=1) |
| | binc = torch.bincount(winner, minlength=10) |
| | counts[int(cls)] += binc |
| |
|
| | counts = counts.float() |
| | counts = counts / (counts.sum(dim=1, keepdim=True) + 1e-9) |
| | return counts |
| |
|
| | @torch.no_grad() |
| | def margin_stats(cos_pre, targets): |
| | """ |
| | Compute margin Δ = pos - best_neg from PRE-margin cosines. |
| | """ |
| | pos = cos_pre.gather(1, targets.view(-1,1)).squeeze(1) |
| | masked = cos_pre.masked_fill(F.one_hot(targets, cos_pre.size(1)).bool(), -1e9) |
| | neg = masked.max(dim=1).values |
| | delta = pos - neg |
| | return float(delta.mean()), float(delta.std()) |
| |
|
| | |
| | |
| | |
| |
|
| | class FeatureAnalyzer: |
| | """ |
| | Analyze feature capacity and geometric patterns. |
| | Now aware of RoseFace: |
| | - can run with or without margin at inference (margin_mode) |
| | - stores pre-margin cosines and post-margin cosines |
| | """ |
| | def __init__(self, model, dataloader, device=None, margin_mode='none'): |
| | """ |
| | margin_mode: |
| | 'none' -> don't pass targets to model (no margin at eval) |
| | 'apply' -> pass targets (apply margin at eval) |
| | 'both' -> run both (twice); store *_nomargin and *_margin |
| | """ |
| | self.model = model |
| | self.dataloader = dataloader |
| | self.device = device or next(model.parameters()).device |
| | self.model.eval() |
| | assert margin_mode in ('none','apply','both') |
| | self.margin_mode = margin_mode |
| |
|
| | def _forward_once(self, images, labels, apply_margin): |
| | |
| | if apply_margin: |
| | outputs = self.model(images, return_features=True, targets=labels) |
| | else: |
| | outputs = self.model(images, return_features=True) |
| |
|
| | out = {k: v.detach().cpu() for k, v in outputs.items() if isinstance(v, torch.Tensor)} |
| | |
| | if getattr(self.model, 'head_type', 'legacy') == 'roseface': |
| | s = float(getattr(self.model, 'scale_s', 1.0)) |
| | if s > 0 and 'logits' in out: |
| | out['cos_post'] = (out['logits'] / s) |
| | return out |
| |
|
| | def extract_all_features(self, max_batches=None): |
| | """ |
| | Extract features, pre-margin cosines, post-margin cosines (if available). |
| | Returns dict with keys: |
| | - cls_features |
| | - features_proj |
| | - similarities (pre-margin cos) |
| | - cos_post (post-margin cos; RoseFace only) |
| | - logits |
| | - labels |
| | If margin_mode == 'both', suffix *_nomargin / *_margin are included. |
| | """ |
| | agg = {} |
| |
|
| | def append_batch(prefix, out_tensors, labels): |
| | |
| | for k, v in out_tensors.items(): |
| | agg.setdefault(f'{prefix}{k}', []).append(v) |
| | agg.setdefault(f'{prefix}labels', []).append(labels.cpu()) |
| |
|
| | with torch.no_grad(): |
| | for i, (images, labels) in enumerate(tqdm(self.dataloader, desc="Extracting features")): |
| | if max_batches is not None and i >= max_batches: |
| | break |
| | images = images.to(self.device, non_blocking=True) |
| | labels = labels.to(self.device, non_blocking=True) |
| |
|
| | if self.margin_mode in ('none', 'both'): |
| | out0 = self._forward_once(images, labels, apply_margin=False) |
| | append_batch('', out0, labels) |
| |
|
| | if self.margin_mode in ('apply', 'both'): |
| | out1 = self._forward_once(images, labels, apply_margin=True) |
| | append_batch('m_', out1, labels) |
| |
|
| | |
| | |
| | for k, lst in agg.items(): |
| | agg[k] = torch.cat(lst, dim=0) |
| |
|
| | |
| | def pick(key: str): |
| | return agg.get(key, agg.get(f"m_{key}", torch.empty(0))) |
| |
|
| | |
| | if self.margin_mode == 'both': |
| | features = { |
| | 'cls_features': pick('features'), |
| | 'features_proj': pick('features_proj'), |
| | 'similarities': agg.get('similarities', torch.empty(0)), |
| | 'cos_post': agg.get('cos_post', torch.empty(0)), |
| | 'labels': agg.get('labels', torch.empty(0)), |
| | 'similarities_margin': agg.get('m_similarities', torch.empty(0)), |
| | 'cos_post_margin': agg.get('m_cos_post', torch.empty(0)), |
| | 'logits': agg.get('logits', torch.empty(0)), |
| | 'logits_margin': agg.get('m_logits', torch.empty(0)), |
| | } |
| | else: |
| | |
| | features = { |
| | 'cls_features': pick('features'), |
| | 'features_proj': pick('features_proj'), |
| | 'similarities': pick('similarities'), |
| | 'cos_post': pick('cos_post'), |
| | 'labels': pick('labels'), |
| | 'logits': pick('logits'), |
| | } |
| | return features |
| |
|
| |
|
| | def analyze_feature_collapse(self, features): |
| | print("\n=== FEATURE COLLAPSE ANALYSIS ===") |
| | cls_features = features['cls_features'].numpy() |
| | unique_patterns = self._count_unique_patterns(cls_features) |
| | print(f"Estimated unique patterns: {unique_patterns}/100 classes") |
| |
|
| | feature_std = cls_features.std(axis=0).mean() |
| | print(f"Average feature std: {feature_std:.4f}") |
| |
|
| | labels = features['labels'].numpy() |
| | sample_size = min(1000, len(labels)) |
| | indices = np.random.choice(len(labels), sample_size, replace=False) |
| | silhouette = silhouette_score(cls_features[indices], labels[indices]) |
| | print(f"Silhouette score: {silhouette:.3f}") |
| |
|
| | |
| | class_features = {} |
| | for i in range(100): |
| | mask = labels == i |
| | if mask.sum() > 0: |
| | class_features[i] = cls_features[mask].mean(axis=0) |
| | if class_features: |
| | centroids = np.stack(list(class_features.values())) |
| | d = np.linalg.norm(centroids[:, None] - centroids[None, :], axis=2) |
| | thr = np.percentile(d[d > 0], 20) |
| | close_pairs = (d < thr) & (d > 0) |
| | classes_with_close_neighbors = close_pairs.sum(axis=1) |
| | print(f"Classes with very similar features: {(classes_with_close_neighbors > 2).sum()}/100") |
| |
|
| | return {'unique_patterns': unique_patterns, 'feature_std': feature_std, 'silhouette': silhouette} |
| |
|
| | def analyze_geometric_patterns(self, features): |
| | print("\n=== GEOMETRIC PATTERN ANALYSIS ===") |
| | sims = features['similarities'] |
| | print(f"Average max cosine: {sims.max(dim=1)[0].mean():.3f}") |
| | print(f"Average min cosine: {sims.min(dim=1)[0].mean():.3f}") |
| | print(f"Cosine std: {sims.std():.3f}") |
| |
|
| | high_sim_threshold = sims.mean() + sims.std() |
| | high_sim_count = (sims > high_sim_threshold).sum(dim=1).float().mean() |
| | print(f"Avg classes above (mean+std): {high_sim_count:.1f}/100") |
| |
|
| | labels = features['labels'] |
| | correct = sims.gather(1, labels.view(-1,1)).squeeze(1).mean().item() |
| | wrong = (sims.sum(dim=1) - sims.gather(1, labels.view(-1,1)).squeeze(1)) / (sims.size(1)-1) |
| | margin = (correct - wrong.mean().item()) |
| | print(f"Avg cosine margin (correct - mean wrong): {margin:.3f}") |
| |
|
| | |
| | if 'cos_post' in features and features['cos_post'].numel() > 0: |
| | cos_post = features['cos_post'] |
| | |
| | pos_pre = sims.gather(1, labels.view(-1,1)) |
| | pos_post = cos_post.gather(1, labels.view(-1,1)) |
| | shift = (pos_post - pos_pre).mean().item() |
| | print(f"Avg target cosine shift (post - pre): {shift:.3f}") |
| |
|
| | return { |
| | 'max_cos': sims.max(dim=1)[0].mean().item(), |
| | 'cos_std': sims.std().item(), |
| | 'high_sim_classes': high_sim_count.item(), |
| | 'discrimination_margin': margin |
| | } |
| |
|
| | def test_linear_probe(self, features, num_epochs=50): |
| | print("\n=== LINEAR PROBE TEST ===") |
| | X = features['cls_features'] |
| | y = features['labels'] |
| | n_train = int(0.8 * len(y)) |
| | X_train, y_train = X[:n_train], y[:n_train] |
| | X_test, y_test = X[n_train:], y[n_train:] |
| |
|
| | probe = torch.nn.Linear(X_train.shape[1], 100).to(self.device) |
| | opt = torch.optim.Adam(probe.parameters(), lr=0.01) |
| |
|
| | X_train = X_train.to(self.device); y_train = y_train.to(self.device) |
| | X_test = X_test.to(self.device); y_test = y_test.to(self.device) |
| |
|
| | best = 0.0 |
| | for epoch in range(num_epochs): |
| | logits = probe(X_train) |
| | loss = F.cross_entropy(logits, y_train) |
| | opt.zero_grad(); loss.backward(); opt.step() |
| | if epoch % 10 == 0: |
| | with torch.no_grad(): |
| | acc = (probe(X_test).argmax(dim=1) == y_test).float().mean().item() |
| | best = max(best, acc) |
| | print(f" Epoch {epoch}: Test acc = {acc*100:.1f}%") |
| | with torch.no_grad(): |
| | final = (probe(X_test).argmax(dim=1) == y_test).float().mean().item() |
| | best = max(best, final) |
| | print(f"Best linear probe accuracy: {best*100:.1f}%") |
| | return best |
| |
|
| | def visualize_features(self, features, method='tsne', n_samples=2000): |
| | print(f"\n=== FEATURE VISUALIZATION ({method.upper()}) ===") |
| | cls_features = features['cls_features'].numpy() |
| | labels = features['labels'].numpy() |
| | n_samples = min(n_samples, len(labels)) |
| | idx = np.random.choice(len(labels), n_samples, replace=False) |
| | X = cls_features[idx]; y = labels[idx] |
| |
|
| | print(f"Reducing {n_samples} samples to 2D...") |
| | reducer = TSNE(n_components=2, random_state=42, perplexity=30) if method=='tsne' else PCA(n_components=2) |
| | X2 = reducer.fit_transform(X) |
| |
|
| | plt.figure(figsize=(12,9)) |
| | scatter = plt.scatter(X2[:,0], X2[:,1], c=y, cmap='nipy_spectral', alpha=0.6, s=15) |
| | plt.title(f'Feature Space Visualization ({method.upper()})'); plt.xlabel('Comp 1'); plt.ylabel('Comp 2') |
| |
|
| | print("Estimating visual clusters...") |
| | silhouette_scores, K = [], list(range(30, 60, 5)) |
| | for k in K: |
| | kmeans = KMeans(n_clusters=k, random_state=42, n_init=3) |
| | cls_lbl = kmeans.fit_predict(X2) |
| | silhouette_scores.append(silhouette_score(X2, cls_lbl)) |
| | best_k = K[int(np.argmax(silhouette_scores))] |
| | kmeans = KMeans(n_clusters=best_k, random_state=42, n_init=5) |
| | cluster_labels = kmeans.fit_predict(X2) |
| | n_populated = len(np.unique(cluster_labels)) |
| | plt.text(0.02, 0.98, f'Estimated clusters: {n_populated}', transform=plt.gca().transAxes, |
| | va='top', fontsize=12, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) |
| | cbar = plt.colorbar(scatter, ticks=np.arange(0,100,10)); cbar.set_label('Class', rotation=270, labelpad=15) |
| | plt.tight_layout(); plt.show() |
| | return X2, n_populated |
| |
|
| | def analyze_pentachora_usage(self): |
| | print("\n=== PENTACHORA USAGE ANALYSIS ===") |
| | print(f"Classes: {self.model.num_classes}") |
| | print(f"Embed dim: {self.model.embed_dim} | Penta dim: {self.model.pentachora_dim}") |
| | print(f"Head: {getattr(self.model,'head_type','legacy')} | Prototype: {getattr(self.model,'prototype_mode','n/a')} | Margin: {getattr(self.model,'margin_type','n/a')}") |
| | if hasattr(self.model, 'to_pentachora_dim'): |
| | if isinstance(self.model.to_pentachora_dim, torch.nn.Linear): |
| | print(f"Projection: Linear {self.model.embed_dim}→{self.model.pentachora_dim}") |
| | else: |
| | print("Projection: Identity") |
| |
|
| | |
| | centroids = self.model.get_class_centroids() |
| | sim = centroids @ centroids.t() |
| | mask = ~torch.eye(self.model.num_classes, dtype=bool, device=sim.device) |
| | off = sim[mask] |
| | print(f"\nCentroid sims: mean={off.mean():.3f} max={off.max():.3f} min={off.min():.3f}") |
| |
|
| | |
| | mean_fro, std_fro = principal_angle_overlap(self.model.class_pentachora) |
| | print(f"Principal-angle Fro overlap: mean={mean_fro:.3f} ± {std_fro:.3f} (lower is better)") |
| |
|
| | return {'mean_similarity': off.mean().item(), 'max_similarity': off.max().item(), 'mean_fro_overlap': mean_fro} |
| |
|
| | def run_full_analysis(self): |
| | print("="*60); print("COMPREHENSIVE FEATURE ANALYSIS"); print("="*60) |
| | print("\nExtracting features (margin_mode =", self.margin_mode, ") ...") |
| | feats = self.extract_all_features(max_batches=50) |
| | print(f"✓ Extracted features from {len(feats['labels'])} samples") |
| |
|
| | res = {} |
| | res['collapse'] = self.analyze_feature_collapse(feats) |
| | res['geometric'] = self.analyze_geometric_patterns(feats) |
| |
|
| | |
| | mu, sig = margin_stats(feats['similarities'], feats['labels']) |
| | print(f"PRE-margin Δ (pos - bestneg): mean={mu:.3f}, std={sig:.3f}") |
| |
|
| | |
| | if 'features_proj' in feats and feats['features_proj'].numel() > 0: |
| | heat = face_usage_heatmap(self.model, feats['features_proj'].to(self.device), feats['labels'].to(self.device), norm_type=getattr(self.model,'norm_type','l1')) |
| | print("Face-usage heatmap computed [C,10] (display top 3 classes by mass):") |
| | class_mass = heat.sum(dim=1) |
| | top3 = torch.topk(class_mass, k=min(3, heat.size(0))).indices.tolist() |
| | for c in top3: |
| | print(f" class {c}: {heat[c].cpu().numpy().round(3)}") |
| |
|
| | res['linear_probe'] = self.test_linear_probe(feats) |
| | res['pentachora'] = self.analyze_pentachora_usage() |
| |
|
| | |
| | _, n_tsne = self.visualize_features(feats, 'tsne') |
| | _, n_pca = self.visualize_features(feats, 'pca') |
| | res['visual_clusters'] = {'tsne': n_tsne, 'pca': n_pca} |
| |
|
| | |
| | print("\n" + "="*60); print("DIAGNOSIS SUMMARY"); print("="*60) |
| | up = res['collapse']['unique_patterns']; lp = res['linear_probe'] |
| | if up <= 45 and lp <= 0.42: |
| | print(f"🔴 Compact regime: {up} unique patterns; linear probe {lp*100:.1f}%") |
| | elif up > 60: |
| | print(f"✅ Diverse regime: {up} unique patterns; linear probe {lp*100:.1f}%") |
| | else: |
| | print(f"⚡ Partial bottleneck: {up} unique patterns; linear probe {lp*100:.1f}%") |
| |
|
| | return res |
| |
|
| | |
| | |
| | |
| | def _count_unique_patterns(self, features, method='elbow'): |
| | X = features[:min(3000, len(features))] |
| | if method == 'elbow': |
| | inertias, K = [], list(range(20, 80, 5)) |
| | for k in K: |
| | km = KMeans(n_clusters=k, random_state=42, n_init=3) |
| | km.fit(X); inertias.append(km.inertia_) |
| | diffs = np.diff(inertias); diffs2 = np.diff(diffs) |
| | if len(diffs2) > 0: |
| | elbow_idx = int(np.argmax(np.abs(diffs2))) + 1 |
| | est = K[elbow_idx] |
| | else: |
| | est = 41 |
| | else: |
| | scores, K = [], list(range(30, 60, 5)) |
| | for k in K: |
| | km = KMeans(n_clusters=k, random_state=42, n_init=3) |
| | lbl = km.fit_predict(X) |
| | scores.append(silhouette_score(X, lbl)) |
| | est = K[int(np.argmax(scores))] |
| | return est |
| |
|
| | |
| | |
| | |
| |
|
| | def quick_41_percent_test(model, test_loader, device=None, apply_margin_eval=False): |
| | """ |
| | If apply_margin_eval=True, pass targets to model at eval (margin applied). |
| | Otherwise, evaluate without margin (classic). |
| | """ |
| | print("="*60); print("41% ACCURACY CAP HYPOTHESIS TEST"); print("="*60) |
| | model.eval() |
| | device = device or next(model.parameters()).device |
| |
|
| | |
| | print("\n1. Verifying model accuracy...") |
| | correct, total = 0, 0 |
| | with torch.no_grad(): |
| | for images, labels in tqdm(test_loader, desc="Testing"): |
| | images = images.to(device) |
| | labels = labels.to(device) |
| | outputs = model(images, targets=labels) if apply_margin_eval else model(images) |
| | pred = outputs['logits'].argmax(dim=1) |
| | correct += (pred == labels).sum().item() |
| | total += labels.size(0) |
| | acc = 100 * correct / total |
| | policy = "WITH margin" if apply_margin_eval else "NO margin" |
| | print(f" Test Accuracy ({policy}): {acc:.2f}%") |
| |
|
| | is_at_cap = abs(acc - 41.0) < 3.0 |
| | |
| | print("\n2. Analyzing feature patterns...") |
| | margin_mode = 'apply' if apply_margin_eval else 'none' |
| | analyzer = FeatureAnalyzer(model, test_loader, device=device, margin_mode=margin_mode) |
| | feats = analyzer.extract_all_features(max_batches=20) |
| | acc_rose5 = offline_head_eval_rose5(model, feats['features_proj'].to(device), feats['labels']) |
| | print(f"Offline prototype eval (rose5, no margin): {acc_rose5*100:.2f}%") |
| | collapse = analyzer.analyze_feature_collapse(feats) |
| | pent = analyzer.analyze_pentachora_usage() |
| |
|
| | print("\n" + "="*60); print("VERDICT"); print("="*60) |
| | if is_at_cap and collapse['unique_patterns'] <= 45: |
| | print("🔴 41% CAP CONFIRMED") |
| | print(f" Acc: {acc:.1f}% | Unique patterns: {collapse['unique_patterns']}") |
| | print(" Likely geometric bottleneck.") |
| | elif collapse['unique_patterns'] <= 45: |
| | print("⚠️ FEATURE BOTTLENECK DETECTED") |
| | print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%") |
| | else: |
| | print("✅ NO 41% BOTTLENECK") |
| | print(f" {collapse['unique_patterns']} patterns; Acc={acc:.1f}%") |
| |
|
| | return { |
| | 'accuracy': acc, |
| | 'unique_patterns': collapse['unique_patterns'], |
| | 'is_bottlenecked': collapse['unique_patterns'] <= 45, |
| | 'pentachora_similarity': pent['mean_similarity'] |
| | } |
| |
|
| | @torch.no_grad() |
| | def offline_head_eval_rose5(model, features_proj, labels): |
| | |
| | z = features_proj |
| | z_l2 = z / (z.norm(p=2, dim=-1, keepdim=True) + 1e-12) |
| | |
| | V = torch.stack([p.vertices for p in model.class_pentachora], dim=0).to(z.device, z.dtype) |
| | V = V / (V.norm(p=2, dim=-1, keepdim=True) + 1e-12) |
| | W = model.rose_face_weights.to(z.device, z.dtype) |
| | faces = torch.einsum('tf,cfd->ctd', W, V) |
| | proto = (V.mean(dim=1) + 0.5 * faces.mean(dim=1)) |
| | proto = proto / (proto.norm(p=2, dim=-1, keepdim=True) + 1e-12) |
| | cos = z_l2 @ proto.t() |
| | acc = (cos.argmax(dim=1) == labels.to(z.device)).float().mean().item() |
| | return acc |
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | print("Starting RoseFace-aware feature analysis...") |
| | print(f"Model device: {next(model.parameters()).device}") |
| | |
| | train_loader, test_loader, train_transforms = get_cifar100_dataloaders(batch_size=128) |
| |
|
| | |
| | print("\nRunning quick 41% hypothesis test (NO margin at eval)...") |
| | res_nom = quick_41_percent_test(model, test_loader, apply_margin_eval=False) |
| |
|
| | print("\nRunning quick 41% hypothesis test (WITH margin at eval)...") |
| | res_mar = quick_41_percent_test(model, test_loader, apply_margin_eval=True) |
| |
|
| | |
| | analyzer = FeatureAnalyzer(model, test_loader, margin_mode='none') |
| | full_results = analyzer.run_full_analysis() |
| |
|