File size: 6,014 Bytes
908ea05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """AdaLN-Zero conditioning module (DiT-style, Peebles 2023).
Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next,
PixArt-Sigma. Standard for diffusion conditioning in 2025-2026.
Why for Diffusion Forcing on EHR:
- Per-token sigma + global cond/action β per-token (scale, shift, gate)
- Gates init to zero β block starts as identity β no catastrophic init
- Much better CFG (dropped condition path goes through zero gates,
not corrupting residual stream)
- DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win
We fuse THREE conditioning signals:
- sigma (B, T) per-token noise level β time_emb (B, T, D)
- cond (B,) cohort-level treatment id β cond_emb (B, D) β broadcast
- action(B, T) per-token latent action id β action_emb (B, T, D)
Combined into c_t (B, T, D) β ConditioningMLP β 6 modulation tensors
per block. Each block uses them as:
h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa)
h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp)
"""
from __future__ import annotations
import torch
import torch.nn as nn
class AdaLNZeroModulator(nn.Module):
"""Generates per-token (scale, shift, gate) for AdaLN-Zero block.
Input: fused conditioning vector c (B, T, d_model).
Output: 6 tensors of shape (B, T, d_model) each:
(scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp)
"""
def __init__(self, d_model: int):
super().__init__()
self.modulator = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, 6 * d_model, bias=True),
)
# Zero-init for the gate-producing rows (AdaLN-Zero trick)
# We zero-init ALL outputs initially; gate stays zero so block is identity
nn.init.zeros_(self.modulator[-1].weight)
nn.init.zeros_(self.modulator[-1].bias)
def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]:
# c: (B, T, d_model)
out = self.modulator(c) # (B, T, 6*d_model)
return out.chunk(6, dim=-1)
class AdaLNZeroBlock(nn.Module):
"""Transformer block with AdaLN-Zero modulation.
Drop-in replacement for the standard pre-norm block. Reads
pre-computed modulation tensors and applies them around Attn + MLP.
"""
def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float,
rope=None, kg_xattn=None):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.rope = rope
self.kg_xattn = kg_xattn
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(d_model, ffn, bias=False),
nn.GELU(),
nn.Linear(ffn, d_model, bias=False),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
scale_msa, shift_msa, gate_msa,
scale_mlp, shift_mlp, gate_mlp,
kg_ctx: torch.Tensor | None = None) -> torch.Tensor:
import torch.nn.functional as F
B, T, D = x.shape
# MSA branch
h = self.norm1(x) * (1 + scale_msa) + shift_msa
qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
if self.rope is not None:
q, k = self.rope(q, k, T)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
dropout_p=self.dropout.p if self.training else 0.0,
)
out = out.transpose(1, 2).reshape(B, T, D)
x = x + gate_msa * self.dropout(self.proj(out))
# KG cross-attention (between MSA and MLP)
if self.kg_xattn is not None and kg_ctx is not None:
x = self.kg_xattn(x, kg_ctx)
# MLP branch
h = self.norm2(x) * (1 + scale_mlp) + shift_mlp
x = x + gate_mlp * self.dropout(self.mlp(h))
return x
class FusedConditioner(nn.Module):
"""Fuse (sigma, cond, action) into one per-token conditioning vector.
Output (B, T, d_model) consumed by AdaLNZeroModulator per layer.
"""
def __init__(self, d_model: int, n_conditions: int, n_actions: int,
use_action: bool = True):
super().__init__()
self.d_model = d_model
self.use_action = use_action
# Sigma β sinusoidal embedding
self.sigma_proj = nn.Sequential(
nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model),
)
self.cond_emb = nn.Embedding(n_conditions, d_model)
if use_action:
self.action_emb = nn.Embedding(n_actions + 1, d_model)
self.fuse = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, d_model),
)
def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor:
import math
half = self.d_model // 2
freqs = torch.exp(
-math.log(10000.0) * torch.arange(half, device=sigma.device) / half
)
ang = sigma.float().unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
return self.sigma_proj(emb)
def forward(self, sigma: torch.Tensor, cond: torch.Tensor,
action: torch.Tensor | None = None) -> torch.Tensor:
# sigma (B, T) β time_emb (B, T, D)
time_emb = self.sinusoidal(sigma)
# cond (B,) β (B, D) β broadcast to (B, T, D)
cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb)
fused = time_emb + cond_emb
if self.use_action and action is not None:
fused = fused + self.action_emb(action)
return self.fuse(fused)
|