geolip-diffusion-proto / constellation_diffusion_cifar10_trainer_standalone.py
AbstractPhil's picture
Rename constellation_diffusion.py to constellation_diffusion_cifar10_trainer_standalone.py
a9031b9 verified
#!/usr/bin/env python3
"""
Flow Matching Diffusion with Constellation Relay Regulator
=============================================================
ODE-based flow matching (not DDPM) on CIFAR-10.
Constellation relay inserted at LayerNorm boundaries as
geometric regulator.
Flow matching:
Forward: x_t = (1-t) * x_0 + t * Ξ΅
Target: v = Ξ΅ - x_0
Loss: ||v_pred(x_t, t) - v||Β²
Sample: Euler ODE from t=1 β†’ t=0
Architecture:
Small UNet with ConvNeXt blocks
Middle: self-attention + constellation relay after each norm
Time + class conditioning via adaptive normalization
The relay operates at the normalized manifold between blocks,
snapping geometry back to the constellation reference frame
after each attention + conv perturbation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import time
from tqdm import tqdm
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION RELAY (adapted for feature maps)
# ══════════════════════════════════════════════════════════════════
class ConstellationRelay(nn.Module):
"""
Geometric regulator for feature maps.
Operates on channel dimension after spatial pooling or per-pixel.
Input: (B, C, H, W) feature map
Mode: 'channel' β€” pool spatial, relay on (B, C), unpool back
'pixel' β€” relay on (B*H*W, C) β€” expensive but thorough
"""
def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3,
pw_hidden=32, gate_init=-3.0, mode='channel'):
super().__init__()
assert channels % patch_dim == 0
self.channels = channels
self.patch_dim = patch_dim
self.n_patches = channels // patch_dim
self.n_anchors = n_anchors
self.n_phases = n_phases
self.mode = mode
P, A, d = self.n_patches, n_anchors, patch_dim
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
tri_dim = n_phases * A
self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))
self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden))
self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
self.pw_b2 = nn.Parameter(torch.zeros(1, P, d))
for p in range(P):
nn.init.xavier_normal_(self.pw_w1.data[p])
nn.init.xavier_normal_(self.pw_w2.data[p])
self.pw_norm = nn.LayerNorm(d)
self.gates = nn.Parameter(torch.full((P,), gate_init))
self.norm = nn.LayerNorm(channels)
def drift(self):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
def at_phase(self, t):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-7)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def _relay_core(self, x_flat):
"""x_flat: (N, C) β†’ (N, C)"""
N, C = x_flat.shape
P, A, d = self.n_patches, self.n_anchors, self.patch_dim
x_n = self.norm(x_flat)
patches = x_n.reshape(N, P, d)
patches_n = F.normalize(patches, dim=-1)
phases = torch.linspace(0, 1, self.n_phases).tolist()
tris = []
for t in phases:
at = F.normalize(self.at_phase(t), dim=-1)
tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at))
tri = torch.cat(tris, dim=-1)
h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1)
pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2)
g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)
blended = g * pw + (1-g) * patches
return x_flat + blended.reshape(N, C)
def forward(self, x):
"""x: (B, C, H, W)"""
B, C, H, W = x.shape
if self.mode == 'channel':
# Global average pool β†’ relay β†’ broadcast back
pooled = x.mean(dim=(-2, -1)) # (B, C)
relayed = self._relay_core(pooled) # (B, C)
# Scale feature map by relay correction
scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1)
return x * scale.clamp(-3, 3) # prevent extreme scaling
else:
# Per-pixel relay β€” (B*H*W, C)
x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
out = self._relay_core(x_flat)
return out.reshape(B, H, W, C).permute(0, 3, 1, 2)
# ══════════════════════════════════════════════════════════════════
# BUILDING BLOCKS
# ══════════════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class AdaGroupNorm(nn.Module):
"""Group norm with adaptive scale/shift from conditioning."""
def __init__(self, channels, cond_dim, n_groups=8):
super().__init__()
self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False)
self.proj = nn.Linear(cond_dim, channels * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x, cond):
x = self.gn(x)
scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
return x * (1 + scale) + shift
class ConvBlock(nn.Module):
"""ConvNeXt-style block with adaptive norm."""
def __init__(self, channels, cond_dim, use_relay=False):
super().__init__()
self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
self.norm = AdaGroupNorm(channels, cond_dim)
self.pw1 = nn.Conv2d(channels, channels * 4, 1)
self.pw2 = nn.Conv2d(channels * 4, channels, 1)
self.act = nn.GELU()
self.relay = ConstellationRelay(
channels, patch_dim=min(16, channels),
n_anchors=min(16, channels),
n_phases=3, pw_hidden=32, gate_init=-3.0,
mode='channel') if use_relay else None
def forward(self, x, cond):
residual = x
x = self.dw_conv(x)
x = self.norm(x, cond)
x = self.pw1(x)
x = self.act(x)
x = self.pw2(x)
x = residual + x
if self.relay is not None:
x = self.relay(x)
return x
class SelfAttnBlock(nn.Module):
"""Simple self-attention for feature maps."""
def __init__(self, channels, n_heads=4):
super().__init__()
self.n_heads = n_heads
self.head_dim = channels // n_heads
self.norm = nn.GroupNorm(8, channels)
self.qkv = nn.Conv2d(channels, channels * 3, 1)
self.out = nn.Conv2d(channels, channels, 1)
nn.init.zeros_(self.out.weight)
nn.init.zeros_(self.out.bias)
def forward(self, x):
B, C, H, W = x.shape
residual = x
x = self.norm(x)
qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W)
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
attn = F.scaled_dot_product_attention(q, k, v)
out = attn.reshape(B, C, H, W)
return residual + self.out(out)
class Downsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
return self.conv(x)
# ══════════════════════════════════════════════════════════════════
# FLOW MATCHING UNET
# ══════════════════════════════════════════════════════════════════
class FlowMatchUNet(nn.Module):
"""
Clean UNet for flow matching.
Explicit skip tracking β€” no dynamic insertion.
Encoder: [64@32] β†’ down β†’ [128@16] β†’ down β†’ [256@8]
Middle: [256@8] with attention + relay
Decoder: [256@8] β†’ up β†’ [128@16] β†’ up β†’ [64@32]
"""
def __init__(
self,
in_channels=3,
base_channels=64,
channel_mults=(1, 2, 4),
n_classes=10,
cond_dim=256,
use_relay=True,
):
super().__init__()
self.use_relay = use_relay
self.channel_mults = channel_mults
# Time + class conditioning
self.time_emb = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.GELU(),
nn.Linear(cond_dim, cond_dim))
self.class_emb = nn.Embedding(n_classes, cond_dim)
# Input projection
self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
# Build encoder: 2 conv blocks per level, then downsample
self.enc = nn.ModuleList()
self.enc_down = nn.ModuleList()
ch_in = base_channels
enc_channels = [base_channels] # track channels at each skip point
for i, mult in enumerate(channel_mults):
ch_out = base_channels * mult
self.enc.append(nn.ModuleList([
ConvBlock(ch_in, cond_dim) if ch_in == ch_out
else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1),
ConvBlock(ch_out, cond_dim)),
ConvBlock(ch_out, cond_dim),
]))
ch_in = ch_out
enc_channels.append(ch_out)
if i < len(channel_mults) - 1:
self.enc_down.append(Downsample(ch_out))
# Middle
mid_ch = ch_in
self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay)
self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4)
self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay)
# Build decoder: upsample, concat skip, 2 conv blocks per level
self.dec_up = nn.ModuleList()
self.dec_skip_proj = nn.ModuleList()
self.dec = nn.ModuleList()
for i in range(len(channel_mults) - 1, -1, -1):
mult = channel_mults[i]
ch_out = base_channels * mult
skip_ch = enc_channels.pop()
# Project concatenated channels
self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1))
self.dec.append(nn.ModuleList([
ConvBlock(ch_out, cond_dim),
ConvBlock(ch_out, cond_dim),
]))
ch_in = ch_out
if i > 0:
self.dec_up.append(Upsample(ch_out))
# Output
self.out_norm = nn.GroupNorm(8, ch_in)
self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1)
nn.init.zeros_(self.out_conv.weight)
nn.init.zeros_(self.out_conv.bias)
def forward(self, x, t, class_labels):
cond = self.time_emb(t) + self.class_emb(class_labels)
h = self.in_conv(x)
skips = [h]
# Encoder
for i in range(len(self.channel_mults)):
for block in self.enc[i]:
if isinstance(block, ConvBlock):
h = block(h, cond)
elif isinstance(block, nn.Sequential):
# Conv1x1 then ConvBlock
h = block[0](h)
h = block[1](h, cond)
else:
h = block(h)
skips.append(h)
if i < len(self.enc_down):
h = self.enc_down[i](h)
# Middle
h = self.mid_block1(h, cond)
h = self.mid_attn(h)
h = self.mid_block2(h, cond)
# Decoder
for i in range(len(self.channel_mults)):
skip = skips.pop()
# Upsample first if needed (except first decoder level)
if i > 0:
h = self.dec_up[i - 1](h)
h = torch.cat([h, skip], dim=1)
h = self.dec_skip_proj[i](h)
for block in self.dec[i]:
h = block(h, cond)
h = self.out_norm(h)
h = F.silu(h)
return self.out_conv(h)
# ══════════════════════════════════════════════════════════════════
# FLOW MATCHING TRAINING
# ══════════════════════════════════════════════════════════════════
# Hyperparams
BATCH = 128
EPOCHS = 50
LR = 3e-4
BASE_CH = 64
USE_RELAY = True
N_CLASSES = 10
SAMPLE_EVERY = 5
N_SAMPLE_STEPS = 50 # Euler ODE steps for sampling
print("=" * 70)
print("FLOW MATCHING + CONSTELLATION RELAY REGULATOR")
print(f" Dataset: CIFAR-10")
print(f" Base channels: {BASE_CH}")
print(f" Relay: {USE_RELAY}")
print(f" Flow matching: ODE (conditional)")
print(f" Sampler: Euler, {N_SAMPLE_STEPS} steps")
print(f" Device: {DEVICE}")
print("=" * 70)
# Data
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
print(f" Train: {len(train_ds):,} images")
# Model
model = FlowMatchUNet(
in_channels=3, base_channels=BASE_CH,
channel_mults=(1, 2, 4), n_classes=N_CLASSES,
cond_dim=256, use_relay=USE_RELAY
).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
relay_params = sum(p.numel() for n, p in model.named_parameters() if 'relay' in n)
print(f" Total params: {n_params:,}")
print(f" Relay params: {relay_params:,} ({100*relay_params/n_params:.1f}%)")
# Count relay modules
n_relays = sum(1 for m in model.modules() if isinstance(m, ConstellationRelay))
print(f" Relay modules: {n_relays}")
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6)
scaler = torch.amp.GradScaler("cuda")
os.makedirs("samples", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
@torch.no_grad()
def sample(model, n_samples=64, n_steps=50, class_label=None):
"""Euler ODE sampling from t=1 (noise) to t=0 (data)."""
model.eval()
B = n_samples
x = torch.randn(B, 3, 32, 32, device=DEVICE)
if class_label is not None:
labels = torch.full((B,), class_label, dtype=torch.long, device=DEVICE)
else:
labels = torch.randint(0, N_CLASSES, (B,), device=DEVICE)
dt = 1.0 / n_steps
for step in range(n_steps):
t_val = 1.0 - step * dt
t = torch.full((B,), t_val, device=DEVICE)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v = model(x, t, labels)
x = x - v * dt # Euler step: x_{t-dt} = x_t - v * dt
# Clamp to valid range
x = x.clamp(-1, 1)
return x, labels
# ══════════════════════════════════════════════════════════════════
# TRAINING LOOP
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*70}")
print(f"TRAINING β€” {EPOCHS} epochs")
print(f"{'='*70}")
best_loss = float('inf')
gs = 0
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
total_loss = 0
n = 0
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
for images, labels in pbar:
images = images.to(DEVICE, non_blocking=True) # (B, 3, 32, 32) in [-1, 1]
labels = labels.to(DEVICE, non_blocking=True)
B = images.shape[0]
# Flow matching: sample t, compute x_t and target velocity
t = torch.rand(B, device=DEVICE)
eps = torch.randn_like(images)
# x_t = (1-t) * x_0 + t * eps
t_b = t.view(B, 1, 1, 1)
x_t = (1 - t_b) * images + t_b * eps
# Target velocity: v = eps - x_0
v_target = eps - images
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v_pred = model(x_t, t, labels)
loss = F.mse_loss(v_pred, v_target)
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
gs += 1
total_loss += loss.item()
n += 1
if n % 20 == 0:
pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")
elapsed = time.time() - t0
avg_loss = total_loss / n
# Checkpoint
mk = ""
if avg_loss < best_loss:
best_loss = avg_loss
torch.save({
'state_dict': model.state_dict(),
'epoch': epoch + 1,
'loss': avg_loss,
'use_relay': USE_RELAY,
}, 'checkpoints/flow_match_best.pt')
mk = " β˜…"
print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} "
f"({elapsed:.0f}s){mk}")
# Sample
if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:
samples, sample_labels = sample(model, n_samples=64, n_steps=N_SAMPLE_STEPS)
# Denormalize
samples = (samples + 1) / 2 # [-1,1] β†’ [0,1]
grid = make_grid(samples, nrow=8, normalize=False)
save_image(grid, f'samples/epoch_{epoch+1:03d}.png')
print(f" β†’ Saved samples/epoch_{epoch+1:03d}.png")
# Per-class samples
if (epoch + 1) % (SAMPLE_EVERY * 2) == 0:
class_names = ['plane', 'auto', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
for c in range(N_CLASSES):
cs, _ = sample(model, n_samples=8, n_steps=N_SAMPLE_STEPS, class_label=c)
cs = (cs + 1) / 2
save_image(make_grid(cs, nrow=8),
f'samples/epoch_{epoch+1:03d}_class_{class_names[c]}.png')
# Relay diagnostics
if USE_RELAY and (epoch + 1) % 10 == 0:
print(f" Relay diagnostics:")
for name, module in model.named_modules():
if isinstance(module, ConstellationRelay):
drift = module.drift().mean().item()
gate = module.gates.sigmoid().mean().item()
print(f" {name}: drift={drift:.4f} rad "
f"({math.degrees(drift):.1f}Β°) gate={gate:.4f}")
print(f"\n{'='*70}")
print(f"DONE β€” Best loss: {best_loss:.4f}")
print(f" Params: {n_params:,} (relay: {relay_params:,})")
print(f" Samples in: samples/")
print(f"{'='*70}")