geolip-spherical-diffusion-proto / modeling_trainer_v2.py
AbstractPhil's picture
Create modeling_trainer_v2.py
f5ff727 verified
#!/usr/bin/env python3
"""
Constellation Diffusion
========================
Everything through the sphere. No skip projection. No attention.
The constellation IS the model's information bottleneck.
16384d encoder output β†’ 256d sphere β†’ 768d triangulation
β†’ conditioned patchwork β†’ 16384d reconstruction
The patchwork must carry ALL information through 768 geometric
measurements. If it works, diffusion is solved through triangulation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import time
from tqdm import tqdm
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION BOTTLENECK β€” NO SKIP
# ══════════════════════════════════════════════════════════════════
class ConstellationBottleneck(nn.Module):
"""
Pure constellation bottleneck. No skip path.
All information passes through S^15 triangulation.
Flow:
(B, spatial) β†’ proj_in(spatial, embed) β†’ LN β†’ reshape β†’ L2 norm
β†’ triangulate: P patches Γ— A anchors Γ— n_phases = tri_dim
β†’ concat(tri, cond)
β†’ deep patchwork with residual blocks
β†’ proj_out(hidden, spatial)
"""
def __init__(
self,
spatial_dim, # C*H*W from encoder
embed_dim=256,
patch_dim=16,
n_anchors=16,
n_phases=3,
cond_dim=256,
pw_hidden=1024,
pw_depth=4, # number of residual blocks in patchwork
):
super().__init__()
self.spatial_dim = spatial_dim
self.embed_dim = embed_dim
self.patch_dim = patch_dim
self.n_patches = embed_dim // patch_dim
self.n_anchors = n_anchors
self.n_phases = n_phases
P, A, d = self.n_patches, n_anchors, patch_dim
# Encoder projection β†’ sphere
self.proj_in = nn.Sequential(
nn.Linear(spatial_dim, embed_dim),
nn.LayerNorm(embed_dim),
)
# Constellation anchors β€” home + learnable
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
# Triangulation dimensions
tri_dim = P * A * n_phases # 16 * 16 * 3 = 768
# Conditioning projection β€” align cond to patchwork input space
pw_input = tri_dim + cond_dim
self.input_proj = nn.Sequential(
nn.Linear(pw_input, pw_hidden),
nn.GELU(),
nn.LayerNorm(pw_hidden),
)
# Deep patchwork β€” residual MLP blocks
# This must carry ALL information. Make it deep enough.
self.pw_blocks = nn.ModuleList()
for _ in range(pw_depth):
self.pw_blocks.append(nn.Sequential(
nn.Linear(pw_hidden, pw_hidden),
nn.GELU(),
nn.LayerNorm(pw_hidden),
nn.Linear(pw_hidden, pw_hidden),
nn.GELU(),
nn.LayerNorm(pw_hidden),
))
# Reconstruction head
self.proj_out = nn.Sequential(
nn.Linear(pw_hidden, pw_hidden),
nn.GELU(),
nn.Linear(pw_hidden, spatial_dim),
)
def drift(self):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
def at_phase(self, t):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-7)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def triangulate(self, patches_n):
"""
patches_n: (B, P, d) normalized on S^(d-1)
Returns: (B, P*A*n_phases) full triangulation profile
"""
phases = torch.linspace(0, 1, self.n_phases, device=patches_n.device).tolist()
tris = []
for t in phases:
anchors_t = F.normalize(self.at_phase(t), dim=-1)
cos = torch.einsum('bpd,pad->bpa', patches_n, anchors_t)
tris.append(1.0 - cos)
return torch.cat(tris, dim=-1).reshape(patches_n.shape[0], -1)
def forward(self, x_flat, cond):
"""
x_flat: (B, spatial_dim)
cond: (B, cond_dim)
Returns: (B, spatial_dim)
"""
# Project to sphere
emb = self.proj_in(x_flat) # (B, embed_dim)
B = emb.shape[0]
patches = emb.reshape(B, self.n_patches, self.patch_dim)
patches_n = F.normalize(patches, dim=-1) # on S^(d-1)
# Triangulate β€” the geometric encoding
tri = self.triangulate(patches_n) # (B, tri_dim)
# Inject conditioning
pw_in = torch.cat([tri, cond], dim=-1) # (B, tri_dim + cond_dim)
# Deep patchwork with residual connections
h = self.input_proj(pw_in)
for block in self.pw_blocks:
h = h + block(h) # residual
# Reconstruct spatial features
return self.proj_out(h)
# ══════════════════════════════════════════════════════════════════
# UNET BUILDING BLOCKS
# ══════════════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class AdaGroupNorm(nn.Module):
def __init__(self, channels, cond_dim, n_groups=8):
super().__init__()
self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False)
self.proj = nn.Linear(cond_dim, channels * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x, cond):
x = self.gn(x)
s, sh = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
return x * (1 + s) + sh
class ConvBlock(nn.Module):
def __init__(self, channels, cond_dim):
super().__init__()
self.dw = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
self.norm = AdaGroupNorm(channels, cond_dim)
self.pw1 = nn.Conv2d(channels, channels * 4, 1)
self.pw2 = nn.Conv2d(channels * 4, channels, 1)
self.act = nn.GELU()
def forward(self, x, cond):
r = x
x = self.dw(x)
x = self.norm(x, cond)
x = self.act(self.pw1(x))
return r + self.pw2(x)
class Downsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
def forward(self, x): return self.conv(x)
class Upsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, padding=1)
def forward(self, x):
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION DIFFUSION UNET
# ══════════════════════════════════════════════════════════════════
class ConstellationDiffusionUNet(nn.Module):
"""
UNet where the middle block IS the constellation.
No attention. No skip projection. Pure geometric bottleneck.
"""
def __init__(
self,
in_ch=3,
base_ch=64,
ch_mults=(1, 2, 4),
n_classes=10,
cond_dim=256,
embed_dim=256,
n_anchors=16,
n_phases=3,
pw_hidden=1024,
pw_depth=4,
):
super().__init__()
self.ch_mults = ch_mults
# Conditioning
self.time_emb = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.GELU(),
nn.Linear(cond_dim, cond_dim))
self.class_emb = nn.Embedding(n_classes, cond_dim)
self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
# Encoder
self.enc = nn.ModuleList()
self.enc_down = nn.ModuleList()
ch = base_ch
enc_channels = [base_ch]
for i, m in enumerate(ch_mults):
ch_out = base_ch * m
self.enc.append(nn.ModuleList([
ConvBlock(ch, cond_dim) if ch == ch_out
else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)),
ConvBlock(ch_out, cond_dim),
]))
ch = ch_out
enc_channels.append(ch)
if i < len(ch_mults) - 1:
self.enc_down.append(Downsample(ch))
# Constellation bottleneck β€” NO SKIP
mid_ch = ch
H_mid = 32 // (2 ** (len(ch_mults) - 1)) # spatial size at bottleneck
spatial_dim = mid_ch * H_mid * H_mid
self.mid_spatial = (mid_ch, H_mid, H_mid)
self.bottleneck = ConstellationBottleneck(
spatial_dim=spatial_dim,
embed_dim=embed_dim,
patch_dim=16,
n_anchors=n_anchors,
n_phases=n_phases,
cond_dim=cond_dim,
pw_hidden=pw_hidden,
pw_depth=pw_depth,
)
# Decoder
self.dec_up = nn.ModuleList()
self.dec_skip_proj = nn.ModuleList()
self.dec = nn.ModuleList()
for i in range(len(ch_mults) - 1, -1, -1):
ch_out = base_ch * ch_mults[i]
skip_ch = enc_channels.pop()
self.dec_skip_proj.append(nn.Conv2d(ch + skip_ch, ch_out, 1))
self.dec.append(nn.ModuleList([
ConvBlock(ch_out, cond_dim),
ConvBlock(ch_out, cond_dim),
]))
ch = ch_out
if i > 0:
self.dec_up.append(Upsample(ch))
self.out_norm = nn.GroupNorm(8, ch)
self.out_conv = nn.Conv2d(ch, in_ch, 3, padding=1)
nn.init.zeros_(self.out_conv.weight)
nn.init.zeros_(self.out_conv.bias)
def forward(self, x, t, class_labels):
cond = self.time_emb(t) + self.class_emb(class_labels)
h = self.in_conv(x)
skips = [h]
# Encoder
for i in range(len(self.ch_mults)):
for block in self.enc[i]:
if isinstance(block, ConvBlock):
h = block(h, cond)
elif isinstance(block, nn.Sequential):
h = block[0](h); h = block[1](h, cond)
skips.append(h)
if i < len(self.enc_down):
h = self.enc_down[i](h)
# β˜… CONSTELLATION BOTTLENECK β€” everything through S^15 β˜…
B = h.shape[0]
h = self.bottleneck(h.reshape(B, -1), cond)
h = h.reshape(B, *self.mid_spatial)
# Decoder
for i in range(len(self.ch_mults)):
skip = skips.pop()
if i > 0:
h = self.dec_up[i - 1](h)
h = torch.cat([h, skip], dim=1)
h = self.dec_skip_proj[i](h)
for block in self.dec[i]:
h = block(h, cond)
return self.out_conv(F.silu(self.out_norm(h)))
# ══════════════════════════════════════════════════════════════════
# SAMPLING
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def sample(model, n=64, steps=50, cls=None, n_cls=10):
model.eval()
x = torch.randn(n, 3, 32, 32, device=DEVICE)
labels = (torch.full((n,), cls, dtype=torch.long, device=DEVICE)
if cls is not None else torch.randint(0, n_cls, (n,), device=DEVICE))
dt = 1.0 / steps
for s in range(steps):
t = torch.full((n,), 1.0 - s * dt, device=DEVICE)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v = model(x, t, labels)
x = x - v.float() * dt
return x.clamp(-1, 1), labels
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
BATCH = 128
EPOCHS = 80
LR = 3e-4
SAMPLE_EVERY = 5
print("=" * 70)
print("CONSTELLATION DIFFUSION β€” PURE GEOMETRIC BOTTLENECK")
print(f" No attention. No skip. Everything through S^15.")
print(f" Device: {DEVICE}")
print("=" * 70)
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,)*3, (0.5,)*3),
])
train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
model = ConstellationDiffusionUNet(
in_ch=3, base_ch=64, ch_mults=(1, 2, 4),
n_classes=10, cond_dim=256, embed_dim=256,
n_anchors=16, n_phases=3, pw_hidden=1024, pw_depth=4,
).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
n_bn = sum(p.numel() for p in model.bottleneck.parameters())
n_enc = sum(p.numel() for n, p in model.named_parameters()
if 'enc' in n or 'in_conv' in n)
n_dec = sum(p.numel() for n, p in model.named_parameters()
if 'dec' in n or 'out' in n)
n_anchor = sum(p.numel() for n, p in model.named_parameters() if 'anchor' in n)
print(f" Total: {n_params:,}")
print(f" Encoder: {n_enc:,}")
print(f" Bottleneck: {n_bn:,} ({100*n_bn/n_params:.1f}%)")
print(f" Anchors: {n_anchor:,}")
print(f" Decoder: {n_dec:,}")
print(f" Train: {len(train_ds):,} images")
# Shape check
with torch.no_grad():
d = torch.randn(2, 3, 32, 32, device=DEVICE)
o = model(d, torch.rand(2, device=DEVICE), torch.randint(0, 10, (2,), device=DEVICE))
print(f" Shape: {d.shape} β†’ {o.shape} βœ“")
bn = model.bottleneck
print(f" Bottleneck: {bn.spatial_dim}d β†’ {bn.embed_dim}d sphere β†’ "
f"{bn.n_patches}pΓ—{bn.patch_dim}d β†’ "
f"{bn.n_patches * bn.n_anchors * bn.n_phases} tri dims")
print(f" Patchwork: {len(bn.pw_blocks)} residual blocks Γ— {1024}d")
print(f" Compression: {bn.spatial_dim} β†’ {bn.n_patches * bn.n_anchors * bn.n_phases} "
f"({bn.spatial_dim / (bn.n_patches * bn.n_anchors * bn.n_phases):.1f}Γ— ratio)")
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6)
scaler = torch.amp.GradScaler("cuda")
os.makedirs("samples_cd", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
print(f"\n{'='*70}")
print(f"TRAINING β€” {EPOCHS} epochs, pure constellation diffusion")
print(f"{'='*70}")
best_loss = float('inf')
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
total_loss = 0
n = 0
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
for images, labels in pbar:
images = images.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
B = images.shape[0]
t = torch.rand(B, device=DEVICE)
eps = torch.randn_like(images)
t_b = t.view(B, 1, 1, 1)
x_t = (1 - t_b) * images + t_b * eps
v_target = eps - images
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v_pred = model(x_t, t, labels)
loss = F.mse_loss(v_pred, v_target)
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
n += 1
if n % 20 == 0:
pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")
elapsed = time.time() - t0
avg_loss = total_loss / n
mk = ""
if avg_loss < best_loss:
best_loss = avg_loss
torch.save({
'state_dict': model.state_dict(),
'epoch': epoch + 1,
'loss': avg_loss,
}, 'checkpoints/constellation_diffusion_best.pt')
mk = " β˜…"
print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} "
f"({elapsed:.0f}s){mk}")
# Diagnostics
if (epoch + 1) % 10 == 0:
with torch.no_grad():
drift = bn.drift().detach()
near_029 = (drift - 0.29154).abs().lt(0.05).float().mean().item()
print(f" β˜… drift: mean={drift.mean():.4f}rad ({math.degrees(drift.mean().item()):.1f}Β°) "
f"max={drift.max():.4f}rad ({math.degrees(drift.max().item()):.1f}Β°) "
f"near_0.29: {near_029:.1%}")
# Anchor utilization quick check
test_imgs = torch.randn(64, 3, 32, 32, device=DEVICE)
t_test = torch.full((64,), 0.5, device=DEVICE)
c_test = torch.randint(0, 10, (64,), device=DEVICE)
cond = model.time_emb(t_test) + model.class_emb(c_test)
h = model.in_conv(test_imgs)
for i in range(len(model.ch_mults)):
for block in model.enc[i]:
if isinstance(block, ConvBlock): h = block(h, cond)
elif isinstance(block, nn.Sequential): h = block[0](h); h = block[1](h, cond)
if i < len(model.enc_down): h = model.enc_down[i](h)
emb = bn.proj_in(h.reshape(64, -1))
patches = F.normalize(emb.reshape(64, bn.n_patches, bn.patch_dim), dim=-1)
anchors_n = F.normalize(bn.anchors, dim=-1)
cos = torch.einsum('bpd,pad->bpa', patches, anchors_n)
nearest = cos.argmax(dim=-1) # (64, P)
# Count unique anchors used across all patches
unique = nearest.unique().numel()
total = bn.n_patches * bn.n_anchors
print(f" β˜… anchors: {unique}/{total} unique assignments "
f"({100*unique/total:.0f}% utilization)")
# Sample
if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:
imgs, _ = sample(model, 64, 50)
save_image(make_grid((imgs + 1) / 2, nrow=8), f'samples_cd/epoch_{epoch+1:03d}.png')
print(f" β†’ samples_cd/epoch_{epoch+1:03d}.png")
if (epoch + 1) % 20 == 0:
names = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck']
for c in range(10):
cs, _ = sample(model, 8, 50, cls=c)
save_image(make_grid((cs+1)/2, nrow=8),
f'samples_cd/epoch_{epoch+1:03d}_{names[c]}.png')
print(f" β†’ per-class samples saved")
print(f"\n{'='*70}")
print(f"CONSTELLATION DIFFUSION β€” COMPLETE")
print(f" Best loss: {best_loss:.4f}")
print(f" Params: {n_params:,} (bottleneck: {n_bn:,})")
with torch.no_grad():
drift = bn.drift().detach()
print(f" Final drift: mean={drift.mean():.4f} max={drift.max():.4f}")
print(f" Near 0.29154: {(drift - 0.29154).abs().lt(0.05).float().mean().item():.1%}")
print(f"{'='*70}")