Create 4_monkey_patched_trainer.py
Browse files- 4_monkey_patched_trainer.py +825 -0
4_monkey_patched_trainer.py
ADDED
|
@@ -0,0 +1,825 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ablation_trainer.py
|
| 3 |
+
===================
|
| 4 |
+
Ablation trainer adapter for PatchSVAE_F (Johanna F-class).
|
| 5 |
+
|
| 6 |
+
Takes an ablation config dict, builds a proper RunConfig with overrides,
|
| 7 |
+
instantiates a PatchSVAE_F_Ablation subclass with the needed hooks,
|
| 8 |
+
runs the real training loop with batch-limit early stop, measures CV
|
| 9 |
+
throughout, computes Group N uniformity diagnostic at the end, and
|
| 10 |
+
returns a result dict ready for upload.
|
| 11 |
+
|
| 12 |
+
Imports from johanna_F_trainer.py. Drop this in alongside it in Colab.
|
| 13 |
+
|
| 14 |
+
Ablation hooks implemented:
|
| 15 |
+
Group A (seeds): pure seed variation via RunConfig.seed
|
| 16 |
+
Group B (noise types): overrides['noise_types'] β RunConfig.allowed_types
|
| 17 |
+
Group C (optimizer): adam/sgd/adamw/lbfgs via build_optimizer
|
| 18 |
+
Group D (scheduler): cosine/constant/linear/warm_restart/one_cycle
|
| 19 |
+
Group E (soft-hand): use_soft_hand + boost + cv_penalty + hard_cv_target
|
| 20 |
+
Group F (activation): enc_in activation function swap
|
| 21 |
+
Group G (row_norm): sphere/none/layer_norm/scale_only
|
| 22 |
+
Group H (SVD): fp64/fp32/batch_shared/linear_readout
|
| 23 |
+
Group I (cross-attn): n_cross_layers + max_alpha
|
| 24 |
+
Group J (capacity): V and hidden overrides (within LOW band)
|
| 25 |
+
Group K (batch size): batch_size override
|
| 26 |
+
Group L (init): orthogonal/kaiming/xavier/normal_small
|
| 27 |
+
Group M (brute SGD): optimizer + lr + momentum + grad_clip
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import os
|
| 31 |
+
import math
|
| 32 |
+
import time
|
| 33 |
+
from dataclasses import asdict, replace
|
| 34 |
+
from typing import Dict, Any, Optional, List
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
# Ablation hooks (Groups F/G/H/L) β implemented as a PatchSVAE_F subclass
|
| 44 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
|
| 46 |
+
ACTIVATIONS = {
|
| 47 |
+
'gelu': F.gelu,
|
| 48 |
+
'relu': F.relu,
|
| 49 |
+
'silu': F.silu,
|
| 50 |
+
'tanh': torch.tanh,
|
| 51 |
+
'identity': lambda x: x,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def row_normalize(M: torch.Tensor, mode: str) -> torch.Tensor:
|
| 56 |
+
"""Group G: different row-normalization modes on the encoded matrix."""
|
| 57 |
+
if mode == 'sphere':
|
| 58 |
+
return F.normalize(M, dim=-1)
|
| 59 |
+
elif mode == 'none':
|
| 60 |
+
return M
|
| 61 |
+
elif mode == 'layer_norm':
|
| 62 |
+
mean = M.mean(dim=-1, keepdim=True)
|
| 63 |
+
var = M.var(dim=-1, keepdim=True, unbiased=False)
|
| 64 |
+
return (M - mean) / (var + 1e-8).sqrt()
|
| 65 |
+
elif mode == 'scale_only':
|
| 66 |
+
# Divide each row by the batch-mean row norm β no unit constraint
|
| 67 |
+
row_norms = M.norm(dim=-1, keepdim=True)
|
| 68 |
+
mean_norm = row_norms.mean(dim=-2, keepdim=True)
|
| 69 |
+
return M / (mean_norm + 1e-8)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"unknown row_norm mode: {mode}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def init_weights(module: nn.Module, scheme: str) -> None:
|
| 75 |
+
"""Group L: initialization scheme applied to all Linear layers."""
|
| 76 |
+
if scheme == 'orthogonal':
|
| 77 |
+
for m in module.modules():
|
| 78 |
+
if isinstance(m, nn.Linear):
|
| 79 |
+
nn.init.orthogonal_(m.weight)
|
| 80 |
+
if m.bias is not None:
|
| 81 |
+
nn.init.zeros_(m.bias)
|
| 82 |
+
elif scheme == 'kaiming_normal':
|
| 83 |
+
for m in module.modules():
|
| 84 |
+
if isinstance(m, nn.Linear):
|
| 85 |
+
nn.init.kaiming_normal_(m.weight)
|
| 86 |
+
if m.bias is not None:
|
| 87 |
+
nn.init.zeros_(m.bias)
|
| 88 |
+
elif scheme == 'xavier_uniform':
|
| 89 |
+
for m in module.modules():
|
| 90 |
+
if isinstance(m, nn.Linear):
|
| 91 |
+
nn.init.xavier_uniform_(m.weight)
|
| 92 |
+
if m.bias is not None:
|
| 93 |
+
nn.init.zeros_(m.bias)
|
| 94 |
+
elif scheme == 'normal_0_02':
|
| 95 |
+
for m in module.modules():
|
| 96 |
+
if isinstance(m, nn.Linear):
|
| 97 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 98 |
+
if m.bias is not None:
|
| 99 |
+
nn.init.zeros_(m.bias)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class PatchSVAE_F_Ablation(PatchSVAE_F):
|
| 103 |
+
"""PatchSVAE_F with ablation hooks for F/G/H/L groups.
|
| 104 |
+
|
| 105 |
+
At default settings (gelu / sphere / fp64 / orthogonal / no linear
|
| 106 |
+
readout) this behaves identically to PatchSVAE_F.
|
| 107 |
+
"""
|
| 108 |
+
def __init__(self, *args,
|
| 109 |
+
activation: str = 'gelu',
|
| 110 |
+
row_norm: str = 'sphere',
|
| 111 |
+
svd_mode: str = 'fp64',
|
| 112 |
+
linear_readout: bool = False,
|
| 113 |
+
match_params: bool = True,
|
| 114 |
+
init_scheme: str = 'orthogonal',
|
| 115 |
+
**kwargs):
|
| 116 |
+
super().__init__(*args, **kwargs)
|
| 117 |
+
self.activation_fn = ACTIVATIONS[activation]
|
| 118 |
+
self.row_norm_mode = row_norm
|
| 119 |
+
self.svd_mode = svd_mode
|
| 120 |
+
self.linear_readout = linear_readout
|
| 121 |
+
|
| 122 |
+
if linear_readout:
|
| 123 |
+
readout_dim = self.matrix_v * self.D
|
| 124 |
+
if match_params:
|
| 125 |
+
self.readout = nn.Linear(readout_dim, readout_dim)
|
| 126 |
+
else:
|
| 127 |
+
self.readout = nn.Identity()
|
| 128 |
+
|
| 129 |
+
# Re-initialize under the requested scheme
|
| 130 |
+
if init_scheme != 'orthogonal':
|
| 131 |
+
init_weights(self, init_scheme)
|
| 132 |
+
# Re-apply orthogonal to enc_out specifically β load-bearing per
|
| 133 |
+
# the architecture docs
|
| 134 |
+
nn.init.orthogonal_(self.enc_out.weight)
|
| 135 |
+
|
| 136 |
+
def encode_patches(self, patches):
|
| 137 |
+
B, N, _ = patches.shape
|
| 138 |
+
flat = patches.reshape(B * N, -1)
|
| 139 |
+
|
| 140 |
+
# F group: activation swap on the enc_in
|
| 141 |
+
h = self.activation_fn(self.enc_in(flat))
|
| 142 |
+
for block in self.enc_blocks:
|
| 143 |
+
# Inner block activations (GELU inside the nn.Sequential) remain.
|
| 144 |
+
# For a complete F ablation this isn't a perfect swap but it's
|
| 145 |
+
# representative of the outer activation.
|
| 146 |
+
h = h + block(h)
|
| 147 |
+
|
| 148 |
+
M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D)
|
| 149 |
+
|
| 150 |
+
# G group: row normalization mode
|
| 151 |
+
M = row_normalize(M, self.row_norm_mode)
|
| 152 |
+
|
| 153 |
+
# H group: SVD variant or linear-readout replacement
|
| 154 |
+
if self.linear_readout:
|
| 155 |
+
flat_M = M.reshape(B * N, -1)
|
| 156 |
+
M_hat = self.readout(flat_M).reshape(B * N, self.matrix_v, self.D)
|
| 157 |
+
# Synthetic U/S/Vt so downstream code runs unchanged
|
| 158 |
+
U = M_hat
|
| 159 |
+
S = M_hat.norm(dim=-2) # column norms as stand-in singular values
|
| 160 |
+
Vt = torch.eye(self.D, device=M.device, dtype=M.dtype
|
| 161 |
+
).unsqueeze(0).expand(B * N, -1, -1)
|
| 162 |
+
elif self.svd_mode == 'fp32':
|
| 163 |
+
# Same algorithm as fp64 path but without the autocast disable
|
| 164 |
+
G = torch.bmm(M.transpose(1, 2), M)
|
| 165 |
+
G.diagonal(dim1=-2, dim2=-1).add_(1e-6) # relaxed reg for fp32
|
| 166 |
+
eigenvalues, Vmat = torch.linalg.eigh(G)
|
| 167 |
+
eigenvalues = eigenvalues.flip(-1)
|
| 168 |
+
Vmat = Vmat.flip(-1)
|
| 169 |
+
S = torch.sqrt(eigenvalues.clamp(min=1e-12))
|
| 170 |
+
U = torch.bmm(M, Vmat) / S.unsqueeze(1).clamp(min=1e-8)
|
| 171 |
+
Vt = Vmat.transpose(-2, -1).contiguous()
|
| 172 |
+
elif self.svd_mode == 'batch_shared':
|
| 173 |
+
# One SVD per batch instead of per patch; S and Vt replicated across N
|
| 174 |
+
M_batched = M.reshape(B, N * self.matrix_v, self.D)
|
| 175 |
+
U_b, S_b, Vt_b = _svd_fp64(M_batched)
|
| 176 |
+
S = S_b.unsqueeze(1).expand(-1, N, -1).reshape(B * N, self.D)
|
| 177 |
+
Vt = Vt_b.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, self.D, self.D)
|
| 178 |
+
U = torch.bmm(M, Vt.transpose(-2, -1)) / S.unsqueeze(1).clamp(min=1e-16)
|
| 179 |
+
else: # 'fp64' default
|
| 180 |
+
U, S, Vt = _svd_fp64(M)
|
| 181 |
+
|
| 182 |
+
U = U.reshape(B, N, self.matrix_v, self.D)
|
| 183 |
+
S = S.reshape(B, N, self.D)
|
| 184 |
+
Vt = Vt.reshape(B, N, self.D, self.D)
|
| 185 |
+
M = M.reshape(B, N, self.matrix_v, self.D)
|
| 186 |
+
S_coord = S
|
| 187 |
+
for layer in self.cross_attn:
|
| 188 |
+
S_coord = layer(S_coord)
|
| 189 |
+
return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 193 |
+
# Optimizer / scheduler builders
|
| 194 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 195 |
+
|
| 196 |
+
def build_optimizer(model: nn.Module, overrides: Dict[str, Any],
|
| 197 |
+
base_lr: float) -> torch.optim.Optimizer:
|
| 198 |
+
"""Groups C, M: optimizer selection."""
|
| 199 |
+
opt_name = overrides.get('optimizer', 'adam')
|
| 200 |
+
lr = overrides.get('lr', base_lr)
|
| 201 |
+
wd = overrides.get('weight_decay', 0.0)
|
| 202 |
+
momentum = overrides.get('momentum', 0.0)
|
| 203 |
+
|
| 204 |
+
if opt_name == 'adam':
|
| 205 |
+
return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
|
| 206 |
+
elif opt_name == 'adamw':
|
| 207 |
+
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 208 |
+
elif opt_name == 'sgd':
|
| 209 |
+
return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
|
| 210 |
+
elif opt_name == 'lbfgs':
|
| 211 |
+
return torch.optim.LBFGS(model.parameters(), lr=lr,
|
| 212 |
+
max_iter=20, history_size=10)
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError(f"unknown optimizer: {opt_name}")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def build_scheduler(opt: torch.optim.Optimizer, overrides: Dict[str, Any],
|
| 218 |
+
total_steps: int):
|
| 219 |
+
"""Group D: scheduler selection."""
|
| 220 |
+
sched_name = overrides.get('scheduler', 'cosine')
|
| 221 |
+
|
| 222 |
+
if sched_name == 'cosine':
|
| 223 |
+
return torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)
|
| 224 |
+
elif sched_name == 'constant':
|
| 225 |
+
return None
|
| 226 |
+
elif sched_name == 'linear':
|
| 227 |
+
return torch.optim.lr_scheduler.LinearLR(
|
| 228 |
+
opt, start_factor=1.0, end_factor=0.01, total_iters=total_steps)
|
| 229 |
+
elif sched_name == 'cosine_warm_restarts':
|
| 230 |
+
T_0 = overrides.get('T_0', 1000)
|
| 231 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=T_0)
|
| 232 |
+
elif sched_name == 'one_cycle':
|
| 233 |
+
return torch.optim.lr_scheduler.OneCycleLR(
|
| 234 |
+
opt, max_lr=opt.param_groups[0]['lr'], total_steps=total_steps)
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(f"unknown scheduler: {sched_name}")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
+
# Group N β uniform sphere CV prediction (cached per V,D)
|
| 241 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
+
|
| 243 |
+
_UNIFORM_CV_CACHE: Dict[tuple, float] = {}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def uniform_sphere_cv_prediction(D: int, V: int = 64, n_samples: int = 2000,
|
| 247 |
+
device: str = 'cuda') -> float:
|
| 248 |
+
"""CV prediction for uniformly random rows on S^(D-1)."""
|
| 249 |
+
key = (V, D, n_samples)
|
| 250 |
+
if key in _UNIFORM_CV_CACHE:
|
| 251 |
+
return _UNIFORM_CV_CACHE[key]
|
| 252 |
+
|
| 253 |
+
g = torch.Generator(device='cpu').manual_seed(12345)
|
| 254 |
+
M = torch.randn(V, D, generator=g, dtype=torch.float64)
|
| 255 |
+
M = M / M.norm(dim=-1, keepdim=True)
|
| 256 |
+
M = M.to(device) if torch.cuda.is_available() else M
|
| 257 |
+
cv = cv_of(M, n_samples=n_samples)
|
| 258 |
+
_UNIFORM_CV_CACHE[key] = cv
|
| 259 |
+
return cv
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 263 |
+
# RunConfig construction from band + overrides
|
| 264 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 265 |
+
|
| 266 |
+
def build_run_config(ablation_config: Dict[str, Any]) -> RunConfig:
|
| 267 |
+
"""Build a RunConfig from band defaults plus the ablation's overrides."""
|
| 268 |
+
band = BAND_REPS[ablation_config['band']]
|
| 269 |
+
overrides = ablation_config['overrides']
|
| 270 |
+
|
| 271 |
+
cfg = RunConfig(
|
| 272 |
+
matrix_v=band['V'],
|
| 273 |
+
D=band['D'],
|
| 274 |
+
patch_size=band['patch_size'],
|
| 275 |
+
hidden=band['hidden'],
|
| 276 |
+
depth=band['depth'],
|
| 277 |
+
n_cross_layers=band['n_cross'],
|
| 278 |
+
img_size=band['img_size'],
|
| 279 |
+
batch_size=128,
|
| 280 |
+
lr=1e-4,
|
| 281 |
+
epochs=1,
|
| 282 |
+
weight_decay=0.0,
|
| 283 |
+
use_cv_ema=True,
|
| 284 |
+
cv_alignment_epochs=0,
|
| 285 |
+
cv_measure_every=50,
|
| 286 |
+
boost=0.5,
|
| 287 |
+
allowed_types=list(range(16)),
|
| 288 |
+
train_size=1_000_000,
|
| 289 |
+
val_size=10_000,
|
| 290 |
+
num_workers=2,
|
| 291 |
+
report_every=100,
|
| 292 |
+
seed=ablation_config['seed'],
|
| 293 |
+
upload=False,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Field remappings: some override keys don't match RunConfig names
|
| 297 |
+
if 'noise_types' in overrides:
|
| 298 |
+
cfg = replace(cfg, allowed_types=overrides['noise_types'])
|
| 299 |
+
if 'V' in overrides:
|
| 300 |
+
cfg = replace(cfg, matrix_v=overrides['V'])
|
| 301 |
+
if 'n_cross' in overrides:
|
| 302 |
+
cfg = replace(cfg, n_cross_layers=overrides['n_cross'])
|
| 303 |
+
|
| 304 |
+
# Direct field overrides
|
| 305 |
+
direct_fields = {'batch_size', 'lr', 'weight_decay', 'boost',
|
| 306 |
+
'allowed_types', 'n_cross_layers', 'max_alpha',
|
| 307 |
+
'matrix_v', 'D', 'hidden', 'patch_size',
|
| 308 |
+
'depth', 'n_heads', 'cv_measure_every'}
|
| 309 |
+
for k, v in overrides.items():
|
| 310 |
+
if k in direct_fields and k not in ('noise_types', 'V', 'n_cross'):
|
| 311 |
+
cfg = replace(cfg, **{k: v})
|
| 312 |
+
|
| 313 |
+
return cfg
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 317 |
+
# Checkpoint save/load β resume-capable state
|
| 318 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 319 |
+
|
| 320 |
+
def save_checkpoint(
|
| 321 |
+
ckpt_path: str,
|
| 322 |
+
epoch: int,
|
| 323 |
+
model: nn.Module,
|
| 324 |
+
opt: torch.optim.Optimizer,
|
| 325 |
+
sched: Optional[Any],
|
| 326 |
+
state: Dict[str, Any],
|
| 327 |
+
ablation_config: Dict[str, Any],
|
| 328 |
+
run_config: Any,
|
| 329 |
+
) -> None:
|
| 330 |
+
"""Save complete resumable state.
|
| 331 |
+
|
| 332 |
+
Includes everything needed to continue training:
|
| 333 |
+
- model weights
|
| 334 |
+
- optimizer state (momentum buffers, LBFGS history, etc.)
|
| 335 |
+
- LR scheduler state
|
| 336 |
+
- EMA / soft-hand state (cv_ema, recon_ema_obs, last_prox, last_cv)
|
| 337 |
+
- RNG state (torch, cuda, numpy) for reproducibility
|
| 338 |
+
- cv_trajectory list up to this epoch
|
| 339 |
+
- ablation_config and run_config (so we can verify match on resume)
|
| 340 |
+
- params_finite flag: True if all model parameters are finite
|
| 341 |
+
"""
|
| 342 |
+
with torch.no_grad():
|
| 343 |
+
params_finite = all(torch.isfinite(p).all().item()
|
| 344 |
+
for p in model.parameters())
|
| 345 |
+
|
| 346 |
+
ckpt = {
|
| 347 |
+
'epoch': epoch,
|
| 348 |
+
'model_state': model.state_dict(),
|
| 349 |
+
'optimizer_state': opt.state_dict(),
|
| 350 |
+
'scheduler_state': sched.state_dict() if sched is not None else None,
|
| 351 |
+
'ema_state': {
|
| 352 |
+
'cv_ema': state.get('cv_ema'),
|
| 353 |
+
'recon_ema_obs': state.get('recon_ema_obs'),
|
| 354 |
+
'last_prox': state.get('last_prox', 1.0),
|
| 355 |
+
'last_cv': state.get('last_cv', 0.0),
|
| 356 |
+
},
|
| 357 |
+
'cv_trajectory': state.get('cv_trajectory', []),
|
| 358 |
+
'global_batch': state.get('global_batch', 0),
|
| 359 |
+
'rng_state': {
|
| 360 |
+
'torch': torch.get_rng_state(),
|
| 361 |
+
'numpy': np.random.get_state(),
|
| 362 |
+
'cuda': (torch.cuda.get_rng_state_all()
|
| 363 |
+
if torch.cuda.is_available() else None),
|
| 364 |
+
},
|
| 365 |
+
'ablation_config': ablation_config,
|
| 366 |
+
'run_config': {k: v for k, v in asdict(run_config).items()
|
| 367 |
+
if isinstance(v, (int, float, str, bool, list))},
|
| 368 |
+
'params_finite': params_finite,
|
| 369 |
+
}
|
| 370 |
+
torch.save(ckpt, ckpt_path)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def load_checkpoint(
|
| 374 |
+
ckpt_path: str,
|
| 375 |
+
model: nn.Module,
|
| 376 |
+
opt: torch.optim.Optimizer,
|
| 377 |
+
sched: Optional[Any] = None,
|
| 378 |
+
restore_rng: bool = True,
|
| 379 |
+
) -> Dict[str, Any]:
|
| 380 |
+
"""Load checkpoint into existing model/opt/sched and return state.
|
| 381 |
+
|
| 382 |
+
Returns dict with keys: epoch, ema_state, cv_trajectory, global_batch,
|
| 383 |
+
params_finite, ablation_config, run_config.
|
| 384 |
+
"""
|
| 385 |
+
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
|
| 386 |
+
|
| 387 |
+
model.load_state_dict(ckpt['model_state'])
|
| 388 |
+
opt.load_state_dict(ckpt['optimizer_state'])
|
| 389 |
+
if sched is not None and ckpt.get('scheduler_state') is not None:
|
| 390 |
+
sched.load_state_dict(ckpt['scheduler_state'])
|
| 391 |
+
|
| 392 |
+
if restore_rng:
|
| 393 |
+
torch.set_rng_state(ckpt['rng_state']['torch'])
|
| 394 |
+
np.random.set_state(ckpt['rng_state']['numpy'])
|
| 395 |
+
if (torch.cuda.is_available()
|
| 396 |
+
and ckpt['rng_state'].get('cuda') is not None):
|
| 397 |
+
torch.cuda.set_rng_state_all(ckpt['rng_state']['cuda'])
|
| 398 |
+
|
| 399 |
+
return {
|
| 400 |
+
'epoch': ckpt['epoch'],
|
| 401 |
+
'ema_state': ckpt['ema_state'],
|
| 402 |
+
'cv_trajectory': ckpt.get('cv_trajectory', []),
|
| 403 |
+
'global_batch': ckpt.get('global_batch', 0),
|
| 404 |
+
'params_finite': ckpt.get('params_finite', True),
|
| 405 |
+
'ablation_config': ckpt['ablation_config'],
|
| 406 |
+
'run_config': ckpt['run_config'],
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 411 |
+
# Main ablation run function
|
| 412 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 413 |
+
|
| 414 |
+
def run_ablation_config(
|
| 415 |
+
ablation_config: Dict[str, Any],
|
| 416 |
+
output_dir: str,
|
| 417 |
+
batch_limit: Optional[int] = 1000,
|
| 418 |
+
num_epochs: int = 1,
|
| 419 |
+
resume_from: Optional[str] = None,
|
| 420 |
+
) -> Dict[str, Any]:
|
| 421 |
+
"""Run one ablation config and return a result dict."""
|
| 422 |
+
cfg = build_run_config(ablation_config)
|
| 423 |
+
overrides = ablation_config['overrides']
|
| 424 |
+
|
| 425 |
+
torch.manual_seed(cfg.seed)
|
| 426 |
+
np.random.seed(cfg.seed)
|
| 427 |
+
torch.set_float32_matmul_precision('high')
|
| 428 |
+
|
| 429 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 430 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 431 |
+
|
| 432 |
+
# βββ TensorBoard writer βββββββββββββββββββββββββββββββββββββββ
|
| 433 |
+
tb_dir = os.path.join(output_dir, "tensorboard")
|
| 434 |
+
os.makedirs(tb_dir, exist_ok=True)
|
| 435 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 436 |
+
writer = SummaryWriter(tb_dir)
|
| 437 |
+
|
| 438 |
+
# βββ Model with ablation hooks ββββββββββββββββββββββββββββββββ
|
| 439 |
+
model = PatchSVAE_F_Ablation(
|
| 440 |
+
matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
|
| 441 |
+
hidden=cfg.hidden, depth=cfg.depth,
|
| 442 |
+
n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
|
| 443 |
+
max_alpha=overrides.get('max_alpha', cfg.max_alpha),
|
| 444 |
+
alpha_init=cfg.alpha_init,
|
| 445 |
+
# ablation hooks
|
| 446 |
+
activation=overrides.get('activation', 'gelu'),
|
| 447 |
+
row_norm=overrides.get('row_norm', 'sphere'),
|
| 448 |
+
svd_mode=overrides.get('svd', 'fp64'),
|
| 449 |
+
linear_readout=overrides.get('linear_readout', False),
|
| 450 |
+
match_params=overrides.get('match_params', True),
|
| 451 |
+
init_scheme=overrides.get('init', 'orthogonal'),
|
| 452 |
+
).to(device)
|
| 453 |
+
|
| 454 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 455 |
+
|
| 456 |
+
# βββ Data βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 457 |
+
train_ds = OmegaNoiseDataset(
|
| 458 |
+
size=cfg.train_size, img_size=cfg.img_size,
|
| 459 |
+
allowed_types=cfg.allowed_types)
|
| 460 |
+
val_ds = OmegaNoiseDataset(
|
| 461 |
+
size=cfg.val_size, img_size=cfg.img_size,
|
| 462 |
+
allowed_types=cfg.allowed_types)
|
| 463 |
+
train_loader = torch.utils.data.DataLoader(
|
| 464 |
+
train_ds, batch_size=cfg.batch_size, shuffle=True,
|
| 465 |
+
num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
|
| 466 |
+
persistent_workers=cfg.num_workers > 0)
|
| 467 |
+
val_loader = torch.utils.data.DataLoader(
|
| 468 |
+
val_ds, batch_size=cfg.batch_size, shuffle=False,
|
| 469 |
+
num_workers=cfg.num_workers, pin_memory=True,
|
| 470 |
+
persistent_workers=cfg.num_workers > 0)
|
| 471 |
+
|
| 472 |
+
# βββ Optimizer + scheduler ββββββββββββββββββββββββββββββββββββ
|
| 473 |
+
effective_steps = batch_limit if batch_limit else (cfg.train_size // cfg.batch_size)
|
| 474 |
+
opt = build_optimizer(model, overrides, cfg.lr)
|
| 475 |
+
sched = build_scheduler(opt, overrides, total_steps=effective_steps)
|
| 476 |
+
grad_clip = overrides.get('grad_clip', None)
|
| 477 |
+
|
| 478 |
+
# βββ Soft-hand variants (Group E) βββββββββββββββββββββββββββββ
|
| 479 |
+
use_soft_hand = overrides.get('soft_hand', True)
|
| 480 |
+
cv_penalty = overrides.get('cv_penalty', 0.0)
|
| 481 |
+
hard_cv_target = overrides.get('hard_cv_target', None)
|
| 482 |
+
cv_measurement_only = overrides.get('cv_measurement_only', False)
|
| 483 |
+
boost_factor = cfg.boost if use_soft_hand else 0.0
|
| 484 |
+
|
| 485 |
+
# βββ Training loop ββββββββββββββββββββββββββββββββββββββββββββ
|
| 486 |
+
start_time = time.time()
|
| 487 |
+
model.train()
|
| 488 |
+
|
| 489 |
+
# State initialization (may be overwritten by resume)
|
| 490 |
+
last_cv = 0.0
|
| 491 |
+
cv_ema = None
|
| 492 |
+
recon_ema_obs = None
|
| 493 |
+
last_prox = 1.0
|
| 494 |
+
cv_trajectory = []
|
| 495 |
+
train_loss_trajectory = [] # per-step recon MSE, independent of CV measurement
|
| 496 |
+
global_batch = 0
|
| 497 |
+
start_epoch = 0
|
| 498 |
+
|
| 499 |
+
# βββ Resume from checkpoint if provided βββββββββββββββββββββββ
|
| 500 |
+
if resume_from is not None:
|
| 501 |
+
resumed = load_checkpoint(resume_from, model, opt, sched, restore_rng=True)
|
| 502 |
+
start_epoch = resumed['epoch'] # next epoch to run
|
| 503 |
+
last_cv = resumed['ema_state'].get('last_cv', 0.0)
|
| 504 |
+
cv_ema = resumed['ema_state'].get('cv_ema')
|
| 505 |
+
recon_ema_obs = resumed['ema_state'].get('recon_ema_obs')
|
| 506 |
+
last_prox = resumed['ema_state'].get('last_prox', 1.0)
|
| 507 |
+
cv_trajectory = resumed.get('cv_trajectory', [])
|
| 508 |
+
global_batch = resumed.get('global_batch', 0)
|
| 509 |
+
print(f" Resumed from epoch {start_epoch}, global_batch {global_batch}")
|
| 510 |
+
|
| 511 |
+
# βββ Per-epoch tracking βββββββββββββββββββββββββββββββββββββββ
|
| 512 |
+
per_epoch_metrics = []
|
| 513 |
+
|
| 514 |
+
for epoch in range(start_epoch, start_epoch + num_epochs):
|
| 515 |
+
model.train()
|
| 516 |
+
epoch_start = time.time()
|
| 517 |
+
epoch_batch_target = batch_limit * (epoch + 1) if batch_limit else None
|
| 518 |
+
|
| 519 |
+
for images, _ in train_loader:
|
| 520 |
+
if epoch_batch_target is not None and global_batch >= epoch_batch_target:
|
| 521 |
+
break
|
| 522 |
+
|
| 523 |
+
images = images.to(device, non_blocking=True)
|
| 524 |
+
opt.zero_grad()
|
| 525 |
+
|
| 526 |
+
if isinstance(opt, torch.optim.LBFGS):
|
| 527 |
+
# βββ LBFGS path ββββββββββββββββββββββββββββββββββββββ
|
| 528 |
+
# Closure builds the SAME loss as the Adam path: pure MSE,
|
| 529 |
+
# plus soft-hand boost, plus optional hard_cv_target penalty.
|
| 530 |
+
# Closure may be called multiple times per outer step (line
|
| 531 |
+
# search); soft-hand uses last_prox from the previous outer
|
| 532 |
+
# batch's CV measurement, which is constant across inner calls.
|
| 533 |
+
#
|
| 534 |
+
# CRITICAL: NO gradient clipping inside the closure.
|
| 535 |
+
# LBFGS's Hessian approximation uses (s_k, y_k) = (Ξparam,
|
| 536 |
+
# Ξgrad) pairs across steps. Clipping grad norm bounds y_k
|
| 537 |
+
# artificially while s_k stays large, causing H β y/s to be
|
| 538 |
+
# underestimated β Hβ»ΒΉ becomes huge β step size explodes.
|
| 539 |
+
# Step safety for LBFGS comes from strong_wolfe line search,
|
| 540 |
+
# not gradient clipping. (Original bug fix in 000079 moved
|
| 541 |
+
# clipping INTO the closure; that fix itself was the bug β
|
| 542 |
+
# verified 2026-04-24 via Q-sweep Rank-1 G-MSE=4.5e26.)
|
| 543 |
+
def closure():
|
| 544 |
+
opt.zero_grad()
|
| 545 |
+
_out = model(images)
|
| 546 |
+
_recon = F.mse_loss(_out['recon'], images)
|
| 547 |
+
|
| 548 |
+
if use_soft_hand and not cv_measurement_only:
|
| 549 |
+
_recon_w = 1.0 + boost_factor * last_prox
|
| 550 |
+
_loss = _recon_w * _recon
|
| 551 |
+
else:
|
| 552 |
+
_loss = _recon
|
| 553 |
+
|
| 554 |
+
if hard_cv_target is not None and cv_ema is not None and cv_penalty > 0:
|
| 555 |
+
_cv_loss = (cv_ema - hard_cv_target) ** 2
|
| 556 |
+
_loss = _loss + cv_penalty * _cv_loss
|
| 557 |
+
|
| 558 |
+
_loss.backward()
|
| 559 |
+
# NO GRADIENT CLIPPING HERE β see comment above.
|
| 560 |
+
return _loss
|
| 561 |
+
|
| 562 |
+
opt.step(closure)
|
| 563 |
+
# Post-step forward to get the settled state for measurement
|
| 564 |
+
with torch.no_grad():
|
| 565 |
+
out = model(images)
|
| 566 |
+
recon_val = F.mse_loss(out['recon'], images).item()
|
| 567 |
+
else:
|
| 568 |
+
# βββ Adam / SGD / AdamW path βββββββββββββββββββββββββ
|
| 569 |
+
out = model(images)
|
| 570 |
+
recon_loss = F.mse_loss(out['recon'], images)
|
| 571 |
+
recon_val = recon_loss.item()
|
| 572 |
+
|
| 573 |
+
# Build loss with E-group ablations
|
| 574 |
+
if use_soft_hand and not cv_measurement_only:
|
| 575 |
+
recon_w = 1.0 + boost_factor * last_prox
|
| 576 |
+
loss = recon_w * recon_loss
|
| 577 |
+
else:
|
| 578 |
+
loss = recon_loss
|
| 579 |
+
|
| 580 |
+
if hard_cv_target is not None and cv_ema is not None and cv_penalty > 0:
|
| 581 |
+
cv_loss_val = (cv_ema - hard_cv_target) ** 2
|
| 582 |
+
loss = loss + cv_penalty * cv_loss_val
|
| 583 |
+
|
| 584 |
+
loss.backward()
|
| 585 |
+
torch.nn.utils.clip_grad_norm_(
|
| 586 |
+
model.cross_attn.parameters(), max_norm=cfg.cross_attn_clip)
|
| 587 |
+
if grad_clip is not None:
|
| 588 |
+
torch.nn.utils.clip_grad_norm_(
|
| 589 |
+
model.parameters(), max_norm=grad_clip)
|
| 590 |
+
opt.step()
|
| 591 |
+
|
| 592 |
+
# βββ Shared post-step measurement block βββββββββββββββββββ
|
| 593 |
+
# Runs for BOTH LBFGS and non-LBFGS paths. Updates EMA, measures
|
| 594 |
+
# CV at intervals, computes prox for next batch's soft-hand.
|
| 595 |
+
with torch.no_grad():
|
| 596 |
+
if recon_ema_obs is None:
|
| 597 |
+
recon_ema_obs = recon_val
|
| 598 |
+
else:
|
| 599 |
+
recon_ema_obs = 0.99 * recon_ema_obs + 0.01 * recon_val
|
| 600 |
+
|
| 601 |
+
# Per-step loss trajectory β independent of CV measurement
|
| 602 |
+
# so small-V configs (where cv_of returns 0) still have a
|
| 603 |
+
# visible training curve.
|
| 604 |
+
train_loss_trajectory.append({
|
| 605 |
+
'batch': global_batch,
|
| 606 |
+
'recon': recon_val,
|
| 607 |
+
})
|
| 608 |
+
|
| 609 |
+
# TB: per-batch scalars (cheap)
|
| 610 |
+
writer.add_scalar('train/recon', recon_val, global_batch)
|
| 611 |
+
writer.add_scalar('train/recon_ema', recon_ema_obs, global_batch)
|
| 612 |
+
writer.add_scalar('train/lr', opt.param_groups[0]['lr'], global_batch)
|
| 613 |
+
|
| 614 |
+
if global_batch % cfg.cv_measure_every == 0:
|
| 615 |
+
current_cv = cv_of(out['svd']['M'][0, 0])
|
| 616 |
+
if current_cv > 0:
|
| 617 |
+
last_cv = current_cv
|
| 618 |
+
if cv_ema is None:
|
| 619 |
+
cv_ema = current_cv
|
| 620 |
+
else:
|
| 621 |
+
cv_ema = ((1.0 - cfg.cv_ema_alpha) * cv_ema
|
| 622 |
+
+ cfg.cv_ema_alpha * current_cv)
|
| 623 |
+
cv_trajectory.append({
|
| 624 |
+
'batch': global_batch,
|
| 625 |
+
'cv': current_cv,
|
| 626 |
+
'cv_ema': cv_ema,
|
| 627 |
+
'recon': recon_val,
|
| 628 |
+
})
|
| 629 |
+
# TB: geometric measurements (when CV is measured)
|
| 630 |
+
writer.add_scalar('geo/cv', current_cv, global_batch)
|
| 631 |
+
writer.add_scalar('geo/cv_ema', cv_ema, global_batch)
|
| 632 |
+
# S spectrum diagnostic
|
| 633 |
+
S_now = out['svd']['S'][0, 0]
|
| 634 |
+
writer.add_scalar('geo/S0', S_now[0].item(), global_batch)
|
| 635 |
+
writer.add_scalar('geo/SD', S_now[-1].item(), global_batch)
|
| 636 |
+
writer.add_scalar('geo/ratio',
|
| 637 |
+
(S_now[0] / (S_now[-1] + 1e-8)).item(),
|
| 638 |
+
global_batch)
|
| 639 |
+
if cv_ema is not None and cv_ema > 1e-6:
|
| 640 |
+
sigma_adapt = max(cfg.cv_sigma_scale * cv_ema, 1e-6)
|
| 641 |
+
delta = last_cv - cv_ema
|
| 642 |
+
last_prox = math.exp(-(delta ** 2) / (2 * sigma_adapt ** 2))
|
| 643 |
+
writer.add_scalar('stab/prox', last_prox, global_batch)
|
| 644 |
+
|
| 645 |
+
if sched is not None:
|
| 646 |
+
sched.step()
|
| 647 |
+
|
| 648 |
+
global_batch += 1
|
| 649 |
+
|
| 650 |
+
# βββ Final evaluation βββββββββββββββββββββββββββββββββββββββββ
|
| 651 |
+
# PROPER TEST STAGE: evaluate on all 16 noise types separately.
|
| 652 |
+
# Previous behavior (one batch from val_loader using training
|
| 653 |
+
# allowed_types) was a validation metric, not a test metric β
|
| 654 |
+
# gaussian-only-trained batteries never saw pink/brown/poisson
|
| 655 |
+
# at eval time, so the old test_mse_final was invalid for
|
| 656 |
+
# detecting which noises the model suffered at.
|
| 657 |
+
model.eval()
|
| 658 |
+
|
| 659 |
+
test_noise_types = overrides.get('test_noise_types', list(range(16)))
|
| 660 |
+
test_samples_per_noise = overrides.get('test_samples_per_noise', 256)
|
| 661 |
+
test_batch_size = overrides.get('test_batch_size', 64)
|
| 662 |
+
|
| 663 |
+
test_mse_per_noise = {} # noise_type (int) β mean MSE
|
| 664 |
+
|
| 665 |
+
# First: run one canonical batch for geometric measurements.
|
| 666 |
+
# Use gaussian (noise_type=0) so measurements are comparable
|
| 667 |
+
# across all configs regardless of training distribution.
|
| 668 |
+
with torch.no_grad():
|
| 669 |
+
geom_ds = OmegaNoiseDataset(
|
| 670 |
+
size=test_batch_size, img_size=cfg.img_size,
|
| 671 |
+
allowed_types=[0])
|
| 672 |
+
geom_loader = torch.utils.data.DataLoader(
|
| 673 |
+
geom_ds, batch_size=test_batch_size, shuffle=False,
|
| 674 |
+
num_workers=0, pin_memory=True, drop_last=True)
|
| 675 |
+
geom_imgs, _ = next(iter(geom_loader))
|
| 676 |
+
geom_imgs = geom_imgs.to(device)
|
| 677 |
+
t_out = model(geom_imgs)
|
| 678 |
+
|
| 679 |
+
final_cv = cv_of(t_out['svd']['M'][0, 0], n_samples=500)
|
| 680 |
+
S_final = t_out['svd']['S'].mean(dim=(0, 1))
|
| 681 |
+
S0 = S_final[0].item()
|
| 682 |
+
SD = S_final[-1].item()
|
| 683 |
+
ratio = S0 / (SD + 1e-8)
|
| 684 |
+
erank = PatchSVAE_F.effective_rank(
|
| 685 |
+
t_out['svd']['S'].reshape(-1, cfg.D)).mean().item()
|
| 686 |
+
observed_cv_precise = cv_of(
|
| 687 |
+
t_out['svd']['M'][0, 0], n_samples=2000)
|
| 688 |
+
|
| 689 |
+
# Second: per-noise test MSE. Separate dataset per noise type so
|
| 690 |
+
# each is pure, test_samples_per_noise samples each.
|
| 691 |
+
with torch.no_grad():
|
| 692 |
+
for nt in test_noise_types:
|
| 693 |
+
nt_ds = OmegaNoiseDataset(
|
| 694 |
+
size=test_samples_per_noise,
|
| 695 |
+
img_size=cfg.img_size,
|
| 696 |
+
allowed_types=[nt])
|
| 697 |
+
nt_loader = torch.utils.data.DataLoader(
|
| 698 |
+
nt_ds, batch_size=test_batch_size, shuffle=False,
|
| 699 |
+
num_workers=0, pin_memory=True, drop_last=False)
|
| 700 |
+
mse_chunks = []
|
| 701 |
+
for imgs, _ in nt_loader:
|
| 702 |
+
imgs = imgs.to(device)
|
| 703 |
+
out = model(imgs)
|
| 704 |
+
mse = F.mse_loss(out['recon'], imgs,
|
| 705 |
+
reduction='none').mean(dim=(1, 2, 3))
|
| 706 |
+
mse_chunks.append(mse)
|
| 707 |
+
test_mse_per_noise[nt] = torch.cat(mse_chunks).mean().item()
|
| 708 |
+
|
| 709 |
+
# Backward-compat aggregate β mean across tested noises
|
| 710 |
+
test_mse_final = sum(test_mse_per_noise.values()) / max(
|
| 711 |
+
1, len(test_mse_per_noise))
|
| 712 |
+
|
| 713 |
+
uniform_cv = uniform_sphere_cv_prediction(
|
| 714 |
+
cfg.D, V=cfg.matrix_v,
|
| 715 |
+
device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 716 |
+
|
| 717 |
+
# Band classification
|
| 718 |
+
classify_on = cv_ema if cv_ema is not None else final_cv
|
| 719 |
+
predicted_band = band_classifier(classify_on)
|
| 720 |
+
expected_band = ablation_config['band']
|
| 721 |
+
|
| 722 |
+
wallclock = time.time() - start_time
|
| 723 |
+
|
| 724 |
+
# Final TB summary scalars for this epoch (also written to epoch/* below)
|
| 725 |
+
writer.add_scalar('summary/test_mse_final', test_mse_final, global_batch)
|
| 726 |
+
writer.add_scalar('summary/cv_ema_final',
|
| 727 |
+
cv_ema if cv_ema is not None else 0.0, global_batch)
|
| 728 |
+
writer.add_scalar('summary/observed_sphere_cv', observed_cv_precise, global_batch)
|
| 729 |
+
writer.add_scalar('summary/uniform_sphere_cv_pred', uniform_cv, global_batch)
|
| 730 |
+
writer.add_scalar('summary/band_deviation',
|
| 731 |
+
observed_cv_precise - uniform_cv, global_batch)
|
| 732 |
+
writer.add_scalar('summary/erank', erank, global_batch)
|
| 733 |
+
|
| 734 |
+
# βββ Per-epoch checkpoint save βββββββββββββββββββββββββββββββ
|
| 735 |
+
ckpt_path = os.path.join(output_dir, f'epoch_{epoch+1}_checkpoint.pt')
|
| 736 |
+
save_checkpoint(
|
| 737 |
+
ckpt_path=ckpt_path,
|
| 738 |
+
epoch=epoch + 1, # next epoch to run on resume
|
| 739 |
+
model=model,
|
| 740 |
+
opt=opt,
|
| 741 |
+
sched=sched,
|
| 742 |
+
state={
|
| 743 |
+
'cv_ema': cv_ema,
|
| 744 |
+
'recon_ema_obs': recon_ema_obs,
|
| 745 |
+
'last_prox': last_prox,
|
| 746 |
+
'last_cv': last_cv,
|
| 747 |
+
'cv_trajectory': cv_trajectory,
|
| 748 |
+
'global_batch': global_batch,
|
| 749 |
+
},
|
| 750 |
+
ablation_config=ablation_config,
|
| 751 |
+
run_config=cfg,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# βββ Record per-epoch metrics ββββββββββββββββββββββββββββββββ
|
| 755 |
+
with torch.no_grad():
|
| 756 |
+
params_finite = all(torch.isfinite(p).all().item()
|
| 757 |
+
for p in model.parameters())
|
| 758 |
+
per_epoch_metrics.append({
|
| 759 |
+
'epoch': epoch + 1,
|
| 760 |
+
'test_mse': test_mse_final,
|
| 761 |
+
'test_mse_per_noise': {int(k): float(v)
|
| 762 |
+
for k, v in test_mse_per_noise.items()},
|
| 763 |
+
'cv_ema': cv_ema if cv_ema is not None else 0.0,
|
| 764 |
+
'observed_sphere_cv': observed_cv_precise,
|
| 765 |
+
'band_deviation': observed_cv_precise - uniform_cv,
|
| 766 |
+
'erank': erank,
|
| 767 |
+
'params_finite': params_finite,
|
| 768 |
+
'wallclock_seconds': time.time() - epoch_start,
|
| 769 |
+
'checkpoint_path': ckpt_path,
|
| 770 |
+
})
|
| 771 |
+
|
| 772 |
+
# TB: per-epoch summary
|
| 773 |
+
writer.add_scalar('epoch/test_mse', test_mse_final, epoch + 1)
|
| 774 |
+
writer.add_scalar('epoch/cv_ema', cv_ema if cv_ema is not None else 0.0, epoch + 1)
|
| 775 |
+
writer.add_scalar('epoch/observed_sphere_cv', observed_cv_precise, epoch + 1)
|
| 776 |
+
|
| 777 |
+
# βββ End of epoch loop ββββββββββββββββββββββββββββββββββββββββββββ
|
| 778 |
+
writer.flush()
|
| 779 |
+
writer.close()
|
| 780 |
+
|
| 781 |
+
# Compute params_finite once more at the very end
|
| 782 |
+
with torch.no_grad():
|
| 783 |
+
final_params_finite = all(torch.isfinite(p).all().item()
|
| 784 |
+
for p in model.parameters())
|
| 785 |
+
|
| 786 |
+
wallclock = time.time() - start_time
|
| 787 |
+
|
| 788 |
+
return {
|
| 789 |
+
'config': ablation_config,
|
| 790 |
+
'run_config': {k: v for k, v in asdict(cfg).items()
|
| 791 |
+
if isinstance(v, (int, float, str, bool, list))},
|
| 792 |
+
# Classification
|
| 793 |
+
'cv_ema_final': cv_ema if cv_ema is not None else 0.0,
|
| 794 |
+
'cv_last': last_cv,
|
| 795 |
+
'predicted_band': predicted_band,
|
| 796 |
+
'expected_band': expected_band,
|
| 797 |
+
'band_match': predicted_band == expected_band,
|
| 798 |
+
# Reconstruction
|
| 799 |
+
'test_mse': test_mse_final,
|
| 800 |
+
'test_mse_per_noise': {int(k): float(v)
|
| 801 |
+
for k, v in test_mse_per_noise.items()},
|
| 802 |
+
'recon_ema': recon_ema_obs if recon_ema_obs is not None else 0.0,
|
| 803 |
+
# Geometry
|
| 804 |
+
'S0': S0,
|
| 805 |
+
'SD': SD,
|
| 806 |
+
'ratio': ratio,
|
| 807 |
+
'erank': erank,
|
| 808 |
+
# Group N
|
| 809 |
+
'observed_sphere_cv': observed_cv_precise,
|
| 810 |
+
'uniform_sphere_cv_prediction': uniform_cv,
|
| 811 |
+
'band_deviation': observed_cv_precise - uniform_cv,
|
| 812 |
+
# Finite-params flag (False means training went to NaN/Inf)
|
| 813 |
+
'params_finite': final_params_finite,
|
| 814 |
+
# Multi-epoch tracking
|
| 815 |
+
'num_epochs_run': num_epochs,
|
| 816 |
+
'start_epoch': start_epoch,
|
| 817 |
+
'per_epoch_metrics': per_epoch_metrics,
|
| 818 |
+
# Bookkeeping
|
| 819 |
+
'params_count': n_params,
|
| 820 |
+
'wallclock_seconds': wallclock,
|
| 821 |
+
'batches_completed': global_batch,
|
| 822 |
+
'batch_limit': batch_limit,
|
| 823 |
+
'cv_trajectory': cv_trajectory,
|
| 824 |
+
'train_loss_trajectory': train_loss_trajectory,
|
| 825 |
+
}
|