""" Cell 3 — Spatial Friction Map Analysis ======================================== The mean friction is uniform across classes (12.19 ± 0.08). But the SPATIAL PATTERN of friction within images might differ. Questions: 1. Do friction maps have spatial structure? (or uniform per image) 2. Does the spatial pattern differ across classes? 3. Do edge/boundary patches have higher friction than interior? 4. Is per-patch friction discriminative even if per-class mean is not? 5. What does the friction map look like for individual images? """ import torch import torch.nn.functional as F import numpy as np from tqdm import tqdm from geolip_core.linalg.conduit import FLEighConduit device = torch.device('cuda') # ═══════════════════════════════════════════════════════════════ # LOAD DATA # ═══════════════════════════════════════════════════════════════ print("Loading Freckles v40 + CIFAR-10...") from geolip_svae import load_model import torchvision import torchvision.transforms as T freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device) freckles.eval() transform = T.Compose([T.Resize(64), T.ToTensor()]) cifar_test = torchvision.datasets.CIFAR10( root='/content/data', train=False, download=True, transform=transform) loader = torch.utils.data.DataLoader( cifar_test, batch_size=64, shuffle=False, num_workers=4) CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] conduit = FLEighConduit().to(device) gh, gw = 16, 16 # patch grid # ═══════════════════════════════════════════════════════════════ # COLLECT SPATIAL FRICTION MAPS # ═══════════════════════════════════════════════════════════════ print("Collecting spatial friction maps (full test set)...\n") # Per-class friction maps: (10, gh, gw, D=4) class_friction_sum = torch.zeros(10, gh, gw, 4) class_friction_sq = torch.zeros(10, gh, gw, 4) class_settle_sum = torch.zeros(10, gh, gw, 4) class_counts = torch.zeros(10) # Also collect per-image statistics for discriminability analysis all_friction_maps = [] # list of (friction_map, label) all_settle_maps = [] n_images_collected = 0 max_collect = 2000 # collect individual maps for first 2000 images for images, labels in tqdm(loader, desc="Processing"): with torch.no_grad(): out = freckles(images.to(device)) S = out['svd']['S'] # (B, N, D) Vt = out['svd']['Vt'] # (B, N, D, D) B_img, N, D = S.shape # Build Gram matrices S2 = S.pow(2) G = torch.einsum('bnij,bnj,bnjk->bnik', Vt.transpose(-2, -1), S2, Vt) G_flat = G.reshape(B_img * N, D, D) packet = conduit(G_flat) # Reshape to spatial: (B, gh, gw, D) fric_map = packet.friction.reshape(B_img, gh, gw, D) sett_map = packet.settle.reshape(B_img, gh, gw, D) fric_cpu = fric_map.cpu() sett_cpu = sett_map.cpu() for i in range(B_img): c = labels[i].item() class_friction_sum[c] += fric_cpu[i] class_friction_sq[c] += fric_cpu[i].pow(2) class_settle_sum[c] += sett_cpu[i] class_counts[c] += 1 if n_images_collected < max_collect: all_friction_maps.append((fric_cpu[i], c)) all_settle_maps.append((sett_cpu[i], c)) n_images_collected += 1 print(f"\nCollected {int(class_counts.sum().item())} images, " f"{n_images_collected} individual maps\n") # ═══════════════════════════════════════════════════════════════ # 1. SPATIAL STRUCTURE WITHIN IMAGES # ═══════════════════════════════════════════════════════════════ print("=" * 70) print(" 1. SPATIAL STRUCTURE — Do friction maps have spatial variance?") print("=" * 70) # Per-image spatial variance: does friction vary across patches within ONE image? per_image_spatial_var = [] for fric_map, label in all_friction_maps: # fric_map: (gh, gw, 4) # Spatial variance: how much does friction vary across the 16x16 grid? per_mode_var = fric_map.reshape(-1, 4).var(dim=0) # var across 256 patches per_image_spatial_var.append((per_mode_var, label)) spatial_vars = torch.stack([v for v, _ in per_image_spatial_var]) # (N, 4) print(f"\n Per-image spatial friction variance (across 256 patches):") print(f" Mode 0 (S₀): mean={spatial_vars[:, 0].mean():.4f} std={spatial_vars[:, 0].std():.4f}") print(f" Mode 1 (S₁): mean={spatial_vars[:, 1].mean():.4f} std={spatial_vars[:, 1].std():.4f}") print(f" Mode 2 (S₂): mean={spatial_vars[:, 2].mean():.4f} std={spatial_vars[:, 2].std():.4f}") print(f" Mode 3 (S₃): mean={spatial_vars[:, 3].mean():.4f} std={spatial_vars[:, 3].std():.4f}") # Coefficient of variation: spatial_std / spatial_mean per image spatial_means = torch.stack([f.reshape(-1, 4).mean(0) for f, _ in all_friction_maps]) spatial_stds = torch.stack([f.reshape(-1, 4).std(0) for f, _ in all_friction_maps]) spatial_cv = spatial_stds / (spatial_means + 1e-8) print(f"\n Per-image spatial CV (std/mean):") for d in range(4): print(f" Mode {d}: CV mean={spatial_cv[:, d].mean():.4f} " f"median={spatial_cv[:, d].median():.4f} max={spatial_cv[:, d].max():.4f}") has_spatial_structure = spatial_cv.mean() > 0.1 print(f"\n VERDICT: {'HAS SPATIAL STRUCTURE' if has_spatial_structure else 'SPATIALLY UNIFORM'} " f"(mean CV = {spatial_cv.mean():.4f})") # ═══════════════════════════════════════════════════════════════ # 2. PER-CLASS SPATIAL FRICTION PATTERNS # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 2. PER-CLASS SPATIAL PATTERNS — Do classes have different friction maps?") print("=" * 70) # Average friction map per class class_means = class_friction_sum / class_counts[:, None, None, None].clamp(min=1) class_vars = class_friction_sq / class_counts[:, None, None, None].clamp(min=1) - class_means.pow(2) # Flatten spatial maps and compare between classes class_flat = class_means.reshape(10, -1) # (10, gh*gw*4) # Inter-class distance matrix dists = torch.cdist(class_flat, class_flat) print(f"\n Inter-class friction map L2 distances:") print(f" {'':>10s}", end="") for c in range(10): print(f" {CLASSES[c][:5]:>6s}", end="") print() for c1 in range(10): print(f" {CLASSES[c1][:10]:>10s}", end="") for c2 in range(10): print(f" {dists[c1, c2]:6.3f}", end="") print() # Mean inter-class vs intra-class distance inter_mask = ~torch.eye(10, dtype=torch.bool) inter_dist = dists[inter_mask].mean().item() print(f"\n Mean inter-class distance: {inter_dist:.4f}") # Cosine similarity between class friction maps class_flat_norm = F.normalize(class_flat, dim=-1) cos_sim = class_flat_norm @ class_flat_norm.T cos_off_diag = cos_sim[inter_mask].mean().item() cos_min = cos_sim[inter_mask].min().item() print(f" Mean cosine similarity: {cos_off_diag:.6f}") print(f" Min cosine similarity: {cos_min:.6f}") print(f" VERDICT: {'DISTINCT PATTERNS' if cos_min < 0.99 else 'NEARLY IDENTICAL PATTERNS'}") # ═══════════════════════════════════════════════════════════════ # 3. CENTER vs EDGE FRICTION # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 3. CENTER vs EDGE — Do boundary patches have higher friction?") print("=" * 70) # Define center and edge regions center_mask = torch.zeros(gh, gw, dtype=torch.bool) center_mask[4:12, 4:12] = True # center 8×8 edge_mask = ~center_mask # border ring for c in range(10): fric_c = class_means[c] # (gh, gw, 4) center_fric = fric_c[center_mask].mean().item() edge_fric = fric_c[edge_mask].mean().item() ratio = edge_fric / (center_fric + 1e-8) if c == 0: print(f"\n {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Edge/Center':>12s}") print(f" {'-' * 40}") print(f" {CLASSES[c]:<10s} {center_fric:8.3f} {edge_fric:8.3f} {ratio:12.4f}") # ═══════════════════════════════════════════════════════════════ # 4. PER-PATCH-POSITION DISCRIMINABILITY # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 4. PER-PATCH-POSITION DISCRIMINABILITY") print("=" * 70) # For each patch position (i,j), is friction discriminative across classes? # Use inter-class variance / intra-class variance ratio (F-statistic proxy) position_f_stat = torch.zeros(gh, gw, 4) for pi in range(gh): for pj in range(gw): for d in range(4): # Class means at this position c_means = class_means[:, pi, pj, d] # (10,) # Inter-class variance inter_var = c_means.var().item() # Intra-class variance (averaged) intra_var = class_vars[:, pi, pj, d].mean().item() position_f_stat[pi, pj, d] = inter_var / (intra_var + 1e-10) # Summary print(f"\n F-statistic (inter-class var / intra-class var) per mode:") for d in range(4): fs = position_f_stat[:, :, d] print(f" Mode {d}: mean={fs.mean():.6f} max={fs.max():.6f} " f"top 5% threshold={fs.quantile(0.95):.6f}") # Best discriminative positions for d in range(4): fs = position_f_stat[:, :, d] best_idx = fs.argmax() bi, bj = best_idx // gw, best_idx % gw print(f" Mode {d} best position: ({bi.item()}, {bj.item()}) F={fs.max():.6f}") overall_f = position_f_stat.mean(dim=-1) # avg across modes print(f"\n Overall best discriminative patch position: " f"{(overall_f.argmax() // gw).item()}, {(overall_f.argmax() % gw).item()} " f"F={overall_f.max():.6f}") print(f" Overall mean F-statistic: {overall_f.mean():.6f}") print(f" VERDICT: {'POSITIONALLY DISCRIMINATIVE' if overall_f.max() > 0.01 else 'NOT DISCRIMINATIVE'}") # ═══════════════════════════════════════════════════════════════ # 5. PER-MODE ANALYSIS — Which SVD mode carries most spatial variance? # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 5. PER-MODE SPATIAL VARIANCE — Which mode has the most structure?") print("=" * 70) for d in range(4): # Spatial variance of mean friction map (across all images) overall_mean_map = class_friction_sum.sum(0) / class_counts.sum() # (gh, gw, 4) mode_map = overall_mean_map[:, :, d] sv = mode_map.var().item() sm = mode_map.mean().item() print(f" Mode {d}: map_mean={sm:.4f} map_var={sv:.6f} map_cv={sv**0.5/(sm+1e-8):.4f}") # ═══════════════════════════════════════════════════════════════ # 6. INDIVIDUAL IMAGE FRICTION MAPS # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 6. SAMPLE FRICTION MAPS — Individual images") print("=" * 70) # Show friction statistics for 2 images per class for c in range(10): maps_c = [(f, l) for f, l in all_friction_maps if l == c][:2] for idx, (fric_map, _) in enumerate(maps_c): # fric_map: (gh, gw, 4) flat = fric_map.reshape(-1, 4) fmean = flat.mean(0) fstd = flat.std(0) fmin = flat.min(0).values fmax = flat.max(0).values # Spatial entropy: how concentrated is the friction? fric_total = flat.sum(dim=-1) # per-patch total friction fric_prob = fric_total / (fric_total.sum() + 1e-8) entropy = -(fric_prob * (fric_prob + 1e-10).log()).sum().item() max_entropy = np.log(256) # uniform = max entropy # Hot spots: patches with friction > 2× mean hot = (fric_total > 2 * fric_total.mean()).sum().item() if idx == 0 and c == 0: print(f"\n {'Class':<10s} {'Img':>3s} {'Mean':>8s} {'Std':>8s} " f"{'Max':>8s} {'Entropy':>8s} {'HotSpots':>9s}") print(f" {'-' * 55}") print(f" {CLASSES[c]:<10s} {idx:3d} {fmean.mean():8.2f} {fstd.mean():8.2f} " f"{fmax.max():8.2f} {entropy/max_entropy:8.3f} {hot:9d}") # ═══════════════════════════════════════════════════════════════ # 7. FRICTION MAP AS CLASSIFIER — Linear probe on spatial friction # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 7. LINEAR PROBE — Can flattened friction maps classify?") print("=" * 70) # Collect features and labels features = [] labels_all = [] for fric_map, label in all_friction_maps: features.append(fric_map.reshape(-1)) # (gh*gw*4,) = 1024 labels_all.append(label) X = torch.stack(features) # (N, 1024) y = torch.tensor(labels_all) # (N,) # Train/test split N = len(y) perm = torch.randperm(N) n_train = int(0.8 * N) X_train, y_train = X[perm[:n_train]], y[perm[:n_train]] X_test, y_test = X[perm[n_train:]], y[perm[n_train:]] # Standardize mean = X_train.mean(0) std = X_train.std(0).clamp(min=1e-6) X_train_n = (X_train - mean) / std X_test_n = (X_test - mean) / std # Ridge regression (closed form, no training loop) lam = 1.0 n_classes = 10 Y_onehot = torch.zeros(n_train, n_classes) Y_onehot.scatter_(1, y_train.unsqueeze(1), 1.0) XtX = X_train_n.T @ X_train_n + lam * torch.eye(X_train_n.shape[1]) XtY = X_train_n.T @ Y_onehot W = torch.linalg.solve(XtX, XtY) train_pred = (X_train_n @ W).argmax(1) test_pred = (X_test_n @ W).argmax(1) train_acc = (train_pred == y_train).float().mean().item() test_acc = (test_pred == y_test).float().mean().item() print(f"\n Features: flattened friction map ({X.shape[1]} dims)") print(f" Train: {n_train}, Test: {N - n_train}") print(f" Train accuracy: {train_acc:.1%}") print(f" Test accuracy: {test_acc:.1%}") print(f" Chance: 10.0%") # Per-class accuracy print(f"\n {'Class':<10s} {'Acc':>6s}") print(f" {'-' * 18}") for c in range(n_classes): mask = y_test == c if mask.sum() > 0: acc = (test_pred[mask] == y_test[mask]).float().mean().item() bar = '█' * int(acc * 20) print(f" {CLASSES[c]:<10s} {acc:5.1%} {bar}") print(f"\n VERDICT: {'DISCRIMINATIVE' if test_acc > 0.15 else 'NOT DISCRIMINATIVE'} " f"spatial friction signal") # ═══════════════════════════════════════════════════════════════ # 8. SETTLE MAP ANALYSIS — Same treatment for settle times # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 8. SETTLE MAP — Spatial convergence patterns") print("=" * 70) settle_features = [] settle_labels = [] for sett_map, label in all_settle_maps: settle_features.append(sett_map.reshape(-1)) settle_labels.append(label) X_s = torch.stack(settle_features) y_s = torch.tensor(settle_labels) perm_s = torch.randperm(len(y_s)) n_train_s = int(0.8 * len(y_s)) X_train_s, y_train_s = X_s[perm_s[:n_train_s]], y_s[perm_s[:n_train_s]] X_test_s, y_test_s = X_s[perm_s[n_train_s:]], y_s[perm_s[n_train_s:]] mean_s = X_train_s.mean(0) std_s = X_train_s.std(0).clamp(min=1e-6) X_train_sn = (X_train_s - mean_s) / std_s X_test_sn = (X_test_s - mean_s) / std_s Y_onehot_s = torch.zeros(n_train_s, n_classes) Y_onehot_s.scatter_(1, y_train_s.unsqueeze(1), 1.0) XtX_s = X_train_sn.T @ X_train_sn + lam * torch.eye(X_train_sn.shape[1]) XtY_s = X_train_sn.T @ Y_onehot_s W_s = torch.linalg.solve(XtX_s, XtY_s) test_pred_s = (X_test_sn @ W_s).argmax(1) test_acc_s = (test_pred_s == y_test_s).float().mean().item() print(f" Settle map linear probe:") print(f" Test accuracy: {test_acc_s:.1%}") print(f" VERDICT: {'DISCRIMINATIVE' if test_acc_s > 0.15 else 'NOT DISCRIMINATIVE'}") # ═══════════════════════════════════════════════════════════════ # 9. COMBINED CONDUIT — friction + settle + eigenvalues # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" 9. COMBINED CONDUIT — All evidence stacked") print("=" * 70) # Also test: raw eigenvalues (S values) as spatial maps for comparison print("\n Collecting eigenvalue spatial maps...") all_eval_maps = [] all_combined = [] for fric_map, label in all_friction_maps: pass # Already collected # Re-collect with eigenvalues eval_features = [] combined_features = [] combined_labels = [] idx = 0 for images, labels_batch in loader: if idx >= max_collect: break with torch.no_grad(): out = freckles(images.to(device)) S = out['svd']['S'] Vt = out['svd']['Vt'] B_img, N, D = S.shape S2 = S.pow(2) G = torch.einsum('bnij,bnj,bnjk->bnik', Vt.transpose(-2, -1), S2, Vt) G_flat = G.reshape(B_img * N, D, D) packet = conduit(G_flat) fric = packet.friction.reshape(B_img, gh, gw, D) sett = packet.settle.reshape(B_img, gh, gw, D) evals = S.reshape(B_img, gh, gw, D) # S values as spatial map for i in range(B_img): if idx >= max_collect: break # Eigenvalue spatial map eval_features.append(evals[i].cpu().reshape(-1)) # Combined: friction + settle + eigenvalues combined = torch.cat([ fric[i].cpu().reshape(-1), sett[i].cpu().reshape(-1), evals[i].cpu().reshape(-1), ]) combined_features.append(combined) combined_labels.append(labels_batch[i].item()) idx += 1 # Eigenvalue-only probe X_e = torch.stack(eval_features) y_e = torch.tensor(combined_labels) perm_e = torch.randperm(len(y_e)) n_train_e = int(0.8 * len(y_e)) def ridge_probe(X, y, perm, n_train, name): X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]] X_te, y_te = X[perm[n_train:]], y[perm[n_train:]] m = X_tr.mean(0) s = X_tr.std(0).clamp(min=1e-6) X_tr_n = (X_tr - m) / s X_te_n = (X_te - m) / s Y_oh = torch.zeros(n_train, n_classes) Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0) W = torch.linalg.solve(X_tr_n.T @ X_tr_n + torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh) acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item() print(f" {name:<30s} dims={X.shape[1]:>5d} test_acc={acc:.1%}") return acc print(f"\n Linear probe comparison (all use same train/test split):\n") acc_evals = ridge_probe(X_e, y_e, perm_e, n_train_e, "Eigenvalues (S) spatial") acc_fric = ridge_probe(X, y, perm, n_train, "Friction spatial") acc_sett = ridge_probe(X_s, y_s, perm_s, n_train_s, "Settle spatial") X_c = torch.stack(combined_features) acc_comb = ridge_probe(X_c, y_e, perm_e, n_train_e, "Combined (S+fric+settle)") print(f"\n Chance: 10.0%") print(f" VERDICT: Combined vs eigenvalues-only lift = " f"{(acc_comb - acc_evals) * 100:+.1f} percentage points") # ═══════════════════════════════════════════════════════════════ # SUMMARY # ═══════════════════════════════════════════════════════════════ print(f"\n{'=' * 70}") print(" SPATIAL FRICTION ANALYSIS — SUMMARY") print("=" * 70) print(f" 1. Spatial structure within images: CV = {spatial_cv.mean():.4f}") print(f" 2. Inter-class pattern distance: cos_min = {cos_min:.6f}") print(f" 3. Center vs edge asymmetry: (see table above)") print(f" 4. Per-position F-statistic: max = {overall_f.max():.6f}") print(f" 5. Friction map linear probe: {test_acc:.1%}") print(f" 6. Settle map linear probe: {test_acc_s:.1%}") print(f" 7. Eigenvalue map linear probe: {acc_evals:.1%}") print(f" 8. Combined conduit linear probe: {acc_comb:.1%}") print(f" 9. Conduit lift over eigenvalues: {(acc_comb - acc_evals)*100:+.1f}pp")