sad / src /models /dit_components.py
haochengsama's picture
Add files using upload-large-folder tool
922bb4b verified
Raw
History Blame Contribute Delete
13.3 kB
"""
dit_components.py
Self-contained DiT (Diffusion Transformer) building blocks.
Adapted from the MDLM / HDLM open-source codebase; kept here so that the
SAD project has zero dependency on any external local directory.
References:
- https://github.com/kuleshov-group/mdlm
- https://github.com/kuleshov-group/gidd
"""
import math
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
try:
import flash_attn
import flash_attn.layers.rotary
_has_flash_attn = True
except ImportError:
torch.backends.cuda.enable_flash_sdp(enabled=True)
_has_flash_attn = False
# Pre-compile flex_attention β€” the real speed win for block-sparse masks.
# Compilation is lazy (happens on first call); re-used across all blocks and steps.
try:
from torch.nn.attention.flex_attention import flex_attention as _flex_attention_raw
_flex_attention_compiled = torch.compile(_flex_attention_raw, dynamic=False)
_has_flex_attention = True
except ImportError:
_flex_attention_compiled = None
_has_flex_attention = False
# JIT fusion flags (same as original)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
# ──────────────────────────────────────────────────────────────────────────────
# Low-level helpers
# ──────────────────────────────────────────────────────────────────────────────
def bias_dropout_add_scale(
x: torch.Tensor,
bias: typing.Optional[torch.Tensor],
scale: torch.Tensor,
residual: typing.Optional[torch.Tensor],
prob: float,
training: bool,
) -> torch.Tensor:
out = scale * F.dropout(x + bias if bias is not None else x, p=prob, training=training)
if residual is not None:
out = residual + out
return out
def bias_dropout_add_scale_fused_train(
x: torch.Tensor,
bias: typing.Optional[torch.Tensor],
scale: torch.Tensor,
residual: typing.Optional[torch.Tensor],
prob: float,
) -> torch.Tensor:
return bias_dropout_add_scale(x, bias, scale, residual, prob, True)
def bias_dropout_add_scale_fused_inference(
x: torch.Tensor,
bias: typing.Optional[torch.Tensor],
scale: torch.Tensor,
residual: typing.Optional[torch.Tensor],
prob: float,
) -> torch.Tensor:
return bias_dropout_add_scale(x, bias, scale, residual, prob, False)
def modulate_fused(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x * (1 + scale) + shift
# ──────────────────────────────────────────────────────────────────────────────
# Rotary position embedding
# ──────────────────────────────────────────────────────────────────────────────
class Rotary(nn.Module):
def __init__(self, dim: int, base: int = 10_000, max_seq_len: int = 512):
super().__init__()
self.dim = dim
self.base = base
self.max_seq_len = max_seq_len
self._precompute()
def _precompute(self):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
t = torch.arange(self.max_seq_len).type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# dims: batch, seq_len, qkv, head, dim
cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
cos_cached[:, :, 2, :, :].fill_(1.0)
sin_cached[:, :, 2, :, :].fill_(0.0)
self.register_buffer("cos_cached", cos_cached)
self.register_buffer("sin_cached", sin_cached)
def forward(self, x: torch.Tensor, seq_dim: int = 1, position_ids: typing.Optional[torch.Tensor] = None):
if position_ids is not None:
# position_ids: [seq_len] 1-D tensor of integer positions
cos = self.cos_cached[:, position_ids]
sin = self.sin_cached[:, position_ids]
return cos, sin
seq_len = x.shape[seq_dim]
return self.cos_cached[:, :seq_len], self.sin_cached[:, :seq_len]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
if _has_flash_attn:
cos = cos[0, :, 0, 0, : cos.shape[-1] // 2]
sin = sin[0, :, 0, 0, : sin.shape[-1] // 2]
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
return (qkv * cos) + (rotate_half(qkv) * sin)
# ──────────────────────────────────────────────────────────────────────────────
# Layers
# ──────────────────────────────────────────────────────────────────────────────
class LayerNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x.float(), [self.dim]) * self.weight[None, None, :]
class EmbeddingLayer(nn.Module):
"""Token embedding table (parameter, not nn.Embedding, for easy weight sharing)."""
def __init__(self, dim: int, vocab_dim: int):
super().__init__()
self.embedding = nn.Parameter(torch.empty(vocab_dim, dim))
nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.embedding[x]
# ──────────────────────────────────────────────────────────────────────────────
# Transformer block (bidirectional, with attention mask support)
# ──────────────────────────────────────────────────────────────────────────────
class DDiTBlockWithMask(nn.Module):
"""
DiT block with adaLN-Zero conditioning and optional attention mask.
Supports both flash-attention (if installed) and standard SDPA.
"""
def __init__(self, dim: int, n_heads: int, cond_dim: int,
mlp_ratio: int = 4, dropout: float = 0.1):
super().__init__()
self.n_heads = n_heads
self.dim = dim
self.dropout = dropout
self.norm1 = LayerNorm(dim)
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_ratio * dim, dim, bias=True),
)
# adaLN-Zero modulation
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def _bias_dropout_scale_fn(self):
return bias_dropout_add_scale_fused_train if self.training \
else bias_dropout_add_scale_fused_inference
def forward(
self,
x: torch.Tensor,
rotary_cos_sin: typing.Tuple[torch.Tensor, torch.Tensor],
c: torch.Tensor,
attention_mask: typing.Optional[torch.Tensor] = None,
seqlens: typing.Optional[torch.Tensor] = None,
flex_block_mask=None,
) -> torch.Tensor:
B, S = x.shape[:2]
bds_fn = self._bias_dropout_scale_fn()
(shift_msa, scale_msa, gate_msa,
shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
# ── Attention ──
x_skip = x
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
qkv = self.attn_qkv(x)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads)
cos, sin = rotary_cos_sin
qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
if flex_block_mask is not None:
# Compiled FlexAttention path β€” skips masked kv-blocks.
q = qkv[:, :, 0].transpose(1, 2) # [B, H, S, d]
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
x = _flex_attention_compiled(q, k, v, block_mask=flex_block_mask)
x = rearrange(x, "b h s d -> b s (h d)", b=B)
elif _has_flash_attn and attention_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu = seqlens.cumsum(-1) if seqlens is not None else torch.arange(
0, (B + 1) * S, step=S, dtype=torch.int32, device=qkv.device)
x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
qkv, cu, S, 0.0, causal=False)
x = rearrange(x, "(b s) h d -> b s (h d)", b=B)
else:
q = qkv[:, :, 0].transpose(1, 2)
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
if attention_mask is not None:
if attention_mask.is_floating_point():
# Pre-built additive float bias (e.g. [B, 1, S, S] from
# forward_vectorized). Pass through, padding any missing
# leading dims so SDPA broadcasts correctly to [B, H, S, S].
float_mask = attention_mask
if float_mask.dim() == 2:
# [B, S] padding mask -> [B, 1, 1, S]
float_mask = float_mask.unsqueeze(1).unsqueeze(1)
while float_mask.dim() < 4:
float_mask = float_mask.unsqueeze(0)
elif attention_mask.dim() == 2:
# [S, S] bool sparse mask (e.g. block-diff mask) β€” broadcast over batch/heads
float_mask = torch.zeros_like(attention_mask, dtype=q.dtype)
float_mask = float_mask.masked_fill(~attention_mask.bool(), -1e9)
float_mask = float_mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S]
else:
# Legacy [B, S] padding mask: derive [B, 1, S, S] additive bias
attn_mask = attention_mask.bool().unsqueeze(1) & attention_mask.bool().unsqueeze(2)
float_mask = torch.zeros(attn_mask.shape, dtype=q.dtype, device=q.device)
float_mask.masked_fill_(~attn_mask.unsqueeze(1), -1e9)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=float_mask)
else:
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b h s d -> b s (h d)", b=B)
x = bds_fn(self.attn_out(x), None, gate_msa, x_skip, self.dropout)
# ── MLP ──
x = bds_fn(
self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)),
None, gate_mlp, x, self.dropout,
)
return x
# ──────────────────────────────────────────────────────────────────────────────
# Final layer
# ──────────────────────────────────────────────────────────────────────────────
class DDitFinalLayer(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, cond_dim: int):
super().__init__()
self.norm_final = LayerNorm(hidden_size)
self.linear = nn.Linear(hidden_size, out_channels)
self.linear.weight.data.zero_()
self.linear.bias.data.zero_()
self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
x = modulate_fused(self.norm_final(x), shift, scale)
return self.linear(x)