AbstractPhil's picture
Create trainer_model.py
cf76b2b verified
#!/usr/bin/env python3
"""
Flow Matching β€” Constellation Bottleneck
==========================================
The constellation IS the bottleneck. Not a regulator. Not a side channel.
All information passes through S^15 triangulation.
Architecture:
Encoder: 3Γ—32Γ—32 β†’ 64Γ—32 β†’ 128Γ—16 β†’ 256Γ—8
Bottleneck:
flatten 256Γ—8Γ—8 = 16384 β†’ Linear(16384, 256) β†’ L2 normalize
β†’ Constellation: 16 patches Γ— 16d, 16 anchors, 3 phases
β†’ Triangulation profile: 16 patches Γ— 48 = 768 dims
β†’ Condition injection: concat(tri, time_emb, class_emb)
β†’ Patchwork MLP: 768+cond β†’ 256 β†’ 16384 β†’ reshape 256Γ—8Γ—8
Decoder: 256Γ—8 β†’ 128Γ—16 β†’ 64Γ—32 β†’ 3Γ—32Γ—32
The triangulation profile IS the representation.
Time and class conditioning enter at the triangulation level β€”
they modulate what the patchwork does with the geometric reading.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import time
from tqdm import tqdm
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION BOTTLENECK
# ══════════════════════════════════════════════════════════════════
class ConstellationBottleneck(nn.Module):
"""
The constellation as information bottleneck.
Input: (B, spatial_dim) flattened feature map
Output: (B, spatial_dim) reconstructed through geometric encoding
All information passes through S^(d-1) triangulation.
Time + class conditioning injected at the triangulation level.
"""
def __init__(
self,
spatial_dim, # 256*8*8 = 16384
embed_dim=256, # project to this before sphere
patch_dim=16,
n_anchors=16,
n_phases=3,
cond_dim=256,
pw_hidden=512,
):
super().__init__()
self.spatial_dim = spatial_dim
self.embed_dim = embed_dim
self.patch_dim = patch_dim
self.n_patches = embed_dim // patch_dim
self.n_anchors = n_anchors
self.n_phases = n_phases
P, A, d = self.n_patches, n_anchors, patch_dim
# Project feature map β†’ embedding sphere
self.proj_in = nn.Linear(spatial_dim, embed_dim)
self.proj_in_norm = nn.LayerNorm(embed_dim)
# Constellation anchors
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
# Triangulation β†’ total dims = P * (A * n_phases)
tri_dim_per_patch = A * n_phases
total_tri_dim = P * tri_dim_per_patch
# Patchwork reads triangulation + conditioning
# This is where time and class information enter
pw_input = total_tri_dim + cond_dim
self.patchwork = nn.Sequential(
nn.Linear(pw_input, pw_hidden),
nn.GELU(),
nn.LayerNorm(pw_hidden),
nn.Linear(pw_hidden, pw_hidden),
nn.GELU(),
nn.LayerNorm(pw_hidden),
nn.Linear(pw_hidden, spatial_dim),
)
# Skip projection β€” residual through the bottleneck
self.skip_proj = nn.Linear(spatial_dim, spatial_dim)
self.skip_gate = nn.Parameter(torch.tensor(-2.0)) # sigmoid β‰ˆ 0.12
def drift(self):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
def at_phase(self, t):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-7)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def triangulate(self, emb_norm):
"""
Multi-phase triangulation on the sphere.
emb_norm: (B, P, d) normalized patches on S^(d-1)
Returns: (B, P * A * n_phases) full triangulation profile
"""
phases = torch.linspace(0, 1, self.n_phases, device=emb_norm.device).tolist()
tris = []
for t in phases:
anchors_t = F.normalize(self.at_phase(t), dim=-1) # (P, A, d)
cos = torch.einsum('bpd,pad->bpa', emb_norm, anchors_t)
tris.append(1.0 - cos)
# (B, P, A*phases) β†’ flatten β†’ (B, P*A*phases)
tri = torch.cat(tris, dim=-1)
return tri.reshape(emb_norm.shape[0], -1)
def forward(self, x_flat, cond):
"""
x_flat: (B, spatial_dim) β€” flattened bottleneck features
cond: (B, cond_dim) β€” time + class conditioning
Returns: (B, spatial_dim)
"""
B = x_flat.shape[0]
# Project to embedding space β†’ normalize to sphere
emb = self.proj_in(x_flat)
emb = self.proj_in_norm(emb)
patches = emb.reshape(B, self.n_patches, self.patch_dim)
patches_n = F.normalize(patches, dim=-1) # on S^(d-1)
# Triangulate β€” the geometric encoding
tri_profile = self.triangulate(patches_n) # (B, P*A*phases)
# Inject conditioning at the triangulation level
pw_input = torch.cat([tri_profile, cond], dim=-1)
# Patchwork reads the geometric profile + conditioning
decoded = self.patchwork(pw_input) # (B, spatial_dim)
# Gated skip connection through the bottleneck
skip = self.skip_proj(x_flat)
gate = self.skip_gate.sigmoid()
return gate * skip + (1 - gate) * decoded
# ══════════════════════════════════════════════════════════════════
# UNET BUILDING BLOCKS
# ══════════════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class AdaGroupNorm(nn.Module):
def __init__(self, channels, cond_dim, n_groups=8):
super().__init__()
self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False)
self.proj = nn.Linear(cond_dim, channels * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x, cond):
x = self.gn(x)
scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
return x * (1 + scale) + shift
class ConvBlock(nn.Module):
def __init__(self, channels, cond_dim):
super().__init__()
self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
self.norm = AdaGroupNorm(channels, cond_dim)
self.pw1 = nn.Conv2d(channels, channels * 4, 1)
self.pw2 = nn.Conv2d(channels * 4, channels, 1)
self.act = nn.GELU()
def forward(self, x, cond):
residual = x
x = self.dw_conv(x)
x = self.norm(x, cond)
x = self.pw1(x)
x = self.act(x)
x = self.pw2(x)
return residual + x
class Downsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, padding=1)
def forward(self, x):
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
# ══════════════════════════════════════════════════════════════════
# FLOW MATCHING UNET WITH CONSTELLATION BOTTLENECK
# ══════════════════════════════════════════════════════════════════
class FlowMatchConstellationUNet(nn.Module):
"""
UNet where the middle block IS the constellation.
No attention. The constellation is the information bottleneck.
32Γ—32 β†’ 16Γ—16 β†’ 8Γ—8 β†’ flatten β†’ project β†’ S^15 β†’ triangulate
β†’ patchwork(tri + time + class) β†’ project back β†’ 8Γ—8 β†’ 16Γ—16 β†’ 32Γ—32
"""
def __init__(
self,
in_channels=3,
base_ch=64,
channel_mults=(1, 2, 4),
n_classes=10,
cond_dim=256,
embed_dim=256,
n_anchors=16,
n_phases=3,
pw_hidden=512,
):
super().__init__()
self.channel_mults = channel_mults
# Conditioning
self.time_emb = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.GELU(),
nn.Linear(cond_dim, cond_dim))
self.class_emb = nn.Embedding(n_classes, cond_dim)
# Input
self.in_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1)
# Encoder
self.enc = nn.ModuleList()
self.enc_down = nn.ModuleList()
ch = base_ch
enc_channels = [base_ch]
for i, mult in enumerate(channel_mults):
ch_out = base_ch * mult
self.enc.append(nn.ModuleList([
ConvBlock(ch, cond_dim) if ch == ch_out
else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)),
ConvBlock(ch_out, cond_dim),
]))
ch = ch_out
enc_channels.append(ch)
if i < len(channel_mults) - 1:
self.enc_down.append(Downsample(ch))
# Constellation bottleneck
# At this point: (B, ch, 8, 8) where ch = base_ch * channel_mults[-1]
mid_ch = ch
spatial = 8 * 8 # after two downsamples from 32
spatial_dim = mid_ch * spatial
self.bottleneck = ConstellationBottleneck(
spatial_dim=spatial_dim,
embed_dim=embed_dim,
patch_dim=16,
n_anchors=n_anchors,
n_phases=n_phases,
cond_dim=cond_dim,
pw_hidden=pw_hidden,
)
self.mid_ch = mid_ch
self.mid_spatial = spatial
# Decoder
self.dec_up = nn.ModuleList()
self.dec_skip_proj = nn.ModuleList()
self.dec = nn.ModuleList()
for i in range(len(channel_mults) - 1, -1, -1):
ch_out = base_ch * channel_mults[i]
skip_ch = enc_channels.pop()
self.dec_skip_proj.append(nn.Conv2d(ch + skip_ch, ch_out, 1))
self.dec.append(nn.ModuleList([
ConvBlock(ch_out, cond_dim),
ConvBlock(ch_out, cond_dim),
]))
ch = ch_out
if i > 0:
self.dec_up.append(Upsample(ch))
# Output
self.out_norm = nn.GroupNorm(8, ch)
self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1)
nn.init.zeros_(self.out_conv.weight)
nn.init.zeros_(self.out_conv.bias)
def forward(self, x, t, class_labels):
cond = self.time_emb(t) + self.class_emb(class_labels)
h = self.in_conv(x)
skips = [h]
# Encoder
for i in range(len(self.channel_mults)):
for block in self.enc[i]:
if isinstance(block, ConvBlock):
h = block(h, cond)
elif isinstance(block, nn.Sequential):
h = block[0](h)
h = block[1](h, cond)
skips.append(h)
if i < len(self.enc_down):
h = self.enc_down[i](h)
# β˜… CONSTELLATION BOTTLENECK β˜…
B, C, H, W = h.shape
h_flat = h.reshape(B, -1) # (B, C*H*W)
h_flat = self.bottleneck(h_flat, cond) # through S^15
h = h_flat.reshape(B, C, H, W)
# Decoder
for i in range(len(self.channel_mults)):
skip = skips.pop()
if i > 0:
h = self.dec_up[i - 1](h)
h = torch.cat([h, skip], dim=1)
h = self.dec_skip_proj[i](h)
for block in self.dec[i]:
h = block(h, cond)
h = self.out_norm(h)
h = F.silu(h)
return self.out_conv(h)
# ══════════════════════════════════════════════════════════════════
# SAMPLING
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def sample(model, n_samples=64, n_steps=50, class_label=None, n_classes=10):
model.eval()
x = torch.randn(n_samples, 3, 32, 32, device=DEVICE)
if class_label is not None:
labels = torch.full((n_samples,), class_label, dtype=torch.long, device=DEVICE)
else:
labels = torch.randint(0, n_classes, (n_samples,), device=DEVICE)
dt = 1.0 / n_steps
for step in range(n_steps):
t_val = 1.0 - step * dt
t = torch.full((n_samples,), t_val, device=DEVICE)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v = model(x, t, labels)
x = x - v.float() * dt
return x.clamp(-1, 1), labels
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
BATCH = 128
EPOCHS = 50
LR = 3e-4
N_CLASSES = 10
SAMPLE_EVERY = 5
print("=" * 70)
print("FLOW MATCHING β€” CONSTELLATION BOTTLENECK")
print(f" No attention. The constellation IS the bottleneck.")
print(f" Device: {DEVICE}")
print("=" * 70)
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,)*3, (0.5,)*3),
])
train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
model = FlowMatchConstellationUNet(
in_channels=3, base_ch=64, channel_mults=(1, 2, 4),
n_classes=N_CLASSES, cond_dim=256, embed_dim=256,
n_anchors=16, n_phases=3, pw_hidden=512,
).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
n_bottleneck = sum(p.numel() for p in model.bottleneck.parameters())
print(f" Total params: {n_params:,}")
print(f" Bottleneck params: {n_bottleneck:,} ({100*n_bottleneck/n_params:.1f}%)")
print(f" Train: {len(train_ds):,} images")
# Verify shapes
with torch.no_grad():
dummy = torch.randn(2, 3, 32, 32, device=DEVICE)
t_dummy = torch.rand(2, device=DEVICE)
c_dummy = torch.randint(0, 10, (2,), device=DEVICE)
out = model(dummy, t_dummy, c_dummy)
print(f" Shape check: {dummy.shape} β†’ {out.shape} βœ“")
# Show bottleneck info
bn = model.bottleneck
drift = bn.drift()
print(f" Bottleneck: {bn.spatial_dim}d β†’ {bn.embed_dim}d sphere "
f"β†’ {bn.n_patches}p Γ— {bn.patch_dim}d Γ— {bn.n_anchors}A Γ— {bn.n_phases}ph "
f"= {bn.n_patches * bn.n_anchors * bn.n_phases} tri dims")
print(f" Skip gate init: {bn.skip_gate.sigmoid().item():.4f}")
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6)
scaler = torch.amp.GradScaler("cuda")
os.makedirs("samples_bn", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
print(f"\n{'='*70}")
print(f"TRAINING β€” {EPOCHS} epochs")
print(f"{'='*70}")
best_loss = float('inf')
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
total_loss = 0
n = 0
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
for images, labels in pbar:
images = images.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
B = images.shape[0]
t = torch.rand(B, device=DEVICE)
eps = torch.randn_like(images)
t_b = t.view(B, 1, 1, 1)
x_t = (1 - t_b) * images + t_b * eps
v_target = eps - images
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v_pred = model(x_t, t, labels)
loss = F.mse_loss(v_pred, v_target)
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
n += 1
if n % 20 == 0:
pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")
elapsed = time.time() - t0
avg_loss = total_loss / n
mk = ""
if avg_loss < best_loss:
best_loss = avg_loss
torch.save({
'state_dict': model.state_dict(),
'epoch': epoch + 1,
'loss': avg_loss,
}, 'checkpoints/constellation_bn_best.pt')
mk = " β˜…"
print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} "
f"({elapsed:.0f}s){mk}")
# Diagnostics
if (epoch + 1) % 10 == 0:
bn = model.bottleneck
drift = bn.drift()
gate = bn.skip_gate.sigmoid().item()
print(f" Bottleneck: drift={drift.mean():.4f}rad ({math.degrees(drift.mean()):.1f}Β°) "
f"max={drift.max():.4f}rad gate={gate:.4f}")
# Sample
if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:
imgs, _ = sample(model, 64, 50)
imgs = (imgs + 1) / 2
save_image(make_grid(imgs, nrow=8), f'samples_bn/epoch_{epoch+1:03d}.png')
print(f" β†’ samples_bn/epoch_{epoch+1:03d}.png")
if (epoch + 1) % (SAMPLE_EVERY * 2) == 0:
class_names = ['plane','auto','bird','cat','deer',
'dog','frog','horse','ship','truck']
for c in range(N_CLASSES):
cs, _ = sample(model, 8, 50, class_label=c)
save_image(make_grid((cs+1)/2, nrow=8),
f'samples_bn/epoch_{epoch+1:03d}_{class_names[c]}.png')
print(f"\n{'='*70}")
print(f"DONE β€” Best loss: {best_loss:.4f}")
print(f" Params: {n_params:,} (bottleneck: {n_bottleneck:,})")
print(f"{'='*70}")