AbstractPhil commited on
Commit
2f0b0b7
Β·
verified Β·
1 Parent(s): 51cbe1d

Create analyze_weights.py

Browse files
Files changed (1) hide show
  1. analyze_weights.py +636 -0
analyze_weights.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GeoLIP Core β€” Full Analysis + Sphere Visualizations
4
+ =====================================================
5
+ Auto-detects CIFAR-10 vs CIFAR-100 from checkpoint config.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import math
13
+ import os
14
+ from collections import defaultdict
15
+ from torchvision import datasets, transforms
16
+
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ CKPT = "checkpoints/geolip_core_best.pt"
19
+ OUT_DIR = "analysis_out"
20
+ BATCH = 256
21
+
22
+ CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
23
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
24
+
25
+ CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
26
+ 'dog', 'frog', 'horse', 'ship', 'truck']
27
+
28
+ os.makedirs(OUT_DIR, exist_ok=True)
29
+
30
+ print("=" * 70)
31
+ print("GEOLIP CORE β€” ANALYSIS + SPHERE VISUALIZATIONS")
32
+ print(f" Checkpoint: {CKPT}")
33
+ print(f" Output: {OUT_DIR}/")
34
+ print("=" * 70)
35
+
36
+ # ══════════════════════════════════════════════════════════════════
37
+ # LOAD β€” auto-detect dataset from config
38
+ # ══════════════════════════════════════════════════════════════════
39
+
40
+ ckpt = torch.load(CKPT, map_location="cpu", weights_only=False)
41
+ cfg = ckpt["config"]
42
+ N_CLASSES = cfg.get('num_classes', 10)
43
+ print(f" Epoch: {ckpt['epoch']} Val acc: {ckpt['val_acc']:.1f}%")
44
+ print(f" Config: output_dim={cfg.get('output_dim')}, "
45
+ f"n_anchors={cfg.get('n_anchors')}, "
46
+ f"n_comp={cfg.get('n_comp')}, d_comp={cfg.get('d_comp')}, "
47
+ f"num_classes={N_CLASSES}")
48
+
49
+ if N_CLASSES <= 10:
50
+ CLASS_NAMES = CIFAR10_CLASSES[:N_CLASSES]
51
+ ds_cls = datasets.CIFAR10
52
+ ds_name = "CIFAR-10"
53
+ else:
54
+ ds_cls = datasets.CIFAR100
55
+ ds_name = "CIFAR-100"
56
+ _tmp = datasets.CIFAR100(root='./data', train=False, download=True)
57
+ CLASS_NAMES = _tmp.classes
58
+ del _tmp
59
+
60
+ print(f" Dataset: {ds_name} ({N_CLASSES} classes)")
61
+
62
+ model = GeoLIPCore(**cfg).to(DEVICE)
63
+ model.load_state_dict(ckpt["state_dict"])
64
+ model.eval()
65
+
66
+ val_transform = transforms.Compose([
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
69
+ ])
70
+ val_ds = ds_cls(root='./data', train=False, download=True, transform=val_transform)
71
+ val_loader = torch.utils.data.DataLoader(
72
+ val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)
73
+
74
+ total_params = sum(p.numel() for p in model.parameters())
75
+
76
+ # ══════════════════════════════════════════════════════════════════
77
+ # COLLECT ALL EMBEDDINGS + PREDICTIONS
78
+ # ══════════════════════════════════════════════════════════════════
79
+
80
+ print("\n Collecting embeddings...")
81
+ all_embs, all_tris, all_nearest, all_labels, all_preds, all_logits = [], [], [], [], [], []
82
+
83
+ with torch.no_grad():
84
+ for imgs, lbls in val_loader:
85
+ imgs = imgs.to(DEVICE)
86
+ out = model(imgs)
87
+ all_embs.append(out['embedding'].float().cpu())
88
+ all_tris.append(out['triangulation'].float().cpu())
89
+ all_nearest.append(out['nearest'].cpu())
90
+ all_labels.append(lbls)
91
+ all_preds.append(out['logits'].argmax(-1).cpu())
92
+ all_logits.append(out['logits'].float().cpu())
93
+
94
+ embs = torch.cat(all_embs)
95
+ tris = torch.cat(all_tris)
96
+ nearest = torch.cat(all_nearest)
97
+ labels = torch.cat(all_labels)
98
+ preds = torch.cat(all_preds)
99
+ logits = torch.cat(all_logits)
100
+
101
+ anchors = model.constellation.anchors.detach().float().cpu()
102
+ anchors_n = F.normalize(anchors, dim=-1)
103
+ n_anchors = anchors.shape[0]
104
+ embs_n = F.normalize(embs, dim=-1)
105
+
106
+ val_acc = (preds == labels).float().mean().item() * 100
107
+ print(f" Val accuracy: {val_acc:.1f}%")
108
+ print(f" Embeddings: {embs.shape}")
109
+ print(f" Anchors: {anchors.shape}")
110
+
111
+ # ══════════════════════════════════════════════════════════════════
112
+ # AUDIT 1: NUMERIC HEALTH
113
+ # ══════════════════════════════════════════════════════════════════
114
+
115
+ print(f"\n{'='*70}")
116
+ print("AUDIT 1: NUMERIC HEALTH")
117
+ print(f"{'='*70}")
118
+
119
+ issues = []
120
+ for name, param in model.named_parameters():
121
+ p = param.detach().float()
122
+ n_nan = torch.isnan(p).sum().item()
123
+ n_inf = torch.isinf(p).sum().item()
124
+ p_std = p.std().item() if p.numel() > 1 else 0
125
+ flags = []
126
+ if n_nan > 0: flags.append(f"NaN={n_nan}")
127
+ if n_inf > 0: flags.append(f"inf={n_inf}")
128
+ if p_std < 1e-8 and p.numel() > 1: flags.append(f"COLLAPSED(std={p_std:.2e})")
129
+ if flags:
130
+ print(f" ⚠ {name:<50} {' '.join(flags)}")
131
+ issues.append(name)
132
+
133
+ if not issues:
134
+ print(f" βœ“ All {total_params:,} parameters clean")
135
+
136
+ # ══════════════════════════════════════════════════════════════════
137
+ # AUDIT 2: PER-CLASS ACCURACY
138
+ # ══════════════════════════════════════════════════════════════════
139
+
140
+ print(f"\n{'='*70}")
141
+ print("AUDIT 2: PER-CLASS ACCURACY")
142
+ print(f"{'='*70}")
143
+
144
+ class_accs = []
145
+ for c in range(N_CLASSES):
146
+ mask = labels == c
147
+ acc = (preds[mask] == c).float().mean().item() * 100 if mask.sum() > 0 else 0
148
+ class_accs.append(acc)
149
+
150
+ if N_CLASSES <= 10:
151
+ for c in range(N_CLASSES):
152
+ print(f" {CLASS_NAMES[c]:<12}: {class_accs[c]:5.1f}%")
153
+ else:
154
+ sorted_idx = sorted(range(N_CLASSES), key=lambda c: class_accs[c])
155
+ print(f" Bottom 10:")
156
+ for c in sorted_idx[:10]:
157
+ print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%")
158
+ print(f" Top 10:")
159
+ for c in sorted_idx[-10:]:
160
+ print(f" {CLASS_NAMES[c]:<20}: {class_accs[c]:5.1f}%")
161
+ print(f" Mean: {np.mean(class_accs):.1f}% "
162
+ f"Median: {np.median(class_accs):.1f}% "
163
+ f"Std: {np.std(class_accs):.1f}%")
164
+
165
+ # ══════════════════════════════════════════════════════════════════
166
+ # AUDIT 3: EMBEDDING SPACE
167
+ # ══════════════════════════════════════════════════════════════════
168
+
169
+ print(f"\n{'='*70}")
170
+ print("AUDIT 3: EMBEDDING SPACE")
171
+ print(f"{'='*70}")
172
+
173
+ n_sample = min(2000, len(embs))
174
+ sim = embs_n[:n_sample] @ embs_n[:n_sample].T
175
+ sim_mask = ~torch.eye(n_sample, dtype=torch.bool)
176
+ labels_s = labels[:n_sample]
177
+ same_class = labels_s.unsqueeze(0) == labels_s.unsqueeze(1)
178
+ same_not_self = same_class & sim_mask
179
+ diff_class = ~same_class & sim_mask
180
+
181
+ self_sim = sim[sim_mask].mean().item()
182
+ same_cos = sim[same_not_self].mean().item() if same_not_self.any() else 0
183
+ diff_cos = sim[diff_class].mean().item() if diff_class.any() else 0
184
+ gap = same_cos - diff_cos
185
+
186
+ _, S, _ = torch.linalg.svd(embs_n[:512].float(), full_matrices=False)
187
+ p = S / S.sum()
188
+ eff_dim = p.pow(2).sum().reciprocal().item()
189
+
190
+ print(f" Self-similarity: {self_sim:.4f}")
191
+ print(f" Same-class cos: {same_cos:.4f}")
192
+ print(f" Diff-class cos: {diff_cos:.4f}")
193
+ print(f" Gap: {gap:.4f}")
194
+ print(f" Effective dim: {eff_dim:.1f}/{embs.shape[1]}")
195
+
196
+ # ══════════════════════════════════════════════════════════════════
197
+ # AUDIT 4: CONSTELLATION HEALTH
198
+ # ══════════════════════════════════════════════════════════════════
199
+
200
+ print(f"\n{'='*70}")
201
+ print("AUDIT 4: CONSTELLATION HEALTH")
202
+ print(f"{'='*70}")
203
+
204
+ anch_sim = anchors_n @ anchors_n.T
205
+ anch_mask = ~torch.eye(n_anchors, dtype=torch.bool)
206
+ anch_off = anch_sim[anch_mask]
207
+ n_active = nearest.unique().numel()
208
+
209
+ counts = torch.zeros(n_anchors, dtype=torch.long)
210
+ for a in range(n_anchors):
211
+ counts[a] = (nearest == a).sum()
212
+
213
+ print(f" Anchors: {n_anchors} Γ— {anchors.shape[1]}")
214
+ print(f" Pairwise cos: mean={anch_off.mean():.4f} max={anch_off.max():.4f}")
215
+ print(f" Active: {n_active}/{n_anchors}")
216
+ print(f" Utilization: min={counts.min().item()} max={counts.max().item()} "
217
+ f"mean={counts.float().mean():.1f} std={counts.float().std():.1f}")
218
+
219
+ # ══════════════════════════════════════════════════════════════════
220
+ # AUDIT 5: PENTACHORON CV
221
+ # ══════════════════════════════════════════════════════════════════
222
+
223
+ print(f"\n{'='*70}")
224
+ print("AUDIT 5: PENTACHORON CV")
225
+ print(f"{'='*70}")
226
+
227
+ sample = embs_n[:2000].to(DEVICE)
228
+ vols = []
229
+ with torch.no_grad():
230
+ for _ in range(500):
231
+ idx = torch.randperm(min(2000, len(sample)), device=DEVICE)[:5]
232
+ pts = sample[idx].unsqueeze(0).float()
233
+ gram = torch.bmm(pts, pts.transpose(1, 2))
234
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
235
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
236
+ d2 = F.relu(d2)
237
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
238
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
239
+ v2 = -torch.linalg.det(cm) / 9216
240
+ if v2[0].item() > 1e-20:
241
+ vols.append(v2[0].sqrt().cpu())
242
+
243
+ if len(vols) > 10:
244
+ vt = torch.stack(vols)
245
+ v_cv = (vt.std() / (vt.mean() + 1e-8)).item()
246
+ band = "βœ“ IN BAND" if 0.18 <= v_cv <= 0.25 else "βœ— outside"
247
+ print(f" CV: {v_cv:.4f} ({band})")
248
+ print(f" Vol mean: {vt.mean():.6f} std: {vt.std():.6f}")
249
+ else:
250
+ v_cv = 0
251
+ print(f" ⚠ Not enough valid pentachora ({len(vols)})")
252
+
253
+ # ══════════════════════════════════════════════════════════════════
254
+ # AUDIT 6: CONFIDENCE CALIBRATION
255
+ # ══════════════════════════════════════════════════════════════════
256
+
257
+ print(f"\n{'='*70}")
258
+ print("AUDIT 6: CONFIDENCE CALIBRATION")
259
+ print(f"{'='*70}")
260
+
261
+ probs = logits.softmax(-1)
262
+ conf = probs.max(dim=1).values
263
+ correct_mask = preds == labels
264
+
265
+ print(f" Correct: mean_conf={conf[correct_mask].mean():.4f} "
266
+ f"std={conf[correct_mask].std():.4f}")
267
+ if (~correct_mask).any():
268
+ wrong_conf = conf[~correct_mask]
269
+ overconf = (wrong_conf > 0.9).sum().item()
270
+ print(f" Wrong: mean_conf={wrong_conf.mean():.4f} "
271
+ f"std={wrong_conf.std():.4f}")
272
+ print(f" Overconfident wrong (>0.9): {overconf}/{wrong_conf.numel()} "
273
+ f"({100*overconf/max(wrong_conf.numel(),1):.1f}%)")
274
+
275
+ # ══════════════════════════════════════════════════════════════════
276
+ # AUDIT 7: GRADIENT FLOW
277
+ # ══════════════════════════════════════════════════════════════════
278
+
279
+ print(f"\n{'='*70}")
280
+ print("AUDIT 7: GRADIENT FLOW")
281
+ print(f"{'='*70}")
282
+
283
+ model.train()
284
+ model.zero_grad()
285
+ imgs_g, lbls_g = next(iter(val_loader))
286
+ imgs_g = imgs_g[:16].to(DEVICE)
287
+ lbls_g = lbls_g[:16].to(DEVICE)
288
+
289
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
290
+ out = model(imgs_g)
291
+ loss = F.cross_entropy(out['logits'], lbls_g) + 0.1 * out['embedding'].mean()
292
+ loss.backward()
293
+
294
+ grad_by_mod = defaultdict(list)
295
+ for name, param in model.named_parameters():
296
+ if param.grad is None: continue
297
+ gn = param.grad.detach().float().norm().item()
298
+ if "encoder" in name: mod = "encoder"
299
+ elif "constellation" in name: mod = "constellation"
300
+ elif "patchwork" in name: mod = "patchwork"
301
+ elif "classifier" in name: mod = "classifier"
302
+ else: mod = "other"
303
+ grad_by_mod[mod].append(gn)
304
+
305
+ for mod in sorted(grad_by_mod):
306
+ norms = grad_by_mod[mod]
307
+ print(f" {mod:<15}: mean={np.mean(norms):.6f} max={np.max(norms):.6f} "
308
+ f"({len(norms)} params)")
309
+ print(f" βœ“ All parameters receive gradient")
310
+ model.eval()
311
+
312
+
313
+ # ══════════════════════════════════════════════════════════════════
314
+ # VISUALIZATIONS
315
+ # ══════════════════════════════════════════════════════════════════
316
+
317
+ try:
318
+ import matplotlib
319
+ matplotlib.use('Agg')
320
+ import matplotlib.pyplot as plt
321
+ HAS_PLT = True
322
+ except ImportError:
323
+ HAS_PLT = False
324
+ print("\n ⚠ matplotlib not available, skipping visualizations")
325
+
326
+ if HAS_PLT:
327
+ if N_CLASSES <= 10:
328
+ CLASS_COLORS = [
329
+ '#e6194b', '#3cb44b', '#4363d8', '#f58231', '#911eb4',
330
+ '#42d4f4', '#f032e6', '#bfef45', '#469990', '#dcbeff']
331
+ else:
332
+ cmap = plt.cm.get_cmap('tab20', min(N_CLASSES, 20))
333
+ CLASS_COLORS = [matplotlib.colors.rgb2hex(cmap(i % 20)) for i in range(N_CLASSES)]
334
+
335
+ print(f"\n{'='*70}")
336
+ print("VISUALIZATIONS")
337
+ print(f"{'='*70}")
338
+
339
+ # PCA basis
340
+ embs_c = embs_n[:5000] - embs_n[:5000].mean(0, keepdim=True)
341
+ _, _, Vt = torch.linalg.svd(embs_c, full_matrices=False)
342
+ proj_2d = (embs_n @ Vt[:2].T).numpy()
343
+ proj_3d = (embs_n @ Vt[:3].T).numpy()
344
+ anch_2d = (anchors_n @ Vt[:2].T).numpy()
345
+ anch_3d = (anchors_n @ Vt[:3].T).numpy()
346
+ proj_labels = labels.numpy()
347
+
348
+ # ── [1] PCA embedding space ──
349
+ print(" [1/8] PCA projection...")
350
+ fig, ax = plt.subplots(1, 1, figsize=(12, 10))
351
+ for c in range(N_CLASSES):
352
+ mask = proj_labels[:5000] == c
353
+ if mask.sum() == 0: continue
354
+ lbl = CLASS_NAMES[c] if N_CLASSES <= 20 else None
355
+ ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1],
356
+ c=CLASS_COLORS[c], s=4, alpha=0.3, label=lbl)
357
+ ax.scatter(anch_2d[:, 0], anch_2d[:, 1],
358
+ c='black', s=60, marker='*', zorder=5, label='anchors')
359
+ if N_CLASSES <= 20:
360
+ ax.legend(fontsize=7, markerscale=2, loc='upper right', ncol=2)
361
+ ax.set_title(f'GeoLIP Core β€” PCA Embedding Space ({ds_name})\n'
362
+ f'val={val_acc:.1f}% | {total_params:,} params | '
363
+ f'CV={v_cv:.4f} | {n_active}/{n_anchors} anchors', fontsize=11)
364
+ ax.set_xlabel('PC1'); ax.set_ylabel('PC2')
365
+ ax.grid(True, alpha=0.2)
366
+ plt.tight_layout()
367
+ plt.savefig(f'{OUT_DIR}/01_pca_embedding_space.png', dpi=200)
368
+ plt.close()
369
+
370
+ # ── [2] Triangulation connections ──
371
+ print(" [2/8] Triangulation connections...")
372
+ fig, ax = plt.subplots(1, 1, figsize=(12, 10))
373
+ subset = min(500, len(embs))
374
+ for i in range(subset):
375
+ a_idx = nearest[i].item()
376
+ ax.plot([proj_2d[i, 0], anch_2d[a_idx, 0]],
377
+ [proj_2d[i, 1], anch_2d[a_idx, 1]],
378
+ c=CLASS_COLORS[labels[i].item()], alpha=0.06, linewidth=0.4)
379
+ for c in range(N_CLASSES):
380
+ mask = proj_labels[:5000] == c
381
+ if mask.sum() == 0: continue
382
+ ax.scatter(proj_2d[:5000][mask, 0], proj_2d[:5000][mask, 1],
383
+ c=CLASS_COLORS[c], s=3, alpha=0.25)
384
+ ax.scatter(anch_2d[:, 0], anch_2d[:, 1],
385
+ c='black', s=80, marker='*', zorder=5)
386
+ if n_anchors <= 128:
387
+ for a in range(n_anchors):
388
+ a_mask = nearest == a
389
+ if a_mask.sum() > 0:
390
+ dom_class = labels[a_mask].mode().values.item()
391
+ ax.annotate(str(dom_class), (anch_2d[a, 0], anch_2d[a, 1]),
392
+ fontsize=4, ha='center', va='center',
393
+ color='white', fontweight='bold',
394
+ bbox=dict(boxstyle='round,pad=0.1',
395
+ fc=CLASS_COLORS[dom_class], alpha=0.7))
396
+ ax.set_title(f'Triangulation: Image β†’ Nearest Anchor ({ds_name})', fontsize=11)
397
+ ax.grid(True, alpha=0.2)
398
+ plt.tight_layout()
399
+ plt.savefig(f'{OUT_DIR}/02_triangulation_connections.png', dpi=200)
400
+ plt.close()
401
+
402
+ # ── [3] 3D sphere ──
403
+ print(" [3/8] 3D sphere projection...")
404
+ fig = plt.figure(figsize=(12, 10))
405
+ ax = fig.add_subplot(111, projection='3d')
406
+ n_3d = min(3000, len(embs))
407
+ for c in range(min(N_CLASSES, 20)):
408
+ mask = proj_labels[:n_3d] == c
409
+ if mask.sum() == 0: continue
410
+ ax.scatter(proj_3d[:n_3d][mask, 0], proj_3d[:n_3d][mask, 1],
411
+ proj_3d[:n_3d][mask, 2],
412
+ c=CLASS_COLORS[c], s=3, alpha=0.25,
413
+ label=CLASS_NAMES[c] if N_CLASSES <= 20 else None)
414
+ ax.scatter(anch_3d[:, 0], anch_3d[:, 1], anch_3d[:, 2],
415
+ c='black', s=40, marker='*', zorder=5)
416
+ if N_CLASSES <= 20:
417
+ ax.legend(fontsize=6, markerscale=2, loc='upper left', ncol=2)
418
+ ax.set_title(f'3D PCA β€” Constellation on the Sphere\n'
419
+ f'{n_anchors} anchors, {N_CLASSES} classes', fontsize=11)
420
+ plt.tight_layout()
421
+ plt.savefig(f'{OUT_DIR}/03_3d_sphere.png', dpi=200)
422
+ plt.close()
423
+
424
+ # ── [4] Anchor-Class heatmap ──
425
+ print(" [4/8] Anchor-class assignment matrix...")
426
+ assign_mat = torch.zeros(N_CLASSES, n_anchors)
427
+ for c in range(N_CLASSES):
428
+ c_nearest = nearest[labels == c]
429
+ for a in range(n_anchors):
430
+ assign_mat[c, a] = (c_nearest == a).sum().float()
431
+ assign_norm = assign_mat / (assign_mat.sum(dim=1, keepdim=True) + 1e-8)
432
+
433
+ peak_class = assign_norm.argmax(dim=0)
434
+ sort_order = peak_class.argsort()
435
+ assign_sorted = assign_norm[:, sort_order]
436
+
437
+ h = max(6, N_CLASSES * 0.12)
438
+ fig, ax = plt.subplots(1, 1, figsize=(16, h))
439
+ im = ax.imshow(assign_sorted.numpy(), aspect='auto', cmap='YlOrRd')
440
+ if N_CLASSES <= 30:
441
+ ax.set_yticks(range(N_CLASSES))
442
+ ax.set_yticklabels(CLASS_NAMES, fontsize=max(4, 9 - N_CLASSES // 15))
443
+ ax.set_xlabel('Anchor index (sorted by peak class)')
444
+ ax.set_title(f'Class β†’ Anchor Assignment ({ds_name})', fontsize=11)
445
+ plt.colorbar(im, ax=ax, shrink=0.8)
446
+ plt.tight_layout()
447
+ plt.savefig(f'{OUT_DIR}/04_anchor_class_heatmap.png', dpi=200)
448
+ plt.close()
449
+
450
+ # ── [5] Triangulation profiles ──
451
+ print(" [5/8] Class triangulation profiles...")
452
+ if N_CLASSES <= 10:
453
+ show_classes = list(range(N_CLASSES))
454
+ else:
455
+ sorted_by_acc = sorted(range(N_CLASSES), key=lambda c: class_accs[c])
456
+ show_classes = sorted_by_acc[:5] + sorted_by_acc[-5:]
457
+
458
+ nrows, ncols = 2, 5
459
+ fig, axes = plt.subplots(nrows, ncols, figsize=(20, 8))
460
+ for idx, c in enumerate(show_classes):
461
+ ax = axes[idx // ncols][idx % ncols]
462
+ c_tris = tris[labels == c]
463
+ if len(c_tris) == 0: continue
464
+ mean_tri = c_tris.mean(0).numpy()
465
+ std_tri = c_tris.std(0).numpy()
466
+ x = np.arange(n_anchors)
467
+ color = CLASS_COLORS[c]
468
+ ax.fill_between(x, mean_tri - std_tri, mean_tri + std_tri,
469
+ alpha=0.3, color=color)
470
+ ax.plot(x, mean_tri, color=color, linewidth=1.5)
471
+ ax.set_title(f'{CLASS_NAMES[c]} ({class_accs[c]:.0f}%)',
472
+ fontsize=9, fontweight='bold', color=color)
473
+ ax.set_ylim(0, max(1.6, mean_tri.max() * 1.2))
474
+ ax.tick_params(labelsize=5)
475
+ tag = "all classes" if N_CLASSES <= 10 else "5 worst + 5 best"
476
+ plt.suptitle(f'Triangulation Fingerprints ({tag})', fontsize=12)
477
+ plt.tight_layout()
478
+ plt.savefig(f'{OUT_DIR}/05_triangulation_profiles.png', dpi=200)
479
+ plt.close()
480
+
481
+ # ── [6] Anchor utilization ──
482
+ print(" [6/8] Anchor utilization...")
483
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
484
+
485
+ sorted_counts, _ = counts.sort(descending=True)
486
+ ax1.bar(range(n_anchors), sorted_counts.numpy(),
487
+ color=['#2196F3' if c > 0 else '#F44336' for c in sorted_counts], width=1.0)
488
+ ax1.set_xlabel('Anchor (sorted)')
489
+ ax1.set_ylabel('Assigned samples')
490
+ ax1.set_title(f'Anchor Utilization ({n_active}/{n_anchors} active)')
491
+ ax1.axhline(y=len(labels) / n_anchors, color='gray', linestyle='--', alpha=0.5)
492
+
493
+ # Per-class anchor entropy
494
+ entropies = []
495
+ for c in range(N_CLASSES):
496
+ c_nearest = nearest[labels == c]
497
+ dist = torch.zeros(n_anchors)
498
+ for a in range(n_anchors):
499
+ dist[a] = (c_nearest == a).sum().float()
500
+ dist = dist / (dist.sum() + 1e-8)
501
+ ent = -(dist * (dist + 1e-10).log()).sum().item()
502
+ entropies.append(ent)
503
+
504
+ if N_CLASSES <= 20:
505
+ ax2.barh(range(N_CLASSES), entropies,
506
+ color=[CLASS_COLORS[c] for c in range(N_CLASSES)])
507
+ ax2.set_yticks(range(N_CLASSES))
508
+ ax2.set_yticklabels(CLASS_NAMES, fontsize=8)
509
+ ax2.set_xlabel('Anchor assignment entropy')
510
+ else:
511
+ ax2.hist(entropies, bins=30, color='steelblue', edgecolor='white')
512
+ ax2.set_xlabel('Anchor assignment entropy')
513
+ ax2.set_ylabel('Number of classes')
514
+
515
+ # Gini
516
+ c_sorted = counts.float().sort().values
517
+ cum = c_sorted.cumsum(0)
518
+ gini = (1 - 2 * cum.sum() / (len(c_sorted) * c_sorted.sum() + 1e-8)).item()
519
+ ax2.set_title(f'Anchor Spread (Gini={gini:.3f})')
520
+ plt.tight_layout()
521
+ plt.savefig(f'{OUT_DIR}/06_anchor_utilization.png', dpi=200)
522
+ plt.close()
523
+
524
+ # ── [7] Patchwork compartment responses ──
525
+ print(" [7/8] Patchwork compartment responses...")
526
+ n_comp = cfg.get('n_comp', 8)
527
+ asgn = model.patchwork.asgn.cpu()
528
+
529
+ if N_CLASSES <= 10:
530
+ show_c = list(range(N_CLASSES))
531
+ else:
532
+ show_c = show_classes
533
+
534
+ ncols_pw = min(4, n_comp)
535
+ nrows_pw = math.ceil(n_comp / ncols_pw)
536
+ fig, axes = plt.subplots(nrows_pw, ncols_pw, figsize=(4 * ncols_pw, 3 * nrows_pw))
537
+ if n_comp == 1: axes = [[axes]]
538
+ elif nrows_pw == 1: axes = [axes if isinstance(axes, list) else list(axes)]
539
+ elif ncols_pw == 1: axes = [[a] for a in axes]
540
+ axes_flat = [axes[r][c] for r in range(nrows_pw) for c in range(ncols_pw)]
541
+
542
+ for k in range(min(n_comp, len(axes_flat))):
543
+ ax = axes_flat[k]
544
+ comp_tris = tris[:, asgn == k]
545
+ class_means = []
546
+ class_labels_show = []
547
+ for c in show_c:
548
+ cm = comp_tris[labels == c]
549
+ if len(cm) > 0:
550
+ class_means.append(cm.mean(0).numpy())
551
+ class_labels_show.append(CLASS_NAMES[c])
552
+ if not class_means: continue
553
+ class_means = np.stack(class_means)
554
+ ax.imshow(class_means, aspect='auto', cmap='viridis')
555
+ ax.set_yticks(range(len(class_labels_show)))
556
+ ax.set_yticklabels(class_labels_show, fontsize=6)
557
+ ax.set_title(f'Comp {k}', fontsize=9)
558
+ for k in range(n_comp, len(axes_flat)):
559
+ axes_flat[k].set_visible(False)
560
+ plt.suptitle('Patchwork Compartment Responses by Class', fontsize=12)
561
+ plt.tight_layout()
562
+ plt.savefig(f'{OUT_DIR}/07_patchwork_compartments.png', dpi=200)
563
+ plt.close()
564
+
565
+ # ── [8] Confusion matrix ──
566
+ print(" [8/8] Confusion matrix...")
567
+ conf_mat = torch.zeros(N_CLASSES, N_CLASSES, dtype=torch.long)
568
+ for i in range(len(labels)):
569
+ conf_mat[labels[i], preds[i]] += 1
570
+ conf_pct = conf_mat.float() / (conf_mat.sum(dim=1, keepdim=True) + 1e-8) * 100
571
+
572
+ if N_CLASSES <= 20:
573
+ fig, ax = plt.subplots(1, 1, figsize=(8, 7))
574
+ im = ax.imshow(conf_pct.numpy(), cmap='Blues', vmin=0, vmax=100)
575
+ for i in range(N_CLASSES):
576
+ for j in range(N_CLASSES):
577
+ v = conf_pct[i, j].item()
578
+ ax.text(j, i, f'{v:.0f}', ha='center', va='center',
579
+ fontsize=max(4, 8 - N_CLASSES // 5),
580
+ color='white' if v > 50 else 'black')
581
+ ax.set_xticks(range(N_CLASSES))
582
+ ax.set_yticks(range(N_CLASSES))
583
+ ax.set_xticklabels(CLASS_NAMES, rotation=45, ha='right', fontsize=7)
584
+ ax.set_yticklabels(CLASS_NAMES, fontsize=7)
585
+ else:
586
+ fig, ax = plt.subplots(1, 1, figsize=(14, 12))
587
+ im = ax.imshow(conf_pct.numpy(), cmap='Blues', vmin=0, vmax=100)
588
+ ax.set_xlabel('Predicted class')
589
+ ax.set_ylabel('True class')
590
+ ax.set_title(f'Confusion Matrix β€” {val_acc:.1f}% ({ds_name})', fontsize=11)
591
+ plt.colorbar(im, ax=ax, shrink=0.8)
592
+ plt.tight_layout()
593
+ plt.savefig(f'{OUT_DIR}/08_confusion_matrix.png', dpi=200)
594
+ plt.close()
595
+
596
+ print(f"\n βœ“ All 8 visualizations saved to {OUT_DIR}/")
597
+
598
+
599
+ # ══════════════════════════════════════════════════════════════════
600
+ # SUMMARY
601
+ # ══════════════════════════════════════════════════════════════════
602
+
603
+ print(f"\n{'='*70}")
604
+ print("SUMMARY")
605
+ print(f"{'='*70}")
606
+ print(f" Dataset: {ds_name} ({N_CLASSES} classes)")
607
+ print(f" Params: {total_params:,}")
608
+ print(f" Val accuracy: {val_acc:.1f}%")
609
+ print(f" Eff dim: {eff_dim:.1f}/{embs.shape[1]}")
610
+ print(f" Same-class cos: {same_cos:.4f}")
611
+ print(f" Diff-class cos: {diff_cos:.4f}")
612
+ print(f" Gap: {gap:.4f}")
613
+ print(f" CV: {v_cv:.4f}")
614
+ print(f" Anchors active: {n_active}/{n_anchors}")
615
+
616
+ worst_i = min(range(N_CLASSES), key=lambda c: class_accs[c])
617
+ best_i = max(range(N_CLASSES), key=lambda c: class_accs[c])
618
+ print(f" Worst class: {CLASS_NAMES[worst_i]} ({class_accs[worst_i]:.1f}%)")
619
+ print(f" Best class: {CLASS_NAMES[best_i]} ({class_accs[best_i]:.1f}%)")
620
+
621
+ warnings = []
622
+ if n_active < n_anchors * 0.5:
623
+ warnings.append(f"Anchor collapse: {n_active}/{n_anchors}")
624
+ if eff_dim < 5:
625
+ warnings.append(f"Embedding collapse: eff_dim={eff_dim:.1f}")
626
+ if gap < 0.02:
627
+ warnings.append(f"Low class separation: gap={gap:.4f}")
628
+
629
+ if warnings:
630
+ print(f"\n ⚠ WARNINGS: {', '.join(warnings)}")
631
+ else:
632
+ print(f"\n βœ“ All diagnostics healthy")
633
+
634
+ print(f"\n{'='*70}")
635
+ print("ANALYSIS COMPLETE")
636
+ print(f"{'='*70}")