| """ |
| Cell 4 β Theorem 3: Release Fidelity |
| ======================================= |
| The light speeding back up after leaving the lens. |
| |
| Full encodeβSVDβdecode round-trip reconstruction analysis. |
| NOT the SVD-only residual (which is ~1e-12). |
| The FULL decoder reconstruction β where the model chooses |
| what to preserve and what to lose. |
| |
| Questions: |
| 1. Does per-patch reconstruction error vary spatially? |
| 2. Does it differ across classes? |
| 3. Per-mode reconstruction: which modes matter for which patches? |
| 4. Does the release residual map classify better than friction? |
| 5. Combined release + friction + eigenvalues β full conduit test |
| 6. Where does the model FAIL to reconstruct? Those are the boundaries. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from tqdm import tqdm |
|
|
| device = torch.device('cuda') |
|
|
| |
| |
| |
|
|
| print("Loading Freckles v40 + CIFAR-10...") |
| from geolip_svae import load_model |
| from geolip_svae.model import extract_patches, stitch_patches |
| import torchvision |
| import torchvision.transforms as T |
|
|
| freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device) |
| freckles.eval() |
|
|
| ps = freckles.patch_size |
| transform = T.Compose([T.Resize(64), T.ToTensor()]) |
| cifar_test = torchvision.datasets.CIFAR10( |
| root='/content/data', train=False, download=True, transform=transform) |
| loader = torch.utils.data.DataLoader( |
| cifar_test, batch_size=64, shuffle=False, num_workers=4) |
|
|
| CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
| gh, gw = 64 // ps, 64 // ps |
| n_patches = gh * gw |
|
|
|
|
| |
| |
| |
|
|
| print("\n" + "=" * 70) |
| print(" 1. FULL ROUND-TRIP β Per-patch reconstruction error") |
| print("=" * 70) |
|
|
| print("\nCollecting per-patch reconstruction errors...\n") |
|
|
| |
| class_error_sum = torch.zeros(10, gh, gw) |
| class_error_sq = torch.zeros(10, gh, gw) |
| class_counts = torch.zeros(10) |
|
|
| |
| all_error_maps = [] |
| all_s_maps = [] |
| max_collect = 2000 |
| n_collected = 0 |
|
|
| for images, labels in tqdm(loader, desc="Reconstructing"): |
| with torch.no_grad(): |
| images_gpu = images.to(device) |
| out = freckles(images_gpu) |
| recon = out['recon'] |
|
|
| B = images_gpu.shape[0] |
| S = out['svd']['S'] |
|
|
| |
| |
| input_patches, _, _ = extract_patches(images_gpu, ps) |
| recon_patches, _, _ = extract_patches(recon, ps) |
|
|
| |
| patch_mse = (input_patches - recon_patches).pow(2).mean(dim=-1) |
|
|
| |
| error_map = patch_mse.reshape(B, gh, gw) |
| s_map = S.reshape(B, gh, gw, -1) |
|
|
| error_cpu = error_map.cpu() |
| s_cpu = s_map.cpu() |
|
|
| for i in range(B): |
| c = labels[i].item() |
| class_error_sum[c] += error_cpu[i] |
| class_error_sq[c] += error_cpu[i].pow(2) |
| class_counts[c] += 1 |
|
|
| if n_collected < max_collect: |
| all_error_maps.append((error_cpu[i], c)) |
| all_s_maps.append((s_cpu[i], c)) |
| n_collected += 1 |
|
|
| print(f"Collected {int(class_counts.sum().item())} images, " |
| f"{n_collected} individual maps\n") |
|
|
|
|
| |
| |
| |
|
|
| print("=" * 70) |
| print(" 1a. SPATIAL STRUCTURE β Does recon error vary across patches?") |
| print("=" * 70) |
|
|
| per_image_cv = [] |
| for error_map, label in all_error_maps: |
| flat = error_map.reshape(-1) |
| cv = flat.std() / (flat.mean() + 1e-10) |
| per_image_cv.append(cv.item()) |
|
|
| cv_arr = np.array(per_image_cv) |
| print(f"\n Per-image spatial CV of reconstruction error:") |
| print(f" Mean CV: {cv_arr.mean():.4f}") |
| print(f" Median CV: {np.median(cv_arr):.4f}") |
| print(f" Min CV: {cv_arr.min():.4f}") |
| print(f" Max CV: {cv_arr.max():.4f}") |
| print(f" VERDICT: {'HAS SPATIAL STRUCTURE' if cv_arr.mean() > 0.1 else 'SPATIALLY UNIFORM'}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 1b. PER-CLASS RECONSTRUCTION ERROR") |
| print("=" * 70) |
|
|
| class_means = class_error_sum / class_counts[:, None, None].clamp(min=1) |
| class_vars = class_error_sq / class_counts[:, None, None].clamp(min=1) - class_means.pow(2) |
|
|
| print(f"\n {'Class':<10s} {'Mean MSE':>10s} {'Std MSE':>10s} {'Max patch':>10s}") |
| print(f" {'-' * 42}") |
| for c in range(10): |
| m = class_means[c] |
| print(f" {CLASSES[c]:<10s} {m.mean():10.6f} {m.std():10.6f} {m.max():10.6f}") |
|
|
| |
| class_flat = class_means.reshape(10, -1) |
| class_flat_norm = F.normalize(class_flat, dim=-1) |
| cos_sim = class_flat_norm @ class_flat_norm.T |
| inter_mask = ~torch.eye(10, dtype=torch.bool) |
| print(f"\n Mean inter-class cosine similarity: {cos_sim[inter_mask].mean():.6f}") |
| print(f" Min inter-class cosine similarity: {cos_sim[inter_mask].min():.6f}") |
| print(f" VERDICT: {'DISTINCT PATTERNS' if cos_sim[inter_mask].min() < 0.99 else 'SIMILAR PATTERNS'}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 2. CENTER vs EDGE β Where does reconstruction fail?") |
| print("=" * 70) |
|
|
| center_mask = torch.zeros(gh, gw, dtype=torch.bool) |
| center_mask[4:12, 4:12] = True |
| edge_mask = ~center_mask |
|
|
| |
| corner_mask = torch.zeros(gh, gw, dtype=torch.bool) |
| corner_mask[:4, :4] = True |
| corner_mask[:4, 12:] = True |
| corner_mask[12:, :4] = True |
| corner_mask[12:, 12:] = True |
|
|
| print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Corner':>8s} {'E/C ratio':>10s}") |
| print(f" {'-' * 48}") |
| for c in range(10): |
| m = class_means[c] |
| center = m[center_mask].mean().item() |
| edge = m[edge_mask].mean().item() |
| corner = m[corner_mask].mean().item() |
| ratio = edge / (center + 1e-10) |
| print(f" {CLASSES[c]:<10s} {center:8.6f} {edge:8.6f} {corner:8.6f} {ratio:10.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 3. PER-MODE RECONSTRUCTION β Ablating SVD modes") |
| print("=" * 70) |
|
|
| print("\nReconstructing with individual modes...") |
|
|
| |
| n_ablate = 256 |
| subset = torch.utils.data.Subset(cifar_test, range(n_ablate)) |
| ablate_loader = torch.utils.data.DataLoader(subset, batch_size=64) |
|
|
| mode_errors = {k: [] for k in range(4)} |
| mode_labels = [] |
| full_errors = [] |
|
|
| for images, labels in ablate_loader: |
| with torch.no_grad(): |
| images_gpu = images.to(device) |
| out = freckles(images_gpu) |
|
|
| S = out['svd']['S'] |
| U = out['svd']['U'] |
| Vt = out['svd']['Vt'] |
| B_img, N, D = S.shape |
|
|
| |
| recon = out['recon'] |
| input_p, _, _ = extract_patches(images_gpu, ps) |
| recon_p, _, _ = extract_patches(recon, ps) |
| full_err = (input_p - recon_p).pow(2).mean(dim=-1) |
| full_errors.append(full_err.cpu()) |
|
|
| |
| for k in range(D): |
| |
| S_ablated = torch.zeros_like(S) |
| S_ablated[:, :, k] = S[:, :, k] |
|
|
| |
| decoded = torch.einsum('bnvd,bnd,bndk->bnvk', U, S_ablated, Vt) |
| |
| |
| |
| |
| |
| |
| mode_energy = S[:, :, k].pow(2) / (S.pow(2).sum(dim=-1) + 1e-10) |
| mode_errors[k].append(mode_energy.cpu()) |
|
|
| mode_labels.append(labels) |
|
|
| mode_labels = torch.cat(mode_labels) |
| full_errors = torch.cat(full_errors) |
|
|
| print(f"\n Per-mode energy fraction (how much each mode contributes):") |
| print(f"\n {'Class':<10s}", end="") |
| for k in range(4): |
| print(f" {'Mode'+str(k):>8s}", end="") |
| print(f" {'FullMSE':>10s}") |
| print(f" {'-' * 50}") |
|
|
| for c in range(10): |
| mask = mode_labels == c |
| if mask.sum() == 0: |
| continue |
| print(f" {CLASSES[c]:<10s}", end="") |
| for k in range(4): |
| me = torch.cat(mode_errors[k]) |
| energy = me[mask].mean().item() |
| print(f" {energy:8.4f}", end="") |
| ferr = full_errors[mask].mean().item() |
| print(f" {ferr:10.6f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 4. LINEAR PROBE β Reconstruction error maps as features") |
| print("=" * 70) |
|
|
| |
| error_features = [] |
| error_labels = [] |
| for error_map, label in all_error_maps: |
| error_features.append(error_map.reshape(-1)) |
| error_labels.append(label) |
|
|
| X_err = torch.stack(error_features) |
| y_err = torch.tensor(error_labels) |
|
|
| N = len(y_err) |
| perm = torch.randperm(N) |
| n_train = int(0.8 * N) |
| n_classes = 10 |
|
|
| def ridge_probe(X, y, perm, n_train, name, lam=1.0): |
| X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]] |
| X_te, y_te = X[perm[n_train:]], y[perm[n_train:]] |
| m = X_tr.mean(0) |
| s = X_tr.std(0).clamp(min=1e-8) |
| X_tr_n = (X_tr - m) / s |
| X_te_n = (X_te - m) / s |
| Y_oh = torch.zeros(len(y_tr), n_classes) |
| Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0) |
| W = torch.linalg.solve( |
| X_tr_n.T @ X_tr_n + lam * torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh) |
| train_acc = ((X_tr_n @ W).argmax(1) == y_tr).float().mean().item() |
| test_acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item() |
| print(f" {name:<40s} dims={X.shape[1]:>5d} " |
| f"train={train_acc:.1%} test={test_acc:.1%}") |
|
|
| |
| preds = (X_te_n @ W).argmax(1) |
| for c in range(n_classes): |
| cm = y_te == c |
| if cm.sum() > 0: |
| acc = (preds[cm] == y_te[cm]).float().mean().item() |
| if c == 0: |
| print(f" {'Class':<10s} {'Acc':>6s}") |
| print(f" {'-' * 18}") |
| bar = 'β' * int(acc * 20) |
| print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}") |
| return test_acc |
|
|
| print(f"\n Ridge probe comparison:\n") |
| acc_err = ridge_probe(X_err, y_err, perm, n_train, "Recon error spatial map") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 5. FULL CONDUIT β Release error + eigenvalues + friction") |
| print("=" * 70) |
|
|
| |
| from geolip_core.linalg.conduit import FLEighConduit |
| conduit = FLEighConduit().to(device) |
|
|
| combined_features = [] |
| combined_labels = [] |
| n_collected2 = 0 |
|
|
| for images, labels in tqdm(loader, desc="Full conduit"): |
| if n_collected2 >= max_collect: |
| break |
| with torch.no_grad(): |
| images_gpu = images.to(device) |
| out = freckles(images_gpu) |
| recon = out['recon'] |
| S = out['svd']['S'] |
| Vt = out['svd']['Vt'] |
| B_img, N, D = S.shape |
|
|
| |
| input_p, _, _ = extract_patches(images_gpu, ps) |
| recon_p, _, _ = extract_patches(recon, ps) |
| patch_mse = (input_p - recon_p).pow(2).mean(dim=-1) |
|
|
| |
| S2 = S.pow(2) |
| G = torch.einsum('bnij,bnj,bnjk->bnik', |
| Vt.transpose(-2, -1), S2, Vt) |
| G_flat = G.reshape(B_img * N, D, D) |
| packet = conduit(G_flat) |
| fric = packet.friction.reshape(B_img, N, D) |
|
|
| |
| err_flat = patch_mse.reshape(B_img, gh * gw) |
| s_flat = S.reshape(B_img, gh * gw * D) |
| f_flat = fric.reshape(B_img, gh * gw * D) |
|
|
| for i in range(B_img): |
| if n_collected2 >= max_collect: |
| break |
| feat = torch.cat([ |
| err_flat[i].cpu(), |
| s_flat[i].cpu(), |
| f_flat[i].cpu(), |
| ]) |
| combined_features.append(feat) |
| combined_labels.append(labels[i].item()) |
| n_collected2 += 1 |
|
|
| X_full = torch.stack(combined_features) |
| y_full = torch.tensor(combined_labels) |
|
|
| perm2 = torch.randperm(len(y_full)) |
| n_train2 = int(0.8 * len(y_full)) |
|
|
| print(f"\n Comparative linear probes:\n") |
|
|
| |
| X_err_only = X_full[:, :256] |
| X_s_only = X_full[:, 256:256 + 256 * 4] |
| X_f_only = X_full[:, 256 + 256 * 4:] |
|
|
| acc_err2 = ridge_probe(X_err_only, y_full, perm2, n_train2, |
| "Release error only") |
| print() |
| acc_s2 = ridge_probe(X_s_only, y_full, perm2, n_train2, |
| "Eigenvalues (S) only") |
| print() |
| acc_f2 = ridge_probe(X_f_only, y_full, perm2, n_train2, |
| "Friction only") |
|
|
| |
| print(f"\n Combinations:\n") |
| X_err_s = torch.cat([X_err_only, X_s_only], dim=-1) |
| acc_err_s = ridge_probe(X_err_s, y_full, perm2, n_train2, |
| "Release + Eigenvalues") |
|
|
| X_err_f = torch.cat([X_err_only, X_f_only], dim=-1) |
| acc_err_f = ridge_probe(X_err_f, y_full, perm2, n_train2, |
| "Release + Friction") |
|
|
| acc_all = ridge_probe(X_full, y_full, perm2, n_train2, |
| "Release + Eigenvalues + Friction") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" 6. HIGH-ERROR PATCHES β Where does reconstruction fail?") |
| print("=" * 70) |
|
|
| |
| print(f"\n Top error positions per class (patch coordinates):") |
| print(f" {'Class':<10s} {'Top 3 positions (row, col)':>40s} {'Error ratio':>12s}") |
| print(f" {'-' * 64}") |
|
|
| for c in range(10): |
| cm = class_means[c] |
| flat = cm.reshape(-1) |
| top3 = flat.argsort(descending=True)[:3] |
| positions = [(idx.item() // gw, idx.item() % gw) for idx in top3] |
| errs = [flat[idx].item() for idx in top3] |
| mean_err = cm.mean().item() |
| ratio = errs[0] / (mean_err + 1e-10) |
| pos_str = ", ".join(f"({r},{c_})" for r, c_ in positions) |
| print(f" {CLASSES[c]:<10s} {pos_str:>40s} {ratio:12.2f}x") |
|
|
| |
| overall_error = class_error_sum.sum(0) / class_counts.sum() |
| hot_threshold = overall_error.mean() + 2 * overall_error.std() |
| hot_patches = (overall_error > hot_threshold).sum().item() |
| print(f"\n Overall error map:") |
| print(f" Mean: {overall_error.mean():.6f}") |
| print(f" Std: {overall_error.std():.6f}") |
| print(f" Hot patches (>2Ο): {hot_patches}/{gh * gw}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'=' * 70}") |
| print(" THEOREM 3: RELEASE FIDELITY β SUMMARY") |
| print("=" * 70) |
|
|
| print(f""" |
| SPATIAL STRUCTURE: |
| Recon error spatial CV: {cv_arr.mean():.4f} |
| (Friction spatial CV was: 0.0137) |
| |
| CLASSIFICATION (ridge probe, test accuracy): |
| Chance: 10.0% |
| Friction maps: 24.3% (from Cell 3) |
| Eigenvalue (S) maps: 21.0% (from Cell 3) |
| Release error maps: {acc_err:.1%} |
| Release + Eigenvalues: {acc_err_s:.1%} |
| Release + Friction: {acc_err_f:.1%} |
| FULL CONDUIT (all three): {acc_all:.1%} |
| |
| THE QUESTION ANSWERED: |
| Does the release signal carry class-discriminative information |
| that eigenvalues and friction do not? |
| Lift from release over eigenvalues: {(acc_err2 - acc_s2) * 100:+.1f}pp |
| Lift from full conduit over eigenvalues: {(acc_all - acc_s2) * 100:+.1f}pp |
| """) |