| """Block Diffusion Transformer for clinical event trajectories. |
| |
| Architecture follows BD3-LMs (Kuleshov group, ICLR 2025 Oral, arXiv 2503.09573): |
| - Absorbing-state discrete diffusion (each clean token can become [MASK] |
| with probability t in [0,1]) |
| - Block-causal attention: token i attends to all tokens in its block + all |
| earlier blocks (block size B=16). Lets sampling proceed block-by-block |
| while still parallelizing within a block. |
| - Time conditioning via sinusoidal embedding of t. |
| - Classifier-free guidance: each sequence has a (treatment, time-zero, |
| cohort-key) condition. During training, 10% of batches drop the condition |
| to the null token; at inference we mix log-probs with guidance scale gamma. |
| |
| Calibrated for M4 Pro 24GB / MPS: |
| - 80M params (d_model=640, n_heads=10, n_layers=16, ffn=2560) |
| - 384-token max sequence (covers ~24 events including BOS/EOS/SEP/YEAR) |
| - Block size 16, denoise steps 100 at inference (~2.5s per trajectory) |
| - 1.2h to train 100 epochs on the ~600-sequence DATASUS v1 cohort |
| """ |
| from __future__ import annotations |
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class CWMConfig: |
| vocab_size: int = 217 |
| mask_token: int = 216 |
| null_cond: int = 0 |
| n_conditions: int = 32 |
| max_seq_len: int = 384 |
| block_size: int = 16 |
| d_model: int = 640 |
| n_heads: int = 10 |
| n_layers: int = 16 |
| ffn: int = 2560 |
| dropout: float = 0.1 |
| n_diff_steps: int = 100 |
| cond_dropout: float = 0.10 |
|
|
|
|
| class SinusoidalTimeEmbed(nn.Module): |
| """Sinusoidal embedding of continuous time t in [0,1].""" |
| def __init__(self, d: int): |
| super().__init__() |
| self.d = d |
| self.proj = nn.Sequential( |
| nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d), |
| ) |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| half = self.d // 2 |
| freqs = torch.exp( |
| -math.log(10000.0) * torch.arange(half, device=t.device) / half |
| ) |
| ang = t.float().unsqueeze(-1) * freqs |
| emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) |
| return self.proj(emb) |
|
|
|
|
| class BlockDiffusionTransformer(nn.Module): |
| """Block-causal Transformer that denoises absorbing-state corrupted sequences. |
| |
| Forward: x (B, T) tokens with [MASK] at some positions, t (B,) diffusion |
| time, cond (B,) treatment-condition id -> logits (B, T, V). |
| |
| Loss is masked cross-entropy: only positions that were corrupted to MASK |
| contribute (standard absorbing-state diffusion objective). |
| """ |
| def __init__(self, cfg: CWMConfig | None = None): |
| super().__init__() |
| self.cfg = cfg or CWMConfig() |
| c = self.cfg |
| self.tok_emb = nn.Embedding(c.vocab_size, c.d_model) |
| self.pos_emb = nn.Embedding(c.max_seq_len, c.d_model) |
| self.cond_emb = nn.Embedding(c.n_conditions, c.d_model) |
| self.time_emb = SinusoidalTimeEmbed(c.d_model) |
| layer = nn.TransformerEncoderLayer( |
| d_model=c.d_model, nhead=c.n_heads, |
| dim_feedforward=c.ffn, dropout=c.dropout, |
| batch_first=True, activation="gelu", norm_first=True, |
| ) |
| self.transformer = nn.TransformerEncoder(layer, num_layers=c.n_layers) |
| self.norm = nn.LayerNorm(c.d_model) |
| self.head = nn.Linear(c.d_model, c.vocab_size) |
| self._register_block_mask() |
|
|
| def _register_block_mask(self): |
| T = self.cfg.max_seq_len |
| bs = self.cfg.block_size |
| block_id = torch.arange(T) // bs |
| |
| |
| mask = block_id.unsqueeze(0) < block_id.unsqueeze(1) |
| self.register_buffer("block_mask", mask) |
|
|
| def forward(self, x: torch.Tensor, t: torch.Tensor, |
| cond: torch.Tensor) -> torch.Tensor: |
| B, T = x.shape |
| pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T) |
| h = self.tok_emb(x) + self.pos_emb(pos) |
| h = h + self.time_emb(t).unsqueeze(1) + self.cond_emb(cond).unsqueeze(1) |
| mask = self.block_mask[:T, :T] |
| h = self.transformer(h, mask=mask, is_causal=False) |
| return self.head(self.norm(h)) |
|
|
| @torch.no_grad() |
| def absorbing_corrupt(self, x_clean: torch.Tensor, |
| t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Forward diffusion: each token independently masked with probability t.""" |
| mp = t.unsqueeze(-1) |
| m = torch.rand_like(x_clean, dtype=torch.float) < mp |
| x_noisy = torch.where(m, self.cfg.mask_token, x_clean) |
| return x_noisy, m |
|
|
| def diffusion_loss(self, x_clean: torch.Tensor, |
| cond: torch.Tensor) -> torch.Tensor: |
| """Standard absorbing-state diffusion loss with CFG cond-dropout.""" |
| B = x_clean.size(0) |
| device = x_clean.device |
|
|
| |
| drop = torch.rand(B, device=device) < self.cfg.cond_dropout |
| cond = torch.where(drop, torch.zeros_like(cond), cond) |
|
|
| |
| t = torch.rand(B, device=device).clamp(min=1e-3, max=1.0 - 1e-3) |
| x_noisy, mask = self.absorbing_corrupt(x_clean, t) |
|
|
| logits = self.forward(x_noisy, t, cond) |
| |
| ce = F.cross_entropy( |
| logits.permute(0, 2, 1), x_clean, reduction="none", |
| ) |
| n_masked = mask.float().sum().clamp(min=1.0) |
| return (ce * mask.float()).sum() / n_masked |
|
|