AbstractPhil commited on
Commit
7b4d453
Β·
verified Β·
1 Parent(s): e3eb56b

Create analysis_v1.py

Browse files
Files changed (1) hide show
  1. analysis_v1.py +505 -0
analysis_v1.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constellation Bottleneck β€” Full Analysis
3
+ ==========================================
4
+ Paste directly after the training cell.
5
+ Uses `model` already in memory.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import math
13
+ import os
14
+ from torchvision import datasets, transforms
15
+ from torchvision.utils import save_image, make_grid
16
+
17
+ DEVICE = "cuda"
18
+ os.makedirs("analysis_bn", exist_ok=True)
19
+
20
+ def compute_cv(points, n_samples=1500, n_points=5):
21
+ N = points.shape[0]
22
+ if N < n_points: return float('nan')
23
+ points = F.normalize(points.to(DEVICE).float(), dim=-1)
24
+ vols = []
25
+ for _ in range(n_samples):
26
+ idx = torch.randperm(min(N, 5000), device=DEVICE)[:n_points]
27
+ pts = points[idx].unsqueeze(0)
28
+ gram = torch.bmm(pts, pts.transpose(1, 2))
29
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
30
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
31
+ d2 = F.relu(d2)
32
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
33
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
34
+ v2 = -torch.linalg.det(cm) / 9216
35
+ if v2[0].item() > 1e-20:
36
+ vols.append(v2[0].sqrt().cpu())
37
+ if len(vols) < 50: return float('nan')
38
+ vt = torch.stack(vols)
39
+ return (vt.std() / (vt.mean() + 1e-8)).item()
40
+
41
+ def eff_dim(x):
42
+ x_c = x - x.mean(0, keepdim=True)
43
+ n = min(512, x.shape[0])
44
+ _, S, _ = torch.linalg.svd(x_c[:n].float(), full_matrices=False)
45
+ p = S / S.sum()
46
+ return p.pow(2).sum().reciprocal().item()
47
+
48
+ CLASS_NAMES = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck']
49
+
50
+ model.eval()
51
+ bn = model.bottleneck
52
+
53
+ print("=" * 80)
54
+ print("CONSTELLATION BOTTLENECK β€” FULL ANALYSIS")
55
+ print(f" Params: {sum(p.numel() for p in model.parameters()):,}")
56
+ print(f" Bottleneck: {sum(p.numel() for p in bn.parameters()):,}")
57
+ print("=" * 80)
58
+
59
+ # Load test data
60
+ transform = transforms.Compose([
61
+ transforms.ToTensor(),
62
+ transforms.Normalize((0.5,)*3, (0.5,)*3),
63
+ ])
64
+ test_ds = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
65
+ test_loader = torch.utils.data.DataLoader(test_ds, batch_size=256, shuffle=False)
66
+ images_test, labels_test = next(iter(test_loader))
67
+ images_test = images_test.to(DEVICE)
68
+ labels_test = labels_test.to(DEVICE)
69
+
70
+
71
+ # ══════════════════════════════════════════════════════════════════
72
+ # TEST 1: BOTTLENECK DIAGNOSTICS
73
+ # ══════════════════════════════════════════════════════════════════
74
+
75
+ print(f"\n{'━'*80}")
76
+ print("TEST 1: Bottleneck Diagnostics")
77
+ print(f"{'━'*80}")
78
+
79
+ drift = bn.drift().detach()
80
+ home = F.normalize(bn.home, dim=-1).detach()
81
+ curr = F.normalize(bn.anchors, dim=-1).detach()
82
+ P, A, d = home.shape
83
+
84
+ print(f" Patches: {P}, Anchors/patch: {A}, Patch dim: {d}")
85
+ print(f" Drift: mean={drift.mean():.6f} rad ({math.degrees(drift.mean()):.2f}Β°)")
86
+ print(f" std={drift.std():.6f} min={drift.min():.6f} max={drift.max():.6f}")
87
+ print(f" max degrees: {math.degrees(drift.max()):.2f}Β°")
88
+ print(f" Skip gate: {bn.skip_gate.sigmoid().item():.4f}")
89
+ print(f" Near 0.29154: {(drift - 0.29154).abs().lt(0.05).float().mean().item():.1%}")
90
+
91
+ # Per-patch drift
92
+ print(f"\n Per-patch drift:")
93
+ for p in range(P):
94
+ d_p = drift[p].mean().item()
95
+ d_max = drift[p].max().item()
96
+ marker = " β—„ 0.29" if abs(d_p - 0.29154) < 0.05 else ""
97
+ marker2 = " β—„ MAX near 0.29" if abs(d_max - 0.29154) < 0.05 else ""
98
+ print(f" P{p:2d}: mean={d_p:.4f} ({math.degrees(d_p):.1f}Β°) "
99
+ f"max={d_max:.4f} ({math.degrees(d_max):.1f}Β°){marker}{marker2}")
100
+
101
+ # Anchor pairwise spread
102
+ print(f"\n Anchor spread per patch:")
103
+ for p in range(min(8, P)):
104
+ sim = (curr[p] @ curr[p].T)
105
+ sim.fill_diagonal_(0)
106
+ print(f" P{p}: mean_cos={sim.mean():.4f} max={sim.max():.4f} min={sim.min():.4f}")
107
+
108
+ # Anchor effective dimensionality
109
+ print(f"\n Anchor effective dimensionality:")
110
+ for p in range(min(8, P)):
111
+ _, S, _ = torch.linalg.svd(curr[p].float(), full_matrices=False)
112
+ pr = S / S.sum()
113
+ ed = pr.pow(2).sum().reciprocal().item()
114
+ print(f" P{p}: eff_dim={ed:.1f} / {A}")
115
+
116
+ # Drift histogram β€” where do anchors cluster?
117
+ all_drifts = drift.flatten().cpu().numpy()
118
+ print(f"\n Drift distribution:")
119
+ bins = [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]
120
+ hist, _ = np.histogram(all_drifts, bins=bins)
121
+ for i in range(len(bins)-1):
122
+ bar = "β–ˆ" * hist[i]
123
+ print(f" {bins[i]:.2f}-{bins[i+1]:.2f}: {hist[i]:3d} {bar}")
124
+
125
+
126
+ # ══════════════════════════════════════════════════════════════════
127
+ # TEST 2: SPHERE REPRESENTATION β€” CV OF BOTTLENECK EMBEDDINGS
128
+ # ══════════════════════════════════════════════════════════════════
129
+
130
+ print(f"\n{'━'*80}")
131
+ print("TEST 2: Sphere Representation β€” CV of bottleneck embeddings")
132
+ print(f" These live on S^15. Does CV approach 0.20?")
133
+ print(f"{'━'*80}")
134
+
135
+ # Hook to capture sphere embeddings
136
+ sphere_embeddings = {}
137
+ tri_profiles = {}
138
+
139
+ def hook_sphere(module, input, output):
140
+ # The forward method: proj_in β†’ norm β†’ reshape β†’ normalize
141
+ # We need to grab AFTER L2 norm. Hook the full bottleneck
142
+ # and manually compute the sphere embedding.
143
+ pass
144
+
145
+ # Manually extract sphere embeddings at different timesteps
146
+ print(f"\n {'t':>6} {'CV_sphere':>10} {'CV_tri':>10} {'eff_d_sph':>10} "
147
+ f"{'eff_d_tri':>10} {'sph_norm':>10}")
148
+
149
+ for t_val in [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0]:
150
+ B = images_test.shape[0]
151
+ t = torch.full((B,), t_val, device=DEVICE)
152
+ eps = torch.randn_like(images_test)
153
+ t_b = t.view(B, 1, 1, 1)
154
+ x_t = (1 - t_b) * images_test + t_b * eps
155
+
156
+ with torch.no_grad():
157
+ # Run encoder manually
158
+ cond = model.time_emb(t) + model.class_emb(labels_test)
159
+ h = model.in_conv(x_t)
160
+ skips = [h]
161
+ for i in range(len(model.channel_mults)):
162
+ for block in model.enc[i]:
163
+ if isinstance(block, nn.Sequential):
164
+ h = block[0](h); h = block[1](h, cond)
165
+ else:
166
+ h = block(h, cond)
167
+ skips.append(h)
168
+ if i < len(model.enc_down):
169
+ h = model.enc_down[i](h)
170
+
171
+ # Get sphere embedding
172
+ h_flat = h.reshape(B, -1)
173
+ emb = bn.proj_in(h_flat)
174
+ emb = bn.proj_in_norm(emb)
175
+ patches = emb.reshape(B, bn.n_patches, bn.patch_dim)
176
+ patches_n = F.normalize(patches, dim=-1)
177
+
178
+ # CV of sphere embeddings (flatten patches back to one vector)
179
+ sphere_flat = patches_n.reshape(B, -1) # (B, 256) on product of spheres
180
+ cv_sphere = compute_cv(sphere_flat, n_samples=1000)
181
+ ed_sphere = eff_dim(sphere_flat)
182
+ norm_sph = sphere_flat.norm(dim=-1).mean().item()
183
+
184
+ # Triangulation profile
185
+ tri = bn.triangulate(patches_n) # (B, 768)
186
+ cv_tri = compute_cv(tri, n_samples=1000)
187
+ ed_tri = eff_dim(tri)
188
+
189
+ # Per-patch CV
190
+ if t_val == 0.0:
191
+ print(f"\n Per-patch CV at t=0 (should be β‰ˆ0.20 if d=16):")
192
+ for p in range(min(8, bn.n_patches)):
193
+ patch_p = patches_n[:, p, :] # (B, 16) on S^15
194
+ cv_p = compute_cv(patch_p, n_samples=1000)
195
+ print(f" Patch {p}: CV={cv_p:.4f}")
196
+ print()
197
+
198
+ print(f" {t_val:>6.2f} {cv_sphere:>10.4f} {cv_tri:>10.4f} {ed_sphere:>10.1f} "
199
+ f"{ed_tri:>10.1f} {norm_sph:>10.4f}")
200
+
201
+
202
+ # ══════════════════════════════════════════════════════════════════
203
+ # TEST 3: PER-CLASS ANCHOR ROUTING
204
+ # ══════════════════════════════════════════════════════════════════
205
+
206
+ print(f"\n{'━'*80}")
207
+ print("TEST 3: Per-Class Anchor Routing")
208
+ print(f"{'━'*80}")
209
+
210
+ # Collect per-class nearest anchors across all patches
211
+ class_nearest = {c: [] for c in range(10)}
212
+ anchors_n = F.normalize(bn.anchors.detach(), dim=-1)
213
+
214
+ for images_b, labels_b in test_loader:
215
+ images_b = images_b.to(DEVICE)
216
+ labels_b = labels_b.to(DEVICE)
217
+ B = images_b.shape[0]
218
+ t = torch.zeros(B, device=DEVICE) # clean images
219
+
220
+ with torch.no_grad():
221
+ cond = model.time_emb(t) + model.class_emb(labels_b)
222
+ h = model.in_conv(images_b)
223
+ for i in range(len(model.channel_mults)):
224
+ for block in model.enc[i]:
225
+ if isinstance(block, nn.Sequential):
226
+ h = block[0](h); h = block[1](h, cond)
227
+ else:
228
+ h = block(h, cond)
229
+ if i < len(model.enc_down):
230
+ h = model.enc_down[i](h)
231
+
232
+ h_flat = h.reshape(B, -1)
233
+ emb = bn.proj_in_norm(bn.proj_in(h_flat))
234
+ patches = F.normalize(emb.reshape(B, bn.n_patches, bn.patch_dim), dim=-1)
235
+
236
+ # Nearest anchor per patch
237
+ cos = torch.einsum('bpd,pad->bpa', patches, anchors_n) # (B, P, A)
238
+ nearest = cos.argmax(dim=-1) # (B, P)
239
+
240
+ for i in range(B):
241
+ c = labels_b[i].item()
242
+ class_nearest[c].append(nearest[i].cpu())
243
+
244
+ if sum(len(v) for v in class_nearest.values()) > 5000:
245
+ break
246
+
247
+ # Show routing for first 4 patches
248
+ for p_idx in range(min(4, bn.n_patches)):
249
+ print(f"\n Patch {p_idx} β€” nearest anchor per class:")
250
+ print(f" {'class':>10}", end="")
251
+ for a in range(A):
252
+ print(f" {a:>4}", end="")
253
+ print()
254
+
255
+ for c in range(10):
256
+ if not class_nearest[c]:
257
+ continue
258
+ nearest_all = torch.stack(class_nearest[c]) # (N, P)
259
+ nearest_p = nearest_all[:, p_idx]
260
+ counts = torch.bincount(nearest_p, minlength=A).float()
261
+ counts = counts / counts.sum()
262
+ row = f" {CLASS_NAMES[c]:>10}"
263
+ for a in range(A):
264
+ pct = counts[a].item()
265
+ if pct > 0.15:
266
+ row += f" {pct:>3.0%}β–ˆ"
267
+ elif pct > 0.05:
268
+ row += f" {pct:>3.0%}β–‘"
269
+ else:
270
+ row += f" {pct:>3.0%}"
271
+ #row += f" {pct:>3.0%}"
272
+ print(row)
273
+
274
+ # Are anchor patterns class-specific?
275
+ print(f"\n Anchor routing entropy per class (lower = more concentrated):")
276
+ for c in range(10):
277
+ if not class_nearest[c]:
278
+ continue
279
+ nearest_all = torch.stack(class_nearest[c])
280
+ # Average across patches
281
+ total_entropy = 0
282
+ for p_idx in range(bn.n_patches):
283
+ counts = torch.bincount(nearest_all[:, p_idx], minlength=A).float()
284
+ counts = counts / counts.sum()
285
+ entropy = -(counts * (counts + 1e-8).log()).sum().item()
286
+ total_entropy += entropy
287
+ avg_entropy = total_entropy / bn.n_patches
288
+ max_entropy = math.log(A)
289
+ print(f" {CLASS_NAMES[c]:>10}: H={avg_entropy:.3f} / {max_entropy:.3f} "
290
+ f"({avg_entropy/max_entropy:.1%} of max)")
291
+
292
+
293
+ # ══════════════════════════════════════════════════════════════════
294
+ # TEST 4: SKIP GATE ANALYSIS
295
+ # ══════════════════════════════════════════════════════════════════
296
+
297
+ print(f"\n{'━'*80}")
298
+ print("TEST 4: Skip Gate β€” how much goes through constellation vs skip?")
299
+ print(f"{'━'*80}")
300
+
301
+ gate = bn.skip_gate.sigmoid().item()
302
+ print(f" Skip gate value: {gate:.4f}")
303
+ print(f" Skip path: {gate:.1%}")
304
+ print(f" Constellation path: {1-gate:.1%}")
305
+ print(f" Skip proj params: {sum(p.numel() for p in [bn.skip_proj.weight, bn.skip_proj.bias]):,}")
306
+ print(f" Patchwork params: {sum(p.numel() for p in bn.patchwork.parameters()):,}")
307
+ print(f"\n ⚠ skip_proj is Linear(16384, 16384) = "
308
+ f"{bn.skip_proj.weight.numel():,} params")
309
+ print(f" ⚠ This single layer is {bn.skip_proj.weight.numel()/1e6:.0f}M params β€” "
310
+ f"larger than the rest of the model combined")
311
+
312
+
313
+ # ══════════════════════════════════════════════════════════════════
314
+ # TEST 5: GENERATION β€” PER CLASS
315
+ # ══════════════════════════════════════════════════════════════════
316
+
317
+ print(f"\n{'━'*80}")
318
+ print("TEST 5: Generation Quality")
319
+ print(f"{'━'*80}")
320
+
321
+ print(f" {'class':>10} {'intra_cos':>10} {'std':>8} {'CV':>8} {'norm':>8}")
322
+
323
+ all_gen = []
324
+ for c in range(10):
325
+ imgs, _ = sample(model, 64, 50, class_label=c)
326
+ imgs = (imgs + 1) / 2 # to [0,1]
327
+ all_gen.append(imgs)
328
+
329
+ flat = imgs.reshape(64, -1)
330
+ flat_n = F.normalize(flat, dim=-1)
331
+ sim = flat_n @ flat_n.T
332
+ mask = ~torch.eye(64, device=DEVICE, dtype=torch.bool)
333
+ intra = sim[mask].mean().item()
334
+ std = sim[mask].std().item()
335
+ cv = compute_cv(flat, 500)
336
+ norm = flat.norm(dim=-1).mean().item()
337
+ print(f" {CLASS_NAMES[c]:>10} {intra:>10.4f} {std:>8.4f} {cv:>8.4f} {norm:>8.2f}")
338
+
339
+ save_image(make_grid(imgs[:16], nrow=4), f"analysis_bn/class_{CLASS_NAMES[c]}.png")
340
+
341
+ # All classes grid
342
+ all_grid = torch.cat([g[:4] for g in all_gen])
343
+ save_image(make_grid(all_grid, nrow=10), "analysis_bn/all_classes.png")
344
+
345
+
346
+ # ══════════════════════════════════════════════════════════════════
347
+ # TEST 6: ABLATION β€” SKIP ONLY vs CONSTELLATION ONLY
348
+ # ══════════════════════════════════════════════════════════════════
349
+
350
+ print(f"\n{'━'*80}")
351
+ print("TEST 6: Ablation β€” Skip-only vs Constellation-only")
352
+ print(f"{'━'*80}")
353
+
354
+ original_gate = bn.skip_gate.data.clone()
355
+
356
+ # A) Full model (as trained)
357
+ torch.manual_seed(999)
358
+ with torch.no_grad():
359
+ imgs_full, _ = sample(model, 32, 50, class_label=3)
360
+
361
+ # B) Skip only (gate β†’ +100, sigmoid β‰ˆ 1.0)
362
+ bn.skip_gate.data.fill_(100.0)
363
+ torch.manual_seed(999)
364
+ with torch.no_grad():
365
+ imgs_skip, _ = sample(model, 32, 50, class_label=3)
366
+
367
+ # C) Constellation only (gate β†’ -100, sigmoid β‰ˆ 0.0)
368
+ bn.skip_gate.data.fill_(-100.0)
369
+ torch.manual_seed(999)
370
+ with torch.no_grad():
371
+ imgs_const, _ = sample(model, 32, 50, class_label=3)
372
+
373
+ # Restore
374
+ bn.skip_gate.data.copy_(original_gate)
375
+
376
+ imgs_full_01 = (imgs_full + 1) / 2
377
+ imgs_skip_01 = (imgs_skip + 1) / 2
378
+ imgs_const_01 = (imgs_const + 1) / 2
379
+
380
+ # Compare
381
+ for name, imgs in [('skip_only', imgs_skip), ('const_only', imgs_const)]:
382
+ delta = (imgs_full - imgs).abs()
383
+ pixel_diff = delta.mean().item()
384
+ cos = F.cosine_similarity(
385
+ imgs_full.reshape(32, -1), imgs.reshape(32, -1)).mean().item()
386
+ print(f" {name:>15}: pixel_Ξ”={pixel_diff:.6f} cos_sim={cos:.6f} "
387
+ f"max_Ξ”={delta.max():.4f}")
388
+
389
+ # Save comparison: top=full, mid=skip_only, bot=constellation_only
390
+ comparison = torch.cat([imgs_full_01[:8], imgs_skip_01[:8], imgs_const_01[:8]])
391
+ save_image(make_grid(comparison, nrow=8), "analysis_bn/ablation_skip_vs_const.png")
392
+ print(f" βœ“ Saved (top=full, mid=skip_only, bot=constellation_only)")
393
+
394
+
395
+ # ══════════════════════════════════════════════════════════════════
396
+ # TEST 7: VELOCITY FIELD
397
+ # ══════════════════════════════════════════════════════════════════
398
+
399
+ print(f"\n{'━'*80}")
400
+ print("TEST 7: Velocity Field Quality")
401
+ print(f"{'━'*80}")
402
+
403
+ print(f" {'t':>6} {'v_norm':>10} {'vΒ·target':>10} {'mse':>10}")
404
+
405
+ for t_val in [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]:
406
+ B = 128
407
+ imgs_v = images_test[:B]
408
+ labs_v = labels_test[:B]
409
+ t = torch.full((B,), t_val, device=DEVICE)
410
+ eps = torch.randn_like(imgs_v)
411
+ t_b = t.view(B, 1, 1, 1)
412
+ x_t = (1 - t_b) * imgs_v + t_b * eps
413
+ v_target = eps - imgs_v
414
+
415
+ with torch.no_grad():
416
+ v_pred = model(x_t, t, labs_v)
417
+
418
+ v_norm = v_pred.reshape(B, -1).norm(dim=-1).mean().item()
419
+ v_cos = F.cosine_similarity(
420
+ v_pred.reshape(B, -1), v_target.reshape(B, -1)).mean().item()
421
+ mse = F.mse_loss(v_pred, v_target).item()
422
+ print(f" {t_val:>6.2f} {v_norm:>10.2f} {v_cos:>10.4f} {mse:>10.4f}")
423
+
424
+
425
+ # ══════════════════════════════════════════════════════════════════
426
+ # TEST 8: ODE TRAJECTORY β€” CV THROUGH GENERATION
427
+ # ══════════════════════════════════════════════════════════════════
428
+
429
+ print(f"\n{'━'*80}")
430
+ print("TEST 8: ODE Trajectory β€” geometry through generation")
431
+ print(f"{'━'*80}")
432
+
433
+ n_steps = 50
434
+ B_traj = 256
435
+ x = torch.randn(B_traj, 3, 32, 32, device=DEVICE)
436
+ labels_traj = torch.randint(0, 10, (B_traj,), device=DEVICE)
437
+ dt = 1.0 / n_steps
438
+
439
+ print(f" {'step':>6} {'t':>6} {'x_norm':>10} {'x_std':>10} {'CV':>8}")
440
+
441
+ for step in range(n_steps):
442
+ t_val = 1.0 - step * dt
443
+ t = torch.full((B_traj,), t_val, device=DEVICE)
444
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
445
+ v = model(x, t, labels_traj)
446
+ x = x - v.float() * dt
447
+
448
+ if step in [0, 1, 5, 10, 20, 30, 40, 49]:
449
+ xf = x.reshape(B_traj, -1)
450
+ print(f" {step:>6} {t_val:>6.2f} {xf.norm(dim=-1).mean().item():>10.2f} "
451
+ f"{x.std().item():>10.4f} {compute_cv(xf, 500):>8.4f}")
452
+
453
+
454
+ # ══════════════════════════════════════════════════════════════════
455
+ # TEST 9: INTER vs INTRA CLASS
456
+ # ══════════════════════════════════════════════════════════════════
457
+
458
+ print(f"\n{'━'*80}")
459
+ print("TEST 9: Inter vs Intra Class Separation")
460
+ print(f"{'━'*80}")
461
+
462
+ intra_sims = []
463
+ inter_sims = []
464
+ for c in range(10):
465
+ flat = F.normalize(all_gen[c].reshape(64, -1), dim=-1)
466
+ sim = flat @ flat.T
467
+ mask = ~torch.eye(64, device=DEVICE, dtype=torch.bool)
468
+ intra_sims.append(sim[mask].mean().item())
469
+
470
+ for i in range(10):
471
+ for j in range(i+1, 10):
472
+ fi = F.normalize(all_gen[i].reshape(64, -1), dim=-1)
473
+ fj = F.normalize(all_gen[j].reshape(64, -1), dim=-1)
474
+ inter_sims.append((fi @ fj.T).mean().item())
475
+
476
+ print(f" Intra-class cos: {np.mean(intra_sims):.4f} Β± {np.std(intra_sims):.4f}")
477
+ print(f" Inter-class cos: {np.mean(inter_sims):.4f} Β± {np.std(inter_sims):.4f}")
478
+ ratio = np.mean(intra_sims) / (np.mean(inter_sims) + 1e-8)
479
+ print(f" Separation ratio: {ratio:.3f}Γ—")
480
+
481
+
482
+ # ══════════════════════════════════════════════════════════════════
483
+ # SUMMARY
484
+ # ═════════════════════════════════════════════════════════���════════
485
+
486
+ print(f"\n{'='*80}")
487
+ print("ANALYSIS COMPLETE")
488
+ print(f"{'='*80}")
489
+ print(f"""
490
+ Files in analysis_bn/:
491
+ class_*.png per-class samples
492
+ all_classes.png 4 per class grid
493
+ ablation_skip_vs_const.png top=full, mid=skip, bot=constellation
494
+
495
+ Key questions answered:
496
+ 1. Does per-patch CV β‰ˆ 0.20? (Test 2)
497
+ β†’ If yes, the bottleneck lives at the natural S^15 dimension
498
+ 2. Is anchor routing class-specific? (Test 3)
499
+ β†’ If entropy varies by class, constellation routes differently
500
+ 3. Does the skip path dominate? (Tests 4 & 6)
501
+ β†’ If skip_only β‰ˆ full, the 268M skip_proj IS the model
502
+ 4. Does constellation-only work at all? (Test 6)
503
+ β†’ The real test of whether geometric encoding carries signal
504
+ """)
505
+ print("=" * 80)