| """ |
| Twin Stereo Diffusion β Fresnel Γ Johanna Spectral Denoising |
| ============================================================== |
| Fresnel sees the clean image. Johanna sees the noise. |
| Procrustes alignment between their spectral bases IS the noise. |
| |
| Training: |
| clean image βββ Fresnel βββ (U_f, S_f, Vt_f) target |
| noised image βββ Johanna βββ (U_j, S_j, Vt_j) input |
| R = Procrustes(U_j β U_f) rotation = noise signature |
| Denoiser(S_j, R, t, labels) β S_f predict clean magnitudes |
| |
| Inference: |
| x_t βββ Johanna βββ S_j βββ Denoiser βββ S_pred |
| decode(U_j, S_pred, Vt_j) βββ xΜ_0 |
| flow step: x_{t-dt} |
| final pass: x_0 βββ Fresnel encode/decode βββ crisp output |
| """ |
|
|
| import os |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as T |
| import numpy as np |
| from tqdm import tqdm |
|
|
| try: |
| from google.colab import userdata |
| os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') |
| from huggingface_hub import login |
| login(token=os.environ["HF_TOKEN"]) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| def load_twins(device='cuda'): |
| """Load both frozen SVAE twins at 128Γ128.""" |
| from geolip_svae import load_model |
|
|
| fresnel, f_cfg = load_model(hf_version='v12_imagenet128', device=device) |
| fresnel.eval() |
| for p in fresnel.parameters(): |
| p.requires_grad = False |
| print(f" Fresnel-small loaded: {sum(p.numel() for p in fresnel.parameters()):,} params (frozen)") |
|
|
| johanna, j_cfg = load_model(hf_version='v16_johanna_omega', device=device) |
| johanna.eval() |
| for p in johanna.parameters(): |
| p.requires_grad = False |
| print(f" Johanna-small loaded: {sum(p.numel() for p in johanna.parameters()):,} params (frozen)") |
|
|
| return fresnel, johanna |
|
|
|
|
| |
| |
| |
|
|
| def batched_procrustes(A, B): |
| """Find orthogonal R such that A @ R β B. |
| |
| Args: |
| A: (batch, M, D) β source (Johanna's U) |
| B: (batch, M, D) β target (Fresnel's U) |
| |
| Returns: |
| R: (batch, D, D) β orthogonal rotation |
| """ |
| M = torch.bmm(B.transpose(-2, -1), A) |
| U, S, Vt = torch.linalg.svd(M) |
| return torch.bmm(Vt.transpose(-2, -1), U.transpose(-2, -1)) |
|
|
|
|
| def compute_procrustes_features(U_j, U_f, D=16): |
| """Compute per-patch Procrustes rotation and extract features. |
| |
| Args: |
| U_j: (B, N, V, D) β Johanna's left singular vectors |
| U_f: (B, N, V, D) β Fresnel's left singular vectors |
| |
| Returns: |
| R: (B, N, D, D) β rotation matrices |
| R_feat: (B, N, D*D) β flattened rotation for projection |
| """ |
| B, N, V, D = U_j.shape |
| Uj = U_j.reshape(B * N, V, D) |
| Uf = U_f.reshape(B * N, V, D) |
| R = batched_procrustes(Uj, Uf) |
| R = R.reshape(B, N, D, D) |
| R_feat = R.reshape(B, N, D * D) |
| return R, R_feat |
|
|
|
|
| |
| |
| |
|
|
| CIFAR_MEAN = (0.4914, 0.4822, 0.4465) |
| CIFAR_STD = (0.2470, 0.2435, 0.2616) |
|
|
|
|
| class TiledCIFAR(torch.utils.data.Dataset): |
| """4 CIFAR-10 images (32β64) tiled 2Γ2 into 128Γ128.""" |
|
|
| def __init__(self, train=True, n_samples=50000): |
| self.n_samples = n_samples |
| self.cifar = torchvision.datasets.CIFAR10( |
| root='./data', train=train, download=True, |
| transform=T.Compose([ |
| T.Resize(64, interpolation=T.InterpolationMode.BILINEAR), |
| T.ToTensor(), |
| T.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ])) |
| self.n = len(self.cifar) |
|
|
| def __len__(self): |
| return self.n_samples |
|
|
| def __getitem__(self, idx): |
| ids = torch.randint(0, self.n, (4,)) |
| imgs, labels = [], [] |
| for i in ids: |
| img, lab = self.cifar[i.item()] |
| imgs.append(img) |
| labels.append(lab) |
| top = torch.cat([imgs[0], imgs[1]], dim=2) |
| bot = torch.cat([imgs[2], imgs[3]], dim=2) |
| return torch.cat([top, bot], dim=1), torch.tensor(labels, dtype=torch.long) |
|
|
|
|
| |
| |
| |
|
|
| def add_noise(x0, t): |
| """Linear flow-matching interpolation: x_t = (1-t)*x0 + t*Ξ΅. |
| |
| Args: |
| x0: (B, 3, 128, 128) clean images |
| t: (B,) timesteps in [0, 1] |
| |
| Returns: |
| x_t: noised images |
| eps: the noise that was added |
| """ |
| eps = torch.randn_like(x0) |
| t_exp = t.view(-1, 1, 1, 1) |
| x_t = (1 - t_exp) * x0 + t_exp * eps |
| return x_t, eps |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t): |
| half = self.dim // 2 |
| emb = math.log(10000) / (half - 1) |
| emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb) |
| emb = t.unsqueeze(1) * emb.unsqueeze(0) |
| return torch.cat([emb.sin(), emb.cos()], dim=1) |
|
|
|
|
| class AdaLN(nn.Module): |
| def __init__(self, dim, cond_dim): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim, elementwise_affine=False) |
| self.proj = nn.Linear(cond_dim, dim * 2) |
| nn.init.zeros_(self.proj.weight) |
| nn.init.zeros_(self.proj.bias) |
|
|
| def forward(self, x, cond): |
| s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1) |
| return self.norm(x) * (1 + s[0]) + s[1] |
|
|
|
|
| class StereoBlock(nn.Module): |
| """Transformer block with AdaLN and Procrustes-conditioned cross-path.""" |
|
|
| def __init__(self, dim, n_heads, cond_dim): |
| super().__init__() |
| self.adaln1 = AdaLN(dim, cond_dim) |
| self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True) |
| self.adaln2 = AdaLN(dim, cond_dim) |
| self.ff = nn.Sequential( |
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) |
|
|
| def forward(self, x, cond): |
| h = self.adaln1(x, cond) |
| h, _ = self.attn(h, h, h) |
| x = x + h |
| return x + self.ff(self.adaln2(x, cond)) |
|
|
|
|
| class StereoDenoiser(nn.Module): |
| """Predicts clean Fresnel omega tokens from noisy Johanna observations. |
| |
| Input: S_j (B, N, D) β Johanna's singular values |
| R_feat (B, N, DΒ²) β Procrustes rotation features |
| t (B,) β noise level |
| labels (B, 4) β tile class labels |
| |
| Output: S_f_pred (B, N, D) β predicted clean Fresnel singular values |
| """ |
|
|
| def __init__(self, n_patches=64, omega_dim=16, hidden=256, |
| depth=8, n_heads=8, n_classes=10, n_tiles=4): |
| super().__init__() |
| self.omega_dim = omega_dim |
| D2 = omega_dim * omega_dim |
|
|
| |
| self.input_proj = nn.Linear(omega_dim + D2, hidden) |
| self.input_proj_no_R = nn.Linear(omega_dim, hidden) |
|
|
| |
| self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02) |
|
|
| |
| self.time_emb = nn.Sequential( |
| SinusoidalPosEmb(hidden), |
| nn.Linear(hidden, hidden), nn.GELU(), |
| nn.Linear(hidden, hidden)) |
|
|
| |
| self.class_emb = nn.Embedding(n_classes, hidden // n_tiles) |
| self.class_proj = nn.Linear(hidden, hidden) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| StereoBlock(hidden, n_heads, hidden) for _ in range(depth)]) |
|
|
| |
| self.out_norm = nn.LayerNorm(hidden) |
| self.out_proj = nn.Linear(hidden, omega_dim) |
| nn.init.zeros_(self.out_proj.weight) |
| nn.init.zeros_(self.out_proj.bias) |
|
|
| def forward(self, S_j, t, labels, R_feat=None): |
| B = S_j.shape[0] |
|
|
| |
| if R_feat is not None: |
| h = self.input_proj(torch.cat([S_j, R_feat], dim=-1)) |
| else: |
| h = self.input_proj_no_R(S_j) |
| h = h + self.pos_emb |
|
|
| |
| t_emb = self.time_emb(t) |
| c_emb = self.class_proj(self.class_emb(labels).reshape(B, -1)) |
| cond = t_emb + c_emb |
|
|
| |
| for block in self.blocks: |
| h = block(h, cond) |
|
|
| |
| return S_j + self.out_proj(self.out_norm(h)) |
|
|
|
|
| |
| |
| |
|
|
| def train(epochs=100, batch_size=64, lr=3e-4, hidden=256, depth=8, |
| n_heads=8, n_train=50000, device='cuda'): |
|
|
| device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
| print("\n" + "=" * 70) |
| print("TWIN STEREO DIFFUSION β Fresnel Γ Johanna") |
| print("=" * 70) |
|
|
| |
| fresnel, johanna = load_twins(device) |
|
|
| |
| train_ds = TiledCIFAR(train=True, n_samples=n_train) |
| val_ds = TiledCIFAR(train=False, n_samples=5000) |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=batch_size, shuffle=True, |
| num_workers=4, pin_memory=True, drop_last=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=batch_size, shuffle=False, |
| num_workers=4, pin_memory=True) |
|
|
| |
| denoiser = StereoDenoiser( |
| n_patches=64, omega_dim=16, hidden=hidden, |
| depth=depth, n_heads=n_heads).to(device) |
|
|
| n_params = sum(p.numel() for p in denoiser.parameters()) |
| print(f"\n StereoDenoiser: {n_params:,} params") |
| print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}") |
| print(f" Training: {n_train} samples, batch={batch_size}") |
| print(f" Pipeline: Johanna(noised) + Procrustes β predict Fresnel(clean)") |
| print("=" * 70) |
|
|
| opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| save_dir = '/content/stereo_checkpoints' |
| os.makedirs(save_dir, exist_ok=True) |
| best_val = float('inf') |
|
|
| for epoch in range(1, epochs + 1): |
| denoiser.train() |
| total_loss, total_r_norm, n = 0, 0, 0 |
|
|
| pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}", |
| bar_format='{l_bar}{bar:20}{r_bar}') |
| for images, labels in pbar: |
| images = images.to(device) |
| labels = labels.to(device) |
| B = images.shape[0] |
|
|
| |
| t = torch.rand(B, device=device) |
|
|
| |
| x_noised, eps = add_noise(images, t) |
|
|
| |
| with torch.no_grad(): |
| f_out = fresnel(images) |
| j_out = johanna(x_noised) |
|
|
| S_f = f_out['svd']['S'] |
| S_j = j_out['svd']['S'] |
|
|
| |
| with torch.no_grad(): |
| R, R_feat = compute_procrustes_features( |
| j_out['svd']['U'], f_out['svd']['U']) |
|
|
| |
| S_pred = denoiser(S_j, t, labels, R_feat) |
| loss = F.mse_loss(S_pred, S_f) |
|
|
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0) |
| opt.step() |
|
|
| total_loss += loss.item() * B |
| with torch.no_grad(): |
| total_r_norm += (R - torch.eye(16, device=device)).norm(dim=(-2, -1)).mean().item() * B |
| n += B |
| pbar.set_postfix_str(f"loss={loss.item():.6f}") |
|
|
| sched.step() |
|
|
| |
| denoiser.eval() |
| val_loss, val_n = 0, 0 |
| with torch.no_grad(): |
| for images, labels in val_loader: |
| images, labels = images.to(device), labels.to(device) |
| B = images.shape[0] |
| t = torch.rand(B, device=device) |
| x_noised, _ = add_noise(images, t) |
| f_out = fresnel(images) |
| j_out = johanna(x_noised) |
| _, R_feat = compute_procrustes_features( |
| j_out['svd']['U'], f_out['svd']['U']) |
| S_pred = denoiser(j_out['svd']['S'], t, labels, R_feat) |
| val_loss += F.mse_loss(S_pred, f_out['svd']['S']).item() * B |
| val_n += B |
|
|
| train_l = total_loss / n |
| val_l = val_loss / val_n |
| r_norm = total_r_norm / n |
|
|
| if val_l < best_val: |
| best_val = val_l |
| torch.save({ |
| 'epoch': epoch, 'val_loss': val_l, |
| 'model_state_dict': denoiser.state_dict(), |
| 'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads}, |
| }, os.path.join(save_dir, 'best.pt')) |
|
|
| if epoch % 5 == 0 or epoch <= 5: |
| print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} " |
| f"best={best_val:.6f} ||R-I||={r_norm:.3f}") |
|
|
| |
| if epoch % 25 == 0: |
| sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir) |
|
|
| print(f"\n TRAINING COMPLETE β best val: {best_val:.6f}") |
| return denoiser |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir, |
| n_samples=4, n_steps=50): |
| """Generate samples using iterative twin denoising. |
| |
| 1. Start from pure noise x_T |
| 2. At each step: |
| a. Johanna encodes x_t β (U_j, S_j, Vt_j) |
| b. Denoiser predicts clean S_f from S_j |
| c. Decode through Johanna's basis β xΜ_0 estimate |
| d. Flow step toward xΜ_0 |
| 3. Final pass: encode through Fresnel β decode with clean basis |
| """ |
| from geolip_svae.model import stitch_patches |
|
|
| denoiser.eval() |
|
|
| labels = torch.randint(0, 10, (n_samples, 4), device=device) |
| class_names = ['plane', 'car', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
| |
| x = torch.randn(n_samples, 3, 128, 128, device=device) |
|
|
| for step in range(n_steps): |
| t_val = 1.0 - step / n_steps |
| t = torch.full((n_samples,), t_val, device=device) |
|
|
| |
| j_out = johanna(x) |
| S_j = j_out['svd']['S'] |
|
|
| |
| S_pred = denoiser(S_j, t, labels, R_feat=None) |
|
|
| |
| decoded = johanna.decode_patches( |
| j_out['svd']['U'], S_pred, j_out['svd']['Vt']) |
| ps = johanna.patch_size |
| gh = gw = int(math.sqrt(S_j.shape[1])) |
| x_hat_0 = johanna.boundary_smooth(stitch_patches(decoded, gh, gw, ps)) |
|
|
| |
| if step < n_steps - 1: |
| dt = 1.0 / n_steps |
| velocity = (x_hat_0 - x) / (t_val + 1e-4) |
| x = x - dt * velocity |
| else: |
| x = x_hat_0 |
|
|
| |
| |
| f_out = fresnel(x) |
| f_decoded = fresnel.decode_patches( |
| f_out['svd']['U'], f_out['svd']['S'], f_out['svd']['Vt']) |
| x_final = fresnel.boundary_smooth(stitch_patches(f_decoded, gh, gw, ps)) |
|
|
| |
| mean = torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1).to(device) |
| std = torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1).to(device) |
|
|
| x_johanna = (x * std + mean).clamp(0, 1).cpu() |
| x_fresnel = (x_final * std + mean).clamp(0, 1).cpu() |
|
|
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| fig, axes = plt.subplots(n_samples, 2, figsize=(8, n_samples * 3)) |
| if n_samples == 1: |
| axes = axes.unsqueeze(0) |
| for i in range(n_samples): |
| tile_labels = [class_names[l] for l in labels[i].cpu().tolist()] |
| axes[i, 0].imshow(x_johanna[i].permute(1, 2, 0).numpy()) |
| axes[i, 0].set_title(f"Johanna decode: {tile_labels}", fontsize=7) |
| axes[i, 0].axis('off') |
| axes[i, 1].imshow(x_fresnel[i].permute(1, 2, 0).numpy()) |
| axes[i, 1].set_title(f"Fresnel polish: {tile_labels}", fontsize=7) |
| axes[i, 1].axis('off') |
| plt.suptitle(f"Twin Stereo Diffusion β Epoch {epoch}", fontsize=10) |
| plt.tight_layout() |
| fname = os.path.join(save_dir, f'stereo_ep{epoch:03d}.png') |
| plt.savefig(fname, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f" Samples saved: {fname}") |
| print(f" Labels: {labels.cpu().tolist()}") |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample_stereo_refined(denoiser, fresnel, johanna, labels, device, |
| n_steps=50): |
| """Two-pass refinement: use Fresnel to estimate R at inference. |
| |
| At each step: |
| 1. Johanna(x_t) β (U_j, S_j, Vt_j) |
| 2. Pass 1: Denoiser(S_j, t, labels) β S_pred (no R) |
| 3. Decode β xΜ_0, encode through Fresnel β U_f_est |
| 4. R_est = Procrustes(U_j, U_f_est) |
| 5. Pass 2: Denoiser(S_j, t, labels, R_est) β S_refined |
| 6. Decode through Fresnel's estimated basis β x_{t-1} |
| """ |
| from geolip_svae.model import stitch_patches |
|
|
| B = labels.shape[0] |
| x = torch.randn(B, 3, 128, 128, device=device) |
| ps = johanna.patch_size |
|
|
| for step in range(n_steps): |
| t_val = 1.0 - step / n_steps |
| t = torch.full((B,), t_val, device=device) |
|
|
| |
| j_out = johanna(x) |
| S_j = j_out['svd']['S'] |
| gh = gw = int(math.sqrt(S_j.shape[1])) |
|
|
| |
| S_pred_1 = denoiser(S_j, t, labels, R_feat=None) |
|
|
| |
| dec_1 = johanna.decode_patches(j_out['svd']['U'], S_pred_1, j_out['svd']['Vt']) |
| x_est = johanna.boundary_smooth(stitch_patches(dec_1, gh, gw, ps)) |
|
|
| |
| f_est = fresnel(x_est) |
|
|
| |
| _, R_feat = compute_procrustes_features( |
| j_out['svd']['U'], f_est['svd']['U']) |
|
|
| |
| S_pred_2 = denoiser(S_j, t, labels, R_feat) |
|
|
| |
| dec_2 = fresnel.decode_patches( |
| f_est['svd']['U'], S_pred_2, f_est['svd']['Vt']) |
| x_clean = fresnel.boundary_smooth(stitch_patches(dec_2, gh, gw, ps)) |
|
|
| |
| if step < n_steps - 1: |
| dt = 1.0 / n_steps |
| velocity = (x_clean - x) / (t_val + 1e-4) |
| x = x - dt * velocity |
| else: |
| x = x_clean |
|
|
| return x |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| torch.set_float32_matmul_precision('high') |
|
|
| train( |
| epochs=100, |
| batch_size=64, |
| lr=3e-4, |
| hidden=256, |
| depth=8, |
| n_heads=8, |
| n_train=50000, |
| ) |