Create hypersphere_convergence_analysis.py

#1
Files changed (1) hide show
  1. hypersphere_convergence_analysis.py +595 -0
hypersphere_convergence_analysis.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GEOLIP HYPERSPHERE MANIFOLD VISUALIZATION
4
+ ==========================================
5
+ 6-panel manifold view + 3-panel expert perspective divergence.
6
+ S^255 projected to S^2 via PCA.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import matplotlib
13
+ matplotlib.use('Agg')
14
+ import matplotlib.pyplot as plt
15
+ from mpl_toolkits.mplot3d import Axes3D
16
+ import math
17
+
18
+ DEVICE = "cpu"
19
+
20
+ # ══════════════════════════════════════════════════════════════════
21
+ # LOAD + EMBED
22
+ # ══════════════════════════════════════════════════════════════════
23
+
24
+ print("Loading soup...")
25
+ ckpt = torch.load("checkpoints/dual_stream_best.pt", map_location="cpu", weights_only=False)
26
+ sd = ckpt["state_dict"]
27
+ D_ANCHOR = ckpt["config"]["d_anchor"]
28
+ N_ANCHORS = ckpt["config"]["n_anchors"]
29
+ anchors = F.normalize(sd["constellation.anchors"], dim=-1)
30
+
31
+ EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
32
+ COCO_CLASSES = [
33
+ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
34
+ "truck", "boat", "traffic light", "fire hydrant", "stop sign",
35
+ "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
36
+ "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
37
+ "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
38
+ "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
39
+ "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
40
+ "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
41
+ "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
42
+ "couch", "potted plant", "bed", "dining table", "toilet", "tv",
43
+ "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
44
+ "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
45
+ "scissors", "teddy bear", "hair drier", "toothbrush",
46
+ ]
47
+
48
+ print("Loading features...")
49
+ from datasets import load_dataset
50
+
51
+ ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val")
52
+ val_ids = ref["image_id"]; N_val = len(val_ids)
53
+ val_id_map = {iid: i for i, iid in enumerate(val_ids)}
54
+ val_labels = torch.zeros(N_val, 80)
55
+ for i, labs in enumerate(ref["labels"]):
56
+ for l in labs:
57
+ if l < 80: val_labels[i, l] = 1.0
58
+
59
+ val_raw = {}
60
+ for name in EXPERTS:
61
+ ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
62
+ feats = torch.zeros(N_val, 768)
63
+ for row in ds:
64
+ if row["image_id"] in val_id_map:
65
+ feats[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
66
+ val_raw[name] = feats; del ds
67
+
68
+ def project_expert(feats, i):
69
+ prefix = f"projectors.{i}.proj_shared" if f"projectors.{i}.proj_shared.0.weight" in sd else f"projectors.{i}.proj"
70
+ W = sd[f"{prefix}.0.weight"]
71
+ b = sd[f"{prefix}.0.bias"]
72
+ lw = sd[f"{prefix}.1.weight"]
73
+ lb = sd[f"{prefix}.1.bias"]
74
+ x = feats @ W.T + b
75
+ mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False)
76
+ x = (x - mu) / (var + 1e-5).sqrt() * lw + lb
77
+ return F.normalize(x, dim=-1)
78
+
79
+ print("Generating embeddings...")
80
+ with torch.no_grad():
81
+ projected = [project_expert(val_raw[name], i) for i, name in enumerate(EXPERTS)]
82
+ fused = F.normalize(sum(projected) / 3, dim=-1)
83
+
84
+ # ══════════════════════════════════════════════════════════════════
85
+ # PCA → 3D
86
+ # ══════════════════════════════════════════════════════════════════
87
+
88
+ emb = fused.numpy()
89
+ emb_centered = emb - emb.mean(axis=0, keepdims=True)
90
+ U, S, Vt = np.linalg.svd(emb_centered[:5000], full_matrices=False)
91
+ pca3 = Vt[:3]
92
+
93
+ emb_3d = emb @ pca3.T
94
+ anchors_3d = anchors.numpy() @ pca3.T
95
+
96
+ var_explained = S[:3]**2 / (S**2).sum()
97
+ print(f"PCA 3D variance: {var_explained.sum()*100:.1f}% "
98
+ f"({var_explained[0]*100:.1f}%, {var_explained[1]*100:.1f}%, {var_explained[2]*100:.1f}%)")
99
+
100
+ def to_sphere(pts):
101
+ norms = np.linalg.norm(pts, axis=-1, keepdims=True)
102
+ return pts / (norms + 1e-8)
103
+
104
+ emb_s = to_sphere(emb_3d)
105
+ anchors_s = to_sphere(anchors_3d)
106
+
107
+ # Reference sphere wireframe
108
+ phi = np.linspace(0, 2*np.pi, 60)
109
+ theta = np.linspace(0, np.pi, 30)
110
+ xs = np.outer(np.cos(phi), np.sin(theta))
111
+ ys = np.outer(np.sin(phi), np.sin(theta))
112
+ zs = np.outer(np.ones_like(phi), np.cos(theta))
113
+
114
+ # Primary class per image (most specific)
115
+ class_freq = val_labels.sum(0).numpy()
116
+ primary_class = np.zeros(N_val, dtype=int)
117
+ for i in range(N_val):
118
+ present = np.where(val_labels[i].numpy() > 0)[0]
119
+ if len(present) > 0:
120
+ primary_class[i] = present[class_freq[present].argmin()]
121
+
122
+ cmap20 = plt.cm.tab20(np.linspace(0, 1, 20))
123
+ class_colors = np.array([cmap20[primary_class[i] % 20] for i in range(N_val)])
124
+
125
+
126
+ # ══════════════════════════════════════════════════════════════════
127
+ # HELPER
128
+ # ══════════════════════════════════════════════════════════════════
129
+
130
+ def setup_ax(ax, title):
131
+ ax.set_facecolor('black')
132
+ ax.xaxis.pane.fill = False; ax.yaxis.pane.fill = False; ax.zaxis.pane.fill = False
133
+ ax.xaxis.pane.set_edgecolor('gray'); ax.yaxis.pane.set_edgecolor('gray')
134
+ ax.zaxis.pane.set_edgecolor('gray')
135
+ ax.set_xlabel('PC1', color='gray', fontsize=8)
136
+ ax.set_ylabel('PC2', color='gray', fontsize=8)
137
+ ax.set_zlabel('PC3', color='gray', fontsize=8)
138
+ ax.tick_params(colors='gray', labelsize=6)
139
+ ax.set_title(title, color='white', fontsize=11, pad=10)
140
+ ax.plot_wireframe(xs*0.98, ys*0.98, zs*0.98, alpha=0.03, color='white', linewidth=0.3)
141
+ ax.set_xlim(-1.3, 1.3); ax.set_ylim(-1.3, 1.3); ax.set_zlim(-1.3, 1.3)
142
+
143
+
144
+ # ══════════════════════════════════════════════════════════════════
145
+ # FIGURE 1: 6-PANEL MANIFOLD VIEW
146
+ # ══════════════════════════════════════════════════════════════════
147
+
148
+ print("Rendering figure 1...")
149
+ fig = plt.figure(figsize=(24, 16), facecolor='black')
150
+ fig.suptitle(
151
+ 'GeoLIP Hypersphere Manifold — S²⁵⁵ projected to S²\n'
152
+ f'{N_ANCHORS} anchors × {D_ANCHOR}-d × 3 experts | mAP={ckpt["mAP"]:.3f} | eff_dim=76.9',
153
+ color='white', fontsize=16, y=0.98)
154
+
155
+ # Panel 1: Full manifold
156
+ ax1 = fig.add_subplot(231, projection='3d')
157
+ setup_ax(ax1, f'Full Manifold — {N_val} embeddings + {N_ANCHORS} anchors')
158
+ ax1.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2],
159
+ c=class_colors, s=1, alpha=0.3)
160
+ ax1.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
161
+ c='red', s=8, alpha=0.6, marker='^')
162
+
163
+ # Panel 2: Class centroids
164
+ ax2 = fig.add_subplot(232, projection='3d')
165
+ setup_ax(ax2, '80 COCO Class Centroids')
166
+ centroids = np.zeros((80, emb.shape[1]))
167
+ for c in range(80):
168
+ mask = val_labels[:, c].numpy() > 0
169
+ if mask.sum() > 0:
170
+ centroids[c] = emb[mask].mean(0)
171
+ centroids_3d = to_sphere(centroids @ pca3.T)
172
+ sizes = val_labels.sum(0).numpy()
173
+ sizes_scaled = 20 + 200 * (sizes / sizes.max())
174
+ colors80 = plt.cm.hsv(np.linspace(0, 0.95, 80))
175
+ ax2.scatter(centroids_3d[:, 0], centroids_3d[:, 1], centroids_3d[:, 2],
176
+ c=colors80, s=sizes_scaled, alpha=0.8, edgecolors='white', linewidth=0.3)
177
+ for c in [0, 2, 14, 15, 16, 22, 23, 56, 62]:
178
+ if sizes[c] > 30:
179
+ ax2.text(centroids_3d[c, 0]*1.15, centroids_3d[c, 1]*1.15,
180
+ centroids_3d[c, 2]*1.15,
181
+ COCO_CLASSES[c], color='white', fontsize=7, ha='center')
182
+
183
+ # Panel 3: 50 random with anchor connections
184
+ ax3 = fig.add_subplot(233, projection='3d')
185
+ setup_ax(ax3, '50 Random — nearest anchor connections')
186
+ np.random.seed(42)
187
+ idx50 = np.random.choice(N_val, 50, replace=False)
188
+ emb_50 = emb_s[idx50]
189
+ colors_50 = class_colors[idx50]
190
+ with torch.no_grad():
191
+ cos_50 = fused[idx50] @ anchors.T
192
+ nearest_50 = cos_50.argmax(-1).numpy()
193
+ ax3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
194
+ c='red', s=4, alpha=0.2, marker='^')
195
+ ax3.scatter(emb_50[:, 0], emb_50[:, 1], emb_50[:, 2],
196
+ c=colors_50, s=40, alpha=0.9, edgecolors='white', linewidth=0.5)
197
+ for i in range(50):
198
+ a = nearest_50[i]
199
+ ax3.plot([emb_50[i, 0], anchors_s[a, 0]],
200
+ [emb_50[i, 1], anchors_s[a, 1]],
201
+ [emb_50[i, 2], anchors_s[a, 2]],
202
+ color='yellow', alpha=0.3, linewidth=0.5)
203
+
204
+ # Panel 4: 10 random — triangulation heatmap
205
+ ax4 = fig.add_subplot(234, projection='3d')
206
+ setup_ax(ax4, '10 Random — anchor affinity heatmap')
207
+ idx10 = np.random.choice(N_val, 10, replace=False)
208
+ emb_10 = emb_s[idx10]
209
+ with torch.no_grad():
210
+ cos_10 = (fused[idx10] @ anchors.T).numpy()
211
+ mean_cos = cos_10.mean(0)
212
+ anchor_heat = (mean_cos - mean_cos.min()) / (mean_cos.max() - mean_cos.min() + 1e-8)
213
+ anchor_colors = plt.cm.hot(anchor_heat)
214
+ ax4.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
215
+ c=anchor_colors, s=10, alpha=0.6)
216
+ ax4.scatter(emb_10[:, 0], emb_10[:, 1], emb_10[:, 2],
217
+ c='cyan', s=80, alpha=1.0, edgecolors='white', linewidth=1, zorder=10)
218
+
219
+ # Panel 5: Single encoding
220
+ ax5 = fig.add_subplot(235, projection='3d')
221
+ single_idx = 42
222
+ single_class = primary_class[single_idx]
223
+ setup_ax(ax5, f'Single Encoding: "{COCO_CLASSES[single_class]}" — top 5 anchors')
224
+ with torch.no_grad():
225
+ cos_single = (fused[single_idx] @ anchors.T).numpy()
226
+ single_heat = (cos_single - cos_single.min()) / (cos_single.max() - cos_single.min() + 1e-8)
227
+ single_colors = plt.cm.plasma(single_heat)
228
+ single_sizes = 2 + 50 * single_heat**3
229
+ ax5.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
230
+ c=single_colors, s=single_sizes, alpha=0.7)
231
+ single_pt = emb_s[single_idx]
232
+ ax5.scatter([single_pt[0]], [single_pt[1]], [single_pt[2]],
233
+ c='lime', s=150, alpha=1.0, edgecolors='white', linewidth=2,
234
+ zorder=10, marker='*')
235
+ top5 = np.argsort(cos_single)[::-1][:5]
236
+ for a in top5:
237
+ ax5.plot([single_pt[0], anchors_s[a, 0]],
238
+ [single_pt[1], anchors_s[a, 1]],
239
+ [single_pt[2], anchors_s[a, 2]],
240
+ color='lime', alpha=0.6, linewidth=1.5)
241
+
242
+ # Panel 6: Radial deviation
243
+ ax6 = fig.add_subplot(236, projection='3d')
244
+ radii = np.linalg.norm(emb_3d, axis=-1)
245
+ setup_ax(ax6, f'PCA Projection Radii — mean={radii.mean():.4f} std={radii.std():.4f}')
246
+ radius_dev = radii - radii.mean()
247
+ dev_norm = (radius_dev - radius_dev.min()) / (radius_dev.max() - radius_dev.min() + 1e-8)
248
+ dev_colors = plt.cm.coolwarm(dev_norm)
249
+ scale = 1.0 / radii.max()
250
+ ax6.scatter(emb_3d[:, 0]*scale, emb_3d[:, 1]*scale, emb_3d[:, 2]*scale,
251
+ c=dev_colors, s=2, alpha=0.4)
252
+
253
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
254
+ plt.savefig("hypersphere_manifold.png", dpi=200, facecolor='black',
255
+ bbox_inches='tight', pad_inches=0.3)
256
+ print("Saved: hypersphere_manifold.png")
257
+ plt.close()
258
+
259
+
260
+ # ══════════════════════════════════════════════════════════════════
261
+ # FIGURE 2: EXPERT PERSPECTIVES
262
+ # ══════════════════════════════════════════════════════════════════
263
+
264
+ print("Rendering figure 2...")
265
+ fig2 = plt.figure(figsize=(21, 7), facecolor='black')
266
+ fig2.suptitle('Expert Perspective Divergence — Same sphere, three lenses',
267
+ color='white', fontsize=14, y=1.02)
268
+
269
+ has_expert_rot = f"constellation.expert_rotations.0" in sd
270
+ if has_expert_rot:
271
+ expert_R = [sd[f"constellation.expert_rotations.{i}"] for i in range(3)]
272
+ expert_W = [sd[f"constellation.expert_whiteners.{i}"] for i in range(3)]
273
+ expert_mu = [sd[f"constellation.expert_means.{i}"] for i in range(3)]
274
+ else:
275
+ expert_R = [torch.eye(D_ANCHOR) for _ in range(3)]
276
+ expert_W = [torch.eye(D_ANCHOR) for _ in range(3)]
277
+ expert_mu = [torch.zeros(D_ANCHOR) for _ in range(3)]
278
+
279
+ with torch.no_grad():
280
+ for i, name in enumerate(EXPERTS):
281
+ ax = fig2.add_subplot(1, 3, i+1, projection='3d')
282
+
283
+ if has_expert_rot:
284
+ centered = fused.float() - expert_mu[i]
285
+ whitened = centered @ expert_W[i]
286
+ rotated = F.normalize(whitened @ expert_R[i].T, dim=-1)
287
+ elif f"projectors.{i}.proj_native.0.weight" in sd:
288
+ W = sd[f"projectors.{i}.proj_native.0.weight"]
289
+ b = sd[f"projectors.{i}.proj_native.0.bias"]
290
+ lw = sd[f"projectors.{i}.proj_native.1.weight"]
291
+ lb = sd[f"projectors.{i}.proj_native.1.bias"]
292
+ x = val_raw[name] @ W.T + b
293
+ mu_v = x.mean(-1, keepdim=True); var_v = x.var(-1, keepdim=True, unbiased=False)
294
+ x = (x - mu_v) / (var_v + 1e-5).sqrt() * lw + lb
295
+ rotated = F.normalize(x, dim=-1)
296
+ else:
297
+ rotated = projected[i]
298
+
299
+ rot_np = rotated.numpy()
300
+ rot_c = rot_np - rot_np.mean(axis=0, keepdims=True)
301
+ _, S_r, Vt_r = np.linalg.svd(rot_c[:5000], full_matrices=False)
302
+ rot_3d = to_sphere(rot_np @ Vt_r[:3].T)
303
+
304
+ var_exp = S_r[:3]**2 / (S_r**2).sum()
305
+ setup_ax(ax, f'{name[:25]}\nPC variance: {var_exp.sum()*100:.1f}%')
306
+ ax.scatter(rot_3d[:, 0], rot_3d[:, 1], rot_3d[:, 2],
307
+ c=class_colors, s=2, alpha=0.4)
308
+
309
+ plt.tight_layout()
310
+ plt.savefig("expert_perspectives.png", dpi=200, facecolor='black',
311
+ bbox_inches='tight', pad_inches=0.3)
312
+ print("Saved: expert_perspectives.png")
313
+ plt.close()
314
+
315
+
316
+ # ══════════════════════════════════════════════════════════════════
317
+ # FIGURE 3: ANCHORS ONLY
318
+ # ══════════════════════════════════════════════════════════════════
319
+
320
+ print("Rendering figure 3 — anchors only...")
321
+
322
+ # Anchor visit counts for coloring
323
+ with torch.no_grad():
324
+ cos_all = fused @ anchors.T
325
+ nearest_all = cos_all.argmax(dim=-1)
326
+ vc = torch.zeros(N_ANCHORS)
327
+ for n in nearest_all:
328
+ vc[n] += 1
329
+ vc_np = vc.numpy()
330
+
331
+ fig3 = plt.figure(figsize=(24, 8), facecolor='black')
332
+ fig3.suptitle(f'Constellation — {N_ANCHORS} anchors × {D_ANCHOR}-d on S²⁵⁵',
333
+ color='white', fontsize=14, y=1.02)
334
+
335
+ # Panel 1: Anchors colored by visit count
336
+ ax_a1 = fig3.add_subplot(131, projection='3d')
337
+ setup_ax(ax_a1, f'Anchor Utilization — {int((vc_np>0).sum())}/{N_ANCHORS} active')
338
+ heat = np.zeros(N_ANCHORS)
339
+ active_mask = vc_np > 0
340
+ heat[active_mask] = np.log1p(vc_np[active_mask])
341
+ heat = heat / (heat.max() + 1e-8)
342
+ a_colors = plt.cm.inferno(heat)
343
+ a_sizes = 5 + 60 * heat
344
+ # Dead anchors in blue
345
+ dead_mask = vc_np == 0
346
+ ax_a1.scatter(anchors_s[dead_mask, 0], anchors_s[dead_mask, 1], anchors_s[dead_mask, 2],
347
+ c='dodgerblue', s=8, alpha=0.4, marker='x', label=f'dead ({int(dead_mask.sum())})')
348
+ ax_a1.scatter(anchors_s[active_mask, 0], anchors_s[active_mask, 1], anchors_s[active_mask, 2],
349
+ c=a_colors[active_mask], s=a_sizes[active_mask], alpha=0.8)
350
+
351
+ # Panel 2: Anchors colored by nearest neighbor distance
352
+ ax_a2 = fig3.add_subplot(132, projection='3d')
353
+ anchor_sim = (anchors.numpy() @ anchors.numpy().T)
354
+ np.fill_diagonal(anchor_sim, -1)
355
+ max_neighbor_cos = anchor_sim.max(axis=1)
356
+ nn_heat = (max_neighbor_cos - max_neighbor_cos.min()) / (max_neighbor_cos.max() - max_neighbor_cos.min() + 1e-8)
357
+ nn_colors = plt.cm.viridis(nn_heat)
358
+ setup_ax(ax_a2, f'Anchor Isolation — nearest neighbor cosine\n'
359
+ f'mean={max_neighbor_cos.mean():.3f} max={max_neighbor_cos.max():.3f}')
360
+ ax_a2.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
361
+ c=nn_colors, s=15, alpha=0.8)
362
+
363
+ # Panel 3: Anchors colored by expert divergence at that anchor
364
+ ax_a3 = fig3.add_subplot(133, projection='3d')
365
+ with torch.no_grad():
366
+ expert_tri_stack = []
367
+ if has_expert_rot:
368
+ for i in range(3):
369
+ centered = fused.float() - expert_mu[i]
370
+ whitened = centered @ expert_W[i]
371
+ rotated = F.normalize(whitened @ expert_R[i].T, dim=-1)
372
+ expert_tri_stack.append(1.0 - (rotated @ anchors.T))
373
+ elif f"projectors.0.proj_native.0.weight" in sd:
374
+ def _pn(feats, i):
375
+ W = sd[f"projectors.{i}.proj_native.0.weight"]
376
+ b = sd[f"projectors.{i}.proj_native.0.bias"]
377
+ lw = sd[f"projectors.{i}.proj_native.1.weight"]
378
+ lb = sd[f"projectors.{i}.proj_native.1.bias"]
379
+ x = feats @ W.T + b
380
+ mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False)
381
+ x = (x - mu) / (var + 1e-5).sqrt() * lw + lb
382
+ return F.normalize(x, dim=-1)
383
+ for i, name in enumerate(EXPERTS):
384
+ nat = _pn(val_raw[name], i)
385
+ expert_tri_stack.append(1.0 - (nat @ anchors.T))
386
+ else:
387
+ for p in projected:
388
+ expert_tri_stack.append(1.0 - (p @ anchors.T))
389
+ tri_stack = torch.stack(expert_tri_stack, dim=-1)
390
+ per_anchor_div = tri_stack.std(dim=-1).mean(dim=0).numpy()
391
+
392
+ div_heat = (per_anchor_div - per_anchor_div.min()) / (per_anchor_div.max() - per_anchor_div.min() + 1e-8)
393
+ div_colors = plt.cm.coolwarm(div_heat)
394
+ setup_ax(ax_a3, f'Expert Divergence per Anchor\n'
395
+ f'mean={per_anchor_div.mean():.4f} range=[{per_anchor_div.min():.4f}, {per_anchor_div.max():.4f}]')
396
+ ax_a3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
397
+ c=div_colors, s=15, alpha=0.8)
398
+
399
+ # Add connections between closest anchor pairs (top 20)
400
+ flat_sim = anchor_sim.copy()
401
+ np.fill_diagonal(flat_sim, -999)
402
+ for panel_ax in [ax_a1, ax_a2]:
403
+ for _ in range(20):
404
+ idx_flat = np.argmax(flat_sim)
405
+ i_a, j_a = np.unravel_index(idx_flat, flat_sim.shape)
406
+ flat_sim[i_a, j_a] = -999; flat_sim[j_a, i_a] = -999
407
+ panel_ax.plot([anchors_s[i_a, 0], anchors_s[j_a, 0]],
408
+ [anchors_s[i_a, 1], anchors_s[j_a, 1]],
409
+ [anchors_s[i_a, 2], anchors_s[j_a, 2]],
410
+ color='white', alpha=0.15, linewidth=0.5)
411
+
412
+ plt.tight_layout()
413
+ plt.savefig("anchors_only.png", dpi=200, facecolor='black',
414
+ bbox_inches='tight', pad_inches=0.3)
415
+ print("Saved: anchors_only.png")
416
+ plt.close()
417
+
418
+
419
+ # ══════════════════════════════════════════════════════════════════
420
+ # FIGURE 4: PAIRWISE EXPERT DIFFERENCES
421
+ # ══════════════════════════════════════════════════════════════════
422
+
423
+ print("Rendering figure 4 — pairwise expert diffs...")
424
+
425
+ with torch.no_grad():
426
+ # Compute per-expert triangulations
427
+ # For dual-stream: use native projectors (the actual expert perspectives)
428
+ # For fused constellation: use expert rotations
429
+ expert_tris = []
430
+
431
+ if has_expert_rot:
432
+ # Fused constellation: rotate through R/W/mu
433
+ for i in range(3):
434
+ centered = fused.float() - expert_mu[i]
435
+ whitened = centered @ expert_W[i]
436
+ rotated = F.normalize(whitened @ expert_R[i].T, dim=-1)
437
+ tri = 1.0 - (rotated @ anchors.T)
438
+ expert_tris.append(tri)
439
+ elif f"projectors.0.proj_native.0.weight" in sd:
440
+ # Dual-stream: use native projector embeddings
441
+ def _proj_native(feats, i):
442
+ W = sd[f"projectors.{i}.proj_native.0.weight"]
443
+ b = sd[f"projectors.{i}.proj_native.0.bias"]
444
+ lw = sd[f"projectors.{i}.proj_native.1.weight"]
445
+ lb = sd[f"projectors.{i}.proj_native.1.bias"]
446
+ x = feats @ W.T + b
447
+ mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False)
448
+ x = (x - mu) / (var + 1e-5).sqrt() * lw + lb
449
+ return F.normalize(x, dim=-1)
450
+ for i, name in enumerate(EXPERTS):
451
+ native_emb = _proj_native(val_raw[name], i)
452
+ tri = 1.0 - (native_emb @ anchors.T)
453
+ expert_tris.append(tri)
454
+ else:
455
+ # Fallback: use shared projections (will be near-identical)
456
+ for p in projected:
457
+ tri = 1.0 - (p @ anchors.T)
458
+ expert_tris.append(tri)
459
+
460
+ # Pairwise diffs
461
+ diff_cd = expert_tris[0] - expert_tris[1]
462
+ diff_cs = expert_tris[0] - expert_tris[2]
463
+ diff_ds = expert_tris[1] - expert_tris[2]
464
+ diffs = [diff_cd, diff_cs, diff_ds]
465
+ diff_names = ["CLIP − DINOv2", "CLIP − SigLIP", "DINOv2 − SigLIP"]
466
+
467
+ abs_tri = expert_tris[0]
468
+
469
+ print(f"\n Pairwise diff statistics:")
470
+ for name, d in zip(diff_names, diffs):
471
+ print(f" {name:20s}: mean={d.mean():.6f} std={d.std():.6f} "
472
+ f"min={d.min():.6f} max={d.max():.6f}")
473
+ print(f" Absolute tri std: {abs_tri.std():.6f}")
474
+ diff_std = diffs[0].std().item()
475
+ abs_std = abs_tri.std().item()
476
+ print(f" Ratio (diff/abs): {diff_std / abs_std:.4f}" if abs_std > 1e-10 else
477
+ f" Ratio (diff/abs): N/A (zero abs std)")
478
+
479
+ # PCA of the diff space
480
+ diff_stacked = torch.cat(diffs, dim=-1).numpy()
481
+ diff_centered = diff_stacked - diff_stacked.mean(axis=0, keepdims=True)
482
+ _, S_diff, Vt_diff = np.linalg.svd(diff_centered[:5000], full_matrices=False)
483
+
484
+ # Guard against zero SVDs
485
+ s_sum = (S_diff**2).sum()
486
+ if s_sum > 1e-20:
487
+ diff_3d = to_sphere(diff_centered @ Vt_diff[:3].T)
488
+ var_diff = S_diff[:3]**2 / s_sum
489
+ eff_dim_diff = float(((S_diff / S_diff.sum())**2).sum()**-1)
490
+ else:
491
+ diff_3d = np.zeros((len(diff_centered), 3))
492
+ var_diff = np.zeros(3)
493
+ eff_dim_diff = 0.0
494
+ print(f"\n Diff space effective dim: {eff_dim_diff:.1f}")
495
+ print(f" Diff PCA 3D variance: {var_diff.sum()*100:.1f}%")
496
+
497
+ abs_stacked = abs_tri.numpy()
498
+ abs_centered = abs_stacked - abs_stacked.mean(axis=0, keepdims=True)
499
+ _, S_abs, Vt_abs = np.linalg.svd(abs_centered[:5000], full_matrices=False)
500
+ abs_eff = float(((S_abs / S_abs.sum())**2).sum()**-1) if S_abs.sum() > 1e-20 else 0.0
501
+ print(f" Absolute tri effective dim: {abs_eff:.1f}")
502
+
503
+ full_stacked = np.concatenate([abs_stacked, diff_stacked], axis=-1)
504
+ full_centered = full_stacked - full_stacked.mean(axis=0, keepdims=True)
505
+ _, S_full, Vt_full = np.linalg.svd(full_centered[:5000], full_matrices=False)
506
+ full_eff = float(((S_full / S_full.sum())**2).sum()**-1) if S_full.sum() > 1e-20 else 0.0
507
+ full_3d = to_sphere(full_centered @ Vt_full[:3].T) if S_full.sum() > 1e-20 else np.zeros((len(full_centered), 3))
508
+ print(f" Full (abs+diffs) effective dim: {full_eff:.1f}")
509
+ print(f" Information gain from diffs: {full_eff - abs_eff:.1f} dimensions")
510
+
511
+ fig4 = plt.figure(figsize=(28, 14), facecolor='black')
512
+ fig4.suptitle(
513
+ 'Expert Pairwise Differences — Where the discriminative signal lives\n'
514
+ f'Diff eff_dim={eff_dim_diff:.1f} | Abs eff_dim={abs_eff:.1f} | '
515
+ f'Combined eff_dim={full_eff:.1f} | Info gain: +{full_eff-abs_eff:.1f} dims',
516
+ color='white', fontsize=14, y=0.98)
517
+
518
+ # Row 1: Three pairwise diff distributions on sphere
519
+ for col, (name, d) in enumerate(zip(diff_names, diffs)):
520
+ ax = fig4.add_subplot(2, 4, col+1, projection='3d')
521
+ d_np = d.numpy()
522
+
523
+ # Per-image: magnitude of diff vector
524
+ diff_mag = np.linalg.norm(d_np, axis=-1)
525
+ mag_heat = (diff_mag - diff_mag.min()) / (diff_mag.max() - diff_mag.min() + 1e-8)
526
+ mag_colors = plt.cm.magma(mag_heat)
527
+
528
+ setup_ax(ax, f'{name}\nstd={d_np.std():.5f}')
529
+ ax.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2],
530
+ c=mag_colors, s=2, alpha=0.5)
531
+
532
+ # Panel 4: Diff space PCA
533
+ ax_dp = fig4.add_subplot(244, projection='3d')
534
+ setup_ax(ax_dp, f'Diff Space PCA\neff_dim={eff_dim_diff:.1f} var={var_diff.sum()*100:.1f}%')
535
+ ax_dp.scatter(diff_3d[:, 0], diff_3d[:, 1], diff_3d[:, 2],
536
+ c=class_colors, s=2, alpha=0.4)
537
+
538
+ # Row 2: Per-anchor diff analysis
539
+ # Per-anchor mean absolute diff (where do experts disagree most?)
540
+ with torch.no_grad():
541
+ per_anchor_cd = diff_cd.abs().mean(dim=0).numpy()
542
+ per_anchor_cs = diff_cs.abs().mean(dim=0).numpy()
543
+ per_anchor_ds = diff_ds.abs().mean(dim=0).numpy()
544
+ per_anchor_total = (per_anchor_cd + per_anchor_cs + per_anchor_ds) / 3
545
+
546
+ # Panel 5: Anchor-level divergence map (total)
547
+ ax_a = fig4.add_subplot(245, projection='3d')
548
+ total_heat = (per_anchor_total - per_anchor_total.min()) / (per_anchor_total.max() - per_anchor_total.min() + 1e-8)
549
+ total_colors = plt.cm.hot(total_heat)
550
+ total_sizes = 5 + 40 * total_heat
551
+ setup_ax(ax_a, f'Anchor Divergence (all pairs)\n'
552
+ f'range=[{per_anchor_total.min():.5f}, {per_anchor_total.max():.5f}]')
553
+ ax_a.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
554
+ c=total_colors, s=total_sizes, alpha=0.8)
555
+
556
+ # Panel 6: Abs tri PCA vs diff PCA side by side
557
+ ax_abs = fig4.add_subplot(246, projection='3d')
558
+ abs_3d = to_sphere(abs_centered @ Vt_abs[:3].T)
559
+ var_abs_3 = S_abs[:3]**2 / (S_abs**2).sum()
560
+ setup_ax(ax_abs, f'Absolute Tri PCA\neff_dim={abs_eff:.1f} var={var_abs_3.sum()*100:.1f}%')
561
+ ax_abs.scatter(abs_3d[:, 0], abs_3d[:, 1], abs_3d[:, 2],
562
+ c=class_colors, s=2, alpha=0.4)
563
+
564
+ # Panel 7: Combined PCA
565
+ ax_full = fig4.add_subplot(247, projection='3d')
566
+ var_full_3 = S_full[:3]**2 / (S_full**2).sum()
567
+ setup_ax(ax_full, f'Combined (abs+diffs) PCA\neff_dim={full_eff:.1f} var={var_full_3.sum()*100:.1f}%')
568
+ ax_full.scatter(full_3d[:, 0], full_3d[:, 1], full_3d[:, 2],
569
+ c=class_colors, s=2, alpha=0.4)
570
+
571
+ # Panel 8: Histogram of diff magnitudes
572
+ ax_hist = fig4.add_subplot(248)
573
+ ax_hist.set_facecolor('black')
574
+ for name, d, color in zip(diff_names, diffs,
575
+ ['#ff6b6b', '#4ecdc4', '#ffe66d']):
576
+ d_np = d.numpy()
577
+ per_image_mag = np.linalg.norm(d_np, axis=-1)
578
+ ax_hist.hist(per_image_mag, bins=50, alpha=0.6, color=color,
579
+ label=name, density=True)
580
+ ax_hist.set_xlabel('Diff magnitude (L2)', color='white', fontsize=9)
581
+ ax_hist.set_ylabel('Density', color='white', fontsize=9)
582
+ ax_hist.set_title('Per-image diff magnitudes', color='white', fontsize=11)
583
+ ax_hist.legend(fontsize=8, facecolor='black', edgecolor='gray',
584
+ labelcolor='white')
585
+ ax_hist.tick_params(colors='gray', labelsize=7)
586
+ ax_hist.spines['bottom'].set_color('gray'); ax_hist.spines['left'].set_color('gray')
587
+ ax_hist.spines['top'].set_visible(False); ax_hist.spines['right'].set_visible(False)
588
+
589
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
590
+ plt.savefig("pairwise_diffs.png", dpi=200, facecolor='black',
591
+ bbox_inches='tight', pad_inches=0.3)
592
+ print("Saved: pairwise_diffs.png")
593
+ plt.close()
594
+
595
+ print("\nDone.")