AbstractPhil commited on
Commit
e7065b5
Β·
verified Β·
1 Parent(s): 51ff6e9

Create 3_johanna_model_trainer.py

Browse files
Files changed (1) hide show
  1. 3_johanna_model_trainer.py +1212 -0
3_johanna_model_trainer.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Johanna F-class Miniature Trainer
3
+ ==================================
4
+ Minimum-viable-battery research target. Sweeps small PatchSVAE
5
+ configurations to find the floor where self-assembly breaks.
6
+
7
+ Naming: johanna-F-S{img_size}-V{V}-D{D}-h{hidden}-d{depth}-p{patch}
8
+
9
+ Uploads to: AbstractPhil/geolip-svae-batteries
10
+ Structure: {config_name}/checkpoints/epoch_NNNN.pt
11
+ {config_name}/tensorboard/...
12
+ {config_name}/config.json
13
+ {config_name}/final_report.json
14
+
15
+ Trains from scratch (no pretrained init β€” that would mask failure modes).
16
+ Gaussian-only foundation (Tier 0) for fast MSE floor discovery.
17
+ Full 16-type dataset available via --all_types flag for battery-behavior
18
+ verification on viable candidates.
19
+
20
+ Diagnostic battery captured per report step:
21
+ Recon: train_recon, test_mse (per-noise-type when all_types)
22
+ Geometry: row_cv, ratio, erank, Sβ‚€, S_min, S_delta (binding)
23
+ Alpha: alpha mean/std (per cross-attn layer)
24
+ CV health: in 0.13-0.30 band?, proximity to target
25
+ Stability: grad_norm, recon_w, prox
26
+ Training: lr, epoch_time, batch_time
27
+
28
+ All scalars logged to TensorBoard. Full run JSON at finish.
29
+ """
30
+
31
+ import os
32
+ import sys
33
+ import json
34
+ import math
35
+ import time
36
+ import argparse
37
+ from dataclasses import dataclass, asdict, field
38
+ from typing import Optional, List, Dict
39
+
40
+ import numpy as np
41
+ import torch
42
+ import torch.nn as nn
43
+ import torch.nn.functional as F
44
+
45
+ try:
46
+ from tqdm.auto import tqdm
47
+ _HAS_TQDM = True
48
+ except ImportError:
49
+ _HAS_TQDM = False
50
+
51
+
52
+ # ── HuggingFace auth from Colab secrets (optional) ──────────────
53
+
54
+ try:
55
+ from google.colab import userdata
56
+ os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
57
+ from huggingface_hub import login
58
+ login(token=os.environ["HF_TOKEN"])
59
+ except Exception:
60
+ pass
61
+
62
+
63
+ # ── SVD Backend (fp64 internals, matches SVAE lineage) ──────────
64
+
65
+ try:
66
+ from geolip_core.linalg.eigh import FLEigh, _FL_MAX_N
67
+ _HAS_FL = True
68
+ except ImportError:
69
+ _HAS_FL = False
70
+
71
+
72
+ def _gram_eigh_svd(A):
73
+ """Gram + torch.linalg.eigh in fp64 with diagonal regularization."""
74
+ orig_dtype = A.dtype
75
+ with torch.amp.autocast('cuda', enabled=False):
76
+ A_d = A.double()
77
+ G = torch.bmm(A_d.transpose(1, 2), A_d)
78
+ G.diagonal(dim1=-2, dim2=-1).add_(1e-12) # conditioning regularizer
79
+ eigenvalues, V = torch.linalg.eigh(G)
80
+ eigenvalues = eigenvalues.flip(-1)
81
+ V = V.flip(-1)
82
+ S = torch.sqrt(eigenvalues.clamp(min=1e-24))
83
+ U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
84
+ Vh = V.transpose(-2, -1).contiguous()
85
+ return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
86
+
87
+
88
+ def _svd_fp64(A):
89
+ """Auto-dispatch SVD with fp64 internals."""
90
+ B, M, N = A.shape
91
+ if _HAS_FL and N <= _FL_MAX_N and A.is_cuda:
92
+ orig_dtype = A.dtype
93
+ with torch.amp.autocast('cuda', enabled=False):
94
+ A_d = A.double()
95
+ G = torch.bmm(A_d.transpose(1, 2), A_d)
96
+ G.diagonal(dim1=-2, dim2=-1).add_(1e-12)
97
+ eigenvalues, V = FLEigh()(G.float())
98
+ eigenvalues = eigenvalues.double().flip(-1)
99
+ V = V.double().flip(-1)
100
+ S = torch.sqrt(eigenvalues.clamp(min=1e-24))
101
+ U = torch.bmm(A_d, V) / S.unsqueeze(1).clamp(min=1e-16)
102
+ Vh = V.transpose(-2, -1).contiguous()
103
+ return U.to(orig_dtype), S.to(orig_dtype), Vh.to(orig_dtype)
104
+ else:
105
+ return _gram_eigh_svd(A)
106
+
107
+
108
+ # ── Cayley-Menger CV (fp64 determinant) ──────────────────────────
109
+
110
+ def cayley_menger_vol2(points):
111
+ B, N, D = points.shape
112
+ pts = points.double()
113
+ gram = torch.bmm(pts, pts.transpose(1, 2))
114
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
115
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
116
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
117
+ cm[:, 0, 1:] = 1.0
118
+ cm[:, 1:, 0] = 1.0
119
+ cm[:, 1:, 1:] = d2
120
+ k = N - 1
121
+ sign = (-1.0) ** (k + 1)
122
+ fact = math.factorial(k)
123
+ return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
124
+
125
+
126
+ def cv_of(emb, n_samples=200):
127
+ """CV of pentachoron volumes for a single (V, D) embedding."""
128
+ if emb.dim() != 2 or emb.shape[0] < 5:
129
+ return 0.0
130
+ N, D = emb.shape
131
+ pool = min(N, 512)
132
+ indices = torch.stack([torch.randperm(pool, device=emb.device)[:5]
133
+ for _ in range(n_samples)])
134
+ vol2 = cayley_menger_vol2(emb[:pool][indices])
135
+ valid = vol2 > 1e-20
136
+ if valid.sum() < 10:
137
+ return 0.0
138
+ vols = vol2[valid].sqrt()
139
+ return (vols.std() / (vols.mean() + 1e-8)).item()
140
+
141
+
142
+ # ── Omega Noise Dataset (16 types) ───────────────────────────────
143
+
144
+ class OmegaNoiseDataset(torch.utils.data.Dataset):
145
+ """16 noise types at arbitrary resolution, with optional type-restriction.
146
+
147
+ Args:
148
+ size: dataset length
149
+ img_size: spatial resolution
150
+ seed_rotate_every: re-seed every N calls (prevents epoch repetition)
151
+ allowed_types: iterable of int type indices, or None for all 16.
152
+ For Gaussian-only foundation use allowed_types=[0].
153
+ """
154
+ N_TYPES = 16
155
+
156
+ def __init__(self, size=1_000_000, img_size=128,
157
+ seed_rotate_every=1000, allowed_types=None):
158
+ self.size = size
159
+ self.img_size = img_size
160
+ self.seed_rotate_every = seed_rotate_every
161
+ self._rng = np.random.RandomState(42)
162
+ self._call_count = 0
163
+ self.allowed_types = (list(allowed_types) if allowed_types
164
+ else list(range(self.N_TYPES)))
165
+
166
+ def __len__(self):
167
+ return self.size
168
+
169
+ def _rotate_seed(self):
170
+ self._call_count += 1
171
+ if self._call_count % self.seed_rotate_every == 0:
172
+ new_seed = int.from_bytes(os.urandom(4), 'big')
173
+ self._rng = np.random.RandomState(new_seed)
174
+ torch.manual_seed(new_seed)
175
+
176
+ def _pink_noise(self, shape):
177
+ white = torch.randn(shape)
178
+ S = torch.fft.rfft2(white)
179
+ h, w = shape[-2], shape[-1]
180
+ fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1)
181
+ fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1)
182
+ f = torch.sqrt(fx**2 + fy**2).clamp(min=1e-8)
183
+ return torch.fft.irfft2(S / f, s=(h, w))
184
+
185
+ def _brown_noise(self, shape):
186
+ white = torch.randn(shape)
187
+ S = torch.fft.rfft2(white)
188
+ h, w = shape[-2], shape[-1]
189
+ fy = torch.fft.fftfreq(h).unsqueeze(-1).expand(-1, w // 2 + 1)
190
+ fx = torch.fft.rfftfreq(w).unsqueeze(0).expand(h, -1)
191
+ f = (fx**2 + fy**2).clamp(min=1e-8)
192
+ return torch.fft.irfft2(S / f, s=(h, w))
193
+
194
+ def __getitem__(self, idx):
195
+ self._rotate_seed()
196
+ s = self.img_size
197
+ noise_type = self.allowed_types[idx % len(self.allowed_types)]
198
+
199
+ if noise_type == 0:
200
+ img = torch.randn(3, s, s)
201
+ elif noise_type == 1:
202
+ img = torch.rand(3, s, s) * 2 - 1
203
+ elif noise_type == 2:
204
+ img = (torch.rand(3, s, s) - 0.5) * 4
205
+ elif noise_type == 3:
206
+ lam = self._rng.uniform(0.5, 20.0)
207
+ img = torch.poisson(torch.full((3, s, s), lam)) / lam - 1.0
208
+ elif noise_type == 4:
209
+ img = self._pink_noise((3, s, s))
210
+ img = img / (img.std() + 1e-8)
211
+ elif noise_type == 5:
212
+ img = self._brown_noise((3, s, s))
213
+ img = img / (img.std() + 1e-8)
214
+ elif noise_type == 6:
215
+ img = torch.where(torch.rand(3, s, s) > 0.5,
216
+ torch.ones(3, s, s) * 2, torch.ones(3, s, s) * -2)
217
+ img = img + torch.randn(3, s, s) * 0.1
218
+ elif noise_type == 7:
219
+ mask = torch.rand(3, s, s) > 0.9
220
+ img = torch.randn(3, s, s) * mask.float() * 3
221
+ elif noise_type == 8:
222
+ block = self._rng.randint(2, 16)
223
+ small = torch.randn(3, s // block + 1, s // block + 1)
224
+ img = F.interpolate(small.unsqueeze(0), size=s, mode='nearest').squeeze(0)
225
+ elif noise_type == 9:
226
+ gy = torch.linspace(-2, 2, s).unsqueeze(1).expand(s, s)
227
+ gx = torch.linspace(-2, 2, s).unsqueeze(0).expand(s, s)
228
+ angle = self._rng.uniform(0, 2 * math.pi)
229
+ grad = math.cos(angle) * gx + math.sin(angle) * gy
230
+ img = grad.unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.5
231
+ elif noise_type == 10:
232
+ check_size = self._rng.randint(2, 16)
233
+ coords_y = torch.arange(s) // check_size
234
+ coords_x = torch.arange(s) // check_size
235
+ checker = ((coords_y.unsqueeze(1) + coords_x.unsqueeze(0)) % 2).float() * 2 - 1
236
+ img = checker.unsqueeze(0).expand(3, -1, -1) + torch.randn(3, s, s) * 0.3
237
+ elif noise_type == 11:
238
+ a = torch.randn(3, s, s)
239
+ b = torch.rand(3, s, s) * 2 - 1
240
+ alpha = self._rng.uniform(0.2, 0.8)
241
+ img = alpha * a + (1 - alpha) * b
242
+ elif noise_type == 12:
243
+ img = torch.zeros(3, s, s)
244
+ h2, w2 = s // 2, s // 2
245
+ img[:, :h2, :w2] = torch.randn(3, h2, w2)
246
+ img[:, :h2, w2:] = torch.rand(3, h2, w2) * 2 - 1
247
+ img[:, h2:, :w2] = self._pink_noise((3, h2, w2)) / 2
248
+ sp = torch.where(torch.rand(3, h2, w2) > 0.5,
249
+ torch.ones(3, h2, w2), -torch.ones(3, h2, w2))
250
+ img[:, h2:, w2:] = sp
251
+ elif noise_type == 13:
252
+ u = torch.rand(3, s, s)
253
+ img = torch.tan(math.pi * (u - 0.5)).clamp(-3, 3)
254
+ elif noise_type == 14:
255
+ img = torch.empty(3, s, s).exponential_(1.0) - 1.0
256
+ elif noise_type == 15:
257
+ u = torch.rand(3, s, s) - 0.5
258
+ img = -torch.sign(u) * torch.log1p(-2 * u.abs())
259
+ else:
260
+ raise ValueError(f"Unknown noise type {noise_type}")
261
+
262
+ return img.clamp(-4, 4).float(), noise_type
263
+
264
+
265
+ # ── Patch Utils + Small Components ───────────────────────────────
266
+
267
+ def extract_patches(images, patch_size):
268
+ B, C, H, W = images.shape
269
+ gh, gw = H // patch_size, W // patch_size
270
+ p = images.reshape(B, C, gh, patch_size, gw, patch_size)
271
+ p = p.permute(0, 2, 4, 1, 3, 5)
272
+ return p.reshape(B, gh * gw, C * patch_size * patch_size), gh, gw
273
+
274
+
275
+ def stitch_patches(patches, gh, gw, patch_size):
276
+ B = patches.shape[0]
277
+ p = patches.reshape(B, gh, gw, 3, patch_size, patch_size)
278
+ p = p.permute(0, 3, 1, 4, 2, 5)
279
+ return p.reshape(B, 3, gh * patch_size, gw * patch_size)
280
+
281
+
282
+ class BoundarySmooth(nn.Module):
283
+ def __init__(self, channels=3, mid=16):
284
+ super().__init__()
285
+ self.net = nn.Sequential(
286
+ nn.Conv2d(channels, mid, 3, padding=1), nn.GELU(),
287
+ nn.Conv2d(mid, channels, 3, padding=1))
288
+ nn.init.zeros_(self.net[-1].weight)
289
+ nn.init.zeros_(self.net[-1].bias)
290
+
291
+ def forward(self, x):
292
+ return x + self.net(x)
293
+
294
+
295
+ class SpectralCrossAttention(nn.Module):
296
+ """Multiplicative cross-attention on S vectors with bounded alpha."""
297
+ def __init__(self, D, n_heads=4, max_alpha=0.2, alpha_init=-2.0):
298
+ super().__init__()
299
+ heads = min(n_heads, D) if D % n_heads != 0 else n_heads
300
+ # ensure divisibility
301
+ while D % heads != 0 and heads > 1:
302
+ heads -= 1
303
+ self.n_heads = heads
304
+ self.head_dim = D // heads
305
+ self.max_alpha = max_alpha
306
+ self.qkv = nn.Linear(D, 3 * D)
307
+ self.out_proj = nn.Linear(D, D)
308
+ self.norm = nn.LayerNorm(D)
309
+ self.scale = self.head_dim ** -0.5
310
+ self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init))
311
+
312
+ @property
313
+ def alpha(self):
314
+ return self.max_alpha * torch.sigmoid(self.alpha_logits)
315
+
316
+ def forward(self, S):
317
+ B, N, D = S.shape
318
+ S_normed = self.norm(S)
319
+ qkv = self.qkv(S_normed).reshape(B, N, 3, self.n_heads, self.head_dim)
320
+ qkv = qkv.permute(2, 0, 3, 1, 4)
321
+ q, k, v = qkv[0], qkv[1], qkv[2]
322
+ attn = (q @ k.transpose(-2, -1)) * self.scale
323
+ attn = attn.softmax(dim=-1)
324
+ out = (attn @ v).transpose(1, 2).reshape(B, N, D)
325
+ gate = torch.tanh(self.out_proj(out))
326
+ return S * (1.0 + self.alpha.unsqueeze(0).unsqueeze(0) * gate)
327
+
328
+
329
+ # ── PatchSVAE-F (miniature) ──────────────────────────────────────
330
+
331
+ class PatchSVAE_F(nn.Module):
332
+ """F-class miniature PatchSVAE.
333
+
334
+ Architecture matches Fresnel/Johanna reference verbatim, sized down.
335
+ All defense-stack mechanisms preserved:
336
+ - Sphere-normalized rows (F.normalize after enc_out)
337
+ - fp64 SVD with 1e-24 / 1e-16 floors and 1e-12 diag reg
338
+ - Orthogonal init on enc_out
339
+ - Multiplicative cross-attention with bounded alpha (max 0.2)
340
+ - Zero-initialized boundary smoothing
341
+ - No BatchNorm, no Dropout
342
+ """
343
+ def __init__(self, matrix_v=64, D=8, patch_size=16, hidden=128,
344
+ depth=1, n_cross_layers=1, n_heads=4,
345
+ max_alpha=0.2, alpha_init=-2.0):
346
+ super().__init__()
347
+ self.matrix_v = matrix_v
348
+ self.D = D
349
+ self.patch_size = patch_size
350
+ self.patch_dim = 3 * patch_size * patch_size
351
+ self.mat_dim = matrix_v * D
352
+ self.hidden = hidden
353
+ self.depth = depth
354
+ self.n_cross_layers = n_cross_layers
355
+
356
+ self.enc_in = nn.Linear(self.patch_dim, hidden)
357
+ self.enc_blocks = nn.ModuleList([
358
+ nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
359
+ nn.GELU(), nn.Linear(hidden, hidden))
360
+ for _ in range(depth)])
361
+ self.enc_out = nn.Linear(hidden, self.mat_dim)
362
+ nn.init.orthogonal_(self.enc_out.weight)
363
+
364
+ self.dec_in = nn.Linear(self.mat_dim, hidden)
365
+ self.dec_blocks = nn.ModuleList([
366
+ nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, hidden),
367
+ nn.GELU(), nn.Linear(hidden, hidden))
368
+ for _ in range(depth)])
369
+ self.dec_out = nn.Linear(hidden, self.patch_dim)
370
+
371
+ self.cross_attn = nn.ModuleList([
372
+ SpectralCrossAttention(D, n_heads=n_heads,
373
+ max_alpha=max_alpha, alpha_init=alpha_init)
374
+ for _ in range(n_cross_layers)])
375
+ self.boundary_smooth = BoundarySmooth(channels=3, mid=16)
376
+
377
+ def encode_patches(self, patches):
378
+ B, N, _ = patches.shape
379
+ flat = patches.reshape(B * N, -1)
380
+ h = F.gelu(self.enc_in(flat))
381
+ for block in self.enc_blocks:
382
+ h = h + block(h)
383
+ M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D)
384
+ M = F.normalize(M, dim=-1)
385
+ U, S, Vt = _svd_fp64(M)
386
+ U = U.reshape(B, N, self.matrix_v, self.D)
387
+ S = S.reshape(B, N, self.D)
388
+ Vt = Vt.reshape(B, N, self.D, self.D)
389
+ M = M.reshape(B, N, self.matrix_v, self.D)
390
+ S_coord = S
391
+ for layer in self.cross_attn:
392
+ S_coord = layer(S_coord)
393
+ return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M}
394
+
395
+ def decode_patches(self, U, S, Vt):
396
+ B, N, V, D = U.shape
397
+ U_flat = U.reshape(B * N, V, D)
398
+ S_flat = S.reshape(B * N, D)
399
+ Vt_flat = Vt.reshape(B * N, D, D)
400
+ M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat)
401
+ h = F.gelu(self.dec_in(M_hat.reshape(B * N, -1)))
402
+ for block in self.dec_blocks:
403
+ h = h + block(h)
404
+ return self.dec_out(h).reshape(B, N, -1)
405
+
406
+ def forward(self, images):
407
+ patches, gh, gw = extract_patches(images, self.patch_size)
408
+ svd = self.encode_patches(patches)
409
+ decoded = self.decode_patches(svd['U'], svd['S'], svd['Vt'])
410
+ recon = stitch_patches(decoded, gh, gw, self.patch_size)
411
+ recon = self.boundary_smooth(recon)
412
+ return {'recon': recon, 'svd': svd, 'gh': gh, 'gw': gw}
413
+
414
+ @staticmethod
415
+ def effective_rank(S):
416
+ p = S / (S.sum(-1, keepdim=True) + 1e-8)
417
+ p = p.clamp(min=1e-8)
418
+ return (-(p * p.log()).sum(-1)).exp()
419
+
420
+
421
+ # ── Config dataclass ─────────────────────────────────────────────
422
+
423
+ @dataclass
424
+ class RunConfig:
425
+ """F-class run configuration.
426
+
427
+ Naming convention: johanna-F-S{img_size}-V{V}-D{D}-h{hidden}-d{depth}-p{patch}
428
+ """
429
+ # Architecture (the sweep axes)
430
+ matrix_v: int = 64
431
+ D: int = 8
432
+ patch_size: int = 16
433
+ hidden: int = 128
434
+ depth: int = 1
435
+ n_cross_layers: int = 1
436
+ n_heads: int = 4
437
+ max_alpha: float = 0.2
438
+ alpha_init: float = -2.0
439
+
440
+ # Training
441
+ img_size: int = 128
442
+ batch_size: int = 128
443
+ lr: float = 1e-3
444
+ epochs: int = 20
445
+ weight_decay: float = 0.0 # Phil: always pure Adam
446
+
447
+ # Loss / soft-hand
448
+ # Soft-hand guides against CV-EMA β€” geometric coherence signal.
449
+ # CV (Cayley-Menger pentachoron volume CV) measures whether the
450
+ # sphere-normalized rows are arranged with geometric consistency.
451
+ # We don't force CV toward a specific value β€” we track its own EMA
452
+ # and reward CV being near its own trajectory (auto-attenuating).
453
+ #
454
+ # This is the correct reading from Phil's SVAE lineage:
455
+ # - CV tells us the geometry is COHERENT (relational guidepost)
456
+ # - Recon MSE tells us the arrangement is VALID (reversible)
457
+ # - Both are needed; recon alone doesn't guarantee geometric structure
458
+ use_cv_ema: bool = True
459
+ cv_ema_alpha: float = 0.01 # EMA smoothing (slow is better)
460
+ cv_alignment_epochs: int = 1 # Pure-MSE epochs before soft-hand activates
461
+ cv_measure_every: int = 25 # Measure CV every N batches
462
+ cv_sigma_scale: float = 0.3 # Proximity width = sigma_scale Γ— cv_ema
463
+ boost: float = 0.5 # Max recon-weight boost when CV near EMA
464
+ cross_attn_clip: float = 0.5
465
+
466
+ # Data
467
+ allowed_types: List[int] = field(default_factory=lambda: [0]) # Gaussian only
468
+ train_size: int = 100_000
469
+ val_size: int = 2000
470
+ num_workers: int = 4
471
+
472
+ # Reporting
473
+ report_every: int = 500 # TB log cadence (batches)
474
+ major_report_every: int = 10 # Console major-report cadence (epochs)
475
+ save_every: int = 5 # Checkpoint save cadence (epochs)
476
+ seed: int = 42
477
+
478
+ # Upload
479
+ hf_repo: str = "AbstractPhil/geolip-svae-batteries"
480
+ upload: bool = True
481
+
482
+ def name(self) -> str:
483
+ return (f"johanna-F-S{self.img_size}-V{self.matrix_v}"
484
+ f"-D{self.D}-h{self.hidden}-d{self.depth}"
485
+ f"-p{self.patch_size}")
486
+
487
+
488
+ # ── Training runner ──────────────────────────────────────────────
489
+
490
+ def run(cfg: RunConfig, out_root: str = "/content/johanna_F_runs"):
491
+ torch.manual_seed(cfg.seed)
492
+ np.random.seed(cfg.seed)
493
+ torch.set_float32_matmul_precision('high')
494
+
495
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
496
+ run_name = cfg.name()
497
+ run_dir = os.path.join(out_root, run_name)
498
+ ckpt_dir = os.path.join(run_dir, "checkpoints")
499
+ tb_dir = os.path.join(run_dir, "tensorboard")
500
+ os.makedirs(ckpt_dir, exist_ok=True)
501
+ os.makedirs(tb_dir, exist_ok=True)
502
+
503
+ # Save config snapshot
504
+ with open(os.path.join(run_dir, "config.json"), 'w') as f:
505
+ json.dump(asdict(cfg), f, indent=2)
506
+
507
+ # Model
508
+ model = PatchSVAE_F(
509
+ matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
510
+ hidden=cfg.hidden, depth=cfg.depth,
511
+ n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
512
+ max_alpha=cfg.max_alpha, alpha_init=cfg.alpha_init,
513
+ ).to(device)
514
+
515
+ n_params = sum(p.numel() for p in model.parameters())
516
+ n_patches = (cfg.img_size // cfg.patch_size) ** 2
517
+
518
+ # TensorBoard
519
+ from torch.utils.tensorboard import SummaryWriter
520
+ writer = SummaryWriter(tb_dir)
521
+
522
+ # HF
523
+ hf_enabled = False
524
+ api = None
525
+ if cfg.upload:
526
+ try:
527
+ from huggingface_hub import HfApi
528
+ api = HfApi()
529
+ api.whoami()
530
+ hf_enabled = True
531
+ except Exception as e:
532
+ print(f" HF upload disabled: {e}")
533
+
534
+ def upload_file(local_path, remote_name):
535
+ if not hf_enabled:
536
+ return
537
+ try:
538
+ api.upload_file(
539
+ path_or_fileobj=local_path,
540
+ path_in_repo=f"{run_name}/{remote_name}",
541
+ repo_id=cfg.hf_repo, repo_type="model")
542
+ except Exception as e:
543
+ print(f" HF upload failed ({remote_name}): {e}")
544
+
545
+ def upload_folder(local_dir, remote_prefix):
546
+ if not hf_enabled:
547
+ return
548
+ try:
549
+ api.upload_folder(
550
+ folder_path=local_dir,
551
+ path_in_repo=f"{run_name}/{remote_prefix}",
552
+ repo_id=cfg.hf_repo, repo_type="model")
553
+ except Exception as e:
554
+ print(f" HF folder upload failed ({remote_prefix}): {e}")
555
+
556
+ # Datasets
557
+ train_ds = OmegaNoiseDataset(size=cfg.train_size, img_size=cfg.img_size,
558
+ allowed_types=cfg.allowed_types)
559
+ val_ds = OmegaNoiseDataset(size=cfg.val_size, img_size=cfg.img_size,
560
+ allowed_types=cfg.allowed_types)
561
+ train_loader = torch.utils.data.DataLoader(
562
+ train_ds, batch_size=cfg.batch_size, shuffle=True,
563
+ num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
564
+ test_loader = torch.utils.data.DataLoader(
565
+ val_ds, batch_size=cfg.batch_size, shuffle=False,
566
+ num_workers=cfg.num_workers, pin_memory=True)
567
+
568
+ # Optimizer (pure Adam per Phil preference)
569
+ opt = torch.optim.Adam(model.parameters(), lr=cfg.lr,
570
+ weight_decay=cfg.weight_decay)
571
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)
572
+
573
+ # Header
574
+ print("=" * 100)
575
+ print(f"RUN: {run_name}")
576
+ print(f" V={cfg.matrix_v}, D={cfg.D}, hidden={cfg.hidden}, depth={cfg.depth}, "
577
+ f"patch={cfg.patch_size}, n_cross={cfg.n_cross_layers}")
578
+ print(f" patches/image={n_patches}, params={n_params:,}, batch={cfg.batch_size}")
579
+ print(f" types={cfg.allowed_types}, epochs={cfg.epochs}, lr={cfg.lr}")
580
+ print(f" HF repo={cfg.hf_repo}, upload={hf_enabled}")
581
+ print("=" * 100)
582
+ print(f" {'ep':>3} {'batch':>7} | {'loss':>7} {'recon':>7} {'test':>7} {'r_obs':>6} | "
583
+ f"{'S0':>6} {'SD':>6} {'ratio':>5} {'erank':>5} | "
584
+ f"{'cv':>6} {'cv_ema':>6} | "
585
+ f"{'prox':>5} {'rw':>4} {'S_del':>6} {'a_m':>5} | "
586
+ f"{'grad':>7} {'ph':>4}")
587
+ print("-" * 130)
588
+
589
+ best_mse = float('inf')
590
+ global_batch = 0
591
+ last_cv = 0.0
592
+ cv_ema = None # the guidance signal β€” CV's own trajectory
593
+ recon_ema_obs = None # observable only, used for logging
594
+ last_prox = 1.0
595
+ in_alignment = cfg.use_cv_ema
596
+ history = []
597
+
598
+ # Save initial config upload (creates repo structure early)
599
+ if hf_enabled:
600
+ upload_file(os.path.join(run_dir, "config.json"), "config.json")
601
+
602
+ for epoch in range(1, cfg.epochs + 1):
603
+ model.train()
604
+ tot_loss = tot_recon = n_seen = 0
605
+ epoch_max_grad = 0.0
606
+ t0 = time.time()
607
+
608
+ # Alignment phase: pure MSE, CV-EMA accumulates
609
+ if cfg.use_cv_ema:
610
+ in_alignment = epoch <= cfg.cv_alignment_epochs
611
+
612
+ # Progress bar for this epoch's batches
613
+ if _HAS_TQDM:
614
+ batch_iter = tqdm(train_loader, desc=f"ep {epoch:3d}/{cfg.epochs}",
615
+ leave=False, dynamic_ncols=True,
616
+ bar_format='{l_bar}{bar:20}{r_bar}')
617
+ else:
618
+ batch_iter = train_loader
619
+
620
+ for batch_idx, (images, _) in enumerate(batch_iter):
621
+ images = images.to(device, non_blocking=True)
622
+ opt.zero_grad()
623
+ out = model(images)
624
+ recon_loss = F.mse_loss(out['recon'], images)
625
+ recon_val = recon_loss.item()
626
+
627
+ with torch.no_grad():
628
+ # Track recon_ema as observable (for logging only)
629
+ if recon_ema_obs is None:
630
+ recon_ema_obs = recon_val
631
+ else:
632
+ recon_ema_obs = 0.99 * recon_ema_obs + 0.01 * recon_val
633
+
634
+ # Measure CV and update CV-EMA (the GUIDANCE signal)
635
+ if batch_idx % cfg.cv_measure_every == 0:
636
+ current_cv = cv_of(out['svd']['M'][0, 0])
637
+ if current_cv > 0:
638
+ last_cv = current_cv
639
+ if cv_ema is None:
640
+ cv_ema = current_cv
641
+ else:
642
+ cv_ema = ((1.0 - cfg.cv_ema_alpha) * cv_ema
643
+ + cfg.cv_ema_alpha * current_cv)
644
+
645
+ # Soft-hand proximity: is CURRENT CV near the EMA trajectory?
646
+ # This measures geometric coherence β€” whether the arrangement
647
+ # is settling (CV near its trend) or thrashing (CV deviating).
648
+ if cv_ema is not None and cv_ema > 1e-6:
649
+ sigma_adapt = max(cfg.cv_sigma_scale * cv_ema, 1e-6)
650
+ delta = last_cv - cv_ema
651
+ last_prox = math.exp(-(delta ** 2) / (2 * sigma_adapt ** 2))
652
+
653
+ # Soft-hand: boost recon weight when geometry is coherent.
654
+ # No penalty term β€” we never fight the model toward a specific CV.
655
+ # We only reward the recon path when geometry is relationally stable.
656
+ if in_alignment:
657
+ loss = recon_loss
658
+ recon_w = 1.0
659
+ else:
660
+ recon_w = 1.0 + cfg.boost * last_prox
661
+ loss = recon_w * recon_loss
662
+
663
+ loss.backward()
664
+
665
+ # Clip cross-attn only (rest stays free per SVAE recipe)
666
+ torch.nn.utils.clip_grad_norm_(
667
+ model.cross_attn.parameters(), max_norm=cfg.cross_attn_clip)
668
+
669
+ # Measure total grad norm for stability tracking
670
+ total_grad = sum(
671
+ p.grad.pow(2).sum().item()
672
+ for p in model.parameters() if p.grad is not None
673
+ ) ** 0.5
674
+ epoch_max_grad = max(epoch_max_grad, total_grad)
675
+
676
+ opt.step()
677
+
678
+ tot_loss += loss.item() * images.size(0)
679
+ tot_recon += recon_val * images.size(0)
680
+ n_seen += images.size(0)
681
+ global_batch += 1
682
+
683
+ # Update tqdm postfix with live stats
684
+ if _HAS_TQDM:
685
+ cvema_s = f"{cv_ema:.3f}" if cv_ema is not None else "---"
686
+ batch_iter.set_postfix_str(
687
+ f"r={recon_val:.4f} cv={last_cv:.3f} ema={cvema_s} "
688
+ f"px={last_prox:.2f} rw={recon_w:.2f}", refresh=False)
689
+
690
+ # Periodic TB + history snapshot (silent β€” no console print at batch level)
691
+ if global_batch % cfg.report_every == 0:
692
+ model.eval()
693
+ with torch.no_grad():
694
+ test_imgs, _ = next(iter(test_loader))
695
+ test_imgs = test_imgs.to(device)
696
+ t_out = model(test_imgs)
697
+ test_mse = F.mse_loss(t_out['recon'], test_imgs).item()
698
+
699
+ # Geometry
700
+ S_batch = t_out['svd']['S'] # (B, N, D)
701
+ S_orig = t_out['svd']['S_orig']
702
+ S_mean = S_batch.mean(dim=(0, 1)) # (D,)
703
+ S0 = S_mean[0].item()
704
+ SD = S_mean[-1].item()
705
+ ratio = S0 / (SD + 1e-8)
706
+ erank = PatchSVAE_F.effective_rank(
707
+ S_batch.reshape(-1, cfg.D)).mean().item()
708
+ s_delta = (S_batch - S_orig).abs().mean().item()
709
+
710
+ # Alpha across cross-attn layers
711
+ if cfg.n_cross_layers > 0:
712
+ alphas = [layer.alpha.detach() for layer in model.cross_attn]
713
+ alpha_mean = torch.stack([a.mean() for a in alphas]).mean().item()
714
+ alpha_std = torch.stack([a.std() for a in alphas]).mean().item()
715
+ else:
716
+ alpha_mean = 0.0
717
+ alpha_std = 0.0
718
+
719
+ # CV band check
720
+ cv_in_band = 0.13 <= last_cv <= 0.30
721
+
722
+ # TB logs (always)
723
+ writer.add_scalar('train/loss', tot_loss / n_seen, global_batch)
724
+ writer.add_scalar('train/recon', tot_recon / n_seen, global_batch)
725
+ writer.add_scalar('test/mse', test_mse, global_batch)
726
+ if recon_ema_obs is not None:
727
+ writer.add_scalar('stab/recon_ema_obs', recon_ema_obs, global_batch)
728
+ writer.add_scalar('geo/S0', S0, global_batch)
729
+ writer.add_scalar('geo/SD', SD, global_batch)
730
+ writer.add_scalar('geo/ratio', ratio, global_batch)
731
+ writer.add_scalar('geo/erank', erank, global_batch)
732
+ writer.add_scalar('geo/row_cv', last_cv, global_batch)
733
+ if cv_ema is not None:
734
+ writer.add_scalar('geo/cv_ema', cv_ema, global_batch)
735
+ writer.add_scalar('geo/cv_in_band', float(cv_in_band), global_batch)
736
+ writer.add_scalar('geo/S_delta', s_delta, global_batch)
737
+ writer.add_scalar('cross_attn/alpha_mean', alpha_mean, global_batch)
738
+ writer.add_scalar('cross_attn/alpha_std', alpha_std, global_batch)
739
+ writer.add_scalar('stab/prox', last_prox, global_batch)
740
+ writer.add_scalar('stab/recon_w', recon_w, global_batch)
741
+ writer.add_scalar('stab/epoch_max_grad', epoch_max_grad, global_batch)
742
+ writer.add_scalar('stab/lr', opt.param_groups[0]['lr'], global_batch)
743
+ writer.add_scalar('stab/in_alignment', float(in_alignment), global_batch)
744
+
745
+ history.append({
746
+ 'epoch': epoch, 'global_batch': global_batch,
747
+ 'train_recon': tot_recon / n_seen,
748
+ 'test_mse': test_mse,
749
+ 'recon_ema_obs': recon_ema_obs,
750
+ 'S0': S0, 'SD': SD, 'ratio': ratio, 'erank': erank,
751
+ 'row_cv': last_cv, 'cv_ema': cv_ema,
752
+ 'cv_in_band': cv_in_band,
753
+ 'S_delta': s_delta,
754
+ 'alpha_mean': alpha_mean, 'alpha_std': alpha_std,
755
+ 'grad_max': epoch_max_grad,
756
+ 'in_alignment': in_alignment,
757
+ 'prox': last_prox, 'recon_w': recon_w,
758
+ })
759
+
760
+ if test_mse < best_mse:
761
+ best_mse = test_mse
762
+ # Save best silently (upload only at save_every)
763
+ torch.save({
764
+ 'epoch': epoch, 'test_mse': test_mse,
765
+ 'global_batch': global_batch,
766
+ 'model_state_dict': model.state_dict(),
767
+ 'config': asdict(cfg),
768
+ }, os.path.join(ckpt_dir, 'best.pt'))
769
+
770
+ model.train()
771
+
772
+ if _HAS_TQDM:
773
+ batch_iter.close()
774
+
775
+ sched.step()
776
+ epoch_time = time.time() - t0
777
+
778
+ # Full-eval at epoch boundary
779
+ model.eval()
780
+ tot_test = test_n = 0
781
+ with torch.no_grad():
782
+ for t_imgs, _ in test_loader:
783
+ t_imgs = t_imgs.to(device)
784
+ t_out = model(t_imgs)
785
+ tot_test += F.mse_loss(t_out['recon'], t_imgs).item() * t_imgs.size(0)
786
+ test_n += t_imgs.size(0)
787
+ epoch_test_mse = tot_test / test_n
788
+ writer.add_scalar('epoch/test_mse', epoch_test_mse, epoch)
789
+ writer.add_scalar('epoch/time_s', epoch_time, epoch)
790
+ writer.add_scalar('epoch/max_grad', epoch_max_grad, epoch)
791
+
792
+ if epoch_test_mse < best_mse:
793
+ best_mse = epoch_test_mse
794
+ torch.save({
795
+ 'epoch': epoch, 'test_mse': epoch_test_mse,
796
+ 'global_batch': global_batch,
797
+ 'model_state_dict': model.state_dict(),
798
+ 'config': asdict(cfg),
799
+ }, os.path.join(ckpt_dir, 'best.pt'))
800
+
801
+ # MAJOR REPORT: every 10 epochs + first + last
802
+ is_major_report = (
803
+ epoch == 1 or epoch == cfg.epochs or
804
+ epoch % cfg.major_report_every == 0
805
+ )
806
+ if is_major_report:
807
+ cvema_s = f"{cv_ema:.4f}" if cv_ema is not None else "---"
808
+ rema_s = f"{recon_ema_obs:.4f}" if recon_ema_obs is not None else "---"
809
+ print(f" ep {epoch:3d}/{cfg.epochs}: "
810
+ f"test_mse={epoch_test_mse:.6f} best={best_mse:.6f} "
811
+ f"| cv_ema={cvema_s} recon_ema={rema_s} "
812
+ f"| max_grad={epoch_max_grad:.2f} "
813
+ f"| {epoch_time:.1f}s"
814
+ f"{' [ALGN]' if in_alignment else ' [HAND]'}")
815
+ else:
816
+ # Minimal per-epoch print (one line)
817
+ print(f" ep {epoch:3d}: mse={epoch_test_mse:.4f} "
818
+ f"grad={epoch_max_grad:.1f} {epoch_time:.1f}s")
819
+
820
+ # Periodic save + upload
821
+ if epoch % cfg.save_every == 0 or epoch == cfg.epochs:
822
+ ep_path = os.path.join(ckpt_dir, f"epoch_{epoch:04d}.pt")
823
+ torch.save({
824
+ 'epoch': epoch, 'test_mse': epoch_test_mse,
825
+ 'global_batch': global_batch,
826
+ 'model_state_dict': model.state_dict(),
827
+ 'optimizer_state_dict': opt.state_dict(),
828
+ 'scheduler_state_dict': sched.state_dict(),
829
+ 'config': asdict(cfg),
830
+ }, ep_path)
831
+ writer.flush()
832
+
833
+ if hf_enabled:
834
+ upload_file(ep_path, f"checkpoints/epoch_{epoch:04d}.pt")
835
+ best_path = os.path.join(ckpt_dir, 'best.pt')
836
+ if os.path.exists(best_path):
837
+ upload_file(best_path, "checkpoints/best.pt")
838
+ upload_folder(tb_dir, "tensorboard")
839
+
840
+ writer.close()
841
+
842
+ # Final report
843
+ final = {
844
+ 'run_name': run_name,
845
+ 'config': asdict(cfg),
846
+ 'n_params': n_params,
847
+ 'n_patches': n_patches,
848
+ 'best_test_mse': best_mse,
849
+ 'final_epoch_mse': epoch_test_mse,
850
+ 'final_cv_ema': cv_ema,
851
+ 'final_recon_ema_obs': recon_ema_obs,
852
+ 'final_S0': history[-1]['S0'] if history else None,
853
+ 'final_erank': history[-1]['erank'] if history else None,
854
+ 'final_row_cv': history[-1]['row_cv'] if history else None,
855
+ 'final_cv_in_band': history[-1]['cv_in_band'] if history else None,
856
+ 'final_S_delta': history[-1]['S_delta'] if history else None,
857
+ 'final_alpha_mean': history[-1]['alpha_mean'] if history else None,
858
+ 'history': history,
859
+ }
860
+ report_path = os.path.join(run_dir, "final_report.json")
861
+ with open(report_path, 'w') as f:
862
+ json.dump(final, f, indent=2)
863
+ if hf_enabled:
864
+ upload_file(report_path, "final_report.json")
865
+
866
+ print(f"\n RUN COMPLETE: {run_name}")
867
+ print(f" Best test MSE: {best_mse:.6f}")
868
+ cvema = f"{cv_ema:.4f}" if cv_ema is not None else "n/a"
869
+ rema = f"{recon_ema_obs:.4f}" if recon_ema_obs is not None else "n/a"
870
+ print(f" Final: S0={final['final_S0']:.3f}, erank={final['final_erank']:.3f}, "
871
+ f"cv={final['final_row_cv']:.4f} (cv_ema={cvema}), recon_ema_obs={rema}")
872
+ print(f" Checkpoints: {ckpt_dir}")
873
+ print(f" TensorBoard: {tb_dir}")
874
+ print(f" Report: {report_path}")
875
+ return final
876
+
877
+
878
+ # ── Smoke test entry point ───────────────────────────────────────
879
+
880
+ def smoke():
881
+ """Minimum viable smoke test β€” tiny config, 3 epochs, Gaussian only.
882
+
883
+ Soft-hand guides against CV-EMA (geometric coherence signal).
884
+ Epoch 1: alignment (pure MSE, cv_ema accumulates)
885
+ Epochs 2-3: hand active (proximity against cv_ema)
886
+ """
887
+ cfg = RunConfig(
888
+ matrix_v=32, D=4, patch_size=16, hidden=64, depth=1,
889
+ n_cross_layers=1, n_heads=2,
890
+ img_size=64,
891
+ batch_size=64,
892
+ lr=1e-3,
893
+ epochs=3,
894
+ allowed_types=[0],
895
+ train_size=5000,
896
+ val_size=500,
897
+ num_workers=2,
898
+ report_every=20,
899
+ save_every=1,
900
+ use_cv_ema=True,
901
+ cv_alignment_epochs=1,
902
+ cv_ema_alpha=0.05, # faster EMA for short smoke
903
+ cv_sigma_scale=0.3,
904
+ upload=True,
905
+ )
906
+ return run(cfg)
907
+
908
+
909
+ # ── F-class sweep ───────────────────────────────────────────────
910
+
911
+ def _johanna_base_cfg(**overrides):
912
+ """Base config for F-class (miniature) battery experiments.
913
+
914
+ Defaults are genuinely small β€” 16 patches per image, modest V/D,
915
+ thin substrate. Override any axis you're sweeping.
916
+
917
+ Reasonable overridable axes:
918
+ matrix_v, D, patch_size, hidden, depth, n_cross_layers,
919
+ img_size, batch_size, lr, epochs, allowed_types
920
+ """
921
+ base = dict(
922
+ # F-class defaults (small, stackable)
923
+ matrix_v=64, D=8, patch_size=16,
924
+ n_cross_layers=1, n_heads=4,
925
+ max_alpha=0.2, alpha_init=-2.0,
926
+
927
+ # Training conventions from Johanna lineage
928
+ img_size=64,
929
+ batch_size=128,
930
+ lr=1e-4,
931
+ epochs=30,
932
+ weight_decay=0.0,
933
+
934
+ allowed_types=list(range(16)),
935
+ train_size=1_280_000,
936
+ val_size=10_000,
937
+ num_workers=4,
938
+
939
+ report_every=500,
940
+ save_every=5,
941
+
942
+ # Soft-hand on CV-EMA (geometric coherence guidepost)
943
+ use_cv_ema=True,
944
+ cv_ema_alpha=0.01,
945
+ cv_alignment_epochs=2,
946
+ cv_sigma_scale=0.3,
947
+ cv_measure_every=50,
948
+ boost=0.5,
949
+ cross_attn_clip=0.5,
950
+
951
+ upload=True,
952
+ )
953
+ base.update(overrides)
954
+ return RunConfig(**base)
955
+
956
+
957
+ SWEEP_F_CLASS = [
958
+ # ═══════════════════════════════════════════════════════════════════
959
+ # F-class sweep: the experimental nursery.
960
+ #
961
+ # F-class models are NOT expected to succeed. They are research
962
+ # specimens β€” most will collapse, some will barely function,
963
+ # a few might surprise us. The sweep's job is to catalog failure
964
+ # modes at small scale, not to find a winning config.
965
+ #
966
+ # A-class = Johanna/Fresnel (17M, teachable workhorses)
967
+ # S-class = Freckles (2.5M, superior recon, too disorderly to teach)
968
+ # F-class = this sweep (0.03M-0.6M, expected to fail)
969
+ #
970
+ # Axes varied deliberately:
971
+ # - TINY overall (V,D,hidden,depth all small)
972
+ # - SMALL patchworks (few patches per image β€” less info to work with)
973
+ # - LARGE patchworks with SMALL internals (many patches + weak substrate)
974
+ # - UNUSUAL shapes (V >> D, D >> V equivalents, unusual depth ratios)
975
+ #
976
+ # Naming: johanna-F-S{S}-V{V}-D{D}-h{hidden}-d{depth}-p{patch}
977
+ # All runs log to HF `AbstractPhil/geolip-svae-batteries/<run_name>/`
978
+ # ═══════════════════════════════════════════════════════════════════
979
+
980
+ # ── TIER 1: D=16 spine (Johanna's universal-attractor dim, small scale) ──
981
+ # Question: does D=16 carry through when everything around it shrinks?
982
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=16, hidden=64, depth=1, patch_size=16), # 16 patches, ~250K
983
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=16, hidden=64, depth=1, patch_size=8), # 64 patches, ~176K
984
+ _johanna_base_cfg(img_size=64, matrix_v=128, D=16, hidden=128, depth=1, patch_size=8), # larger V+h, ~645K
985
+
986
+ # ── TIER 2: D-sweep at matched substrate (V=64, h=64, d=1, p=16) ──
987
+ # Does battery behavior survive D<16? (answers an unasked question of omega paper)
988
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=8, hidden=64, depth=1, patch_size=16), # D=8
989
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=4, hidden=64, depth=1, patch_size=16), # D=4
990
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=2, hidden=64, depth=1, patch_size=16), # D=2 β€” likely collapse
991
+
992
+ # ── TIER 3: Substrate axis at (V=64, D=8) ──
993
+ # Which axis carries the self-assembly work β€” width or depth?
994
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=8, hidden=128, depth=1, patch_size=16), # wider substrate
995
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=8, hidden=64, depth=2, patch_size=16), # deeper substrate
996
+ _johanna_base_cfg(img_size=64, matrix_v=64, D=8, hidden=32, depth=1, patch_size=16), # starved width
997
+
998
+ # ── TIER 4: Patchwork sweeps β€” big patchworks with small internals ──
999
+ # Many weak cells vs few strong cells. Tests the "cells are SVAE-shaped
1000
+ # functions, not batteries" framing.
1001
+ _johanna_base_cfg(img_size=64, matrix_v=32, D=4, hidden=32, depth=1, patch_size=4), # 256 patches, tiny internals
1002
+ _johanna_base_cfg(img_size=64, matrix_v=32, D=4, hidden=64, depth=1, patch_size=4), # 256 patches, slightly more substrate
1003
+ _johanna_base_cfg(img_size=32, matrix_v=32, D=4, hidden=32, depth=1, patch_size=2, batch_size=256), # 256 patches @ S=32
1004
+
1005
+ # ── TIER 5: Small patchworks (few patches, stronger cells) ──
1006
+ # Inverse: few big cells. Easier for a wrapper to channel, but each
1007
+ # cell carries more load.
1008
+ _johanna_base_cfg(img_size=32, matrix_v=64, D=8, hidden=128, depth=1, patch_size=16, batch_size=256), # 4 patches
1009
+ _johanna_base_cfg(img_size=32, matrix_v=64, D=8, hidden=64, depth=1, patch_size=16, batch_size=256), # 4 patches thinner
1010
+
1011
+ # ── TIER 6: Unusual shapes (chaos measurements) ──
1012
+ # V >> D and D >> V ratios well outside Johanna's 16:1 ratio.
1013
+ _johanna_base_cfg(img_size=64, matrix_v=256, D=2, hidden=64, depth=1, patch_size=16), # V=128Γ—D ratio
1014
+ _johanna_base_cfg(img_size=64, matrix_v=16, D=16, hidden=64, depth=1, patch_size=16), # V=D (square matrix)
1015
+ _johanna_base_cfg(img_size=64, matrix_v=8, D=16, hidden=64, depth=1, patch_size=16), # V<D (wide matrix)
1016
+
1017
+ # ── TIER 7: Extreme smallness (almost-certain collapse) ──
1018
+ # Lower bound of what can be built. Mostly exists as a "what does total
1019
+ # collapse look like" reference.
1020
+ _johanna_base_cfg(img_size=32, matrix_v=16, D=4, hidden=16, depth=1, patch_size=8, batch_size=256), # ~20K params
1021
+ _johanna_base_cfg(img_size=16, matrix_v=8, D=2, hidden=8, depth=1, patch_size=4, batch_size=256), # absurdly small
1022
+ ]
1023
+
1024
+
1025
+ # ── Run existence check (for auto-resume across sessions) ────────
1026
+
1027
+ def _hf_run_exists(run_name: str, hf_repo: str) -> bool:
1028
+ """Check if this run has already been completed on HF.
1029
+ A run is considered complete if `<run_name>/final_report.json` exists.
1030
+ """
1031
+ try:
1032
+ from huggingface_hub import HfApi
1033
+ api = HfApi()
1034
+ # list_repo_files returns paths in the repo. Check for final_report.json.
1035
+ files = api.list_repo_files(repo_id=hf_repo, repo_type="model")
1036
+ marker = f"{run_name}/final_report.json"
1037
+ return marker in files
1038
+ except Exception as e:
1039
+ # If we can't check (auth, network, etc), assume not complete and run.
1040
+ print(f"[HF-check] could not verify completion of {run_name}: {e}")
1041
+ return False
1042
+
1043
+
1044
+ def sweep(configs=None, out_root="/content/johanna_F_runs",
1045
+ skip_on_error=True, skip_completed=True):
1046
+ """Run a list of RunConfigs sequentially. One per config, full training each.
1047
+
1048
+ Args:
1049
+ configs: List of RunConfig. Defaults to SWEEP_F_CLASS.
1050
+ out_root: Local output directory.
1051
+ skip_on_error: If True, log the error and continue to next config.
1052
+ skip_completed: If True, skip configs whose final_report.json is
1053
+ already on HF (for auto-resume across sessions).
1054
+ """
1055
+ if configs is None:
1056
+ configs = SWEEP_F_CLASS
1057
+
1058
+ print("\n" + "#" * 100)
1059
+ print(f"# F-CLASS SWEEP: {len(configs)} configurations")
1060
+ print(f"# Most will collapse. This is expected. Collapse is a data point.")
1061
+ print(f"# Results go to HF: {configs[0].hf_repo if configs else '(n/a)'}")
1062
+ print(f"# Use a separate readout cell to pool JSON metrics afterward.")
1063
+ print("#" * 100)
1064
+
1065
+ # Pre-flight: check which runs already exist on HF
1066
+ completed = set()
1067
+ if skip_completed and configs and configs[0].upload:
1068
+ print("\n[preflight] checking HF for already-completed runs...")
1069
+ for cfg in configs:
1070
+ if _hf_run_exists(cfg.name(), cfg.hf_repo):
1071
+ completed.add(cfg.name())
1072
+ print(f" βœ“ already complete: {cfg.name()}")
1073
+ print(f"[preflight] {len(completed)}/{len(configs)} already done β€” will skip\n")
1074
+
1075
+ for i, cfg in enumerate(configs):
1076
+ status = "SKIP" if cfg.name() in completed else "RUN "
1077
+ print(f"# {status} {i+1:2d}. {cfg.name()}")
1078
+ print("#" * 100 + "\n")
1079
+
1080
+ results = []
1081
+ for i, cfg in enumerate(configs):
1082
+ if cfg.name() in completed:
1083
+ results.append({
1084
+ 'run_name': cfg.name(),
1085
+ 'skipped': True,
1086
+ 'reason': 'already completed on HF',
1087
+ 'config': asdict(cfg),
1088
+ })
1089
+ continue
1090
+
1091
+ print(f"\n\n{'β–“' * 100}")
1092
+ print(f"β–“ [{i+1}/{len(configs)}] {cfg.name()}")
1093
+ print(f"{'β–“' * 100}")
1094
+ try:
1095
+ final = run(cfg, out_root=out_root)
1096
+ results.append(final)
1097
+ except Exception as e:
1098
+ print(f"\n[!] RUN FAILED: {cfg.name()}")
1099
+ print(f"[!] {type(e).__name__}: {e}")
1100
+ import traceback
1101
+ traceback.print_exc()
1102
+ results.append({'run_name': cfg.name(), 'error': str(e),
1103
+ 'error_type': type(e).__name__,
1104
+ 'config': asdict(cfg)})
1105
+ if not skip_on_error:
1106
+ raise
1107
+
1108
+ # Write combined summary
1109
+ summary_path = os.path.join(out_root, "sweep_summary.json")
1110
+ os.makedirs(out_root, exist_ok=True)
1111
+ summary = {
1112
+ 'n_runs': len(results),
1113
+ 'n_succeeded': sum(1 for r in results
1114
+ if 'error' not in r and not r.get('skipped')),
1115
+ 'n_skipped': sum(1 for r in results if r.get('skipped')),
1116
+ 'n_errored': sum(1 for r in results if 'error' in r),
1117
+ 'runs': [
1118
+ {
1119
+ 'run_name': r.get('run_name'),
1120
+ 'skipped': r.get('skipped', False),
1121
+ 'error': r.get('error'),
1122
+ 'best_test_mse': r.get('best_test_mse'),
1123
+ 'final_cv_ema': r.get('final_cv_ema'),
1124
+ 'final_recon_ema_obs': r.get('final_recon_ema_obs'),
1125
+ 'final_S0': r.get('final_S0'),
1126
+ 'final_erank': r.get('final_erank'),
1127
+ 'final_row_cv': r.get('final_row_cv'),
1128
+ 'final_alpha_mean': r.get('final_alpha_mean'),
1129
+ 'n_params': r.get('n_params'),
1130
+ 'n_patches': r.get('n_patches'),
1131
+ } for r in results
1132
+ ],
1133
+ }
1134
+ with open(summary_path, 'w') as f:
1135
+ json.dump(summary, f, indent=2)
1136
+ print(f"\n\n{'#' * 100}")
1137
+ print(f"# SWEEP DONE: {summary['n_succeeded']} succeeded, "
1138
+ f"{summary['n_skipped']} skipped, {summary['n_errored']} errored")
1139
+ print(f"# Summary: {summary_path}")
1140
+ print(f"{'#' * 100}\n")
1141
+
1142
+ # Upload summary to HF
1143
+ if configs and configs[0].upload:
1144
+ try:
1145
+ from huggingface_hub import HfApi
1146
+ api = HfApi()
1147
+ api.whoami()
1148
+ api.upload_file(
1149
+ path_or_fileobj=summary_path,
1150
+ path_in_repo="sweep_summary.json",
1151
+ repo_id=configs[0].hf_repo, repo_type="model")
1152
+ print(f"[HF] Uploaded sweep summary to {configs[0].hf_repo}/sweep_summary.json")
1153
+ except Exception as e:
1154
+ print(f"[HF] Summary upload failed: {e}")
1155
+
1156
+ return results
1157
+
1158
+
1159
+ def _in_notebook():
1160
+ """Detect Jupyter / Colab environment."""
1161
+ try:
1162
+ from IPython import get_ipython
1163
+ shell = get_ipython().__class__.__name__
1164
+ return shell in ('ZMQInteractiveShell', 'Shell', 'Google.Colab')
1165
+ except Exception:
1166
+ return False
1167
+
1168
+
1169
+ def parse_args(argv=None):
1170
+ p = argparse.ArgumentParser()
1171
+ p.add_argument('--smoke', action='store_true', help='Run smoke test')
1172
+ p.add_argument('--V', type=int, default=64)
1173
+ p.add_argument('--D', type=int, default=8)
1174
+ p.add_argument('--hidden', type=int, default=128)
1175
+ p.add_argument('--depth', type=int, default=1)
1176
+ p.add_argument('--patch', type=int, default=16)
1177
+ p.add_argument('--img_size', type=int, default=128)
1178
+ p.add_argument('--batch', type=int, default=128)
1179
+ p.add_argument('--epochs', type=int, default=20)
1180
+ p.add_argument('--lr', type=float, default=1e-3)
1181
+ p.add_argument('--n_cross', type=int, default=1)
1182
+ p.add_argument('--all_types', action='store_true')
1183
+ p.add_argument('--no_upload', action='store_true')
1184
+ return p.parse_args(argv)
1185
+
1186
+
1187
+ def _cli_main(argv=None):
1188
+ args = parse_args(argv)
1189
+ if args.smoke:
1190
+ return smoke()
1191
+ cfg = RunConfig(
1192
+ matrix_v=args.V, D=args.D, patch_size=args.patch,
1193
+ hidden=args.hidden, depth=args.depth,
1194
+ n_cross_layers=args.n_cross,
1195
+ img_size=args.img_size,
1196
+ batch_size=args.batch,
1197
+ epochs=args.epochs, lr=args.lr,
1198
+ allowed_types=list(range(16)) if args.all_types else [0],
1199
+ upload=not args.no_upload,
1200
+ )
1201
+ return run(cfg)
1202
+
1203
+
1204
+ if __name__ == "__main__" and not _in_notebook():
1205
+ _cli_main()
1206
+ elif __name__ == "__main__":
1207
+ # In Colab / Jupyter: do nothing automatically.
1208
+ print("johanna_F_trainer loaded in notebook mode.")
1209
+ print(" β†’ smoke test: smoke()")
1210
+ print(" β†’ single config: run(RunConfig(matrix_v=256, D=16, ...))")
1211
+ print(" β†’ full F-class sweep: sweep()")
1212
+ print(" β†’ custom sweep list: sweep([cfg1, cfg2, ...])")