"""AdaLN-Zero conditioning module (DiT-style, Peebles 2023). Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next, PixArt-Sigma. Standard for diffusion conditioning in 2025-2026. Why for Diffusion Forcing on EHR: - Per-token sigma + global cond/action → per-token (scale, shift, gate) - Gates init to zero ⇒ block starts as identity ⇒ no catastrophic init - Much better CFG (dropped condition path goes through zero gates, not corrupting residual stream) - DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win We fuse THREE conditioning signals: - sigma (B, T) per-token noise level → time_emb (B, T, D) - cond (B,) cohort-level treatment id → cond_emb (B, D) → broadcast - action(B, T) per-token latent action id → action_emb (B, T, D) Combined into c_t (B, T, D) → ConditioningMLP → 6 modulation tensors per block. Each block uses them as: h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa) h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp) """ from __future__ import annotations import torch import torch.nn as nn class AdaLNZeroModulator(nn.Module): """Generates per-token (scale, shift, gate) for AdaLN-Zero block. Input: fused conditioning vector c (B, T, d_model). Output: 6 tensors of shape (B, T, d_model) each: (scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp) """ def __init__(self, d_model: int): super().__init__() self.modulator = nn.Sequential( nn.SiLU(), nn.Linear(d_model, 6 * d_model, bias=True), ) # Zero-init for the gate-producing rows (AdaLN-Zero trick) # We zero-init ALL outputs initially; gate stays zero so block is identity nn.init.zeros_(self.modulator[-1].weight) nn.init.zeros_(self.modulator[-1].bias) def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]: # c: (B, T, d_model) out = self.modulator(c) # (B, T, 6*d_model) return out.chunk(6, dim=-1) class AdaLNZeroBlock(nn.Module): """Transformer block with AdaLN-Zero modulation. Drop-in replacement for the standard pre-norm block. Reads pre-computed modulation tensors and applies them around Attn + MLP. """ def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float, rope=None, kg_xattn=None): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.rope = rope self.kg_xattn = kg_xattn self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False) self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.proj = nn.Linear(d_model, d_model, bias=False) self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False) self.mlp = nn.Sequential( nn.Linear(d_model, ffn, bias=False), nn.GELU(), nn.Linear(ffn, d_model, bias=False), ) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor, scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp, kg_ctx: torch.Tensor | None = None) -> torch.Tensor: import torch.nn.functional as F B, T, D = x.shape # MSA branch h = self.norm1(x) * (1 + scale_msa) + shift_msa qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim) q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) if self.rope is not None: q, k = self.rope(q, k, T) out = F.scaled_dot_product_attention( q, k, v, attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None], dropout_p=self.dropout.p if self.training else 0.0, ) out = out.transpose(1, 2).reshape(B, T, D) x = x + gate_msa * self.dropout(self.proj(out)) # KG cross-attention (between MSA and MLP) if self.kg_xattn is not None and kg_ctx is not None: x = self.kg_xattn(x, kg_ctx) # MLP branch h = self.norm2(x) * (1 + scale_mlp) + shift_mlp x = x + gate_mlp * self.dropout(self.mlp(h)) return x class FusedConditioner(nn.Module): """Fuse (sigma, cond, action) into one per-token conditioning vector. Output (B, T, d_model) consumed by AdaLNZeroModulator per layer. """ def __init__(self, d_model: int, n_conditions: int, n_actions: int, use_action: bool = True): super().__init__() self.d_model = d_model self.use_action = use_action # Sigma → sinusoidal embedding self.sigma_proj = nn.Sequential( nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model), ) self.cond_emb = nn.Embedding(n_conditions, d_model) if use_action: self.action_emb = nn.Embedding(n_actions + 1, d_model) self.fuse = nn.Sequential( nn.SiLU(), nn.Linear(d_model, d_model), ) def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor: import math half = self.d_model // 2 freqs = torch.exp( -math.log(10000.0) * torch.arange(half, device=sigma.device) / half ) ang = sigma.float().unsqueeze(-1) * freqs emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) return self.sigma_proj(emb) def forward(self, sigma: torch.Tensor, cond: torch.Tensor, action: torch.Tensor | None = None) -> torch.Tensor: # sigma (B, T) → time_emb (B, T, D) time_emb = self.sinusoidal(sigma) # cond (B,) → (B, D) → broadcast to (B, T, D) cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb) fused = time_emb + cond_emb if self.use_action and action is not None: fused = fused + self.action_emb(action) return self.fuse(fused)