| |
| """ |
| Flow Matching β Constellation Bottleneck |
| ========================================== |
| The constellation IS the bottleneck. Not a regulator. Not a side channel. |
| All information passes through S^15 triangulation. |
| |
| Architecture: |
| Encoder: 3Γ32Γ32 β 64Γ32 β 128Γ16 β 256Γ8 |
| Bottleneck: |
| flatten 256Γ8Γ8 = 16384 β Linear(16384, 256) β L2 normalize |
| β Constellation: 16 patches Γ 16d, 16 anchors, 3 phases |
| β Triangulation profile: 16 patches Γ 48 = 768 dims |
| β Condition injection: concat(tri, time_emb, class_emb) |
| β Patchwork MLP: 768+cond β 256 β 16384 β reshape 256Γ8Γ8 |
| Decoder: 256Γ8 β 128Γ16 β 64Γ32 β 3Γ32Γ32 |
| |
| The triangulation profile IS the representation. |
| Time and class conditioning enter at the triangulation level β |
| they modulate what the patchwork does with the geometric reading. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import os |
| import time |
| from tqdm import tqdm |
| from torchvision import datasets, transforms |
| from torchvision.utils import save_image, make_grid |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| |
| |
| |
|
|
| class ConstellationBottleneck(nn.Module): |
| """ |
| The constellation as information bottleneck. |
| |
| Input: (B, spatial_dim) flattened feature map |
| Output: (B, spatial_dim) reconstructed through geometric encoding |
| |
| All information passes through S^(d-1) triangulation. |
| Time + class conditioning injected at the triangulation level. |
| """ |
| def __init__( |
| self, |
| spatial_dim, |
| embed_dim=256, |
| patch_dim=16, |
| n_anchors=16, |
| n_phases=3, |
| cond_dim=256, |
| pw_hidden=512, |
| ): |
| super().__init__() |
| self.spatial_dim = spatial_dim |
| self.embed_dim = embed_dim |
| self.patch_dim = patch_dim |
| self.n_patches = embed_dim // patch_dim |
| self.n_anchors = n_anchors |
| self.n_phases = n_phases |
|
|
| P, A, d = self.n_patches, n_anchors, patch_dim |
|
|
| |
| self.proj_in = nn.Linear(spatial_dim, embed_dim) |
| self.proj_in_norm = nn.LayerNorm(embed_dim) |
|
|
| |
| home = torch.empty(P, A, d) |
| nn.init.xavier_normal_(home.view(P * A, d)) |
| home = F.normalize(home.view(P, A, d), dim=-1) |
| self.register_buffer('home', home) |
| self.anchors = nn.Parameter(home.clone()) |
|
|
| |
| tri_dim_per_patch = A * n_phases |
| total_tri_dim = P * tri_dim_per_patch |
|
|
| |
| |
| pw_input = total_tri_dim + cond_dim |
| self.patchwork = nn.Sequential( |
| nn.Linear(pw_input, pw_hidden), |
| nn.GELU(), |
| nn.LayerNorm(pw_hidden), |
| nn.Linear(pw_hidden, pw_hidden), |
| nn.GELU(), |
| nn.LayerNorm(pw_hidden), |
| nn.Linear(pw_hidden, spatial_dim), |
| ) |
|
|
| |
| self.skip_proj = nn.Linear(spatial_dim, spatial_dim) |
| self.skip_gate = nn.Parameter(torch.tensor(-2.0)) |
|
|
| def drift(self): |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) |
| return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) |
|
|
| def at_phase(self, t): |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) |
| omega = self.drift().unsqueeze(-1) |
| so = omega.sin().clamp(min=1e-7) |
| return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c |
|
|
| def triangulate(self, emb_norm): |
| """ |
| Multi-phase triangulation on the sphere. |
| emb_norm: (B, P, d) normalized patches on S^(d-1) |
| Returns: (B, P * A * n_phases) full triangulation profile |
| """ |
| phases = torch.linspace(0, 1, self.n_phases, device=emb_norm.device).tolist() |
| tris = [] |
| for t in phases: |
| anchors_t = F.normalize(self.at_phase(t), dim=-1) |
| cos = torch.einsum('bpd,pad->bpa', emb_norm, anchors_t) |
| tris.append(1.0 - cos) |
| |
| tri = torch.cat(tris, dim=-1) |
| return tri.reshape(emb_norm.shape[0], -1) |
|
|
| def forward(self, x_flat, cond): |
| """ |
| x_flat: (B, spatial_dim) β flattened bottleneck features |
| cond: (B, cond_dim) β time + class conditioning |
| Returns: (B, spatial_dim) |
| """ |
| B = x_flat.shape[0] |
|
|
| |
| emb = self.proj_in(x_flat) |
| emb = self.proj_in_norm(emb) |
| patches = emb.reshape(B, self.n_patches, self.patch_dim) |
| patches_n = F.normalize(patches, dim=-1) |
|
|
| |
| tri_profile = self.triangulate(patches_n) |
|
|
| |
| pw_input = torch.cat([tri_profile, cond], dim=-1) |
|
|
| |
| decoded = self.patchwork(pw_input) |
|
|
| |
| skip = self.skip_proj(x_flat) |
| gate = self.skip_gate.sigmoid() |
| return gate * skip + (1 - gate) * decoded |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t): |
| half = self.dim // 2 |
| emb = math.log(10000) / (half - 1) |
| emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb) |
| emb = t.unsqueeze(-1) * emb.unsqueeze(0) |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
|
|
| class AdaGroupNorm(nn.Module): |
| def __init__(self, channels, cond_dim, n_groups=8): |
| super().__init__() |
| self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False) |
| self.proj = nn.Linear(cond_dim, channels * 2) |
| nn.init.zeros_(self.proj.weight) |
| nn.init.zeros_(self.proj.bias) |
|
|
| def forward(self, x, cond): |
| x = self.gn(x) |
| scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) |
| return x * (1 + scale) + shift |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, channels, cond_dim): |
| super().__init__() |
| self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels) |
| self.norm = AdaGroupNorm(channels, cond_dim) |
| self.pw1 = nn.Conv2d(channels, channels * 4, 1) |
| self.pw2 = nn.Conv2d(channels * 4, channels, 1) |
| self.act = nn.GELU() |
|
|
| def forward(self, x, cond): |
| residual = x |
| x = self.dw_conv(x) |
| x = self.norm(x, cond) |
| x = self.pw1(x) |
| x = self.act(x) |
| x = self.pw2(x) |
| return residual + x |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) |
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.conv = nn.Conv2d(ch, ch, 3, padding=1) |
| def forward(self, x): |
| return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) |
|
|
|
|
| |
| |
| |
|
|
| class FlowMatchConstellationUNet(nn.Module): |
| """ |
| UNet where the middle block IS the constellation. |
| No attention. The constellation is the information bottleneck. |
| |
| 32Γ32 β 16Γ16 β 8Γ8 β flatten β project β S^15 β triangulate |
| β patchwork(tri + time + class) β project back β 8Γ8 β 16Γ16 β 32Γ32 |
| """ |
| def __init__( |
| self, |
| in_channels=3, |
| base_ch=64, |
| channel_mults=(1, 2, 4), |
| n_classes=10, |
| cond_dim=256, |
| embed_dim=256, |
| n_anchors=16, |
| n_phases=3, |
| pw_hidden=512, |
| ): |
| super().__init__() |
| self.channel_mults = channel_mults |
|
|
| |
| self.time_emb = nn.Sequential( |
| SinusoidalPosEmb(cond_dim), |
| nn.Linear(cond_dim, cond_dim), nn.GELU(), |
| nn.Linear(cond_dim, cond_dim)) |
| self.class_emb = nn.Embedding(n_classes, cond_dim) |
|
|
| |
| self.in_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1) |
|
|
| |
| self.enc = nn.ModuleList() |
| self.enc_down = nn.ModuleList() |
| ch = base_ch |
| enc_channels = [base_ch] |
|
|
| for i, mult in enumerate(channel_mults): |
| ch_out = base_ch * mult |
| self.enc.append(nn.ModuleList([ |
| ConvBlock(ch, cond_dim) if ch == ch_out |
| else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)), |
| ConvBlock(ch_out, cond_dim), |
| ])) |
| ch = ch_out |
| enc_channels.append(ch) |
| if i < len(channel_mults) - 1: |
| self.enc_down.append(Downsample(ch)) |
|
|
| |
| |
| mid_ch = ch |
| spatial = 8 * 8 |
| spatial_dim = mid_ch * spatial |
|
|
| self.bottleneck = ConstellationBottleneck( |
| spatial_dim=spatial_dim, |
| embed_dim=embed_dim, |
| patch_dim=16, |
| n_anchors=n_anchors, |
| n_phases=n_phases, |
| cond_dim=cond_dim, |
| pw_hidden=pw_hidden, |
| ) |
| self.mid_ch = mid_ch |
| self.mid_spatial = spatial |
|
|
| |
| self.dec_up = nn.ModuleList() |
| self.dec_skip_proj = nn.ModuleList() |
| self.dec = nn.ModuleList() |
|
|
| for i in range(len(channel_mults) - 1, -1, -1): |
| ch_out = base_ch * channel_mults[i] |
| skip_ch = enc_channels.pop() |
| self.dec_skip_proj.append(nn.Conv2d(ch + skip_ch, ch_out, 1)) |
| self.dec.append(nn.ModuleList([ |
| ConvBlock(ch_out, cond_dim), |
| ConvBlock(ch_out, cond_dim), |
| ])) |
| ch = ch_out |
| if i > 0: |
| self.dec_up.append(Upsample(ch)) |
|
|
| |
| self.out_norm = nn.GroupNorm(8, ch) |
| self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1) |
| nn.init.zeros_(self.out_conv.weight) |
| nn.init.zeros_(self.out_conv.bias) |
|
|
| def forward(self, x, t, class_labels): |
| cond = self.time_emb(t) + self.class_emb(class_labels) |
| h = self.in_conv(x) |
| skips = [h] |
|
|
| |
| for i in range(len(self.channel_mults)): |
| for block in self.enc[i]: |
| if isinstance(block, ConvBlock): |
| h = block(h, cond) |
| elif isinstance(block, nn.Sequential): |
| h = block[0](h) |
| h = block[1](h, cond) |
| skips.append(h) |
| if i < len(self.enc_down): |
| h = self.enc_down[i](h) |
|
|
| |
| B, C, H, W = h.shape |
| h_flat = h.reshape(B, -1) |
| h_flat = self.bottleneck(h_flat, cond) |
| h = h_flat.reshape(B, C, H, W) |
|
|
| |
| for i in range(len(self.channel_mults)): |
| skip = skips.pop() |
| if i > 0: |
| h = self.dec_up[i - 1](h) |
| h = torch.cat([h, skip], dim=1) |
| h = self.dec_skip_proj[i](h) |
| for block in self.dec[i]: |
| h = block(h, cond) |
|
|
| h = self.out_norm(h) |
| h = F.silu(h) |
| return self.out_conv(h) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample(model, n_samples=64, n_steps=50, class_label=None, n_classes=10): |
| model.eval() |
| x = torch.randn(n_samples, 3, 32, 32, device=DEVICE) |
| if class_label is not None: |
| labels = torch.full((n_samples,), class_label, dtype=torch.long, device=DEVICE) |
| else: |
| labels = torch.randint(0, n_classes, (n_samples,), device=DEVICE) |
|
|
| dt = 1.0 / n_steps |
| for step in range(n_steps): |
| t_val = 1.0 - step * dt |
| t = torch.full((n_samples,), t_val, device=DEVICE) |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| v = model(x, t, labels) |
| x = x - v.float() * dt |
| return x.clamp(-1, 1), labels |
|
|
|
|
| |
| |
| |
|
|
| BATCH = 128 |
| EPOCHS = 50 |
| LR = 3e-4 |
| N_CLASSES = 10 |
| SAMPLE_EVERY = 5 |
|
|
| print("=" * 70) |
| print("FLOW MATCHING β CONSTELLATION BOTTLENECK") |
| print(f" No attention. The constellation IS the bottleneck.") |
| print(f" Device: {DEVICE}") |
| print("=" * 70) |
|
|
| transform = transforms.Compose([ |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5,)*3, (0.5,)*3), |
| ]) |
| train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform) |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=BATCH, shuffle=True, |
| num_workers=4, pin_memory=True, drop_last=True) |
|
|
| model = FlowMatchConstellationUNet( |
| in_channels=3, base_ch=64, channel_mults=(1, 2, 4), |
| n_classes=N_CLASSES, cond_dim=256, embed_dim=256, |
| n_anchors=16, n_phases=3, pw_hidden=512, |
| ).to(DEVICE) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| n_bottleneck = sum(p.numel() for p in model.bottleneck.parameters()) |
| print(f" Total params: {n_params:,}") |
| print(f" Bottleneck params: {n_bottleneck:,} ({100*n_bottleneck/n_params:.1f}%)") |
| print(f" Train: {len(train_ds):,} images") |
|
|
| |
| with torch.no_grad(): |
| dummy = torch.randn(2, 3, 32, 32, device=DEVICE) |
| t_dummy = torch.rand(2, device=DEVICE) |
| c_dummy = torch.randint(0, 10, (2,), device=DEVICE) |
| out = model(dummy, t_dummy, c_dummy) |
| print(f" Shape check: {dummy.shape} β {out.shape} β") |
|
|
| |
| bn = model.bottleneck |
| drift = bn.drift() |
| print(f" Bottleneck: {bn.spatial_dim}d β {bn.embed_dim}d sphere " |
| f"β {bn.n_patches}p Γ {bn.patch_dim}d Γ {bn.n_anchors}A Γ {bn.n_phases}ph " |
| f"= {bn.n_patches * bn.n_anchors * bn.n_phases} tri dims") |
| print(f" Skip gate init: {bn.skip_gate.sigmoid().item():.4f}") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6) |
| scaler = torch.amp.GradScaler("cuda") |
|
|
| os.makedirs("samples_bn", exist_ok=True) |
| os.makedirs("checkpoints", exist_ok=True) |
|
|
| print(f"\n{'='*70}") |
| print(f"TRAINING β {EPOCHS} epochs") |
| print(f"{'='*70}") |
|
|
| best_loss = float('inf') |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| t0 = time.time() |
| total_loss = 0 |
| n = 0 |
|
|
| pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") |
| for images, labels in pbar: |
| images = images.to(DEVICE, non_blocking=True) |
| labels = labels.to(DEVICE, non_blocking=True) |
| B = images.shape[0] |
|
|
| t = torch.rand(B, device=DEVICE) |
| eps = torch.randn_like(images) |
| t_b = t.view(B, 1, 1, 1) |
| x_t = (1 - t_b) * images + t_b * eps |
| v_target = eps - images |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| v_pred = model(x_t, t, labels) |
| loss = F.mse_loss(v_pred, v_target) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| total_loss += loss.item() |
| n += 1 |
| if n % 20 == 0: |
| pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}") |
|
|
| elapsed = time.time() - t0 |
| avg_loss = total_loss / n |
|
|
| mk = "" |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save({ |
| 'state_dict': model.state_dict(), |
| 'epoch': epoch + 1, |
| 'loss': avg_loss, |
| }, 'checkpoints/constellation_bn_best.pt') |
| mk = " β
" |
|
|
| print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} " |
| f"({elapsed:.0f}s){mk}") |
|
|
| |
| if (epoch + 1) % 10 == 0: |
| bn = model.bottleneck |
| drift = bn.drift() |
| gate = bn.skip_gate.sigmoid().item() |
| print(f" Bottleneck: drift={drift.mean():.4f}rad ({math.degrees(drift.mean()):.1f}Β°) " |
| f"max={drift.max():.4f}rad gate={gate:.4f}") |
|
|
| |
| if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0: |
| imgs, _ = sample(model, 64, 50) |
| imgs = (imgs + 1) / 2 |
| save_image(make_grid(imgs, nrow=8), f'samples_bn/epoch_{epoch+1:03d}.png') |
| print(f" β samples_bn/epoch_{epoch+1:03d}.png") |
|
|
| if (epoch + 1) % (SAMPLE_EVERY * 2) == 0: |
| class_names = ['plane','auto','bird','cat','deer', |
| 'dog','frog','horse','ship','truck'] |
| for c in range(N_CLASSES): |
| cs, _ = sample(model, 8, 50, class_label=c) |
| save_image(make_grid((cs+1)/2, nrow=8), |
| f'samples_bn/epoch_{epoch+1:03d}_{class_names[c]}.png') |
|
|
| print(f"\n{'='*70}") |
| print(f"DONE β Best loss: {best_loss:.4f}") |
| print(f" Params: {n_params:,} (bottleneck: {n_bottleneck:,})") |
| print(f"{'='*70}") |