geolip-conduit-experiments / cell_4_proper_experiment_3.py
AbstractPhil's picture
Create cell_4_proper_experiment_3.py
1d62b3b verified
"""
Cell 4 β€” Theorem 3: Release Fidelity
=======================================
The light speeding back up after leaving the lens.
Full encode→SVD→decode round-trip reconstruction analysis.
NOT the SVD-only residual (which is ~1e-12).
The FULL decoder reconstruction β€” where the model chooses
what to preserve and what to lose.
Questions:
1. Does per-patch reconstruction error vary spatially?
2. Does it differ across classes?
3. Per-mode reconstruction: which modes matter for which patches?
4. Does the release residual map classify better than friction?
5. Combined release + friction + eigenvalues β€” full conduit test
6. Where does the model FAIL to reconstruct? Those are the boundaries.
"""
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
device = torch.device('cuda')
# ═══════════════════════════════════════════════════════════════
# LOAD
# ═══════════════════════════════════════════════════════════════
print("Loading Freckles v40 + CIFAR-10...")
from geolip_svae import load_model
from geolip_svae.model import extract_patches, stitch_patches
import torchvision
import torchvision.transforms as T
freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
freckles.eval()
ps = freckles.patch_size # 4
transform = T.Compose([T.Resize(64), T.ToTensor()])
cifar_test = torchvision.datasets.CIFAR10(
root='/content/data', train=False, download=True, transform=transform)
loader = torch.utils.data.DataLoader(
cifar_test, batch_size=64, shuffle=False, num_workers=4)
CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
gh, gw = 64 // ps, 64 // ps # 16, 16
n_patches = gh * gw # 256
# ═══════════════════════════════════════════════════════════════
# 1. FULL ROUND-TRIP RECONSTRUCTION β€” Per-patch error maps
# ═══════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print(" 1. FULL ROUND-TRIP β€” Per-patch reconstruction error")
print("=" * 70)
print("\nCollecting per-patch reconstruction errors...\n")
# Per-class spatial error maps
class_error_sum = torch.zeros(10, gh, gw)
class_error_sq = torch.zeros(10, gh, gw)
class_counts = torch.zeros(10)
# Individual maps for probing
all_error_maps = [] # (error_map, label)
all_s_maps = [] # (S_map, label)
max_collect = 2000
n_collected = 0
for images, labels in tqdm(loader, desc="Reconstructing"):
with torch.no_grad():
images_gpu = images.to(device)
out = freckles(images_gpu)
recon = out['recon']
B = images_gpu.shape[0]
S = out['svd']['S'] # (B, N, D)
# Per-patch error: split input and recon into patches, compare
# Input patches: (B, N, C*ps*ps)
input_patches, _, _ = extract_patches(images_gpu, ps)
recon_patches, _, _ = extract_patches(recon, ps)
# Per-patch MSE: (B, N)
patch_mse = (input_patches - recon_patches).pow(2).mean(dim=-1)
# Reshape to spatial: (B, gh, gw)
error_map = patch_mse.reshape(B, gh, gw)
s_map = S.reshape(B, gh, gw, -1)
error_cpu = error_map.cpu()
s_cpu = s_map.cpu()
for i in range(B):
c = labels[i].item()
class_error_sum[c] += error_cpu[i]
class_error_sq[c] += error_cpu[i].pow(2)
class_counts[c] += 1
if n_collected < max_collect:
all_error_maps.append((error_cpu[i], c))
all_s_maps.append((s_cpu[i], c))
n_collected += 1
print(f"Collected {int(class_counts.sum().item())} images, "
f"{n_collected} individual maps\n")
# ═══════════════════════════════════════════════════════════════
# 1a. SPATIAL STRUCTURE OF RECONSTRUCTION ERROR
# ═══════════════════════════════════════════════════════════════
print("=" * 70)
print(" 1a. SPATIAL STRUCTURE β€” Does recon error vary across patches?")
print("=" * 70)
per_image_cv = []
for error_map, label in all_error_maps:
flat = error_map.reshape(-1)
cv = flat.std() / (flat.mean() + 1e-10)
per_image_cv.append(cv.item())
cv_arr = np.array(per_image_cv)
print(f"\n Per-image spatial CV of reconstruction error:")
print(f" Mean CV: {cv_arr.mean():.4f}")
print(f" Median CV: {np.median(cv_arr):.4f}")
print(f" Min CV: {cv_arr.min():.4f}")
print(f" Max CV: {cv_arr.max():.4f}")
print(f" VERDICT: {'HAS SPATIAL STRUCTURE' if cv_arr.mean() > 0.1 else 'SPATIALLY UNIFORM'}")
# ═══════════════════════════════════════════════════════════════
# 1b. PER-CLASS RECONSTRUCTION ERROR
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 1b. PER-CLASS RECONSTRUCTION ERROR")
print("=" * 70)
class_means = class_error_sum / class_counts[:, None, None].clamp(min=1)
class_vars = class_error_sq / class_counts[:, None, None].clamp(min=1) - class_means.pow(2)
print(f"\n {'Class':<10s} {'Mean MSE':>10s} {'Std MSE':>10s} {'Max patch':>10s}")
print(f" {'-' * 42}")
for c in range(10):
m = class_means[c]
print(f" {CLASSES[c]:<10s} {m.mean():10.6f} {m.std():10.6f} {m.max():10.6f}")
# Inter-class distance
class_flat = class_means.reshape(10, -1)
class_flat_norm = F.normalize(class_flat, dim=-1)
cos_sim = class_flat_norm @ class_flat_norm.T
inter_mask = ~torch.eye(10, dtype=torch.bool)
print(f"\n Mean inter-class cosine similarity: {cos_sim[inter_mask].mean():.6f}")
print(f" Min inter-class cosine similarity: {cos_sim[inter_mask].min():.6f}")
print(f" VERDICT: {'DISTINCT PATTERNS' if cos_sim[inter_mask].min() < 0.99 else 'SIMILAR PATTERNS'}")
# ═══════════════════════════════════════════════════════════════
# 2. CENTER vs EDGE RECONSTRUCTION
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 2. CENTER vs EDGE β€” Where does reconstruction fail?")
print("=" * 70)
center_mask = torch.zeros(gh, gw, dtype=torch.bool)
center_mask[4:12, 4:12] = True
edge_mask = ~center_mask
# Corner masks for finer granularity
corner_mask = torch.zeros(gh, gw, dtype=torch.bool)
corner_mask[:4, :4] = True
corner_mask[:4, 12:] = True
corner_mask[12:, :4] = True
corner_mask[12:, 12:] = True
print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Corner':>8s} {'E/C ratio':>10s}")
print(f" {'-' * 48}")
for c in range(10):
m = class_means[c]
center = m[center_mask].mean().item()
edge = m[edge_mask].mean().item()
corner = m[corner_mask].mean().item()
ratio = edge / (center + 1e-10)
print(f" {CLASSES[c]:<10s} {center:8.6f} {edge:8.6f} {corner:8.6f} {ratio:10.4f}")
# ═══════════════════════════════════════════════════════════════
# 3. PER-MODE RECONSTRUCTION β€” Which modes carry class signal?
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 3. PER-MODE RECONSTRUCTION β€” Ablating SVD modes")
print("=" * 70)
print("\nReconstructing with individual modes...")
# For a subset, reconstruct using only mode k
n_ablate = 256
subset = torch.utils.data.Subset(cifar_test, range(n_ablate))
ablate_loader = torch.utils.data.DataLoader(subset, batch_size=64)
mode_errors = {k: [] for k in range(4)}
mode_labels = []
full_errors = []
for images, labels in ablate_loader:
with torch.no_grad():
images_gpu = images.to(device)
out = freckles(images_gpu)
S = out['svd']['S'] # (B, N, D)
U = out['svd']['U'] # (B, N, V, D)
Vt = out['svd']['Vt'] # (B, N, D, D)
B_img, N, D = S.shape
# Full reconstruction error per patch
recon = out['recon']
input_p, _, _ = extract_patches(images_gpu, ps)
recon_p, _, _ = extract_patches(recon, ps)
full_err = (input_p - recon_p).pow(2).mean(dim=-1) # (B, N)
full_errors.append(full_err.cpu())
# Per-mode ablation: reconstruct using only mode k
for k in range(D):
# Zero out all modes except k
S_ablated = torch.zeros_like(S)
S_ablated[:, :, k] = S[:, :, k]
# Reconstruct: decoded_patches = U @ diag(S) @ Vt
decoded = torch.einsum('bnvd,bnd,bndk->bnvk', U, S_ablated, Vt)
# decoded: (B, N, V, D) but we need (B, N, V*D) = (B, N, patch_dim)
# Actually the SVAE decoder is more complex β€” it uses cross-attention.
# For a clean per-mode test, compare S_ablated contribution to full S.
# Mode k's contribution to the enc_out matrix M:
# M_k = U[:,:,:,k] * S[:,:,k] @ Vt[:,:,k,:]
# Fraction of total energy in mode k:
mode_energy = S[:, :, k].pow(2) / (S.pow(2).sum(dim=-1) + 1e-10)
mode_errors[k].append(mode_energy.cpu())
mode_labels.append(labels)
mode_labels = torch.cat(mode_labels)
full_errors = torch.cat(full_errors) # (N_img, N_patches)
print(f"\n Per-mode energy fraction (how much each mode contributes):")
print(f"\n {'Class':<10s}", end="")
for k in range(4):
print(f" {'Mode'+str(k):>8s}", end="")
print(f" {'FullMSE':>10s}")
print(f" {'-' * 50}")
for c in range(10):
mask = mode_labels == c
if mask.sum() == 0:
continue
print(f" {CLASSES[c]:<10s}", end="")
for k in range(4):
me = torch.cat(mode_errors[k])
energy = me[mask].mean().item()
print(f" {energy:8.4f}", end="")
ferr = full_errors[mask].mean().item()
print(f" {ferr:10.6f}")
# ═══════════════════════════════════════════════════════════════
# 4. RECONSTRUCTION ERROR AS CLASSIFIER
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 4. LINEAR PROBE β€” Reconstruction error maps as features")
print("=" * 70)
# Flatten per-patch error map as feature
error_features = []
error_labels = []
for error_map, label in all_error_maps:
error_features.append(error_map.reshape(-1)) # (256,)
error_labels.append(label)
X_err = torch.stack(error_features) # (N, 256)
y_err = torch.tensor(error_labels)
N = len(y_err)
perm = torch.randperm(N)
n_train = int(0.8 * N)
n_classes = 10
def ridge_probe(X, y, perm, n_train, name, lam=1.0):
X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]]
X_te, y_te = X[perm[n_train:]], y[perm[n_train:]]
m = X_tr.mean(0)
s = X_tr.std(0).clamp(min=1e-8)
X_tr_n = (X_tr - m) / s
X_te_n = (X_te - m) / s
Y_oh = torch.zeros(len(y_tr), n_classes)
Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0)
W = torch.linalg.solve(
X_tr_n.T @ X_tr_n + lam * torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh)
train_acc = ((X_tr_n @ W).argmax(1) == y_tr).float().mean().item()
test_acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item()
print(f" {name:<40s} dims={X.shape[1]:>5d} "
f"train={train_acc:.1%} test={test_acc:.1%}")
# Per-class
preds = (X_te_n @ W).argmax(1)
for c in range(n_classes):
cm = y_te == c
if cm.sum() > 0:
acc = (preds[cm] == y_te[cm]).float().mean().item()
if c == 0:
print(f" {'Class':<10s} {'Acc':>6s}")
print(f" {'-' * 18}")
bar = 'β–ˆ' * int(acc * 20)
print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}")
return test_acc
print(f"\n Ridge probe comparison:\n")
acc_err = ridge_probe(X_err, y_err, perm, n_train, "Recon error spatial map")
# ═══════════════════════════════════════════════════════════════
# 5. COMBINED: RELEASE + EIGENVALUES + FRICTION
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 5. FULL CONDUIT β€” Release error + eigenvalues + friction")
print("=" * 70)
# Rebuild combined features with release error included
from geolip_core.linalg.conduit import FLEighConduit
conduit = FLEighConduit().to(device)
combined_features = []
combined_labels = []
n_collected2 = 0
for images, labels in tqdm(loader, desc="Full conduit"):
if n_collected2 >= max_collect:
break
with torch.no_grad():
images_gpu = images.to(device)
out = freckles(images_gpu)
recon = out['recon']
S = out['svd']['S']
Vt = out['svd']['Vt']
B_img, N, D = S.shape
# Per-patch recon error
input_p, _, _ = extract_patches(images_gpu, ps)
recon_p, _, _ = extract_patches(recon, ps)
patch_mse = (input_p - recon_p).pow(2).mean(dim=-1) # (B, N)
# Friction from conduit
S2 = S.pow(2)
G = torch.einsum('bnij,bnj,bnjk->bnik',
Vt.transpose(-2, -1), S2, Vt)
G_flat = G.reshape(B_img * N, D, D)
packet = conduit(G_flat)
fric = packet.friction.reshape(B_img, N, D)
# Combine: error_map(256) + S_map(256Γ—4) + friction_map(256Γ—4)
err_flat = patch_mse.reshape(B_img, gh * gw)
s_flat = S.reshape(B_img, gh * gw * D)
f_flat = fric.reshape(B_img, gh * gw * D)
for i in range(B_img):
if n_collected2 >= max_collect:
break
feat = torch.cat([
err_flat[i].cpu(),
s_flat[i].cpu(),
f_flat[i].cpu(),
])
combined_features.append(feat)
combined_labels.append(labels[i].item())
n_collected2 += 1
X_full = torch.stack(combined_features)
y_full = torch.tensor(combined_labels)
perm2 = torch.randperm(len(y_full))
n_train2 = int(0.8 * len(y_full))
print(f"\n Comparative linear probes:\n")
# Individual features
X_err_only = X_full[:, :256]
X_s_only = X_full[:, 256:256 + 256 * 4]
X_f_only = X_full[:, 256 + 256 * 4:]
acc_err2 = ridge_probe(X_err_only, y_full, perm2, n_train2,
"Release error only")
print()
acc_s2 = ridge_probe(X_s_only, y_full, perm2, n_train2,
"Eigenvalues (S) only")
print()
acc_f2 = ridge_probe(X_f_only, y_full, perm2, n_train2,
"Friction only")
# Combinations
print(f"\n Combinations:\n")
X_err_s = torch.cat([X_err_only, X_s_only], dim=-1)
acc_err_s = ridge_probe(X_err_s, y_full, perm2, n_train2,
"Release + Eigenvalues")
X_err_f = torch.cat([X_err_only, X_f_only], dim=-1)
acc_err_f = ridge_probe(X_err_f, y_full, perm2, n_train2,
"Release + Friction")
acc_all = ridge_probe(X_full, y_full, perm2, n_train2,
"Release + Eigenvalues + Friction")
# ═══════════════════════════════════════════════════════════════
# 6. HIGH-ERROR PATCH ANALYSIS β€” Where does the model fail?
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" 6. HIGH-ERROR PATCHES β€” Where does reconstruction fail?")
print("=" * 70)
# For each class, find patches with highest error
print(f"\n Top error positions per class (patch coordinates):")
print(f" {'Class':<10s} {'Top 3 positions (row, col)':>40s} {'Error ratio':>12s}")
print(f" {'-' * 64}")
for c in range(10):
cm = class_means[c] # (gh, gw)
flat = cm.reshape(-1)
top3 = flat.argsort(descending=True)[:3]
positions = [(idx.item() // gw, idx.item() % gw) for idx in top3]
errs = [flat[idx].item() for idx in top3]
mean_err = cm.mean().item()
ratio = errs[0] / (mean_err + 1e-10)
pos_str = ", ".join(f"({r},{c_})" for r, c_ in positions)
print(f" {CLASSES[c]:<10s} {pos_str:>40s} {ratio:12.2f}x")
# Overall hot spots across all classes
overall_error = class_error_sum.sum(0) / class_counts.sum()
hot_threshold = overall_error.mean() + 2 * overall_error.std()
hot_patches = (overall_error > hot_threshold).sum().item()
print(f"\n Overall error map:")
print(f" Mean: {overall_error.mean():.6f}")
print(f" Std: {overall_error.std():.6f}")
print(f" Hot patches (>2Οƒ): {hot_patches}/{gh * gw}")
# ═══════════════════════════════════════════════════════════════
# SUMMARY
# ═══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" THEOREM 3: RELEASE FIDELITY β€” SUMMARY")
print("=" * 70)
print(f"""
SPATIAL STRUCTURE:
Recon error spatial CV: {cv_arr.mean():.4f}
(Friction spatial CV was: 0.0137)
CLASSIFICATION (ridge probe, test accuracy):
Chance: 10.0%
Friction maps: 24.3% (from Cell 3)
Eigenvalue (S) maps: 21.0% (from Cell 3)
Release error maps: {acc_err:.1%}
Release + Eigenvalues: {acc_err_s:.1%}
Release + Friction: {acc_err_f:.1%}
FULL CONDUIT (all three): {acc_all:.1%}
THE QUESTION ANSWERED:
Does the release signal carry class-discriminative information
that eigenvalues and friction do not?
Lift from release over eigenvalues: {(acc_err2 - acc_s2) * 100:+.1f}pp
Lift from full conduit over eigenvalues: {(acc_all - acc_s2) * 100:+.1f}pp
""")