| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import warnings |
| from typing import Optional |
| import importlib.metadata |
| import logging |
| import math |
| from .bert_padding import pad_input, unpad_input_only, index_first_axis |
| from .configuration_bert import FlexBertConfig, maybe_add_padding |
| from .normalization import get_norm_layer |
| from .initialization import ModuleType, init_weights |
|
|
| IMPL_USE_FLASH3 = False |
| IMPL_USE_FLASH2 = False |
| try: |
| from flash_attn_interface import flash_attn_varlen_func |
|
|
| IMPL_USE_FLASH3 = True |
| except ImportError: |
| pass |
| |
| try: |
| from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func |
|
|
| installed_version = importlib.metadata.version("flash_attn") |
| if installed_version < "2.5.7": |
| raise ImportError("newer version of flash_attn required (>= 2.5.7)") |
| IMPL_USE_FLASH2 = True |
| except ImportError: |
| pass |
|
|
| try: |
| from flash_attn.layers.rotary import RotaryEmbedding |
| from .rotary import UnpaddedRotaryEmbedding |
|
|
| except ImportError: |
| RotaryEmbedding = None |
| UnpaddedRotaryEmbedding = None |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BertAlibiUnpadSelfAttention(nn.Module): |
| """Performs multi-headed self attention on a batch of unpadded sequences. |
| |
| If Flash Attention 2 is installed, this module uses Flash Attention to greatly improve throughput. |
| The Flash Attention implementation used in MosaicBERT supports arbitrary attention biases (which |
| we use to implement ALiBi). If either Flash Attention 2 is not installed the implementation will |
| default to a math-equivalent pytorch version, which is much slower. |
| |
| See `forward` method for additional details. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size) |
| self.deterministic_fa2 = getattr(config, "deterministic_fa2", False) |
|
|
| |
| if not IMPL_USE_FLASH2: |
| warnings.warn( |
| "Unable to import flash_attn; defaulting MosaicBERT attention implementation to " |
| "vanilla PyTorch (this will reduce throughput when using this model)." |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| indices: torch.Tensor, |
| attn_mask: torch.Tensor, |
| bias: torch.Tensor, |
| slopes: torch.Tensor, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations: vanilla attention with ALiBi, and Flash Attention 2 with ALiBi |
| |
| The arguments are unpadded. The vanilla implementation of attention requires padded arguments while the |
| Flash Attention implementation does not. If using vanilla we first call `pad_input`. Once we compute |
| attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not |
| sending pad tokens through ffs saves compute. |
| |
| Args: |
| hidden_states: (total_nnz, dim) |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| indices: (total_nnz,) |
| attn_mask: (batch, max_seqlen) |
| bias: (batch, heads, max_seqlen, max_seqlen) |
| slopes: (heads) or (batch, heads) |
| |
| Returns: |
| attention: (total_nnz, dim) |
| """ |
| bs, dim = hidden_states.shape |
| qkv = self.Wqkv(hidden_states) |
|
|
| |
| if IMPL_USE_FLASH2: |
| qkv = qkv.view(-1, 3, self.num_attention_heads, self.attention_head_size) |
| assert 1 <= len(slopes.shape) <= 2, f"{slopes=}" |
| assert slopes.shape[-1] == self.num_attention_heads, f"{slopes=}" |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attention = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| alibi_slopes=slopes, |
| causal=self.is_causal |
| ) |
| attention = attention.to(orig_dtype) |
| else: |
| attention = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| alibi_slopes=slopes, |
| causal = self.is_causal |
| ) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) |
| unpad_bs, *_ = qkv.shape |
| qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size) |
| |
| q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) |
| k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) |
| v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) |
| attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size) |
| attention_scores = attention_scores + bias |
| attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
| attention_probs = self.dropout(attention_probs) |
| attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) |
|
|
| attention = bert_padding.unpad_input_only(attention, torch.squeeze(attn_mask) == 1) |
|
|
| return attention.view(bs, dim) |
|
|
|
|
| |
| class BertSelfOutput(nn.Module): |
| """Computes the output of the attention layer. |
| |
| This module is modeled after the Hugging Face BERT's |
| :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`. |
| The implementation is identical. Rather than use the original module |
| directly, we re-implement it here so that Mosaic BERT's modules will not |
| be affected by any Composer surgery algorithm that modifies Hugging Face |
| BERT modules. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.LayerNorm = get_norm_layer(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| return hidden_states |
|
|
|
|
| class BertAlibiUnpadAttention(nn.Module): |
| """Chains attention, Dropout, and LayerNorm for Mosaic BERT.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.self = BertAlibiUnpadSelfAttention(config) |
| self.output = BertSelfOutput(config) |
|
|
| def forward( |
| self, |
| input_tensor: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_s: int, |
| subset_idx: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| bias: Optional[torch.Tensor] = None, |
| slopes: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Forward pass for scaled self-attention without padding. |
| |
| Arguments: |
| input_tensor: (total_nnz, dim) |
| cu_seqlens: (batch + 1,) |
| max_s: int |
| subset_idx: () set of indices whose values we care about at the end of the layer |
| (e.g., the masked tokens, if this is the final layer). |
| indices: None or (total_nnz,) |
| attn_mask: None or (batch, max_seqlen) |
| bias: None or (batch, heads, max_seqlen, max_seqlen) |
| slopes: None or (batch, heads) or (heads,) |
| """ |
| assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}" |
| assert False |
| self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes) |
| if subset_idx is not None: |
| return self.output( |
| bert_padding.index_first_axis(self_output, subset_idx), |
| bert_padding.index_first_axis(input_tensor, subset_idx), |
| ) |
| else: |
| return self.output(self_output, input_tensor) |
|
|
|
|
| class FlexBertAttentionBase(nn.Module): |
| """A FlexBERT attention base class for type hints.""" |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_id = layer_id |
|
|
| def _init_weights(self, reset_params: bool = False): |
| raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
| def forward(self, hidden_states: torch.Tensor, attn_mask: torch.Tensor, **kwargs) -> torch.Tensor: |
| raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
| def extra_repr(self) -> str: |
| repr = "" |
| if hasattr(self, "num_attention_heads"): |
| repr += f"num_attention_heads={self.num_attention_heads}" |
| if hasattr(self, "attn_head_size"): |
| repr += f", attn_head_size={self.attn_head_size}" |
| if hasattr(self, "sliding_window"): |
| repr += f", sliding_window={self.sliding_window if self.sliding_window != (-1, -1) else 'False'}" |
| if hasattr(self, "use_fa2"): |
| repr += f", use_fa2={self.use_fa2}" |
| if hasattr(self, "deterministic_fa2"): |
| repr += f", deterministic_fa2={self.deterministic_fa2}" |
| return repr |
|
|
|
|
| class FlexBertUnpadAttention(FlexBertAttentionBase): |
| """Performs multi-headed self attention on a batch of unpadded sequences. |
| |
| If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. |
| If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, |
| which requires padding and unpadding inputs, adding some overhead. |
| |
| See `forward` method for additional detail. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attn_head_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
| self.use_fa2 = config.use_fa2 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| |
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| logger.warn_once( |
| "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" |
| " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." |
| ) |
| self.use_fa2 = False |
| if not self.use_fa2: |
| if not self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "SDPA attention is being used without an attention mask. Including padding in the " |
| " attention calculation may cause differences from the Flash Attention implementation." |
| ) |
| else: |
| logger.warn_once( |
| "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" |
| " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" |
| " with sequence length." |
| ) |
| if self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wqkv, |
| layer_dim=self.config.hidden_size, |
| layer_id=None, |
| type_of_module=ModuleType.in_module, |
| ) |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| indices: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. |
| |
| The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the |
| Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute |
| attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not |
| sending pad tokens through ffs saves compute. |
| |
| Args: |
| hidden_states: (total_nnz, dim) |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| indices: (total_nnz,) |
| attn_mask: (batch, max_seqlen) |
| |
| Returns: |
| attention: (total_nnz, dim) |
| """ |
| bs, dim = hidden_states.shape |
| qkv = self.Wqkv(hidden_states) |
|
|
| if self.use_fa2: |
| qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| attn = attn.view(bs, dim) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) |
| unpad_bs, seqlen, _ = qkv.shape |
|
|
| qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size) |
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ) |
| attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) |
| attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) |
|
|
| return self.out_drop(self.Wo(attn)) |
|
|
|
|
| class FlexBertUnpadParallelAttention(FlexBertAttentionBase): |
| """Computes the output of the multi-headed self parallel attention on a batch of unpadded sequences |
| |
| If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. |
| If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, |
| which requires padding and unpadding inputs, adding some overhead. |
| |
| See `forward` method for additional detail. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.hidden_size = config.hidden_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
| self.use_fa2 = config.use_fa2 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| |
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| logger.warn_once( |
| "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" |
| " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." |
| ) |
| self.use_fa2 = False |
| if not self.use_fa2: |
| if not self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "SDPA attention is being used without an attention mask. Including padding in the " |
| " attention calculation may cause differences from the Flash Attention implementation." |
| ) |
| else: |
| logger.warn_once( |
| "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" |
| " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" |
| " with sequence length." |
| ) |
| if self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| indices: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. |
| |
| The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the |
| Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute |
| attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not |
| sending pad tokens through ffs saves compute. |
| |
| Args: |
| qkv: (total_nnz, 3 * dim) |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| indices: (total_nnz,) |
| attn_mask: (batch, max_seqlen) |
| |
| Returns: |
| attention: (total_nnz, dim) |
| """ |
| bs = qkv.shape[0] |
| dim = self.hidden_size |
| if self.use_fa2: |
| qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| attn = attn.view(bs, dim) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) |
| unpad_bs, seqlen, _ = qkv.shape |
|
|
| qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size) |
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ) |
| attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) |
| attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) |
|
|
| return self.out_drop(self.Wo(attn.view(bs, dim))) |
|
|
|
|
| class FlexBertPaddedAttention(FlexBertAttentionBase): |
| """Performs multi-headed self attention on a batch of padded sequences. |
| |
| This module supports two attention implementations: |
| 1. Flash Attention 2 (if installed), which improves throughput. |
| 2. PyTorch's scaled_dot_product_attention. |
| |
| See `forward` method for additional detail. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attn_head_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
| self.use_fa2 = config.use_fa2 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| self.use_fa2 = False |
| if self.use_fa2 and self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "Flash Attention 2 does not support attention masks. Use unpadded attention " |
| "the equivalent functionality of masking out padding tokens." |
| ) |
| if not self.use_fa2 and self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wqkv, |
| layer_dim=self.config.hidden_size, |
| layer_id=None, |
| type_of_module=ModuleType.in_module, |
| ) |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: |
| Flash Attention 2 and PyTorch's scaled_dot_product_attention. |
| |
| Args: |
| hidden_states: (batch, seqlen, dim) |
| attn_mask: (batch, seqlen) |
| |
| Returns: |
| attention: (batch, seqlen, dim) |
| """ |
| bs, seqlen, dim = hidden_states.shape |
| qkv = self.Wqkv(hidden_states) |
|
|
| if self.use_fa2: |
| qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) |
|
|
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ).transpose(1, 2) |
|
|
| attn = attn.view(bs, seqlen, dim) |
| return self.out_drop(self.Wo(attn)) |
|
|
|
|
| class FlexBertUnpadRopeAttention(FlexBertAttentionBase): |
| """Performs multi-headed self attention on a batch of unpadded sequences. |
| |
| If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. |
| If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, |
| which requires padding and unpadding inputs, adding some overhead. |
| |
| See `forward` method for additional details. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attn_head_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| if config.rotary_emb_dim is None: |
| config.rotary_emb_dim = self.attn_head_size |
|
|
| rotary_base = config.rotary_emb_base |
| rotary_dim = config.rotary_emb_dim |
| if self.sliding_window != (-1, -1): |
| if config.local_attn_rotary_emb_base != -1: |
| rotary_base = config.local_attn_rotary_emb_base |
| if config.local_attn_rotary_emb_dim is not None: |
| rotary_dim = config.local_attn_rotary_emb_dim |
|
|
| assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed" |
| self.rotary_emb = UnpaddedRotaryEmbedding( |
| dim=rotary_dim, |
| base=rotary_base, |
| scale_base=config.rotary_emb_scale_base, |
| interleaved=config.rotary_emb_interleaved, |
| ) |
|
|
| self.use_fa2 = config.use_fa2 |
| |
| self.use_fa3 = config.use_fa2 and self.sliding_window == (-1, -1) and IMPL_USE_FLASH3 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
|
|
| |
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| logger.warn_once( |
| "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" |
| " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." |
| ) |
| self.use_fa2 = False |
| if not self.use_fa2: |
| if not self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "SDPA attention is being used without an attention mask. Including padding in the " |
| " attention calculation may cause differences from the Flash Attention implementation." |
| ) |
| else: |
| logger.warn_once( |
| "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" |
| " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" |
| " with sequence length." |
| ) |
| if self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wqkv, |
| layer_dim=self.config.hidden_size, |
| layer_id=None, |
| type_of_module=ModuleType.in_module, |
| ) |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| indices: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. |
| |
| The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the |
| Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute |
| attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not |
| sending pad tokens through ffs saves compute. |
| |
| Args: |
| hidden_states: (total_nnz, dim) |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| indices: (total_nnz,) |
| attn_mask: (batch, max_seqlen) |
| |
| Returns: |
| attention: (total_nnz, dim) |
| """ |
| bs, dim = hidden_states.shape |
| qkv = self.Wqkv(hidden_states) |
|
|
| |
| seqlen_offset = 0 |
|
|
| |
| qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) |
| qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset) |
|
|
| if self.use_fa3: |
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
| q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1) |
|
|
| attn, _ = flash_attn_varlen_func( |
| q=q, |
| k=k, |
| v=v, |
| cu_seqlens_q=cu_seqlens, |
| cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=max_seqlen, |
| max_seqlen_k=max_seqlen, |
| deterministic=self.deterministic_fa2, |
| causal=self.is_causal, |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1) |
| attn, _ = flash_attn_varlen_func( |
| q=q, |
| k=k, |
| v=v, |
| cu_seqlens_q=cu_seqlens, |
| cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=max_seqlen, |
| max_seqlen_k=max_seqlen, |
| deterministic=self.deterministic_fa2, |
| causal=self.is_causal, |
| ) |
| attn = attn.view(bs, dim) |
| elif self.use_fa2: |
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal, |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal, |
| ) |
| attn = attn.view(bs, dim) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = bert_padding.pad_input( |
| qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1] |
| ) |
| unpad_bs, seqlen, *_ = qkv.shape |
|
|
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ) |
| attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) |
| attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) |
|
|
| return self.out_drop(self.Wo(attn)) |
|
|
|
|
| class FlexBertPaddedRopeAttention(FlexBertAttentionBase): |
| """Performs multi-headed self attention on a batch of padded sequences. |
| |
| This module supports two attention implementations: |
| 1. Flash Attention 2 (if installed), which improves throughput. |
| 2. PyTorch's scaled_dot_product_attention. |
| |
| See `forward` method for additional details. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attn_head_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
|
|
| self.use_fa2 = config.use_fa2 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| if config.rotary_emb_dim is None: |
| config.rotary_emb_dim = self.attn_head_size |
|
|
| rotary_base = config.rotary_emb_base |
| rotary_dim = config.rotary_emb_dim |
| if self.sliding_window != (-1, -1): |
| if config.local_attn_rotary_emb_base != -1: |
| rotary_base = config.local_attn_rotary_emb_base |
| if config.local_attn_rotary_emb_dim is not None: |
| rotary_dim = config.local_attn_rotary_emb_dim |
|
|
| assert RotaryEmbedding is not None, "rotary_emb is not installed" |
| self.rotary_emb = RotaryEmbedding( |
| dim=rotary_dim, |
| base=rotary_base, |
| scale_base=config.rotary_emb_scale_base, |
| interleaved=config.rotary_emb_interleaved, |
| ) |
|
|
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| self.use_fa2 = False |
| if self.use_fa2 and self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "Flash Attention 2 does not support attention masks. Use unpadded attention " |
| "the equivalent functionality of masking out padding tokens." |
| ) |
| if not self.use_fa2 and self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wqkv, |
| layer_dim=self.config.hidden_size, |
| layer_id=None, |
| type_of_module=ModuleType.in_module, |
| ) |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: |
| Flash Attention 2 and PyTorch's scaled_dot_product_attention. |
| |
| Args: |
| hidden_states: (batch, seqlen, dim) |
| attn_mask: (batch, seqlen) |
| |
| Returns: |
| attention: (batch, seqlen, dim) |
| """ |
| bs, seqlen, dim = hidden_states.shape |
| qkv = self.Wqkv(hidden_states) |
|
|
| seqlen_offset = 0 |
|
|
| |
| qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) |
|
|
| if IMPL_USE_FLASH2: |
| |
| qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal, |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) |
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ).transpose(1, 2) |
|
|
| attn = attn.view(bs, seqlen, dim) |
| return self.out_drop(self.Wo(attn)) |
|
|
|
|
| class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase): |
| """Performs multi-headed self attention on a batch of unpadded sequences. |
| |
| If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. |
| If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, |
| which requires padding and unpadding inputs, adding some overhead. |
| |
| See `forward` method for additional details. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.hidden_size = config.hidden_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| if config.rotary_emb_dim is None: |
| config.rotary_emb_dim = self.attn_head_size |
|
|
| rotary_base = config.rotary_emb_base |
| rotary_dim = config.rotary_emb_dim |
| if self.sliding_window != (-1, -1): |
| if config.local_attn_rotary_emb_base != -1: |
| rotary_base = config.local_attn_rotary_emb_base |
| if config.local_attn_rotary_emb_dim is not None: |
| rotary_dim = config.local_attn_rotary_emb_dim |
|
|
| assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed" |
| self.rotary_emb = UnpaddedRotaryEmbedding( |
| dim=rotary_dim, |
| base=rotary_base, |
| scale_base=config.rotary_emb_scale_base, |
| interleaved=config.rotary_emb_interleaved, |
| ) |
|
|
| self.use_fa2 = config.use_fa2 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
|
|
| |
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| logger.warn_once( |
| "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" |
| " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." |
| ) |
| self.use_fa2 = False |
| if not self.use_fa2: |
| if not self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "SDPA attention is being used without an attention mask. Including padding in the " |
| " attention calculation may cause differences from the Flash Attention implementation." |
| ) |
| else: |
| logger.warn_once( |
| "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" |
| " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" |
| " with sequence length." |
| ) |
| if self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| indices: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. |
| |
| The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the |
| Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute |
| attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not |
| sending pad tokens through ffs saves compute. |
| |
| Args: |
| qkv: (total_nnz, 3 * dim) |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| indices: (total_nnz,) |
| attn_mask: (batch, max_seqlen) |
| |
| Returns: |
| attention: (total_nnz, dim) |
| """ |
| bs = qkv.shape[0] |
| dim = self.hidden_size |
|
|
| |
| seqlen_offset = 0 |
|
|
| |
| qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) |
| qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset) |
|
|
| if self.use_fa2: |
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal, |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal, |
| ) |
| attn = attn.view(bs, dim) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = bert_padding.pad_input( |
| qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1] |
| ) |
| unpad_bs, seqlen, *_ = qkv.shape |
|
|
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ) |
| attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) |
| attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) |
|
|
| return self.out_drop(self.Wo(attn)) |
|
|
|
|
| class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase): |
| """Performs multi-headed self attention on a batch of padded sequences. |
| |
| This module supports two attention implementations: |
| 1. Flash Attention 2 (if installed), which improves throughput. |
| 2. PyTorch's scaled_dot_product_attention. |
| |
| See `forward` method for additional details. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.hidden_size = config.hidden_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
|
|
| self.use_fa2 = config.use_fa2 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| self.use_fa2 = False |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| if config.rotary_emb_dim is None: |
| config.rotary_emb_dim = self.attn_head_size |
|
|
| rotary_base = config.rotary_emb_base |
| rotary_dim = config.rotary_emb_dim |
| if self.sliding_window != (-1, -1): |
| if config.local_attn_rotary_emb_base != -1: |
| rotary_base = config.local_attn_rotary_emb_base |
| if config.local_attn_rotary_emb_dim is not None: |
| rotary_dim = config.local_attn_rotary_emb_dim |
|
|
| assert RotaryEmbedding is not None, "rotary_emb is not installed" |
| self.rotary_emb = RotaryEmbedding( |
| dim=rotary_dim, |
| base=rotary_base, |
| scale_base=config.rotary_emb_scale_base, |
| interleaved=config.rotary_emb_interleaved, |
| ) |
|
|
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| self.use_fa2 = False |
| if self.use_fa2 and self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "Flash Attention 2 does not support attention masks. Use unpadded attention " |
| "the equivalent functionality of masking out padding tokens." |
| ) |
| if not self.use_fa2 and self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: |
| Flash Attention 2 and PyTorch's scaled_dot_product_attention. |
| |
| Args: |
| qkv: (batch, seqlen, 3 * dim) |
| attn_mask: (batch, seqlen) |
| |
| Returns: |
| attention: (batch, seqlen, dim) |
| """ |
| bs, seqlen, _ = qkv.shape |
| dim = self.hidden_size |
|
|
| seqlen_offset = 0 |
|
|
| |
| qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) |
|
|
| if self.use_fa2: |
| |
| qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| else: |
| assert not self.is_causal, f"causal mask not implemented here yet" |
| assert False |
| qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) |
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ).transpose(1, 2) |
|
|
| attn = attn.view(bs, seqlen, dim) |
| return self.out_drop(self.Wo(attn)) |
|
|
|
|
| class FlexBertPaddedParallelAttention(FlexBertAttentionBase): |
| """Performs multi-headed self attention on a batch of padded sequences. |
| |
| This module supports two attention implementations: |
| 1. Flash Attention 2 (if installed), which improves throughput. |
| 2. PyTorch's scaled_dot_product_attention. |
| |
| See `forward` method for additional detail. |
| """ |
|
|
| def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| super().__init__(config=config, layer_id=layer_id) |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
| f"heads ({config.num_attention_heads})" |
| ) |
|
|
| self.is_causal = config.causal_mask |
| self.num_attention_heads = config.num_attention_heads |
| self.attn_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.hidden_size = config.hidden_size |
| self.p_dropout = config.attention_probs_dropout_prob |
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) |
| self.out_drop = ( |
| nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() |
| ) |
| self.use_fa2 = config.use_fa2 |
| self.deterministic_fa2 = config.deterministic_fa2 |
| self.use_sdpa_attn_mask = config.use_sdpa_attn_mask |
|
|
| if config.global_attn_every_n_layers > 0: |
| if config.sliding_window == -1: |
| raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") |
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
| else: |
| self.sliding_window = (-1, -1) |
| else: |
| self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) |
|
|
| if not IMPL_USE_FLASH2 and self.use_fa2: |
| self.use_fa2 = False |
| if self.use_fa2 and self.use_sdpa_attn_mask: |
| logger.warn_once( |
| "Flash Attention 2 does not support attention masks. Use unpadded attention " |
| "the equivalent functionality of masking out padding tokens." |
| ) |
| if not self.use_fa2 and self.sliding_window[0] > 0: |
| raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") |
|
|
| def _init_weights(self, reset_params: bool = False): |
| init_weights( |
| self.config, |
| self.Wo, |
| layer_dim=self.config.hidden_size, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Perform self-attention. |
| |
| There are two attention implementations supported: |
| Flash Attention 2 and PyTorch's scaled_dot_product_attention. |
| |
| Args: |
| qkv: (batch, seqlen, 3 * dim) |
| attn_mask: (batch, seqlen) |
| |
| Returns: |
| attention: (batch, seqlen, dim) |
| """ |
| bs, seqlen, _ = qkv.shape |
| dim = self.hidden_size |
|
|
| if self.use_fa2: |
| qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(torch.bfloat16) |
|
|
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=self.p_dropout, |
| deterministic=self.deterministic_fa2, |
| window_size=self.sliding_window, |
| causal=self.is_causal |
| ) |
| else: |
| assert not self.is_causal, f"causal attention mask not yet implemented here" |
| assert False |
| qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) |
| q, k, v = qkv.transpose(3, 1).unbind(dim=2) |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| dropout_p=self.p_dropout, |
| attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) |
| if self.use_sdpa_attn_mask |
| else None, |
| ).transpose(1, 2) |
|
|
| attn = attn.view(bs, seqlen, dim) |
| return self.out_drop(self.Wo(attn)) |
|
|
|
|
| ATTN2CLS = { |
| "unpadded_base": FlexBertUnpadAttention, |
| "padded_base": FlexBertPaddedAttention, |
| "unpadded_parallel": FlexBertUnpadParallelAttention, |
| "padded_parallel": FlexBertPaddedParallelAttention, |
| "unpadded_rope": FlexBertUnpadRopeAttention, |
| "padded_rope": FlexBertPaddedRopeAttention, |
| "unpadded_rope_parallel": FlexBertUnpadRopeParallelAttention, |
| "padded_rope_parallel": FlexBertPaddedRopeParallelAttention, |
| } |
|
|
|
|
| def get_attention_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertAttentionBase: |
| try: |
| attention_layer = ( |
| config.initial_attention_layer |
| if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None |
| else config.attention_layer |
| ) |
| return ATTN2CLS[maybe_add_padding(config, attention_layer)](config, layer_id=layer_id) |
| except KeyError: |
| if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None: |
| raise ValueError( |
| f"Invalid attention layer type: {config.initial_attention_layer=}, must be one of {ATTN2CLS.keys()}." |
| f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified." |
| ) |
| else: |
| raise ValueError( |
| f"Invalid attention layer type: {config.attention_layer=}, must be one of {ATTN2CLS.keys()}. " |
| f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified." |
| ) |
|
|