""" Twin Stereo Diffusion v2 — Omega-Space Flow Matching ====================================================== Pre-encode everything. Diffuse on the manifold. Decode once. Training: 1. Pre-encode all images through Fresnel → S_f (per image) 2. Compute pooled basis: mean U_f, Vt_f across dataset (orthogonalized) 3. Flow matching on omega tokens: noise S directly, predict clean S 4. Denoiser lives entirely in omega space — no pixel-space ODE Inference: 1. Start from noise omega tokens (sampled from empirical noise distribution) 2. ODE in omega space: S_t → predict S_clean → flow step on S 3. Decode ONCE at the end: pooled basis (U_mean, Vt_mean) + predicted S → Fresnel decoder → pixels No iterative encode/decode. No pixel-space accumulation. The structural response IS the pooled spectral basis. """ import os import math import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as T import numpy as np from tqdm import tqdm try: from google.colab import userdata os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') from huggingface_hub import login login(token=os.environ["HF_TOKEN"]) except Exception: pass # ═══════════════════════════════════════════════════════════════ # FROZEN FRESNEL # ═══════════════════════════════════════════════════════════════ def load_fresnel(device='cuda'): from geolip_svae import load_model model, cfg = load_model(hf_version='v12_imagenet128', device=device) model.eval() for p in model.parameters(): p.requires_grad = False print(f" Fresnel-small: {sum(p.numel() for p in model.parameters()):,} params (frozen)") return model, cfg # ═══════════════════════════════════════════════════════════════ # DATASET # ═══════════════════════════════════════════════════════════════ IMG_MEAN = (0.4802, 0.4481, 0.3975) IMG_STD = (0.2770, 0.2691, 0.2821) class TinyImageNet128(torch.utils.data.Dataset): """TinyImageNet 200 classes, 64→128.""" def __init__(self, split='train'): from datasets import load_dataset self.ds = load_dataset('zh-plus/tiny-imagenet', split=split) self.transform = T.Compose([ T.Resize(128, interpolation=T.InterpolationMode.BILINEAR), T.ToTensor(), T.Normalize(IMG_MEAN, IMG_STD), ]) def __len__(self): return len(self.ds) def __getitem__(self, idx): item = self.ds[idx] img = item['image'] if img.mode != 'RGB': img = img.convert('RGB') return self.transform(img), item['label'] # ═══════════════════════════════════════════════════════════════ # PRE-ENCODE + POOLED BASIS # ═══════════════════════════════════════════════════════════════ @torch.no_grad() def pre_encode_with_basis(fresnel, dataset, device, batch_size=64): """Encode entire dataset, compute pooled orthogonal basis. Returns: omega: (N, 64, 16) — all S_f labels: (N,) — all labels U_pool: (64, 256, 16) — orthogonalized mean U per patch Vt_pool: (64, 16, 16) — orthogonalized mean Vt per patch omega_mean: (16,) — mean singular value profile omega_std: (16,) — std singular value profile """ loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) all_S, all_labels = [], [] U_sum = torch.zeros(64, 256, 16, dtype=torch.float64, device=device) Vt_sum = torch.zeros(64, 16, 16, dtype=torch.float64, device=device) count = 0 print(f" Pre-encoding {len(dataset)} images through Fresnel...") for images, labs in tqdm(loader, desc="Encoding"): images = images.to(device) out = fresnel(images) S = out['svd']['S'] # (B, 64, 16) U = out['svd']['U'] # (B, 64, 256, 16) Vt = out['svd']['Vt'] # (B, 64, 16, 16) all_S.append(S.cpu()) all_labels.append(labs) # Running sum for pooled basis U_sum += U.double().sum(dim=0) # (64, 256, 16) Vt_sum += Vt.double().sum(dim=0) # (64, 16, 16) count += S.shape[0] omega = torch.cat(all_S, dim=0) labels = torch.cat(all_labels, dim=0) # ── Orthogonalize pooled basis via polar decomposition ── U_mean = (U_sum / count).float() # (64, 256, 16) Vt_mean = (Vt_sum / count).float() # (64, 16, 16) # Polar decomposition: nearest orthogonal matrix to mean # For U: SVD(U_mean) → U_orth @ Vt_orth gives nearest orthogonal Uu, _, Uv = torch.linalg.svd(U_mean, full_matrices=False) U_pool = torch.bmm(Uu, Uv) # (64, 256, 16) Vu, _, Vv = torch.linalg.svd(Vt_mean, full_matrices=False) Vt_pool = torch.bmm(Vu, Vv) # (64, 16, 16) omega_mean = omega.mean(dim=(0, 1)) omega_std = omega.std(dim=(0, 1)) print(f" Encoded: {omega.shape}, {labels.shape}") print(f" Omega: mean={omega.mean():.3f} std={omega.std():.3f} " f"range=[{omega.min():.3f}, {omega.max():.3f}]") print(f" Pooled basis: U={U_pool.shape}, Vt={Vt_pool.shape}") print(f" Basis orthogonality check: ||U^T U - I|| = " f"{(torch.bmm(U_pool.transpose(-2,-1), U_pool) - torch.eye(16, device=device)).norm():.6f}") return omega, labels, U_pool, Vt_pool, omega_mean, omega_std class PreEncodedDataset(torch.utils.data.Dataset): def __init__(self, omega, labels): self.omega = omega self.labels = labels def __len__(self): return len(self.omega) def __getitem__(self, idx): return self.omega[idx], self.labels[idx] # ═══════════════════════════════════════════════════════════════ # DENOISER — PURE OMEGA SPACE # ═══════════════════════════════════════════════════════════════ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): half = self.dim // 2 emb = math.log(10000) / (half - 1) emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb) emb = t.unsqueeze(1) * emb.unsqueeze(0) return torch.cat([emb.sin(), emb.cos()], dim=1) class AdaLN(nn.Module): def __init__(self, dim, cond_dim): super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False) self.proj = nn.Linear(cond_dim, dim * 2) nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x, cond): s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1) return self.norm(x) * (1 + s[0]) + s[1] class OmegaBlock(nn.Module): def __init__(self, dim, n_heads, cond_dim): super().__init__() self.adaln1 = AdaLN(dim, cond_dim) self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True) self.adaln2 = AdaLN(dim, cond_dim) self.ff = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) def forward(self, x, cond): h = self.adaln1(x, cond) h, _ = self.attn(h, h, h) x = x + h return x + self.ff(self.adaln2(x, cond)) class OmegaDenoiser(nn.Module): """Predict clean S_f from noised S_t. Lives entirely in omega space. Input: S_t (B, 64, 16) — noised omega tokens t (B,) — noise level labels (B,) — class Output: S_0 (B, 64, 16) — predicted clean omega tokens """ def __init__(self, n_patches=64, omega_dim=16, hidden=256, depth=8, n_heads=8, n_classes=200): super().__init__() self.input_proj = nn.Linear(omega_dim, hidden) self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02) self.time_emb = nn.Sequential( SinusoidalPosEmb(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) self.class_emb = nn.Embedding(n_classes, hidden) self.blocks = nn.ModuleList([ OmegaBlock(hidden, n_heads, hidden) for _ in range(depth)]) self.out_norm = nn.LayerNorm(hidden) self.out_proj = nn.Linear(hidden, omega_dim) nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) def forward(self, S_t, t, labels): B = S_t.shape[0] h = self.input_proj(S_t) + self.pos_emb cond = self.time_emb(t) + self.class_emb(labels) for block in self.blocks: h = block(h, cond) return S_t + self.out_proj(self.out_norm(h)) # ═══════════════════════════════════════════════════════════════ # FLOW MATCHING — OMEGA SPACE # ═══════════════════════════════════════════════════════════════ def omega_flow_loss(model, S_clean, labels, omega_mean, omega_std, device): """Flow matching loss entirely in omega space. Noise: Gaussian in omega space, matched to empirical distribution. Path: S_t = (1-t) * S_noise + t * S_clean Target: x₀-prediction (predict clean singular values) """ B = S_clean.shape[0] t = torch.rand(B, device=device) # Noise omega tokens from empirical distribution S_noise = omega_mean.to(device) + omega_std.to(device) * torch.randn_like(S_clean) # Interpolate t_exp = t.view(B, 1, 1) S_t = (1 - t_exp) * S_noise + t_exp * S_clean # Predict clean S_pred = model(S_t, t, labels) return F.mse_loss(S_pred, S_clean) @torch.no_grad() def sample_omega_ode(model, labels, omega_mean, omega_std, n_steps=50, device='cuda'): """Euler ODE sampler in omega space. No pixel-space loop.""" B = labels.shape[0] # Start from noise omega tokens S = omega_mean.to(device) + omega_std.to(device) * torch.randn(B, 64, 16, device=device) for step in range(n_steps): t_val = step / n_steps # 0 → 1 (noise → clean) t = torch.full((B,), t_val, device=device) S_pred = model(S, t, labels) # Velocity toward clean dt = 1.0 / n_steps velocity = (S_pred - S) / (1.0 - t_val + 1e-4) S = S + dt * velocity return S # ═══════════════════════════════════════════════════════════════ # TRAINING # ═══════════════════════════════════════════════════════════════ def train(epochs=100, batch_size=256, lr=3e-4, hidden=256, depth=8, n_heads=8, device='cuda'): device = torch.device(device if torch.cuda.is_available() else 'cpu') print("\n" + "=" * 70) print("TWIN STEREO v2 — Omega-Space Flow Matching") print("=" * 70) fresnel, f_cfg = load_fresnel(device) # ── Pre-encode ── print("\n Loading TinyImageNet...") train_ds = TinyImageNet128(split='train') val_ds = TinyImageNet128(split='valid') train_omega, train_labels, U_pool, Vt_pool, omega_mean, omega_std = \ pre_encode_with_basis(fresnel, train_ds, device) val_omega, val_labels, _, _, _, _ = \ pre_encode_with_basis(fresnel, val_ds, device) # Move pooled basis to device U_pool = U_pool.to(device) Vt_pool = Vt_pool.to(device) # ── Dataloaders on pre-encoded tokens ── train_loader = torch.utils.data.DataLoader( PreEncodedDataset(train_omega, train_labels), batch_size=batch_size, shuffle=True, drop_last=True) val_loader = torch.utils.data.DataLoader( PreEncodedDataset(val_omega, val_labels), batch_size=batch_size, shuffle=False) # ── Denoiser ── denoiser = OmegaDenoiser( n_patches=64, omega_dim=16, hidden=hidden, depth=depth, n_heads=n_heads, n_classes=200).to(device) n_params = sum(p.numel() for p in denoiser.parameters()) print(f"\n OmegaDenoiser: {n_params:,} params") print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}") print(f" Training: {len(train_omega)} pre-encoded samples, batch={batch_size}") print(f" Pure omega-space flow matching — no pixel ODE") print("=" * 70) opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) save_dir = '/content/stereo_v2_checkpoints' os.makedirs(save_dir, exist_ok=True) best_val = float('inf') for epoch in range(1, epochs + 1): denoiser.train() total_loss, n = 0, 0 for omega, labels in train_loader: omega = omega.to(device) labels = labels.to(device) loss = omega_flow_loss(denoiser, omega, labels, omega_mean, omega_std, device) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0) opt.step() total_loss += loss.item() * len(omega) n += len(omega) sched.step() # ── Validation ── denoiser.eval() val_loss, val_n = 0, 0 with torch.no_grad(): for omega, labels in val_loader: omega, labels = omega.to(device), labels.to(device) loss = omega_flow_loss(denoiser, omega, labels, omega_mean, omega_std, device) val_loss += loss.item() * len(omega) val_n += len(omega) train_l = total_loss / n val_l = val_loss / val_n if val_l < best_val: best_val = val_l torch.save({ 'epoch': epoch, 'val_loss': val_l, 'model_state_dict': denoiser.state_dict(), 'U_pool': U_pool.cpu(), 'Vt_pool': Vt_pool.cpu(), 'omega_mean': omega_mean, 'omega_std': omega_std, 'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads}, }, os.path.join(save_dir, 'best.pt')) print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} best={best_val:.6f}") # ── Sample every epoch ── sample_and_decode(denoiser, fresnel, U_pool, Vt_pool, omega_mean, omega_std, device, epoch, save_dir) print(f"\n TRAINING COMPLETE — best val: {best_val:.6f}") return denoiser # ═══════════════════════════════════════════════════════════════ # SAMPLING + DECODE # ═══════════════════════════════════════════════════════════════ @torch.no_grad() def sample_and_decode(denoiser, fresnel, U_pool, Vt_pool, omega_mean, omega_std, device, epoch, save_dir, n_samples=4, n_steps=50): """Sample omega tokens via ODE, decode once through Fresnel.""" from geolip_svae.model import stitch_patches denoiser.eval() labels = torch.randint(0, 200, (n_samples,), device=device) # ── ODE in omega space ── S_pred = sample_omega_ode(denoiser, labels, omega_mean, omega_std, n_steps=n_steps, device=device) # ── Decode ONCE through Fresnel with pooled basis ── B, N, D = S_pred.shape U = U_pool.unsqueeze(0).expand(B, -1, -1, -1) # (B, 64, 256, 16) Vt = Vt_pool.unsqueeze(0).expand(B, -1, -1, -1) # (B, 64, 16, 16) decoded = fresnel.decode_patches(U, S_pred, Vt) ps = fresnel.patch_size gh = gw = int(math.sqrt(N)) images = fresnel.boundary_smooth(stitch_patches(decoded, gh, gw, ps)) # ── Also decode a real training example for comparison ── # Encode a real image → get its actual S → decode with pooled basis # This tests whether pooled basis alone reconstructs well # ── Denormalize ── mean = torch.tensor(IMG_MEAN).reshape(1, 3, 1, 1).to(device) std = torch.tensor(IMG_STD).reshape(1, 3, 1, 1).to(device) images = (images * std + mean).clamp(0, 1).cpu() # ── Plot ── import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 3, 3)) if n_samples == 1: axes = [axes] for i in range(n_samples): axes[i].imshow(images[i].permute(1, 2, 0).numpy()) axes[i].set_title(f"class {labels[i].item()}", fontsize=8) axes[i].axis('off') plt.suptitle(f"Omega-Space Diffusion — Epoch {epoch}", fontsize=10) plt.tight_layout() fname = os.path.join(save_dir, f'omega_v2_ep{epoch:03d}.png') plt.savefig(fname, dpi=150, bbox_inches='tight') plt.close() print(f" Samples: {fname} | labels={labels.cpu().tolist()}") # ═══════════════════════════════════════════════════════════════ # CLI # ═══════════════════════════════════════════════════════════════ if __name__ == "__main__": torch.set_float32_matmul_precision('high') train( epochs=100, batch_size=256, # pure omega space — no VAE per batch lr=3e-4, hidden=256, depth=8, n_heads=8, )