"""Dense cross-attention block used by the DINAC-AE class-token head.""" from __future__ import annotations from dataclasses import dataclass from torch import Tensor, nn from common.norms import RMSNorm from dit.attention_blocks import CrossAttentionCore from dit.mlp import build_dit_mlp, reset_module_parameters from dit.mlp_types import MLPType @dataclass class CrossAttentionConfig: """Configuration for the exported dense cross-attention block.""" n_heads: int = 16 head_dim: int | None = None query_extra_dim: int = 0 context_extra_dim: int = 0 key_extra_dim: int = 0 mlp_ratio: float = 2.0 attn_dropout: float = 0.0 mlp_type: MLPType = MLPType.GELU activation_config: object | None = None use_norms: bool = True block_index: int = 0 use_attn_residual: bool = True class CrossAttentionBlock(nn.Module): """Dense pre-norm cross-attention plus residual MLP.""" query_dim: int context_dim: int query_extra_dim: int context_extra_dim: int key_extra_dim: int n_heads: int head_dim: int attn_dim: int use_norms: bool attn_dropout: float use_attn_residual: bool query_norm: RMSNorm | None context_norm: RMSNorm | None mlp_norm: RMSNorm | None q_proj: nn.Linear attn_core: CrossAttentionCore kv_proj: nn.Linear out_proj: nn.Linear mlp: nn.Module def __init__( self, *, query_dim: int, context_dim: int, cfg: CrossAttentionConfig, ) -> None: super().__init__() n_heads = int(cfg.n_heads) if cfg.head_dim is None: if query_dim % n_heads != 0: raise ValueError("query_dim must be divisible by n_heads") head_dim = query_dim // n_heads else: head_dim = int(cfg.head_dim) self.query_dim = int(query_dim) self.context_dim = int(context_dim) self.query_extra_dim = int(cfg.query_extra_dim) self.context_extra_dim = int(cfg.context_extra_dim) self.key_extra_dim = int(cfg.key_extra_dim) self.n_heads = n_heads self.head_dim = int(head_dim) self.attn_dim = int(self.n_heads * self.head_dim) self.use_norms = bool(cfg.use_norms) self.attn_dropout = float(cfg.attn_dropout) if not cfg.use_attn_residual: raise ValueError("DINAC-AE export requires attention residuals") self.use_attn_residual = True self.query_norm = RMSNorm(self.query_dim) if self.use_norms else None self.context_norm = RMSNorm(self.context_dim) if self.use_norms else None self.mlp_norm = RMSNorm(query_dim) if self.use_norms else None self.q_proj = nn.Linear( self.query_dim + self.query_extra_dim, self.attn_dim, bias=False ) self.attn_core = CrossAttentionCore( query_dim=query_dim, context_dim=context_dim, context_extra_dim=self.context_extra_dim, key_extra_dim=self.key_extra_dim, n_heads=self.n_heads, head_dim=self.head_dim, attn_dropout=self.attn_dropout, ) self.kv_proj = self.attn_core.kv_proj self.out_proj = self.attn_core.out_proj hidden = int(round(cfg.mlp_ratio * query_dim)) self.mlp = build_dit_mlp( mlp_type=cfg.mlp_type, in_features=query_dim, hidden_budget=hidden, activation_config=cfg.activation_config, block_index=int(cfg.block_index), bias_up=False, bias_down=False, ) self.reset_parameters() def reset_parameters(self) -> None: """Reset projections and MLP parameters.""" nn.init.xavier_uniform_(self.q_proj.weight) self.attn_core.reset_parameters() reset_module_parameters(self.mlp) def forward( self, query: Tensor, context: Tensor, *, query_extra: Tensor | None = None, context_extra: Tensor | None = None, key_extra: Tensor | None = None, key_padding_mask: Tensor | None = None, ) -> Tensor: # type: ignore[override] """Run dense cross-attention followed by the residual MLP.""" query_tokens = self.query_norm(query) if self.query_norm is not None else query if query_extra is not None: q_in = query_tokens.new_empty( *query_tokens.shape[:-1], int(query_tokens.shape[-1]) + int(query_extra.shape[-1]), ) q_in[..., : int(query_tokens.shape[-1])] = query_tokens q_in[..., int(query_tokens.shape[-1]) :] = query_extra else: q_in = query_tokens context_tokens = ( self.context_norm(context) if self.context_norm is not None else context ) if context_extra is not None: kv_tokens = context_tokens.new_empty( *context_tokens.shape[:-1], int(context_tokens.shape[-1]) + int(context_extra.shape[-1]), ) kv_tokens[..., : int(context_tokens.shape[-1])] = context_tokens kv_tokens[..., int(context_tokens.shape[-1]) :] = context_extra else: kv_tokens = context_tokens q_attn_tokens = self.q_proj(q_in) attn_out = self.attn_core( q_attn_tokens, kv_tokens, training=self.training, key_extra=key_extra, key_padding_mask=key_padding_mask, ) tokens = query + attn_out mlp_in = self.mlp_norm(tokens) if self.mlp_norm is not None else tokens return tokens + self.mlp(mlp_in) def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None: """No-op hook kept for the token-alignment head API.""" _ = fullgraph, dynamic def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None: """No-op hook kept for the token-alignment head API.""" _ = fullgraph, dynamic __all__ = ["CrossAttentionBlock", "CrossAttentionConfig"]