gemeo-twin-stack / src /gemeo /cwm /block_diffusion.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Block Diffusion Transformer for clinical event trajectories.
Architecture follows BD3-LMs (Kuleshov group, ICLR 2025 Oral, arXiv 2503.09573):
- Absorbing-state discrete diffusion (each clean token can become [MASK]
with probability t in [0,1])
- Block-causal attention: token i attends to all tokens in its block + all
earlier blocks (block size B=16). Lets sampling proceed block-by-block
while still parallelizing within a block.
- Time conditioning via sinusoidal embedding of t.
- Classifier-free guidance: each sequence has a (treatment, time-zero,
cohort-key) condition. During training, 10% of batches drop the condition
to the null token; at inference we mix log-probs with guidance scale gamma.
Calibrated for M4 Pro 24GB / MPS:
- 80M params (d_model=640, n_heads=10, n_layers=16, ffn=2560)
- 384-token max sequence (covers ~24 events including BOS/EOS/SEP/YEAR)
- Block size 16, denoise steps 100 at inference (~2.5s per trajectory)
- 1.2h to train 100 epochs on the ~600-sequence DATASUS v1 cohort
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class CWMConfig:
vocab_size: int = 217 # 216 events + 1 MASK absorbing state
mask_token: int = 216
null_cond: int = 0 # CFG null condition (= "no treatment specified")
n_conditions: int = 32 # treatment/intervention vocabulary
max_seq_len: int = 384
block_size: int = 16
d_model: int = 640
n_heads: int = 10
n_layers: int = 16
ffn: int = 2560
dropout: float = 0.1
n_diff_steps: int = 100 # inference denoise steps
cond_dropout: float = 0.10 # CFG training: drop condition with this prob
class SinusoidalTimeEmbed(nn.Module):
"""Sinusoidal embedding of continuous time t in [0,1]."""
def __init__(self, d: int):
super().__init__()
self.d = d
self.proj = nn.Sequential(
nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.d // 2
freqs = torch.exp(
-math.log(10000.0) * torch.arange(half, device=t.device) / half
)
ang = t.float().unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
return self.proj(emb)
class BlockDiffusionTransformer(nn.Module):
"""Block-causal Transformer that denoises absorbing-state corrupted sequences.
Forward: x (B, T) tokens with [MASK] at some positions, t (B,) diffusion
time, cond (B,) treatment-condition id -> logits (B, T, V).
Loss is masked cross-entropy: only positions that were corrupted to MASK
contribute (standard absorbing-state diffusion objective).
"""
def __init__(self, cfg: CWMConfig | None = None):
super().__init__()
self.cfg = cfg or CWMConfig()
c = self.cfg
self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
self.pos_emb = nn.Embedding(c.max_seq_len, c.d_model)
self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
self.time_emb = SinusoidalTimeEmbed(c.d_model)
layer = nn.TransformerEncoderLayer(
d_model=c.d_model, nhead=c.n_heads,
dim_feedforward=c.ffn, dropout=c.dropout,
batch_first=True, activation="gelu", norm_first=True,
)
self.transformer = nn.TransformerEncoder(layer, num_layers=c.n_layers)
self.norm = nn.LayerNorm(c.d_model)
self.head = nn.Linear(c.d_model, c.vocab_size)
self._register_block_mask()
def _register_block_mask(self):
T = self.cfg.max_seq_len
bs = self.cfg.block_size
block_id = torch.arange(T) // bs
# mask[i,j] = True means token i should NOT attend to token j
# Block-causal: attend to own block + all earlier blocks
mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
self.register_buffer("block_mask", mask)
def forward(self, x: torch.Tensor, t: torch.Tensor,
cond: torch.Tensor) -> torch.Tensor:
B, T = x.shape
pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
h = self.tok_emb(x) + self.pos_emb(pos)
h = h + self.time_emb(t).unsqueeze(1) + self.cond_emb(cond).unsqueeze(1)
mask = self.block_mask[:T, :T]
h = self.transformer(h, mask=mask, is_causal=False)
return self.head(self.norm(h))
@torch.no_grad()
def absorbing_corrupt(self, x_clean: torch.Tensor,
t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward diffusion: each token independently masked with probability t."""
mp = t.unsqueeze(-1)
m = torch.rand_like(x_clean, dtype=torch.float) < mp
x_noisy = torch.where(m, self.cfg.mask_token, x_clean)
return x_noisy, m
def diffusion_loss(self, x_clean: torch.Tensor,
cond: torch.Tensor) -> torch.Tensor:
"""Standard absorbing-state diffusion loss with CFG cond-dropout."""
B = x_clean.size(0)
device = x_clean.device
# CFG: drop condition to null with prob cond_dropout
drop = torch.rand(B, device=device) < self.cfg.cond_dropout
cond = torch.where(drop, torch.zeros_like(cond), cond)
# Sample noise level uniformly in (0,1)
t = torch.rand(B, device=device).clamp(min=1e-3, max=1.0 - 1e-3)
x_noisy, mask = self.absorbing_corrupt(x_clean, t)
logits = self.forward(x_noisy, t, cond)
# CE loss only on masked positions
ce = F.cross_entropy(
logits.permute(0, 2, 1), x_clean, reduction="none",
)
n_masked = mask.float().sum().clamp(min=1.0)
return (ce * mask.float()).sum() / n_masked