AbstractPhil's picture
Create 18_j_cell_reformed.py
486b1d0 verified
# ═══════════════════════════════════════════════════════════════════════
# Cell J''' β€” per-patch axis-feature classifier
# ═══════════════════════════════════════════════════════════════════════
# Same task as J/J'/J''. Same architectures. Only the encode + feature
# extraction stages change.
#
# Key fix from J' and J'':
# J' : encode_axes(images, patch_idx=0) β†’ [B, V, n_axes]
# β†’ max-pool over V β†’ [B, n_axes]
# Used 1 of 256 patches per tile.
#
# J'' : same as J' but with V-stats instead of max-pool.
# Still using 1 of 256 patches per tile.
#
# J''': encode_axes(images) # no patch_idx β†’ [B, n_patches=256, V, n_axes]
# β†’ spatial stats over patches AND value stats over V
# Uses ALL 256 patches per tile. 256Γ— more spatial signal.
#
# Codebooks calibrated with the new per-patch averaging path
# (sample_agg='mean', patch_agg='mean') for codebooks that reflect the
# bank's spatial-mean response.
import json
import math
import time
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import geolip_svae.arrays
from transformers import AutoModel
array_model = globals().get('array_model')
if array_model is None:
array_model = AutoModel.from_pretrained("AbstractPhil/geolip-svae-h2-64")
array_model = (array_model.cuda().eval()
if torch.cuda.is_available() else array_model.eval())
DEVICE = next(array_model.parameters()).device
EXP_DIR = Path("/content/h2_64_exp")
EXP_DIR.mkdir(parents=True, exist_ok=True)
TILE_SIZE = 64
NOISE_NAMES = {
0: 'gaussian', 1: 'uniform', 2: 'uniform_scaled', 3: 'poisson',
4: 'pink', 5: 'brown', 6: 'salt_pepper', 7: 'sparse_impulses',
8: 'block_upsampled', 9: 'gradient_gaussian', 10: 'checker',
11: 'gauss_uniform_mix', 12: 'four_quadrant',
13: 'cauchy', 14: 'exponential', 15: 'laplace',
}
# ══════════════════════════════════════════════════════════════════════
# Noise generators β€” inlined from Cell J so this cell is self-contained
# ══════════════════════════════════════════════════════════════════════
def _pink(shape, rng):
w = torch.randn(shape, generator=rng)
s = torch.fft.rfft2(w)
h, ww = shape[-2], shape[-1]
fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1)
fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
return torch.fft.irfft2(s / torch.sqrt(fx**2 + fy**2).clamp(min=1e-8),
s=(h, ww))
def _brown(shape, rng):
w = torch.randn(shape, generator=rng)
s = torch.fft.rfft2(w)
h, ww = shape[-2], shape[-1]
fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, ww // 2 + 1)
fx = torch.fft.rfftfreq(ww).unsqueeze(0).expand(h, -1)
return torch.fft.irfft2(s / (fx**2 + fy**2).clamp(min=1e-8), s=(h, ww))
def gen_noise(noise_type, size, seed):
"""Pure noise generator. size must be even (some generators use s//2)."""
rng_t = torch.Generator().manual_seed(seed)
rng_n = np.random.RandomState(seed)
s = size
if noise_type == 0:
img = torch.randn(3, s, s, generator=rng_t)
elif noise_type == 1:
img = torch.rand(3, s, s, generator=rng_t) * 2 - 1
elif noise_type == 2:
img = (torch.rand(3, s, s, generator=rng_t) - 0.5) * 4
elif noise_type == 3:
lam = rng_n.uniform(0.5, 20.0)
img = torch.poisson(torch.full((3, s, s), lam), generator=rng_t) / lam - 1.0
elif noise_type == 4:
img = _pink((3, s, s), rng_t); img = img / (img.std() + 1e-8)
elif noise_type == 5:
img = _brown((3, s, s), rng_t); img = img / (img.std() + 1e-8)
elif noise_type == 6:
mask = torch.rand(3, s, s, generator=rng_t) > 0.5
img = torch.where(mask, torch.ones(3, s, s) * 2, torch.ones(3, s, s) * -2)
img = img + torch.randn(3, s, s, generator=rng_t) * 0.1
elif noise_type == 7:
mask = torch.rand(3, s, s, generator=rng_t) > 0.9
img = torch.randn(3, s, s, generator=rng_t) * mask.float() * 3
elif noise_type == 8:
block = rng_n.randint(2, 16)
small = torch.randn(3, s // block + 1, s // block + 1, generator=rng_t)
img = F.interpolate(small.unsqueeze(0), size=s, mode='nearest').squeeze(0)
elif noise_type == 9:
gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s)
gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s)
angle = rng_n.uniform(0, 2 * math.pi)
grad = math.cos(angle) * gx + math.sin(angle) * gy
img = (grad.unsqueeze(0).expand(3, -1, -1)
+ torch.randn(3, s, s, generator=rng_t) * 0.5)
elif noise_type == 10:
cs = rng_n.randint(2, 16)
cy = torch.arange(s) // cs; cx = torch.arange(s) // cs
checker = ((cy.unsqueeze(1) + cx.unsqueeze(0)) % 2).float() * 2 - 1
img = (checker.unsqueeze(0).expand(3, -1, -1)
+ torch.randn(3, s, s, generator=rng_t) * 0.3)
elif noise_type == 11:
a = torch.randn(3, s, s, generator=rng_t)
b = torch.rand(3, s, s, generator=rng_t) * 2 - 1
alpha = rng_n.uniform(0.2, 0.8)
img = alpha * a + (1 - alpha) * b
elif noise_type == 12:
img = torch.zeros(3, s, s)
h2 = s // 2
img[:, :h2, :h2] = torch.randn(3, h2, h2, generator=rng_t)
img[:, :h2, h2:] = torch.rand(3, h2, h2, generator=rng_t) * 2 - 1
img[:, h2:, :h2] = _pink((3, h2, h2), rng_t) / 2
sp = torch.where(torch.rand(3, h2, h2, generator=rng_t) > 0.5,
torch.ones(3, h2, h2), -torch.ones(3, h2, h2))
img[:, h2:, h2:] = sp
elif noise_type == 13:
u = torch.rand(3, s, s, generator=rng_t)
img = torch.tan(math.pi * (u - 0.5)).clamp(-3, 3)
elif noise_type == 14:
img = torch.empty(3, s, s).exponential_(1.0, generator=rng_t) - 1.0
elif noise_type == 15:
u = torch.rand(3, s, s, generator=rng_t) - 0.5
img = -torch.sign(u) * torch.log1p(-2 * u.abs())
else:
raise ValueError(f"Unknown noise_type {noise_type}")
return img.clamp(-4, 4).float()
def gen_zone_matte(res, n_zones, seed):
"""Spatially-mixed noise: n_zones grid of different noise types."""
assert n_zones in (4, 9, 16), "Use 2Γ—2, 3Γ—3, or 4Γ—4 grids"
side = int(math.sqrt(n_zones))
cell = res // side
assert cell % 2 == 0, f"cell size {cell} must be even for noise generators"
rng_n = np.random.RandomState(seed)
zone_types = rng_n.choice(16, size=n_zones, replace=False).tolist()
img = torch.zeros(3, res, res)
zone_map = torch.zeros(res, res, dtype=torch.long)
for i in range(side):
for j in range(side):
zi = i * side + j
nt = zone_types[zi]
cell_seed = seed * 1000 + zi + 1
cell_img = gen_noise(nt, cell, cell_seed)
img[:, i*cell:(i+1)*cell, j*cell:(j+1)*cell] = cell_img
zone_map[i*cell:(i+1)*cell, j*cell:(j+1)*cell] = zi
return img, zone_types, zone_map
SUBSET_BATTERY_IDS = list(range(16)) + [19, 20]
SUBSET_PHASE = 'best'
N_BATTERIES_SUB = len(SUBSET_BATTERY_IDS)
COMP_LABELS = {16: 'zone_4', 17: 'zone_9', 18: 'zone_16'}
N_CLASSES = 16 + len(COMP_LABELS)
def label_name(n):
return NOISE_NAMES.get(n, COMP_LABELS.get(n, f"?{n}"))
print("=" * 78)
print("PHASE J''' β€” PER-PATCH AXIS-FEATURE SCANNER")
print("=" * 78)
print(f"Subset: {N_BATTERIES_SUB} batteries, phase={SUBSET_PHASE}")
print(f"Per tile: 256 patches Γ— 32 V Γ— ~27 axes = ~221K activations")
# ══════════════════════════════════════════════════════════════════════
# Codebook calibration β€” use the new batched + per-patch API
# ══════════════════════════════════════════════════════════════════════
print(f"\nCalibrating codebooks (batched, per-patch averaging)...")
t0 = time.time()
g = torch.Generator().manual_seed(42)
calib_imgs = torch.randn(512, 3, 64, 64, generator=g)
targets = [(bid, SUBSET_PHASE) for bid in SUBSET_BATTERY_IDS]
codebooks_dict = array_model.compute_axis_codebooks(
targets=targets,
calibration_images=calib_imgs,
sample_agg='mean',
patch_agg='mean', # NEW: average across 256 patches per image
batch_size=64,
)
codebooks = {bid: codebooks_dict[(bid, SUBSET_PHASE)].to(DEVICE)
for bid in SUBSET_BATTERY_IDS}
for bid in SUBSET_BATTERY_IDS:
print(f" battery {bid:>2} ({label_name(bid):<22}): "
f"{codebooks[bid].shape[0]} axes")
MAX_AXES = max(cb.shape[0] for cb in codebooks.values())
print(f"Calibration time: {time.time() - t0:.1f}s, MAX_AXES: {MAX_AXES}")
# ══════════════════════════════════════════════════════════════════════
# Per-patch tile scan β€” uses encode_axes WITHOUT patch_idx
# ══════════════════════════════════════════════════════════════════════
@torch.no_grad()
def perpatch_tile_scan(image, codebooks, battery_ids, tile_size=TILE_SIZE):
"""For each tile, get full per-patch activations against each bank.
Returns features condensed at scan-time (full tensor too big to store):
[n_tiles, n_banks, MAX_AXES, N_STATS]
where N_STATS = stats over (n_patches Γ— V) jointly.
Stats per (bank, axis): max, mean, std, top10_mean, entropy
Computed over the joint distribution of activations across both
patches AND V rows (because both contribute to "how does this tile
align with this axis").
"""
C, H, W = image.shape
n_h, n_w = H // tile_size, W // tile_size
n_tiles = n_h * n_w
tiles = image.unfold(1, tile_size, tile_size).unfold(2, tile_size, tile_size)
tiles = tiles.permute(1, 2, 0, 3, 4).contiguous().reshape(
n_tiles, C, tile_size, tile_size).to(DEVICE)
n_banks = len(battery_ids)
out = torch.zeros(n_tiles, n_banks, MAX_AXES, N_STATS, dtype=torch.float32)
tile_batch = 32 # smaller because per-patch activations are larger
for b_i, bid in enumerate(battery_ids):
cb = codebooks[bid] # [n_axes_i, D]
n_axes_i = cb.shape[0]
for start in range(0, n_tiles, tile_batch):
end = min(start + tile_batch, n_tiles)
batch = tiles[start:end]
# Per-patch encoding: [B_t, n_patches=256, V=32, n_axes_i]
acts = array_model.encode_axes(
images=batch, battery_idx=bid,
phase=SUBSET_PHASE, codebook=cb,
)
B_t, P, V, n_ax = acts.shape
# Reshape to [B_t, P*V, n_axes_i] β€” joint patch+V distribution
joint = acts.reshape(B_t, P * V, n_ax).cpu()
# Stats over the joint patch+V dimension, per axis:
mx = joint.max(dim=1).values # [B_t, n_ax]
mn = joint.mean(dim=1) # [B_t, n_ax]
sd = joint.std(dim=1) # [B_t, n_ax]
k = min(10, P * V)
top_k = joint.topk(k, dim=1).values
top10 = top_k.mean(dim=1) # [B_t, n_ax]
# Entropy over softmax(activations across patchΓ—V): low entropy
# means a few specific (patch, V-row) positions dominate, high
# entropy means uniform alignment across the spatial-row plane.
sm = F.softmax(joint, dim=1) # [B_t, P*V, n_ax]
ent = -(sm * (sm + 1e-12).log()).sum(dim=1) # [B_t, n_ax]
ent = ent / math.log(P * V) # normalized
# Stack [B_t, n_axes_i, N_STATS] and place into output
stats = torch.stack([mx, mn, sd, top10, ent], dim=-1)
out[start:end, b_i, :n_axes_i, :] = stats
return out
N_STATS = 5 # max, mean, std, top10_mean, entropy
print(f"\nN_STATS per (bank, axis) per tile: {N_STATS}")
# ══════════════════════════════════════════════════════════════════════
# Build feature bank
# ══════════════════════════════════════════════════════════════════════
RESOLUTIONS = [256, 512, 1024]
N_IMAGES_PER_LABEL = 24
print(f"\nBuilding per-patch axis-stats scan bank...")
scan_bank = {}
t0 = time.time()
for res in RESOLUTIONS:
scan_bank[res] = {}
n_tiles = (res // TILE_SIZE) ** 2
print(f"\n res={res} ({n_tiles} tiles per image):")
for nt in range(16):
for img_idx in range(N_IMAGES_PER_LABEL):
seed = 1_000_000 + res * 100 + nt * 100 + img_idx
img = gen_noise(nt, res, seed)
stats = perpatch_tile_scan(img, codebooks, SUBSET_BATTERY_IDS)
scan_bank[res][(nt, img_idx)] = stats
print(f" {label_name(nt):<22} done")
for zone_n, zone_lbl in [(4, 16), (9, 17), (16, 18)]:
side = int(math.sqrt(zone_n))
if res % side != 0 or (res // side) % 2 != 0:
print(f" {label_name(zone_lbl):<22} SKIP")
continue
for img_idx in range(N_IMAGES_PER_LABEL):
seed = 2_000_000 + res * 100 + zone_n * 10 + img_idx
img, _, _ = gen_zone_matte(res, zone_n, seed)
stats = perpatch_tile_scan(img, codebooks, SUBSET_BATTERY_IDS)
scan_bank[res][(zone_lbl, img_idx)] = stats
print(f" {label_name(zone_lbl):<22} done")
print(f"\nTotal scan time: {time.time() - t0:.1f}s")
# ══════════════════════════════════════════════════════════════════════
# Feature builders β€” A''' summary, B''' attn-pool over tiles
# ══════════════════════════════════════════════════════════════════════
def perpatch_summary_features(stats_scan):
"""Aggregate over tiles via mean+max for each (bank, axis, stat).
stats_scan: [n_tiles, n_banks, MAX_AXES, N_STATS]
Returns: [n_banks * MAX_AXES * N_STATS * 2] flat
"""
mn = stats_scan.mean(dim=0)
mx = stats_scan.max(dim=0).values
return torch.stack([mn, mx], dim=-1).flatten()
def perpatch_tile_grid_features(stats_scan, max_tiles=16):
"""Tile-grid for attention pool. Flattens (bank, axis, stat) per tile.
Returns: [max_tiles, n_banks * MAX_AXES * N_STATS]
"""
n_tiles, n_banks, n_ax, n_st = stats_scan.shape
flat = stats_scan.reshape(n_tiles, n_banks * n_ax * n_st)
if n_tiles >= max_tiles:
idx = torch.randperm(n_tiles)[:max_tiles]
return flat[idx]
pad = torch.zeros(max_tiles - n_tiles, n_banks * n_ax * n_st)
return torch.cat([flat, pad], dim=0)
print(f"\nBuilding feature tensors per resolution...")
features_A_by_res = {}
features_B_by_res = {}
labels_by_res = {}
MAX_TILES = 16
for res in RESOLUTIONS:
feat_A, feat_B, labs = [], [], []
for (lbl, img_idx), stats in scan_bank[res].items():
feat_A.append(perpatch_summary_features(stats))
feat_B.append(perpatch_tile_grid_features(stats, max_tiles=MAX_TILES))
labs.append(lbl)
features_A_by_res[res] = torch.stack(feat_A)
features_B_by_res[res] = torch.stack(feat_B)
labels_by_res[res] = torch.tensor(labs)
print(f" res={res}: {features_A_by_res[res].shape[0]} samples, "
f"A''' feat {features_A_by_res[res].shape[1]}-dim, "
f"B''' feat {tuple(features_B_by_res[res].shape[1:])}")
# ══════════════════════════════════════════════════════════════════════
# Classifiers (same as J/J'/J'' for fair comparison)
# ══════════════════════════════════════════════════════════════════════
class SummaryMLP(nn.Module):
def __init__(self, in_dim, hidden=128, n_classes=N_CLASSES):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, n_classes),
)
def forward(self, x): return self.net(x)
class AttentionPoolMLP(nn.Module):
def __init__(self, n_features, hidden=128, n_classes=N_CLASSES):
super().__init__()
self.attn_scorer = nn.Linear(n_features, 1)
self.classifier = nn.Sequential(
nn.Linear(n_features, hidden),
nn.ReLU(),
nn.Linear(hidden, n_classes),
)
def forward(self, x):
scores = self.attn_scorer(x).squeeze(-1)
weights = torch.softmax(scores, dim=1)
pooled = (x * weights.unsqueeze(-1)).sum(dim=1)
return self.classifier(pooled)
def train_classifier(model_cls, train_x, train_y, test_x, test_y,
in_spec, n_epochs=200, lr=1e-2):
torch.manual_seed(42)
clf = model_cls(in_spec)
optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
batch_size = 128
train_hist, test_hist = [], []
for epoch in range(n_epochs):
perm = torch.randperm(train_x.shape[0])
clf.train()
for i in range(0, train_x.shape[0], batch_size):
idx = perm[i:i + batch_size]
loss = F.cross_entropy(clf(train_x[idx]), train_y[idx])
optimizer.zero_grad(); loss.backward(); optimizer.step()
clf.eval()
with torch.no_grad():
train_acc = (clf(train_x).argmax(dim=1) == train_y).float().mean().item()
test_acc = (clf(test_x).argmax(dim=1) == test_y).float().mean().item()
train_hist.append(train_acc); test_hist.append(test_acc)
clf.eval()
with torch.no_grad():
preds = clf(test_x).argmax(dim=1)
classes = torch.unique(test_y).tolist()
per_class = {c: ((preds == test_y) & (test_y == c)).sum().item() /
max(1, (test_y == c).sum().item())
for c in classes}
return test_hist[-1], per_class, train_hist, test_hist
# ══════════════════════════════════════════════════════════════════════
# Train + compare against all priors
# ══════════════════════════════════════════════════════════════════════
ref_paths = {
'cell_j': EXP_DIR / "results_expJ.json",
'cell_jp': EXP_DIR / "results_expJ_axes.json",
'cell_jpp': EXP_DIR / "results_expJ_vstats.json",
}
refs = {}
for k, p in ref_paths.items():
if p.exists():
with open(p) as f:
r = json.load(f)
refs[k] = {int(res): {'A': v['accuracy_A'], 'B': v['accuracy_B']}
for res, v in r['per_resolution'].items()}
print(f"\nLoaded {k}: {p}")
results = {}
for res in RESOLUTIONS:
print(f"\n{'─' * 78}")
print(f"Resolution {res}Γ—{res}")
print(f"{'─' * 78}")
n_items = features_A_by_res[res].shape[0]
rng = np.random.RandomState(42)
indices = rng.permutation(n_items)
n_train = int(n_items * 0.8)
train_idx, test_idx = indices[:n_train], indices[n_train:]
labels = labels_by_res[res]
# A'''
xA = features_A_by_res[res]
mA, sA = xA[train_idx].mean(dim=0), xA[train_idx].std(dim=0).clamp(min=1e-8)
xA = (xA - mA) / sA
accA, per_class_A, tA, vA = train_classifier(
SummaryMLP, xA[train_idx], labels[train_idx],
xA[test_idx], labels[test_idx], in_spec=xA.shape[1])
# B'''
xB = features_B_by_res[res]
flat = xB[train_idx].reshape(-1, xB.shape[-1])
mB, sB = flat.mean(dim=0), flat.std(dim=0).clamp(min=1e-8)
xB = (xB - mB) / sB
accB, per_class_B, tB, vB = train_classifier(
AttentionPoolMLP, xB[train_idx], labels[train_idx],
xB[test_idx], labels[test_idx], in_spec=xB.shape[-1])
print(f" A''' (per-patch summary): test={accA:.1%}")
if 'cell_j' in refs:
d = accA - refs['cell_j'][res]['A']
print(f" vs Cell J A (MSE): {refs['cell_j'][res]['A']:.1%} Ξ” {d:+.1%}")
if 'cell_jp' in refs:
d = accA - refs['cell_jp'][res]['A']
print(f" vs Cell J' A (max-axes): {refs['cell_jp'][res]['A']:.1%} Ξ” {d:+.1%}")
if 'cell_jpp' in refs:
d = accA - refs['cell_jpp'][res]['A']
print(f" vs Cell J'' A (V-stats): {refs['cell_jpp'][res]['A']:.1%} Ξ” {d:+.1%}")
print(f"\n B''' (per-patch attn): test={accB:.1%}")
if 'cell_j' in refs:
d = accB - refs['cell_j'][res]['B']
print(f" vs Cell J B (MSE): {refs['cell_j'][res]['B']:.1%} Ξ” {d:+.1%}")
if 'cell_jp' in refs:
d = accB - refs['cell_jp'][res]['B']
print(f" vs Cell J' B (max-axes): {refs['cell_jp'][res]['B']:.1%} Ξ” {d:+.1%}")
if 'cell_jpp' in refs:
d = accB - refs['cell_jpp'][res]['B']
print(f" vs Cell J'' B (V-stats): {refs['cell_jpp'][res]['B']:.1%} Ξ” {d:+.1%}")
print(f"\n {'Class':<22} {'A':>9} {'B':>9} {'Ξ”(B-A)':>9}")
for c in sorted(per_class_A.keys()):
a = per_class_A[c]; b = per_class_B.get(c, 0.0)
sym = '+' if b > a + 0.01 else '-' if b < a - 0.01 else ' '
print(f" {label_name(c):<22} {a:>9.1%} {b:>9.1%} {sym}{abs(b-a):>8.1%}")
results[res] = {
'accuracy_A': accA, 'accuracy_B': accB,
'per_class_A': {label_name(c): per_class_A[c] for c in per_class_A},
'per_class_B': {label_name(c): per_class_B.get(c, 0.0) for c in per_class_A},
'train_curve_A': tA, 'test_curve_A': vA,
'train_curve_B': tB, 'test_curve_B': vB,
}
# ══════════════════════════════════════════════════════════════════════
# Plots + verdict
# ══════════════════════════════════════════════════════════════════════
fig, axes = plt.subplots(1, 2, figsize=(20, 6))
for idx, (clf_label, key_curve, key_acc) in enumerate(
[("A''' per-patch summary MLP", 'test_curve_A', 'accuracy_A'),
("B''' per-patch attn-pool MLP", 'test_curve_B', 'accuracy_B')]
):
ax = axes[idx]
for res in RESOLUTIONS:
ax.plot(results[res][key_curve],
label=f'{res} ({results[res][key_acc]:.1%})',
linewidth=1.5, alpha=0.85)
ax.axhline(1 / N_CLASSES, color='gray', linestyle='--', linewidth=1,
label=f'Random ({1/N_CLASSES:.1%})')
ax.set_xlabel('Epoch'); ax.set_ylabel('Test accuracy')
ax.set_title(clf_label)
ax.legend(loc='lower right'); ax.grid(linestyle=':', alpha=0.5)
ax.set_ylim(0, 1.05)
plt.tight_layout()
plt.savefig(EXP_DIR / 'expJ_perpatch_curves.png', dpi=120, bbox_inches='tight')
plt.show()
print(f"\n{'=' * 78}")
print(f"PHASE J''' VERDICT β€” per-patch axis features")
print(f"{'=' * 78}")
if 'cell_j' in refs:
print(f"\n{'Res':<6} | {'MSE':>7} {'maxax':>7} {'vstat':>7} {'perpatch':>9} "
f"| {'MSE':>7} {'maxax':>7} {'vstat':>7} {'perpatch':>9}")
print(f" | {'A':>7} {'A':>7} {'A':>7} {'A':>9} "
f"| {'B':>7} {'B':>7} {'B':>7} {'B':>9}")
print("-" * 100)
for res in RESOLUTIONS:
ja = refs['cell_j'][res]['A']; jb = refs['cell_j'][res]['B']
pa = refs.get('cell_jp', {}).get(res, {}).get('A', float('nan'))
pb = refs.get('cell_jp', {}).get(res, {}).get('B', float('nan'))
va = refs.get('cell_jpp', {}).get(res, {}).get('A', float('nan'))
vb = refs.get('cell_jpp', {}).get(res, {}).get('B', float('nan'))
ka = results[res]['accuracy_A']; kb = results[res]['accuracy_B']
print(f"{str(res):<6} | {ja:>6.1%} {pa:>6.1%} {va:>6.1%} {ka:>8.1%} "
f"| {jb:>6.1%} {pb:>6.1%} {vb:>6.1%} {kb:>8.1%}")
avg_dA_mse = np.mean([results[r]['accuracy_A'] - refs['cell_j'][r]['A']
for r in RESOLUTIONS])
avg_dB_mse = np.mean([results[r]['accuracy_B'] - refs['cell_j'][r]['B']
for r in RESOLUTIONS])
print(f"\nMean delta vs Cell J (MSE baseline):")
print(f" A (summary): {avg_dA_mse:+.1%}")
print(f" B (attn): {avg_dB_mse:+.1%}")
print()
if avg_dA_mse > 0.03 and avg_dB_mse > 0.03:
print("βœ“ PER-PATCH AXIS FEATURES BEAT MSE on both classifiers.")
print(" The 256Γ— spatial signal was the missing piece.")
elif avg_dA_mse > 0.03 or avg_dB_mse > 0.03:
print("~ PER-PATCH AXIS FEATURES BEAT MSE on one classifier.")
print(" Mixed result β€” investigate per-class.")
elif abs(avg_dA_mse) < 0.03 and abs(avg_dB_mse) < 0.03:
print("= PER-PATCH AXIS FEATURES MATCH MSE.")
print(" Comparable performance, axis pipeline now competitive.")
else:
print("βœ— PER-PATCH AXIS FEATURES UNDERPERFORM MSE.")
print(" Even with 256Γ— more spatial data, axes lose to MSE here.")
print(" Reconstruction error remains the structurally optimal signal")
print(" for noise discrimination given how the banks were trained.")
with open(EXP_DIR / 'results_expJ_perpatch.json', 'w') as f:
json.dump({
'subset_battery_ids': SUBSET_BATTERY_IDS,
'subset_phase': SUBSET_PHASE,
'n_classes': N_CLASSES,
'n_stats': N_STATS,
'stat_names': ['max', 'mean', 'std', 'top10_mean', 'entropy'],
'codebook_sizes': {bid: codebooks[bid].shape[0]
for bid in SUBSET_BATTERY_IDS},
'codebook_calibration': 'mean+mean (per-patch averaging)',
'max_axes_padded_to': MAX_AXES,
'per_resolution': {
str(res): {
'accuracy_A': results[res]['accuracy_A'],
'accuracy_B': results[res]['accuracy_B'],
'per_class_A': results[res]['per_class_A'],
'per_class_B': results[res]['per_class_B'],
}
for res in RESOLUTIONS
},
}, f, indent=2, default=str)
print(f"\nSaved results_expJ_perpatch.json")