gemeo-sus / src /adaln_zero.py
timmers's picture
GEMEO/SUS v6 recurrence-aware (RAVEN) β€” new-onset Top-1 60.1% vs baseline 38.2%, defeats autocorrelation trap. GEMEO Arch v2.0 Principle 7 proven.
908ea05 verified
"""AdaLN-Zero conditioning module (DiT-style, Peebles 2023).
Used in: DiT (ICCV 2023), Stable Diffusion 3 (Esser 2024), Sora, Lumina-Next,
PixArt-Sigma. Standard for diffusion conditioning in 2025-2026.
Why for Diffusion Forcing on EHR:
- Per-token sigma + global cond/action β†’ per-token (scale, shift, gate)
- Gates init to zero β‡’ block starts as identity β‡’ no catastrophic init
- Much better CFG (dropped condition path goes through zero gates,
not corrupting residual stream)
- DFoT (Diffusion Forcing Transformer 2, ICLR 2026) confirms +3-8% win
We fuse THREE conditioning signals:
- sigma (B, T) per-token noise level β†’ time_emb (B, T, D)
- cond (B,) cohort-level treatment id β†’ cond_emb (B, D) β†’ broadcast
- action(B, T) per-token latent action id β†’ action_emb (B, T, D)
Combined into c_t (B, T, D) β†’ ConditioningMLP β†’ 6 modulation tensors
per block. Each block uses them as:
h = x + gate_msa * Attn(scale_msa * Norm(x) + shift_msa)
h = h + gate_mlp * MLP(scale_mlp * Norm(h) + shift_mlp)
"""
from __future__ import annotations
import torch
import torch.nn as nn
class AdaLNZeroModulator(nn.Module):
"""Generates per-token (scale, shift, gate) for AdaLN-Zero block.
Input: fused conditioning vector c (B, T, d_model).
Output: 6 tensors of shape (B, T, d_model) each:
(scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp)
"""
def __init__(self, d_model: int):
super().__init__()
self.modulator = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, 6 * d_model, bias=True),
)
# Zero-init for the gate-producing rows (AdaLN-Zero trick)
# We zero-init ALL outputs initially; gate stays zero so block is identity
nn.init.zeros_(self.modulator[-1].weight)
nn.init.zeros_(self.modulator[-1].bias)
def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, ...]:
# c: (B, T, d_model)
out = self.modulator(c) # (B, T, 6*d_model)
return out.chunk(6, dim=-1)
class AdaLNZeroBlock(nn.Module):
"""Transformer block with AdaLN-Zero modulation.
Drop-in replacement for the standard pre-norm block. Reads
pre-computed modulation tensors and applies them around Attn + MLP.
"""
def __init__(self, d_model: int, n_heads: int, ffn: int, dropout: float,
rope=None, kg_xattn=None):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.rope = rope
self.kg_xattn = kg_xattn
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(d_model, ffn, bias=False),
nn.GELU(),
nn.Linear(ffn, d_model, bias=False),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
scale_msa, shift_msa, gate_msa,
scale_mlp, shift_mlp, gate_mlp,
kg_ctx: torch.Tensor | None = None) -> torch.Tensor:
import torch.nn.functional as F
B, T, D = x.shape
# MSA branch
h = self.norm1(x) * (1 + scale_msa) + shift_msa
qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
if self.rope is not None:
q, k = self.rope(q, k, T)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
dropout_p=self.dropout.p if self.training else 0.0,
)
out = out.transpose(1, 2).reshape(B, T, D)
x = x + gate_msa * self.dropout(self.proj(out))
# KG cross-attention (between MSA and MLP)
if self.kg_xattn is not None and kg_ctx is not None:
x = self.kg_xattn(x, kg_ctx)
# MLP branch
h = self.norm2(x) * (1 + scale_mlp) + shift_mlp
x = x + gate_mlp * self.dropout(self.mlp(h))
return x
class FusedConditioner(nn.Module):
"""Fuse (sigma, cond, action) into one per-token conditioning vector.
Output (B, T, d_model) consumed by AdaLNZeroModulator per layer.
"""
def __init__(self, d_model: int, n_conditions: int, n_actions: int,
use_action: bool = True):
super().__init__()
self.d_model = d_model
self.use_action = use_action
# Sigma β†’ sinusoidal embedding
self.sigma_proj = nn.Sequential(
nn.Linear(d_model, d_model), nn.SiLU(), nn.Linear(d_model, d_model),
)
self.cond_emb = nn.Embedding(n_conditions, d_model)
if use_action:
self.action_emb = nn.Embedding(n_actions + 1, d_model)
self.fuse = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, d_model),
)
def sinusoidal(self, sigma: torch.Tensor) -> torch.Tensor:
import math
half = self.d_model // 2
freqs = torch.exp(
-math.log(10000.0) * torch.arange(half, device=sigma.device) / half
)
ang = sigma.float().unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
return self.sigma_proj(emb)
def forward(self, sigma: torch.Tensor, cond: torch.Tensor,
action: torch.Tensor | None = None) -> torch.Tensor:
# sigma (B, T) β†’ time_emb (B, T, D)
time_emb = self.sinusoidal(sigma)
# cond (B,) β†’ (B, D) β†’ broadcast to (B, T, D)
cond_emb = self.cond_emb(cond).unsqueeze(1).expand_as(time_emb)
fused = time_emb + cond_emb
if self.use_action and action is not None:
fused = fused + self.action_emb(action)
return self.fuse(fused)