| """Dense SDPA attention blocks used by the DINAC-AE export.""" |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
| from common.norms import RMSNorm |
| from common.rope import rotate_half, rotate_half_adjacent |
| from dit.position_encoding import DiTPositionEncoding |
|
|
|
|
| def _axial_rope_rotate_fn( |
| position_encoding: DiTPositionEncoding, |
| ) -> Callable[[Tensor], Tensor]: |
| """Return the head-dimension rotation matching the configured RoPE layout.""" |
|
|
| match position_encoding: |
| case ( |
| DiTPositionEncoding.ROPE_2D_AXIAL_DILATED |
| | DiTPositionEncoding.ROPE_2D_AXIAL_NORMALIZED |
| | DiTPositionEncoding.ROPE_2D_AXIAL_FREQ_AWARE |
| | DiTPositionEncoding.ROPE_1D |
| ): |
| return rotate_half |
| case ( |
| DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED |
| | DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED_DILATED |
| | DiTPositionEncoding.ROPE_2D_AXIAL_BETA_WARP |
| | DiTPositionEncoding.ROPE_2D_AXIAL_ALPHA_WARP |
| | DiTPositionEncoding.ROPE_3D_ZIMAGE |
| ): |
| return rotate_half_adjacent |
| case _ as unreachable: |
| raise ValueError(f"Unsupported RoPE position encoding: {unreachable}") |
|
|
|
|
| class DitSelfAttentionCore(nn.Module): |
| """Dense self-attention core with optional axial RoPE on Q/K.""" |
|
|
| d_model: int |
| n_heads: int |
| head_dim: int |
| position_encoding: DiTPositionEncoding |
| qkv: nn.Linear |
| proj_out: nn.Linear |
| q_norm: RMSNorm |
| k_norm: RMSNorm |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| *, |
| position_encoding: DiTPositionEncoding, |
| ) -> None: |
| super().__init__() |
| if d_model % n_heads != 0: |
| raise ValueError("d_model must be divisible by n_heads") |
| self.d_model = int(d_model) |
| self.n_heads = int(n_heads) |
| self.head_dim = int(self.d_model // self.n_heads) |
| self.position_encoding = position_encoding |
| self.qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=False) |
| self.proj_out = nn.Linear(self.d_model, self.d_model, bias=False) |
| self.q_norm = RMSNorm(self.head_dim) |
| self.k_norm = RMSNorm(self.head_dim) |
|
|
| def reset_parameters(self) -> None: |
| """Reset projections to their initialization.""" |
|
|
| nn.init.xavier_uniform_(self.qkv.weight) |
| nn.init.xavier_uniform_(self.proj_out.weight) |
|
|
| def forward( |
| self, tokens: Tensor, *, rope_sincos: tuple[Tensor, Tensor] | None |
| ) -> Tensor: |
| """Apply dense self-attention to ``[B, N, D]`` tokens.""" |
|
|
| batch, sequence_length, _width = tokens.shape |
| qkv = self.qkv(tokens) |
| q, k, v = qkv.chunk(3, dim=-1) |
| q = q.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2) |
| q = self.q_norm(q.contiguous()) |
| k = self.k_norm(k.contiguous()) |
| q, k = self._apply_axial_rope_dense(q, k, rope_sincos=rope_sincos) |
| attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) |
| attn = ( |
| attn.transpose(1, 2).contiguous().view(batch, sequence_length, self.d_model) |
| ) |
| return self.proj_out(attn) |
|
|
| def _apply_axial_rope_dense( |
| self, |
| q: Tensor, |
| k: Tensor, |
| *, |
| rope_sincos: tuple[Tensor, Tensor] | None, |
| ) -> tuple[Tensor, Tensor]: |
| """Apply axial RoPE to dense Q/K tensors.""" |
|
|
| if rope_sincos is None: |
| return q, k |
| sin, cos = rope_sincos |
| rope_len = int(sin.shape[-2]) |
| rope_dtype = sin.dtype |
| q_dtype = q.dtype |
| k_dtype = k.dtype |
| q_rope = q.to(dtype=rope_dtype) |
| k_rope = k.to(dtype=rope_dtype) |
| match sin.dim(): |
| case 2: |
| sin_b = sin.view(1, 1, rope_len, self.head_dim) |
| cos_b = cos.view(1, 1, rope_len, self.head_dim) |
| case 3: |
| sin_b = sin.view(int(q.shape[0]), 1, rope_len, self.head_dim) |
| cos_b = cos.view(int(q.shape[0]), 1, rope_len, self.head_dim) |
| case _ as unreachable: |
| raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}") |
| rotate = _axial_rope_rotate_fn(self.position_encoding) |
| q_span = q_rope[:, :, :rope_len, :] |
| k_span = k_rope[:, :, :rope_len, :] |
| q_head = (q_span * cos_b) + (rotate(q_span) * sin_b) |
| k_head = (k_span * cos_b) + (rotate(k_span) * sin_b) |
| q_rope = torch.cat([q_head, q_rope[:, :, rope_len:, :]], dim=2) |
| k_rope = torch.cat([k_head, k_rope[:, :, rope_len:, :]], dim=2) |
| return q_rope.to(dtype=q_dtype), k_rope.to(dtype=k_dtype) |
|
|
|
|
| class CrossAttentionCore(nn.Module): |
| """Dense cross-attention core used by the class-token readout.""" |
|
|
| query_dim: int |
| context_dim: int |
| context_extra_dim: int |
| key_extra_dim: int |
| n_heads: int |
| head_dim: int |
| attn_dim: int |
| context_in_dim: int |
| attn_dropout: float |
| kv_proj: nn.Linear |
| k_extra_proj: nn.Linear | None |
| out_proj: nn.Linear |
| q_norm_heads: RMSNorm |
| k_norm_heads: RMSNorm |
|
|
| def __init__( |
| self, |
| *, |
| query_dim: int, |
| context_dim: int, |
| n_heads: int, |
| head_dim: int, |
| context_extra_dim: int = 0, |
| key_extra_dim: int = 0, |
| attn_dropout: float = 0.0, |
| ) -> None: |
| super().__init__() |
| self.query_dim = int(query_dim) |
| self.context_dim = int(context_dim) |
| self.context_extra_dim = int(context_extra_dim) |
| self.key_extra_dim = int(key_extra_dim) |
| self.n_heads = int(n_heads) |
| self.head_dim = int(head_dim) |
| self.attn_dim = int(self.n_heads * self.head_dim) |
| self.context_in_dim = int(self.context_dim + self.context_extra_dim) |
| self.attn_dropout = float(attn_dropout) |
| self.kv_proj = nn.Linear(self.context_in_dim, 2 * self.attn_dim, bias=False) |
| if self.key_extra_dim == 0: |
| self.k_extra_proj = None |
| else: |
| self.k_extra_proj = nn.Linear(self.key_extra_dim, self.attn_dim, bias=False) |
| self.out_proj = nn.Linear(self.attn_dim, self.query_dim, bias=False) |
| self.q_norm_heads = RMSNorm(self.head_dim) |
| self.k_norm_heads = RMSNorm(self.head_dim) |
|
|
| def reset_parameters(self) -> None: |
| """Reset projections to their initialization.""" |
|
|
| nn.init.xavier_uniform_(self.kv_proj.weight) |
| if self.k_extra_proj is not None: |
| nn.init.xavier_uniform_(self.k_extra_proj.weight) |
| nn.init.xavier_uniform_(self.out_proj.weight) |
|
|
| def _split_heads(self, x: Tensor) -> Tensor: |
| batch, sequence_length, _width = x.shape |
| return x.view(batch, sequence_length, self.n_heads, self.head_dim).transpose( |
| 1, 2 |
| ) |
|
|
| def _merge_heads(self, x: Tensor) -> Tensor: |
| batch, _heads, sequence_length, _head_dim = x.shape |
| return ( |
| x.transpose(1, 2).contiguous().view(batch, sequence_length, self.attn_dim) |
| ) |
|
|
| def forward( |
| self, |
| q_tokens: Tensor, |
| kv_tokens: Tensor, |
| *, |
| training: bool, |
| key_extra: Tensor | None = None, |
| key_padding_mask: Tensor | None = None, |
| ) -> Tensor: |
| """Apply dense cross-attention to query and context tokens.""" |
|
|
| kv = self.kv_proj(kv_tokens) |
| k, v = kv.chunk(2, dim=-1) |
| if self.k_extra_proj is not None and key_extra is not None: |
| k = k + self.k_extra_proj(key_extra) |
| q = self.q_norm_heads(self._split_heads(q_tokens).contiguous()) |
| k = self.k_norm_heads(self._split_heads(k).contiguous()) |
| v = self._split_heads(v).contiguous() |
| if key_padding_mask is None: |
| attn_mask = None |
| else: |
| attn_mask = (~key_padding_mask).to(dtype=q.dtype) |
| attn_mask = attn_mask.view( |
| key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1] |
| ) |
| attn_mask = attn_mask.masked_fill(attn_mask > 0, float("-inf")) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=attn_mask, |
| dropout_p=self.attn_dropout if training else 0.0, |
| is_causal=False, |
| ) |
| return self.out_proj(self._merge_heads(attn)) |
|
|
|
|
| __all__ = ["CrossAttentionCore", "DitSelfAttentionCore"] |
|
|