Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| import warnings | |
| import torch.nn.functional as F | |
| USE_FLASH_ATTENTION3 = True | |
| try: | |
| from flash_attn_interface import flash_attn_func | |
| FA3_AVAILABLE = True | |
| warnings.warn('flash attention 3 is available (LVSM)') | |
| except ImportError: | |
| FA3_AVAILABLE = False | |
| warnings.warn('flash attention 3 is not available (LVSM)') | |
| try: | |
| import xformers.ops as xops | |
| XFORMERS_AVAILABLE = True | |
| except ImportError: | |
| XFORMERS_AVAILABLE = False | |
| warnings.warn('xformers is not available (LVSM)') | |
| # raise ImportError("Please install xformers to use flashatt v2") | |
| def init_weights(module, std=0.02): | |
| """Initialize weights for linear and embedding layers. | |
| Args: | |
| module: Module to initialize | |
| std: Standard deviation for normal initialization | |
| """ | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=std) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| # src: https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/llama/model.py#L28 | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight.type_as(x) | |
| class MLP(nn.Module): | |
| """ | |
| Multi-Layer Perceptron block. | |
| Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65 | |
| """ | |
| def __init__( | |
| self, | |
| dim, | |
| mlp_ratio=4, | |
| bias=False, | |
| dropout=0.0, | |
| activation=nn.GELU, | |
| mlp_dim=None, | |
| ): | |
| """ | |
| Args: | |
| dim: Input dimension | |
| mlp_ratio: Multiplier for hidden dimension | |
| bias: Whether to use bias in linear layers | |
| dropout: Dropout probability | |
| activation: Activation function | |
| mlp_dim: Optional explicit hidden dimension (overrides mlp_ratio) | |
| """ | |
| super().__init__() | |
| hidden_dim = mlp_dim if mlp_dim is not None else int(dim * mlp_ratio) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(dim, hidden_dim, bias=bias), | |
| activation(), | |
| nn.Linear(hidden_dim, dim, bias=bias), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class QK_Norm_SelfAttention(nn.Module): | |
| """ | |
| Self-attention with optional Q-K normalization. | |
| Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92 | |
| """ | |
| def __init__( | |
| self, | |
| dim, | |
| head_dim, | |
| qkv_bias=False, | |
| fc_bias=True, | |
| attn_dropout=0.0, | |
| fc_dropout=0.0, | |
| use_qk_norm=True, | |
| ): | |
| """ | |
| Args: | |
| dim: Input dimension | |
| head_dim: Dimension of each attention head | |
| qkv_bias: Whether to use bias in QKV projection | |
| fc_bias: Whether to use bias in output projection | |
| attn_dropout: Dropout probability for attention weights | |
| fc_dropout: Dropout probability for output projection | |
| use_qk_norm: Whether to use Q-K normalization | |
| We use flash attention V2 for efficiency. | |
| """ | |
| super().__init__() | |
| assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}" | |
| self.dim = dim | |
| self.head_dim = head_dim | |
| self.num_heads = dim // head_dim | |
| self.attn_dropout = attn_dropout | |
| self.use_qk_norm = use_qk_norm | |
| self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) | |
| self.fc = nn.Linear(dim, dim, bias=fc_bias) | |
| self.attn_fc_dropout = nn.Dropout(fc_dropout) | |
| # Optional Q-K normalization | |
| if self.use_qk_norm: | |
| self.q_norm = RMSNorm(head_dim) | |
| self.k_norm = RMSNorm(head_dim) | |
| def forward(self, x, attn_bias=None): | |
| """ | |
| Args: | |
| x: Input tensor of shape (batch, seq_len, dim) | |
| attn_bias: Optional attention bias mask | |
| Returns: | |
| Output tensor of shape (batch, seq_len, dim) | |
| """ | |
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
| q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v)) | |
| # Apply qk normalization if enabled | |
| if self.use_qk_norm: | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| if USE_FLASH_ATTENTION3 and FA3_AVAILABLE: | |
| x = flash_attn_func(q, k, v)[0] | |
| elif XFORMERS_AVAILABLE: | |
| x = xops.memory_efficient_attention( | |
| q, k, v, | |
| attn_bias=attn_bias, | |
| p=self.attn_dropout if self.training else 0.0, | |
| op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), | |
| ) | |
| else: | |
| # use pytorch's built-in attention | |
| q = q.permute(0, 2, 1, 3).contiguous() # [B, H, L, C] | |
| k = k.permute(0, 2, 1, 3).contiguous() | |
| v = v.permute(0, 2, 1, 3).contiguous() | |
| x = F.scaled_dot_product_attention(q, k, v) | |
| x = x.permute(0, 2, 1, 3).contiguous() # [B, L, H, C] | |
| x = rearrange(x, "b l nh dh -> b l (nh dh)") | |
| x = self.attn_fc_dropout(self.fc(x)) | |
| return x | |
| class SubsetAttention(nn.Module): | |
| """Attention that can attend to subsets of queries or keys/values.""" | |
| def __init__( | |
| self, | |
| dim, | |
| head_dim, | |
| qkv_bias=False, | |
| attn_dropout=0.0, | |
| fc_bias=False, | |
| fc_dropout=0.0, | |
| use_qk_norm=False | |
| ): | |
| """ | |
| Args: | |
| dim: Input dimension | |
| head_dim: Dimension of each attention head | |
| qkv_bias: Whether to use bias in QKV projection | |
| attn_dropout: Dropout probability for attention weights | |
| fc_bias: Whether to use bias in output projection | |
| fc_dropout: Dropout probability for output projection | |
| use_qk_norm: Whether to use Q-K normalization | |
| We use flash attention V2 for efficiency. | |
| """ | |
| super().__init__() | |
| assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}" | |
| self.dim = dim | |
| self.head_dim = head_dim | |
| self.num_heads = dim // head_dim | |
| self.attn_dropout = attn_dropout | |
| self.use_qk_norm = use_qk_norm | |
| # Projections | |
| self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) | |
| self.fc = nn.Linear(dim, dim, bias=fc_bias) | |
| self.attn_fc_dropout = nn.Dropout(fc_dropout) | |
| # Optional Q-K normalization | |
| if self.use_qk_norm: | |
| self.q_norm = RMSNorm(head_dim) | |
| self.k_norm = RMSNorm(head_dim) | |
| def forward(self, x, subset_kv_size=None, subset_q_size=None): | |
| """ | |
| Args: | |
| x: Input tensor of shape (batch, seq_len, dim) | |
| subset_kv_size: If provided, only attend to tokens after this index in KV | |
| subset_q_size: If provided, only compute attention for queries up to this index | |
| Returns: | |
| Output tensor of shape (batch, seq_len, dim) | |
| """ | |
| # Only one subset parameter can be provided | |
| assert not (subset_kv_size is not None and subset_q_size is not None), \ | |
| "Only one of subset_kv_size or subset_q_size can be provided" | |
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
| q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v)) | |
| if self.use_qk_norm: | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| # Handle subset attention cases | |
| if subset_kv_size is not None and subset_kv_size < k.shape[1]: | |
| # Attend to subset of key/value tokens | |
| k_subset = k[:, subset_kv_size:, :, :].contiguous() | |
| v_subset = v[:, subset_kv_size:, :, :].contiguous() | |
| x = xops.memory_efficient_attention( | |
| q, k_subset, v_subset, | |
| attn_bias=None, | |
| p=self.attn_dropout if self.training else 0.0, | |
| op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), | |
| ) | |
| elif subset_q_size is not None and subset_q_size < q.shape[1]: | |
| # Only compute attention for subset of query tokens | |
| q_subset = q[:, :subset_q_size, :, :].contiguous() | |
| x = xops.memory_efficient_attention( | |
| q_subset, k, v, | |
| attn_bias=None, | |
| p=self.attn_dropout if self.training else 0.0, | |
| op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), | |
| ) | |
| else: | |
| # Regular attention for all tokens | |
| x = xops.memory_efficient_attention( | |
| q, k, v, | |
| attn_bias=None, | |
| p=self.attn_dropout if self.training else 0.0, | |
| op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), | |
| ) | |
| x = rearrange(x, "b l nh dh -> b l (nh dh)") | |
| # Final projection | |
| x = self.attn_fc_dropout(self.fc(x)) | |
| return x | |
| class QK_Norm_TransformerBlock(nn.Module): | |
| """ | |
| Standard transformer block with pre-normalization architecture. | |
| Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113 | |
| """ | |
| def __init__( | |
| self, | |
| dim, | |
| head_dim, | |
| ln_bias=False, | |
| attn_qkv_bias=False, | |
| attn_dropout=0.0, | |
| attn_fc_bias=False, | |
| attn_fc_dropout=0.0, | |
| mlp_ratio=4, | |
| mlp_bias=False, | |
| mlp_dropout=0.0, | |
| use_qk_norm=True, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim, bias=ln_bias) | |
| self.attn = QK_Norm_SelfAttention( | |
| dim=dim, | |
| head_dim=head_dim, | |
| qkv_bias=attn_qkv_bias, | |
| fc_bias=attn_fc_bias, | |
| attn_dropout=attn_dropout, | |
| fc_dropout=attn_fc_dropout, | |
| use_qk_norm=use_qk_norm, | |
| ) | |
| self.norm2 = nn.LayerNorm(dim, bias=ln_bias) | |
| self.mlp = MLP( | |
| dim=dim, | |
| mlp_ratio=mlp_ratio, | |
| bias=mlp_bias, | |
| dropout=mlp_dropout, | |
| ) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |