| """AdaLN-Zero conditioning module (DiT-style, Peebles 2023). |
| |
| Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next, |
| PixArt-Sigma. Standard for diffusion conditioning in 2025-2026. |
| |
| Why for Diffusion Forcing on EHR: |
| - Per-token sigma + global cond/action β per-token (scale, shift, gate) |
| - Gates init to zero β block starts as identity β no catastrophic init |
| - Much better CFG (dropped condition path goes through zero gates, |
| not corrupting residual stream) |
| - DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win |
| |
| We fuse THREE conditioning signals: |
| - sigma (B, T) per-token noise level β time_emb (B, T, D) |
| - cond (B,) cohort-level treatment id β cond_emb (B, D) β broadcast |
| - action(B, T) per-token latent action id β action_emb (B, T, D) |
| |
| Combined into c_t (B, T, D) β ConditioningMLP β 6 modulation tensors |
| per block. Each block uses them as: |
| |
| h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa) |
| h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp) |
| """ |
| from __future__ import annotations |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class AdaLNZeroModulator(nn.Module): |
| """Generates per-token (scale, shift, gate) for AdaLN-Zero block. |
| |
| Input: fused conditioning vector c (B, T, d_model). |
| Output: 6 tensors of shape (B, T, d_model) each: |
| (scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp) |
| """ |
| def __init__(self, d_model: int): |
| super().__init__() |
| self.modulator = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(d_model, 6 * d_model, bias=True), |
| ) |
| |
| |
| nn.init.zeros_(self.modulator[-1].weight) |
| nn.init.zeros_(self.modulator[-1].bias) |
|
|
| def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]: |
| |
| out = self.modulator(c) |
| return out.chunk(6, dim=-1) |
|
|
|
|
| class AdaLNZeroBlock(nn.Module): |
| """Transformer block with AdaLN-Zero modulation. |
| |
| Drop-in replacement for the standard pre-norm block. Reads |
| pre-computed modulation tensors and applies them around Attn + MLP. |
| """ |
| def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float, |
| rope=None, kg_xattn=None): |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.rope = rope |
| self.kg_xattn = kg_xattn |
|
|
| self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False) |
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.proj = nn.Linear(d_model, d_model, bias=False) |
| self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False) |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, ffn, bias=False), |
| nn.GELU(), |
| nn.Linear(ffn, d_model, bias=False), |
| ) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor, |
| scale_msa, shift_msa, gate_msa, |
| scale_mlp, shift_mlp, gate_mlp, |
| kg_ctx: torch.Tensor | None = None) -> torch.Tensor: |
| import torch.nn.functional as F |
| B, T, D = x.shape |
| |
| h = self.norm1(x) * (1 + scale_msa) + shift_msa |
| qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) |
| if self.rope is not None: |
| q, k = self.rope(q, k, T) |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None], |
| dropout_p=self.dropout.p if self.training else 0.0, |
| ) |
| out = out.transpose(1, 2).reshape(B, T, D) |
| x = x + gate_msa * self.dropout(self.proj(out)) |
| |
| if self.kg_xattn is not None and kg_ctx is not None: |
| x = self.kg_xattn(x, kg_ctx) |
| |
| h = self.norm2(x) * (1 + scale_mlp) + shift_mlp |
| x = x + gate_mlp * self.dropout(self.mlp(h)) |
| return x |
|
|
|
|
| class FusedConditioner(nn.Module): |
| """Fuse (sigma, cond, action) into one per-token conditioning vector. |
| |
| Output (B, T, d_model) consumed by AdaLNZeroModulator per layer. |
| """ |
| def __init__(self, d_model: int, n_conditions: int, n_actions: int, |
| use_action: bool = True): |
| super().__init__() |
| self.d_model = d_model |
| self.use_action = use_action |
| |
| self.sigma_proj = nn.Sequential( |
| nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model), |
| ) |
| self.cond_emb = nn.Embedding(n_conditions, d_model) |
| if use_action: |
| self.action_emb = nn.Embedding(n_actions + 1, d_model) |
| self.fuse = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(d_model, d_model), |
| ) |
|
|
| def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor: |
| import math |
| half = self.d_model // 2 |
| freqs = torch.exp( |
| -math.log(10000.0) * torch.arange(half, device=sigma.device) / half |
| ) |
| ang = sigma.float().unsqueeze(-1) * freqs |
| emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) |
| return self.sigma_proj(emb) |
|
|
| def forward(self, sigma: torch.Tensor, cond: torch.Tensor, |
| action: torch.Tensor | None = None) -> torch.Tensor: |
| |
| time_emb = self.sinusoidal(sigma) |
| |
| cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb) |
| fused = time_emb + cond_emb |
| if self.use_action and action is not None: |
| fused = fused + self.action_emb(action) |
| return self.fuse(fused) |
|
|