geolip-spherical-diffusion-proto / GLFM_trainer_model.py
AbstractPhil's picture
Create GLFM_trainer_model.py
0a70478 verified
#!/usr/bin/env python3
"""
Geometric Lookup Flow Matching (GLFM)
========================================
A flow matching variant where velocity prediction is driven by
geometric address lookup on S^15.
Core insight (empirical):
The constellation bottleneck doesn't reconstruct encoder features.
It produces cos_sim β‰ˆ 0 to its input. Instead, the triangulation
profile acts as a continuous ADDRESS on the unit hypersphere,
and the generator produces velocity fields from that address.
This is: v(x_t, t, c) = Generator(Address(x_t), t, c)
where Address(x) = triangulate(project_to_sphere(encode(x)))
GLFM formalizes this into three stages:
Stage 1 β€” GEOMETRIC ADDRESSING
Encoder maps x_t to multiple resolution embeddings on S^15.
Each resolution captures different spatial frequency information.
Triangulation against fixed anchors produces a structured address.
Stage 2 β€” ADDRESS CONDITIONING
The geometric address is concatenated with:
- Timestep embedding (sinusoidal)
- Class/text conditioning
- Noise level features
The conditioning modulates WHAT to generate at this address.
Stage 3 β€” VELOCITY GENERATION
A deep MLP generates the velocity field from the conditioned address.
This is NOT reconstruction β€” it's generation from a lookup.
The generator never sees the raw encoder features.
Key properties:
- Address space is geometrically structured (Voronoi cells on S^15)
- Anchors self-organize: <0.29 rad = frame holders, >0.29 = task encoders
- Precision-invariant (works at fp8)
- 21Γ— compression with zero velocity quality loss
- Multi-scale addressing captures both coarse and fine structure
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import time
from tqdm import tqdm
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ══════════════════════════════════════════════════════════════════
# STAGE 1: GEOMETRIC ADDRESSING
# ══════════════════════════════════════════════════════════════════
class GeometricAddressEncoder(nn.Module):
"""
Maps spatial features to geometric addresses on S^15.
Multi-scale: produces addresses at 2 resolutions.
- Coarse: global pool β†’ single 256d embedding β†’ 1 address
- Fine: per-spatial-position β†’ 256d embeddings β†’ HW addresses
Each address is triangulated against the constellation.
The combined triangulation profiles form the full geometric address.
"""
def __init__(
self,
spatial_channels, # C from encoder output
spatial_size, # H (=W) from encoder output
embed_dim=256,
patch_dim=16,
n_anchors=16,
n_phases=3,
):
super().__init__()
self.spatial_channels = spatial_channels
self.spatial_size = spatial_size
self.embed_dim = embed_dim
self.patch_dim = patch_dim
self.n_patches = embed_dim // patch_dim
self.n_anchors = n_anchors
self.n_phases = n_phases
P, A, d = self.n_patches, n_anchors, patch_dim
# Coarse address: global pool β†’ sphere
self.coarse_proj = nn.Sequential(
nn.Linear(spatial_channels, embed_dim),
nn.LayerNorm(embed_dim),
)
# Fine address: per-position β†’ sphere
self.fine_proj = nn.Sequential(
nn.Linear(spatial_channels, embed_dim),
nn.LayerNorm(embed_dim),
)
# Shared constellation β€” same anchors for both scales
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
# Triangulation dimensions per address
self.tri_dim = P * A * n_phases # 768
# Total address dim: coarse(768) + fine_aggregated(768)
self.address_dim = self.tri_dim * 2
def drift(self):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
def at_phase(self, t):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-7)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def triangulate(self, patches_n):
"""patches_n: (..., P, d) β†’ (..., P*A*n_phases)"""
shape = patches_n.shape[:-2]
P, A, d = self.n_patches, self.n_anchors, self.patch_dim
flat = patches_n.reshape(-1, P, d)
phases = torch.linspace(0, 1, self.n_phases, device=flat.device).tolist()
tris = []
for t in phases:
at = F.normalize(self.at_phase(t), dim=-1)
tris.append(1.0 - torch.einsum('bpd,pad->bpa', flat, at))
tri = torch.cat(tris, dim=-1).reshape(flat.shape[0], -1)
return tri.reshape(*shape, -1)
def forward(self, feature_map):
"""
feature_map: (B, C, H, W) from encoder
Returns: (B, address_dim) geometric address
"""
B, C, H, W = feature_map.shape
# Coarse: global pool β†’ single address
coarse = feature_map.mean(dim=(-2, -1)) # (B, C)
coarse_emb = self.coarse_proj(coarse) # (B, embed_dim)
coarse_patches = F.normalize(
coarse_emb.reshape(B, self.n_patches, self.patch_dim), dim=-1)
coarse_addr = self.triangulate(coarse_patches) # (B, tri_dim)
# Fine: per-position, then aggregate
fine = feature_map.permute(0, 2, 3, 1).reshape(B * H * W, C) # (BHW, C)
fine_emb = self.fine_proj(fine) # (BHW, embed_dim)
fine_patches = F.normalize(
fine_emb.reshape(B * H * W, self.n_patches, self.patch_dim), dim=-1)
fine_addr = self.triangulate(fine_patches) # (BHW, tri_dim)
# Aggregate fine addresses: mean + max pooling
fine_addr = fine_addr.reshape(B, H * W, -1)
fine_mean = fine_addr.mean(dim=1) # (B, tri_dim)
fine_max = fine_addr.max(dim=1).values # (B, tri_dim)
# Combine mean and max via learned gate
fine_combined = (fine_mean + fine_max) / 2 # (B, tri_dim)
# Full address = coarse + fine
return torch.cat([coarse_addr, fine_combined], dim=-1) # (B, 2*tri_dim)
# ══════════════════════════════════════════════════════════════════
# STAGE 2: ADDRESS CONDITIONING
# ══════════════════════════════════════════════════════════════════
class AddressConditioner(nn.Module):
"""
Combines geometric address with timestep and class conditioning.
Produces a conditioned address vector ready for the generator.
"""
def __init__(self, address_dim, cond_dim=256, output_dim=1024):
super().__init__()
self.time_emb = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.GELU(),
nn.Linear(cond_dim, cond_dim))
# Noise level features β€” learned embedding of discretized t
self.noise_emb = nn.Embedding(64, cond_dim)
self.fuse = nn.Sequential(
nn.Linear(address_dim + cond_dim * 3, output_dim),
nn.GELU(),
nn.LayerNorm(output_dim),
)
def forward(self, address, t, class_emb):
"""
address: (B, address_dim) from geometric encoder
t: (B,) timestep
class_emb: (B, cond_dim) class embedding
Returns: (B, output_dim) conditioned address
"""
t_emb = self.time_emb(t)
# Discretize t for noise level embedding
t_discrete = (t * 63).long().clamp(0, 63)
n_emb = self.noise_emb(t_discrete)
combined = torch.cat([address, t_emb, class_emb, n_emb], dim=-1)
return self.fuse(combined)
# ══════════════════════════════════════════════════════════════════
# STAGE 3: VELOCITY GENERATOR
# ══════════════════════════════════════════════════════════════════
class VelocityGenerator(nn.Module):
"""
Generates spatial velocity features from a conditioned address.
NOT reconstruction β€” generation from geometric lookup.
"""
def __init__(self, cond_address_dim, spatial_dim, hidden=1024, depth=4):
super().__init__()
self.spatial_dim = spatial_dim
# Deep residual MLP
self.blocks = nn.ModuleList()
self.blocks.append(nn.Sequential(
nn.Linear(cond_address_dim, hidden),
nn.GELU(), nn.LayerNorm(hidden)))
for _ in range(depth):
self.blocks.append(ResBlock(hidden))
self.head = nn.Sequential(
nn.Linear(hidden, hidden), nn.GELU(),
nn.Linear(hidden, spatial_dim))
def forward(self, cond_address):
"""
cond_address: (B, cond_address_dim)
Returns: (B, spatial_dim) generated velocity features
"""
h = self.blocks[0](cond_address)
for block in self.blocks[1:]:
h = block(h)
return self.head(h)
class ResBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim), nn.GELU(), nn.LayerNorm(dim),
nn.Linear(dim, dim), nn.GELU(), nn.LayerNorm(dim))
def forward(self, x):
return x + self.net(x)
# ══════════════════════════════════════════════════════════════════
# BUILDING BLOCKS
# ══════════════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class AdaGroupNorm(nn.Module):
def __init__(self, ch, cond_dim, groups=8):
super().__init__()
self.gn = nn.GroupNorm(min(groups, ch), ch, affine=False)
self.proj = nn.Linear(cond_dim, ch * 2)
nn.init.zeros_(self.proj.weight); nn.init.zeros_(self.proj.bias)
def forward(self, x, cond):
x = self.gn(x)
s, sh = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
return x * (1 + s) + sh
class ConvBlock(nn.Module):
def __init__(self, ch, cond_dim):
super().__init__()
self.dw = nn.Conv2d(ch, ch, 7, padding=3, groups=ch)
self.norm = AdaGroupNorm(ch, cond_dim)
self.pw1 = nn.Conv2d(ch, ch * 4, 1)
self.pw2 = nn.Conv2d(ch * 4, ch, 1)
self.act = nn.GELU()
def forward(self, x, cond):
r = x
x = self.act(self.pw1(self.norm(self.dw(x), cond)))
return r + self.pw2(x)
class Downsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
def forward(self, x): return self.conv(x)
class Upsample(nn.Module):
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, 3, padding=1)
def forward(self, x):
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
# ══════════════════════════════════════════════════════════════════
# GLFM UNET
# ══════════════════════════════════════════════════════════════════
class GLFMUNet(nn.Module):
"""
Geometric Lookup Flow Matching UNet.
Encoder β†’ GeometricAddress β†’ Conditioner β†’ VelocityGenerator β†’ Decoder
The middle of the UNet is the three-stage GLFM pipeline.
No attention. No reconstruction. Pure geometric lookup.
"""
def __init__(
self,
in_ch=3,
base_ch=64,
ch_mults=(1, 2, 4),
n_classes=10,
cond_dim=256,
embed_dim=256,
n_anchors=16,
n_phases=3,
gen_hidden=1024,
gen_depth=4,
):
super().__init__()
self.ch_mults = ch_mults
# Class embedding (shared with conditioner)
self.class_emb = nn.Embedding(n_classes, cond_dim)
# Encoder conditioning (for AdaGroupNorm in conv blocks)
self.enc_time = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.GELU(),
nn.Linear(cond_dim, cond_dim))
self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
# Encoder
self.enc = nn.ModuleList()
self.enc_down = nn.ModuleList()
ch = base_ch
enc_channels = [base_ch]
for i, m in enumerate(ch_mults):
ch_out = base_ch * m
self.enc.append(nn.ModuleList([
ConvBlock(ch, cond_dim) if ch == ch_out
else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)),
ConvBlock(ch_out, cond_dim),
]))
ch = ch_out
enc_channels.append(ch)
if i < len(ch_mults) - 1:
self.enc_down.append(Downsample(ch))
# β˜… GLFM PIPELINE β˜…
mid_ch = ch
H_mid = 32 // (2 ** (len(ch_mults) - 1))
spatial_dim = mid_ch * H_mid * H_mid
self.mid_spatial = (mid_ch, H_mid, H_mid)
# Stage 1: Geometric Address Encoder
self.geo_encoder = GeometricAddressEncoder(
spatial_channels=mid_ch,
spatial_size=H_mid,
embed_dim=embed_dim,
patch_dim=16,
n_anchors=n_anchors,
n_phases=n_phases,
)
# Stage 2: Address Conditioner
self.conditioner = AddressConditioner(
address_dim=self.geo_encoder.address_dim,
cond_dim=cond_dim,
output_dim=gen_hidden,
)
# Stage 3: Velocity Generator
self.generator = VelocityGenerator(
cond_address_dim=gen_hidden,
spatial_dim=spatial_dim,
hidden=gen_hidden,
depth=gen_depth,
)
# Decoder
self.dec_up = nn.ModuleList()
self.dec_skip = nn.ModuleList()
self.dec = nn.ModuleList()
# Decoder conditioning
self.dec_time = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.GELU(),
nn.Linear(cond_dim, cond_dim))
for i in range(len(ch_mults) - 1, -1, -1):
ch_out = base_ch * ch_mults[i]
skip_ch = enc_channels.pop()
self.dec_skip.append(nn.Conv2d(ch + skip_ch, ch_out, 1))
self.dec.append(nn.ModuleList([
ConvBlock(ch_out, cond_dim),
ConvBlock(ch_out, cond_dim),
]))
ch = ch_out
if i > 0:
self.dec_up.append(Upsample(ch))
self.out_norm = nn.GroupNorm(8, ch)
self.out_conv = nn.Conv2d(ch, in_ch, 3, padding=1)
nn.init.zeros_(self.out_conv.weight)
nn.init.zeros_(self.out_conv.bias)
def forward(self, x, t, class_labels):
# Conditioning
enc_cond = self.enc_time(t) + self.class_emb(class_labels)
dec_cond = self.dec_time(t) + self.class_emb(class_labels)
cls_emb = self.class_emb(class_labels)
h = self.in_conv(x)
skips = [h]
# Encoder
for i in range(len(self.ch_mults)):
for block in self.enc[i]:
if isinstance(block, ConvBlock): h = block(h, enc_cond)
elif isinstance(block, nn.Sequential):
h = block[0](h); h = block[1](h, enc_cond)
skips.append(h)
if i < len(self.enc_down):
h = self.enc_down[i](h)
# β˜… GLFM: Address β†’ Condition β†’ Generate β˜…
B = h.shape[0]
address = self.geo_encoder(h) # Stage 1
cond_addr = self.conditioner(address, t, cls_emb) # Stage 2
h = self.generator(cond_addr) # Stage 3
h = h.reshape(B, *self.mid_spatial)
# Decoder
for i in range(len(self.ch_mults)):
skip = skips.pop()
if i > 0:
h = self.dec_up[i - 1](h)
h = torch.cat([h, skip], dim=1)
h = self.dec_skip[i](h)
for block in self.dec[i]:
h = block(h, dec_cond)
return self.out_conv(F.silu(self.out_norm(h)))
# ══════════════════════════════════════════════════════════════════
# SAMPLING
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def sample(model, n=64, steps=50, cls=None, n_cls=10):
model.eval()
x = torch.randn(n, 3, 32, 32, device=DEVICE)
labels = (torch.full((n,), cls, dtype=torch.long, device=DEVICE)
if cls is not None else torch.randint(0, n_cls, (n,), device=DEVICE))
dt = 1.0 / steps
for s in range(steps):
t = torch.full((n,), 1.0 - s * dt, device=DEVICE)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v = model(x, t, labels)
x = x - v.float() * dt
return x.clamp(-1, 1), labels
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
BATCH = 128
EPOCHS = 80
LR = 3e-4
SAMPLE_EVERY = 5
print("=" * 70)
print("GEOMETRIC LOOKUP FLOW MATCHING (GLFM)")
print(f" Three-stage: Address β†’ Condition β†’ Generate")
print(f" Multi-scale: coarse (global) + fine (per-position)")
print(f" Device: {DEVICE}")
print("=" * 70)
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,)*3, (0.5,)*3),
])
train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH, shuffle=True,
num_workers=4, pin_memory=True, drop_last=True)
model = GLFMUNet(
in_ch=3, base_ch=64, ch_mults=(1, 2, 4),
n_classes=10, cond_dim=256, embed_dim=256,
n_anchors=16, n_phases=3,
gen_hidden=1024, gen_depth=4,
).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
n_geo = sum(p.numel() for p in model.geo_encoder.parameters())
n_cond = sum(p.numel() for p in model.conditioner.parameters())
n_gen = sum(p.numel() for p in model.generator.parameters())
n_anchor = sum(p.numel() for n, p in model.named_parameters() if 'anchor' in n)
print(f" Total: {n_params:,}")
print(f" Geo Encoder: {n_geo:,} (Stage 1 β€” address)")
print(f" Conditioner: {n_cond:,} (Stage 2 β€” fuse)")
print(f" Generator: {n_gen:,} (Stage 3 β€” velocity)")
print(f" Anchors: {n_anchor:,}")
print(f" Address dim: {model.geo_encoder.address_dim} "
f"(coarse {model.geo_encoder.tri_dim} + fine {model.geo_encoder.tri_dim})")
print(f" Compression: {model.generator.spatial_dim} β†’ "
f"{model.geo_encoder.address_dim} "
f"({model.generator.spatial_dim / model.geo_encoder.address_dim:.1f}Γ—)")
# Shape check
with torch.no_grad():
d = torch.randn(2, 3, 32, 32, device=DEVICE)
o = model(d, torch.rand(2, device=DEVICE), torch.randint(0, 10, (2,), device=DEVICE))
print(f" Shape: {d.shape} β†’ {o.shape} βœ“")
print(f" Train: {len(train_ds):,}")
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6)
scaler = torch.amp.GradScaler("cuda")
os.makedirs("samples_glfm", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
print(f"\n{'='*70}")
print(f"TRAINING β€” {EPOCHS} epochs")
print(f"{'='*70}")
best_loss = float('inf')
bn = model.geo_encoder # for diagnostics
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
total_loss = 0
n = 0
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
for images, labels in pbar:
images = images.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
B = images.shape[0]
t = torch.rand(B, device=DEVICE)
eps = torch.randn_like(images)
t_b = t.view(B, 1, 1, 1)
x_t = (1 - t_b) * images + t_b * eps
v_target = eps - images
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
v_pred = model(x_t, t, labels)
loss = F.mse_loss(v_pred, v_target)
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
n += 1
if n % 20 == 0:
pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")
elapsed = time.time() - t0
avg_loss = total_loss / n
mk = ""
if avg_loss < best_loss:
best_loss = avg_loss
torch.save({
'state_dict': model.state_dict(),
'epoch': epoch + 1, 'loss': avg_loss,
}, 'checkpoints/glfm_best.pt')
mk = " β˜…"
print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} "
f"({elapsed:.0f}s){mk}")
# Diagnostics
if (epoch + 1) % 10 == 0:
with torch.no_grad():
drift = bn.drift().detach()
near = (drift - 0.29154).abs().lt(0.05).float().mean().item()
crossed = (drift > 0.29154).float().mean().item()
print(f" β˜… drift: mean={drift.mean():.4f} max={drift.max():.4f} "
f"near_0.29={near:.1%} crossed={crossed:.1%}")
# Sample
if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:
imgs, _ = sample(model, 64, 50)
save_image(make_grid((imgs + 1) / 2, nrow=8), f'samples_glfm/epoch_{epoch+1:03d}.png')
print(f" β†’ samples_glfm/epoch_{epoch+1:03d}.png")
if (epoch + 1) % 20 == 0:
names = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck']
for c in range(10):
cs, _ = sample(model, 8, 50, cls=c)
save_image(make_grid((cs+1)/2, nrow=8),
f'samples_glfm/epoch_{epoch+1:03d}_{names[c]}.png')
print(f" β†’ per-class samples")
print(f"\n{'='*70}")
print(f"GEOMETRIC LOOKUP FLOW MATCHING β€” COMPLETE")
print(f" Best loss: {best_loss:.4f}")
print(f" Total: {n_params:,}")
with torch.no_grad():
drift = bn.drift().detach()
near = (drift - 0.29154).abs().lt(0.05).float().mean().item()
crossed = (drift > 0.29154).float().mean().item()
print(f" Final drift: mean={drift.mean():.4f} max={drift.max():.4f}")
print(f" Near 0.29: {near:.1%} Crossed: {crossed:.1%}")
print(f"{'='*70}")