#!/usr/bin/env python3 """ GeoLIP Core — Full Analysis + Sphere Visualizations ===================================================== Auto-detects CIFAR-10 vs CIFAR-100 from checkpoint config. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import os from collections import defaultdict from torchvision import datasets, transforms DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CKPT = "checkpoints/geolip_core_best.pt" OUT_DIR = "analysis_out" BATCH = 256 # ── HuggingFace push ── HF_REPO_ID = "AbstractPhil/geolip-constellation-core" HF_PUSH = True CIFAR_MEAN = (0.4914, 0.4822, 0.4465) CIFAR_STD = (0.2470, 0.2435, 0.2616) CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] os.makedirs(OUT_DIR, exist_ok=True) print("=" * 70) print("GEOLIP CORE — ANALYSIS + SPHERE VISUALIZATIONS") print(f" Checkpoint: {CKPT}") print(f" Output: {OUT_DIR}/") print("=" * 70) # ══════════════════════════════════════════════════════════════════ # LOAD — auto-detect dataset from config # ══════════════════════════════════════════════════════════════════ ckpt = torch.load(CKPT, map_location="cpu", weights_only=False) cfg = ckpt["config"] N_CLASSES = cfg.get('num_classes', 10) print(f" Epoch: {ckpt['epoch']} Val acc: {ckpt['val_acc']:.1f}%") print(f" Config: output_dim={cfg.get('output_dim')}, " f"n_anchors={cfg.get('n_anchors')}, " f"n_comp={cfg.get('n_comp')}, d_comp={cfg.get('d_comp')}, " f"num_classes={N_CLASSES}") if N_CLASSES <= 10: CLASS_NAMES = CIFAR10_CLASSES[:N_CLASSES] ds_cls = datasets.CIFAR10 ds_name = "CIFAR-10" else: ds_cls = datasets.CIFAR100 ds_name = "CIFAR-100" _tmp = datasets.CIFAR100(root='./data', train=False, download=True) CLASS_NAMES = _tmp.classes del _tmp print(f" Dataset: {ds_name} ({N_CLASSES} classes)") model = GeoLIPCore(**cfg).to(DEVICE) model.load_state_dict(ckpt["state_dict"]) model.eval() val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) val_ds = ds_cls(root='./data', train=False, download=True, transform=val_transform) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True) total_params = sum(p.numel() for p in model.parameters()) # ══════════════════════════════════════════════════════════════════ # COLLECT ALL EMBEDDINGS + PREDICTIONS # ══════════════════════════════════════════════════════════════════ print("\n Collecting embeddings...") all_embs, all_tris, all_nearest, all_labels, all_preds, all_logits = [], [], [], [], [], [] with torch.no_grad(): for imgs, lbls in val_loader: imgs = imgs.to(DEVICE) out = model(imgs) all_embs.append(out['embedding'].float().cpu()) all_tris.append(out['triangulation'].float().cpu()) all_nearest.append(out['nearest'].cpu()) all_labels.append(lbls) all_preds.append(out['logits'].argmax(-1).cpu()) all_logits.append(out['logits'].float().cpu()) embs = torch.cat(all_embs) tris = torch.cat(all_tris) nearest = torch.cat(all_nearest) labels = torch.cat(all_labels) preds = torch.cat(all_preds) logits = torch.cat(all_logits) embs_n = F.normalize(embs, dim=-1) val_acc = (preds == labels).float().mean().item() * 100 print(f" Val accuracy: {val_acc:.1f}%") print(f" Embeddings: {embs.shape}") # ══════════════════════════════════════════════════════════════════ # ANCHOR PUSH — drag anchors to where the data lives # ══════════════════════════════════════════════════════════════════ N_PUSH_STEPS = 30 PUSH_LR = 0.5 print(f"\n Pushing anchors toward CLASS centroids ({N_PUSH_STEPS} steps, lr={PUSH_LR})...") # Before stats anchors_before = model.constellation.anchors.detach().float().cpu().clone() anch_n_before = F.normalize(anchors_before, dim=-1) cos_before = (embs_n @ anch_n_before.T).max(dim=1).values.mean().item() print(f" Before: mean nearest_cos = {cos_before:.4f}") # Push using class centroids emb_device = embs.to(DEVICE) lbl_device = labels.to(DEVICE) if hasattr(model, 'push_anchors_to_centroids'): for step in range(N_PUSH_STEPS): moved = model.push_anchors_to_centroids(emb_device, lbl_device, lr=PUSH_LR) if (step + 1) % 10 == 0: an_tmp = F.normalize(model.constellation.anchors.detach().float().cpu(), dim=-1) c_tmp = (embs_n @ an_tmp.T).max(dim=1).values.mean().item() print(f" Step {step+1:3d}: nearest_cos = {c_tmp:.4f}, moved = {moved}") else: # Inline class-centroid push with torch.no_grad(): anchors_param = model.constellation.anchors.data emb_dev = F.normalize(emb_device, dim=-1) # Compute class centroids once classes = lbl_device.unique() n_cls = classes.shape[0] centroids = [] for c in classes: mask = lbl_device == c centroids.append(F.normalize(emb_dev[mask].mean(0, keepdim=True), dim=-1)) centroids = torch.cat(centroids, dim=0) # (C, D) # Assign anchors to classes round-robin n_a = anchors_param.shape[0] anchors_per_class = n_a // n_cls for step in range(N_PUSH_STEPS): an = F.normalize(anchors_param, dim=-1) cos_ac = an @ centroids.T # (A, C) # Greedy assign assigned = torch.full((n_a,), -1, dtype=torch.long, device=DEVICE) cls_count = torch.zeros(n_cls, dtype=torch.long, device=DEVICE) _, flat_idx = cos_ac.flatten().sort(descending=True) for idx in flat_idx: a = (idx // n_cls).item() c_idx = (idx % n_cls).item() if assigned[a] >= 0: continue if cls_count[c_idx] >= anchors_per_class + 1: continue assigned[a] = c_idx cls_count[c_idx] += 1 if (assigned >= 0).all(): break unassigned = (assigned < 0).nonzero(as_tuple=True)[0] if len(unassigned) > 0: assigned[unassigned] = (an[unassigned] @ centroids.T).argmax(dim=1) # Push each anchor toward its class centroid for a in range(n_a): target = centroids[assigned[a].item()] rank = (assigned[:a] == assigned[a]).sum().item() if rank > 0: noise = torch.randn_like(target) * 0.05 noise = noise - (noise * target).sum() * target target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0) anchors_param[a] = F.normalize( (an[a] + PUSH_LR * (target - an[a])).unsqueeze(0), dim=-1).squeeze(0) if (step + 1) % 10 == 0: an_tmp = F.normalize(anchors_param, dim=-1) c_tmp = (emb_dev @ an_tmp.T).max(dim=1).values.mean().item() print(f" Step {step+1:3d}: nearest_cos = {c_tmp:.4f}") # After stats anchors = model.constellation.anchors.detach().float().cpu() anchors_n = F.normalize(anchors, dim=-1) n_anchors = anchors.shape[0] cos_after = (embs_n @ anchors_n.T).max(dim=1).values.mean().item() drift = (F.normalize(anchors_before, dim=-1) - anchors_n).norm(dim=-1).mean().item() print(f" After: mean nearest_cos = {cos_after:.4f} (Δ={cos_after - cos_before:+.4f})") print(f" Anchor drift: {drift:.4f}") # Re-triangulate with pushed anchors with torch.no_grad(): new_cos = embs_n @ anchors_n.T tris = 1.0 - new_cos nearest = new_cos.argmax(dim=1) print(f" Anchors: {anchors.shape}") # ══════════════════════════════════════════════════════════════════ # AUDIT 1: NUMERIC HEALTH # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("AUDIT 1: NUMERIC HEALTH") print(f"{'='*70}") issues = [] for name, param in model.named_parameters(): p = param.detach().float() n_nan = torch.isnan(p).sum().item() n_inf = torch.isinf(p).sum().item() p_std = p.std().item() if p.numel() > 1 else 0 flags = [] if n_nan > 0: flags.append(f"NaN={n_nan}") if n_inf > 0: flags.append(f"inf={n_inf}") if p_std < 1e-8 and p.numel() > 1: flags.append(f"COLLAPSED(std={p_std:.2e})") if flags: print(f" ⚠ {name:<50} {' '.join(flags)}") issues.append(name) if not issues: print(f" ✓ All {total_params:,} parameters clean") # ══════════════════════════════════════════════════════════════════ # AUDIT 2: PER-CLASS ACCURACY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("AUDIT 2: PER-CLASS ACCURACY") print(f"{'='*70}") class_accs = [] for c in range(N_CLASSES): mask = labels == c acc = (preds[mask] == c).float().mean().item() * 100 if mask.sum() > 0 else 0 class_accs.append(acc) if N_CLASSES <= 10: for c in range(N_CLASSES): print(f" {CLASS_NAMES[c]:<12}: {class_accs[c]:5.1f}%") else: sorted_idx = sorted(range(N_CLASSES), key=lambda c: class_accs[c]) print(f" Bottom 10:") for c in sorted_idx[:10]: print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%") print(f" Top 10:") for c in sorted_idx[-10:]: print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%") print(f" Mean: {np.mean(class_accs):.1f}% " f"Median: {np.median(class_accs):.1f}% " f"Std: {np.std(class_accs):.1f}%") # ══════════════════════════════════════════════════════════════════ # AUDIT 3: EMBEDDING SPACE # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("AUDIT 3: EMBEDDING SPACE") print(f"{'='*70}") n_sample = min(2000, len(embs)) sim = embs_n[:n_sample] @ embs_n[:n_sample].T sim_mask = ~torch.eye(n_sample, dtype=torch.bool) labels_s = labels[:n_sample] same_class = labels_s.unsqueeze(0) == labels_s.unsqueeze(1) same_not_self = same_class & sim_mask diff_class = ~same_class & sim_mask self_sim = sim[sim_mask].mean().item() same_cos = sim[same_not_self].mean().item() if same_not_self.any() else 0 diff_cos = sim[diff_class].mean().item() if diff_class.any() else 0 gap = same_cos - diff_cos _, S, _ = torch.linalg.svd(embs_n[:512].float(), full_matrices=False) p = S / S.sum() eff_dim = p.pow(2).sum().reciprocal().item() print(f" Self-similarity: {self_sim:.4f}") print(f" Same-class cos: {same_cos:.4f}") print(f" Diff-class cos: {diff_cos:.4f}") print(f" Gap: {gap:.4f}") print(f" Effective dim: {eff_dim:.1f}/{embs.shape[1]}") # ══════════════════════════════════════════════════════════════════ # AUDIT 4: CONSTELLATION HEALTH # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("AUDIT 4: CONSTELLATION HEALTH") print(f"{'='*70}") anch_sim = anchors_n @ anchors_n.T anch_mask = ~torch.eye(n_anchors, dtype=torch.bool) anch_off = anch_sim[anch_mask] n_active = nearest.unique().numel() counts = torch.zeros(n_anchors, dtype=torch.long) for a in range(n_anchors): counts[a] = (nearest == a).sum() print(f" Anchors: {n_anchors} × {anchors.shape[1]}") print(f" Pairwise cos: mean={anch_off.mean():.4f} max={anch_off.max():.4f}") print(f" Active: {n_active}/{n_anchors}") print(f" Utilization: min={counts.min().item()} max={counts.max().item()} " f"mean={counts.float().mean():.1f} std={counts.float().std():.1f}") # ══════════════════════════════════════════════════════════════════ # AUDIT 5: PENTACHORON CV # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("AUDIT 5: PENTACHORON CV") print(f"{'='*70}") sample = embs_n[:2000].to(DEVICE) vols = [] with torch.no_grad(): for _ in range(500): idx = torch.randperm(min(2000, len(sample)), device=DEVICE)[:5] pts = sample[idx].unsqueeze(0).float() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 v2 = -torch.linalg.det(cm) / 9216 if v2[0].item() > 1e-20: vols.append(v2[0].sqrt().cpu()) if len(vols) > 10: vt = torch.stack(vols) v_cv = (vt.std() / (vt.mean() + 1e-8)).item() band = "✓ IN BAND" if 0.18 <= v_cv <= 0.25 else "✗ outside" print(f" CV: {v_cv:.4f} ({band})") print(f" Vol mean: {vt.mean():.6f} std: {vt.std():.6f}") else: v_cv = 0 print(f" ⚠ Not enough valid pentachora ({len(vols)})") # ══════════════════════════════════════════════════════════════════ # AUDIT 6: CONFIDENCE CALIBRATION # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("AUDIT 6: CONFIDENCE CALIBRATION") print(f"{'='*70}") probs = logits.softmax(-1) conf = probs.max(dim=1).values correct_mask = preds == labels print(f" Correct: mean_conf={conf[correct_mask].mean():.4f} " f"std={conf[correct_mask].std():.4f}") if (~correct_mask).any(): wrong_conf = conf[~correct_mask] overconf = (wrong_conf > 0.9).sum().item() print(f" Wrong: mean_conf={wrong_conf.mean():.4f} " f"std={wrong_conf.std():.4f}") print(f" Overconfident wrong (>0.9): {overconf}/{wrong_conf.numel()} " f"({100*overconf/max(wrong_conf.numel(),1):.1f}%)") # ══════════════════════════════════════════════════════════════════ # AUDIT 7: GRADIENT FLOW # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("AUDIT 7: GRADIENT FLOW") print(f"{'='*70}") model.train() model.zero_grad() imgs_g, lbls_g = next(iter(val_loader)) imgs_g = imgs_g[:16].to(DEVICE) lbls_g = lbls_g[:16].to(DEVICE) with torch.amp.autocast("cuda", dtype=torch.bfloat16): out = model(imgs_g) loss = F.cross_entropy(out['logits'], lbls_g) + 0.1 * out['embedding'].mean() loss.backward() grad_by_mod = defaultdict(list) for name, param in model.named_parameters(): if param.grad is None: continue gn = param.grad.detach().float().norm().item() if "encoder" in name: mod = "encoder" elif "constellation" in name: mod = "constellation" elif "patchwork" in name: mod = "patchwork" elif "classifier" in name: mod = "classifier" else: mod = "other" grad_by_mod[mod].append(gn) for mod in sorted(grad_by_mod): norms = grad_by_mod[mod] print(f" {mod:<15}: mean={np.mean(norms):.6f} max={np.max(norms):.6f} " f"({len(norms)} params)") print(f" ✓ All parameters receive gradient") model.eval() # ══════════════════════════════════════════════════════════════════ # VISUALIZATIONS # ══════════════════════════════════════════════════════════════════ try: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt HAS_PLT = True except ImportError: HAS_PLT = False print("\n ⚠ matplotlib not available, skipping visualizations") if HAS_PLT: if N_CLASSES <= 10: CLASS_COLORS = [ '#e6194b', '#3cb44b', '#4363d8', '#f58231', '#911eb4', '#42d4f4', '#f032e6', '#bfef45', '#469990', '#dcbeff'] else: # Vibrant HSV spiral — 100 distinct saturated colors import colorsys CLASS_COLORS = [] for i in range(N_CLASSES): # Golden angle rotation for max hue separation hue = (i * 0.618033988749895) % 1.0 # Alternate saturation/value for neighboring hues sat = 0.75 + 0.25 * (i % 3) / 2 val = 0.85 + 0.15 * ((i + 1) % 2) r, g, b = colorsys.hsv_to_rgb(hue, sat, val) CLASS_COLORS.append(f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}') # Dark theme for all plots — makes colors pop plt.style.use('dark_background') plt.rcParams.update({ 'figure.facecolor': '#1a1a2e', 'axes.facecolor': '#16213e', 'axes.edgecolor': '#444466', 'axes.labelcolor': '#e0e0e0', 'text.color': '#e0e0e0', 'xtick.color': '#aaaacc', 'ytick.color': '#aaaacc', 'grid.color': '#333355', 'legend.facecolor': '#1a1a2e', 'legend.edgecolor': '#444466', }) print(f"\n{'='*70}") print("VISUALIZATIONS") print(f"{'='*70}") def save_fig(filename, dpi=200): plt.savefig(f'{OUT_DIR}/{filename}', dpi=dpi) # ── Sphere grid helpers ── def draw_sphere_grid_2d(ax, radius, n_meridians=24): """Draw sphere reference grid — UNMISSABLE.""" print(f" >>> DRAWING 2D GRID: radius={radius:.4f}, lw=5, white+cyan") theta = np.linspace(0, 2 * np.pi, 500) xr = radius * np.cos(theta) yr = radius * np.sin(theta) # Cyan glow (fat, behind) ax.plot(xr, yr, color='#00e5ff', alpha=0.6, lw=9, zorder=49) # White ring on top ax.plot(xr, yr, color='white', alpha=1.0, lw=5, zorder=50, solid_capstyle='round') # Inner rings — dashed cyan, thick for frac in [0.5, 0.75]: ax.plot(frac * xr, frac * yr, color='#00e5ff', alpha=0.5, lw=2, linestyle='--', zorder=50) # Meridian ticks — chunky white for i in range(n_meridians): a = 2 * np.pi * i / n_meridians r0, r1 = radius * 0.92, radius * 1.08 ax.plot([r0*np.cos(a), r1*np.cos(a)], [r0*np.sin(a), r1*np.sin(a)], color='white', alpha=0.8, lw=2, zorder=50) # Crosshairs s = radius * 1.15 ax.plot([-s, s], [0, 0], color='#00e5ff', alpha=0.3, lw=1.5, zorder=49) ax.plot([0, 0], [-s, s], color='#00e5ff', alpha=0.3, lw=1.5, zorder=49) # Text label proving it rendered ax.text(radius * 0.72, radius * 0.72, f'r={radius:.2f}', color='#00e5ff', fontsize=10, fontweight='bold', alpha=0.9, zorder=51) def draw_sphere_grid_3d(ax, radius, n_lines=16): """Draw a wireframe sphere in 3D PCA space — THICK.""" print(f" >>> DRAWING 3D WIREFRAME: radius={radius:.4f}, lw=1.2+3") theta = np.linspace(0, 2 * np.pi, 80) phi = np.linspace(0, np.pi, 40) # Latitude rings for p in np.linspace(0, np.pi, n_lines + 1)[1:-1]: r = radius * np.sin(p) z = radius * np.cos(p) ax.plot(r * np.cos(theta), r * np.sin(theta), z * np.ones_like(theta), color='white', alpha=0.4, lw=1.2) # Longitude meridians for t in np.linspace(0, 2 * np.pi, n_lines, endpoint=False): x = radius * np.sin(phi) * np.cos(t) y = radius * np.sin(phi) * np.sin(t) z = radius * np.cos(phi) ax.plot(x, y, z, color='white', alpha=0.4, lw=1.2) # Equator — bright cyan, extra thick ax.plot(radius * np.cos(theta), radius * np.sin(theta), np.zeros_like(theta), color='#00e5ff', alpha=0.9, lw=3) # PCA basis embs_c = embs_n[:5000] - embs_n[:5000].mean(0, keepdim=True) _, _, Vt = torch.linalg.svd(embs_c, full_matrices=False) proj_2d = (embs_n @ Vt[:2].T).numpy() proj_3d = (embs_n @ Vt[:3].T).numpy() anch_2d = (anchors_n @ Vt[:2].T).numpy() anch_3d = (anchors_n @ Vt[:3].T).numpy() proj_labels = labels.numpy() # Compute sphere radius from projected data emb_radii_2d = np.sqrt(proj_2d[:5000, 0]**2 + proj_2d[:5000, 1]**2) sphere_r_2d = np.percentile(emb_radii_2d, 95) emb_radii_3d = np.sqrt((proj_3d[:3000]**2).sum(axis=1)) sphere_r_3d = np.percentile(emb_radii_3d, 95) # Sanity: if projections are tiny, use data range instead data_range_2d = max(np.abs(proj_2d[:5000]).max(), np.abs(anch_2d).max()) data_range_3d = max(np.abs(proj_3d[:3000]).max(), np.abs(anch_3d).max()) if sphere_r_2d < 0.01: sphere_r_2d = data_range_2d * 0.9 if sphere_r_3d < 0.01: sphere_r_3d = data_range_3d * 0.9 print(f" Sphere radius (2D): {sphere_r_2d:.4f} (3D): {sphere_r_3d:.4f}") print(f" Data range (2D): {data_range_2d:.4f} (3D): {data_range_3d:.4f}") # ── [1] PCA embedding space ── print(" [1/8] PCA projection...") fig, ax = plt.subplots(1, 1, figsize=(12, 10)) for c in range(N_CLASSES): mask = proj_labels[:5000] == c if mask.sum() == 0: continue lbl = CLASS_NAMES[c] if N_CLASSES <= 20 else None ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1], c=CLASS_COLORS[c], s=4, alpha=0.5, label=lbl, zorder=2) ax.scatter(anch_2d[:, 0], anch_2d[:, 1], c='#FFD700', s=60, marker='*', edgecolors='white', linewidths=0.3, zorder=5, label='anchors') # Grid drawn LAST — on top of everything draw_sphere_grid_2d(ax, sphere_r_2d) if N_CLASSES <= 20: ax.legend(fontsize=7, markerscale=2, loc='upper right', ncol=2) ax.set_title(f'GeoLIP Core — PCA Embedding Space ({ds_name})\n' f'val={val_acc:.1f}% | {total_params:,} params | ' f'CV={v_cv:.4f} | {n_active}/{n_anchors} anchors', fontsize=11) ax.set_xlabel('PC1'); ax.set_ylabel('PC2') ax.set_aspect('equal') ax.grid(True, alpha=0.15, color='#555577') plt.tight_layout() save_fig('01_pca_embedding_space.png') plt.close() # ── [2] Triangulation connections ── print(" [2/8] Triangulation connections...") fig, ax = plt.subplots(1, 1, figsize=(12, 10)) subset = min(500, len(embs)) for i in range(subset): a_idx = nearest[i].item() ax.plot([proj_2d[i, 0], anch_2d[a_idx, 0]], [proj_2d[i, 1], anch_2d[a_idx, 1]], c=CLASS_COLORS[labels[i].item()], alpha=0.1, linewidth=0.5) for c in range(N_CLASSES): mask = proj_labels[:5000] == c if mask.sum() == 0: continue ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1], c=CLASS_COLORS[c], s=5, alpha=0.4, zorder=2) ax.scatter(anch_2d[:, 0], anch_2d[:, 1], c='#FFD700', s=80, marker='*', edgecolors='white', linewidths=0.3, zorder=5) if n_anchors <= 128: for a in range(n_anchors): a_mask = nearest == a if a_mask.sum() > 0: dom_class = labels[a_mask].mode().values.item() ax.annotate(str(dom_class), (anch_2d[a, 0], anch_2d[a, 1]), fontsize=4, ha='center', va='center', color='white', fontweight='bold', bbox=dict(boxstyle='round,pad=0.1', fc=CLASS_COLORS[dom_class], ec='#FFD700', linewidth=0.5, alpha=0.85)) # Grid drawn LAST draw_sphere_grid_2d(ax, sphere_r_2d) ax.set_title(f'Triangulation: Image → Nearest Anchor ({ds_name})', fontsize=11) ax.set_aspect('equal') ax.grid(True, alpha=0.15, color='#555577') plt.tight_layout() save_fig('02_triangulation_connections.png') plt.close() # ── [3] 3D sphere ── print(" [3/8] 3D sphere projection...") fig = plt.figure(figsize=(12, 10)) ax = fig.add_subplot(111, projection='3d') n_3d = min(3000, len(embs)) for c in range(min(N_CLASSES, 20)): mask = proj_labels[:n_3d] == c if mask.sum() == 0: continue ax.scatter(proj_3d[:n_3d][mask, 0], proj_3d[:n_3d][mask, 1], proj_3d[:n_3d][mask, 2], c=CLASS_COLORS[c], s=5, alpha=0.4, label=CLASS_NAMES[c] if N_CLASSES <= 20 else None) ax.scatter(anch_3d[:, 0], anch_3d[:, 1], anch_3d[:, 2], c='#FFD700', s=40, marker='*', edgecolors='white', linewidths=0.3, zorder=5) # Wireframe drawn AFTER data — 3D has no zorder, draw order is render order draw_sphere_grid_3d(ax, sphere_r_3d) if N_CLASSES <= 20: ax.legend(fontsize=6, markerscale=2, loc='upper left', ncol=2) ax.set_title(f'3D PCA — Constellation on the Sphere\n' f'{n_anchors} anchors, {N_CLASSES} classes', fontsize=11) try: ax.set_box_aspect([1, 1, 1]) except AttributeError: pass # older matplotlib ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False plt.tight_layout() save_fig('03_3d_sphere.png') plt.close() # ── [4] Anchor-Class heatmap ── print(" [4/8] Anchor-class assignment matrix...") assign_mat = torch.zeros(N_CLASSES, n_anchors) for c in range(N_CLASSES): c_nearest = nearest[labels == c] for a in range(n_anchors): assign_mat[c, a] = (c_nearest == a).sum().float() assign_norm = assign_mat / (assign_mat.sum(dim=1, keepdim=True) + 1e-8) peak_class = assign_norm.argmax(dim=0) sort_order = peak_class.argsort() assign_sorted = assign_norm[:, sort_order] h = max(6, N_CLASSES * 0.12) fig, ax = plt.subplots(1, 1, figsize=(16, h)) im = ax.imshow(assign_sorted.numpy(), aspect='auto', cmap='inferno') if N_CLASSES <= 30: ax.set_yticks(range(N_CLASSES)) ax.set_yticklabels(CLASS_NAMES, fontsize=max(4, 9 - N_CLASSES // 15)) ax.set_xlabel('Anchor index (sorted by peak class)') ax.set_title(f'Class → Anchor Assignment ({ds_name})', fontsize=11) plt.colorbar(im, ax=ax, shrink=0.8) plt.tight_layout() save_fig('04_anchor_class_heatmap.png') plt.close() # ── [5] Triangulation profiles ── print(" [5/8] Class triangulation profiles...") if N_CLASSES <= 10: show_classes = list(range(N_CLASSES)) else: sorted_by_acc = sorted(range(N_CLASSES), key=lambda c: class_accs[c]) show_classes = sorted_by_acc[:5] + sorted_by_acc[-5:] nrows, ncols = 2, 5 fig, axes = plt.subplots(nrows, ncols, figsize=(20, 8)) for idx, c in enumerate(show_classes): ax = axes[idx // ncols][idx % ncols] c_tris = tris[labels == c] if len(c_tris) == 0: continue mean_tri = c_tris.mean(0).numpy() std_tri = c_tris.std(0).numpy() x = np.arange(n_anchors) color = CLASS_COLORS[c] ax.fill_between(x, mean_tri - std_tri, mean_tri + std_tri, alpha=0.3, color=color) ax.plot(x, mean_tri, color=color, linewidth=1.5) ax.set_title(f'{CLASS_NAMES[c]} ({class_accs[c]:.0f}%)', fontsize=9, fontweight='bold', color=color) ax.set_ylim(0, max(1.6, mean_tri.max() * 1.2)) ax.tick_params(labelsize=5) tag = "all classes" if N_CLASSES <= 10 else "5 worst + 5 best" plt.suptitle(f'Triangulation Fingerprints ({tag})', fontsize=12) plt.tight_layout() save_fig('05_triangulation_profiles.png') plt.close() # ── [6] Anchor utilization ── print(" [6/8] Anchor utilization...") fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) sorted_counts, _ = counts.sort(descending=True) ax1.bar(range(n_anchors), sorted_counts.numpy(), color=['#00BCD4' if c > 0 else '#FF5252' for c in sorted_counts], width=1.0) ax1.set_xlabel('Anchor (sorted)') ax1.set_ylabel('Assigned samples') ax1.set_title(f'Anchor Utilization ({n_active}/{n_anchors} active)') ax1.axhline(y=len(labels) / n_anchors, color='#888899', linestyle='--', alpha=0.5) # Per-class anchor entropy entropies = [] for c in range(N_CLASSES): c_nearest = nearest[labels == c] dist = torch.zeros(n_anchors) for a in range(n_anchors): dist[a] = (c_nearest == a).sum().float() dist = dist / (dist.sum() + 1e-8) ent = -(dist * (dist + 1e-10).log()).sum().item() entropies.append(ent) if N_CLASSES <= 20: ax2.barh(range(N_CLASSES), entropies, color=[CLASS_COLORS[c] for c in range(N_CLASSES)]) ax2.set_yticks(range(N_CLASSES)) ax2.set_yticklabels(CLASS_NAMES, fontsize=8) ax2.set_xlabel('Anchor assignment entropy') else: ax2.hist(entropies, bins=30, color='#00BCD4', edgecolor='#333355') ax2.set_xlabel('Anchor assignment entropy') ax2.set_ylabel('Number of classes') # Gini c_sorted = counts.float().sort().values cum = c_sorted.cumsum(0) gini = (1 - 2 * cum.sum() / (len(c_sorted) * c_sorted.sum() + 1e-8)).item() ax2.set_title(f'Anchor Spread (Gini={gini:.3f})') plt.tight_layout() save_fig('06_anchor_utilization.png') plt.close() # ── [7] Patchwork compartment responses ── print(" [7/8] Patchwork compartment responses...") n_comp = cfg.get('n_comp', 8) asgn = model.patchwork.asgn.cpu() if N_CLASSES <= 10: show_c = list(range(N_CLASSES)) else: show_c = show_classes ncols_pw = min(4, n_comp) nrows_pw = math.ceil(n_comp / ncols_pw) fig, axes = plt.subplots(nrows_pw, ncols_pw, figsize=(4 * ncols_pw, 3 * nrows_pw)) if n_comp == 1: axes = [[axes]] elif nrows_pw == 1: axes = [axes if isinstance(axes, list) else list(axes)] elif ncols_pw == 1: axes = [[a] for a in axes] axes_flat = [axes[r][c] for r in range(nrows_pw) for c in range(ncols_pw)] for k in range(min(n_comp, len(axes_flat))): ax = axes_flat[k] comp_tris = tris[:, asgn == k] class_means = [] class_labels_show = [] for c in show_c: cm = comp_tris[labels == c] if len(cm) > 0: class_means.append(cm.mean(0).numpy()) class_labels_show.append(CLASS_NAMES[c]) if not class_means: continue class_means = np.stack(class_means) ax.imshow(class_means, aspect='auto', cmap='plasma') ax.set_yticks(range(len(class_labels_show))) ax.set_yticklabels(class_labels_show, fontsize=6) ax.set_title(f'Comp {k}', fontsize=9) for k in range(n_comp, len(axes_flat)): axes_flat[k].set_visible(False) plt.suptitle('Patchwork Compartment Responses by Class', fontsize=12) plt.tight_layout() save_fig('07_patchwork_compartments.png') plt.close() # ── [8] Confusion matrix ── print(" [8/8] Confusion matrix...") conf_mat = torch.zeros(N_CLASSES, N_CLASSES, dtype=torch.long) for i in range(len(labels)): conf_mat[labels[i], preds[i]] += 1 conf_pct = conf_mat.float() / (conf_mat.sum(dim=1, keepdim=True) + 1e-8) * 100 if N_CLASSES <= 20: fig, ax = plt.subplots(1, 1, figsize=(8, 7)) im = ax.imshow(conf_pct.numpy(), cmap='magma', vmin=0, vmax=100) for i in range(N_CLASSES): for j in range(N_CLASSES): v = conf_pct[i, j].item() ax.text(j, i, f'{v:.0f}', ha='center', va='center', fontsize=max(4, 8 - N_CLASSES // 5), color='black' if v > 60 else '#e0e0e0') ax.set_xticks(range(N_CLASSES)) ax.set_yticks(range(N_CLASSES)) ax.set_xticklabels(CLASS_NAMES, rotation=45, ha='right', fontsize=7) ax.set_yticklabels(CLASS_NAMES, fontsize=7) else: fig, ax = plt.subplots(1, 1, figsize=(14, 12)) im = ax.imshow(conf_pct.numpy(), cmap='magma', vmin=0, vmax=100) ax.set_xlabel('Predicted class') ax.set_ylabel('True class') ax.set_title(f'Confusion Matrix — {val_acc:.1f}% ({ds_name})', fontsize=11) plt.colorbar(im, ax=ax, shrink=0.8) plt.tight_layout() save_fig('08_confusion_matrix.png') plt.close() print(f"\n ✓ All 8 visualizations saved to {OUT_DIR}/") # ══════════════════════════════════════════════════════════════════ # SUMMARY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print("SUMMARY") print(f"{'='*70}") print(f" Dataset: {ds_name} ({N_CLASSES} classes)") print(f" Params: {total_params:,}") print(f" Val accuracy: {val_acc:.1f}%") print(f" Eff dim: {eff_dim:.1f}/{embs.shape[1]}") print(f" Same-class cos: {same_cos:.4f}") print(f" Diff-class cos: {diff_cos:.4f}") print(f" Gap: {gap:.4f}") print(f" CV: {v_cv:.4f}") print(f" Anchors active: {n_active}/{n_anchors}") worst_i = min(range(N_CLASSES), key=lambda c: class_accs[c]) best_i = max(range(N_CLASSES), key=lambda c: class_accs[c]) print(f" Worst class: {CLASS_NAMES[worst_i]} ({class_accs[worst_i]:.1f}%)") print(f" Best class: {CLASS_NAMES[best_i]} ({class_accs[best_i]:.1f}%)") warnings = [] if n_active < n_anchors * 0.5: warnings.append(f"Anchor collapse: {n_active}/{n_anchors}") if eff_dim < 5: warnings.append(f"Embedding collapse: eff_dim={eff_dim:.1f}") if gap < 0.02: warnings.append(f"Low class separation: gap={gap:.4f}") if warnings: print(f"\n ⚠ WARNINGS: {', '.join(warnings)}") else: print(f"\n ✓ All diagnostics healthy") print(f"\n{'='*70}") print("ANALYSIS COMPLETE") print(f"{'='*70}") # ══════════════════════════════════════════════════════════════════ # PUSH IMAGES TO HUGGINGFACE # ══════════════════════════════════════════════════════════════════ if HF_PUSH: from huggingface_hub import upload_folder print(f"\n Uploading {OUT_DIR}/ → {HF_REPO_ID}/analysis/ ...") upload_folder( repo_id=HF_REPO_ID, folder_path=OUT_DIR, path_in_repo="analysis", commit_message=f"Analysis: val={val_acc:.1f}% CV={v_cv:.4f} {n_active}/{n_anchors} anchors", ) print(f" ✓ Done: https://huggingface.co/{HF_REPO_ID}/tree/main/analysis")