""" dit_components.py Self-contained DiT (Diffusion Transformer) building blocks. Adapted from the MDLM / HDLM open-source codebase; kept here so that the SAD project has zero dependency on any external local directory. References: - https://github.com/kuleshov-group/mdlm - https://github.com/kuleshov-group/gidd """ import math import typing import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange try: import flash_attn import flash_attn.layers.rotary _has_flash_attn = True except ImportError: torch.backends.cuda.enable_flash_sdp(enabled=True) _has_flash_attn = False # Pre-compile flex_attention — the real speed win for block-sparse masks. # Compilation is lazy (happens on first call); re-used across all blocks and steps. try: from torch.nn.attention.flex_attention import flex_attention as _flex_attention_raw _flex_attention_compiled = torch.compile(_flex_attention_raw, dynamic=False) _has_flex_attention = True except ImportError: _flex_attention_compiled = None _has_flex_attention = False # JIT fusion flags (same as original) torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) # ────────────────────────────────────────────────────────────────────────────── # Low-level helpers # ────────────────────────────────────────────────────────────────────────────── def bias_dropout_add_scale( x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: torch.Tensor, residual: typing.Optional[torch.Tensor], prob: float, training: bool, ) -> torch.Tensor: out = scale * F.dropout(x + bias if bias is not None else x, p=prob, training=training) if residual is not None: out = residual + out return out def bias_dropout_add_scale_fused_train( x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: torch.Tensor, residual: typing.Optional[torch.Tensor], prob: float, ) -> torch.Tensor: return bias_dropout_add_scale(x, bias, scale, residual, prob, True) def bias_dropout_add_scale_fused_inference( x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: torch.Tensor, residual: typing.Optional[torch.Tensor], prob: float, ) -> torch.Tensor: return bias_dropout_add_scale(x, bias, scale, residual, prob, False) def modulate_fused(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return x * (1 + scale) + shift # ────────────────────────────────────────────────────────────────────────────── # Rotary position embedding # ────────────────────────────────────────────────────────────────────────────── class Rotary(nn.Module): def __init__(self, dim: int, base: int = 10_000, max_seq_len: int = 512): super().__init__() self.dim = dim self.base = base self.max_seq_len = max_seq_len self._precompute() def _precompute(self): inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) t = torch.arange(self.max_seq_len).type_as(inv_freq) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) # dims: batch, seq_len, qkv, head, dim cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) cos_cached[:, :, 2, :, :].fill_(1.0) sin_cached[:, :, 2, :, :].fill_(0.0) self.register_buffer("cos_cached", cos_cached) self.register_buffer("sin_cached", sin_cached) def forward(self, x: torch.Tensor, seq_dim: int = 1, position_ids: typing.Optional[torch.Tensor] = None): if position_ids is not None: # position_ids: [seq_len] 1-D tensor of integer positions cos = self.cos_cached[:, position_ids] sin = self.sin_cached[:, position_ids] return cos, sin seq_len = x.shape[seq_dim] return self.cos_cached[:, :seq_len], self.sin_cached[:, :seq_len] def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: if _has_flash_attn: cos = cos[0, :, 0, 0, : cos.shape[-1] // 2] sin = sin[0, :, 0, 0, : sin.shape[-1] // 2] return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin) return (qkv * cos) + (rotate_half(qkv) * sin) # ────────────────────────────────────────────────────────────────────────────── # Layers # ────────────────────────────────────────────────────────────────────────────── class LayerNorm(nn.Module): def __init__(self, dim: int): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm(x.float(), [self.dim]) * self.weight[None, None, :] class EmbeddingLayer(nn.Module): """Token embedding table (parameter, not nn.Embedding, for easy weight sharing).""" def __init__(self, dim: int, vocab_dim: int): super().__init__() self.embedding = nn.Parameter(torch.empty(vocab_dim, dim)) nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.embedding[x] # ────────────────────────────────────────────────────────────────────────────── # Transformer block (bidirectional, with attention mask support) # ────────────────────────────────────────────────────────────────────────────── class DDiTBlockWithMask(nn.Module): """ DiT block with adaLN-Zero conditioning and optional attention mask. Supports both flash-attention (if installed) and standard SDPA. """ def __init__(self, dim: int, n_heads: int, cond_dim: int, mlp_ratio: int = 4, dropout: float = 0.1): super().__init__() self.n_heads = n_heads self.dim = dim self.dropout = dropout self.norm1 = LayerNorm(dim) self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) self.attn_out = nn.Linear(dim, dim, bias=False) self.norm2 = LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_ratio * dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_ratio * dim, dim, bias=True), ) # adaLN-Zero modulation self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) self.adaLN_modulation.weight.data.zero_() self.adaLN_modulation.bias.data.zero_() def _bias_dropout_scale_fn(self): return bias_dropout_add_scale_fused_train if self.training \ else bias_dropout_add_scale_fused_inference def forward( self, x: torch.Tensor, rotary_cos_sin: typing.Tuple[torch.Tensor, torch.Tensor], c: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None, seqlens: typing.Optional[torch.Tensor] = None, flex_block_mask=None, ) -> torch.Tensor: B, S = x.shape[:2] bds_fn = self._bias_dropout_scale_fn() (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) # ── Attention ── x_skip = x x = modulate_fused(self.norm1(x), shift_msa, scale_msa) qkv = self.attn_qkv(x) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads) cos, sin = rotary_cos_sin qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) if flex_block_mask is not None: # Compiled FlexAttention path — skips masked kv-blocks. q = qkv[:, :, 0].transpose(1, 2) # [B, H, S, d] k = qkv[:, :, 1].transpose(1, 2) v = qkv[:, :, 2].transpose(1, 2) x = _flex_attention_compiled(q, k, v, block_mask=flex_block_mask) x = rearrange(x, "b h s d -> b s (h d)", b=B) elif _has_flash_attn and attention_mask is None: qkv = rearrange(qkv, "b s ... -> (b s) ...") cu = seqlens.cumsum(-1) if seqlens is not None else torch.arange( 0, (B + 1) * S, step=S, dtype=torch.int32, device=qkv.device) x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func( qkv, cu, S, 0.0, causal=False) x = rearrange(x, "(b s) h d -> b s (h d)", b=B) else: q = qkv[:, :, 0].transpose(1, 2) k = qkv[:, :, 1].transpose(1, 2) v = qkv[:, :, 2].transpose(1, 2) if attention_mask is not None: if attention_mask.is_floating_point(): # Pre-built additive float bias (e.g. [B, 1, S, S] from # forward_vectorized). Pass through, padding any missing # leading dims so SDPA broadcasts correctly to [B, H, S, S]. float_mask = attention_mask if float_mask.dim() == 2: # [B, S] padding mask -> [B, 1, 1, S] float_mask = float_mask.unsqueeze(1).unsqueeze(1) while float_mask.dim() < 4: float_mask = float_mask.unsqueeze(0) elif attention_mask.dim() == 2: # [S, S] bool sparse mask (e.g. block-diff mask) — broadcast over batch/heads float_mask = torch.zeros_like(attention_mask, dtype=q.dtype) float_mask = float_mask.masked_fill(~attention_mask.bool(), -1e9) float_mask = float_mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S] else: # Legacy [B, S] padding mask: derive [B, 1, S, S] additive bias attn_mask = attention_mask.bool().unsqueeze(1) & attention_mask.bool().unsqueeze(2) float_mask = torch.zeros(attn_mask.shape, dtype=q.dtype, device=q.device) float_mask.masked_fill_(~attn_mask.unsqueeze(1), -1e9) x = F.scaled_dot_product_attention(q, k, v, attn_mask=float_mask) else: x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "b h s d -> b s (h d)", b=B) x = bds_fn(self.attn_out(x), None, gate_msa, x_skip, self.dropout) # ── MLP ── x = bds_fn( self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)), None, gate_mlp, x, self.dropout, ) return x # ────────────────────────────────────────────────────────────────────────────── # Final layer # ────────────────────────────────────────────────────────────────────────────── class DDitFinalLayer(nn.Module): def __init__(self, hidden_size: int, out_channels: int, cond_dim: int): super().__init__() self.norm_final = LayerNorm(hidden_size) self.linear = nn.Linear(hidden_size, out_channels) self.linear.weight.data.zero_() self.linear.bias.data.zero_() self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True) self.adaLN_modulation.weight.data.zero_() self.adaLN_modulation.bias.data.zero_() def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) x = modulate_fused(self.norm_final(x), shift, scale) return self.linear(x)