| """ |
| ablation_trainer.py |
| =================== |
| Ablation trainer adapter for PatchSVAE_F (Johanna F-class). |
| |
| Takes an ablation config dict, builds a proper RunConfig with overrides, |
| instantiates a PatchSVAE_F_Ablation subclass with the needed hooks, |
| runs the real training loop with batch-limit early stop, measures CV |
| throughout, computes Group N uniformity diagnostic at the end, and |
| returns a result dict ready for upload. |
| |
| Imports from johanna_F_trainer.py. Drop this in alongside it in Colab. |
| |
| Ablation hooks implemented: |
| Group A (seeds): pure seed variation via RunConfig.seed |
| Group B (noise types): overrides['noise_types'] β RunConfig.allowed_types |
| Group C (optimizer): adam/sgd/adamw/lbfgs via build_optimizer |
| Group D (scheduler): cosine/constant/linear/warm_restart/one_cycle |
| Group E (soft-hand): use_soft_hand + boost + cv_penalty + hard_cv_target |
| Group F (activation): enc_in activation function swap |
| Group G (row_norm): sphere/none/layer_norm/scale_only |
| Group H (SVD): fp64/fp32/batch_shared/linear_readout |
| Group I (cross-attn): n_cross_layers + max_alpha |
| Group J (capacity): V and hidden overrides (within LOW band) |
| Group K (batch size): batch_size override |
| Group L (init): orthogonal/kaiming/xavier/normal_small |
| Group M (brute SGD): optimizer + lr + momentum + grad_clip |
| """ |
|
|
| import os |
| import math |
| import time |
| from dataclasses import asdict, replace |
| from typing import Dict, Any, Optional, List |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| ACTIVATIONS = { |
| 'gelu': F.gelu, |
| 'relu': F.relu, |
| 'silu': F.silu, |
| 'tanh': torch.tanh, |
| 'identity': lambda x: x, |
| } |
|
|
|
|
| def row_normalize(M: torch.Tensor, mode: str) -> torch.Tensor: |
| """Group G: different row-normalization modes on the encoded matrix.""" |
| if mode == 'sphere': |
| return F.normalize(M, dim=-1) |
| elif mode == 'none': |
| return M |
| elif mode == 'layer_norm': |
| mean = M.mean(dim=-1, keepdim=True) |
| var = M.var(dim=-1, keepdim=True, unbiased=False) |
| return (M - mean) / (var + 1e-8).sqrt() |
| elif mode == 'scale_only': |
| |
| row_norms = M.norm(dim=-1, keepdim=True) |
| mean_norm = row_norms.mean(dim=-2, keepdim=True) |
| return M / (mean_norm + 1e-8) |
| else: |
| raise ValueError(f"unknown row_norm mode: {mode}") |
|
|
|
|
| def init_weights(module: nn.Module, scheme: str) -> None: |
| """Group L: initialization scheme applied to all Linear layers.""" |
| if scheme == 'orthogonal': |
| for m in module.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.orthogonal_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif scheme == 'kaiming_normal': |
| for m in module.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif scheme == 'xavier_uniform': |
| for m in module.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif scheme == 'normal_0_02': |
| for m in module.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
|
|
| class PatchSVAE_F_Ablation(PatchSVAE_F): |
| """PatchSVAE_F with ablation hooks for F/G/H/L groups. |
| |
| At default settings (gelu / sphere / fp64 / orthogonal / no linear |
| readout) this behaves identically to PatchSVAE_F. |
| """ |
| def __init__(self, *args, |
| activation: str = 'gelu', |
| row_norm: str = 'sphere', |
| svd_mode: str = 'fp64', |
| linear_readout: bool = False, |
| match_params: bool = True, |
| init_scheme: str = 'orthogonal', |
| **kwargs): |
| super().__init__(*args, **kwargs) |
| self.activation_fn = ACTIVATIONS[activation] |
| self.row_norm_mode = row_norm |
| self.svd_mode = svd_mode |
| self.linear_readout = linear_readout |
|
|
| if linear_readout: |
| readout_dim = self.matrix_v * self.D |
| if match_params: |
| self.readout = nn.Linear(readout_dim, readout_dim) |
| else: |
| self.readout = nn.Identity() |
|
|
| |
| if init_scheme != 'orthogonal': |
| init_weights(self, init_scheme) |
| |
| |
| nn.init.orthogonal_(self.enc_out.weight) |
|
|
| def encode_patches(self, patches): |
| B, N, _ = patches.shape |
| flat = patches.reshape(B * N, -1) |
|
|
| |
| h = self.activation_fn(self.enc_in(flat)) |
| for block in self.enc_blocks: |
| |
| |
| |
| h = h + block(h) |
|
|
| M = self.enc_out(h).reshape(B * N, self.matrix_v, self.D) |
|
|
| |
| M = row_normalize(M, self.row_norm_mode) |
|
|
| |
| if self.linear_readout: |
| flat_M = M.reshape(B * N, -1) |
| M_hat = self.readout(flat_M).reshape(B * N, self.matrix_v, self.D) |
| |
| U = M_hat |
| S = M_hat.norm(dim=-2) |
| Vt = torch.eye(self.D, device=M.device, dtype=M.dtype |
| ).unsqueeze(0).expand(B * N, -1, -1) |
| elif self.svd_mode == 'fp32': |
| |
| G = torch.bmm(M.transpose(1, 2), M) |
| G.diagonal(dim1=-2, dim2=-1).add_(1e-6) |
| eigenvalues, Vmat = torch.linalg.eigh(G) |
| eigenvalues = eigenvalues.flip(-1) |
| Vmat = Vmat.flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-12)) |
| U = torch.bmm(M, Vmat) / S.unsqueeze(1).clamp(min=1e-8) |
| Vt = Vmat.transpose(-2, -1).contiguous() |
| elif self.svd_mode == 'batch_shared': |
| |
| M_batched = M.reshape(B, N * self.matrix_v, self.D) |
| U_b, S_b, Vt_b = _svd_fp64(M_batched) |
| S = S_b.unsqueeze(1).expand(-1, N, -1).reshape(B * N, self.D) |
| Vt = Vt_b.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, self.D, self.D) |
| U = torch.bmm(M, Vt.transpose(-2, -1)) / S.unsqueeze(1).clamp(min=1e-16) |
| else: |
| U, S, Vt = _svd_fp64(M) |
|
|
| U = U.reshape(B, N, self.matrix_v, self.D) |
| S = S.reshape(B, N, self.D) |
| Vt = Vt.reshape(B, N, self.D, self.D) |
| M = M.reshape(B, N, self.matrix_v, self.D) |
| S_coord = S |
| for layer in self.cross_attn: |
| S_coord = layer(S_coord) |
| return {'U': U, 'S_orig': S, 'S': S_coord, 'Vt': Vt, 'M': M} |
|
|
|
|
| |
| |
| |
|
|
| def build_optimizer(model: nn.Module, overrides: Dict[str, Any], |
| base_lr: float) -> torch.optim.Optimizer: |
| """Groups C, M: optimizer selection.""" |
| opt_name = overrides.get('optimizer', 'adam') |
| lr = overrides.get('lr', base_lr) |
| wd = overrides.get('weight_decay', 0.0) |
| momentum = overrides.get('momentum', 0.0) |
|
|
| if opt_name == 'adam': |
| return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) |
| elif opt_name == 'adamw': |
| return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
| elif opt_name == 'sgd': |
| return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) |
| elif opt_name == 'lbfgs': |
| return torch.optim.LBFGS(model.parameters(), lr=lr, |
| max_iter=20, history_size=10) |
| else: |
| raise ValueError(f"unknown optimizer: {opt_name}") |
|
|
|
|
| def build_scheduler(opt: torch.optim.Optimizer, overrides: Dict[str, Any], |
| total_steps: int): |
| """Group D: scheduler selection.""" |
| sched_name = overrides.get('scheduler', 'cosine') |
|
|
| if sched_name == 'cosine': |
| return torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps) |
| elif sched_name == 'constant': |
| return None |
| elif sched_name == 'linear': |
| return torch.optim.lr_scheduler.LinearLR( |
| opt, start_factor=1.0, end_factor=0.01, total_iters=total_steps) |
| elif sched_name == 'cosine_warm_restarts': |
| T_0 = overrides.get('T_0', 1000) |
| return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=T_0) |
| elif sched_name == 'one_cycle': |
| return torch.optim.lr_scheduler.OneCycleLR( |
| opt, max_lr=opt.param_groups[0]['lr'], total_steps=total_steps) |
| else: |
| raise ValueError(f"unknown scheduler: {sched_name}") |
|
|
|
|
| |
| |
| |
|
|
| _UNIFORM_CV_CACHE: Dict[tuple, float] = {} |
|
|
|
|
| def uniform_sphere_cv_prediction(D: int, V: int = 64, n_samples: int = 2000, |
| device: str = 'cuda') -> float: |
| """CV prediction for uniformly random rows on S^(D-1).""" |
| key = (V, D, n_samples) |
| if key in _UNIFORM_CV_CACHE: |
| return _UNIFORM_CV_CACHE[key] |
|
|
| g = torch.Generator(device='cpu').manual_seed(12345) |
| M = torch.randn(V, D, generator=g, dtype=torch.float64) |
| M = M / M.norm(dim=-1, keepdim=True) |
| M = M.to(device) if torch.cuda.is_available() else M |
| cv = cv_of(M, n_samples=n_samples) |
| _UNIFORM_CV_CACHE[key] = cv |
| return cv |
|
|
|
|
| |
| |
| |
|
|
| def build_run_config(ablation_config: Dict[str, Any]) -> RunConfig: |
| """Build a RunConfig from band defaults plus the ablation's overrides.""" |
| band = BAND_REPS[ablation_config['band']] |
| overrides = ablation_config['overrides'] |
|
|
| cfg = RunConfig( |
| matrix_v=band['V'], |
| D=band['D'], |
| patch_size=band['patch_size'], |
| hidden=band['hidden'], |
| depth=band['depth'], |
| n_cross_layers=band['n_cross'], |
| img_size=band['img_size'], |
| batch_size=128, |
| lr=1e-4, |
| epochs=1, |
| weight_decay=0.0, |
| use_cv_ema=True, |
| cv_alignment_epochs=0, |
| cv_measure_every=50, |
| boost=0.5, |
| allowed_types=list(range(16)), |
| train_size=1_000_000, |
| val_size=10_000, |
| num_workers=2, |
| report_every=100, |
| seed=ablation_config['seed'], |
| upload=False, |
| ) |
|
|
| |
| if 'noise_types' in overrides: |
| cfg = replace(cfg, allowed_types=overrides['noise_types']) |
| if 'V' in overrides: |
| cfg = replace(cfg, matrix_v=overrides['V']) |
| if 'n_cross' in overrides: |
| cfg = replace(cfg, n_cross_layers=overrides['n_cross']) |
|
|
| |
| direct_fields = {'batch_size', 'lr', 'weight_decay', 'boost', |
| 'allowed_types', 'n_cross_layers', 'max_alpha', |
| 'matrix_v', 'D', 'hidden', 'patch_size', |
| 'depth', 'n_heads', 'cv_measure_every'} |
| for k, v in overrides.items(): |
| if k in direct_fields and k not in ('noise_types', 'V', 'n_cross'): |
| cfg = replace(cfg, **{k: v}) |
|
|
| return cfg |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint( |
| ckpt_path: str, |
| epoch: int, |
| model: nn.Module, |
| opt: torch.optim.Optimizer, |
| sched: Optional[Any], |
| state: Dict[str, Any], |
| ablation_config: Dict[str, Any], |
| run_config: Any, |
| ) -> None: |
| """Save complete resumable state. |
| |
| Includes everything needed to continue training: |
| - model weights |
| - optimizer state (momentum buffers, LBFGS history, etc.) |
| - LR scheduler state |
| - EMA / soft-hand state (cv_ema, recon_ema_obs, last_prox, last_cv) |
| - RNG state (torch, cuda, numpy) for reproducibility |
| - cv_trajectory list up to this epoch |
| - ablation_config and run_config (so we can verify match on resume) |
| - params_finite flag: True if all model parameters are finite |
| """ |
| with torch.no_grad(): |
| params_finite = all(torch.isfinite(p).all().item() |
| for p in model.parameters()) |
|
|
| ckpt = { |
| 'epoch': epoch, |
| 'model_state': model.state_dict(), |
| 'optimizer_state': opt.state_dict(), |
| 'scheduler_state': sched.state_dict() if sched is not None else None, |
| 'ema_state': { |
| 'cv_ema': state.get('cv_ema'), |
| 'recon_ema_obs': state.get('recon_ema_obs'), |
| 'last_prox': state.get('last_prox', 1.0), |
| 'last_cv': state.get('last_cv', 0.0), |
| }, |
| 'cv_trajectory': state.get('cv_trajectory', []), |
| 'global_batch': state.get('global_batch', 0), |
| 'rng_state': { |
| 'torch': torch.get_rng_state(), |
| 'numpy': np.random.get_state(), |
| 'cuda': (torch.cuda.get_rng_state_all() |
| if torch.cuda.is_available() else None), |
| }, |
| 'ablation_config': ablation_config, |
| 'run_config': {k: v for k, v in asdict(run_config).items() |
| if isinstance(v, (int, float, str, bool, list))}, |
| 'params_finite': params_finite, |
| } |
| torch.save(ckpt, ckpt_path) |
|
|
|
|
| def load_checkpoint( |
| ckpt_path: str, |
| model: nn.Module, |
| opt: torch.optim.Optimizer, |
| sched: Optional[Any] = None, |
| restore_rng: bool = True, |
| ) -> Dict[str, Any]: |
| """Load checkpoint into existing model/opt/sched and return state. |
| |
| Returns dict with keys: epoch, ema_state, cv_trajectory, global_batch, |
| params_finite, ablation_config, run_config. |
| """ |
| ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
|
|
| model.load_state_dict(ckpt['model_state']) |
| opt.load_state_dict(ckpt['optimizer_state']) |
| if sched is not None and ckpt.get('scheduler_state') is not None: |
| sched.load_state_dict(ckpt['scheduler_state']) |
|
|
| if restore_rng: |
| torch.set_rng_state(ckpt['rng_state']['torch']) |
| np.random.set_state(ckpt['rng_state']['numpy']) |
| if (torch.cuda.is_available() |
| and ckpt['rng_state'].get('cuda') is not None): |
| torch.cuda.set_rng_state_all(ckpt['rng_state']['cuda']) |
|
|
| return { |
| 'epoch': ckpt['epoch'], |
| 'ema_state': ckpt['ema_state'], |
| 'cv_trajectory': ckpt.get('cv_trajectory', []), |
| 'global_batch': ckpt.get('global_batch', 0), |
| 'params_finite': ckpt.get('params_finite', True), |
| 'ablation_config': ckpt['ablation_config'], |
| 'run_config': ckpt['run_config'], |
| } |
|
|
|
|
| |
| |
| |
|
|
| def run_ablation_config( |
| ablation_config: Dict[str, Any], |
| output_dir: str, |
| batch_limit: Optional[int] = 1000, |
| num_epochs: int = 1, |
| resume_from: Optional[str] = None, |
| ) -> Dict[str, Any]: |
| """Run one ablation config and return a result dict.""" |
| cfg = build_run_config(ablation_config) |
| overrides = ablation_config['overrides'] |
|
|
| torch.manual_seed(cfg.seed) |
| np.random.seed(cfg.seed) |
| torch.set_float32_matmul_precision('high') |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| tb_dir = os.path.join(output_dir, "tensorboard") |
| os.makedirs(tb_dir, exist_ok=True) |
| from torch.utils.tensorboard import SummaryWriter |
| writer = SummaryWriter(tb_dir) |
|
|
| |
| model = PatchSVAE_F_Ablation( |
| matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size, |
| hidden=cfg.hidden, depth=cfg.depth, |
| n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads, |
| max_alpha=overrides.get('max_alpha', cfg.max_alpha), |
| alpha_init=cfg.alpha_init, |
| |
| activation=overrides.get('activation', 'gelu'), |
| row_norm=overrides.get('row_norm', 'sphere'), |
| svd_mode=overrides.get('svd', 'fp64'), |
| linear_readout=overrides.get('linear_readout', False), |
| match_params=overrides.get('match_params', True), |
| init_scheme=overrides.get('init', 'orthogonal'), |
| ).to(device) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
|
|
| |
| train_ds = OmegaNoiseDataset( |
| size=cfg.train_size, img_size=cfg.img_size, |
| allowed_types=cfg.allowed_types) |
| val_ds = OmegaNoiseDataset( |
| size=cfg.val_size, img_size=cfg.img_size, |
| allowed_types=cfg.allowed_types) |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=cfg.num_workers, pin_memory=True, drop_last=True, |
| persistent_workers=cfg.num_workers > 0) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=cfg.num_workers, pin_memory=True, |
| persistent_workers=cfg.num_workers > 0) |
|
|
| |
| effective_steps = batch_limit if batch_limit else (cfg.train_size // cfg.batch_size) |
| opt = build_optimizer(model, overrides, cfg.lr) |
| sched = build_scheduler(opt, overrides, total_steps=effective_steps) |
| grad_clip = overrides.get('grad_clip', None) |
|
|
| |
| use_soft_hand = overrides.get('soft_hand', True) |
| cv_penalty = overrides.get('cv_penalty', 0.0) |
| hard_cv_target = overrides.get('hard_cv_target', None) |
| cv_measurement_only = overrides.get('cv_measurement_only', False) |
| boost_factor = cfg.boost if use_soft_hand else 0.0 |
|
|
| |
| start_time = time.time() |
| model.train() |
|
|
| |
| last_cv = 0.0 |
| cv_ema = None |
| recon_ema_obs = None |
| last_prox = 1.0 |
| cv_trajectory = [] |
| train_loss_trajectory = [] |
| global_batch = 0 |
| start_epoch = 0 |
|
|
| |
| if resume_from is not None: |
| resumed = load_checkpoint(resume_from, model, opt, sched, restore_rng=True) |
| start_epoch = resumed['epoch'] |
| last_cv = resumed['ema_state'].get('last_cv', 0.0) |
| cv_ema = resumed['ema_state'].get('cv_ema') |
| recon_ema_obs = resumed['ema_state'].get('recon_ema_obs') |
| last_prox = resumed['ema_state'].get('last_prox', 1.0) |
| cv_trajectory = resumed.get('cv_trajectory', []) |
| global_batch = resumed.get('global_batch', 0) |
| print(f" Resumed from epoch {start_epoch}, global_batch {global_batch}") |
|
|
| |
| per_epoch_metrics = [] |
|
|
| for epoch in range(start_epoch, start_epoch + num_epochs): |
| model.train() |
| epoch_start = time.time() |
| epoch_batch_target = batch_limit * (epoch + 1) if batch_limit else None |
|
|
| for images, _ in train_loader: |
| if epoch_batch_target is not None and global_batch >= epoch_batch_target: |
| break |
|
|
| images = images.to(device, non_blocking=True) |
| opt.zero_grad() |
|
|
| if isinstance(opt, torch.optim.LBFGS): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def closure(): |
| opt.zero_grad() |
| _out = model(images) |
| _recon = F.mse_loss(_out['recon'], images) |
|
|
| if use_soft_hand and not cv_measurement_only: |
| _recon_w = 1.0 + boost_factor * last_prox |
| _loss = _recon_w * _recon |
| else: |
| _loss = _recon |
|
|
| if hard_cv_target is not None and cv_ema is not None and cv_penalty > 0: |
| _cv_loss = (cv_ema - hard_cv_target) ** 2 |
| _loss = _loss + cv_penalty * _cv_loss |
|
|
| _loss.backward() |
| |
| return _loss |
|
|
| opt.step(closure) |
| |
| with torch.no_grad(): |
| out = model(images) |
| recon_val = F.mse_loss(out['recon'], images).item() |
| else: |
| |
| out = model(images) |
| recon_loss = F.mse_loss(out['recon'], images) |
| recon_val = recon_loss.item() |
|
|
| |
| if use_soft_hand and not cv_measurement_only: |
| recon_w = 1.0 + boost_factor * last_prox |
| loss = recon_w * recon_loss |
| else: |
| loss = recon_loss |
|
|
| if hard_cv_target is not None and cv_ema is not None and cv_penalty > 0: |
| cv_loss_val = (cv_ema - hard_cv_target) ** 2 |
| loss = loss + cv_penalty * cv_loss_val |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| model.cross_attn.parameters(), max_norm=cfg.cross_attn_clip) |
| if grad_clip is not None: |
| torch.nn.utils.clip_grad_norm_( |
| model.parameters(), max_norm=grad_clip) |
| opt.step() |
|
|
| |
| |
| |
| with torch.no_grad(): |
| if recon_ema_obs is None: |
| recon_ema_obs = recon_val |
| else: |
| recon_ema_obs = 0.99 * recon_ema_obs + 0.01 * recon_val |
|
|
| |
| |
| |
| train_loss_trajectory.append({ |
| 'batch': global_batch, |
| 'recon': recon_val, |
| }) |
|
|
| |
| writer.add_scalar('train/recon', recon_val, global_batch) |
| writer.add_scalar('train/recon_ema', recon_ema_obs, global_batch) |
| writer.add_scalar('train/lr', opt.param_groups[0]['lr'], global_batch) |
|
|
| if global_batch % cfg.cv_measure_every == 0: |
| current_cv = cv_of(out['svd']['M'][0, 0]) |
| if current_cv > 0: |
| last_cv = current_cv |
| if cv_ema is None: |
| cv_ema = current_cv |
| else: |
| cv_ema = ((1.0 - cfg.cv_ema_alpha) * cv_ema |
| + cfg.cv_ema_alpha * current_cv) |
| cv_trajectory.append({ |
| 'batch': global_batch, |
| 'cv': current_cv, |
| 'cv_ema': cv_ema, |
| 'recon': recon_val, |
| }) |
| |
| writer.add_scalar('geo/cv', current_cv, global_batch) |
| writer.add_scalar('geo/cv_ema', cv_ema, global_batch) |
| |
| S_now = out['svd']['S'][0, 0] |
| writer.add_scalar('geo/S0', S_now[0].item(), global_batch) |
| writer.add_scalar('geo/SD', S_now[-1].item(), global_batch) |
| writer.add_scalar('geo/ratio', |
| (S_now[0] / (S_now[-1] + 1e-8)).item(), |
| global_batch) |
| if cv_ema is not None and cv_ema > 1e-6: |
| sigma_adapt = max(cfg.cv_sigma_scale * cv_ema, 1e-6) |
| delta = last_cv - cv_ema |
| last_prox = math.exp(-(delta ** 2) / (2 * sigma_adapt ** 2)) |
| writer.add_scalar('stab/prox', last_prox, global_batch) |
|
|
| if sched is not None: |
| sched.step() |
|
|
| global_batch += 1 |
|
|
| |
| |
| |
| |
| |
| |
| |
| model.eval() |
|
|
| test_noise_types = overrides.get('test_noise_types', list(range(16))) |
| test_samples_per_noise = overrides.get('test_samples_per_noise', 256) |
| test_batch_size = overrides.get('test_batch_size', 64) |
|
|
| test_mse_per_noise = {} |
|
|
| |
| |
| |
| with torch.no_grad(): |
| geom_ds = OmegaNoiseDataset( |
| size=test_batch_size, img_size=cfg.img_size, |
| allowed_types=[0]) |
| geom_loader = torch.utils.data.DataLoader( |
| geom_ds, batch_size=test_batch_size, shuffle=False, |
| num_workers=0, pin_memory=True, drop_last=True) |
| geom_imgs, _ = next(iter(geom_loader)) |
| geom_imgs = geom_imgs.to(device) |
| t_out = model(geom_imgs) |
|
|
| final_cv = cv_of(t_out['svd']['M'][0, 0], n_samples=500) |
| S_final = t_out['svd']['S'].mean(dim=(0, 1)) |
| S0 = S_final[0].item() |
| SD = S_final[-1].item() |
| ratio = S0 / (SD + 1e-8) |
| erank = PatchSVAE_F.effective_rank( |
| t_out['svd']['S'].reshape(-1, cfg.D)).mean().item() |
| observed_cv_precise = cv_of( |
| t_out['svd']['M'][0, 0], n_samples=2000) |
|
|
| |
| |
| with torch.no_grad(): |
| for nt in test_noise_types: |
| nt_ds = OmegaNoiseDataset( |
| size=test_samples_per_noise, |
| img_size=cfg.img_size, |
| allowed_types=[nt]) |
| nt_loader = torch.utils.data.DataLoader( |
| nt_ds, batch_size=test_batch_size, shuffle=False, |
| num_workers=0, pin_memory=True, drop_last=False) |
| mse_chunks = [] |
| for imgs, _ in nt_loader: |
| imgs = imgs.to(device) |
| out = model(imgs) |
| mse = F.mse_loss(out['recon'], imgs, |
| reduction='none').mean(dim=(1, 2, 3)) |
| mse_chunks.append(mse) |
| test_mse_per_noise[nt] = torch.cat(mse_chunks).mean().item() |
|
|
| |
| test_mse_final = sum(test_mse_per_noise.values()) / max( |
| 1, len(test_mse_per_noise)) |
|
|
| uniform_cv = uniform_sphere_cv_prediction( |
| cfg.D, V=cfg.matrix_v, |
| device='cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| classify_on = cv_ema if cv_ema is not None else final_cv |
| predicted_band = band_classifier(classify_on) |
| expected_band = ablation_config['band'] |
|
|
| wallclock = time.time() - start_time |
|
|
| |
| writer.add_scalar('summary/test_mse_final', test_mse_final, global_batch) |
| writer.add_scalar('summary/cv_ema_final', |
| cv_ema if cv_ema is not None else 0.0, global_batch) |
| writer.add_scalar('summary/observed_sphere_cv', observed_cv_precise, global_batch) |
| writer.add_scalar('summary/uniform_sphere_cv_pred', uniform_cv, global_batch) |
| writer.add_scalar('summary/band_deviation', |
| observed_cv_precise - uniform_cv, global_batch) |
| writer.add_scalar('summary/erank', erank, global_batch) |
|
|
| |
| ckpt_path = os.path.join(output_dir, f'epoch_{epoch+1}_checkpoint.pt') |
| save_checkpoint( |
| ckpt_path=ckpt_path, |
| epoch=epoch + 1, |
| model=model, |
| opt=opt, |
| sched=sched, |
| state={ |
| 'cv_ema': cv_ema, |
| 'recon_ema_obs': recon_ema_obs, |
| 'last_prox': last_prox, |
| 'last_cv': last_cv, |
| 'cv_trajectory': cv_trajectory, |
| 'global_batch': global_batch, |
| }, |
| ablation_config=ablation_config, |
| run_config=cfg, |
| ) |
|
|
| |
| with torch.no_grad(): |
| params_finite = all(torch.isfinite(p).all().item() |
| for p in model.parameters()) |
| per_epoch_metrics.append({ |
| 'epoch': epoch + 1, |
| 'test_mse': test_mse_final, |
| 'test_mse_per_noise': {int(k): float(v) |
| for k, v in test_mse_per_noise.items()}, |
| 'cv_ema': cv_ema if cv_ema is not None else 0.0, |
| 'observed_sphere_cv': observed_cv_precise, |
| 'band_deviation': observed_cv_precise - uniform_cv, |
| 'erank': erank, |
| 'params_finite': params_finite, |
| 'wallclock_seconds': time.time() - epoch_start, |
| 'checkpoint_path': ckpt_path, |
| }) |
|
|
| |
| writer.add_scalar('epoch/test_mse', test_mse_final, epoch + 1) |
| writer.add_scalar('epoch/cv_ema', cv_ema if cv_ema is not None else 0.0, epoch + 1) |
| writer.add_scalar('epoch/observed_sphere_cv', observed_cv_precise, epoch + 1) |
|
|
| |
| writer.flush() |
| writer.close() |
|
|
| |
| with torch.no_grad(): |
| final_params_finite = all(torch.isfinite(p).all().item() |
| for p in model.parameters()) |
|
|
| wallclock = time.time() - start_time |
|
|
| return { |
| 'config': ablation_config, |
| 'run_config': {k: v for k, v in asdict(cfg).items() |
| if isinstance(v, (int, float, str, bool, list))}, |
| |
| 'cv_ema_final': cv_ema if cv_ema is not None else 0.0, |
| 'cv_last': last_cv, |
| 'predicted_band': predicted_band, |
| 'expected_band': expected_band, |
| 'band_match': predicted_band == expected_band, |
| |
| 'test_mse': test_mse_final, |
| 'test_mse_per_noise': {int(k): float(v) |
| for k, v in test_mse_per_noise.items()}, |
| 'recon_ema': recon_ema_obs if recon_ema_obs is not None else 0.0, |
| |
| 'S0': S0, |
| 'SD': SD, |
| 'ratio': ratio, |
| 'erank': erank, |
| |
| 'observed_sphere_cv': observed_cv_precise, |
| 'uniform_sphere_cv_prediction': uniform_cv, |
| 'band_deviation': observed_cv_precise - uniform_cv, |
| |
| 'params_finite': final_params_finite, |
| |
| 'num_epochs_run': num_epochs, |
| 'start_epoch': start_epoch, |
| 'per_epoch_metrics': per_epoch_metrics, |
| |
| 'params_count': n_params, |
| 'wallclock_seconds': wallclock, |
| 'batches_completed': global_batch, |
| 'batch_limit': batch_limit, |
| 'cv_trajectory': cv_trajectory, |
| 'train_loss_trajectory': train_loss_trajectory, |
| } |