| """ |
| Twin Stereo Diffusion v2 β Omega-Space Flow Matching |
| ====================================================== |
| Pre-encode everything. Diffuse on the manifold. Decode once. |
| |
| Training: |
| 1. Pre-encode all images through Fresnel β S_f (per image) |
| 2. Compute pooled basis: mean U_f, Vt_f across dataset (orthogonalized) |
| 3. Flow matching on omega tokens: noise S directly, predict clean S |
| 4. Denoiser lives entirely in omega space β no pixel-space ODE |
| |
| Inference: |
| 1. Start from noise omega tokens (sampled from empirical noise distribution) |
| 2. ODE in omega space: S_t β predict S_clean β flow step on S |
| 3. Decode ONCE at the end: pooled basis (U_mean, Vt_mean) + predicted S β Fresnel decoder β pixels |
| |
| No iterative encode/decode. No pixel-space accumulation. |
| The structural response IS the pooled spectral basis. |
| """ |
|
|
| import os |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as T |
| import numpy as np |
| from tqdm import tqdm |
|
|
| try: |
| from google.colab import userdata |
| os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN') |
| from huggingface_hub import login |
| login(token=os.environ["HF_TOKEN"]) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| def load_fresnel(device='cuda'): |
| from geolip_svae import load_model |
| model, cfg = load_model(hf_version='v12_imagenet128', device=device) |
| model.eval() |
| for p in model.parameters(): |
| p.requires_grad = False |
| print(f" Fresnel-small: {sum(p.numel() for p in model.parameters()):,} params (frozen)") |
| return model, cfg |
|
|
|
|
| |
| |
| |
|
|
| IMG_MEAN = (0.4802, 0.4481, 0.3975) |
| IMG_STD = (0.2770, 0.2691, 0.2821) |
|
|
|
|
| class TinyImageNet128(torch.utils.data.Dataset): |
| """TinyImageNet 200 classes, 64β128.""" |
|
|
| def __init__(self, split='train'): |
| from datasets import load_dataset |
| self.ds = load_dataset('zh-plus/tiny-imagenet', split=split) |
| self.transform = T.Compose([ |
| T.Resize(128, interpolation=T.InterpolationMode.BILINEAR), |
| T.ToTensor(), |
| T.Normalize(IMG_MEAN, IMG_STD), |
| ]) |
|
|
| def __len__(self): |
| return len(self.ds) |
|
|
| def __getitem__(self, idx): |
| item = self.ds[idx] |
| img = item['image'] |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| return self.transform(img), item['label'] |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def pre_encode_with_basis(fresnel, dataset, device, batch_size=64): |
| """Encode entire dataset, compute pooled orthogonal basis. |
| |
| Returns: |
| omega: (N, 64, 16) β all S_f |
| labels: (N,) β all labels |
| U_pool: (64, 256, 16) β orthogonalized mean U per patch |
| Vt_pool: (64, 16, 16) β orthogonalized mean Vt per patch |
| omega_mean: (16,) β mean singular value profile |
| omega_std: (16,) β std singular value profile |
| """ |
| loader = torch.utils.data.DataLoader( |
| dataset, batch_size=batch_size, shuffle=False, |
| num_workers=4, pin_memory=True) |
|
|
| all_S, all_labels = [], [] |
| U_sum = torch.zeros(64, 256, 16, dtype=torch.float64, device=device) |
| Vt_sum = torch.zeros(64, 16, 16, dtype=torch.float64, device=device) |
| count = 0 |
|
|
| print(f" Pre-encoding {len(dataset)} images through Fresnel...") |
| for images, labs in tqdm(loader, desc="Encoding"): |
| images = images.to(device) |
| out = fresnel(images) |
| S = out['svd']['S'] |
| U = out['svd']['U'] |
| Vt = out['svd']['Vt'] |
|
|
| all_S.append(S.cpu()) |
| all_labels.append(labs) |
|
|
| |
| U_sum += U.double().sum(dim=0) |
| Vt_sum += Vt.double().sum(dim=0) |
| count += S.shape[0] |
|
|
| omega = torch.cat(all_S, dim=0) |
| labels = torch.cat(all_labels, dim=0) |
|
|
| |
| U_mean = (U_sum / count).float() |
| Vt_mean = (Vt_sum / count).float() |
|
|
| |
| |
| Uu, _, Uv = torch.linalg.svd(U_mean, full_matrices=False) |
| U_pool = torch.bmm(Uu, Uv) |
|
|
| Vu, _, Vv = torch.linalg.svd(Vt_mean, full_matrices=False) |
| Vt_pool = torch.bmm(Vu, Vv) |
|
|
| omega_mean = omega.mean(dim=(0, 1)) |
| omega_std = omega.std(dim=(0, 1)) |
|
|
| print(f" Encoded: {omega.shape}, {labels.shape}") |
| print(f" Omega: mean={omega.mean():.3f} std={omega.std():.3f} " |
| f"range=[{omega.min():.3f}, {omega.max():.3f}]") |
| print(f" Pooled basis: U={U_pool.shape}, Vt={Vt_pool.shape}") |
| print(f" Basis orthogonality check: ||U^T U - I|| = " |
| f"{(torch.bmm(U_pool.transpose(-2,-1), U_pool) - torch.eye(16, device=device)).norm():.6f}") |
|
|
| return omega, labels, U_pool, Vt_pool, omega_mean, omega_std |
|
|
|
|
| class PreEncodedDataset(torch.utils.data.Dataset): |
| def __init__(self, omega, labels): |
| self.omega = omega |
| self.labels = labels |
| def __len__(self): |
| return len(self.omega) |
| def __getitem__(self, idx): |
| return self.omega[idx], self.labels[idx] |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
| def forward(self, t): |
| half = self.dim // 2 |
| emb = math.log(10000) / (half - 1) |
| emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb) |
| emb = t.unsqueeze(1) * emb.unsqueeze(0) |
| return torch.cat([emb.sin(), emb.cos()], dim=1) |
|
|
|
|
| class AdaLN(nn.Module): |
| def __init__(self, dim, cond_dim): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim, elementwise_affine=False) |
| self.proj = nn.Linear(cond_dim, dim * 2) |
| nn.init.zeros_(self.proj.weight) |
| nn.init.zeros_(self.proj.bias) |
| def forward(self, x, cond): |
| s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1) |
| return self.norm(x) * (1 + s[0]) + s[1] |
|
|
|
|
| class OmegaBlock(nn.Module): |
| def __init__(self, dim, n_heads, cond_dim): |
| super().__init__() |
| self.adaln1 = AdaLN(dim, cond_dim) |
| self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True) |
| self.adaln2 = AdaLN(dim, cond_dim) |
| self.ff = nn.Sequential( |
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) |
| def forward(self, x, cond): |
| h = self.adaln1(x, cond) |
| h, _ = self.attn(h, h, h) |
| x = x + h |
| return x + self.ff(self.adaln2(x, cond)) |
|
|
|
|
| class OmegaDenoiser(nn.Module): |
| """Predict clean S_f from noised S_t. Lives entirely in omega space. |
| |
| Input: S_t (B, 64, 16) β noised omega tokens |
| t (B,) β noise level |
| labels (B,) β class |
| |
| Output: S_0 (B, 64, 16) β predicted clean omega tokens |
| """ |
|
|
| def __init__(self, n_patches=64, omega_dim=16, hidden=256, |
| depth=8, n_heads=8, n_classes=200): |
| super().__init__() |
| self.input_proj = nn.Linear(omega_dim, hidden) |
| self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02) |
|
|
| self.time_emb = nn.Sequential( |
| SinusoidalPosEmb(hidden), |
| nn.Linear(hidden, hidden), nn.GELU(), |
| nn.Linear(hidden, hidden)) |
|
|
| self.class_emb = nn.Embedding(n_classes, hidden) |
|
|
| self.blocks = nn.ModuleList([ |
| OmegaBlock(hidden, n_heads, hidden) for _ in range(depth)]) |
|
|
| self.out_norm = nn.LayerNorm(hidden) |
| self.out_proj = nn.Linear(hidden, omega_dim) |
| nn.init.zeros_(self.out_proj.weight) |
| nn.init.zeros_(self.out_proj.bias) |
|
|
| def forward(self, S_t, t, labels): |
| B = S_t.shape[0] |
| h = self.input_proj(S_t) + self.pos_emb |
| cond = self.time_emb(t) + self.class_emb(labels) |
| for block in self.blocks: |
| h = block(h, cond) |
| return S_t + self.out_proj(self.out_norm(h)) |
|
|
|
|
| |
| |
| |
|
|
| def omega_flow_loss(model, S_clean, labels, omega_mean, omega_std, device): |
| """Flow matching loss entirely in omega space. |
| |
| Noise: Gaussian in omega space, matched to empirical distribution. |
| Path: S_t = (1-t) * S_noise + t * S_clean |
| Target: xβ-prediction (predict clean singular values) |
| """ |
| B = S_clean.shape[0] |
| t = torch.rand(B, device=device) |
|
|
| |
| S_noise = omega_mean.to(device) + omega_std.to(device) * torch.randn_like(S_clean) |
|
|
| |
| t_exp = t.view(B, 1, 1) |
| S_t = (1 - t_exp) * S_noise + t_exp * S_clean |
|
|
| |
| S_pred = model(S_t, t, labels) |
| return F.mse_loss(S_pred, S_clean) |
|
|
|
|
| @torch.no_grad() |
| def sample_omega_ode(model, labels, omega_mean, omega_std, |
| n_steps=50, device='cuda'): |
| """Euler ODE sampler in omega space. No pixel-space loop.""" |
| B = labels.shape[0] |
|
|
| |
| S = omega_mean.to(device) + omega_std.to(device) * torch.randn(B, 64, 16, device=device) |
|
|
| for step in range(n_steps): |
| t_val = step / n_steps |
| t = torch.full((B,), t_val, device=device) |
| S_pred = model(S, t, labels) |
|
|
| |
| dt = 1.0 / n_steps |
| velocity = (S_pred - S) / (1.0 - t_val + 1e-4) |
| S = S + dt * velocity |
|
|
| return S |
|
|
|
|
| |
| |
| |
|
|
| def train(epochs=100, batch_size=256, lr=3e-4, hidden=256, depth=8, |
| n_heads=8, device='cuda'): |
|
|
| device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
| print("\n" + "=" * 70) |
| print("TWIN STEREO v2 β Omega-Space Flow Matching") |
| print("=" * 70) |
|
|
| fresnel, f_cfg = load_fresnel(device) |
|
|
| |
| print("\n Loading TinyImageNet...") |
| train_ds = TinyImageNet128(split='train') |
| val_ds = TinyImageNet128(split='valid') |
|
|
| train_omega, train_labels, U_pool, Vt_pool, omega_mean, omega_std = \ |
| pre_encode_with_basis(fresnel, train_ds, device) |
| val_omega, val_labels, _, _, _, _ = \ |
| pre_encode_with_basis(fresnel, val_ds, device) |
|
|
| |
| U_pool = U_pool.to(device) |
| Vt_pool = Vt_pool.to(device) |
|
|
| |
| train_loader = torch.utils.data.DataLoader( |
| PreEncodedDataset(train_omega, train_labels), |
| batch_size=batch_size, shuffle=True, drop_last=True) |
| val_loader = torch.utils.data.DataLoader( |
| PreEncodedDataset(val_omega, val_labels), |
| batch_size=batch_size, shuffle=False) |
|
|
| |
| denoiser = OmegaDenoiser( |
| n_patches=64, omega_dim=16, hidden=hidden, |
| depth=depth, n_heads=n_heads, n_classes=200).to(device) |
|
|
| n_params = sum(p.numel() for p in denoiser.parameters()) |
| print(f"\n OmegaDenoiser: {n_params:,} params") |
| print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}") |
| print(f" Training: {len(train_omega)} pre-encoded samples, batch={batch_size}") |
| print(f" Pure omega-space flow matching β no pixel ODE") |
| print("=" * 70) |
|
|
| opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
|
|
| save_dir = '/content/stereo_v2_checkpoints' |
| os.makedirs(save_dir, exist_ok=True) |
| best_val = float('inf') |
|
|
| for epoch in range(1, epochs + 1): |
| denoiser.train() |
| total_loss, n = 0, 0 |
|
|
| for omega, labels in train_loader: |
| omega = omega.to(device) |
| labels = labels.to(device) |
|
|
| loss = omega_flow_loss(denoiser, omega, labels, |
| omega_mean, omega_std, device) |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0) |
| opt.step() |
|
|
| total_loss += loss.item() * len(omega) |
| n += len(omega) |
|
|
| sched.step() |
|
|
| |
| denoiser.eval() |
| val_loss, val_n = 0, 0 |
| with torch.no_grad(): |
| for omega, labels in val_loader: |
| omega, labels = omega.to(device), labels.to(device) |
| loss = omega_flow_loss(denoiser, omega, labels, |
| omega_mean, omega_std, device) |
| val_loss += loss.item() * len(omega) |
| val_n += len(omega) |
|
|
| train_l = total_loss / n |
| val_l = val_loss / val_n |
|
|
| if val_l < best_val: |
| best_val = val_l |
| torch.save({ |
| 'epoch': epoch, 'val_loss': val_l, |
| 'model_state_dict': denoiser.state_dict(), |
| 'U_pool': U_pool.cpu(), |
| 'Vt_pool': Vt_pool.cpu(), |
| 'omega_mean': omega_mean, |
| 'omega_std': omega_std, |
| 'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads}, |
| }, os.path.join(save_dir, 'best.pt')) |
|
|
| print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} best={best_val:.6f}") |
|
|
| |
| sample_and_decode(denoiser, fresnel, U_pool, Vt_pool, |
| omega_mean, omega_std, device, epoch, save_dir) |
|
|
| print(f"\n TRAINING COMPLETE β best val: {best_val:.6f}") |
| return denoiser |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample_and_decode(denoiser, fresnel, U_pool, Vt_pool, |
| omega_mean, omega_std, device, epoch, save_dir, |
| n_samples=4, n_steps=50): |
| """Sample omega tokens via ODE, decode once through Fresnel.""" |
| from geolip_svae.model import stitch_patches |
|
|
| denoiser.eval() |
| labels = torch.randint(0, 200, (n_samples,), device=device) |
|
|
| |
| S_pred = sample_omega_ode(denoiser, labels, omega_mean, omega_std, |
| n_steps=n_steps, device=device) |
|
|
| |
| B, N, D = S_pred.shape |
| U = U_pool.unsqueeze(0).expand(B, -1, -1, -1) |
| Vt = Vt_pool.unsqueeze(0).expand(B, -1, -1, -1) |
|
|
| decoded = fresnel.decode_patches(U, S_pred, Vt) |
| ps = fresnel.patch_size |
| gh = gw = int(math.sqrt(N)) |
| images = fresnel.boundary_smooth(stitch_patches(decoded, gh, gw, ps)) |
|
|
| |
| |
| |
|
|
| |
| mean = torch.tensor(IMG_MEAN).reshape(1, 3, 1, 1).to(device) |
| std = torch.tensor(IMG_STD).reshape(1, 3, 1, 1).to(device) |
| images = (images * std + mean).clamp(0, 1).cpu() |
|
|
| |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 3, 3)) |
| if n_samples == 1: |
| axes = [axes] |
| for i in range(n_samples): |
| axes[i].imshow(images[i].permute(1, 2, 0).numpy()) |
| axes[i].set_title(f"class {labels[i].item()}", fontsize=8) |
| axes[i].axis('off') |
| plt.suptitle(f"Omega-Space Diffusion β Epoch {epoch}", fontsize=10) |
| plt.tight_layout() |
| fname = os.path.join(save_dir, f'omega_v2_ep{epoch:03d}.png') |
| plt.savefig(fname, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f" Samples: {fname} | labels={labels.cpu().tolist()}") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| torch.set_float32_matmul_precision('high') |
| train( |
| epochs=100, |
| batch_size=256, |
| lr=3e-4, |
| hidden=256, |
| depth=8, |
| n_heads=8, |
| ) |