| """ |
| Constellation Diffusion β Analysis |
| ===================================== |
| Paste after training. Uses `model` and `bn` from memory. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import os |
| from torchvision import datasets, transforms |
| from torchvision.utils import save_image, make_grid |
|
|
| DEVICE = "cuda" |
| os.makedirs("analysis_cd", exist_ok=True) |
|
|
| def compute_cv(points, n_samples=1500, n_points=5): |
| N = points.shape[0] |
| if N < n_points: return float('nan') |
| points = F.normalize(points.to(DEVICE).float(), dim=-1) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(N, 5000), device=DEVICE)[:n_points] |
| pts = points[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| v2 = -torch.linalg.det(cm) / 9216 |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt().cpu()) |
| if len(vols) < 50: return float('nan') |
| vt = torch.stack(vols) |
| return (vt.std() / (vt.mean() + 1e-8)).item() |
|
|
| def eff_dim(x): |
| x_c = x - x.mean(0, keepdim=True) |
| n = min(512, x.shape[0]) |
| _, S, _ = torch.linalg.svd(x_c[:n].float(), full_matrices=False) |
| p = S / S.sum() |
| return p.pow(2).sum().reciprocal().item() |
|
|
| CLASS_NAMES = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck'] |
|
|
| model.eval() |
| bn = model.bottleneck |
|
|
| print("=" * 80) |
| print("CONSTELLATION DIFFUSION β PURE BOTTLENECK ANALYSIS") |
| n_params = sum(p.numel() for p in model.parameters()) |
| n_bn = sum(p.numel() for p in bn.parameters()) |
| print(f" Total: {n_params:,} Bottleneck: {n_bn:,} ({100*n_bn/n_params:.1f}%)") |
| print(f" Compression: {bn.spatial_dim} β {bn.n_patches * bn.n_anchors * bn.n_phases} " |
| f"({bn.spatial_dim / (bn.n_patches * bn.n_anchors * bn.n_phases):.1f}Γ)") |
| print("=" * 80) |
|
|
| |
| transform = transforms.Compose([ |
| transforms.ToTensor(), transforms.Normalize((0.5,)*3, (0.5,)*3)]) |
| test_ds = datasets.CIFAR10('./data', train=False, download=True, transform=transform) |
| test_loader = torch.utils.data.DataLoader(test_ds, batch_size=256, shuffle=False) |
|
|
|
|
| |
| @torch.no_grad() |
| def get_sphere_embeddings(images, labels, t_val=0.0): |
| """Run encoder + projection, return patches on S^15 and tri profiles.""" |
| B = images.shape[0] |
| t = torch.full((B,), t_val, device=DEVICE) |
| eps = torch.randn_like(images) |
| t_b = t.view(B, 1, 1, 1) |
| x_t = (1 - t_b) * images + t_b * eps |
|
|
| cond = model.time_emb(t) + model.class_emb(labels) |
| h = model.in_conv(x_t) |
| for i in range(len(model.ch_mults)): |
| for block in model.enc[i]: |
| if isinstance(block, nn.Sequential): |
| h = block[0](h); h = block[1](h, cond) |
| else: |
| h = block(h, cond) |
| if i < len(model.enc_down): |
| h = model.enc_down[i](h) |
|
|
| h_flat = h.reshape(B, -1) |
| emb = bn.proj_in(h_flat) |
| patches = emb.reshape(B, bn.n_patches, bn.patch_dim) |
| patches_n = F.normalize(patches, dim=-1) |
| tri = bn.triangulate(patches_n) |
| return patches_n, tri, h_flat |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 1: Drift & Anchor Diagnostics") |
| print(f"{'β'*80}") |
|
|
| with torch.no_grad(): |
| drift = bn.drift().detach() |
| home = F.normalize(bn.home, dim=-1).detach() |
| curr = F.normalize(bn.anchors, dim=-1).detach() |
| P, A, d = home.shape |
|
|
| print(f" Drift: mean={drift.mean():.6f} rad ({math.degrees(drift.mean().item()):.2f}Β°)") |
| print(f" max={drift.max():.6f} rad ({math.degrees(drift.max().item()):.2f}Β°)") |
| print(f" Near 0.29154: {(drift - 0.29154).abs().lt(0.05).float().mean().item():.1%}") |
| print(f" Near 0.29154 (Β±0.03): {(drift - 0.29154).abs().lt(0.03).float().mean().item():.1%}") |
|
|
| |
| all_d = drift.flatten().cpu().numpy() |
| print(f"\n Drift distribution ({len(all_d)} anchors):") |
| bins = [0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.29154, 0.35, 0.40, 0.50] |
| hist, _ = np.histogram(all_d, bins=bins) |
| for i in range(len(bins)-1): |
| bar = "β" * (hist[i] // 2 + (1 if hist[i] > 0 else 0)) |
| label = " β BINDING" if bins[i+1] == 0.29154 else "" |
| print(f" {bins[i]:.3f}-{bins[i+1]:.3f}: {hist[i]:3d} {bar}{label}") |
|
|
| |
| print(f"\n Per-patch drift summary:") |
| for p in range(P): |
| d_mean = drift[p].mean().item() |
| d_max = drift[p].max().item() |
| n_near = (drift[p] - 0.29154).abs().lt(0.05).sum().item() |
| flags = [] |
| if abs(d_mean - 0.29154) < 0.05: flags.append("MEANβ0.29") |
| if abs(d_max - 0.29154) < 0.05: flags.append("MAXβ0.29") |
| if d_max > 0.29154: flags.append("CROSSED") |
| flag_str = " β " + ", ".join(flags) if flags else "" |
| print(f" P{p:2d}: mean={d_mean:.4f} max={d_max:.4f} near={n_near}/{A}{flag_str}") |
|
|
| |
| print(f"\n Anchor effective dimensionality:") |
| for p in range(P): |
| _, S, _ = torch.linalg.svd(curr[p].float(), full_matrices=False) |
| pr = S / S.sum() |
| ed = pr.pow(2).sum().reciprocal().item() |
| print(f" P{p:2d}: {ed:.1f} / {A}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 2: Sphere Geometry β per-patch CV across timesteps") |
| print(f"{'β'*80}") |
|
|
| images_t, labels_t = next(iter(test_loader)) |
| images_t, labels_t = images_t.to(DEVICE), labels_t.to(DEVICE) |
|
|
| |
| patches_n, tri, _ = get_sphere_embeddings(images_t, labels_t, 0.0) |
| print(f"\n Per-patch CV at t=0.0 (natural S^15 = 0.20):") |
| for p in range(P): |
| cv_p = compute_cv(patches_n[:, p, :], 1000) |
| print(f" P{p:2d}: CV={cv_p:.4f}") |
|
|
| |
| print(f"\n {'t':>6} {'CV_sphere':>10} {'CV_tri':>10} {'eff_d_sph':>10} {'eff_d_tri':>10}") |
| for t_val in [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0]: |
| pn, tr, _ = get_sphere_embeddings(images_t, labels_t, t_val) |
| sph_flat = pn.reshape(pn.shape[0], -1) |
| cv_s = compute_cv(sph_flat, 1000) |
| cv_t = compute_cv(tr, 1000) |
| ed_s = eff_dim(sph_flat) |
| ed_t = eff_dim(tr) |
| print(f" {t_val:>6.2f} {cv_s:>10.4f} {cv_t:>10.4f} {ed_s:>10.1f} {ed_t:>10.1f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 3: Per-Class Anchor Routing") |
| print(f"{'β'*80}") |
|
|
| class_nearest = {c: [] for c in range(10)} |
| anchors_n = F.normalize(bn.anchors.detach(), dim=-1) |
|
|
| for imgs_b, labs_b in test_loader: |
| imgs_b, labs_b = imgs_b.to(DEVICE), labs_b.to(DEVICE) |
| pn, _, _ = get_sphere_embeddings(imgs_b, labs_b, 0.0) |
| cos = torch.einsum('bpd,pad->bpa', pn, anchors_n) |
| nearest = cos.argmax(dim=-1).cpu() |
| for i in range(imgs_b.shape[0]): |
| class_nearest[labs_b[i].item()].append(nearest[i]) |
| if sum(len(v) for v in class_nearest.values()) > 8000: |
| break |
|
|
| |
| for p_idx in range(min(4, P)): |
| print(f"\n Patch {p_idx}:") |
| print(f" {'class':>8}", end="") |
| for a in range(A): |
| print(f" {a:>4}", end="") |
| print(" entropy") |
|
|
| for c in range(10): |
| if not class_nearest[c]: continue |
| nearest_all = torch.stack(class_nearest[c]) |
| counts = torch.bincount(nearest_all[:, p_idx], minlength=A).float() |
| counts = counts / counts.sum() |
| entropy = -(counts * (counts + 1e-8).log()).sum().item() |
|
|
| row = f" {CLASS_NAMES[c]:>8}" |
| for a in range(A): |
| pct = counts[a].item() |
| if pct > 0.15: row += f" {pct:>3.0%}β" |
| elif pct > 0.08: row += f" {pct:>3.0%}β" |
| elif pct > 0.02: row += f" {pct:>3.0%} " |
| else: row += f" ." |
| row += f" {entropy:.2f}" |
| print(row) |
|
|
| |
| all_nearest = torch.cat([torch.stack(v) for v in class_nearest.values() if v]) |
| unique_per_patch = [] |
| for p_idx in range(P): |
| unique_per_patch.append(all_nearest[:, p_idx].unique().numel()) |
| print(f"\n Unique anchors per patch: {unique_per_patch}") |
| print(f" Mean utilization: {np.mean(unique_per_patch):.1f}/{A} " |
| f"({100*np.mean(unique_per_patch)/A:.0f}%)") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 4: Reconstruction Fidelity β what survives 768 dims?") |
| print(f"{'β'*80}") |
|
|
| print(f" {'t':>6} {'input_norm':>12} {'output_norm':>12} {'cos_sim':>10} " |
| f"{'rel_error':>10} {'mse':>10}") |
|
|
| for t_val in [0.0, 0.25, 0.5, 0.75, 1.0]: |
| B = images_t.shape[0] |
| t = torch.full((B,), t_val, device=DEVICE) |
| eps = torch.randn_like(images_t) |
| t_b = t.view(B, 1, 1, 1) |
| x_t = (1 - t_b) * images_t + t_b * eps |
| cond = model.time_emb(t) + model.class_emb(labels_t) |
|
|
| with torch.no_grad(): |
| |
| h = model.in_conv(x_t) |
| for i in range(len(model.ch_mults)): |
| for block in model.enc[i]: |
| if isinstance(block, nn.Sequential): |
| h = block[0](h); h = block[1](h, cond) |
| else: h = block(h, cond) |
| if i < len(model.enc_down): h = model.enc_down[i](h) |
|
|
| h_flat = h.reshape(B, -1) |
| h_reconstructed = bn(h_flat, cond) |
|
|
| in_norm = h_flat.norm(dim=-1).mean().item() |
| out_norm = h_reconstructed.norm(dim=-1).mean().item() |
| cos = F.cosine_similarity(h_flat, h_reconstructed).mean().item() |
| rel_err = (h_flat - h_reconstructed).norm(dim=-1).mean().item() / (in_norm + 1e-8) |
| mse = F.mse_loss(h_flat, h_reconstructed).item() |
|
|
| print(f" {t_val:>6.2f} {in_norm:>12.2f} {out_norm:>12.2f} {cos:>10.6f} " |
| f"{rel_err:>10.4f} {mse:>10.2f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 5: Generation Quality β per class") |
| print(f"{'β'*80}") |
|
|
| print(f" {'class':>8} {'intra_cos':>10} {'std':>8} {'CV':>8} {'norm':>8}") |
|
|
| all_gen = [] |
| for c in range(10): |
| with torch.no_grad(): |
| imgs, _ = sample(model, 64, 50, cls=c) |
| imgs = (imgs + 1) / 2 |
| all_gen.append(imgs) |
|
|
| flat = imgs.reshape(64, -1) |
| flat_n = F.normalize(flat, dim=-1) |
| sim = flat_n @ flat_n.T |
| mask = ~torch.eye(64, device=DEVICE, dtype=torch.bool) |
| print(f" {CLASS_NAMES[c]:>8} {sim[mask].mean().item():>10.4f} " |
| f"{sim[mask].std().item():>8.4f} {compute_cv(flat, 500):>8.4f} " |
| f"{flat.norm(dim=-1).mean().item():>8.2f}") |
|
|
| save_image(make_grid(imgs[:16], nrow=4), f"analysis_cd/class_{CLASS_NAMES[c]}.png") |
|
|
| all_grid = torch.cat([g[:4] for g in all_gen]) |
| save_image(make_grid(all_grid, nrow=10), "analysis_cd/all_classes.png") |
| print(f" β Saved to analysis_cd/") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 6: Velocity Field Quality") |
| print(f"{'β'*80}") |
|
|
| print(f" {'t':>6} {'v_norm':>10} {'vΒ·target':>10} {'mse':>10}") |
|
|
| for t_val in [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]: |
| B = 128 |
| imgs_v = images_t[:B] |
| labs_v = labels_t[:B] |
| t = torch.full((B,), t_val, device=DEVICE) |
| eps = torch.randn_like(imgs_v) |
| t_b = t.view(B, 1, 1, 1) |
| x_t = (1 - t_b) * imgs_v + t_b * eps |
| v_target = eps - imgs_v |
|
|
| with torch.no_grad(): |
| v_pred = model(x_t, t, labs_v) |
| v_cos = F.cosine_similarity( |
| v_pred.reshape(B, -1), v_target.reshape(B, -1)).mean().item() |
| mse = F.mse_loss(v_pred, v_target).item() |
| v_norm = v_pred.reshape(B, -1).norm(dim=-1).mean().item() |
| print(f" {t_val:>6.2f} {v_norm:>10.2f} {v_cos:>10.4f} {mse:>10.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 7: ODE Trajectory β geometry through generation") |
| print(f"{'β'*80}") |
|
|
| B_traj = 256 |
| x = torch.randn(B_traj, 3, 32, 32, device=DEVICE) |
| labs_traj = torch.randint(0, 10, (B_traj,), device=DEVICE) |
| dt = 1.0 / 50 |
|
|
| print(f" {'step':>6} {'t':>6} {'norm':>10} {'std':>10} {'CV':>8}") |
| for step in range(50): |
| t = torch.full((B_traj,), 1.0 - step * dt, device=DEVICE) |
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| v = model(x, t, labs_traj) |
| x = x - v.float() * dt |
| if step in [0, 1, 5, 10, 20, 30, 40, 49]: |
| xf = x.reshape(B_traj, -1) |
| print(f" {step:>6} {1.0-step*dt:>6.2f} {xf.norm(dim=-1).mean().item():>10.2f} " |
| f"{x.std().item():>10.4f} {compute_cv(xf, 500):>8.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 8: Class Separation") |
| print(f"{'β'*80}") |
|
|
| intra, inter = [], [] |
| for c in range(10): |
| f = F.normalize(all_gen[c].reshape(64, -1), dim=-1) |
| s = f @ f.T |
| m = ~torch.eye(64, device=DEVICE, dtype=torch.bool) |
| intra.append(s[m].mean().item()) |
|
|
| for i in range(10): |
| for j in range(i+1, 10): |
| fi = F.normalize(all_gen[i].reshape(64, -1), dim=-1) |
| fj = F.normalize(all_gen[j].reshape(64, -1), dim=-1) |
| inter.append((fi @ fj.T).mean().item()) |
|
|
| print(f" Intra-class cos: {np.mean(intra):.4f} Β± {np.std(intra):.4f}") |
| print(f" Inter-class cos: {np.mean(inter):.4f} Β± {np.std(inter):.4f}") |
| print(f" Separation ratio: {np.mean(intra) / (np.mean(inter) + 1e-8):.3f}Γ") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 9: Comparison Summary") |
| print(f"{'β'*80}") |
|
|
| print(f""" |
| {'':>25} {'Regulator':>12} {'Skip BN':>12} {'Pure BN':>12} |
| {'':>25} {'(v1)':>12} {'(v2)':>12} {'(v3)':>12} |
| {'β'*73} |
| {'Relay/BN params':>25} {'76K':>12} {'281M':>12} {f'{n_bn:,}':>12} |
| {'Total params':>25} {'6.1M':>12} {'287M':>12} {f'{n_params:,}':>12} |
| {'Best loss':>25} {'0.1900':>12} {'0.1757':>12} {f'{best_loss:.4f}':>12} |
| {'Constellation signal':>25} {'6%':>12} {'88%':>12} {'100%':>12} |
| {'Skip params':>25} {'0':>12} {'268M':>12} {'0':>12} |
| {'Anchor routing':>25} {'2 active':>12} {'class-spec':>12} {'(see T3)':>12} |
| """) |
|
|
| |
| with torch.no_grad(): |
| drift = bn.drift().detach() |
| near = (drift - 0.29154).abs().lt(0.05).float().mean().item() |
| near_tight = (drift - 0.29154).abs().lt(0.03).float().mean().item() |
| crossed = (drift > 0.29154).float().mean().item() |
|
|
| print(f" Final drift stats:") |
| print(f" Mean: {drift.mean():.6f} rad ({math.degrees(drift.mean().item()):.2f}Β°)") |
| print(f" Max: {drift.max():.6f} rad ({math.degrees(drift.max().item()):.2f}Β°)") |
| print(f" Near 0.29154: {near:.1%} (Β±0.05) {near_tight:.1%} (Β±0.03)") |
| print(f" Crossed 0.29: {crossed:.1%}") |
|
|
|
|
| print(f"\n{'='*80}") |
| print("ANALYSIS COMPLETE") |
| print(f"{'='*80}") |