File size: 6,014 Bytes
908ea05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""AdaLN-Zero conditioning module (DiT-style, Peebles 2023).

Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next,
PixArt-Sigma. Standard for diffusion conditioning in 2025-2026.

Why for Diffusion Forcing on EHR:
  - Per-token sigma + global cond/action β†’ per-token (scale, shift, gate)
  - Gates init to zero β‡’ block starts as identity β‡’ no catastrophic init
  - Much better CFG (dropped condition path goes through zero gates,
    not corrupting residual stream)
  - DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win

We fuse THREE conditioning signals:
  - sigma (B, T) per-token noise level β†’ time_emb (B, T, D)
  - cond  (B,)   cohort-level treatment id β†’ cond_emb (B, D) β†’ broadcast
  - action(B, T) per-token latent action id β†’ action_emb (B, T, D)

Combined into c_t (B, T, D) β†’ ConditioningMLP β†’ 6 modulation tensors
per block. Each block uses them as:

    h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa)
    h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp)
"""
from __future__ import annotations
import torch
import torch.nn as nn


class AdaLNZeroModulator(nn.Module):
    """Generates per-token (scale, shift, gate) for AdaLN-Zero block.

    Input: fused conditioning vector c (B, T, d_model).
    Output: 6 tensors of shape (B, T, d_model) each:
      (scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp)
    """
    def __init__(self, d_model: int):
        super().__init__()
        self.modulator = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_model, 6 * d_model, bias=True),
        )
        # Zero-init for the gate-producing rows (AdaLN-Zero trick)
        # We zero-init ALL outputs initially; gate stays zero so block is identity
        nn.init.zeros_(self.modulator[-1].weight)
        nn.init.zeros_(self.modulator[-1].bias)

    def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # c: (B, T, d_model)
        out = self.modulator(c)  # (B, T, 6*d_model)
        return out.chunk(6, dim=-1)


class AdaLNZeroBlock(nn.Module):
    """Transformer block with AdaLN-Zero modulation.

    Drop-in replacement for the standard pre-norm block. Reads
    pre-computed modulation tensors and applies them around Attn + MLP.
    """
    def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float,
                 rope=None, kg_xattn=None):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.rope = rope
        self.kg_xattn = kg_xattn

        self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, ffn, bias=False),
            nn.GELU(),
            nn.Linear(ffn, d_model, bias=False),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
                scale_msa, shift_msa, gate_msa,
                scale_mlp, shift_mlp, gate_mlp,
                kg_ctx: torch.Tensor | None = None) -> torch.Tensor:
        import torch.nn.functional as F
        B, T, D = x.shape
        # MSA branch
        h = self.norm1(x) * (1 + scale_msa) + shift_msa
        qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
        if self.rope is not None:
            q, k = self.rope(q, k, T)
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
            dropout_p=self.dropout.p if self.training else 0.0,
        )
        out = out.transpose(1, 2).reshape(B, T, D)
        x = x + gate_msa * self.dropout(self.proj(out))
        # KG cross-attention (between MSA and MLP)
        if self.kg_xattn is not None and kg_ctx is not None:
            x = self.kg_xattn(x, kg_ctx)
        # MLP branch
        h = self.norm2(x) * (1 + scale_mlp) + shift_mlp
        x = x + gate_mlp * self.dropout(self.mlp(h))
        return x


class FusedConditioner(nn.Module):
    """Fuse (sigma, cond, action) into one per-token conditioning vector.

    Output (B, T, d_model) consumed by AdaLNZeroModulator per layer.
    """
    def __init__(self, d_model: int, n_conditions: int, n_actions: int,
                 use_action: bool = True):
        super().__init__()
        self.d_model = d_model
        self.use_action = use_action
        # Sigma β†’ sinusoidal embedding
        self.sigma_proj = nn.Sequential(
            nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model),
        )
        self.cond_emb = nn.Embedding(n_conditions, d_model)
        if use_action:
            self.action_emb = nn.Embedding(n_actions + 1, d_model)
        self.fuse = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_model, d_model),
        )

    def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor:
        import math
        half = self.d_model // 2
        freqs = torch.exp(
            -math.log(10000.0) * torch.arange(half, device=sigma.device) / half
        )
        ang = sigma.float().unsqueeze(-1) * freqs
        emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
        return self.sigma_proj(emb)

    def forward(self, sigma: torch.Tensor, cond: torch.Tensor,
                action: torch.Tensor | None = None) -> torch.Tensor:
        # sigma (B, T) β†’ time_emb (B, T, D)
        time_emb = self.sinusoidal(sigma)
        # cond (B,) β†’ (B, D) β†’ broadcast to (B, T, D)
        cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb)
        fused = time_emb + cond_emb
        if self.use_action and action is not None:
            fused = fused + self.action_emb(action)
        return self.fuse(fused)