#!/usr/bin/env python3 """ Flow Matching — Constellation Bottleneck ========================================== The constellation IS the bottleneck. Not a regulator. Not a side channel. All information passes through S^15 triangulation. Architecture: Encoder: 3×32×32 → 64×32 → 128×16 → 256×8 Bottleneck: flatten 256×8×8 = 16384 → Linear(16384, 256) → L2 normalize → Constellation: 16 patches × 16d, 16 anchors, 3 phases → Triangulation profile: 16 patches × 48 = 768 dims → Condition injection: concat(tri, time_emb, class_emb) → Patchwork MLP: 768+cond → 256 → 16384 → reshape 256×8×8 Decoder: 256×8 → 128×16 → 64×32 → 3×32×32 The triangulation profile IS the representation. Time and class conditioning enter at the triangulation level — they modulate what the patchwork does with the geometric reading. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import os import time from tqdm import tqdm from torchvision import datasets, transforms from torchvision.utils import save_image, make_grid DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ══════════════════════════════════════════════════════════════════ # CONSTELLATION BOTTLENECK # ══════════════════════════════════════════════════════════════════ class ConstellationBottleneck(nn.Module): """ The constellation as information bottleneck. Input: (B, spatial_dim) flattened feature map Output: (B, spatial_dim) reconstructed through geometric encoding All information passes through S^(d-1) triangulation. Time + class conditioning injected at the triangulation level. """ def __init__( self, spatial_dim, # 256*8*8 = 16384 embed_dim=256, # project to this before sphere patch_dim=16, n_anchors=16, n_phases=3, cond_dim=256, pw_hidden=512, ): super().__init__() self.spatial_dim = spatial_dim self.embed_dim = embed_dim self.patch_dim = patch_dim self.n_patches = embed_dim // patch_dim self.n_anchors = n_anchors self.n_phases = n_phases P, A, d = self.n_patches, n_anchors, patch_dim # Project feature map → embedding sphere self.proj_in = nn.Linear(spatial_dim, embed_dim) self.proj_in_norm = nn.LayerNorm(embed_dim) # Constellation anchors home = torch.empty(P, A, d) nn.init.xavier_normal_(home.view(P * A, d)) home = F.normalize(home.view(P, A, d), dim=-1) self.register_buffer('home', home) self.anchors = nn.Parameter(home.clone()) # Triangulation → total dims = P * (A * n_phases) tri_dim_per_patch = A * n_phases total_tri_dim = P * tri_dim_per_patch # Patchwork reads triangulation + conditioning # This is where time and class information enter pw_input = total_tri_dim + cond_dim self.patchwork = nn.Sequential( nn.Linear(pw_input, pw_hidden), nn.GELU(), nn.LayerNorm(pw_hidden), nn.Linear(pw_hidden, pw_hidden), nn.GELU(), nn.LayerNorm(pw_hidden), nn.Linear(pw_hidden, spatial_dim), ) # Skip projection — residual through the bottleneck self.skip_proj = nn.Linear(spatial_dim, spatial_dim) self.skip_gate = nn.Parameter(torch.tensor(-2.0)) # sigmoid ≈ 0.12 def drift(self): h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) def at_phase(self, t): h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) omega = self.drift().unsqueeze(-1) so = omega.sin().clamp(min=1e-7) return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c def triangulate(self, emb_norm): """ Multi-phase triangulation on the sphere. emb_norm: (B, P, d) normalized patches on S^(d-1) Returns: (B, P * A * n_phases) full triangulation profile """ phases = torch.linspace(0, 1, self.n_phases, device=emb_norm.device).tolist() tris = [] for t in phases: anchors_t = F.normalize(self.at_phase(t), dim=-1) # (P, A, d) cos = torch.einsum('bpd,pad->bpa', emb_norm, anchors_t) tris.append(1.0 - cos) # (B, P, A*phases) → flatten → (B, P*A*phases) tri = torch.cat(tris, dim=-1) return tri.reshape(emb_norm.shape[0], -1) def forward(self, x_flat, cond): """ x_flat: (B, spatial_dim) — flattened bottleneck features cond: (B, cond_dim) — time + class conditioning Returns: (B, spatial_dim) """ B = x_flat.shape[0] # Project to embedding space → normalize to sphere emb = self.proj_in(x_flat) emb = self.proj_in_norm(emb) patches = emb.reshape(B, self.n_patches, self.patch_dim) patches_n = F.normalize(patches, dim=-1) # on S^(d-1) # Triangulate — the geometric encoding tri_profile = self.triangulate(patches_n) # (B, P*A*phases) # Inject conditioning at the triangulation level pw_input = torch.cat([tri_profile, cond], dim=-1) # Patchwork reads the geometric profile + conditioning decoded = self.patchwork(pw_input) # (B, spatial_dim) # Gated skip connection through the bottleneck skip = self.skip_proj(x_flat) gate = self.skip_gate.sigmoid() return gate * skip + (1 - gate) * decoded # ══════════════════════════════════════════════════════════════════ # UNET BUILDING BLOCKS # ══════════════════════════════════════════════════════════════════ class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): half = self.dim // 2 emb = math.log(10000) / (half - 1) emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb) emb = t.unsqueeze(-1) * emb.unsqueeze(0) return torch.cat([emb.sin(), emb.cos()], dim=-1) class AdaGroupNorm(nn.Module): def __init__(self, channels, cond_dim, n_groups=8): super().__init__() self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False) self.proj = nn.Linear(cond_dim, channels * 2) nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x, cond): x = self.gn(x) scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) return x * (1 + scale) + shift class ConvBlock(nn.Module): def __init__(self, channels, cond_dim): super().__init__() self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels) self.norm = AdaGroupNorm(channels, cond_dim) self.pw1 = nn.Conv2d(channels, channels * 4, 1) self.pw2 = nn.Conv2d(channels * 4, channels, 1) self.act = nn.GELU() def forward(self, x, cond): residual = x x = self.dw_conv(x) x = self.norm(x, cond) x = self.pw1(x) x = self.act(x) x = self.pw2(x) return residual + x class Downsample(nn.Module): def __init__(self, ch): super().__init__() self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) def forward(self, x): return self.conv(x) class Upsample(nn.Module): def __init__(self, ch): super().__init__() self.conv = nn.Conv2d(ch, ch, 3, padding=1) def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) # ══════════════════════════════════════════════════════════════════ # FLOW MATCHING UNET WITH CONSTELLATION BOTTLENECK # ══════════════════════════════════════════════════════════════════ class FlowMatchConstellationUNet(nn.Module): """ UNet where the middle block IS the constellation. No attention. The constellation is the information bottleneck. 32×32 → 16×16 → 8×8 → flatten → project → S^15 → triangulate → patchwork(tri + time + class) → project back → 8×8 → 16×16 → 32×32 """ def __init__( self, in_channels=3, base_ch=64, channel_mults=(1, 2, 4), n_classes=10, cond_dim=256, embed_dim=256, n_anchors=16, n_phases=3, pw_hidden=512, ): super().__init__() self.channel_mults = channel_mults # Conditioning self.time_emb = nn.Sequential( SinusoidalPosEmb(cond_dim), nn.Linear(cond_dim, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim)) self.class_emb = nn.Embedding(n_classes, cond_dim) # Input self.in_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1) # Encoder self.enc = nn.ModuleList() self.enc_down = nn.ModuleList() ch = base_ch enc_channels = [base_ch] for i, mult in enumerate(channel_mults): ch_out = base_ch * mult self.enc.append(nn.ModuleList([ ConvBlock(ch, cond_dim) if ch == ch_out else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)), ConvBlock(ch_out, cond_dim), ])) ch = ch_out enc_channels.append(ch) if i < len(channel_mults) - 1: self.enc_down.append(Downsample(ch)) # Constellation bottleneck # At this point: (B, ch, 8, 8) where ch = base_ch * channel_mults[-1] mid_ch = ch spatial = 8 * 8 # after two downsamples from 32 spatial_dim = mid_ch * spatial self.bottleneck = ConstellationBottleneck( spatial_dim=spatial_dim, embed_dim=embed_dim, patch_dim=16, n_anchors=n_anchors, n_phases=n_phases, cond_dim=cond_dim, pw_hidden=pw_hidden, ) self.mid_ch = mid_ch self.mid_spatial = spatial # Decoder self.dec_up = nn.ModuleList() self.dec_skip_proj = nn.ModuleList() self.dec = nn.ModuleList() for i in range(len(channel_mults) - 1, -1, -1): ch_out = base_ch * channel_mults[i] skip_ch = enc_channels.pop() self.dec_skip_proj.append(nn.Conv2d(ch + skip_ch, ch_out, 1)) self.dec.append(nn.ModuleList([ ConvBlock(ch_out, cond_dim), ConvBlock(ch_out, cond_dim), ])) ch = ch_out if i > 0: self.dec_up.append(Upsample(ch)) # Output self.out_norm = nn.GroupNorm(8, ch) self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1) nn.init.zeros_(self.out_conv.weight) nn.init.zeros_(self.out_conv.bias) def forward(self, x, t, class_labels): cond = self.time_emb(t) + self.class_emb(class_labels) h = self.in_conv(x) skips = [h] # Encoder for i in range(len(self.channel_mults)): for block in self.enc[i]: if isinstance(block, ConvBlock): h = block(h, cond) elif isinstance(block, nn.Sequential): h = block[0](h) h = block[1](h, cond) skips.append(h) if i < len(self.enc_down): h = self.enc_down[i](h) # ★ CONSTELLATION BOTTLENECK ★ B, C, H, W = h.shape h_flat = h.reshape(B, -1) # (B, C*H*W) h_flat = self.bottleneck(h_flat, cond) # through S^15 h = h_flat.reshape(B, C, H, W) # Decoder for i in range(len(self.channel_mults)): skip = skips.pop() if i > 0: h = self.dec_up[i - 1](h) h = torch.cat([h, skip], dim=1) h = self.dec_skip_proj[i](h) for block in self.dec[i]: h = block(h, cond) h = self.out_norm(h) h = F.silu(h) return self.out_conv(h) # ══════════════════════════════════════════════════════════════════ # SAMPLING # ══════════════════════════════════════════════════════════════════ @torch.no_grad() def sample(model, n_samples=64, n_steps=50, class_label=None, n_classes=10): model.eval() x = torch.randn(n_samples, 3, 32, 32, device=DEVICE) if class_label is not None: labels = torch.full((n_samples,), class_label, dtype=torch.long, device=DEVICE) else: labels = torch.randint(0, n_classes, (n_samples,), device=DEVICE) dt = 1.0 / n_steps for step in range(n_steps): t_val = 1.0 - step * dt t = torch.full((n_samples,), t_val, device=DEVICE) with torch.amp.autocast("cuda", dtype=torch.bfloat16): v = model(x, t, labels) x = x - v.float() * dt return x.clamp(-1, 1), labels # ══════════════════════════════════════════════════════════════════ # TRAINING # ══════════════════════════════════════════════════════════════════ BATCH = 128 EPOCHS = 50 LR = 3e-4 N_CLASSES = 10 SAMPLE_EVERY = 5 print("=" * 70) print("FLOW MATCHING — CONSTELLATION BOTTLENECK") print(f" No attention. The constellation IS the bottleneck.") print(f" Device: {DEVICE}") print("=" * 70) transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,)*3, (0.5,)*3), ]) train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) model = FlowMatchConstellationUNet( in_channels=3, base_ch=64, channel_mults=(1, 2, 4), n_classes=N_CLASSES, cond_dim=256, embed_dim=256, n_anchors=16, n_phases=3, pw_hidden=512, ).to(DEVICE) n_params = sum(p.numel() for p in model.parameters()) n_bottleneck = sum(p.numel() for p in model.bottleneck.parameters()) print(f" Total params: {n_params:,}") print(f" Bottleneck params: {n_bottleneck:,} ({100*n_bottleneck/n_params:.1f}%)") print(f" Train: {len(train_ds):,} images") # Verify shapes with torch.no_grad(): dummy = torch.randn(2, 3, 32, 32, device=DEVICE) t_dummy = torch.rand(2, device=DEVICE) c_dummy = torch.randint(0, 10, (2,), device=DEVICE) out = model(dummy, t_dummy, c_dummy) print(f" Shape check: {dummy.shape} → {out.shape} ✓") # Show bottleneck info bn = model.bottleneck drift = bn.drift() print(f" Bottleneck: {bn.spatial_dim}d → {bn.embed_dim}d sphere " f"→ {bn.n_patches}p × {bn.patch_dim}d × {bn.n_anchors}A × {bn.n_phases}ph " f"= {bn.n_patches * bn.n_anchors * bn.n_phases} tri dims") print(f" Skip gate init: {bn.skip_gate.sigmoid().item():.4f}") optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6) scaler = torch.amp.GradScaler("cuda") os.makedirs("samples_bn", exist_ok=True) os.makedirs("checkpoints", exist_ok=True) print(f"\n{'='*70}") print(f"TRAINING — {EPOCHS} epochs") print(f"{'='*70}") best_loss = float('inf') for epoch in range(EPOCHS): model.train() t0 = time.time() total_loss = 0 n = 0 pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") for images, labels in pbar: images = images.to(DEVICE, non_blocking=True) labels = labels.to(DEVICE, non_blocking=True) B = images.shape[0] t = torch.rand(B, device=DEVICE) eps = torch.randn_like(images) t_b = t.view(B, 1, 1, 1) x_t = (1 - t_b) * images + t_b * eps v_target = eps - images with torch.amp.autocast("cuda", dtype=torch.bfloat16): v_pred = model(x_t, t, labels) loss = F.mse_loss(v_pred, v_target) optimizer.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() total_loss += loss.item() n += 1 if n % 20 == 0: pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}") elapsed = time.time() - t0 avg_loss = total_loss / n mk = "" if avg_loss < best_loss: best_loss = avg_loss torch.save({ 'state_dict': model.state_dict(), 'epoch': epoch + 1, 'loss': avg_loss, }, 'checkpoints/constellation_bn_best.pt') mk = " ★" print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} " f"({elapsed:.0f}s){mk}") # Diagnostics if (epoch + 1) % 10 == 0: bn = model.bottleneck drift = bn.drift() gate = bn.skip_gate.sigmoid().item() print(f" Bottleneck: drift={drift.mean():.4f}rad ({math.degrees(drift.mean()):.1f}°) " f"max={drift.max():.4f}rad gate={gate:.4f}") # Sample if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0: imgs, _ = sample(model, 64, 50) imgs = (imgs + 1) / 2 save_image(make_grid(imgs, nrow=8), f'samples_bn/epoch_{epoch+1:03d}.png') print(f" → samples_bn/epoch_{epoch+1:03d}.png") if (epoch + 1) % (SAMPLE_EVERY * 2) == 0: class_names = ['plane','auto','bird','cat','deer', 'dog','frog','horse','ship','truck'] for c in range(N_CLASSES): cs, _ = sample(model, 8, 50, class_label=c) save_image(make_grid((cs+1)/2, nrow=8), f'samples_bn/epoch_{epoch+1:03d}_{class_names[c]}.png') print(f"\n{'='*70}") print(f"DONE — Best loss: {best_loss:.4f}") print(f" Params: {n_params:,} (bottleneck: {n_bottleneck:,})") print(f"{'='*70}")