#!/usr/bin/env python3 """ Flow Matching Diffusion with Constellation Relay Regulator ============================================================= ODE-based flow matching (not DDPM) on CIFAR-10. Constellation relay inserted at LayerNorm boundaries as geometric regulator. Flow matching: Forward: x_t = (1-t) * x_0 + t * ε Target: v = ε - x_0 Loss: ||v_pred(x_t, t) - v||² Sample: Euler ODE from t=1 → t=0 Architecture: Small UNet with ConvNeXt blocks Middle: self-attention + constellation relay after each norm Time + class conditioning via adaptive normalization The relay operates at the normalized manifold between blocks, snapping geometry back to the constellation reference frame after each attention + conv perturbation. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import os import time from tqdm import tqdm from torchvision import datasets, transforms from torchvision.utils import save_image, make_grid DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ══════════════════════════════════════════════════════════════════ # CONSTELLATION RELAY (adapted for feature maps) # ══════════════════════════════════════════════════════════════════ class ConstellationRelay(nn.Module): """ Geometric regulator for feature maps. Operates on channel dimension after spatial pooling or per-pixel. Input: (B, C, H, W) feature map Mode: 'channel' — pool spatial, relay on (B, C), unpool back 'pixel' — relay on (B*H*W, C) — expensive but thorough """ def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3, pw_hidden=32, gate_init=-3.0, mode='channel'): super().__init__() assert channels % patch_dim == 0 self.channels = channels self.patch_dim = patch_dim self.n_patches = channels // patch_dim self.n_anchors = n_anchors self.n_phases = n_phases self.mode = mode P, A, d = self.n_patches, n_anchors, patch_dim home = torch.empty(P, A, d) nn.init.xavier_normal_(home.view(P * A, d)) home = F.normalize(home.view(P, A, d), dim=-1) self.register_buffer('home', home) self.anchors = nn.Parameter(home.clone()) tri_dim = n_phases * A self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden)) self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden)) self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) self.pw_b2 = nn.Parameter(torch.zeros(1, P, d)) for p in range(P): nn.init.xavier_normal_(self.pw_w1.data[p]) nn.init.xavier_normal_(self.pw_w2.data[p]) self.pw_norm = nn.LayerNorm(d) self.gates = nn.Parameter(torch.full((P,), gate_init)) self.norm = nn.LayerNorm(channels) def drift(self): h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) def at_phase(self, t): h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) omega = self.drift().unsqueeze(-1) so = omega.sin().clamp(min=1e-7) return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c def _relay_core(self, x_flat): """x_flat: (N, C) → (N, C)""" N, C = x_flat.shape P, A, d = self.n_patches, self.n_anchors, self.patch_dim x_n = self.norm(x_flat) patches = x_n.reshape(N, P, d) patches_n = F.normalize(patches, dim=-1) phases = torch.linspace(0, 1, self.n_phases).tolist() tris = [] for t in phases: at = F.normalize(self.at_phase(t), dim=-1) tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at)) tri = torch.cat(tris, dim=-1) h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1) pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2) g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1) blended = g * pw + (1-g) * patches return x_flat + blended.reshape(N, C) def forward(self, x): """x: (B, C, H, W)""" B, C, H, W = x.shape if self.mode == 'channel': # Global average pool → relay → broadcast back pooled = x.mean(dim=(-2, -1)) # (B, C) relayed = self._relay_core(pooled) # (B, C) # Scale feature map by relay correction scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1) return x * scale.clamp(-3, 3) # prevent extreme scaling else: # Per-pixel relay — (B*H*W, C) x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C) out = self._relay_core(x_flat) return out.reshape(B, H, W, C).permute(0, 3, 1, 2) # ══════════════════════════════════════════════════════════════════ # BUILDING BLOCKS # ══════════════════════════════════════════════════════════════════ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): half = self.dim // 2 emb = math.log(10000) / (half - 1) emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb) emb = t.unsqueeze(-1) * emb.unsqueeze(0) return torch.cat([emb.sin(), emb.cos()], dim=-1) class AdaGroupNorm(nn.Module): """Group norm with adaptive scale/shift from conditioning.""" def __init__(self, channels, cond_dim, n_groups=8): super().__init__() self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False) self.proj = nn.Linear(cond_dim, channels * 2) nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x, cond): x = self.gn(x) scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) return x * (1 + scale) + shift class ConvBlock(nn.Module): """ConvNeXt-style block with adaptive norm.""" def __init__(self, channels, cond_dim, use_relay=False): super().__init__() self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels) self.norm = AdaGroupNorm(channels, cond_dim) self.pw1 = nn.Conv2d(channels, channels * 4, 1) self.pw2 = nn.Conv2d(channels * 4, channels, 1) self.act = nn.GELU() self.relay = ConstellationRelay( channels, patch_dim=min(16, channels), n_anchors=min(16, channels), n_phases=3, pw_hidden=32, gate_init=-3.0, mode='channel') if use_relay else None def forward(self, x, cond): residual = x x = self.dw_conv(x) x = self.norm(x, cond) x = self.pw1(x) x = self.act(x) x = self.pw2(x) x = residual + x if self.relay is not None: x = self.relay(x) return x class SelfAttnBlock(nn.Module): """Simple self-attention for feature maps.""" def __init__(self, channels, n_heads=4): super().__init__() self.n_heads = n_heads self.head_dim = channels // n_heads self.norm = nn.GroupNorm(8, channels) self.qkv = nn.Conv2d(channels, channels * 3, 1) self.out = nn.Conv2d(channels, channels, 1) nn.init.zeros_(self.out.weight) nn.init.zeros_(self.out.bias) def forward(self, x): B, C, H, W = x.shape residual = x x = self.norm(x) qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W) q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] attn = F.scaled_dot_product_attention(q, k, v) out = attn.reshape(B, C, H, W) return residual + self.out(out) class Downsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) def forward(self, x): return self.conv(x) class Upsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x): x = F.interpolate(x, scale_factor=2, mode='nearest') return self.conv(x) # ══════════════════════════════════════════════════════════════════ # FLOW MATCHING UNET # ══════════════════════════════════════════════════════════════════ class FlowMatchUNet(nn.Module): """ Clean UNet for flow matching. Explicit skip tracking — no dynamic insertion. Encoder: [64@32] → down → [128@16] → down → [256@8] Middle: [256@8] with attention + relay Decoder: [256@8] → up → [128@16] → up → [64@32] """ def __init__( self, in_channels=3, base_channels=64, channel_mults=(1, 2, 4), n_classes=10, cond_dim=256, use_relay=True, ): super().__init__() self.use_relay = use_relay self.channel_mults = channel_mults # Time + class conditioning self.time_emb = nn.Sequential( SinusoidalPosEmb(cond_dim), nn.Linear(cond_dim, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim)) self.class_emb = nn.Embedding(n_classes, cond_dim) # Input projection self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1) # Build encoder: 2 conv blocks per level, then downsample self.enc = nn.ModuleList() self.enc_down = nn.ModuleList() ch_in = base_channels enc_channels = [base_channels] # track channels at each skip point for i, mult in enumerate(channel_mults): ch_out = base_channels * mult self.enc.append(nn.ModuleList([ ConvBlock(ch_in, cond_dim) if ch_in == ch_out else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1), ConvBlock(ch_out, cond_dim)), ConvBlock(ch_out, cond_dim), ])) ch_in = ch_out enc_channels.append(ch_out) if i < len(channel_mults) - 1: self.enc_down.append(Downsample(ch_out)) # Middle mid_ch = ch_in self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay) self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4) self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay) # Build decoder: upsample, concat skip, 2 conv blocks per level self.dec_up = nn.ModuleList() self.dec_skip_proj = nn.ModuleList() self.dec = nn.ModuleList() for i in range(len(channel_mults) - 1, -1, -1): mult = channel_mults[i] ch_out = base_channels * mult skip_ch = enc_channels.pop() # Project concatenated channels self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1)) self.dec.append(nn.ModuleList([ ConvBlock(ch_out, cond_dim), ConvBlock(ch_out, cond_dim), ])) ch_in = ch_out if i > 0: self.dec_up.append(Upsample(ch_out)) # Output self.out_norm = nn.GroupNorm(8, ch_in) self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1) nn.init.zeros_(self.out_conv.weight) nn.init.zeros_(self.out_conv.bias) def forward(self, x, t, class_labels): cond = self.time_emb(t) + self.class_emb(class_labels) h = self.in_conv(x) skips = [h] # Encoder for i in range(len(self.channel_mults)): for block in self.enc[i]: if isinstance(block, ConvBlock): h = block(h, cond) elif isinstance(block, nn.Sequential): # Conv1x1 then ConvBlock h = block[0](h) h = block[1](h, cond) else: h = block(h) skips.append(h) if i < len(self.enc_down): h = self.enc_down[i](h) # Middle h = self.mid_block1(h, cond) h = self.mid_attn(h) h = self.mid_block2(h, cond) # Decoder for i in range(len(self.channel_mults)): skip = skips.pop() # Upsample first if needed (except first decoder level) if i > 0: h = self.dec_up[i - 1](h) h = torch.cat([h, skip], dim=1) h = self.dec_skip_proj[i](h) for block in self.dec[i]: h = block(h, cond) h = self.out_norm(h) h = F.silu(h) return self.out_conv(h) # ══════════════════════════════════════════════════════════════════ # FLOW MATCHING TRAINING # ══════════════════════════════════════════════════════════════════ # Hyperparams BATCH = 128 EPOCHS = 50 LR = 3e-4 BASE_CH = 64 USE_RELAY = True N_CLASSES = 10 SAMPLE_EVERY = 5 N_SAMPLE_STEPS = 50 # Euler ODE steps for sampling print("=" * 70) print("FLOW MATCHING + CONSTELLATION RELAY REGULATOR") print(f" Dataset: CIFAR-10") print(f" Base channels: {BASE_CH}") print(f" Relay: {USE_RELAY}") print(f" Flow matching: ODE (conditional)") print(f" Sampler: Euler, {N_SAMPLE_STEPS} steps") print(f" Device: {DEVICE}") print("=" * 70) # Data transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) print(f" Train: {len(train_ds):,} images") # Model model = FlowMatchUNet( in_channels=3, base_channels=BASE_CH, channel_mults=(1, 2, 4), n_classes=N_CLASSES, cond_dim=256, use_relay=USE_RELAY ).to(DEVICE) n_params = sum(p.numel() for p in model.parameters()) relay_params = sum(p.numel() for n, p in model.named_parameters() if 'relay' in n) print(f" Total params: {n_params:,}") print(f" Relay params: {relay_params:,} ({100*relay_params/n_params:.1f}%)") # Count relay modules n_relays = sum(1 for m in model.modules() if isinstance(m, ConstellationRelay)) print(f" Relay modules: {n_relays}") optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6) scaler = torch.amp.GradScaler("cuda") os.makedirs("samples", exist_ok=True) os.makedirs("checkpoints", exist_ok=True) @torch.no_grad() def sample(model, n_samples=64, n_steps=50, class_label=None): """Euler ODE sampling from t=1 (noise) to t=0 (data).""" model.eval() B = n_samples x = torch.randn(B, 3, 32, 32, device=DEVICE) if class_label is not None: labels = torch.full((B,), class_label, dtype=torch.long, device=DEVICE) else: labels = torch.randint(0, N_CLASSES, (B,), device=DEVICE) dt = 1.0 / n_steps for step in range(n_steps): t_val = 1.0 - step * dt t = torch.full((B,), t_val, device=DEVICE) with torch.amp.autocast("cuda", dtype=torch.bfloat16): v = model(x, t, labels) x = x - v * dt # Euler step: x_{t-dt} = x_t - v * dt # Clamp to valid range x = x.clamp(-1, 1) return x, labels # ══════════════════════════════════════════════════════════════════ # TRAINING LOOP # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*70}") print(f"TRAINING — {EPOCHS} epochs") print(f"{'='*70}") best_loss = float('inf') gs = 0 for epoch in range(EPOCHS): model.train() t0 = time.time() total_loss = 0 n = 0 pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") for images, labels in pbar: images = images.to(DEVICE, non_blocking=True) # (B, 3, 32, 32) in [-1, 1] labels = labels.to(DEVICE, non_blocking=True) B = images.shape[0] # Flow matching: sample t, compute x_t and target velocity t = torch.rand(B, device=DEVICE) eps = torch.randn_like(images) # x_t = (1-t) * x_0 + t * eps t_b = t.view(B, 1, 1, 1) x_t = (1 - t_b) * images + t_b * eps # Target velocity: v = eps - x_0 v_target = eps - images with torch.amp.autocast("cuda", dtype=torch.bfloat16): v_pred = model(x_t, t, labels) loss = F.mse_loss(v_pred, v_target) optimizer.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() gs += 1 total_loss += loss.item() n += 1 if n % 20 == 0: pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}") elapsed = time.time() - t0 avg_loss = total_loss / n # Checkpoint mk = "" if avg_loss < best_loss: best_loss = avg_loss torch.save({ 'state_dict': model.state_dict(), 'epoch': epoch + 1, 'loss': avg_loss, 'use_relay': USE_RELAY, }, 'checkpoints/flow_match_best.pt') mk = " ★" print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} " f"({elapsed:.0f}s){mk}") # Sample if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0: samples, sample_labels = sample(model, n_samples=64, n_steps=N_SAMPLE_STEPS) # Denormalize samples = (samples + 1) / 2 # [-1,1] → [0,1] grid = make_grid(samples, nrow=8, normalize=False) save_image(grid, f'samples/epoch_{epoch+1:03d}.png') print(f" → Saved samples/epoch_{epoch+1:03d}.png") # Per-class samples if (epoch + 1) % (SAMPLE_EVERY * 2) == 0: class_names = ['plane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] for c in range(N_CLASSES): cs, _ = sample(model, n_samples=8, n_steps=N_SAMPLE_STEPS, class_label=c) cs = (cs + 1) / 2 save_image(make_grid(cs, nrow=8), f'samples/epoch_{epoch+1:03d}_class_{class_names[c]}.png') # Relay diagnostics if USE_RELAY and (epoch + 1) % 10 == 0: print(f" Relay diagnostics:") for name, module in model.named_modules(): if isinstance(module, ConstellationRelay): drift = module.drift().mean().item() gate = module.gates.sigmoid().mean().item() print(f" {name}: drift={drift:.4f} rad " f"({math.degrees(drift):.1f}°) gate={gate:.4f}") print(f"\n{'='*70}") print(f"DONE — Best loss: {best_loss:.4f}") print(f" Params: {n_params:,} (relay: {relay_params:,})") print(f" Samples in: samples/") print(f"{'='*70}")