geolip-omega-diffusion-128 / prototype2_trainer.py
AbstractPhil's picture
Update prototype2_trainer.py
f6fae3f verified
"""
Twin Stereo Diffusion v2 β€” Omega-Space Flow Matching
======================================================
Pre-encode everything. Diffuse on the manifold. Decode once.
Training:
1. Pre-encode all images through Fresnel β†’ S_f (per image)
2. Compute pooled basis: mean U_f, Vt_f across dataset (orthogonalized)
3. Flow matching on omega tokens: noise S directly, predict clean S
4. Denoiser lives entirely in omega space β€” no pixel-space ODE
Inference:
1. Start from noise omega tokens (sampled from empirical noise distribution)
2. ODE in omega space: S_t β†’ predict S_clean β†’ flow step on S
3. Decode ONCE at the end: pooled basis (U_mean, Vt_mean) + predicted S β†’ Fresnel decoder β†’ pixels
No iterative encode/decode. No pixel-space accumulation.
The structural response IS the pooled spectral basis.
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import numpy as np
from tqdm import tqdm
try:
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
except Exception:
pass
# ═══════════════════════════════════════════════════════════════
# FROZEN FRESNEL
# ═══════════════════════════════════════════════════════════════
def load_fresnel(device='cuda'):
from geolip_svae import load_model
model, cfg = load_model(hf_version='v12_imagenet128', device=device)
model.eval()
for p in model.parameters():
p.requires_grad = False
print(f" Fresnel-small: {sum(p.numel() for p in model.parameters()):,} params (frozen)")
return model, cfg
# ═══════════════════════════════════════════════════════════════
# DATASET
# ═══════════════════════════════════════════════════════════════
IMG_MEAN = (0.4802, 0.4481, 0.3975)
IMG_STD = (0.2770, 0.2691, 0.2821)
class TinyImageNet128(torch.utils.data.Dataset):
"""TinyImageNet 200 classes, 64β†’128."""
def __init__(self, split='train'):
from datasets import load_dataset
self.ds = load_dataset('zh-plus/tiny-imagenet', split=split)
self.transform = T.Compose([
T.Resize(128, interpolation=T.InterpolationMode.BILINEAR),
T.ToTensor(),
T.Normalize(IMG_MEAN, IMG_STD),
])
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
item = self.ds[idx]
img = item['image']
if img.mode != 'RGB':
img = img.convert('RGB')
return self.transform(img), item['label']
# ═══════════════════════════════════════════════════════════════
# PRE-ENCODE + POOLED BASIS
# ═══════════════════════════════════════════════════════════════
@torch.no_grad()
def pre_encode_with_basis(fresnel, dataset, device, batch_size=64):
"""Encode entire dataset, compute pooled orthogonal basis.
Returns:
omega: (N, 64, 16) β€” all S_f
labels: (N,) β€” all labels
U_pool: (64, 256, 16) β€” orthogonalized mean U per patch
Vt_pool: (64, 16, 16) β€” orthogonalized mean Vt per patch
omega_mean: (16,) β€” mean singular value profile
omega_std: (16,) β€” std singular value profile
"""
loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
all_S, all_labels = [], []
U_sum = torch.zeros(64, 256, 16, dtype=torch.float64, device=device)
Vt_sum = torch.zeros(64, 16, 16, dtype=torch.float64, device=device)
count = 0
print(f" Pre-encoding {len(dataset)} images through Fresnel...")
for images, labs in tqdm(loader, desc="Encoding"):
images = images.to(device)
out = fresnel(images)
S = out['svd']['S'] # (B, 64, 16)
U = out['svd']['U'] # (B, 64, 256, 16)
Vt = out['svd']['Vt'] # (B, 64, 16, 16)
all_S.append(S.cpu())
all_labels.append(labs)
# Running sum for pooled basis
U_sum += U.double().sum(dim=0) # (64, 256, 16)
Vt_sum += Vt.double().sum(dim=0) # (64, 16, 16)
count += S.shape[0]
omega = torch.cat(all_S, dim=0)
labels = torch.cat(all_labels, dim=0)
# ── Orthogonalize pooled basis via polar decomposition ──
U_mean = (U_sum / count).float() # (64, 256, 16)
Vt_mean = (Vt_sum / count).float() # (64, 16, 16)
# Polar decomposition: nearest orthogonal matrix to mean
# For U: SVD(U_mean) β†’ U_orth @ Vt_orth gives nearest orthogonal
Uu, _, Uv = torch.linalg.svd(U_mean, full_matrices=False)
U_pool = torch.bmm(Uu, Uv) # (64, 256, 16)
Vu, _, Vv = torch.linalg.svd(Vt_mean, full_matrices=False)
Vt_pool = torch.bmm(Vu, Vv) # (64, 16, 16)
omega_mean = omega.mean(dim=(0, 1))
omega_std = omega.std(dim=(0, 1))
print(f" Encoded: {omega.shape}, {labels.shape}")
print(f" Omega: mean={omega.mean():.3f} std={omega.std():.3f} "
f"range=[{omega.min():.3f}, {omega.max():.3f}]")
print(f" Pooled basis: U={U_pool.shape}, Vt={Vt_pool.shape}")
print(f" Basis orthogonality check: ||U^T U - I|| = "
f"{(torch.bmm(U_pool.transpose(-2,-1), U_pool) - torch.eye(16, device=device)).norm():.6f}")
return omega, labels, U_pool, Vt_pool, omega_mean, omega_std
class PreEncodedDataset(torch.utils.data.Dataset):
def __init__(self, omega, labels):
self.omega = omega
self.labels = labels
def __len__(self):
return len(self.omega)
def __getitem__(self, idx):
return self.omega[idx], self.labels[idx]
# ═══════════════════════════════════════════════════════════════
# DENOISER β€” PURE OMEGA SPACE
# ═══════════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb)
emb = t.unsqueeze(1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=1)
class AdaLN(nn.Module):
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.proj = nn.Linear(cond_dim, dim * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x, cond):
s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1)
return self.norm(x) * (1 + s[0]) + s[1]
class OmegaBlock(nn.Module):
def __init__(self, dim, n_heads, cond_dim):
super().__init__()
self.adaln1 = AdaLN(dim, cond_dim)
self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True)
self.adaln2 = AdaLN(dim, cond_dim)
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
def forward(self, x, cond):
h = self.adaln1(x, cond)
h, _ = self.attn(h, h, h)
x = x + h
return x + self.ff(self.adaln2(x, cond))
class OmegaDenoiser(nn.Module):
"""Predict clean S_f from noised S_t. Lives entirely in omega space.
Input: S_t (B, 64, 16) β€” noised omega tokens
t (B,) β€” noise level
labels (B,) β€” class
Output: S_0 (B, 64, 16) β€” predicted clean omega tokens
"""
def __init__(self, n_patches=64, omega_dim=16, hidden=256,
depth=8, n_heads=8, n_classes=200):
super().__init__()
self.input_proj = nn.Linear(omega_dim, hidden)
self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02)
self.time_emb = nn.Sequential(
SinusoidalPosEmb(hidden),
nn.Linear(hidden, hidden), nn.GELU(),
nn.Linear(hidden, hidden))
self.class_emb = nn.Embedding(n_classes, hidden)
self.blocks = nn.ModuleList([
OmegaBlock(hidden, n_heads, hidden) for _ in range(depth)])
self.out_norm = nn.LayerNorm(hidden)
self.out_proj = nn.Linear(hidden, omega_dim)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, S_t, t, labels):
B = S_t.shape[0]
h = self.input_proj(S_t) + self.pos_emb
cond = self.time_emb(t) + self.class_emb(labels)
for block in self.blocks:
h = block(h, cond)
return S_t + self.out_proj(self.out_norm(h))
# ═══════════════════════════════════════════════════════════════
# FLOW MATCHING β€” OMEGA SPACE
# ═══════════════════════════════════════════════════════════════
def omega_flow_loss(model, S_clean, labels, omega_mean, omega_std, device):
"""Flow matching loss entirely in omega space.
Noise: Gaussian in omega space, matched to empirical distribution.
Path: S_t = (1-t) * S_noise + t * S_clean
Target: xβ‚€-prediction (predict clean singular values)
"""
B = S_clean.shape[0]
t = torch.rand(B, device=device)
# Noise omega tokens from empirical distribution
S_noise = omega_mean.to(device) + omega_std.to(device) * torch.randn_like(S_clean)
# Interpolate
t_exp = t.view(B, 1, 1)
S_t = (1 - t_exp) * S_noise + t_exp * S_clean
# Predict clean
S_pred = model(S_t, t, labels)
return F.mse_loss(S_pred, S_clean)
@torch.no_grad()
def sample_omega_ode(model, labels, omega_mean, omega_std,
n_steps=50, device='cuda'):
"""Euler ODE sampler in omega space. No pixel-space loop."""
B = labels.shape[0]
# Start from noise omega tokens
S = omega_mean.to(device) + omega_std.to(device) * torch.randn(B, 64, 16, device=device)
for step in range(n_steps):
t_val = step / n_steps # 0 β†’ 1 (noise β†’ clean)
t = torch.full((B,), t_val, device=device)
S_pred = model(S, t, labels)
# Velocity toward clean
dt = 1.0 / n_steps
velocity = (S_pred - S) / (1.0 - t_val + 1e-4)
S = S + dt * velocity
return S
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train(epochs=100, batch_size=256, lr=3e-4, hidden=256, depth=8,
n_heads=8, device='cuda'):
device = torch.device(device if torch.cuda.is_available() else 'cpu')
print("\n" + "=" * 70)
print("TWIN STEREO v2 β€” Omega-Space Flow Matching")
print("=" * 70)
fresnel, f_cfg = load_fresnel(device)
# ── Pre-encode ──
print("\n Loading TinyImageNet...")
train_ds = TinyImageNet128(split='train')
val_ds = TinyImageNet128(split='valid')
train_omega, train_labels, U_pool, Vt_pool, omega_mean, omega_std = \
pre_encode_with_basis(fresnel, train_ds, device)
val_omega, val_labels, _, _, _, _ = \
pre_encode_with_basis(fresnel, val_ds, device)
# Move pooled basis to device
U_pool = U_pool.to(device)
Vt_pool = Vt_pool.to(device)
# ── Dataloaders on pre-encoded tokens ──
train_loader = torch.utils.data.DataLoader(
PreEncodedDataset(train_omega, train_labels),
batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
PreEncodedDataset(val_omega, val_labels),
batch_size=batch_size, shuffle=False)
# ── Denoiser ──
denoiser = OmegaDenoiser(
n_patches=64, omega_dim=16, hidden=hidden,
depth=depth, n_heads=n_heads, n_classes=200).to(device)
n_params = sum(p.numel() for p in denoiser.parameters())
print(f"\n OmegaDenoiser: {n_params:,} params")
print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}")
print(f" Training: {len(train_omega)} pre-encoded samples, batch={batch_size}")
print(f" Pure omega-space flow matching β€” no pixel ODE")
print("=" * 70)
opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
save_dir = '/content/stereo_v2_checkpoints'
os.makedirs(save_dir, exist_ok=True)
best_val = float('inf')
for epoch in range(1, epochs + 1):
denoiser.train()
total_loss, n = 0, 0
for omega, labels in train_loader:
omega = omega.to(device)
labels = labels.to(device)
loss = omega_flow_loss(denoiser, omega, labels,
omega_mean, omega_std, device)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0)
opt.step()
total_loss += loss.item() * len(omega)
n += len(omega)
sched.step()
# ── Validation ──
denoiser.eval()
val_loss, val_n = 0, 0
with torch.no_grad():
for omega, labels in val_loader:
omega, labels = omega.to(device), labels.to(device)
loss = omega_flow_loss(denoiser, omega, labels,
omega_mean, omega_std, device)
val_loss += loss.item() * len(omega)
val_n += len(omega)
train_l = total_loss / n
val_l = val_loss / val_n
if val_l < best_val:
best_val = val_l
torch.save({
'epoch': epoch, 'val_loss': val_l,
'model_state_dict': denoiser.state_dict(),
'U_pool': U_pool.cpu(),
'Vt_pool': Vt_pool.cpu(),
'omega_mean': omega_mean,
'omega_std': omega_std,
'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads},
}, os.path.join(save_dir, 'best.pt'))
print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} best={best_val:.6f}")
# ── Sample every epoch ──
sample_and_decode(denoiser, fresnel, U_pool, Vt_pool,
omega_mean, omega_std, device, epoch, save_dir)
print(f"\n TRAINING COMPLETE β€” best val: {best_val:.6f}")
return denoiser
# ═══════════════════════════════════════════════════════════════
# SAMPLING + DECODE
# ═══════════════════════════════════════════════════════════════
@torch.no_grad()
def sample_and_decode(denoiser, fresnel, U_pool, Vt_pool,
omega_mean, omega_std, device, epoch, save_dir,
n_samples=4, n_steps=50):
"""Sample omega tokens via ODE, decode once through Fresnel."""
from geolip_svae.model import stitch_patches
denoiser.eval()
labels = torch.randint(0, 200, (n_samples,), device=device)
# ── ODE in omega space ──
S_pred = sample_omega_ode(denoiser, labels, omega_mean, omega_std,
n_steps=n_steps, device=device)
# ── Decode ONCE through Fresnel with pooled basis ──
B, N, D = S_pred.shape
U = U_pool.unsqueeze(0).expand(B, -1, -1, -1) # (B, 64, 256, 16)
Vt = Vt_pool.unsqueeze(0).expand(B, -1, -1, -1) # (B, 64, 16, 16)
decoded = fresnel.decode_patches(U, S_pred, Vt)
ps = fresnel.patch_size
gh = gw = int(math.sqrt(N))
images = fresnel.boundary_smooth(stitch_patches(decoded, gh, gw, ps))
# ── Also decode a real training example for comparison ──
# Encode a real image β†’ get its actual S β†’ decode with pooled basis
# This tests whether pooled basis alone reconstructs well
# ── Denormalize ──
mean = torch.tensor(IMG_MEAN).reshape(1, 3, 1, 1).to(device)
std = torch.tensor(IMG_STD).reshape(1, 3, 1, 1).to(device)
images = (images * std + mean).clamp(0, 1).cpu()
# ── Plot ──
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 3, 3))
if n_samples == 1:
axes = [axes]
for i in range(n_samples):
axes[i].imshow(images[i].permute(1, 2, 0).numpy())
axes[i].set_title(f"class {labels[i].item()}", fontsize=8)
axes[i].axis('off')
plt.suptitle(f"Omega-Space Diffusion β€” Epoch {epoch}", fontsize=10)
plt.tight_layout()
fname = os.path.join(save_dir, f'omega_v2_ep{epoch:03d}.png')
plt.savefig(fname, dpi=150, bbox_inches='tight')
plt.close()
print(f" Samples: {fname} | labels={labels.cpu().tolist()}")
# ═══════════════════════════════════════════════════════════════
# CLI
# ═══════════════════════════════════════════════════════════════
if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
train(
epochs=100,
batch_size=256, # pure omega space β€” no VAE per batch
lr=3e-4,
hidden=256,
depth=8,
n_heads=8,
)