| """ |
| dit_components.py |
| |
| Self-contained DiT (Diffusion Transformer) building blocks. |
| Adapted from the MDLM / HDLM open-source codebase; kept here so that the |
| SAD project has zero dependency on any external local directory. |
| |
| References: |
| - https://github.com/kuleshov-group/mdlm |
| - https://github.com/kuleshov-group/gidd |
| """ |
|
|
| import math |
| import typing |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| try: |
| import flash_attn |
| import flash_attn.layers.rotary |
| _has_flash_attn = True |
| except ImportError: |
| torch.backends.cuda.enable_flash_sdp(enabled=True) |
| _has_flash_attn = False |
|
|
| |
| |
| try: |
| from torch.nn.attention.flex_attention import flex_attention as _flex_attention_raw |
| _flex_attention_compiled = torch.compile(_flex_attention_raw, dynamic=False) |
| _has_flex_attention = True |
| except ImportError: |
| _flex_attention_compiled = None |
| _has_flex_attention = False |
|
|
| |
| torch._C._jit_set_profiling_mode(False) |
| torch._C._jit_set_profiling_executor(False) |
| torch._C._jit_override_can_fuse_on_cpu(True) |
| torch._C._jit_override_can_fuse_on_gpu(True) |
|
|
|
|
| |
| |
| |
|
|
| def bias_dropout_add_scale( |
| x: torch.Tensor, |
| bias: typing.Optional[torch.Tensor], |
| scale: torch.Tensor, |
| residual: typing.Optional[torch.Tensor], |
| prob: float, |
| training: bool, |
| ) -> torch.Tensor: |
| out = scale * F.dropout(x + bias if bias is not None else x, p=prob, training=training) |
| if residual is not None: |
| out = residual + out |
| return out |
|
|
|
|
| def bias_dropout_add_scale_fused_train( |
| x: torch.Tensor, |
| bias: typing.Optional[torch.Tensor], |
| scale: torch.Tensor, |
| residual: typing.Optional[torch.Tensor], |
| prob: float, |
| ) -> torch.Tensor: |
| return bias_dropout_add_scale(x, bias, scale, residual, prob, True) |
|
|
|
|
| def bias_dropout_add_scale_fused_inference( |
| x: torch.Tensor, |
| bias: typing.Optional[torch.Tensor], |
| scale: torch.Tensor, |
| residual: typing.Optional[torch.Tensor], |
| prob: float, |
| ) -> torch.Tensor: |
| return bias_dropout_add_scale(x, bias, scale, residual, prob, False) |
|
|
|
|
| def modulate_fused(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: |
| return x * (1 + scale) + shift |
|
|
|
|
| |
| |
| |
|
|
| class Rotary(nn.Module): |
| def __init__(self, dim: int, base: int = 10_000, max_seq_len: int = 512): |
| super().__init__() |
| self.dim = dim |
| self.base = base |
| self.max_seq_len = max_seq_len |
| self._precompute() |
|
|
| def _precompute(self): |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
| t = torch.arange(self.max_seq_len).type_as(inv_freq) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| |
| cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) |
| sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) |
| cos_cached[:, :, 2, :, :].fill_(1.0) |
| sin_cached[:, :, 2, :, :].fill_(0.0) |
| self.register_buffer("cos_cached", cos_cached) |
| self.register_buffer("sin_cached", sin_cached) |
|
|
| def forward(self, x: torch.Tensor, seq_dim: int = 1, position_ids: typing.Optional[torch.Tensor] = None): |
| if position_ids is not None: |
| |
| cos = self.cos_cached[:, position_ids] |
| sin = self.sin_cached[:, position_ids] |
| return cos, sin |
| seq_len = x.shape[seq_dim] |
| return self.cos_cached[:, :seq_len], self.sin_cached[:, :seq_len] |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| if _has_flash_attn: |
| cos = cos[0, :, 0, 0, : cos.shape[-1] // 2] |
| sin = sin[0, :, 0, 0, : sin.shape[-1] // 2] |
| return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin) |
| return (qkv * cos) + (rotate_half(qkv) * sin) |
|
|
|
|
| |
| |
| |
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.dim = dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return F.layer_norm(x.float(), [self.dim]) * self.weight[None, None, :] |
|
|
|
|
| class EmbeddingLayer(nn.Module): |
| """Token embedding table (parameter, not nn.Embedding, for easy weight sharing).""" |
|
|
| def __init__(self, dim: int, vocab_dim: int): |
| super().__init__() |
| self.embedding = nn.Parameter(torch.empty(vocab_dim, dim)) |
| nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.embedding[x] |
|
|
|
|
| |
| |
| |
|
|
| class DDiTBlockWithMask(nn.Module): |
| """ |
| DiT block with adaLN-Zero conditioning and optional attention mask. |
| Supports both flash-attention (if installed) and standard SDPA. |
| """ |
|
|
| def __init__(self, dim: int, n_heads: int, cond_dim: int, |
| mlp_ratio: int = 4, dropout: float = 0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.dim = dim |
| self.dropout = dropout |
|
|
| self.norm1 = LayerNorm(dim) |
| self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) |
| self.attn_out = nn.Linear(dim, dim, bias=False) |
|
|
| self.norm2 = LayerNorm(dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_ratio * dim, bias=True), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(mlp_ratio * dim, dim, bias=True), |
| ) |
|
|
| |
| self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) |
| self.adaLN_modulation.weight.data.zero_() |
| self.adaLN_modulation.bias.data.zero_() |
|
|
| def _bias_dropout_scale_fn(self): |
| return bias_dropout_add_scale_fused_train if self.training \ |
| else bias_dropout_add_scale_fused_inference |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rotary_cos_sin: typing.Tuple[torch.Tensor, torch.Tensor], |
| c: torch.Tensor, |
| attention_mask: typing.Optional[torch.Tensor] = None, |
| seqlens: typing.Optional[torch.Tensor] = None, |
| flex_block_mask=None, |
| ) -> torch.Tensor: |
| B, S = x.shape[:2] |
| bds_fn = self._bias_dropout_scale_fn() |
|
|
| (shift_msa, scale_msa, gate_msa, |
| shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) |
|
|
| |
| x_skip = x |
| x = modulate_fused(self.norm1(x), shift_msa, scale_msa) |
| qkv = self.attn_qkv(x) |
| qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.n_heads) |
| cos, sin = rotary_cos_sin |
| qkv = apply_rotary_pos_emb(qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)) |
|
|
| if flex_block_mask is not None: |
| |
| q = qkv[:, :, 0].transpose(1, 2) |
| k = qkv[:, :, 1].transpose(1, 2) |
| v = qkv[:, :, 2].transpose(1, 2) |
| x = _flex_attention_compiled(q, k, v, block_mask=flex_block_mask) |
| x = rearrange(x, "b h s d -> b s (h d)", b=B) |
| elif _has_flash_attn and attention_mask is None: |
| qkv = rearrange(qkv, "b s ... -> (b s) ...") |
| cu = seqlens.cumsum(-1) if seqlens is not None else torch.arange( |
| 0, (B + 1) * S, step=S, dtype=torch.int32, device=qkv.device) |
| x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func( |
| qkv, cu, S, 0.0, causal=False) |
| x = rearrange(x, "(b s) h d -> b s (h d)", b=B) |
| else: |
| q = qkv[:, :, 0].transpose(1, 2) |
| k = qkv[:, :, 1].transpose(1, 2) |
| v = qkv[:, :, 2].transpose(1, 2) |
| if attention_mask is not None: |
| if attention_mask.is_floating_point(): |
| |
| |
| |
| float_mask = attention_mask |
| if float_mask.dim() == 2: |
| |
| float_mask = float_mask.unsqueeze(1).unsqueeze(1) |
| while float_mask.dim() < 4: |
| float_mask = float_mask.unsqueeze(0) |
| elif attention_mask.dim() == 2: |
| |
| float_mask = torch.zeros_like(attention_mask, dtype=q.dtype) |
| float_mask = float_mask.masked_fill(~attention_mask.bool(), -1e9) |
| float_mask = float_mask.unsqueeze(0).unsqueeze(0) |
| else: |
| |
| attn_mask = attention_mask.bool().unsqueeze(1) & attention_mask.bool().unsqueeze(2) |
| float_mask = torch.zeros(attn_mask.shape, dtype=q.dtype, device=q.device) |
| float_mask.masked_fill_(~attn_mask.unsqueeze(1), -1e9) |
| x = F.scaled_dot_product_attention(q, k, v, attn_mask=float_mask) |
| else: |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b h s d -> b s (h d)", b=B) |
|
|
| x = bds_fn(self.attn_out(x), None, gate_msa, x_skip, self.dropout) |
|
|
| |
| x = bds_fn( |
| self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)), |
| None, gate_mlp, x, self.dropout, |
| ) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class DDitFinalLayer(nn.Module): |
| def __init__(self, hidden_size: int, out_channels: int, cond_dim: int): |
| super().__init__() |
| self.norm_final = LayerNorm(hidden_size) |
| self.linear = nn.Linear(hidden_size, out_channels) |
| self.linear.weight.data.zero_() |
| self.linear.bias.data.zero_() |
| self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True) |
| self.adaLN_modulation.weight.data.zero_() |
| self.adaLN_modulation.bias.data.zero_() |
|
|
| def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) |
| x = modulate_fused(self.norm_final(x), shift, scale) |
| return self.linear(x) |
|
|