Rename constellation_diffusion.py to constellation_diffusion_cifar10_trainer_standalone.py
a9031b9 verified | #!/usr/bin/env python3 | |
| """ | |
| Flow Matching Diffusion with Constellation Relay Regulator | |
| ============================================================= | |
| ODE-based flow matching (not DDPM) on CIFAR-10. | |
| Constellation relay inserted at LayerNorm boundaries as | |
| geometric regulator. | |
| Flow matching: | |
| Forward: x_t = (1-t) * x_0 + t * Ξ΅ | |
| Target: v = Ξ΅ - x_0 | |
| Loss: ||v_pred(x_t, t) - v||Β² | |
| Sample: Euler ODE from t=1 β t=0 | |
| Architecture: | |
| Small UNet with ConvNeXt blocks | |
| Middle: self-attention + constellation relay after each norm | |
| Time + class conditioning via adaptive normalization | |
| The relay operates at the normalized manifold between blocks, | |
| snapping geometry back to the constellation reference frame | |
| after each attention + conv perturbation. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import math | |
| import os | |
| import time | |
| from tqdm import tqdm | |
| from torchvision import datasets, transforms | |
| from torchvision.utils import save_image, make_grid | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONSTELLATION RELAY (adapted for feature maps) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ConstellationRelay(nn.Module): | |
| """ | |
| Geometric regulator for feature maps. | |
| Operates on channel dimension after spatial pooling or per-pixel. | |
| Input: (B, C, H, W) feature map | |
| Mode: 'channel' β pool spatial, relay on (B, C), unpool back | |
| 'pixel' β relay on (B*H*W, C) β expensive but thorough | |
| """ | |
| def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3, | |
| pw_hidden=32, gate_init=-3.0, mode='channel'): | |
| super().__init__() | |
| assert channels % patch_dim == 0 | |
| self.channels = channels | |
| self.patch_dim = patch_dim | |
| self.n_patches = channels // patch_dim | |
| self.n_anchors = n_anchors | |
| self.n_phases = n_phases | |
| self.mode = mode | |
| P, A, d = self.n_patches, n_anchors, patch_dim | |
| home = torch.empty(P, A, d) | |
| nn.init.xavier_normal_(home.view(P * A, d)) | |
| home = F.normalize(home.view(P, A, d), dim=-1) | |
| self.register_buffer('home', home) | |
| self.anchors = nn.Parameter(home.clone()) | |
| tri_dim = n_phases * A | |
| self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden)) | |
| self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden)) | |
| self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) | |
| self.pw_b2 = nn.Parameter(torch.zeros(1, P, d)) | |
| for p in range(P): | |
| nn.init.xavier_normal_(self.pw_w1.data[p]) | |
| nn.init.xavier_normal_(self.pw_w2.data[p]) | |
| self.pw_norm = nn.LayerNorm(d) | |
| self.gates = nn.Parameter(torch.full((P,), gate_init)) | |
| self.norm = nn.LayerNorm(channels) | |
| def drift(self): | |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) | |
| return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) | |
| def at_phase(self, t): | |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) | |
| omega = self.drift().unsqueeze(-1) | |
| so = omega.sin().clamp(min=1e-7) | |
| return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c | |
| def _relay_core(self, x_flat): | |
| """x_flat: (N, C) β (N, C)""" | |
| N, C = x_flat.shape | |
| P, A, d = self.n_patches, self.n_anchors, self.patch_dim | |
| x_n = self.norm(x_flat) | |
| patches = x_n.reshape(N, P, d) | |
| patches_n = F.normalize(patches, dim=-1) | |
| phases = torch.linspace(0, 1, self.n_phases).tolist() | |
| tris = [] | |
| for t in phases: | |
| at = F.normalize(self.at_phase(t), dim=-1) | |
| tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at)) | |
| tri = torch.cat(tris, dim=-1) | |
| h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1) | |
| pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2) | |
| g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1) | |
| blended = g * pw + (1-g) * patches | |
| return x_flat + blended.reshape(N, C) | |
| def forward(self, x): | |
| """x: (B, C, H, W)""" | |
| B, C, H, W = x.shape | |
| if self.mode == 'channel': | |
| # Global average pool β relay β broadcast back | |
| pooled = x.mean(dim=(-2, -1)) # (B, C) | |
| relayed = self._relay_core(pooled) # (B, C) | |
| # Scale feature map by relay correction | |
| scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1) | |
| return x * scale.clamp(-3, 3) # prevent extreme scaling | |
| else: | |
| # Per-pixel relay β (B*H*W, C) | |
| x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C) | |
| out = self._relay_core(x_flat) | |
| return out.reshape(B, H, W, C).permute(0, 3, 1, 2) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BUILDING BLOCKS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, t): | |
| half = self.dim // 2 | |
| emb = math.log(10000) / (half - 1) | |
| emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb) | |
| emb = t.unsqueeze(-1) * emb.unsqueeze(0) | |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) | |
| class AdaGroupNorm(nn.Module): | |
| """Group norm with adaptive scale/shift from conditioning.""" | |
| def __init__(self, channels, cond_dim, n_groups=8): | |
| super().__init__() | |
| self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False) | |
| self.proj = nn.Linear(cond_dim, channels * 2) | |
| nn.init.zeros_(self.proj.weight) | |
| nn.init.zeros_(self.proj.bias) | |
| def forward(self, x, cond): | |
| x = self.gn(x) | |
| scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) | |
| return x * (1 + scale) + shift | |
| class ConvBlock(nn.Module): | |
| """ConvNeXt-style block with adaptive norm.""" | |
| def __init__(self, channels, cond_dim, use_relay=False): | |
| super().__init__() | |
| self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels) | |
| self.norm = AdaGroupNorm(channels, cond_dim) | |
| self.pw1 = nn.Conv2d(channels, channels * 4, 1) | |
| self.pw2 = nn.Conv2d(channels * 4, channels, 1) | |
| self.act = nn.GELU() | |
| self.relay = ConstellationRelay( | |
| channels, patch_dim=min(16, channels), | |
| n_anchors=min(16, channels), | |
| n_phases=3, pw_hidden=32, gate_init=-3.0, | |
| mode='channel') if use_relay else None | |
| def forward(self, x, cond): | |
| residual = x | |
| x = self.dw_conv(x) | |
| x = self.norm(x, cond) | |
| x = self.pw1(x) | |
| x = self.act(x) | |
| x = self.pw2(x) | |
| x = residual + x | |
| if self.relay is not None: | |
| x = self.relay(x) | |
| return x | |
| class SelfAttnBlock(nn.Module): | |
| """Simple self-attention for feature maps.""" | |
| def __init__(self, channels, n_heads=4): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.head_dim = channels // n_heads | |
| self.norm = nn.GroupNorm(8, channels) | |
| self.qkv = nn.Conv2d(channels, channels * 3, 1) | |
| self.out = nn.Conv2d(channels, channels, 1) | |
| nn.init.zeros_(self.out.weight) | |
| nn.init.zeros_(self.out.bias) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| residual = x | |
| x = self.norm(x) | |
| qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W) | |
| q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] | |
| attn = F.scaled_dot_product_attention(q, k, v) | |
| out = attn.reshape(B, C, H, W) | |
| return residual + self.out(out) | |
| class Downsample(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Upsample(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) | |
| def forward(self, x): | |
| x = F.interpolate(x, scale_factor=2, mode='nearest') | |
| return self.conv(x) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FLOW MATCHING UNET | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FlowMatchUNet(nn.Module): | |
| """ | |
| Clean UNet for flow matching. | |
| Explicit skip tracking β no dynamic insertion. | |
| Encoder: [64@32] β down β [128@16] β down β [256@8] | |
| Middle: [256@8] with attention + relay | |
| Decoder: [256@8] β up β [128@16] β up β [64@32] | |
| """ | |
| def __init__( | |
| self, | |
| in_channels=3, | |
| base_channels=64, | |
| channel_mults=(1, 2, 4), | |
| n_classes=10, | |
| cond_dim=256, | |
| use_relay=True, | |
| ): | |
| super().__init__() | |
| self.use_relay = use_relay | |
| self.channel_mults = channel_mults | |
| # Time + class conditioning | |
| self.time_emb = nn.Sequential( | |
| SinusoidalPosEmb(cond_dim), | |
| nn.Linear(cond_dim, cond_dim), nn.GELU(), | |
| nn.Linear(cond_dim, cond_dim)) | |
| self.class_emb = nn.Embedding(n_classes, cond_dim) | |
| # Input projection | |
| self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1) | |
| # Build encoder: 2 conv blocks per level, then downsample | |
| self.enc = nn.ModuleList() | |
| self.enc_down = nn.ModuleList() | |
| ch_in = base_channels | |
| enc_channels = [base_channels] # track channels at each skip point | |
| for i, mult in enumerate(channel_mults): | |
| ch_out = base_channels * mult | |
| self.enc.append(nn.ModuleList([ | |
| ConvBlock(ch_in, cond_dim) if ch_in == ch_out | |
| else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1), | |
| ConvBlock(ch_out, cond_dim)), | |
| ConvBlock(ch_out, cond_dim), | |
| ])) | |
| ch_in = ch_out | |
| enc_channels.append(ch_out) | |
| if i < len(channel_mults) - 1: | |
| self.enc_down.append(Downsample(ch_out)) | |
| # Middle | |
| mid_ch = ch_in | |
| self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay) | |
| self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4) | |
| self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay) | |
| # Build decoder: upsample, concat skip, 2 conv blocks per level | |
| self.dec_up = nn.ModuleList() | |
| self.dec_skip_proj = nn.ModuleList() | |
| self.dec = nn.ModuleList() | |
| for i in range(len(channel_mults) - 1, -1, -1): | |
| mult = channel_mults[i] | |
| ch_out = base_channels * mult | |
| skip_ch = enc_channels.pop() | |
| # Project concatenated channels | |
| self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1)) | |
| self.dec.append(nn.ModuleList([ | |
| ConvBlock(ch_out, cond_dim), | |
| ConvBlock(ch_out, cond_dim), | |
| ])) | |
| ch_in = ch_out | |
| if i > 0: | |
| self.dec_up.append(Upsample(ch_out)) | |
| # Output | |
| self.out_norm = nn.GroupNorm(8, ch_in) | |
| self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1) | |
| nn.init.zeros_(self.out_conv.weight) | |
| nn.init.zeros_(self.out_conv.bias) | |
| def forward(self, x, t, class_labels): | |
| cond = self.time_emb(t) + self.class_emb(class_labels) | |
| h = self.in_conv(x) | |
| skips = [h] | |
| # Encoder | |
| for i in range(len(self.channel_mults)): | |
| for block in self.enc[i]: | |
| if isinstance(block, ConvBlock): | |
| h = block(h, cond) | |
| elif isinstance(block, nn.Sequential): | |
| # Conv1x1 then ConvBlock | |
| h = block[0](h) | |
| h = block[1](h, cond) | |
| else: | |
| h = block(h) | |
| skips.append(h) | |
| if i < len(self.enc_down): | |
| h = self.enc_down[i](h) | |
| # Middle | |
| h = self.mid_block1(h, cond) | |
| h = self.mid_attn(h) | |
| h = self.mid_block2(h, cond) | |
| # Decoder | |
| for i in range(len(self.channel_mults)): | |
| skip = skips.pop() | |
| # Upsample first if needed (except first decoder level) | |
| if i > 0: | |
| h = self.dec_up[i - 1](h) | |
| h = torch.cat([h, skip], dim=1) | |
| h = self.dec_skip_proj[i](h) | |
| for block in self.dec[i]: | |
| h = block(h, cond) | |
| h = self.out_norm(h) | |
| h = F.silu(h) | |
| return self.out_conv(h) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FLOW MATCHING TRAINING | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Hyperparams | |
| BATCH = 128 | |
| EPOCHS = 50 | |
| LR = 3e-4 | |
| BASE_CH = 64 | |
| USE_RELAY = True | |
| N_CLASSES = 10 | |
| SAMPLE_EVERY = 5 | |
| N_SAMPLE_STEPS = 50 # Euler ODE steps for sampling | |
| print("=" * 70) | |
| print("FLOW MATCHING + CONSTELLATION RELAY REGULATOR") | |
| print(f" Dataset: CIFAR-10") | |
| print(f" Base channels: {BASE_CH}") | |
| print(f" Relay: {USE_RELAY}") | |
| print(f" Flow matching: ODE (conditional)") | |
| print(f" Sampler: Euler, {N_SAMPLE_STEPS} steps") | |
| print(f" Device: {DEVICE}") | |
| print("=" * 70) | |
| # Data | |
| transform = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ]) | |
| train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform) | |
| train_loader = torch.utils.data.DataLoader( | |
| train_ds, batch_size=BATCH, shuffle=True, | |
| num_workers=4, pin_memory=True, drop_last=True) | |
| print(f" Train: {len(train_ds):,} images") | |
| # Model | |
| model = FlowMatchUNet( | |
| in_channels=3, base_channels=BASE_CH, | |
| channel_mults=(1, 2, 4), n_classes=N_CLASSES, | |
| cond_dim=256, use_relay=USE_RELAY | |
| ).to(DEVICE) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| relay_params = sum(p.numel() for n, p in model.named_parameters() if 'relay' in n) | |
| print(f" Total params: {n_params:,}") | |
| print(f" Relay params: {relay_params:,} ({100*relay_params/n_params:.1f}%)") | |
| # Count relay modules | |
| n_relays = sum(1 for m in model.modules() if isinstance(m, ConstellationRelay)) | |
| print(f" Relay modules: {n_relays}") | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6) | |
| scaler = torch.amp.GradScaler("cuda") | |
| os.makedirs("samples", exist_ok=True) | |
| os.makedirs("checkpoints", exist_ok=True) | |
| def sample(model, n_samples=64, n_steps=50, class_label=None): | |
| """Euler ODE sampling from t=1 (noise) to t=0 (data).""" | |
| model.eval() | |
| B = n_samples | |
| x = torch.randn(B, 3, 32, 32, device=DEVICE) | |
| if class_label is not None: | |
| labels = torch.full((B,), class_label, dtype=torch.long, device=DEVICE) | |
| else: | |
| labels = torch.randint(0, N_CLASSES, (B,), device=DEVICE) | |
| dt = 1.0 / n_steps | |
| for step in range(n_steps): | |
| t_val = 1.0 - step * dt | |
| t = torch.full((B,), t_val, device=DEVICE) | |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): | |
| v = model(x, t, labels) | |
| x = x - v * dt # Euler step: x_{t-dt} = x_t - v * dt | |
| # Clamp to valid range | |
| x = x.clamp(-1, 1) | |
| return x, labels | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TRAINING LOOP | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n{'='*70}") | |
| print(f"TRAINING β {EPOCHS} epochs") | |
| print(f"{'='*70}") | |
| best_loss = float('inf') | |
| gs = 0 | |
| for epoch in range(EPOCHS): | |
| model.train() | |
| t0 = time.time() | |
| total_loss = 0 | |
| n = 0 | |
| pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") | |
| for images, labels in pbar: | |
| images = images.to(DEVICE, non_blocking=True) # (B, 3, 32, 32) in [-1, 1] | |
| labels = labels.to(DEVICE, non_blocking=True) | |
| B = images.shape[0] | |
| # Flow matching: sample t, compute x_t and target velocity | |
| t = torch.rand(B, device=DEVICE) | |
| eps = torch.randn_like(images) | |
| # x_t = (1-t) * x_0 + t * eps | |
| t_b = t.view(B, 1, 1, 1) | |
| x_t = (1 - t_b) * images + t_b * eps | |
| # Target velocity: v = eps - x_0 | |
| v_target = eps - images | |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): | |
| v_pred = model(x_t, t, labels) | |
| loss = F.mse_loss(v_pred, v_target) | |
| optimizer.zero_grad(set_to_none=True) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| gs += 1 | |
| total_loss += loss.item() | |
| n += 1 | |
| if n % 20 == 0: | |
| pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}") | |
| elapsed = time.time() - t0 | |
| avg_loss = total_loss / n | |
| # Checkpoint | |
| mk = "" | |
| if avg_loss < best_loss: | |
| best_loss = avg_loss | |
| torch.save({ | |
| 'state_dict': model.state_dict(), | |
| 'epoch': epoch + 1, | |
| 'loss': avg_loss, | |
| 'use_relay': USE_RELAY, | |
| }, 'checkpoints/flow_match_best.pt') | |
| mk = " β " | |
| print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} " | |
| f"({elapsed:.0f}s){mk}") | |
| # Sample | |
| if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0: | |
| samples, sample_labels = sample(model, n_samples=64, n_steps=N_SAMPLE_STEPS) | |
| # Denormalize | |
| samples = (samples + 1) / 2 # [-1,1] β [0,1] | |
| grid = make_grid(samples, nrow=8, normalize=False) | |
| save_image(grid, f'samples/epoch_{epoch+1:03d}.png') | |
| print(f" β Saved samples/epoch_{epoch+1:03d}.png") | |
| # Per-class samples | |
| if (epoch + 1) % (SAMPLE_EVERY * 2) == 0: | |
| class_names = ['plane', 'auto', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| for c in range(N_CLASSES): | |
| cs, _ = sample(model, n_samples=8, n_steps=N_SAMPLE_STEPS, class_label=c) | |
| cs = (cs + 1) / 2 | |
| save_image(make_grid(cs, nrow=8), | |
| f'samples/epoch_{epoch+1:03d}_class_{class_names[c]}.png') | |
| # Relay diagnostics | |
| if USE_RELAY and (epoch + 1) % 10 == 0: | |
| print(f" Relay diagnostics:") | |
| for name, module in model.named_modules(): | |
| if isinstance(module, ConstellationRelay): | |
| drift = module.drift().mean().item() | |
| gate = module.gates.sigmoid().mean().item() | |
| print(f" {name}: drift={drift:.4f} rad " | |
| f"({math.degrees(drift):.1f}Β°) gate={gate:.4f}") | |
| print(f"\n{'='*70}") | |
| print(f"DONE β Best loss: {best_loss:.4f}") | |
| print(f" Params: {n_params:,} (relay: {relay_params:,})") | |
| print(f" Samples in: samples/") | |
| print(f"{'='*70}") |