"""Dense SDPA attention blocks used by the DINAC-AE export.""" from __future__ import annotations from collections.abc import Callable import torch import torch.nn.functional as F from torch import Tensor, nn from common.norms import RMSNorm from common.rope import rotate_half, rotate_half_adjacent from dit.position_encoding import DiTPositionEncoding def _axial_rope_rotate_fn( position_encoding: DiTPositionEncoding, ) -> Callable[[Tensor], Tensor]: """Return the head-dimension rotation matching the configured RoPE layout.""" match position_encoding: case ( DiTPositionEncoding.ROPE_2D_AXIAL_DILATED | DiTPositionEncoding.ROPE_2D_AXIAL_NORMALIZED | DiTPositionEncoding.ROPE_2D_AXIAL_FREQ_AWARE | DiTPositionEncoding.ROPE_1D ): return rotate_half case ( DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED | DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED_DILATED | DiTPositionEncoding.ROPE_2D_AXIAL_BETA_WARP | DiTPositionEncoding.ROPE_2D_AXIAL_ALPHA_WARP | DiTPositionEncoding.ROPE_3D_ZIMAGE ): return rotate_half_adjacent case _ as unreachable: raise ValueError(f"Unsupported RoPE position encoding: {unreachable}") class DitSelfAttentionCore(nn.Module): """Dense self-attention core with optional axial RoPE on Q/K.""" d_model: int n_heads: int head_dim: int position_encoding: DiTPositionEncoding qkv: nn.Linear proj_out: nn.Linear q_norm: RMSNorm k_norm: RMSNorm def __init__( self, d_model: int, n_heads: int, *, position_encoding: DiTPositionEncoding, ) -> None: super().__init__() if d_model % n_heads != 0: raise ValueError("d_model must be divisible by n_heads") self.d_model = int(d_model) self.n_heads = int(n_heads) self.head_dim = int(self.d_model // self.n_heads) self.position_encoding = position_encoding self.qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=False) self.proj_out = nn.Linear(self.d_model, self.d_model, bias=False) self.q_norm = RMSNorm(self.head_dim) self.k_norm = RMSNorm(self.head_dim) def reset_parameters(self) -> None: """Reset projections to their initialization.""" nn.init.xavier_uniform_(self.qkv.weight) nn.init.xavier_uniform_(self.proj_out.weight) def forward( self, tokens: Tensor, *, rope_sincos: tuple[Tensor, Tensor] | None ) -> Tensor: """Apply dense self-attention to ``[B, N, D]`` tokens.""" batch, sequence_length, _width = tokens.shape qkv = self.qkv(tokens) q, k, v = qkv.chunk(3, dim=-1) q = q.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(batch, sequence_length, self.n_heads, self.head_dim).transpose(1, 2) q = self.q_norm(q.contiguous()) k = self.k_norm(k.contiguous()) q, k = self._apply_axial_rope_dense(q, k, rope_sincos=rope_sincos) attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) attn = ( attn.transpose(1, 2).contiguous().view(batch, sequence_length, self.d_model) ) return self.proj_out(attn) def _apply_axial_rope_dense( self, q: Tensor, k: Tensor, *, rope_sincos: tuple[Tensor, Tensor] | None, ) -> tuple[Tensor, Tensor]: """Apply axial RoPE to dense Q/K tensors.""" if rope_sincos is None: return q, k sin, cos = rope_sincos rope_len = int(sin.shape[-2]) rope_dtype = sin.dtype q_dtype = q.dtype k_dtype = k.dtype q_rope = q.to(dtype=rope_dtype) k_rope = k.to(dtype=rope_dtype) match sin.dim(): case 2: sin_b = sin.view(1, 1, rope_len, self.head_dim) cos_b = cos.view(1, 1, rope_len, self.head_dim) case 3: sin_b = sin.view(int(q.shape[0]), 1, rope_len, self.head_dim) cos_b = cos.view(int(q.shape[0]), 1, rope_len, self.head_dim) case _ as unreachable: raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}") rotate = _axial_rope_rotate_fn(self.position_encoding) q_span = q_rope[:, :, :rope_len, :] k_span = k_rope[:, :, :rope_len, :] q_head = (q_span * cos_b) + (rotate(q_span) * sin_b) k_head = (k_span * cos_b) + (rotate(k_span) * sin_b) q_rope = torch.cat([q_head, q_rope[:, :, rope_len:, :]], dim=2) k_rope = torch.cat([k_head, k_rope[:, :, rope_len:, :]], dim=2) return q_rope.to(dtype=q_dtype), k_rope.to(dtype=k_dtype) class CrossAttentionCore(nn.Module): """Dense cross-attention core used by the class-token readout.""" query_dim: int context_dim: int context_extra_dim: int key_extra_dim: int n_heads: int head_dim: int attn_dim: int context_in_dim: int attn_dropout: float kv_proj: nn.Linear k_extra_proj: nn.Linear | None out_proj: nn.Linear q_norm_heads: RMSNorm k_norm_heads: RMSNorm def __init__( self, *, query_dim: int, context_dim: int, n_heads: int, head_dim: int, context_extra_dim: int = 0, key_extra_dim: int = 0, attn_dropout: float = 0.0, ) -> None: super().__init__() self.query_dim = int(query_dim) self.context_dim = int(context_dim) self.context_extra_dim = int(context_extra_dim) self.key_extra_dim = int(key_extra_dim) self.n_heads = int(n_heads) self.head_dim = int(head_dim) self.attn_dim = int(self.n_heads * self.head_dim) self.context_in_dim = int(self.context_dim + self.context_extra_dim) self.attn_dropout = float(attn_dropout) self.kv_proj = nn.Linear(self.context_in_dim, 2 * self.attn_dim, bias=False) if self.key_extra_dim == 0: self.k_extra_proj = None else: self.k_extra_proj = nn.Linear(self.key_extra_dim, self.attn_dim, bias=False) self.out_proj = nn.Linear(self.attn_dim, self.query_dim, bias=False) self.q_norm_heads = RMSNorm(self.head_dim) self.k_norm_heads = RMSNorm(self.head_dim) def reset_parameters(self) -> None: """Reset projections to their initialization.""" nn.init.xavier_uniform_(self.kv_proj.weight) if self.k_extra_proj is not None: nn.init.xavier_uniform_(self.k_extra_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) def _split_heads(self, x: Tensor) -> Tensor: batch, sequence_length, _width = x.shape return x.view(batch, sequence_length, self.n_heads, self.head_dim).transpose( 1, 2 ) def _merge_heads(self, x: Tensor) -> Tensor: batch, _heads, sequence_length, _head_dim = x.shape return ( x.transpose(1, 2).contiguous().view(batch, sequence_length, self.attn_dim) ) def forward( self, q_tokens: Tensor, kv_tokens: Tensor, *, training: bool, key_extra: Tensor | None = None, key_padding_mask: Tensor | None = None, ) -> Tensor: """Apply dense cross-attention to query and context tokens.""" kv = self.kv_proj(kv_tokens) k, v = kv.chunk(2, dim=-1) if self.k_extra_proj is not None and key_extra is not None: k = k + self.k_extra_proj(key_extra) q = self.q_norm_heads(self._split_heads(q_tokens).contiguous()) k = self.k_norm_heads(self._split_heads(k).contiguous()) v = self._split_heads(v).contiguous() if key_padding_mask is None: attn_mask = None else: attn_mask = (~key_padding_mask).to(dtype=q.dtype) attn_mask = attn_mask.view( key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1] ) attn_mask = attn_mask.masked_fill(attn_mask > 0, float("-inf")) attn = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout if training else 0.0, is_causal=False, ) return self.out_proj(self._merge_heads(attn)) __all__ = ["CrossAttentionCore", "DitSelfAttentionCore"]