File size: 22,303 Bytes
a986628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
"""
Cell 3 β€” Spatial Friction Map Analysis
========================================
The mean friction is uniform across classes (12.19 Β± 0.08).
But the SPATIAL PATTERN of friction within images might differ.

Questions:
  1. Do friction maps have spatial structure? (or uniform per image)
  2. Does the spatial pattern differ across classes?
  3. Do edge/boundary patches have higher friction than interior?
  4. Is per-patch friction discriminative even if per-class mean is not?
  5. What does the friction map look like for individual images?
"""

import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

from geolip_core.linalg.conduit import FLEighConduit

device = torch.device('cuda')

# ═══════════════════════════════════════════════════════════════
# LOAD DATA
# ═══════════════════════════════════════════════════════════════

print("Loading Freckles v40 + CIFAR-10...")
from geolip_svae import load_model
import torchvision
import torchvision.transforms as T

freckles, cfg = load_model(hf_version='v40_freckles_noise', device=device)
freckles.eval()

transform = T.Compose([T.Resize(64), T.ToTensor()])
cifar_test = torchvision.datasets.CIFAR10(
    root='/content/data', train=False, download=True, transform=transform)
loader = torch.utils.data.DataLoader(
    cifar_test, batch_size=64, shuffle=False, num_workers=4)

CLASSES = ['airplane', 'auto', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

conduit = FLEighConduit().to(device)
gh, gw = 16, 16  # patch grid


# ═══════════════════════════════════════════════════════════════
# COLLECT SPATIAL FRICTION MAPS
# ═══════════════════════════════════════════════════════════════

print("Collecting spatial friction maps (full test set)...\n")

# Per-class friction maps: (10, gh, gw, D=4)
class_friction_sum = torch.zeros(10, gh, gw, 4)
class_friction_sq = torch.zeros(10, gh, gw, 4)
class_settle_sum = torch.zeros(10, gh, gw, 4)
class_counts = torch.zeros(10)

# Also collect per-image statistics for discriminability analysis
all_friction_maps = []  # list of (friction_map, label)
all_settle_maps = []

n_images_collected = 0
max_collect = 2000  # collect individual maps for first 2000 images

for images, labels in tqdm(loader, desc="Processing"):
    with torch.no_grad():
        out = freckles(images.to(device))
        S = out['svd']['S']       # (B, N, D)
        Vt = out['svd']['Vt']     # (B, N, D, D)
        B_img, N, D = S.shape

        # Build Gram matrices
        S2 = S.pow(2)
        G = torch.einsum('bnij,bnj,bnjk->bnik',
                         Vt.transpose(-2, -1), S2, Vt)
        G_flat = G.reshape(B_img * N, D, D)

        packet = conduit(G_flat)

        # Reshape to spatial: (B, gh, gw, D)
        fric_map = packet.friction.reshape(B_img, gh, gw, D)
        sett_map = packet.settle.reshape(B_img, gh, gw, D)

    fric_cpu = fric_map.cpu()
    sett_cpu = sett_map.cpu()

    for i in range(B_img):
        c = labels[i].item()
        class_friction_sum[c] += fric_cpu[i]
        class_friction_sq[c] += fric_cpu[i].pow(2)
        class_settle_sum[c] += sett_cpu[i]
        class_counts[c] += 1

        if n_images_collected < max_collect:
            all_friction_maps.append((fric_cpu[i], c))
            all_settle_maps.append((sett_cpu[i], c))
            n_images_collected += 1

print(f"\nCollected {int(class_counts.sum().item())} images, "
      f"{n_images_collected} individual maps\n")


# ═══════════════════════════════════════════════════════════════
# 1. SPATIAL STRUCTURE WITHIN IMAGES
# ═══════════════════════════════════════════════════════════════

print("=" * 70)
print("  1. SPATIAL STRUCTURE β€” Do friction maps have spatial variance?")
print("=" * 70)

# Per-image spatial variance: does friction vary across patches within ONE image?
per_image_spatial_var = []
for fric_map, label in all_friction_maps:
    # fric_map: (gh, gw, 4)
    # Spatial variance: how much does friction vary across the 16x16 grid?
    per_mode_var = fric_map.reshape(-1, 4).var(dim=0)  # var across 256 patches
    per_image_spatial_var.append((per_mode_var, label))

spatial_vars = torch.stack([v for v, _ in per_image_spatial_var])  # (N, 4)

print(f"\n  Per-image spatial friction variance (across 256 patches):")
print(f"  Mode 0 (Sβ‚€): mean={spatial_vars[:, 0].mean():.4f} std={spatial_vars[:, 0].std():.4f}")
print(f"  Mode 1 (S₁): mean={spatial_vars[:, 1].mean():.4f} std={spatial_vars[:, 1].std():.4f}")
print(f"  Mode 2 (Sβ‚‚): mean={spatial_vars[:, 2].mean():.4f} std={spatial_vars[:, 2].std():.4f}")
print(f"  Mode 3 (S₃): mean={spatial_vars[:, 3].mean():.4f} std={spatial_vars[:, 3].std():.4f}")

# Coefficient of variation: spatial_std / spatial_mean per image
spatial_means = torch.stack([f.reshape(-1, 4).mean(0) for f, _ in all_friction_maps])
spatial_stds = torch.stack([f.reshape(-1, 4).std(0) for f, _ in all_friction_maps])
spatial_cv = spatial_stds / (spatial_means + 1e-8)

print(f"\n  Per-image spatial CV (std/mean):")
for d in range(4):
    print(f"    Mode {d}: CV mean={spatial_cv[:, d].mean():.4f} "
          f"median={spatial_cv[:, d].median():.4f} max={spatial_cv[:, d].max():.4f}")

has_spatial_structure = spatial_cv.mean() > 0.1
print(f"\n  VERDICT: {'HAS SPATIAL STRUCTURE' if has_spatial_structure else 'SPATIALLY UNIFORM'} "
      f"(mean CV = {spatial_cv.mean():.4f})")


# ═══════════════════════════════════════════════════════════════
# 2. PER-CLASS SPATIAL FRICTION PATTERNS
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  2. PER-CLASS SPATIAL PATTERNS β€” Do classes have different friction maps?")
print("=" * 70)

# Average friction map per class
class_means = class_friction_sum / class_counts[:, None, None, None].clamp(min=1)
class_vars = class_friction_sq / class_counts[:, None, None, None].clamp(min=1) - class_means.pow(2)

# Flatten spatial maps and compare between classes
class_flat = class_means.reshape(10, -1)  # (10, gh*gw*4)

# Inter-class distance matrix
dists = torch.cdist(class_flat, class_flat)

print(f"\n  Inter-class friction map L2 distances:")
print(f"  {'':>10s}", end="")
for c in range(10):
    print(f" {CLASSES[c][:5]:>6s}", end="")
print()
for c1 in range(10):
    print(f"  {CLASSES[c1][:10]:>10s}", end="")
    for c2 in range(10):
        print(f" {dists[c1, c2]:6.3f}", end="")
    print()

# Mean inter-class vs intra-class distance
inter_mask = ~torch.eye(10, dtype=torch.bool)
inter_dist = dists[inter_mask].mean().item()
print(f"\n  Mean inter-class distance: {inter_dist:.4f}")

# Cosine similarity between class friction maps
class_flat_norm = F.normalize(class_flat, dim=-1)
cos_sim = class_flat_norm @ class_flat_norm.T
cos_off_diag = cos_sim[inter_mask].mean().item()
cos_min = cos_sim[inter_mask].min().item()
print(f"  Mean cosine similarity:    {cos_off_diag:.6f}")
print(f"  Min cosine similarity:     {cos_min:.6f}")
print(f"  VERDICT: {'DISTINCT PATTERNS' if cos_min < 0.99 else 'NEARLY IDENTICAL PATTERNS'}")


# ═══════════════════════════════════════════════════════════════
# 3. CENTER vs EDGE FRICTION
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  3. CENTER vs EDGE β€” Do boundary patches have higher friction?")
print("=" * 70)

# Define center and edge regions
center_mask = torch.zeros(gh, gw, dtype=torch.bool)
center_mask[4:12, 4:12] = True  # center 8Γ—8
edge_mask = ~center_mask         # border ring

for c in range(10):
    fric_c = class_means[c]  # (gh, gw, 4)
    center_fric = fric_c[center_mask].mean().item()
    edge_fric = fric_c[edge_mask].mean().item()
    ratio = edge_fric / (center_fric + 1e-8)
    if c == 0:
        print(f"\n  {'Class':<10s} {'Center':>8s} {'Edge':>8s} {'Edge/Center':>12s}")
        print(f"  {'-' * 40}")
    print(f"  {CLASSES[c]:<10s} {center_fric:8.3f} {edge_fric:8.3f} {ratio:12.4f}")


# ═══════════════════════════════════════════════════════════════
# 4. PER-PATCH-POSITION DISCRIMINABILITY
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  4. PER-PATCH-POSITION DISCRIMINABILITY")
print("=" * 70)

# For each patch position (i,j), is friction discriminative across classes?
# Use inter-class variance / intra-class variance ratio (F-statistic proxy)

position_f_stat = torch.zeros(gh, gw, 4)

for pi in range(gh):
    for pj in range(gw):
        for d in range(4):
            # Class means at this position
            c_means = class_means[:, pi, pj, d]  # (10,)
            # Inter-class variance
            inter_var = c_means.var().item()
            # Intra-class variance (averaged)
            intra_var = class_vars[:, pi, pj, d].mean().item()
            position_f_stat[pi, pj, d] = inter_var / (intra_var + 1e-10)

# Summary
print(f"\n  F-statistic (inter-class var / intra-class var) per mode:")
for d in range(4):
    fs = position_f_stat[:, :, d]
    print(f"    Mode {d}: mean={fs.mean():.6f} max={fs.max():.6f} "
          f"top 5% threshold={fs.quantile(0.95):.6f}")

# Best discriminative positions
for d in range(4):
    fs = position_f_stat[:, :, d]
    best_idx = fs.argmax()
    bi, bj = best_idx // gw, best_idx % gw
    print(f"    Mode {d} best position: ({bi.item()}, {bj.item()}) F={fs.max():.6f}")

overall_f = position_f_stat.mean(dim=-1)  # avg across modes
print(f"\n  Overall best discriminative patch position: "
      f"{(overall_f.argmax() // gw).item()}, {(overall_f.argmax() % gw).item()} "
      f"F={overall_f.max():.6f}")
print(f"  Overall mean F-statistic: {overall_f.mean():.6f}")
print(f"  VERDICT: {'POSITIONALLY DISCRIMINATIVE' if overall_f.max() > 0.01 else 'NOT DISCRIMINATIVE'}")


# ═══════════════════════════════════════════════════════════════
# 5. PER-MODE ANALYSIS β€” Which SVD mode carries most spatial variance?
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  5. PER-MODE SPATIAL VARIANCE β€” Which mode has the most structure?")
print("=" * 70)

for d in range(4):
    # Spatial variance of mean friction map (across all images)
    overall_mean_map = class_friction_sum.sum(0) / class_counts.sum()  # (gh, gw, 4)
    mode_map = overall_mean_map[:, :, d]
    sv = mode_map.var().item()
    sm = mode_map.mean().item()
    print(f"  Mode {d}: map_mean={sm:.4f} map_var={sv:.6f} map_cv={sv**0.5/(sm+1e-8):.4f}")


# ═══════════════════════════════════════════════════════════════
# 6. INDIVIDUAL IMAGE FRICTION MAPS
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  6. SAMPLE FRICTION MAPS β€” Individual images")
print("=" * 70)

# Show friction statistics for 2 images per class
for c in range(10):
    maps_c = [(f, l) for f, l in all_friction_maps if l == c][:2]
    for idx, (fric_map, _) in enumerate(maps_c):
        # fric_map: (gh, gw, 4)
        flat = fric_map.reshape(-1, 4)
        fmean = flat.mean(0)
        fstd = flat.std(0)
        fmin = flat.min(0).values
        fmax = flat.max(0).values

        # Spatial entropy: how concentrated is the friction?
        fric_total = flat.sum(dim=-1)  # per-patch total friction
        fric_prob = fric_total / (fric_total.sum() + 1e-8)
        entropy = -(fric_prob * (fric_prob + 1e-10).log()).sum().item()
        max_entropy = np.log(256)  # uniform = max entropy

        # Hot spots: patches with friction > 2Γ— mean
        hot = (fric_total > 2 * fric_total.mean()).sum().item()

        if idx == 0 and c == 0:
            print(f"\n  {'Class':<10s} {'Img':>3s} {'Mean':>8s} {'Std':>8s} "
                  f"{'Max':>8s} {'Entropy':>8s} {'HotSpots':>9s}")
            print(f"  {'-' * 55}")

        print(f"  {CLASSES[c]:<10s} {idx:3d} {fmean.mean():8.2f} {fstd.mean():8.2f} "
              f"{fmax.max():8.2f} {entropy/max_entropy:8.3f} {hot:9d}")


# ═══════════════════════════════════════════════════════════════
# 7. FRICTION MAP AS CLASSIFIER β€” Linear probe on spatial friction
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  7. LINEAR PROBE β€” Can flattened friction maps classify?")
print("=" * 70)

# Collect features and labels
features = []
labels_all = []
for fric_map, label in all_friction_maps:
    features.append(fric_map.reshape(-1))  # (gh*gw*4,) = 1024
    labels_all.append(label)

X = torch.stack(features)  # (N, 1024)
y = torch.tensor(labels_all)  # (N,)

# Train/test split
N = len(y)
perm = torch.randperm(N)
n_train = int(0.8 * N)
X_train, y_train = X[perm[:n_train]], y[perm[:n_train]]
X_test, y_test = X[perm[n_train:]], y[perm[n_train:]]

# Standardize
mean = X_train.mean(0)
std = X_train.std(0).clamp(min=1e-6)
X_train_n = (X_train - mean) / std
X_test_n = (X_test - mean) / std

# Ridge regression (closed form, no training loop)
lam = 1.0
n_classes = 10
Y_onehot = torch.zeros(n_train, n_classes)
Y_onehot.scatter_(1, y_train.unsqueeze(1), 1.0)

XtX = X_train_n.T @ X_train_n + lam * torch.eye(X_train_n.shape[1])
XtY = X_train_n.T @ Y_onehot
W = torch.linalg.solve(XtX, XtY)

train_pred = (X_train_n @ W).argmax(1)
test_pred = (X_test_n @ W).argmax(1)
train_acc = (train_pred == y_train).float().mean().item()
test_acc = (test_pred == y_test).float().mean().item()

print(f"\n  Features: flattened friction map ({X.shape[1]} dims)")
print(f"  Train: {n_train}, Test: {N - n_train}")
print(f"  Train accuracy: {train_acc:.1%}")
print(f"  Test accuracy:  {test_acc:.1%}")
print(f"  Chance:          10.0%")

# Per-class accuracy
print(f"\n  {'Class':<10s} {'Acc':>6s}")
print(f"  {'-' * 18}")
for c in range(n_classes):
    mask = y_test == c
    if mask.sum() > 0:
        acc = (test_pred[mask] == y_test[mask]).float().mean().item()
        bar = 'β–ˆ' * int(acc * 20)
        print(f"  {CLASSES[c]:<10s} {acc:5.1%} {bar}")

print(f"\n  VERDICT: {'DISCRIMINATIVE' if test_acc > 0.15 else 'NOT DISCRIMINATIVE'} "
      f"spatial friction signal")


# ═══════════════════════════════════════════════════════════════
# 8. SETTLE MAP ANALYSIS β€” Same treatment for settle times
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  8. SETTLE MAP β€” Spatial convergence patterns")
print("=" * 70)

settle_features = []
settle_labels = []
for sett_map, label in all_settle_maps:
    settle_features.append(sett_map.reshape(-1))
    settle_labels.append(label)

X_s = torch.stack(settle_features)
y_s = torch.tensor(settle_labels)

perm_s = torch.randperm(len(y_s))
n_train_s = int(0.8 * len(y_s))
X_train_s, y_train_s = X_s[perm_s[:n_train_s]], y_s[perm_s[:n_train_s]]
X_test_s, y_test_s = X_s[perm_s[n_train_s:]], y_s[perm_s[n_train_s:]]

mean_s = X_train_s.mean(0)
std_s = X_train_s.std(0).clamp(min=1e-6)
X_train_sn = (X_train_s - mean_s) / std_s
X_test_sn = (X_test_s - mean_s) / std_s

Y_onehot_s = torch.zeros(n_train_s, n_classes)
Y_onehot_s.scatter_(1, y_train_s.unsqueeze(1), 1.0)
XtX_s = X_train_sn.T @ X_train_sn + lam * torch.eye(X_train_sn.shape[1])
XtY_s = X_train_sn.T @ Y_onehot_s
W_s = torch.linalg.solve(XtX_s, XtY_s)

test_pred_s = (X_test_sn @ W_s).argmax(1)
test_acc_s = (test_pred_s == y_test_s).float().mean().item()

print(f"  Settle map linear probe:")
print(f"  Test accuracy: {test_acc_s:.1%}")
print(f"  VERDICT: {'DISCRIMINATIVE' if test_acc_s > 0.15 else 'NOT DISCRIMINATIVE'}")


# ═══════════════════════════════════════════════════════════════
# 9. COMBINED CONDUIT β€” friction + settle + eigenvalues
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  9. COMBINED CONDUIT β€” All evidence stacked")
print("=" * 70)

# Also test: raw eigenvalues (S values) as spatial maps for comparison
print("\n  Collecting eigenvalue spatial maps...")
all_eval_maps = []
all_combined = []

for fric_map, label in all_friction_maps:
    pass  # Already collected

# Re-collect with eigenvalues
eval_features = []
combined_features = []
combined_labels = []

idx = 0
for images, labels_batch in loader:
    if idx >= max_collect:
        break
    with torch.no_grad():
        out = freckles(images.to(device))
        S = out['svd']['S']
        Vt = out['svd']['Vt']
        B_img, N, D = S.shape

        S2 = S.pow(2)
        G = torch.einsum('bnij,bnj,bnjk->bnik',
                         Vt.transpose(-2, -1), S2, Vt)
        G_flat = G.reshape(B_img * N, D, D)
        packet = conduit(G_flat)

        fric = packet.friction.reshape(B_img, gh, gw, D)
        sett = packet.settle.reshape(B_img, gh, gw, D)
        evals = S.reshape(B_img, gh, gw, D)  # S values as spatial map

    for i in range(B_img):
        if idx >= max_collect:
            break
        # Eigenvalue spatial map
        eval_features.append(evals[i].cpu().reshape(-1))
        # Combined: friction + settle + eigenvalues
        combined = torch.cat([
            fric[i].cpu().reshape(-1),
            sett[i].cpu().reshape(-1),
            evals[i].cpu().reshape(-1),
        ])
        combined_features.append(combined)
        combined_labels.append(labels_batch[i].item())
        idx += 1

# Eigenvalue-only probe
X_e = torch.stack(eval_features)
y_e = torch.tensor(combined_labels)

perm_e = torch.randperm(len(y_e))
n_train_e = int(0.8 * len(y_e))

def ridge_probe(X, y, perm, n_train, name):
    X_tr, y_tr = X[perm[:n_train]], y[perm[:n_train]]
    X_te, y_te = X[perm[n_train:]], y[perm[n_train:]]
    m = X_tr.mean(0)
    s = X_tr.std(0).clamp(min=1e-6)
    X_tr_n = (X_tr - m) / s
    X_te_n = (X_te - m) / s
    Y_oh = torch.zeros(n_train, n_classes)
    Y_oh.scatter_(1, y_tr.unsqueeze(1), 1.0)
    W = torch.linalg.solve(X_tr_n.T @ X_tr_n + torch.eye(X_tr_n.shape[1]), X_tr_n.T @ Y_oh)
    acc = ((X_te_n @ W).argmax(1) == y_te).float().mean().item()
    print(f"  {name:<30s} dims={X.shape[1]:>5d}  test_acc={acc:.1%}")
    return acc

print(f"\n  Linear probe comparison (all use same train/test split):\n")
acc_evals = ridge_probe(X_e, y_e, perm_e, n_train_e, "Eigenvalues (S) spatial")
acc_fric = ridge_probe(X, y, perm, n_train, "Friction spatial")
acc_sett = ridge_probe(X_s, y_s, perm_s, n_train_s, "Settle spatial")

X_c = torch.stack(combined_features)
acc_comb = ridge_probe(X_c, y_e, perm_e, n_train_e, "Combined (S+fric+settle)")

print(f"\n  Chance: 10.0%")
print(f"  VERDICT: Combined vs eigenvalues-only lift = "
      f"{(acc_comb - acc_evals) * 100:+.1f} percentage points")


# ═══════════════════════════════════════════════════════════════
# SUMMARY
# ═══════════════════════════════════════════════════════════════

print(f"\n{'=' * 70}")
print("  SPATIAL FRICTION ANALYSIS β€” SUMMARY")
print("=" * 70)
print(f"  1. Spatial structure within images:   CV = {spatial_cv.mean():.4f}")
print(f"  2. Inter-class pattern distance:      cos_min = {cos_min:.6f}")
print(f"  3. Center vs edge asymmetry:          (see table above)")
print(f"  4. Per-position F-statistic:          max = {overall_f.max():.6f}")
print(f"  5. Friction map linear probe:         {test_acc:.1%}")
print(f"  6. Settle map linear probe:           {test_acc_s:.1%}")
print(f"  7. Eigenvalue map linear probe:       {acc_evals:.1%}")
print(f"  8. Combined conduit linear probe:     {acc_comb:.1%}")
print(f"  9. Conduit lift over eigenvalues:     {(acc_comb - acc_evals)*100:+.1f}pp")