geolip-omega-diffusion-128 / prototype1_trainer.py
AbstractPhil's picture
Create prototype1_trainer.py
1daa7e1 verified
"""
Twin Stereo Diffusion β€” Fresnel Γ— Johanna Spectral Denoising
==============================================================
Fresnel sees the clean image. Johanna sees the noise.
Procrustes alignment between their spectral bases IS the noise.
Training:
clean image ──→ Fresnel ──→ (U_f, S_f, Vt_f) target
noised image ──→ Johanna ──→ (U_j, S_j, Vt_j) input
R = Procrustes(U_j β†’ U_f) rotation = noise signature
Denoiser(S_j, R, t, labels) β†’ S_f predict clean magnitudes
Inference:
x_t ──→ Johanna ──→ S_j ──→ Denoiser ──→ S_pred
decode(U_j, S_pred, Vt_j) ──→ xΜ‚_0
flow step: x_{t-dt}
final pass: x_0 ──→ Fresnel encode/decode ──→ crisp output
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import numpy as np
from tqdm import tqdm
try:
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
except Exception:
pass
# ═══════════════════════════════════════════════════════════════
# FROZEN TWINS
# ═══════════════════════════════════════════════════════════════
def load_twins(device='cuda'):
"""Load both frozen SVAE twins at 128Γ—128."""
from geolip_svae import load_model
fresnel, f_cfg = load_model(hf_version='v12_imagenet128', device=device)
fresnel.eval()
for p in fresnel.parameters():
p.requires_grad = False
print(f" Fresnel-small loaded: {sum(p.numel() for p in fresnel.parameters()):,} params (frozen)")
johanna, j_cfg = load_model(hf_version='v16_johanna_omega', device=device)
johanna.eval()
for p in johanna.parameters():
p.requires_grad = False
print(f" Johanna-small loaded: {sum(p.numel() for p in johanna.parameters()):,} params (frozen)")
return fresnel, johanna
# ═══════════════════════════════════════════════════════════════
# PROCRUSTES ALIGNMENT
# ═══════════════════════════════════════════════════════════════
def batched_procrustes(A, B):
"""Find orthogonal R such that A @ R β‰ˆ B.
Args:
A: (batch, M, D) β€” source (Johanna's U)
B: (batch, M, D) β€” target (Fresnel's U)
Returns:
R: (batch, D, D) β€” orthogonal rotation
"""
M = torch.bmm(B.transpose(-2, -1), A) # (batch, D, D)
U, S, Vt = torch.linalg.svd(M)
return torch.bmm(Vt.transpose(-2, -1), U.transpose(-2, -1))
def compute_procrustes_features(U_j, U_f, D=16):
"""Compute per-patch Procrustes rotation and extract features.
Args:
U_j: (B, N, V, D) β€” Johanna's left singular vectors
U_f: (B, N, V, D) β€” Fresnel's left singular vectors
Returns:
R: (B, N, D, D) β€” rotation matrices
R_feat: (B, N, D*D) β€” flattened rotation for projection
"""
B, N, V, D = U_j.shape
Uj = U_j.reshape(B * N, V, D)
Uf = U_f.reshape(B * N, V, D)
R = batched_procrustes(Uj, Uf) # (B*N, D, D)
R = R.reshape(B, N, D, D)
R_feat = R.reshape(B, N, D * D)
return R, R_feat
# ═══════════════════════════════════════════════════════════════
# TILED CIFAR-10 DATASET
# ═══════════════════════════════════════════════════════════════
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)
class TiledCIFAR(torch.utils.data.Dataset):
"""4 CIFAR-10 images (32β†’64) tiled 2Γ—2 into 128Γ—128."""
def __init__(self, train=True, n_samples=50000):
self.n_samples = n_samples
self.cifar = torchvision.datasets.CIFAR10(
root='./data', train=train, download=True,
transform=T.Compose([
T.Resize(64, interpolation=T.InterpolationMode.BILINEAR),
T.ToTensor(),
T.Normalize(CIFAR_MEAN, CIFAR_STD),
]))
self.n = len(self.cifar)
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
ids = torch.randint(0, self.n, (4,))
imgs, labels = [], []
for i in ids:
img, lab = self.cifar[i.item()]
imgs.append(img)
labels.append(lab)
top = torch.cat([imgs[0], imgs[1]], dim=2)
bot = torch.cat([imgs[2], imgs[3]], dim=2)
return torch.cat([top, bot], dim=1), torch.tensor(labels, dtype=torch.long)
# ═══════════════════════════════════════════════════════════════
# NOISE SCHEDULE
# ═══════════════════════════════════════════════════════════════
def add_noise(x0, t):
"""Linear flow-matching interpolation: x_t = (1-t)*x0 + t*Ξ΅.
Args:
x0: (B, 3, 128, 128) clean images
t: (B,) timesteps in [0, 1]
Returns:
x_t: noised images
eps: the noise that was added
"""
eps = torch.randn_like(x0)
t_exp = t.view(-1, 1, 1, 1)
x_t = (1 - t_exp) * x0 + t_exp * eps
return x_t, eps
# ═══════════════════════════════════════════════════════════════
# SPECTRAL DENOISER
# ═══════════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb)
emb = t.unsqueeze(1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=1)
class AdaLN(nn.Module):
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
self.proj = nn.Linear(cond_dim, dim * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x, cond):
s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1)
return self.norm(x) * (1 + s[0]) + s[1]
class StereoBlock(nn.Module):
"""Transformer block with AdaLN and Procrustes-conditioned cross-path."""
def __init__(self, dim, n_heads, cond_dim):
super().__init__()
self.adaln1 = AdaLN(dim, cond_dim)
self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True)
self.adaln2 = AdaLN(dim, cond_dim)
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
def forward(self, x, cond):
h = self.adaln1(x, cond)
h, _ = self.attn(h, h, h)
x = x + h
return x + self.ff(self.adaln2(x, cond))
class StereoDenoiser(nn.Module):
"""Predicts clean Fresnel omega tokens from noisy Johanna observations.
Input: S_j (B, N, D) β€” Johanna's singular values
R_feat (B, N, DΒ²) β€” Procrustes rotation features
t (B,) β€” noise level
labels (B, 4) β€” tile class labels
Output: S_f_pred (B, N, D) β€” predicted clean Fresnel singular values
"""
def __init__(self, n_patches=64, omega_dim=16, hidden=256,
depth=8, n_heads=8, n_classes=10, n_tiles=4):
super().__init__()
self.omega_dim = omega_dim
D2 = omega_dim * omega_dim
# Input: omega tokens + Procrustes features
self.input_proj = nn.Linear(omega_dim + D2, hidden)
self.input_proj_no_R = nn.Linear(omega_dim, hidden)
# Positional embedding
self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02)
# Timestep embedding
self.time_emb = nn.Sequential(
SinusoidalPosEmb(hidden),
nn.Linear(hidden, hidden), nn.GELU(),
nn.Linear(hidden, hidden))
# Class embedding
self.class_emb = nn.Embedding(n_classes, hidden // n_tiles)
self.class_proj = nn.Linear(hidden, hidden)
# Transformer blocks
self.blocks = nn.ModuleList([
StereoBlock(hidden, n_heads, hidden) for _ in range(depth)])
# Output
self.out_norm = nn.LayerNorm(hidden)
self.out_proj = nn.Linear(hidden, omega_dim)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, S_j, t, labels, R_feat=None):
B = S_j.shape[0]
# Project input (with or without Procrustes features)
if R_feat is not None:
h = self.input_proj(torch.cat([S_j, R_feat], dim=-1))
else:
h = self.input_proj_no_R(S_j)
h = h + self.pos_emb
# Conditioning
t_emb = self.time_emb(t)
c_emb = self.class_proj(self.class_emb(labels).reshape(B, -1))
cond = t_emb + c_emb
# Transformer
for block in self.blocks:
h = block(h, cond)
# Predict residual: S_f β‰ˆ S_j + correction
return S_j + self.out_proj(self.out_norm(h))
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train(epochs=100, batch_size=64, lr=3e-4, hidden=256, depth=8,
n_heads=8, n_train=50000, device='cuda'):
device = torch.device(device if torch.cuda.is_available() else 'cpu')
print("\n" + "=" * 70)
print("TWIN STEREO DIFFUSION β€” Fresnel Γ— Johanna")
print("=" * 70)
# ── Load frozen twins ──
fresnel, johanna = load_twins(device)
# ── Data ──
train_ds = TiledCIFAR(train=True, n_samples=n_train)
val_ds = TiledCIFAR(train=False, n_samples=5000)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
# ── Denoiser ──
denoiser = StereoDenoiser(
n_patches=64, omega_dim=16, hidden=hidden,
depth=depth, n_heads=n_heads).to(device)
n_params = sum(p.numel() for p in denoiser.parameters())
print(f"\n StereoDenoiser: {n_params:,} params")
print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}")
print(f" Training: {n_train} samples, batch={batch_size}")
print(f" Pipeline: Johanna(noised) + Procrustes β†’ predict Fresnel(clean)")
print("=" * 70)
opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
save_dir = '/content/stereo_checkpoints'
os.makedirs(save_dir, exist_ok=True)
best_val = float('inf')
for epoch in range(1, epochs + 1):
denoiser.train()
total_loss, total_r_norm, n = 0, 0, 0
pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
bar_format='{l_bar}{bar:20}{r_bar}')
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
B = images.shape[0]
# ── Sample timestep ──
t = torch.rand(B, device=device)
# ── Noise the image ──
x_noised, eps = add_noise(images, t)
# ── Encode through both twins ──
with torch.no_grad():
f_out = fresnel(images) # clean
j_out = johanna(x_noised) # noised
S_f = f_out['svd']['S'] # target: (B, 64, 16)
S_j = j_out['svd']['S'] # input: (B, 64, 16)
# ── Procrustes alignment ──
with torch.no_grad():
R, R_feat = compute_procrustes_features(
j_out['svd']['U'], f_out['svd']['U'])
# ── Predict clean omega tokens ──
S_pred = denoiser(S_j, t, labels, R_feat)
loss = F.mse_loss(S_pred, S_f)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0)
opt.step()
total_loss += loss.item() * B
with torch.no_grad():
total_r_norm += (R - torch.eye(16, device=device)).norm(dim=(-2, -1)).mean().item() * B
n += B
pbar.set_postfix_str(f"loss={loss.item():.6f}")
sched.step()
# ── Validation ──
denoiser.eval()
val_loss, val_n = 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
B = images.shape[0]
t = torch.rand(B, device=device)
x_noised, _ = add_noise(images, t)
f_out = fresnel(images)
j_out = johanna(x_noised)
_, R_feat = compute_procrustes_features(
j_out['svd']['U'], f_out['svd']['U'])
S_pred = denoiser(j_out['svd']['S'], t, labels, R_feat)
val_loss += F.mse_loss(S_pred, f_out['svd']['S']).item() * B
val_n += B
train_l = total_loss / n
val_l = val_loss / val_n
r_norm = total_r_norm / n
if val_l < best_val:
best_val = val_l
torch.save({
'epoch': epoch, 'val_loss': val_l,
'model_state_dict': denoiser.state_dict(),
'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads},
}, os.path.join(save_dir, 'best.pt'))
if epoch % 5 == 0 or epoch <= 5:
print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} "
f"best={best_val:.6f} ||R-I||={r_norm:.3f}")
# ── Sample ──
if epoch % 25 == 0:
sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir)
print(f"\n TRAINING COMPLETE β€” best val: {best_val:.6f}")
return denoiser
# ═══════════════════════════════════════════════════════════════
# SAMPLING β€” ITERATIVE STEREO DENOISING
# ═══════════════════════════════════════════════════════════════
@torch.no_grad()
def sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir,
n_samples=4, n_steps=50):
"""Generate samples using iterative twin denoising.
1. Start from pure noise x_T
2. At each step:
a. Johanna encodes x_t β†’ (U_j, S_j, Vt_j)
b. Denoiser predicts clean S_f from S_j
c. Decode through Johanna's basis β†’ xΜ‚_0 estimate
d. Flow step toward xΜ‚_0
3. Final pass: encode through Fresnel β†’ decode with clean basis
"""
from geolip_svae.model import stitch_patches
denoiser.eval()
labels = torch.randint(0, 10, (n_samples, 4), device=device)
class_names = ['plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Start from noise
x = torch.randn(n_samples, 3, 128, 128, device=device)
for step in range(n_steps):
t_val = 1.0 - step / n_steps
t = torch.full((n_samples,), t_val, device=device)
# Johanna sees current state
j_out = johanna(x)
S_j = j_out['svd']['S']
# Denoiser predicts clean omega tokens (no R at inference)
S_pred = denoiser(S_j, t, labels, R_feat=None)
# Decode through Johanna's basis
decoded = johanna.decode_patches(
j_out['svd']['U'], S_pred, j_out['svd']['Vt'])
ps = johanna.patch_size
gh = gw = int(math.sqrt(S_j.shape[1]))
x_hat_0 = johanna.boundary_smooth(stitch_patches(decoded, gh, gw, ps))
# Flow step toward clean estimate
if step < n_steps - 1:
dt = 1.0 / n_steps
velocity = (x_hat_0 - x) / (t_val + 1e-4)
x = x - dt * velocity
else:
x = x_hat_0
# ── Final Fresnel polish ──
# Encode through Fresnel to get clean basis, re-decode
f_out = fresnel(x)
f_decoded = fresnel.decode_patches(
f_out['svd']['U'], f_out['svd']['S'], f_out['svd']['Vt'])
x_final = fresnel.boundary_smooth(stitch_patches(f_decoded, gh, gw, ps))
# ── Denormalize and save ──
mean = torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1).to(device)
std = torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1).to(device)
x_johanna = (x * std + mean).clamp(0, 1).cpu()
x_fresnel = (x_final * std + mean).clamp(0, 1).cpu()
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(n_samples, 2, figsize=(8, n_samples * 3))
if n_samples == 1:
axes = axes.unsqueeze(0)
for i in range(n_samples):
tile_labels = [class_names[l] for l in labels[i].cpu().tolist()]
axes[i, 0].imshow(x_johanna[i].permute(1, 2, 0).numpy())
axes[i, 0].set_title(f"Johanna decode: {tile_labels}", fontsize=7)
axes[i, 0].axis('off')
axes[i, 1].imshow(x_fresnel[i].permute(1, 2, 0).numpy())
axes[i, 1].set_title(f"Fresnel polish: {tile_labels}", fontsize=7)
axes[i, 1].axis('off')
plt.suptitle(f"Twin Stereo Diffusion β€” Epoch {epoch}", fontsize=10)
plt.tight_layout()
fname = os.path.join(save_dir, f'stereo_ep{epoch:03d}.png')
plt.savefig(fname, dpi=150, bbox_inches='tight')
plt.close()
print(f" Samples saved: {fname}")
print(f" Labels: {labels.cpu().tolist()}")
# ═══════════════════════════════════════════════════════════════
# ADVANCED SAMPLING β€” DUAL-ENCODE REFINEMENT
# ═══════════════════════════════════════════════════════════════
@torch.no_grad()
def sample_stereo_refined(denoiser, fresnel, johanna, labels, device,
n_steps=50):
"""Two-pass refinement: use Fresnel to estimate R at inference.
At each step:
1. Johanna(x_t) β†’ (U_j, S_j, Vt_j)
2. Pass 1: Denoiser(S_j, t, labels) β†’ S_pred (no R)
3. Decode β†’ xΜ‚_0, encode through Fresnel β†’ U_f_est
4. R_est = Procrustes(U_j, U_f_est)
5. Pass 2: Denoiser(S_j, t, labels, R_est) β†’ S_refined
6. Decode through Fresnel's estimated basis β†’ x_{t-1}
"""
from geolip_svae.model import stitch_patches
B = labels.shape[0]
x = torch.randn(B, 3, 128, 128, device=device)
ps = johanna.patch_size
for step in range(n_steps):
t_val = 1.0 - step / n_steps
t = torch.full((B,), t_val, device=device)
# Johanna encodes current state
j_out = johanna(x)
S_j = j_out['svd']['S']
gh = gw = int(math.sqrt(S_j.shape[1]))
# Pass 1: predict without R
S_pred_1 = denoiser(S_j, t, labels, R_feat=None)
# Decode pass 1 through Johanna
dec_1 = johanna.decode_patches(j_out['svd']['U'], S_pred_1, j_out['svd']['Vt'])
x_est = johanna.boundary_smooth(stitch_patches(dec_1, gh, gw, ps))
# Fresnel sees the estimate β†’ get clean-style basis
f_est = fresnel(x_est)
# Procrustes: how far is Johanna's basis from Fresnel's?
_, R_feat = compute_procrustes_features(
j_out['svd']['U'], f_est['svd']['U'])
# Pass 2: predict WITH R conditioning
S_pred_2 = denoiser(S_j, t, labels, R_feat)
# Decode through Fresnel's estimated basis
dec_2 = fresnel.decode_patches(
f_est['svd']['U'], S_pred_2, f_est['svd']['Vt'])
x_clean = fresnel.boundary_smooth(stitch_patches(dec_2, gh, gw, ps))
# Flow step
if step < n_steps - 1:
dt = 1.0 / n_steps
velocity = (x_clean - x) / (t_val + 1e-4)
x = x - dt * velocity
else:
x = x_clean
return x
# ═══════════════════════════════════════════════════════════════
# CLI
# ═══════════════════════════════════════════════════════════════
if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
train(
epochs=100,
batch_size=64, # 2 VAE forwards per batch, keep it moderate
lr=3e-4,
hidden=256,
depth=8,
n_heads=8,
n_train=50000,
)