import torch import torch.nn as nn from einops import rearrange import warnings import torch.nn.functional as F USE_FLASH_ATTENTION3 = True try: from flash_attn_interface import flash_attn_func FA3_AVAILABLE = True warnings.warn('flash attention 3 is available (LVSM)') except ImportError: FA3_AVAILABLE = False warnings.warn('flash attention 3 is not available (LVSM)') try: import xformers.ops as xops XFORMERS_AVAILABLE = True except ImportError: XFORMERS_AVAILABLE = False warnings.warn('xformers is not available (LVSM)') # raise ImportError("Please install xformers to use flashatt v2") def init_weights(module, std=0.02): """Initialize weights for linear and embedding layers. Args: module: Module to initialize std: Standard deviation for normal initialization """ if isinstance(module, (nn.Linear, nn.Embedding)): torch.nn.init.normal_(module.weight, mean=0.0, std=std) if isinstance(module, nn.Linear) and module.bias is not None: torch.nn.init.zeros_(module.bias) # src: https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/llama/model.py#L28 class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight.type_as(x) class MLP(nn.Module): """ Multi-Layer Perceptron block. Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65 """ def __init__( self, dim, mlp_ratio=4, bias=False, dropout=0.0, activation=nn.GELU, mlp_dim=None, ): """ Args: dim: Input dimension mlp_ratio: Multiplier for hidden dimension bias: Whether to use bias in linear layers dropout: Dropout probability activation: Activation function mlp_dim: Optional explicit hidden dimension (overrides mlp_ratio) """ super().__init__() hidden_dim = mlp_dim if mlp_dim is not None else int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, hidden_dim, bias=bias), activation(), nn.Linear(hidden_dim, dim, bias=bias), nn.Dropout(dropout), ) def forward(self, x): return self.mlp(x) class QK_Norm_SelfAttention(nn.Module): """ Self-attention with optional Q-K normalization. Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92 """ def __init__( self, dim, head_dim, qkv_bias=False, fc_bias=True, attn_dropout=0.0, fc_dropout=0.0, use_qk_norm=True, ): """ Args: dim: Input dimension head_dim: Dimension of each attention head qkv_bias: Whether to use bias in QKV projection fc_bias: Whether to use bias in output projection attn_dropout: Dropout probability for attention weights fc_dropout: Dropout probability for output projection use_qk_norm: Whether to use Q-K normalization We use flash attention V2 for efficiency. """ super().__init__() assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}" self.dim = dim self.head_dim = head_dim self.num_heads = dim // head_dim self.attn_dropout = attn_dropout self.use_qk_norm = use_qk_norm self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) self.fc = nn.Linear(dim, dim, bias=fc_bias) self.attn_fc_dropout = nn.Dropout(fc_dropout) # Optional Q-K normalization if self.use_qk_norm: self.q_norm = RMSNorm(head_dim) self.k_norm = RMSNorm(head_dim) def forward(self, x, attn_bias=None): """ Args: x: Input tensor of shape (batch, seq_len, dim) attn_bias: Optional attention bias mask Returns: Output tensor of shape (batch, seq_len, dim) """ q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v)) # Apply qk normalization if enabled if self.use_qk_norm: q = self.q_norm(q) k = self.k_norm(k) if USE_FLASH_ATTENTION3 and FA3_AVAILABLE: x = flash_attn_func(q, k, v)[0] elif XFORMERS_AVAILABLE: x = xops.memory_efficient_attention( q, k, v, attn_bias=attn_bias, p=self.attn_dropout if self.training else 0.0, op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), ) else: # use pytorch's built-in attention q = q.permute(0, 2, 1, 3).contiguous() # [B, H, L, C] k = k.permute(0, 2, 1, 3).contiguous() v = v.permute(0, 2, 1, 3).contiguous() x = F.scaled_dot_product_attention(q, k, v) x = x.permute(0, 2, 1, 3).contiguous() # [B, L, H, C] x = rearrange(x, "b l nh dh -> b l (nh dh)") x = self.attn_fc_dropout(self.fc(x)) return x class SubsetAttention(nn.Module): """Attention that can attend to subsets of queries or keys/values.""" def __init__( self, dim, head_dim, qkv_bias=False, attn_dropout=0.0, fc_bias=False, fc_dropout=0.0, use_qk_norm=False ): """ Args: dim: Input dimension head_dim: Dimension of each attention head qkv_bias: Whether to use bias in QKV projection attn_dropout: Dropout probability for attention weights fc_bias: Whether to use bias in output projection fc_dropout: Dropout probability for output projection use_qk_norm: Whether to use Q-K normalization We use flash attention V2 for efficiency. """ super().__init__() assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}" self.dim = dim self.head_dim = head_dim self.num_heads = dim // head_dim self.attn_dropout = attn_dropout self.use_qk_norm = use_qk_norm # Projections self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) self.fc = nn.Linear(dim, dim, bias=fc_bias) self.attn_fc_dropout = nn.Dropout(fc_dropout) # Optional Q-K normalization if self.use_qk_norm: self.q_norm = RMSNorm(head_dim) self.k_norm = RMSNorm(head_dim) def forward(self, x, subset_kv_size=None, subset_q_size=None): """ Args: x: Input tensor of shape (batch, seq_len, dim) subset_kv_size: If provided, only attend to tokens after this index in KV subset_q_size: If provided, only compute attention for queries up to this index Returns: Output tensor of shape (batch, seq_len, dim) """ # Only one subset parameter can be provided assert not (subset_kv_size is not None and subset_q_size is not None), \ "Only one of subset_kv_size or subset_q_size can be provided" q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v)) if self.use_qk_norm: q = self.q_norm(q) k = self.k_norm(k) # Handle subset attention cases if subset_kv_size is not None and subset_kv_size < k.shape[1]: # Attend to subset of key/value tokens k_subset = k[:, subset_kv_size:, :, :].contiguous() v_subset = v[:, subset_kv_size:, :, :].contiguous() x = xops.memory_efficient_attention( q, k_subset, v_subset, attn_bias=None, p=self.attn_dropout if self.training else 0.0, op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), ) elif subset_q_size is not None and subset_q_size < q.shape[1]: # Only compute attention for subset of query tokens q_subset = q[:, :subset_q_size, :, :].contiguous() x = xops.memory_efficient_attention( q_subset, k, v, attn_bias=None, p=self.attn_dropout if self.training else 0.0, op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), ) else: # Regular attention for all tokens x = xops.memory_efficient_attention( q, k, v, attn_bias=None, p=self.attn_dropout if self.training else 0.0, op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), ) x = rearrange(x, "b l nh dh -> b l (nh dh)") # Final projection x = self.attn_fc_dropout(self.fc(x)) return x class QK_Norm_TransformerBlock(nn.Module): """ Standard transformer block with pre-normalization architecture. Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113 """ def __init__( self, dim, head_dim, ln_bias=False, attn_qkv_bias=False, attn_dropout=0.0, attn_fc_bias=False, attn_fc_dropout=0.0, mlp_ratio=4, mlp_bias=False, mlp_dropout=0.0, use_qk_norm=True, ): super().__init__() self.norm1 = nn.LayerNorm(dim, bias=ln_bias) self.attn = QK_Norm_SelfAttention( dim=dim, head_dim=head_dim, qkv_bias=attn_qkv_bias, fc_bias=attn_fc_bias, attn_dropout=attn_dropout, fc_dropout=attn_fc_dropout, use_qk_norm=use_qk_norm, ) self.norm2 = nn.LayerNorm(dim, bias=ln_bias) self.mlp = MLP( dim=dim, mlp_ratio=mlp_ratio, bias=mlp_bias, dropout=mlp_dropout, ) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x