"""Block Diffusion Transformer for clinical event trajectories. Architecture follows BD3-LMs (Kuleshov group, ICLR 2025 Oral, arXiv 2503.09573): - Absorbing-state discrete diffusion (each clean token can become [MASK] with probability t in [0,1]) - Block-causal attention: token i attends to all tokens in its block + all earlier blocks (block size B=16). Lets sampling proceed block-by-block while still parallelizing within a block. - Time conditioning via sinusoidal embedding of t. - Classifier-free guidance: each sequence has a (treatment, time-zero, cohort-key) condition. During training, 10% of batches drop the condition to the null token; at inference we mix log-probs with guidance scale gamma. Calibrated for M4 Pro 24GB / MPS: - 80M params (d_model=640, n_heads=10, n_layers=16, ffn=2560) - 384-token max sequence (covers ~24 events including BOS/EOS/SEP/YEAR) - Block size 16, denoise steps 100 at inference (~2.5s per trajectory) - 1.2h to train 100 epochs on the ~600-sequence DATASUS v1 cohort """ from __future__ import annotations import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class CWMConfig: vocab_size: int = 217 # 216 events + 1 MASK absorbing state mask_token: int = 216 null_cond: int = 0 # CFG null condition (= "no treatment specified") n_conditions: int = 32 # treatment/intervention vocabulary max_seq_len: int = 384 block_size: int = 16 d_model: int = 640 n_heads: int = 10 n_layers: int = 16 ffn: int = 2560 dropout: float = 0.1 n_diff_steps: int = 100 # inference denoise steps cond_dropout: float = 0.10 # CFG training: drop condition with this prob class SinusoidalTimeEmbed(nn.Module): """Sinusoidal embedding of continuous time t in [0,1].""" def __init__(self, d: int): super().__init__() self.d = d self.proj = nn.Sequential( nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d), ) def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.d // 2 freqs = torch.exp( -math.log(10000.0) * torch.arange(half, device=t.device) / half ) ang = t.float().unsqueeze(-1) * freqs emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) return self.proj(emb) class BlockDiffusionTransformer(nn.Module): """Block-causal Transformer that denoises absorbing-state corrupted sequences. Forward: x (B, T) tokens with [MASK] at some positions, t (B,) diffusion time, cond (B,) treatment-condition id -> logits (B, T, V). Loss is masked cross-entropy: only positions that were corrupted to MASK contribute (standard absorbing-state diffusion objective). """ def __init__(self, cfg: CWMConfig | None = None): super().__init__() self.cfg = cfg or CWMConfig() c = self.cfg self.tok_emb = nn.Embedding(c.vocab_size, c.d_model) self.pos_emb = nn.Embedding(c.max_seq_len, c.d_model) self.cond_emb = nn.Embedding(c.n_conditions, c.d_model) self.time_emb = SinusoidalTimeEmbed(c.d_model) layer = nn.TransformerEncoderLayer( d_model=c.d_model, nhead=c.n_heads, dim_feedforward=c.ffn, dropout=c.dropout, batch_first=True, activation="gelu", norm_first=True, ) self.transformer = nn.TransformerEncoder(layer, num_layers=c.n_layers) self.norm = nn.LayerNorm(c.d_model) self.head = nn.Linear(c.d_model, c.vocab_size) self._register_block_mask() def _register_block_mask(self): T = self.cfg.max_seq_len bs = self.cfg.block_size block_id = torch.arange(T) // bs # mask[i,j] = True means token i should NOT attend to token j # Block-causal: attend to own block + all earlier blocks mask = block_id.unsqueeze(0) < block_id.unsqueeze(1) self.register_buffer("block_mask", mask) def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: B, T = x.shape pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T) h = self.tok_emb(x) + self.pos_emb(pos) h = h + self.time_emb(t).unsqueeze(1) + self.cond_emb(cond).unsqueeze(1) mask = self.block_mask[:T, :T] h = self.transformer(h, mask=mask, is_causal=False) return self.head(self.norm(h)) @torch.no_grad() def absorbing_corrupt(self, x_clean: torch.Tensor, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Forward diffusion: each token independently masked with probability t.""" mp = t.unsqueeze(-1) m = torch.rand_like(x_clean, dtype=torch.float) < mp x_noisy = torch.where(m, self.cfg.mask_token, x_clean) return x_noisy, m def diffusion_loss(self, x_clean: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: """Standard absorbing-state diffusion loss with CFG cond-dropout.""" B = x_clean.size(0) device = x_clean.device # CFG: drop condition to null with prob cond_dropout drop = torch.rand(B, device=device) < self.cfg.cond_dropout cond = torch.where(drop, torch.zeros_like(cond), cond) # Sample noise level uniformly in (0,1) t = torch.rand(B, device=device).clamp(min=1e-3, max=1.0 - 1e-3) x_noisy, mask = self.absorbing_corrupt(x_clean, t) logits = self.forward(x_noisy, t, cond) # CE loss only on masked positions ce = F.cross_entropy( logits.permute(0, 2, 1), x_clean, reduction="none", ) n_masked = mask.float().sum().clamp(min=1.0) return (ce * mask.float()).sum() / n_masked