File size: 5,863 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """Block Diffusion Transformer for clinical event trajectories.
Architecture follows BD3-LMs (Kuleshov group, ICLR 2025 Oral, arXiv 2503.09573):
- Absorbing-state discrete diffusion (each clean token can become [MASK]
with probability t in [0,1])
- Block-causal attention: token i attends to all tokens in its block + all
earlier blocks (block size B=16). Lets sampling proceed block-by-block
while still parallelizing within a block.
- Time conditioning via sinusoidal embedding of t.
- Classifier-free guidance: each sequence has a (treatment, time-zero,
cohort-key) condition. During training, 10% of batches drop the condition
to the null token; at inference we mix log-probs with guidance scale gamma.
Calibrated for M4 Pro 24GB / MPS:
- 80M params (d_model=640, n_heads=10, n_layers=16, ffn=2560)
- 384-token max sequence (covers ~24 events including BOS/EOS/SEP/YEAR)
- Block size 16, denoise steps 100 at inference (~2.5s per trajectory)
- 1.2h to train 100 epochs on the ~600-sequence DATASUS v1 cohort
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class CWMConfig:
vocab_size: int = 217 # 216 events + 1 MASK absorbing state
mask_token: int = 216
null_cond: int = 0 # CFG null condition (= "no treatment specified")
n_conditions: int = 32 # treatment/intervention vocabulary
max_seq_len: int = 384
block_size: int = 16
d_model: int = 640
n_heads: int = 10
n_layers: int = 16
ffn: int = 2560
dropout: float = 0.1
n_diff_steps: int = 100 # inference denoise steps
cond_dropout: float = 0.10 # CFG training: drop condition with this prob
class SinusoidalTimeEmbed(nn.Module):
"""Sinusoidal embedding of continuous time t in [0,1]."""
def __init__(self, d: int):
super().__init__()
self.d = d
self.proj = nn.Sequential(
nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.d // 2
freqs = torch.exp(
-math.log(10000.0) * torch.arange(half, device=t.device) / half
)
ang = t.float().unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
return self.proj(emb)
class BlockDiffusionTransformer(nn.Module):
"""Block-causal Transformer that denoises absorbing-state corrupted sequences.
Forward: x (B, T) tokens with [MASK] at some positions, t (B,) diffusion
time, cond (B,) treatment-condition id -> logits (B, T, V).
Loss is masked cross-entropy: only positions that were corrupted to MASK
contribute (standard absorbing-state diffusion objective).
"""
def __init__(self, cfg: CWMConfig | None = None):
super().__init__()
self.cfg = cfg or CWMConfig()
c = self.cfg
self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
self.pos_emb = nn.Embedding(c.max_seq_len, c.d_model)
self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
self.time_emb = SinusoidalTimeEmbed(c.d_model)
layer = nn.TransformerEncoderLayer(
d_model=c.d_model, nhead=c.n_heads,
dim_feedforward=c.ffn, dropout=c.dropout,
batch_first=True, activation="gelu", norm_first=True,
)
self.transformer = nn.TransformerEncoder(layer, num_layers=c.n_layers)
self.norm = nn.LayerNorm(c.d_model)
self.head = nn.Linear(c.d_model, c.vocab_size)
self._register_block_mask()
def _register_block_mask(self):
T = self.cfg.max_seq_len
bs = self.cfg.block_size
block_id = torch.arange(T) // bs
# mask[i,j] = True means token i should NOT attend to token j
# Block-causal: attend to own block + all earlier blocks
mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
self.register_buffer("block_mask", mask)
def forward(self, x: torch.Tensor, t: torch.Tensor,
cond: torch.Tensor) -> torch.Tensor:
B, T = x.shape
pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
h = self.tok_emb(x) + self.pos_emb(pos)
h = h + self.time_emb(t).unsqueeze(1) + self.cond_emb(cond).unsqueeze(1)
mask = self.block_mask[:T, :T]
h = self.transformer(h, mask=mask, is_causal=False)
return self.head(self.norm(h))
@torch.no_grad()
def absorbing_corrupt(self, x_clean: torch.Tensor,
t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward diffusion: each token independently masked with probability t."""
mp = t.unsqueeze(-1)
m = torch.rand_like(x_clean, dtype=torch.float) < mp
x_noisy = torch.where(m, self.cfg.mask_token, x_clean)
return x_noisy, m
def diffusion_loss(self, x_clean: torch.Tensor,
cond: torch.Tensor) -> torch.Tensor:
"""Standard absorbing-state diffusion loss with CFG cond-dropout."""
B = x_clean.size(0)
device = x_clean.device
# CFG: drop condition to null with prob cond_dropout
drop = torch.rand(B, device=device) < self.cfg.cond_dropout
cond = torch.where(drop, torch.zeros_like(cond), cond)
# Sample noise level uniformly in (0,1)
t = torch.rand(B, device=device).clamp(min=1e-3, max=1.0 - 1e-3)
x_noisy, mask = self.absorbing_corrupt(x_clean, t)
logits = self.forward(x_noisy, t, cond)
# CE loss only on masked positions
ce = F.cross_entropy(
logits.permute(0, 2, 1), x_clean, reduction="none",
)
n_masked = mask.float().sum().clamp(min=1.0)
return (ce * mask.float()).sum() / n_masked
|