""" Twin Stereo Diffusion — Fresnel × Johanna Spectral Denoising ============================================================== Fresnel sees the clean image. Johanna sees the noise. Procrustes alignment between their spectral bases IS the noise. Training: clean image ──→ Fresnel ──→ (U_f, S_f, Vt_f) target noised image ──→ Johanna ──→ (U_j, S_j, Vt_j) input R = Procrustes(U_j → U_f) rotation = noise signature Denoiser(S_j, R, t, labels) → S_f predict clean magnitudes Inference: x_t ──→ Johanna ──→ S_j ──→ Denoiser ──→ S_pred decode(U_j, S_pred, Vt_j) ──→ x̂_0 flow step: x_{t-dt} final pass: x_0 ──→ Fresnel encode/decode ──→ crisp output """ import os import math import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as T import numpy as np from tqdm import tqdm try: from google.colab import userdata os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') from huggingface_hub import login login(token=os.environ["HF_TOKEN"]) except Exception: pass # ═══════════════════════════════════════════════════════════════ # FROZEN TWINS # ═══════════════════════════════════════════════════════════════ def load_twins(device='cuda'): """Load both frozen SVAE twins at 128×128.""" from geolip_svae import load_model fresnel, f_cfg = load_model(hf_version='v12_imagenet128', device=device) fresnel.eval() for p in fresnel.parameters(): p.requires_grad = False print(f" Fresnel-small loaded: {sum(p.numel() for p in fresnel.parameters()):,} params (frozen)") johanna, j_cfg = load_model(hf_version='v16_johanna_omega', device=device) johanna.eval() for p in johanna.parameters(): p.requires_grad = False print(f" Johanna-small loaded: {sum(p.numel() for p in johanna.parameters()):,} params (frozen)") return fresnel, johanna # ═══════════════════════════════════════════════════════════════ # PROCRUSTES ALIGNMENT # ═══════════════════════════════════════════════════════════════ def batched_procrustes(A, B): """Find orthogonal R such that A @ R ≈ B. Args: A: (batch, M, D) — source (Johanna's U) B: (batch, M, D) — target (Fresnel's U) Returns: R: (batch, D, D) — orthogonal rotation """ M = torch.bmm(B.transpose(-2, -1), A) # (batch, D, D) U, S, Vt = torch.linalg.svd(M) return torch.bmm(Vt.transpose(-2, -1), U.transpose(-2, -1)) def compute_procrustes_features(U_j, U_f, D=16): """Compute per-patch Procrustes rotation and extract features. Args: U_j: (B, N, V, D) — Johanna's left singular vectors U_f: (B, N, V, D) — Fresnel's left singular vectors Returns: R: (B, N, D, D) — rotation matrices R_feat: (B, N, D*D) — flattened rotation for projection """ B, N, V, D = U_j.shape Uj = U_j.reshape(B * N, V, D) Uf = U_f.reshape(B * N, V, D) R = batched_procrustes(Uj, Uf) # (B*N, D, D) R = R.reshape(B, N, D, D) R_feat = R.reshape(B, N, D * D) return R, R_feat # ═══════════════════════════════════════════════════════════════ # TILED CIFAR-10 DATASET # ═══════════════════════════════════════════════════════════════ CIFAR_MEAN = (0.4914, 0.4822, 0.4465) CIFAR_STD = (0.2470, 0.2435, 0.2616) class TiledCIFAR(torch.utils.data.Dataset): """4 CIFAR-10 images (32→64) tiled 2×2 into 128×128.""" def __init__(self, train=True, n_samples=50000): self.n_samples = n_samples self.cifar = torchvision.datasets.CIFAR10( root='./data', train=train, download=True, transform=T.Compose([ T.Resize(64, interpolation=T.InterpolationMode.BILINEAR), T.ToTensor(), T.Normalize(CIFAR_MEAN, CIFAR_STD), ])) self.n = len(self.cifar) def __len__(self): return self.n_samples def __getitem__(self, idx): ids = torch.randint(0, self.n, (4,)) imgs, labels = [], [] for i in ids: img, lab = self.cifar[i.item()] imgs.append(img) labels.append(lab) top = torch.cat([imgs[0], imgs[1]], dim=2) bot = torch.cat([imgs[2], imgs[3]], dim=2) return torch.cat([top, bot], dim=1), torch.tensor(labels, dtype=torch.long) # ═══════════════════════════════════════════════════════════════ # NOISE SCHEDULE # ═══════════════════════════════════════════════════════════════ def add_noise(x0, t): """Linear flow-matching interpolation: x_t = (1-t)*x0 + t*ε. Args: x0: (B, 3, 128, 128) clean images t: (B,) timesteps in [0, 1] Returns: x_t: noised images eps: the noise that was added """ eps = torch.randn_like(x0) t_exp = t.view(-1, 1, 1, 1) x_t = (1 - t_exp) * x0 + t_exp * eps return x_t, eps # ═══════════════════════════════════════════════════════════════ # SPECTRAL DENOISER # ═══════════════════════════════════════════════════════════════ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): half = self.dim // 2 emb = math.log(10000) / (half - 1) emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb) emb = t.unsqueeze(1) * emb.unsqueeze(0) return torch.cat([emb.sin(), emb.cos()], dim=1) class AdaLN(nn.Module): def __init__(self, dim, cond_dim): super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False) self.proj = nn.Linear(cond_dim, dim * 2) nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x, cond): s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1) return self.norm(x) * (1 + s[0]) + s[1] class StereoBlock(nn.Module): """Transformer block with AdaLN and Procrustes-conditioned cross-path.""" def __init__(self, dim, n_heads, cond_dim): super().__init__() self.adaln1 = AdaLN(dim, cond_dim) self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True) self.adaln2 = AdaLN(dim, cond_dim) self.ff = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) def forward(self, x, cond): h = self.adaln1(x, cond) h, _ = self.attn(h, h, h) x = x + h return x + self.ff(self.adaln2(x, cond)) class StereoDenoiser(nn.Module): """Predicts clean Fresnel omega tokens from noisy Johanna observations. Input: S_j (B, N, D) — Johanna's singular values R_feat (B, N, D²) — Procrustes rotation features t (B,) — noise level labels (B, 4) — tile class labels Output: S_f_pred (B, N, D) — predicted clean Fresnel singular values """ def __init__(self, n_patches=64, omega_dim=16, hidden=256, depth=8, n_heads=8, n_classes=10, n_tiles=4): super().__init__() self.omega_dim = omega_dim D2 = omega_dim * omega_dim # Input: omega tokens + Procrustes features self.input_proj = nn.Linear(omega_dim + D2, hidden) self.input_proj_no_R = nn.Linear(omega_dim, hidden) # Positional embedding self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02) # Timestep embedding self.time_emb = nn.Sequential( SinusoidalPosEmb(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) # Class embedding self.class_emb = nn.Embedding(n_classes, hidden // n_tiles) self.class_proj = nn.Linear(hidden, hidden) # Transformer blocks self.blocks = nn.ModuleList([ StereoBlock(hidden, n_heads, hidden) for _ in range(depth)]) # Output self.out_norm = nn.LayerNorm(hidden) self.out_proj = nn.Linear(hidden, omega_dim) nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) def forward(self, S_j, t, labels, R_feat=None): B = S_j.shape[0] # Project input (with or without Procrustes features) if R_feat is not None: h = self.input_proj(torch.cat([S_j, R_feat], dim=-1)) else: h = self.input_proj_no_R(S_j) h = h + self.pos_emb # Conditioning t_emb = self.time_emb(t) c_emb = self.class_proj(self.class_emb(labels).reshape(B, -1)) cond = t_emb + c_emb # Transformer for block in self.blocks: h = block(h, cond) # Predict residual: S_f ≈ S_j + correction return S_j + self.out_proj(self.out_norm(h)) # ═══════════════════════════════════════════════════════════════ # TRAINING # ═══════════════════════════════════════════════════════════════ def train(epochs=100, batch_size=64, lr=3e-4, hidden=256, depth=8, n_heads=8, n_train=50000, device='cuda'): device = torch.device(device if torch.cuda.is_available() else 'cpu') print("\n" + "=" * 70) print("TWIN STEREO DIFFUSION — Fresnel × Johanna") print("=" * 70) # ── Load frozen twins ── fresnel, johanna = load_twins(device) # ── Data ── train_ds = TiledCIFAR(train=True, n_samples=n_train) val_ds = TiledCIFAR(train=False, n_samples=5000) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) # ── Denoiser ── denoiser = StereoDenoiser( n_patches=64, omega_dim=16, hidden=hidden, depth=depth, n_heads=n_heads).to(device) n_params = sum(p.numel() for p in denoiser.parameters()) print(f"\n StereoDenoiser: {n_params:,} params") print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}") print(f" Training: {n_train} samples, batch={batch_size}") print(f" Pipeline: Johanna(noised) + Procrustes → predict Fresnel(clean)") print("=" * 70) opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) save_dir = '/content/stereo_checkpoints' os.makedirs(save_dir, exist_ok=True) best_val = float('inf') for epoch in range(1, epochs + 1): denoiser.train() total_loss, total_r_norm, n = 0, 0, 0 pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}", bar_format='{l_bar}{bar:20}{r_bar}') for images, labels in pbar: images = images.to(device) labels = labels.to(device) B = images.shape[0] # ── Sample timestep ── t = torch.rand(B, device=device) # ── Noise the image ── x_noised, eps = add_noise(images, t) # ── Encode through both twins ── with torch.no_grad(): f_out = fresnel(images) # clean j_out = johanna(x_noised) # noised S_f = f_out['svd']['S'] # target: (B, 64, 16) S_j = j_out['svd']['S'] # input: (B, 64, 16) # ── Procrustes alignment ── with torch.no_grad(): R, R_feat = compute_procrustes_features( j_out['svd']['U'], f_out['svd']['U']) # ── Predict clean omega tokens ── S_pred = denoiser(S_j, t, labels, R_feat) loss = F.mse_loss(S_pred, S_f) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0) opt.step() total_loss += loss.item() * B with torch.no_grad(): total_r_norm += (R - torch.eye(16, device=device)).norm(dim=(-2, -1)).mean().item() * B n += B pbar.set_postfix_str(f"loss={loss.item():.6f}") sched.step() # ── Validation ── denoiser.eval() val_loss, val_n = 0, 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) B = images.shape[0] t = torch.rand(B, device=device) x_noised, _ = add_noise(images, t) f_out = fresnel(images) j_out = johanna(x_noised) _, R_feat = compute_procrustes_features( j_out['svd']['U'], f_out['svd']['U']) S_pred = denoiser(j_out['svd']['S'], t, labels, R_feat) val_loss += F.mse_loss(S_pred, f_out['svd']['S']).item() * B val_n += B train_l = total_loss / n val_l = val_loss / val_n r_norm = total_r_norm / n if val_l < best_val: best_val = val_l torch.save({ 'epoch': epoch, 'val_loss': val_l, 'model_state_dict': denoiser.state_dict(), 'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads}, }, os.path.join(save_dir, 'best.pt')) if epoch % 5 == 0 or epoch <= 5: print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} " f"best={best_val:.6f} ||R-I||={r_norm:.3f}") # ── Sample ── if epoch % 25 == 0: sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir) print(f"\n TRAINING COMPLETE — best val: {best_val:.6f}") return denoiser # ═══════════════════════════════════════════════════════════════ # SAMPLING — ITERATIVE STEREO DENOISING # ═══════════════════════════════════════════════════════════════ @torch.no_grad() def sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir, n_samples=4, n_steps=50): """Generate samples using iterative twin denoising. 1. Start from pure noise x_T 2. At each step: a. Johanna encodes x_t → (U_j, S_j, Vt_j) b. Denoiser predicts clean S_f from S_j c. Decode through Johanna's basis → x̂_0 estimate d. Flow step toward x̂_0 3. Final pass: encode through Fresnel → decode with clean basis """ from geolip_svae.model import stitch_patches denoiser.eval() labels = torch.randint(0, 10, (n_samples, 4), device=device) class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # Start from noise x = torch.randn(n_samples, 3, 128, 128, device=device) for step in range(n_steps): t_val = 1.0 - step / n_steps t = torch.full((n_samples,), t_val, device=device) # Johanna sees current state j_out = johanna(x) S_j = j_out['svd']['S'] # Denoiser predicts clean omega tokens (no R at inference) S_pred = denoiser(S_j, t, labels, R_feat=None) # Decode through Johanna's basis decoded = johanna.decode_patches( j_out['svd']['U'], S_pred, j_out['svd']['Vt']) ps = johanna.patch_size gh = gw = int(math.sqrt(S_j.shape[1])) x_hat_0 = johanna.boundary_smooth(stitch_patches(decoded, gh, gw, ps)) # Flow step toward clean estimate if step < n_steps - 1: dt = 1.0 / n_steps velocity = (x_hat_0 - x) / (t_val + 1e-4) x = x - dt * velocity else: x = x_hat_0 # ── Final Fresnel polish ── # Encode through Fresnel to get clean basis, re-decode f_out = fresnel(x) f_decoded = fresnel.decode_patches( f_out['svd']['U'], f_out['svd']['S'], f_out['svd']['Vt']) x_final = fresnel.boundary_smooth(stitch_patches(f_decoded, gh, gw, ps)) # ── Denormalize and save ── mean = torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1).to(device) std = torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1).to(device) x_johanna = (x * std + mean).clamp(0, 1).cpu() x_fresnel = (x_final * std + mean).clamp(0, 1).cpu() import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt fig, axes = plt.subplots(n_samples, 2, figsize=(8, n_samples * 3)) if n_samples == 1: axes = axes.unsqueeze(0) for i in range(n_samples): tile_labels = [class_names[l] for l in labels[i].cpu().tolist()] axes[i, 0].imshow(x_johanna[i].permute(1, 2, 0).numpy()) axes[i, 0].set_title(f"Johanna decode: {tile_labels}", fontsize=7) axes[i, 0].axis('off') axes[i, 1].imshow(x_fresnel[i].permute(1, 2, 0).numpy()) axes[i, 1].set_title(f"Fresnel polish: {tile_labels}", fontsize=7) axes[i, 1].axis('off') plt.suptitle(f"Twin Stereo Diffusion — Epoch {epoch}", fontsize=10) plt.tight_layout() fname = os.path.join(save_dir, f'stereo_ep{epoch:03d}.png') plt.savefig(fname, dpi=150, bbox_inches='tight') plt.close() print(f" Samples saved: {fname}") print(f" Labels: {labels.cpu().tolist()}") # ═══════════════════════════════════════════════════════════════ # ADVANCED SAMPLING — DUAL-ENCODE REFINEMENT # ═══════════════════════════════════════════════════════════════ @torch.no_grad() def sample_stereo_refined(denoiser, fresnel, johanna, labels, device, n_steps=50): """Two-pass refinement: use Fresnel to estimate R at inference. At each step: 1. Johanna(x_t) → (U_j, S_j, Vt_j) 2. Pass 1: Denoiser(S_j, t, labels) → S_pred (no R) 3. Decode → x̂_0, encode through Fresnel → U_f_est 4. R_est = Procrustes(U_j, U_f_est) 5. Pass 2: Denoiser(S_j, t, labels, R_est) → S_refined 6. Decode through Fresnel's estimated basis → x_{t-1} """ from geolip_svae.model import stitch_patches B = labels.shape[0] x = torch.randn(B, 3, 128, 128, device=device) ps = johanna.patch_size for step in range(n_steps): t_val = 1.0 - step / n_steps t = torch.full((B,), t_val, device=device) # Johanna encodes current state j_out = johanna(x) S_j = j_out['svd']['S'] gh = gw = int(math.sqrt(S_j.shape[1])) # Pass 1: predict without R S_pred_1 = denoiser(S_j, t, labels, R_feat=None) # Decode pass 1 through Johanna dec_1 = johanna.decode_patches(j_out['svd']['U'], S_pred_1, j_out['svd']['Vt']) x_est = johanna.boundary_smooth(stitch_patches(dec_1, gh, gw, ps)) # Fresnel sees the estimate → get clean-style basis f_est = fresnel(x_est) # Procrustes: how far is Johanna's basis from Fresnel's? _, R_feat = compute_procrustes_features( j_out['svd']['U'], f_est['svd']['U']) # Pass 2: predict WITH R conditioning S_pred_2 = denoiser(S_j, t, labels, R_feat) # Decode through Fresnel's estimated basis dec_2 = fresnel.decode_patches( f_est['svd']['U'], S_pred_2, f_est['svd']['Vt']) x_clean = fresnel.boundary_smooth(stitch_patches(dec_2, gh, gw, ps)) # Flow step if step < n_steps - 1: dt = 1.0 / n_steps velocity = (x_clean - x) / (t_val + 1e-4) x = x - dt * velocity else: x = x_clean return x # ═══════════════════════════════════════════════════════════════ # CLI # ═══════════════════════════════════════════════════════════════ if __name__ == "__main__": torch.set_float32_matmul_precision('high') train( epochs=100, batch_size=64, # 2 VAE forwards per batch, keep it moderate lr=3e-4, hidden=256, depth=8, n_heads=8, n_train=50000, )