File size: 5,863 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Block Diffusion Transformer for clinical event trajectories.

Architecture follows BD3-LMs (Kuleshov group, ICLR 2025 Oral, arXiv 2503.09573):
  - Absorbing-state discrete diffusion (each clean token can become [MASK]
    with probability t in [0,1])
  - Block-causal attention: token i attends to all tokens in its block + all
    earlier blocks (block size B=16). Lets sampling proceed block-by-block
    while still parallelizing within a block.
  - Time conditioning via sinusoidal embedding of t.
  - Classifier-free guidance: each sequence has a (treatment, time-zero,
    cohort-key) condition. During training, 10% of batches drop the condition
    to the null token; at inference we mix log-probs with guidance scale gamma.

Calibrated for M4 Pro 24GB / MPS:
  - 80M params (d_model=640, n_heads=10, n_layers=16, ffn=2560)
  - 384-token max sequence (covers ~24 events including BOS/EOS/SEP/YEAR)
  - Block size 16, denoise steps 100 at inference (~2.5s per trajectory)
  - 1.2h to train 100 epochs on the ~600-sequence DATASUS v1 cohort
"""
from __future__ import annotations
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class CWMConfig:
    vocab_size: int = 217          # 216 events + 1 MASK absorbing state
    mask_token: int = 216
    null_cond: int = 0             # CFG null condition (= "no treatment specified")
    n_conditions: int = 32         # treatment/intervention vocabulary
    max_seq_len: int = 384
    block_size: int = 16
    d_model: int = 640
    n_heads: int = 10
    n_layers: int = 16
    ffn: int = 2560
    dropout: float = 0.1
    n_diff_steps: int = 100        # inference denoise steps
    cond_dropout: float = 0.10     # CFG training: drop condition with this prob


class SinusoidalTimeEmbed(nn.Module):
    """Sinusoidal embedding of continuous time t in [0,1]."""
    def __init__(self, d: int):
        super().__init__()
        self.d = d
        self.proj = nn.Sequential(
            nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
        )

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half = self.d // 2
        freqs = torch.exp(
            -math.log(10000.0) * torch.arange(half, device=t.device) / half
        )
        ang = t.float().unsqueeze(-1) * freqs
        emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
        return self.proj(emb)


class BlockDiffusionTransformer(nn.Module):
    """Block-causal Transformer that denoises absorbing-state corrupted sequences.

    Forward: x (B, T) tokens with [MASK] at some positions, t (B,) diffusion
    time, cond (B,) treatment-condition id -> logits (B, T, V).

    Loss is masked cross-entropy: only positions that were corrupted to MASK
    contribute (standard absorbing-state diffusion objective).
    """
    def __init__(self, cfg: CWMConfig | None = None):
        super().__init__()
        self.cfg = cfg or CWMConfig()
        c = self.cfg
        self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
        self.pos_emb = nn.Embedding(c.max_seq_len, c.d_model)
        self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
        self.time_emb = SinusoidalTimeEmbed(c.d_model)
        layer = nn.TransformerEncoderLayer(
            d_model=c.d_model, nhead=c.n_heads,
            dim_feedforward=c.ffn, dropout=c.dropout,
            batch_first=True, activation="gelu", norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=c.n_layers)
        self.norm = nn.LayerNorm(c.d_model)
        self.head = nn.Linear(c.d_model, c.vocab_size)
        self._register_block_mask()

    def _register_block_mask(self):
        T = self.cfg.max_seq_len
        bs = self.cfg.block_size
        block_id = torch.arange(T) // bs
        # mask[i,j] = True means token i should NOT attend to token j
        # Block-causal: attend to own block + all earlier blocks
        mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
        self.register_buffer("block_mask", mask)

    def forward(self, x: torch.Tensor, t: torch.Tensor,
                cond: torch.Tensor) -> torch.Tensor:
        B, T = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        h = self.tok_emb(x) + self.pos_emb(pos)
        h = h + self.time_emb(t).unsqueeze(1) + self.cond_emb(cond).unsqueeze(1)
        mask = self.block_mask[:T, :T]
        h = self.transformer(h, mask=mask, is_causal=False)
        return self.head(self.norm(h))

    @torch.no_grad()
    def absorbing_corrupt(self, x_clean: torch.Tensor,
                          t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward diffusion: each token independently masked with probability t."""
        mp = t.unsqueeze(-1)
        m = torch.rand_like(x_clean, dtype=torch.float) < mp
        x_noisy = torch.where(m, self.cfg.mask_token, x_clean)
        return x_noisy, m

    def diffusion_loss(self, x_clean: torch.Tensor,
                       cond: torch.Tensor) -> torch.Tensor:
        """Standard absorbing-state diffusion loss with CFG cond-dropout."""
        B = x_clean.size(0)
        device = x_clean.device

        # CFG: drop condition to null with prob cond_dropout
        drop = torch.rand(B, device=device) < self.cfg.cond_dropout
        cond = torch.where(drop, torch.zeros_like(cond), cond)

        # Sample noise level uniformly in (0,1)
        t = torch.rand(B, device=device).clamp(min=1e-3, max=1.0 - 1e-3)
        x_noisy, mask = self.absorbing_corrupt(x_clean, t)

        logits = self.forward(x_noisy, t, cond)
        # CE loss only on masked positions
        ce = F.cross_entropy(
            logits.permute(0, 2, 1), x_clean, reduction="none",
        )
        n_masked = mask.float().sum().clamp(min=1.0)
        return (ce * mask.float()).sum() / n_masked