AbstractPhil commited on
Commit
7f5d157
Β·
verified Β·
1 Parent(s): cededba

Create 4_monkey_patched_trainer.py

Browse files
Files changed (1) hide show
  1. 4_monkey_patched_trainer.py +825 -0
4_monkey_patched_trainer.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ablation_trainer.py
3
+ ===================
4
+ Ablation trainer adapter for PatchSVAE_F (Johanna F-class).
5
+
6
+ Takes an ablation config dict, builds a proper RunConfig with overrides,
7
+ instantiates a PatchSVAE_F_Ablation subclass with the needed hooks,
8
+ runs the real training loop with batch-limit early stop, measures CV
9
+ throughout, computes Group N uniformity diagnostic at the end, and
10
+ returns a result dict ready for upload.
11
+
12
+ Imports from johanna_F_trainer.py. Drop this in alongside it in Colab.
13
+
14
+ Ablation hooks implemented:
15
+ Group A (seeds): pure seed variation via RunConfig.seed
16
+ Group B (noise types): overrides['noise_types'] β†’ RunConfig.allowed_types
17
+ Group C (optimizer): adam/sgd/adamw/lbfgs via build_optimizer
18
+ Group D (scheduler): cosine/constant/linear/warm_restart/one_cycle
19
+ Group E (soft-hand): use_soft_hand + boost + cv_penalty + hard_cv_target
20
+ Group F (activation): enc_in activation function swap
21
+ Group G (row_norm): sphere/none/layer_norm/scale_only
22
+ Group H (SVD): fp64/fp32/batch_shared/linear_readout
23
+ Group I (cross-attn): n_cross_layers + max_alpha
24
+ Group J (capacity): V and hidden overrides (within LOW band)
25
+ Group K (batch size): batch_size override
26
+ Group L (init): orthogonal/kaiming/xavier/normal_small
27
+ Group M (brute SGD): optimizer + lr + momentum + grad_clip
28
+ """
29
+
30
+ import os
31
+ import math
32
+ import time
33
+ from dataclasses import asdict, replace
34
+ from typing import Dict, Any, Optional, List
35
+
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+
42
+ # ────────────────────────────────────────────────────────────────────────
43
+ # Ablation hooks (Groups F/G/H/L) β€” implemented as a PatchSVAE_F subclass
44
+ # ────────────────────────────────────────────────────────────────────────
45
+
46
+ ACTIVATIONS = {
47
+ 'gelu': F.gelu,
48
+ 'relu': F.relu,
49
+ 'silu': F.silu,
50
+ 'tanh': torch.tanh,
51
+ 'identity': lambda x: x,
52
+ }
53
+
54
+
55
+ def row_normalize(M: torch.Tensor, mode: str) -> torch.Tensor:
56
+ """Group G: different row-normalization modes on the encoded matrix."""
57
+ if mode == 'sphere':
58
+ return F.normalize(M, dim=-1)
59
+ elif mode == 'none':
60
+ return M
61
+ elif mode == 'layer_norm':
62
+ mean = M.mean(dim=-1, keepdim=True)
63
+ var = M.var(dim=-1, keepdim=True, unbiased=False)
64
+ return (M - mean) / (var + 1e-8).sqrt()
65
+ elif mode == 'scale_only':
66
+ # Divide each row by the batch-mean row norm β€” no unit constraint
67
+ row_norms = M.norm(dim=-1, keepdim=True)
68
+ mean_norm = row_norms.mean(dim=-2, keepdim=True)
69
+ return M / (mean_norm + 1e-8)
70
+ else:
71
+ raise ValueError(f"unknown row_norm mode: {mode}")
72
+
73
+
74
+ def init_weights(module: nn.Module, scheme: str) -> None:
75
+ """Group L: initialization scheme applied to all Linear layers."""
76
+ if scheme == 'orthogonal':
77
+ for m in module.modules():
78
+ if isinstance(m, nn.Linear):
79
+ nn.init.orthogonal_(m.weight)
80
+ if m.bias is not None:
81
+ nn.init.zeros_(m.bias)
82
+ elif scheme == 'kaiming_normal':
83
+ for m in module.modules():
84
+ if isinstance(m, nn.Linear):
85
+ nn.init.kaiming_normal_(m.weight)
86
+ if m.bias is not None:
87
+ nn.init.zeros_(m.bias)
88
+ elif scheme == 'xavier_uniform':
89
+ for m in module.modules():
90
+ if isinstance(m, nn.Linear):
91
+ nn.init.xavier_uniform_(m.weight)
92
+ if m.bias is not None:
93
+ nn.init.zeros_(m.bias)
94
+ elif scheme == 'normal_0_02':
95
+ for m in module.modules():
96
+ if isinstance(m, nn.Linear):
97
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
98
+ if m.bias is not None:
99
+ nn.init.zeros_(m.bias)
100
+
101
+
102
+ class PatchSVAE_F_Ablation(PatchSVAE_F):
103
+ """PatchSVAE_F with ablation hooks for F/G/H/L groups.
104
+
105
+ At default settings (gelu / sphere / fp64 / orthogonal / no linear
106
+ readout) this behaves identically to PatchSVAE_F.
107
+ """
108
+ def __init__(self, *args,
109
+ activation: str = 'gelu',
110
+ row_norm: str = 'sphere',
111
+ svd_mode: str = 'fp64',
112
+ linear_readout: bool = False,
113
+ match_params: bool = True,
114
+ init_scheme: str = 'orthogonal',
115
+ **kwargs):
116
+ super().__init__(*args, **kwargs)
117
+ self.activation_fn = ACTIVATIONS[activation]
118
+ self.row_norm_mode = row_norm
119
+ self.svd_mode = svd_mode
120
+ self.linear_readout = linear_readout
121
+
122
+ if linear_readout:
123
+ readout_dim = self.matrix_v * self.D
124
+ if match_params:
125
+ self.readout = nn.Linear(readout_dim, readout_dim)
126
+ else:
127
+ self.readout = nn.Identity()
128
+
129
+ # Re-initialize under the requested scheme
130
+ if init_scheme != 'orthogonal':
131
+ init_weights(self, init_scheme)
132
+ # Re-apply orthogonal to enc_out specifically β€” load-bearing per
133
+ # the architecture docs
134
+ nn.init.orthogonal_(self.enc_out.weight)
135
+
136
+ def encode_patches(self, patches):
137
+ B, N, _ = patches.shape
138
+ flat = patches.reshape(B * N, -1)
139
+
140
+ # F group: activation swap on the enc_in
141
+ h = self.activation_fn(self.enc_in(flat))
142
+ for block in self.enc_blocks:
143
+ # Inner block activations (GELU inside the nn.Sequential) remain.
144
+ # For a complete F ablation this isn't a perfect swap but it's
145
+ # representative of the outer activation.
146
+ h = h + block(h)
147
+
148
+ M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D)
149
+
150
+ # G group: row normalization mode
151
+ M = row_normalize(M, self.row_norm_mode)
152
+
153
+ # H group: SVD variant or linear-readout replacement
154
+ if self.linear_readout:
155
+ flat_M = M.reshape(B * N, -1)
156
+ M_hat = self.readout(flat_M).reshape(B * N, self.matrix_v, self.D)
157
+ # Synthetic U/S/Vt so downstream code runs unchanged
158
+ U = M_hat
159
+ S = M_hat.norm(dim=-2) # column norms as stand-in singular values
160
+ Vt = torch.eye(self.D, device=M.device, dtype=M.dtype
161
+ ).unsqueeze(0).expand(B * N, -1, -1)
162
+ elif self.svd_mode == 'fp32':
163
+ # Same algorithm as fp64 path but without the autocast disable
164
+ G = torch.bmm(M.transpose(1, 2), M)
165
+ G.diagonal(dim1=-2, dim2=-1).add_(1e-6) # relaxed reg for fp32
166
+ eigenvalues, Vmat = torch.linalg.eigh(G)
167
+ eigenvalues = eigenvalues.flip(-1)
168
+ Vmat = Vmat.flip(-1)
169
+ S = torch.sqrt(eigenvalues.clamp(min=1e-12))
170
+ U = torch.bmm(M, Vmat) / S.unsqueeze(1).clamp(min=1e-8)
171
+ Vt = Vmat.transpose(-2, -1).contiguous()
172
+ elif self.svd_mode == 'batch_shared':
173
+ # One SVD per batch instead of per patch; S and Vt replicated across N
174
+ M_batched = M.reshape(B, N * self.matrix_v, self.D)
175
+ U_b, S_b, Vt_b = _svd_fp64(M_batched)
176
+ S = S_b.unsqueeze(1).expand(-1, N, -1).reshape(B * N, self.D)
177
+ Vt = Vt_b.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, self.D, self.D)
178
+ U = torch.bmm(M, Vt.transpose(-2, -1)) / S.unsqueeze(1).clamp(min=1e-16)
179
+ else: # 'fp64' default
180
+ U, S, Vt = _svd_fp64(M)
181
+
182
+ U = U.reshape(B, N, self.matrix_v, self.D)
183
+ S = S.reshape(B, N, self.D)
184
+ Vt = Vt.reshape(B, N, self.D, self.D)
185
+ M = M.reshape(B, N, self.matrix_v, self.D)
186
+ S_coord = S
187
+ for layer in self.cross_attn:
188
+ S_coord = layer(S_coord)
189
+ return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M}
190
+
191
+
192
+ # ────────────────────────────────────────────────────────────────────────
193
+ # Optimizer / scheduler builders
194
+ # ────────────────────────────────────────────────────────────────────────
195
+
196
+ def build_optimizer(model: nn.Module, overrides: Dict[str, Any],
197
+ base_lr: float) -> torch.optim.Optimizer:
198
+ """Groups C, M: optimizer selection."""
199
+ opt_name = overrides.get('optimizer', 'adam')
200
+ lr = overrides.get('lr', base_lr)
201
+ wd = overrides.get('weight_decay', 0.0)
202
+ momentum = overrides.get('momentum', 0.0)
203
+
204
+ if opt_name == 'adam':
205
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
206
+ elif opt_name == 'adamw':
207
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
208
+ elif opt_name == 'sgd':
209
+ return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
210
+ elif opt_name == 'lbfgs':
211
+ return torch.optim.LBFGS(model.parameters(), lr=lr,
212
+ max_iter=20, history_size=10)
213
+ else:
214
+ raise ValueError(f"unknown optimizer: {opt_name}")
215
+
216
+
217
+ def build_scheduler(opt: torch.optim.Optimizer, overrides: Dict[str, Any],
218
+ total_steps: int):
219
+ """Group D: scheduler selection."""
220
+ sched_name = overrides.get('scheduler', 'cosine')
221
+
222
+ if sched_name == 'cosine':
223
+ return torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)
224
+ elif sched_name == 'constant':
225
+ return None
226
+ elif sched_name == 'linear':
227
+ return torch.optim.lr_scheduler.LinearLR(
228
+ opt, start_factor=1.0, end_factor=0.01, total_iters=total_steps)
229
+ elif sched_name == 'cosine_warm_restarts':
230
+ T_0 = overrides.get('T_0', 1000)
231
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=T_0)
232
+ elif sched_name == 'one_cycle':
233
+ return torch.optim.lr_scheduler.OneCycleLR(
234
+ opt, max_lr=opt.param_groups[0]['lr'], total_steps=total_steps)
235
+ else:
236
+ raise ValueError(f"unknown scheduler: {sched_name}")
237
+
238
+
239
+ # ────────────────────────────────────────────────────────────────────────
240
+ # Group N β€” uniform sphere CV prediction (cached per V,D)
241
+ # ────────────────────────────────────────────────────────────────────────
242
+
243
+ _UNIFORM_CV_CACHE: Dict[tuple, float] = {}
244
+
245
+
246
+ def uniform_sphere_cv_prediction(D: int, V: int = 64, n_samples: int = 2000,
247
+ device: str = 'cuda') -> float:
248
+ """CV prediction for uniformly random rows on S^(D-1)."""
249
+ key = (V, D, n_samples)
250
+ if key in _UNIFORM_CV_CACHE:
251
+ return _UNIFORM_CV_CACHE[key]
252
+
253
+ g = torch.Generator(device='cpu').manual_seed(12345)
254
+ M = torch.randn(V, D, generator=g, dtype=torch.float64)
255
+ M = M / M.norm(dim=-1, keepdim=True)
256
+ M = M.to(device) if torch.cuda.is_available() else M
257
+ cv = cv_of(M, n_samples=n_samples)
258
+ _UNIFORM_CV_CACHE[key] = cv
259
+ return cv
260
+
261
+
262
+ # ────────────────────────────────────────────────────────────────────────
263
+ # RunConfig construction from band + overrides
264
+ # ────────────────────────────────────────────────────────────────────────
265
+
266
+ def build_run_config(ablation_config: Dict[str, Any]) -> RunConfig:
267
+ """Build a RunConfig from band defaults plus the ablation's overrides."""
268
+ band = BAND_REPS[ablation_config['band']]
269
+ overrides = ablation_config['overrides']
270
+
271
+ cfg = RunConfig(
272
+ matrix_v=band['V'],
273
+ D=band['D'],
274
+ patch_size=band['patch_size'],
275
+ hidden=band['hidden'],
276
+ depth=band['depth'],
277
+ n_cross_layers=band['n_cross'],
278
+ img_size=band['img_size'],
279
+ batch_size=128,
280
+ lr=1e-4,
281
+ epochs=1,
282
+ weight_decay=0.0,
283
+ use_cv_ema=True,
284
+ cv_alignment_epochs=0,
285
+ cv_measure_every=50,
286
+ boost=0.5,
287
+ allowed_types=list(range(16)),
288
+ train_size=1_000_000,
289
+ val_size=10_000,
290
+ num_workers=2,
291
+ report_every=100,
292
+ seed=ablation_config['seed'],
293
+ upload=False,
294
+ )
295
+
296
+ # Field remappings: some override keys don't match RunConfig names
297
+ if 'noise_types' in overrides:
298
+ cfg = replace(cfg, allowed_types=overrides['noise_types'])
299
+ if 'V' in overrides:
300
+ cfg = replace(cfg, matrix_v=overrides['V'])
301
+ if 'n_cross' in overrides:
302
+ cfg = replace(cfg, n_cross_layers=overrides['n_cross'])
303
+
304
+ # Direct field overrides
305
+ direct_fields = {'batch_size', 'lr', 'weight_decay', 'boost',
306
+ 'allowed_types', 'n_cross_layers', 'max_alpha',
307
+ 'matrix_v', 'D', 'hidden', 'patch_size',
308
+ 'depth', 'n_heads', 'cv_measure_every'}
309
+ for k, v in overrides.items():
310
+ if k in direct_fields and k not in ('noise_types', 'V', 'n_cross'):
311
+ cfg = replace(cfg, **{k: v})
312
+
313
+ return cfg
314
+
315
+
316
+ # ────────────────────────────────────────────────────────────────────────
317
+ # Checkpoint save/load β€” resume-capable state
318
+ # ────────────────────────────────────────────────────────────────────────
319
+
320
+ def save_checkpoint(
321
+ ckpt_path: str,
322
+ epoch: int,
323
+ model: nn.Module,
324
+ opt: torch.optim.Optimizer,
325
+ sched: Optional[Any],
326
+ state: Dict[str, Any],
327
+ ablation_config: Dict[str, Any],
328
+ run_config: Any,
329
+ ) -> None:
330
+ """Save complete resumable state.
331
+
332
+ Includes everything needed to continue training:
333
+ - model weights
334
+ - optimizer state (momentum buffers, LBFGS history, etc.)
335
+ - LR scheduler state
336
+ - EMA / soft-hand state (cv_ema, recon_ema_obs, last_prox, last_cv)
337
+ - RNG state (torch, cuda, numpy) for reproducibility
338
+ - cv_trajectory list up to this epoch
339
+ - ablation_config and run_config (so we can verify match on resume)
340
+ - params_finite flag: True if all model parameters are finite
341
+ """
342
+ with torch.no_grad():
343
+ params_finite = all(torch.isfinite(p).all().item()
344
+ for p in model.parameters())
345
+
346
+ ckpt = {
347
+ 'epoch': epoch,
348
+ 'model_state': model.state_dict(),
349
+ 'optimizer_state': opt.state_dict(),
350
+ 'scheduler_state': sched.state_dict() if sched is not None else None,
351
+ 'ema_state': {
352
+ 'cv_ema': state.get('cv_ema'),
353
+ 'recon_ema_obs': state.get('recon_ema_obs'),
354
+ 'last_prox': state.get('last_prox', 1.0),
355
+ 'last_cv': state.get('last_cv', 0.0),
356
+ },
357
+ 'cv_trajectory': state.get('cv_trajectory', []),
358
+ 'global_batch': state.get('global_batch', 0),
359
+ 'rng_state': {
360
+ 'torch': torch.get_rng_state(),
361
+ 'numpy': np.random.get_state(),
362
+ 'cuda': (torch.cuda.get_rng_state_all()
363
+ if torch.cuda.is_available() else None),
364
+ },
365
+ 'ablation_config': ablation_config,
366
+ 'run_config': {k: v for k, v in asdict(run_config).items()
367
+ if isinstance(v, (int, float, str, bool, list))},
368
+ 'params_finite': params_finite,
369
+ }
370
+ torch.save(ckpt, ckpt_path)
371
+
372
+
373
+ def load_checkpoint(
374
+ ckpt_path: str,
375
+ model: nn.Module,
376
+ opt: torch.optim.Optimizer,
377
+ sched: Optional[Any] = None,
378
+ restore_rng: bool = True,
379
+ ) -> Dict[str, Any]:
380
+ """Load checkpoint into existing model/opt/sched and return state.
381
+
382
+ Returns dict with keys: epoch, ema_state, cv_trajectory, global_batch,
383
+ params_finite, ablation_config, run_config.
384
+ """
385
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
386
+
387
+ model.load_state_dict(ckpt['model_state'])
388
+ opt.load_state_dict(ckpt['optimizer_state'])
389
+ if sched is not None and ckpt.get('scheduler_state') is not None:
390
+ sched.load_state_dict(ckpt['scheduler_state'])
391
+
392
+ if restore_rng:
393
+ torch.set_rng_state(ckpt['rng_state']['torch'])
394
+ np.random.set_state(ckpt['rng_state']['numpy'])
395
+ if (torch.cuda.is_available()
396
+ and ckpt['rng_state'].get('cuda') is not None):
397
+ torch.cuda.set_rng_state_all(ckpt['rng_state']['cuda'])
398
+
399
+ return {
400
+ 'epoch': ckpt['epoch'],
401
+ 'ema_state': ckpt['ema_state'],
402
+ 'cv_trajectory': ckpt.get('cv_trajectory', []),
403
+ 'global_batch': ckpt.get('global_batch', 0),
404
+ 'params_finite': ckpt.get('params_finite', True),
405
+ 'ablation_config': ckpt['ablation_config'],
406
+ 'run_config': ckpt['run_config'],
407
+ }
408
+
409
+
410
+ # ────────────────────────────────────────────────────────────────────────
411
+ # Main ablation run function
412
+ # ────────────────────────────────────────────────────────────────────────
413
+
414
+ def run_ablation_config(
415
+ ablation_config: Dict[str, Any],
416
+ output_dir: str,
417
+ batch_limit: Optional[int] = 1000,
418
+ num_epochs: int = 1,
419
+ resume_from: Optional[str] = None,
420
+ ) -> Dict[str, Any]:
421
+ """Run one ablation config and return a result dict."""
422
+ cfg = build_run_config(ablation_config)
423
+ overrides = ablation_config['overrides']
424
+
425
+ torch.manual_seed(cfg.seed)
426
+ np.random.seed(cfg.seed)
427
+ torch.set_float32_matmul_precision('high')
428
+
429
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
430
+ os.makedirs(output_dir, exist_ok=True)
431
+
432
+ # ─── TensorBoard writer ───────────────────────────────────────
433
+ tb_dir = os.path.join(output_dir, "tensorboard")
434
+ os.makedirs(tb_dir, exist_ok=True)
435
+ from torch.utils.tensorboard import SummaryWriter
436
+ writer = SummaryWriter(tb_dir)
437
+
438
+ # ─── Model with ablation hooks ────────────────────────────────
439
+ model = PatchSVAE_F_Ablation(
440
+ matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
441
+ hidden=cfg.hidden, depth=cfg.depth,
442
+ n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
443
+ max_alpha=overrides.get('max_alpha', cfg.max_alpha),
444
+ alpha_init=cfg.alpha_init,
445
+ # ablation hooks
446
+ activation=overrides.get('activation', 'gelu'),
447
+ row_norm=overrides.get('row_norm', 'sphere'),
448
+ svd_mode=overrides.get('svd', 'fp64'),
449
+ linear_readout=overrides.get('linear_readout', False),
450
+ match_params=overrides.get('match_params', True),
451
+ init_scheme=overrides.get('init', 'orthogonal'),
452
+ ).to(device)
453
+
454
+ n_params = sum(p.numel() for p in model.parameters())
455
+
456
+ # ─── Data ─────────────────────────────────────────────────────
457
+ train_ds = OmegaNoiseDataset(
458
+ size=cfg.train_size, img_size=cfg.img_size,
459
+ allowed_types=cfg.allowed_types)
460
+ val_ds = OmegaNoiseDataset(
461
+ size=cfg.val_size, img_size=cfg.img_size,
462
+ allowed_types=cfg.allowed_types)
463
+ train_loader = torch.utils.data.DataLoader(
464
+ train_ds, batch_size=cfg.batch_size, shuffle=True,
465
+ num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
466
+ persistent_workers=cfg.num_workers > 0)
467
+ val_loader = torch.utils.data.DataLoader(
468
+ val_ds, batch_size=cfg.batch_size, shuffle=False,
469
+ num_workers=cfg.num_workers, pin_memory=True,
470
+ persistent_workers=cfg.num_workers > 0)
471
+
472
+ # ─── Optimizer + scheduler ────────────────────────────────────
473
+ effective_steps = batch_limit if batch_limit else (cfg.train_size // cfg.batch_size)
474
+ opt = build_optimizer(model, overrides, cfg.lr)
475
+ sched = build_scheduler(opt, overrides, total_steps=effective_steps)
476
+ grad_clip = overrides.get('grad_clip', None)
477
+
478
+ # ─── Soft-hand variants (Group E) ─────────────────────────────
479
+ use_soft_hand = overrides.get('soft_hand', True)
480
+ cv_penalty = overrides.get('cv_penalty', 0.0)
481
+ hard_cv_target = overrides.get('hard_cv_target', None)
482
+ cv_measurement_only = overrides.get('cv_measurement_only', False)
483
+ boost_factor = cfg.boost if use_soft_hand else 0.0
484
+
485
+ # ─── Training loop ────────────────────────────────────────────
486
+ start_time = time.time()
487
+ model.train()
488
+
489
+ # State initialization (may be overwritten by resume)
490
+ last_cv = 0.0
491
+ cv_ema = None
492
+ recon_ema_obs = None
493
+ last_prox = 1.0
494
+ cv_trajectory = []
495
+ train_loss_trajectory = [] # per-step recon MSE, independent of CV measurement
496
+ global_batch = 0
497
+ start_epoch = 0
498
+
499
+ # ─── Resume from checkpoint if provided ───────────────────────
500
+ if resume_from is not None:
501
+ resumed = load_checkpoint(resume_from, model, opt, sched, restore_rng=True)
502
+ start_epoch = resumed['epoch'] # next epoch to run
503
+ last_cv = resumed['ema_state'].get('last_cv', 0.0)
504
+ cv_ema = resumed['ema_state'].get('cv_ema')
505
+ recon_ema_obs = resumed['ema_state'].get('recon_ema_obs')
506
+ last_prox = resumed['ema_state'].get('last_prox', 1.0)
507
+ cv_trajectory = resumed.get('cv_trajectory', [])
508
+ global_batch = resumed.get('global_batch', 0)
509
+ print(f" Resumed from epoch {start_epoch}, global_batch {global_batch}")
510
+
511
+ # ─── Per-epoch tracking ───────────────────────────────────────
512
+ per_epoch_metrics = []
513
+
514
+ for epoch in range(start_epoch, start_epoch + num_epochs):
515
+ model.train()
516
+ epoch_start = time.time()
517
+ epoch_batch_target = batch_limit * (epoch + 1) if batch_limit else None
518
+
519
+ for images, _ in train_loader:
520
+ if epoch_batch_target is not None and global_batch >= epoch_batch_target:
521
+ break
522
+
523
+ images = images.to(device, non_blocking=True)
524
+ opt.zero_grad()
525
+
526
+ if isinstance(opt, torch.optim.LBFGS):
527
+ # ─── LBFGS path ──────────────────────────────────────
528
+ # Closure builds the SAME loss as the Adam path: pure MSE,
529
+ # plus soft-hand boost, plus optional hard_cv_target penalty.
530
+ # Closure may be called multiple times per outer step (line
531
+ # search); soft-hand uses last_prox from the previous outer
532
+ # batch's CV measurement, which is constant across inner calls.
533
+ #
534
+ # CRITICAL: NO gradient clipping inside the closure.
535
+ # LBFGS's Hessian approximation uses (s_k, y_k) = (Ξ”param,
536
+ # Ξ”grad) pairs across steps. Clipping grad norm bounds y_k
537
+ # artificially while s_k stays large, causing H β‰ˆ y/s to be
538
+ # underestimated β†’ H⁻¹ becomes huge β†’ step size explodes.
539
+ # Step safety for LBFGS comes from strong_wolfe line search,
540
+ # not gradient clipping. (Original bug fix in 000079 moved
541
+ # clipping INTO the closure; that fix itself was the bug β€”
542
+ # verified 2026-04-24 via Q-sweep Rank-1 G-MSE=4.5e26.)
543
+ def closure():
544
+ opt.zero_grad()
545
+ _out = model(images)
546
+ _recon = F.mse_loss(_out['recon'], images)
547
+
548
+ if use_soft_hand and not cv_measurement_only:
549
+ _recon_w = 1.0 + boost_factor * last_prox
550
+ _loss = _recon_w * _recon
551
+ else:
552
+ _loss = _recon
553
+
554
+ if hard_cv_target is not None and cv_ema is not None and cv_penalty > 0:
555
+ _cv_loss = (cv_ema - hard_cv_target) ** 2
556
+ _loss = _loss + cv_penalty * _cv_loss
557
+
558
+ _loss.backward()
559
+ # NO GRADIENT CLIPPING HERE β€” see comment above.
560
+ return _loss
561
+
562
+ opt.step(closure)
563
+ # Post-step forward to get the settled state for measurement
564
+ with torch.no_grad():
565
+ out = model(images)
566
+ recon_val = F.mse_loss(out['recon'], images).item()
567
+ else:
568
+ # ─── Adam / SGD / AdamW path ─────────────────────────
569
+ out = model(images)
570
+ recon_loss = F.mse_loss(out['recon'], images)
571
+ recon_val = recon_loss.item()
572
+
573
+ # Build loss with E-group ablations
574
+ if use_soft_hand and not cv_measurement_only:
575
+ recon_w = 1.0 + boost_factor * last_prox
576
+ loss = recon_w * recon_loss
577
+ else:
578
+ loss = recon_loss
579
+
580
+ if hard_cv_target is not None and cv_ema is not None and cv_penalty > 0:
581
+ cv_loss_val = (cv_ema - hard_cv_target) ** 2
582
+ loss = loss + cv_penalty * cv_loss_val
583
+
584
+ loss.backward()
585
+ torch.nn.utils.clip_grad_norm_(
586
+ model.cross_attn.parameters(), max_norm=cfg.cross_attn_clip)
587
+ if grad_clip is not None:
588
+ torch.nn.utils.clip_grad_norm_(
589
+ model.parameters(), max_norm=grad_clip)
590
+ opt.step()
591
+
592
+ # ─── Shared post-step measurement block ───────────────────
593
+ # Runs for BOTH LBFGS and non-LBFGS paths. Updates EMA, measures
594
+ # CV at intervals, computes prox for next batch's soft-hand.
595
+ with torch.no_grad():
596
+ if recon_ema_obs is None:
597
+ recon_ema_obs = recon_val
598
+ else:
599
+ recon_ema_obs = 0.99 * recon_ema_obs + 0.01 * recon_val
600
+
601
+ # Per-step loss trajectory β€” independent of CV measurement
602
+ # so small-V configs (where cv_of returns 0) still have a
603
+ # visible training curve.
604
+ train_loss_trajectory.append({
605
+ 'batch': global_batch,
606
+ 'recon': recon_val,
607
+ })
608
+
609
+ # TB: per-batch scalars (cheap)
610
+ writer.add_scalar('train/recon', recon_val, global_batch)
611
+ writer.add_scalar('train/recon_ema', recon_ema_obs, global_batch)
612
+ writer.add_scalar('train/lr', opt.param_groups[0]['lr'], global_batch)
613
+
614
+ if global_batch % cfg.cv_measure_every == 0:
615
+ current_cv = cv_of(out['svd']['M'][0, 0])
616
+ if current_cv > 0:
617
+ last_cv = current_cv
618
+ if cv_ema is None:
619
+ cv_ema = current_cv
620
+ else:
621
+ cv_ema = ((1.0 - cfg.cv_ema_alpha) * cv_ema
622
+ + cfg.cv_ema_alpha * current_cv)
623
+ cv_trajectory.append({
624
+ 'batch': global_batch,
625
+ 'cv': current_cv,
626
+ 'cv_ema': cv_ema,
627
+ 'recon': recon_val,
628
+ })
629
+ # TB: geometric measurements (when CV is measured)
630
+ writer.add_scalar('geo/cv', current_cv, global_batch)
631
+ writer.add_scalar('geo/cv_ema', cv_ema, global_batch)
632
+ # S spectrum diagnostic
633
+ S_now = out['svd']['S'][0, 0]
634
+ writer.add_scalar('geo/S0', S_now[0].item(), global_batch)
635
+ writer.add_scalar('geo/SD', S_now[-1].item(), global_batch)
636
+ writer.add_scalar('geo/ratio',
637
+ (S_now[0] / (S_now[-1] + 1e-8)).item(),
638
+ global_batch)
639
+ if cv_ema is not None and cv_ema > 1e-6:
640
+ sigma_adapt = max(cfg.cv_sigma_scale * cv_ema, 1e-6)
641
+ delta = last_cv - cv_ema
642
+ last_prox = math.exp(-(delta ** 2) / (2 * sigma_adapt ** 2))
643
+ writer.add_scalar('stab/prox', last_prox, global_batch)
644
+
645
+ if sched is not None:
646
+ sched.step()
647
+
648
+ global_batch += 1
649
+
650
+ # ─── Final evaluation ─────────────────────────────────────────
651
+ # PROPER TEST STAGE: evaluate on all 16 noise types separately.
652
+ # Previous behavior (one batch from val_loader using training
653
+ # allowed_types) was a validation metric, not a test metric β€”
654
+ # gaussian-only-trained batteries never saw pink/brown/poisson
655
+ # at eval time, so the old test_mse_final was invalid for
656
+ # detecting which noises the model suffered at.
657
+ model.eval()
658
+
659
+ test_noise_types = overrides.get('test_noise_types', list(range(16)))
660
+ test_samples_per_noise = overrides.get('test_samples_per_noise', 256)
661
+ test_batch_size = overrides.get('test_batch_size', 64)
662
+
663
+ test_mse_per_noise = {} # noise_type (int) β†’ mean MSE
664
+
665
+ # First: run one canonical batch for geometric measurements.
666
+ # Use gaussian (noise_type=0) so measurements are comparable
667
+ # across all configs regardless of training distribution.
668
+ with torch.no_grad():
669
+ geom_ds = OmegaNoiseDataset(
670
+ size=test_batch_size, img_size=cfg.img_size,
671
+ allowed_types=[0])
672
+ geom_loader = torch.utils.data.DataLoader(
673
+ geom_ds, batch_size=test_batch_size, shuffle=False,
674
+ num_workers=0, pin_memory=True, drop_last=True)
675
+ geom_imgs, _ = next(iter(geom_loader))
676
+ geom_imgs = geom_imgs.to(device)
677
+ t_out = model(geom_imgs)
678
+
679
+ final_cv = cv_of(t_out['svd']['M'][0, 0], n_samples=500)
680
+ S_final = t_out['svd']['S'].mean(dim=(0, 1))
681
+ S0 = S_final[0].item()
682
+ SD = S_final[-1].item()
683
+ ratio = S0 / (SD + 1e-8)
684
+ erank = PatchSVAE_F.effective_rank(
685
+ t_out['svd']['S'].reshape(-1, cfg.D)).mean().item()
686
+ observed_cv_precise = cv_of(
687
+ t_out['svd']['M'][0, 0], n_samples=2000)
688
+
689
+ # Second: per-noise test MSE. Separate dataset per noise type so
690
+ # each is pure, test_samples_per_noise samples each.
691
+ with torch.no_grad():
692
+ for nt in test_noise_types:
693
+ nt_ds = OmegaNoiseDataset(
694
+ size=test_samples_per_noise,
695
+ img_size=cfg.img_size,
696
+ allowed_types=[nt])
697
+ nt_loader = torch.utils.data.DataLoader(
698
+ nt_ds, batch_size=test_batch_size, shuffle=False,
699
+ num_workers=0, pin_memory=True, drop_last=False)
700
+ mse_chunks = []
701
+ for imgs, _ in nt_loader:
702
+ imgs = imgs.to(device)
703
+ out = model(imgs)
704
+ mse = F.mse_loss(out['recon'], imgs,
705
+ reduction='none').mean(dim=(1, 2, 3))
706
+ mse_chunks.append(mse)
707
+ test_mse_per_noise[nt] = torch.cat(mse_chunks).mean().item()
708
+
709
+ # Backward-compat aggregate β€” mean across tested noises
710
+ test_mse_final = sum(test_mse_per_noise.values()) / max(
711
+ 1, len(test_mse_per_noise))
712
+
713
+ uniform_cv = uniform_sphere_cv_prediction(
714
+ cfg.D, V=cfg.matrix_v,
715
+ device='cuda' if torch.cuda.is_available() else 'cpu')
716
+
717
+ # Band classification
718
+ classify_on = cv_ema if cv_ema is not None else final_cv
719
+ predicted_band = band_classifier(classify_on)
720
+ expected_band = ablation_config['band']
721
+
722
+ wallclock = time.time() - start_time
723
+
724
+ # Final TB summary scalars for this epoch (also written to epoch/* below)
725
+ writer.add_scalar('summary/test_mse_final', test_mse_final, global_batch)
726
+ writer.add_scalar('summary/cv_ema_final',
727
+ cv_ema if cv_ema is not None else 0.0, global_batch)
728
+ writer.add_scalar('summary/observed_sphere_cv', observed_cv_precise, global_batch)
729
+ writer.add_scalar('summary/uniform_sphere_cv_pred', uniform_cv, global_batch)
730
+ writer.add_scalar('summary/band_deviation',
731
+ observed_cv_precise - uniform_cv, global_batch)
732
+ writer.add_scalar('summary/erank', erank, global_batch)
733
+
734
+ # ─── Per-epoch checkpoint save ───────────────────────────────
735
+ ckpt_path = os.path.join(output_dir, f'epoch_{epoch+1}_checkpoint.pt')
736
+ save_checkpoint(
737
+ ckpt_path=ckpt_path,
738
+ epoch=epoch + 1, # next epoch to run on resume
739
+ model=model,
740
+ opt=opt,
741
+ sched=sched,
742
+ state={
743
+ 'cv_ema': cv_ema,
744
+ 'recon_ema_obs': recon_ema_obs,
745
+ 'last_prox': last_prox,
746
+ 'last_cv': last_cv,
747
+ 'cv_trajectory': cv_trajectory,
748
+ 'global_batch': global_batch,
749
+ },
750
+ ablation_config=ablation_config,
751
+ run_config=cfg,
752
+ )
753
+
754
+ # ─── Record per-epoch metrics ────────────────────────────────
755
+ with torch.no_grad():
756
+ params_finite = all(torch.isfinite(p).all().item()
757
+ for p in model.parameters())
758
+ per_epoch_metrics.append({
759
+ 'epoch': epoch + 1,
760
+ 'test_mse': test_mse_final,
761
+ 'test_mse_per_noise': {int(k): float(v)
762
+ for k, v in test_mse_per_noise.items()},
763
+ 'cv_ema': cv_ema if cv_ema is not None else 0.0,
764
+ 'observed_sphere_cv': observed_cv_precise,
765
+ 'band_deviation': observed_cv_precise - uniform_cv,
766
+ 'erank': erank,
767
+ 'params_finite': params_finite,
768
+ 'wallclock_seconds': time.time() - epoch_start,
769
+ 'checkpoint_path': ckpt_path,
770
+ })
771
+
772
+ # TB: per-epoch summary
773
+ writer.add_scalar('epoch/test_mse', test_mse_final, epoch + 1)
774
+ writer.add_scalar('epoch/cv_ema', cv_ema if cv_ema is not None else 0.0, epoch + 1)
775
+ writer.add_scalar('epoch/observed_sphere_cv', observed_cv_precise, epoch + 1)
776
+
777
+ # ─── End of epoch loop ────────────────────────────────────────────
778
+ writer.flush()
779
+ writer.close()
780
+
781
+ # Compute params_finite once more at the very end
782
+ with torch.no_grad():
783
+ final_params_finite = all(torch.isfinite(p).all().item()
784
+ for p in model.parameters())
785
+
786
+ wallclock = time.time() - start_time
787
+
788
+ return {
789
+ 'config': ablation_config,
790
+ 'run_config': {k: v for k, v in asdict(cfg).items()
791
+ if isinstance(v, (int, float, str, bool, list))},
792
+ # Classification
793
+ 'cv_ema_final': cv_ema if cv_ema is not None else 0.0,
794
+ 'cv_last': last_cv,
795
+ 'predicted_band': predicted_band,
796
+ 'expected_band': expected_band,
797
+ 'band_match': predicted_band == expected_band,
798
+ # Reconstruction
799
+ 'test_mse': test_mse_final,
800
+ 'test_mse_per_noise': {int(k): float(v)
801
+ for k, v in test_mse_per_noise.items()},
802
+ 'recon_ema': recon_ema_obs if recon_ema_obs is not None else 0.0,
803
+ # Geometry
804
+ 'S0': S0,
805
+ 'SD': SD,
806
+ 'ratio': ratio,
807
+ 'erank': erank,
808
+ # Group N
809
+ 'observed_sphere_cv': observed_cv_precise,
810
+ 'uniform_sphere_cv_prediction': uniform_cv,
811
+ 'band_deviation': observed_cv_precise - uniform_cv,
812
+ # Finite-params flag (False means training went to NaN/Inf)
813
+ 'params_finite': final_params_finite,
814
+ # Multi-epoch tracking
815
+ 'num_epochs_run': num_epochs,
816
+ 'start_epoch': start_epoch,
817
+ 'per_epoch_metrics': per_epoch_metrics,
818
+ # Bookkeeping
819
+ 'params_count': n_params,
820
+ 'wallclock_seconds': wallclock,
821
+ 'batches_completed': global_batch,
822
+ 'batch_limit': batch_limit,
823
+ 'cv_trajectory': cv_trajectory,
824
+ 'train_loss_trajectory': train_loss_trajectory,
825
+ }