| |
| |
| |
| |
| @@ -52,6 +52,7 @@ line-ending = "auto" |
| addopts = "--doctest-glob='**/*.md'" |
| doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" |
| markers = [ |
| + "flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')", |
| "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", |
| "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", |
| "generate: marks tests that use the GenerationTesterMixin" |
| |
| |
| |
| |
| @@ -75,6 +75,7 @@ def flash_attention_forward( |
| softcap=softcap, |
| use_top_left_mask=_use_top_left_mask, |
| target_dtype=target_dtype, |
| + attn_implementation=module.config._attn_implementation, |
| **kwargs, |
| ) |
| |
| |
| |
| |
| |
| @@ -14,6 +14,7 @@ |
| |
| import inspect |
| import os |
| +import warnings |
| from typing import Optional, TypedDict |
| |
| import torch |
| @@ -21,6 +22,7 @@ |
| |
| from .utils import ( |
| is_flash_attn_2_available, |
| + is_flash_attn_3_available, |
| is_flash_attn_greater_or_equal, |
| is_flash_attn_greater_or_equal_2_10, |
| is_torch_npu_available, |
| @@ -32,18 +34,123 @@ |
| flash_attn_func = None |
| |
| |
| -if is_flash_attn_2_available(): |
| - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa |
| - from flash_attn import flash_attn_func, flash_attn_varlen_func |
| - from flash_attn.layers.rotary import apply_rotary_emb # noqa |
| +def _index_first_axis(tensor, indices): |
| + """ |
| + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, |
| + after flattening the first two dimensions of the tensor. This is functionally equivalent to |
| + FA2's `index_first_axis` and replaces the need to import it. |
| + """ |
| + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first |
| + # two dimensions to get (total_tokens, ...) before indexing. |
| + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) |
| + return reshaped_tensor[indices] |
| + |
| + |
| +def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): |
| + """ |
| + FA3-compatible unpad_input function. |
| |
| + Arguments: |
| + hidden_states: (batch, seqlen, ...) |
| + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. |
| + Return: |
| + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. |
| + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. |
| + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| + max_seqlen_in_batch: int |
| + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. |
| + """ |
| + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask |
| + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) |
| + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() |
| + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| + |
| + return ( |
| + _index_first_axis(hidden_states, indices), |
| + indices, |
| + cu_seqlens, |
| + max_seqlen_in_batch, |
| + used_seqlens_in_batch, |
| + ) |
| + |
| + |
| +def _fa3_pad_input(hidden_states, indices, batch, seqlen): |
| + """ |
| + FA3-compatible pad_input function. |
| + |
| + Arguments: |
| + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| + batch: int, batch size for the padded sequence. |
| + seqlen: int, maximum sequence length for the padded sequence. |
| + Return: |
| + hidden_states: (batch, seqlen, ...) |
| + """ |
| + dim = hidden_states.shape[1:] |
| + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) |
| + output[indices] = hidden_states |
| + return output.view(batch, seqlen, *dim) |
| + |
| + |
| +FA_VERSION = None |
| +if is_flash_attn_2_available(): |
| + from flash_attn import flash_attn_func as flash_attn_2_func |
| + from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func |
| + from flash_attn.bert_padding import pad_input as pad_input_fa2 |
| + from flash_attn.bert_padding import unpad_input as unpad_input_fa2 |
| + from flash_attn.layers.rotary import apply_rotary_emb |
| + |
| + HAS_FA2 = True |
| + FA_VERSION = 2 |
| +else: |
| + flash_attn_2_func = None |
| + flash_attn_2_varlen_func = None |
| + pad_input_fa2 = None |
| + unpad_input_fa2 = None |
| + apply_rotary_emb = None |
| + HAS_FA2 = False |
| + |
| +if is_flash_attn_3_available(): |
| + from flash_attn_interface import flash_attn_func as flash_attn_3_func |
| + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func |
| + |
| + pad_input_fa3 = _fa3_pad_input |
| + unpad_input_fa3 = _fa3_unpad_input |
| + HAS_FA3 = True |
| + FA_VERSION = 3 |
| +else: |
| + flash_attn_3_func = None |
| + flash_attn_3_varlen_func = None |
| + pad_input_fa3 = None |
| + unpad_input_fa3 = None |
| + HAS_FA3 = False |
| + |
| + |
| +# Current Flash Attention implementations |
| +if FA_VERSION: |
| + flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"] |
| + flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"] |
| + unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"] |
| + pad_input = globals()[f"pad_input_fa{FA_VERSION}"] |
| |
| # patch functions in package `flash-attn` when using flash-attention on Ascend NPU. |
| if is_torch_npu_available(): |
| - from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input |
| - from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa |
| - from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func |
| - from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func |
| + from .integrations.npu_flash_attention import ( |
| + npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401 |
| + ) |
| + from .integrations.npu_flash_attention import ( |
| + npu_flash_attn_func as flash_attn_func, |
| + ) |
| + from .integrations.npu_flash_attention import ( |
| + npu_flash_attn_varlen_func as flash_attn_varlen_func, |
| + ) |
| + from .integrations.npu_flash_attention import ( |
| + pad_input, |
| + unpad_input, |
| + ) |
| |
| |
| _flash_supports_window_size = False |
| @@ -56,6 +163,9 @@ |
| def is_flash_attn_available(): |
| """Determine whether flash-attention can be used or not.""" |
| |
| + if is_flash_attn_3_available(): |
| + return True |
| + |
| # if package `flash-attn` is available, flash-attention can be used natively. |
| if is_flash_attn_2_available(): |
| return True |
| @@ -70,6 +180,9 @@ def is_flash_attn_available(): |
| def flash_attn_supports_top_left_mask(): |
| """Determine whether flash-attention uses top-left or down-right mask""" |
| |
| + if is_flash_attn_3_available(): |
| + return False |
| + |
| if is_flash_attn_2_available(): |
| # top-left mask is used in package `flash-attn` with version lower than 2.1.0 |
| return not is_flash_attn_greater_or_equal_2_10() |
| @@ -116,6 +229,7 @@ def _upad_input( |
| value_layer: torch.Tensor, |
| attention_mask: torch.Tensor, |
| query_length: int, |
| + unpad_input_func, |
| ): |
| """ |
| Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. |
| @@ -134,6 +248,8 @@ def _upad_input( |
| Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| query_length (`int`): |
| Target length. |
| + unpad_input_func: |
| + The function to use for unpadding the input tensors. |
| |
| Return: |
| query_layer (`torch.Tensor`): |
| @@ -158,12 +274,10 @@ def _upad_input( |
| |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
| |
| - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) |
| - value_layer = index_first_axis( |
| - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| - ) |
| + key_layer = _index_first_axis(key_layer, indices_k) |
| + value_layer = _index_first_axis(value_layer, indices_k) |
| if query_length == kv_seq_len: |
| - query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) |
| + query_layer = _index_first_axis(query_layer, indices_k) |
| cu_seqlens_q = cu_seqlens_k |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| indices_q = indices_k |
| @@ -177,7 +291,7 @@ def _upad_input( |
| else: |
| # The -q_len: slice assumes left padding. |
| attention_mask = attention_mask[:, -query_length:] |
| - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask) |
| + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) |
| |
| return ( |
| query_layer, |
| @@ -189,7 +303,7 @@ def _upad_input( |
| ) |
| |
| |
| -def prepare_fa2_from_position_ids(query, key, value, position_ids): |
| +def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): |
| """ |
| This function returns necessary arguments to call `flash_attn_varlen_func`. |
| All three query, key, value states will be flattened. |
| @@ -239,6 +353,14 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): |
| return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) |
| |
| |
| +def prepare_fa2_from_position_ids(*args, **kwargs): |
| + warnings.warn( |
| + "The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.", |
| + FutureWarning, |
| + ) |
| + return _prepare_flash_attention_from_position_ids(*args, **kwargs) |
| + |
| + |
| def fa_peft_integration_check( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| @@ -303,6 +425,7 @@ def _flash_attention_forward( |
| max_length_q: Optional[int] = None, |
| max_length_k: Optional[int] = None, |
| target_dtype: Optional[torch.dtype] = None, |
| + attn_implementation: Optional[str] = None, |
| **kwargs, |
| ): |
| """ |
| @@ -329,7 +452,28 @@ def _flash_attention_forward( |
| Softcap for the attention logits, used e.g. in gemma2. |
| deterministic (`bool`, *optional*): |
| Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. |
| + attn_implementation (`str`, *optional*): |
| + The attention implementation to use. If None, will default to the one based on the environment. |
| """ |
| + if attn_implementation is None: |
| + _flash_attn_varlen_func = flash_attn_varlen_func |
| + _flash_attn_func = flash_attn_func |
| + _pad_input = pad_input |
| + _unpad_input = unpad_input |
| + _is_fa3 = HAS_FA3 |
| + elif attn_implementation == "flash_attention_3": |
| + _flash_attn_varlen_func = flash_attn_3_varlen_func |
| + _flash_attn_func = flash_attn_3_func |
| + _pad_input = pad_input_fa3 |
| + _unpad_input = unpad_input_fa3 |
| + _is_fa3 = True |
| + elif attn_implementation == "flash_attention_2": |
| + _flash_attn_varlen_func = flash_attn_2_varlen_func |
| + _flash_attn_func = flash_attn_2_func |
| + _pad_input = pad_input_fa2 |
| + _unpad_input = unpad_input_fa2 |
| + _is_fa3 = False |
| + |
| if not use_top_left_mask: |
| causal = is_causal |
| else: |
| @@ -342,6 +486,12 @@ def _flash_attention_forward( |
| ) |
| flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} |
| |
| + if _is_fa3: |
| + if dropout > 0.0: |
| + logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.") |
| + else: |
| + flash_kwargs["dropout_p"] = dropout |
| + |
| if flash_241: |
| if deterministic is None: |
| global deterministic_g |
| @@ -362,12 +512,12 @@ def _flash_attention_forward( |
| if attention_mask is not None: |
| batch_size = query_states.shape[0] |
| query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( |
| - query_states, key_states, value_states, attention_mask, query_length |
| + query_states, key_states, value_states, attention_mask, query_length, _unpad_input |
| ) |
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
| |
| - attn_output_unpad = flash_attn_varlen_func( |
| + attn_output_unpad = _flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| @@ -375,12 +525,11 @@ def _flash_attention_forward( |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| - dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| **flash_kwargs, |
| ) |
| - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| + attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| |
| # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing |
| # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. |
| @@ -394,7 +543,7 @@ def _flash_attention_forward( |
| |
| if cu_seq_lens_q is None or cu_seq_lens_k is None: |
| query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( |
| - prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) |
| + _prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids) |
| ) |
| |
| cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens |
| @@ -405,7 +554,7 @@ def _flash_attention_forward( |
| key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) |
| value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) |
| |
| - attn_output = flash_attn_varlen_func( |
| + attn_output = _flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| @@ -413,7 +562,6 @@ def _flash_attention_forward( |
| cu_seqlens_k=cu_seq_lens_k, |
| max_seqlen_q=max_length_q, |
| max_seqlen_k=max_length_k, |
| - dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| **flash_kwargs, |
| @@ -422,10 +570,12 @@ def _flash_attention_forward( |
| attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) |
| |
| else: |
| - attn_output = flash_attn_func( |
| - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs |
| + attn_output = _flash_attn_func( |
| + query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs |
| ) |
| |
| + if isinstance(attn_output, tuple): |
| + return attn_output[0] |
| return attn_output |
| |
| |
| |
| |
| |
| |
| @@ -105,6 +105,7 @@ |
| is_accelerate_available, |
| is_bitsandbytes_available, |
| is_flash_attn_2_available, |
| + is_flash_attn_3_available, |
| is_kernels_available, |
| is_offline_mode, |
| is_optimum_available, |
| @@ -1957,6 +1958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi |
| # Flash Attention 2 support |
| _supports_flash_attn_2 = False |
| |
| + # Flash Attention 3 support |
| + _supports_flash_attn_3 = False |
| + |
| # SDPA support |
| _supports_sdpa = False |
| |
| @@ -2247,6 +2251,8 @@ def _autoset_attn_implementation( |
| and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys() |
| ): |
| message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' |
| + if cls._supports_flash_attn_3: |
| + message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' |
| if cls._supports_flash_attn_2: |
| message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' |
| if cls._supports_sdpa: |
| @@ -2282,7 +2288,15 @@ def _autoset_attn_implementation( |
| ): |
| sub_config._attn_implementation_internal = curr_attn_implementation |
| |
| - if config._attn_implementation == "flash_attention_2": |
| + if config._attn_implementation == "flash_attention_3": |
| + cls._check_and_enable_flash_attn_3( |
| + config, |
| + torch_dtype=torch_dtype, |
| + device_map=device_map, |
| + hard_check_only=False, |
| + check_device_map=check_device_map, |
| + ) |
| + elif config._attn_implementation == "flash_attention_2": |
| cls._check_and_enable_flash_attn_2( |
| config, |
| torch_dtype=torch_dtype, |
| @@ -2498,6 +2512,94 @@ def _check_and_enable_flash_attn_2( |
| config._attn_implementation = "flash_attention_2" |
| return config |
| |
| + @classmethod |
| + def _check_and_enable_flash_attn_3( |
| + cls, |
| + config, |
| + torch_dtype: Optional[torch.dtype] = None, |
| + device_map: Optional[Union[str, dict[str, int]]] = None, |
| + check_device_map: bool = True, |
| + hard_check_only: bool = False, |
| + ) -> PretrainedConfig: |
| + """ |
| + Checks the availability of Flash Attention 3 and compatibility with the current model. |
| + |
| + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. |
| + """ |
| + if not cls._supports_flash_attn_3: |
| + raise ValueError( |
| + f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" |
| + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" |
| + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" |
| + ) |
| + |
| + if not is_flash_attn_3_available(): |
| + preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" |
| + |
| + if importlib.util.find_spec("flash_attn_3") is None: |
| + raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.") |
| + |
| + if torch.cuda.is_available(): |
| + major, _ = torch.cuda.get_device_capability() |
| + if major < 9: |
| + raise ValueError( |
| + f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0." |
| + ) |
| + else: |
| + raise ImportError(f"{preface} Flash Attention 3 is not available.") |
| + else: |
| + raise ValueError( |
| + f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device." |
| + ) |
| + |
| + if torch_dtype is None: |
| + logger.warning_once( |
| + "You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour" |
| + ) |
| + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: |
| + logger.warning_once( |
| + "Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but" |
| + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," |
| + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' |
| + ) |
| + |
| + if getattr(config, "alibi", False) or getattr(config, "use_alibi", False): |
| + raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.") |
| + |
| + # Check for attention dropout, which is incompatible with FA3 |
| + if hasattr(config, "attention_dropout") and config.attention_dropout > 0: |
| + raise ValueError( |
| + f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3." |
| + ) |
| + |
| + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, |
| + # or the model may be initialized under the context manager `with torch.device("cuda"):`. |
| + if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]: |
| + if torch.cuda.is_available(): |
| + logger.warning_once( |
| + "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU" |
| + " after initializing it on CPU with `model.to('cuda')`." |
| + ) |
| + else: |
| + raise ValueError( |
| + "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. " |
| + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " |
| + "or initialising the model on CPU and then moving it to GPU." |
| + ) |
| + elif ( |
| + check_device_map |
| + and device_map is not None |
| + and isinstance(device_map, dict) |
| + and ("cpu" in device_map.values() or "disk" in device_map.values()) |
| + ): |
| + raise ValueError( |
| + "You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to " |
| + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." |
| + ) |
| + if not hard_check_only: |
| + config._attn_implementation = "flash_attention_3" |
| + return config |
| + |
| @classmethod |
| def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: |
| """ |
| @@ -4134,7 +4236,7 @@ def from_pretrained( |
| |
| </Tip> |
| attn_implementation (`str`, *optional*): |
| - The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. |
| + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. |
| |
| > Parameters for big model inference |
| |
| @@ -5770,6 +5872,7 @@ class AttentionInterface(GeneralInterface): |
| # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if |
| # a new instance is created (in order to locally override a given function) |
| _global_mapping = { |
| + "flash_attention_3": flash_attention_forward, |
| "flash_attention_2": flash_attention_forward, |
| "flex_attention": flex_attention_forward, |
| "paged_attention": paged_attention_forward, |
| |
| |
| |
| |
| @@ -321,6 +321,7 @@ class ArceePreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["ArceeDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -667,6 +667,7 @@ class AriaPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["AriaDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -318,6 +318,7 @@ class BitNetPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["BitNetDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -355,6 +355,7 @@ class CoherePreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["CohereDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -334,6 +334,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Cohere2DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -504,6 +504,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["DeepseekV3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -556,6 +556,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["DiffLlamaDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = False |
| |
| |
| |
| |
| @@ -424,6 +424,7 @@ class Dots1PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Dots1DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -318,6 +318,7 @@ class GemmaPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["GemmaDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -339,6 +339,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Gemma2DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -422,6 +422,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): |
| "SiglipMultiheadAttentionPoolingHead", |
| ] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -335,6 +335,7 @@ class GlmPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["GlmDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -343,6 +343,7 @@ class Glm4PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Glm4DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -292,6 +292,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["GPTNeoXLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -305,6 +305,7 @@ class GranitePreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["GraniteDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -320,6 +320,7 @@ class HeliumPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["HeliumDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -320,6 +320,7 @@ class LlamaPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["LlamaDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -590,6 +590,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["MiniMaxDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -262,6 +262,7 @@ class MistralPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["MistralDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -417,6 +417,7 @@ class MixtralPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["MixtralDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -301,6 +301,7 @@ class OlmoPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["OlmoDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -305,6 +305,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Olmo2DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -295,6 +295,7 @@ class PhiPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["PhiDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -316,6 +316,7 @@ class Phi3PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Phi3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -1622,6 +1622,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Phi4MultimodalDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -266,6 +266,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen2DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -292,6 +292,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -424,6 +424,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Qwen3MoeDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -299,6 +299,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["Starcoder2DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -561,6 +561,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["T5GemmaBlock"] |
| _skip_keys_device_placement = ["past_key_values"] |
| + _supports_flash_attn_3 = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -86,6 +86,7 @@ |
| is_faiss_available, |
| is_fbgemm_gpu_available, |
| is_flash_attn_2_available, |
| + is_flash_attn_3_available, |
| is_flax_available, |
| is_flute_available, |
| is_fsdp_available, |
| @@ -571,6 +572,15 @@ def require_flash_attn(test_case): |
| return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) |
| |
| |
| +def require_flash_attn_3(test_case): |
| + """ |
| + Decorator marking a test that requires Flash Attention 3. |
| + |
| + These tests are skipped when Flash Attention 3 isn't installed. |
| + """ |
| + return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case) |
| + |
| + |
| def require_torch_sdpa(test_case): |
| """ |
| Decorator marking a test that requires PyTorch's SDPA. |
| |
| |
| |
| |
| @@ -153,6 +153,7 @@ |
| is_faiss_available, |
| is_fbgemm_gpu_available, |
| is_flash_attn_2_available, |
| + is_flash_attn_3_available, |
| is_flash_attn_greater_or_equal, |
| is_flash_attn_greater_or_equal_2_10, |
| is_flax_available, |
| |
| |
| |
| |
| @@ -926,6 +926,9 @@ class ClassAttrs: |
| _skip_keys_device_placement = r""" |
| A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library. |
| """ |
| + _supports_flash_attn_3 = r""" |
| + Whether the model's attention implementation supports FlashAttention 3.0. |
| + """ |
| _supports_flash_attn_2 = r""" |
| Whether the model's attention implementation supports FlashAttention 2.0. |
| """ |
| |
| |
| |
| |
| @@ -1120,6 +1120,25 @@ def is_flash_attn_2_available(): |
| return False |
| |
| |
| +@lru_cache() |
| +def is_flash_attn_3_available(): |
| + if not is_torch_available(): |
| + return False |
| + |
| + if not _is_package_available("flash_attn_3"): |
| + return False |
| + |
| + import torch |
| + |
| + if not torch.cuda.is_available(): |
| + return False |
| + |
| + # TODO: Check for a minimum version when FA3 is stable |
| + # return version.parse(importlib.metadata.version("flash_attn_3")) >= version.parse("3.0.0") |
| + |
| + return True |
| + |
| + |
| @lru_cache |
| def is_flash_attn_greater_or_equal_2_10(): |
| if not _is_package_available("flash_attn"): |
| |
| new file mode 100644 |
| |
| |
| |
| @@ -0,0 +1,144 @@ |
| +# Copyright 2025 Eduard Durech and SGLang team. |
| +# |
| +# Licensed under the Apache License, Version 2.0 (the "License"); |
| +# you may not use this file except in compliance with the License. |
| +# You may obtain a copy of the License at |
| +# |
| +# http://www.apache.org/licenses/LICENSE-2.0 |
| +# |
| +# Unless required by applicable law or agreed to in writing, software |
| +# distributed under the License is distributed on an "AS IS" BASIS, |
| +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| +# See the License for the specific language governing permissions and |
| +# limitations under the License. |
| +# |
| +# Usage: |
| +# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py |
| + |
| +import unittest |
| + |
| +import pytest |
| +import torch |
| + |
| +from transformers import AutoModelForCausalLM, AutoTokenizer |
| +from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow |
| + |
| + |
| +class FlashAttentionParityTest(unittest.TestCase): |
| + # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py |
| + def _lcs(self, X, Y): |
| + m = len(X) |
| + n = len(Y) |
| + L = [[0] * (n + 1) for _ in range(m + 1)] |
| + |
| + for i in range(m + 1): |
| + for j in range(n + 1): |
| + if i == 0 or j == 0: |
| + L[i][j] = 0 |
| + elif X[i - 1] == Y[j - 1]: |
| + L[i][j] = L[i - 1][j - 1] + 1 |
| + else: |
| + L[i][j] = max(L[i - 1][j], L[i][j - 1]) |
| + |
| + return L[m][n] |
| + |
| + # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py |
| + def _calculate_rouge_l(self, output_strs_list1, output_strs_list2): |
| + rouge_l_scores = [] |
| + |
| + for s1, s2 in zip(output_strs_list1, output_strs_list2): |
| + lcs_len = self._lcs(s1, s2) |
| + precision = lcs_len / len(s1) if len(s1) > 0 else 0 |
| + recall = lcs_len / len(s2) if len(s2) > 0 else 0 |
| + if precision + recall > 0: |
| + fmeasure = (2 * precision * recall) / (precision + recall) |
| + else: |
| + fmeasure = 0.0 |
| + rouge_l_scores.append(fmeasure) |
| + |
| + return rouge_l_scores |
| + |
| + def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5): |
| + for _ in range(n_warmup): |
| + model.generate(**inputs, max_new_tokens=20, do_sample=False) |
| + torch.cuda.synchronize() |
| + |
| + start_time = torch.cuda.Event(enable_timing=True) |
| + end_time = torch.cuda.Event(enable_timing=True) |
| + |
| + start_time.record() |
| + for _ in range(n_runs): |
| + model.generate(**inputs, max_new_tokens=20, do_sample=False) |
| + end_time.record() |
| + torch.cuda.synchronize() |
| + |
| + return start_time.elapsed_time(end_time) / n_runs |
| + |
| + @pytest.mark.flash_attn_3_test |
| + @require_torch_gpu |
| + @require_flash_attn |
| + @require_flash_attn_3 |
| + @slow |
| + def test_flash_attention_2_3_parity(self): |
| + model_id = "meta-llama/Llama-3.2-1B-Instruct" |
| + prompt = "The ETH AI Center is" |
| + |
| + # 1. Load FA2 model and tokenizer |
| + model_2 = AutoModelForCausalLM.from_pretrained( |
| + model_id, |
| + torch_dtype=torch.bfloat16, |
| + attn_implementation="flash_attention_2", |
| + ).to("cuda") |
| + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| + |
| + # 2. Load FA3 model |
| + try: |
| + model_3 = AutoModelForCausalLM.from_pretrained( |
| + model_id, |
| + torch_dtype=torch.bfloat16, |
| + attn_implementation="flash_attention_3", |
| + ).to("cuda") |
| + except (ValueError, ImportError) as e: |
| + pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}") |
| + |
| + # 3. Generate with both models |
| + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
| + |
| + with torch.no_grad(): |
| + output_2 = model_2.generate( |
| + **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True |
| + ) |
| + output_3 = model_3.generate( |
| + **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True |
| + ) |
| + |
| + # 4. Correctness check |
| + # 4a. Logits |
| + logits_2 = torch.stack(output_2.scores) |
| + logits_3 = torch.stack(output_3.scores) |
| + torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3) |
| + logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1) |
| + logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1) |
| + max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item() |
| + |
| + # 4b. Generated text |
| + text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True) |
| + text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True) |
| + rouge_score = self._calculate_rouge_l([text_2], [text_3])[0] |
| + assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})" |
| + |
| + # 5. Performance check |
| + with torch.no_grad(): |
| + time_2 = self._benchmark_generation(model_2, inputs) |
| + time_3 = self._benchmark_generation(model_3, inputs) |
| + |
| + print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---") |
| + print(f"Prompt: '{prompt}'") |
| + print(f"Generated text with Flash Attention 2: {text_2}") |
| + print(f"Generated text with Flash Attention 3: {text_3}") |
| + print(f"ROUGE-L: {rouge_score}") |
| + print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}") |
| + print(f"Flash Attention 2 latency: {time_2:.2f} ms") |
| + print(f"Flash Attention 3 latency: {time_3:.2f} ms") |
| + print(f"Speed-up: {time_2 / time_3:.2f}x") |
| + print("---") |
| |
| |
| |
| |
| @@ -34,6 +34,7 @@ |
| is_flaky, |
| require_accelerate, |
| require_flash_attn, |
| + require_flash_attn_3, |
| require_optimum_quanto, |
| require_read_token, |
| require_torch, |
| @@ -2292,6 +2293,7 @@ def _test_attention_implementation(self, attn_implementation): |
| support_flag = { |
| "sdpa": "_supports_sdpa", |
| "flash_attention_2": "_supports_flash_attn_2", |
| + "flash_attention_3": "_supports_flash_attn_3", |
| } |
| |
| for model_class in self.all_generative_model_classes: |
| @@ -2369,6 +2371,14 @@ def test_eager_matches_fa2_generate(self): |
| """Tests that generate has equivalent outputs with FA2 and eager attention implementations.""" |
| self._test_attention_implementation("flash_attention_2") |
| |
| + @pytest.mark.flash_attn_3_test |
| + @require_flash_attn_3 |
| + @require_torch_gpu |
| + @slow |
| + def test_eager_matches_fa3_generate(self): |
| + """Tests that generate has equivalent outputs with FA3 and eager attention implementations.""" |
| + self._test_attention_implementation("flash_attention_3") |
| + |
| def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): |
| input_batch_size = int(output.sequences.shape[0] / num_return_sequences) |
| internal_batch_size = ( |
| |
| |
| |
| |
| @@ -84,6 +84,7 @@ |
| require_bitsandbytes, |
| require_deepspeed, |
| require_flash_attn, |
| + require_flash_attn_3, |
| require_non_hpu, |
| require_safetensors, |
| require_torch, |
| @@ -3129,18 +3130,19 @@ def test_model_is_small(self): |
| f"{model_class} is too big for the common tests ({num_params})! It should have 1M max." |
| ) |
| |
| - @require_flash_attn |
| - @require_torch_gpu |
| - @mark.flash_attn_test |
| - @slow |
| - @is_flaky() |
| - def test_flash_attn_2_inference_equivalence(self): |
| + def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str): |
| + r""" |
| + Tests the equivalence between the eager and flash attention implementations. |
| + This test is only for inference and runs with `torch_dtype=torch.bfloat16`. |
| + """ |
| if not self.has_attentions: |
| self.skipTest(reason="Model architecture does not support attentions") |
| |
| for model_class in self.all_model_classes: |
| - if not model_class._supports_flash_attn_2: |
| - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") |
| + if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( |
| + attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 |
| + ): |
| + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") |
| |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| model = model_class(config) |
| @@ -3148,7 +3150,7 @@ def test_flash_attn_2_inference_equivalence(self): |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.save_pretrained(tmpdirname) |
| model_fa = model_class.from_pretrained( |
| - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" |
| + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation |
| ) |
| model_fa.to(torch_device) |
| |
| @@ -3163,9 +3165,12 @@ def test_flash_attn_2_inference_equivalence(self): |
| |
| if dummy_attention_mask is not None: |
| dummy_attention_mask = dummy_attention_mask[:1] |
| - dummy_attention_mask[:, 1:] = 1 |
| - dummy_attention_mask[:, :1] = 0 |
| - |
| + if padding_side == "left": |
| + dummy_attention_mask[:, 1:] = 1 |
| + dummy_attention_mask[:, :1] = 0 |
| + else: |
| + dummy_attention_mask[:, :-1] = 1 |
| + dummy_attention_mask[:, -1:] = 0 |
| if model.config.is_encoder_decoder: |
| decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] |
| |
| @@ -3220,104 +3225,46 @@ def test_flash_attn_2_inference_equivalence(self): |
| else outputs_fa.decoder_hidden_states[-1] |
| ) |
| |
| - assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) |
| + if padding_side == "left": |
| + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) |
| |
| - # check with inference + dropout |
| - model.train() |
| - _ = model_fa(dummy_input, **other_inputs) |
| + # check with inference + dropout |
| + model.train() |
| + _ = model_fa(dummy_input, **other_inputs) |
| + else: |
| + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) |
| |
| @require_flash_attn |
| @require_torch_gpu |
| @mark.flash_attn_test |
| @slow |
| @is_flaky() |
| - def test_flash_attn_2_inference_equivalence_right_padding(self): |
| - if not self.has_attentions: |
| - self.skipTest(reason="Model architecture does not support attentions") |
| - |
| - for model_class in self.all_model_classes: |
| - if not model_class._supports_flash_attn_2: |
| - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") |
| - |
| - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| - model = model_class(config) |
| - |
| - with tempfile.TemporaryDirectory() as tmpdirname: |
| - model.save_pretrained(tmpdirname) |
| - model_fa = model_class.from_pretrained( |
| - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" |
| - ) |
| - model_fa.to(torch_device) |
| - |
| - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) |
| - model.to(torch_device) |
| - |
| - dummy_input = inputs_dict[model.main_input_name][:1] |
| - if dummy_input.dtype in [torch.float32, torch.float16]: |
| - dummy_input = dummy_input.to(torch.bfloat16) |
| - |
| - dummy_attention_mask = inputs_dict.get("attention_mask", None) |
| - |
| - if dummy_attention_mask is not None: |
| - dummy_attention_mask = dummy_attention_mask[:1] |
| - dummy_attention_mask[:, :-1] = 1 |
| - dummy_attention_mask[:, -1:] = 0 |
| - |
| - if model.config.is_encoder_decoder: |
| - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] |
| - |
| - outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) |
| - outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) |
| - else: |
| - outputs = model(dummy_input, output_hidden_states=True) |
| - outputs_fa = model_fa(dummy_input, output_hidden_states=True) |
| - |
| - logits = ( |
| - outputs.hidden_states[-1] |
| - if not model.config.is_encoder_decoder |
| - else outputs.decoder_hidden_states[-1] |
| - ) |
| - logits_fa = ( |
| - outputs_fa.hidden_states[-1] |
| - if not model.config.is_encoder_decoder |
| - else outputs_fa.decoder_hidden_states[-1] |
| - ) |
| - |
| - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) |
| - |
| - if model.config.is_encoder_decoder: |
| - other_inputs = { |
| - "decoder_input_ids": decoder_input_ids, |
| - "decoder_attention_mask": dummy_attention_mask, |
| - "output_hidden_states": True, |
| - } |
| - if dummy_attention_mask is not None: |
| - other_inputs["attention_mask"] = dummy_attention_mask |
| - |
| - outputs = model(dummy_input, **other_inputs) |
| - outputs_fa = model_fa(dummy_input, **other_inputs) |
| - else: |
| - other_inputs = { |
| - "output_hidden_states": True, |
| - } |
| - if dummy_attention_mask is not None: |
| - other_inputs["attention_mask"] = dummy_attention_mask |
| + def test_flash_attn_2_inference_equivalence(self): |
| + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="left") |
| |
| - outputs = model(dummy_input, **other_inputs) |
| - outputs_fa = model_fa(dummy_input, **other_inputs) |
| + @require_flash_attn |
| + @require_torch_gpu |
| + @mark.flash_attn_test |
| + @slow |
| + @is_flaky() |
| + def test_flash_attn_2_inference_equivalence_right_padding(self): |
| + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="right") |
| |
| - logits = ( |
| - outputs.hidden_states[-1] |
| - if not model.config.is_encoder_decoder |
| - else outputs.decoder_hidden_states[-1] |
| - ) |
| - logits_fa = ( |
| - outputs_fa.hidden_states[-1] |
| - if not model.config.is_encoder_decoder |
| - else outputs_fa.decoder_hidden_states[-1] |
| - ) |
| + @require_flash_attn_3 |
| + @require_torch_gpu |
| + @mark.flash_attn_3_test |
| + @slow |
| + @is_flaky() |
| + def test_flash_attn_3_inference_equivalence(self): |
| + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="left") |
| |
| - assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) |
| + @require_flash_attn_3 |
| + @require_torch_gpu |
| + @mark.flash_attn_3_test |
| + @slow |
| + @is_flaky() |
| + def test_flash_attn_3_inference_equivalence_right_padding(self): |
| + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="right") |
| |
| def test_attn_implementation_composite_models(self): |
| """ |
| @@ -3959,24 +3906,21 @@ def test_sdpa_matches_eager_sliding_window(self): |
| torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4) |
| ) |
| |
| - @require_flash_attn |
| - @require_torch_gpu |
| - @mark.flash_attn_test |
| - def test_flash_attn_2_can_dispatch_composite_models(self): |
| + def flash_attn_can_dispatch_composite_models(self, attn_implementation: str): |
| """ |
| - Tests if composite models can dispatch on FA2 if the sub-models support FA2. |
| + Tests if composite models can dispatch on flash attention if the sub-models support it. |
| The tests is needed as we handle differently composite models and we cannot check them |
| - with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching |
| + with above tests. If any of the sub-models does not support flash attention, we'll raise an error when dispatching |
| that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific |
| backbone models (LM/vision/audio/etc) |
| """ |
| if not self.has_attentions: |
| self.skipTest(reason="Model architecture does not support attentions") |
| |
| - if not is_torch_fp16_available_on_device(torch_device): |
| - self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") |
| + if not is_torch_bf16_available_on_device(torch_device): |
| + self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)") |
| |
| - torch_dtype = torch.float16 |
| + torch_dtype = torch.bfloat16 |
| for model_class in self.all_model_classes: |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| model = model_class(config) |
| @@ -3987,44 +3931,64 @@ def test_flash_attn_2_can_dispatch_composite_models(self): |
| model.save_pretrained(tmpdirname) |
| model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) |
| |
| - sub_models_supporting_fa2 = [ |
| - module._supports_flash_attn_2 |
| + sub_models_supporting_fa = [ |
| + ( |
| + module._supports_flash_attn_3 |
| + if attn_implementation == "flash_attention_3" |
| + else module._supports_flash_attn_2 |
| + ) |
| for name, module in model.named_modules() |
| if isinstance(module, PreTrainedModel) and name != "" |
| ] |
| - supports_fa2_all_modules = ( |
| - all(sub_models_supporting_fa2) |
| - if len(sub_models_supporting_fa2) > 0 |
| - else model._supports_flash_attn_2 |
| + supports_fa_all_modules = ( |
| + all(sub_models_supporting_fa) |
| + if len(sub_models_supporting_fa) > 0 |
| + else ( |
| + model._supports_flash_attn_3 |
| + if attn_implementation == "flash_attention_3" |
| + else model._supports_flash_attn_2 |
| + ) |
| ) |
| - if not supports_fa2_all_modules: |
| + if not supports_fa_all_modules: |
| with self.assertRaises(ValueError): |
| - model_fa2 = model_class.from_pretrained( |
| + model_fa = model_class.from_pretrained( |
| tmpdirname, |
| torch_dtype=torch_dtype, |
| - attn_implementation="flash_attention_2", |
| + attn_implementation=attn_implementation, |
| ) |
| else: |
| - model_fa2 = model_class.from_pretrained( |
| - tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2" |
| + model_fa = model_class.from_pretrained( |
| + tmpdirname, torch_dtype=torch_dtype, attn_implementation=attn_implementation |
| ) |
| - for key in model_fa2.config: |
| - if isinstance(getattr(model_fa2.config, key), PretrainedConfig): |
| - sub_config = getattr(model_fa2.config, key) |
| - self.assertTrue(sub_config._attn_implementation == "flash_attention_2") |
| + for key in model_fa.config: |
| + if isinstance(getattr(model_fa.config, key), PretrainedConfig): |
| + sub_config = getattr(model_fa.config, key) |
| + self.assertTrue(sub_config._attn_implementation == attn_implementation) |
| |
| - has_fa2 = False |
| - for name, submodule in model_fa2.named_modules(): |
| + has_fa = False |
| + for name, submodule in model_fa.named_modules(): |
| class_name = submodule.__class__.__name__ |
| if ( |
| "Attention" in class_name |
| and getattr(submodule, "config", None) |
| - and submodule.config._attn_implementation == "flash_attention_2" |
| + and submodule.config._attn_implementation == attn_implementation |
| ): |
| - has_fa2 = True |
| + has_fa = True |
| break |
| - if not has_fa2: |
| - raise ValueError("The FA2 model should have FA2 layers") |
| + if not has_fa: |
| + raise ValueError(f"The {attn_implementation} model should have {attn_implementation} layers") |
| + |
| + @require_flash_attn |
| + @require_torch_gpu |
| + @mark.flash_attn_test |
| + def test_flash_attn_2_can_dispatch_composite_models(self): |
| + self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_2") |
| + |
| + @require_flash_attn_3 |
| + @require_torch_gpu |
| + @mark.flash_attn_3_test |
| + def test_flash_attn_3_can_dispatch_composite_models(self): |
| + self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_3") |
| |
| @require_flash_attn |
| @require_torch_gpu |
| @@ -4121,27 +4085,29 @@ def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(s |
| |
| assert not loss.isnan().any() |
| |
| - @require_flash_attn |
| - @require_torch_gpu |
| - @mark.flash_attn_test |
| - @slow |
| - def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| + def flash_attention_padding_matches_padding_free_with_position_ids( |
| + self, attn_implementation: str, fa_kwargs: bool = False |
| + ): |
| if not self.has_attentions: |
| self.skipTest(reason="Model architecture does not support attentions") |
| |
| max_new_tokens = 30 |
| |
| for model_class in self.all_generative_model_classes: |
| - if not model_class._supports_flash_attn_2: |
| - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") |
| + if not ( |
| + model_class._supports_flash_attn_2 |
| + if attn_implementation == "flash_attention_2" |
| + else model_class._supports_flash_attn_3 |
| + ): |
| + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") |
| |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: |
| self.skipTest("Model dummy inputs should contain padding in their attention mask") |
| |
| dummy_input = inputs_dict[model_class.main_input_name] |
| - if dummy_input.dtype in [torch.float32, torch.bfloat16]: |
| - dummy_input = dummy_input.to(torch.float16) |
| + if dummy_input.dtype in [torch.float32, torch.float16]: |
| + dummy_input = dummy_input.to(torch.bfloat16) |
| |
| # make sure that all models have enough positions for generation |
| if hasattr(config, "max_position_embeddings"): |
| @@ -4151,7 +4117,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| if "position_ids" not in inspect.signature(model.forward).parameters: |
| self.skipTest("Model does not support position_ids") |
| |
| - if "position_ids" not in inspect.signature(model.forward).parameters: |
| + if (not fa_kwargs) and "position_ids" not in inspect.signature(model.forward).parameters: |
| continue # this model doesn't accept position ids as input |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| @@ -4166,26 +4132,40 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| model = ( |
| model_class.from_pretrained( |
| tmpdirname, |
| - torch_dtype=torch.float16, |
| - attn_implementation="flash_attention_2", |
| + torch_dtype=torch.bfloat16, |
| + attn_implementation=attn_implementation, |
| ) |
| .to(torch_device) |
| .eval() |
| ) |
| |
| - # flatten |
| - padfree_inputs_dict = { |
| - k: v[dummy_attention_mask.bool()].unsqueeze(0) |
| - for k, v in inputs_dict.items() |
| - if not k == "attention_mask" |
| - } |
| - # add position_ids |
| - padfree_inputs_dict["position_ids"] = ( |
| - torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]) |
| - .long() |
| - .unsqueeze(0) |
| - .to(torch_device) |
| - ) |
| + if fa_kwargs: |
| + # flatten |
| + features = [ |
| + {"input_ids": i[a.bool()].tolist()} |
| + for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) |
| + ] |
| + |
| + # add position_ids + fa_kwargs |
| + data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True) |
| + batch = data_collator(features) |
| + padfree_inputs_dict = { |
| + k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items() |
| + } |
| + else: |
| + # flatten |
| + padfree_inputs_dict = { |
| + k: v[dummy_attention_mask.bool()].unsqueeze(0) |
| + for k, v in inputs_dict.items() |
| + if not k == "attention_mask" |
| + } |
| + # add position_ids |
| + padfree_inputs_dict["position_ids"] = ( |
| + torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]) |
| + .long() |
| + .unsqueeze(0) |
| + .to(torch_device) |
| + ) |
| |
| res_padded = model(**inputs_dict) |
| res_padfree = model(**padfree_inputs_dict) |
| @@ -4195,119 +4175,96 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| |
| torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) |
| # acceptable numerical instability |
| - tol = torch.finfo(torch.float16).eps |
| + tol = torch.finfo(torch.bfloat16).eps |
| torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) |
| |
| @require_flash_attn |
| @require_torch_gpu |
| @mark.flash_attn_test |
| @slow |
| - def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): |
| - if not self.has_attentions: |
| - self.skipTest(reason="Model architecture does not support attentions") |
| - |
| - max_new_tokens = 30 |
| - |
| - for model_class in self.all_generative_model_classes: |
| - if not model_class._supports_flash_attn_2: |
| - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") |
| - |
| - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| - if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: |
| - self.skipTest("Model dummy inputs should contain padding in their attention mask") |
| - |
| - dummy_input = inputs_dict[model_class.main_input_name] |
| - if dummy_input.dtype in [torch.float32, torch.bfloat16]: |
| - dummy_input = dummy_input.to(torch.float16) |
| - |
| - # make sure that all models have enough positions for generation |
| - if hasattr(config, "max_position_embeddings"): |
| - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 |
| - |
| - model = model_class(config) |
| - if "position_ids" not in inspect.signature(model.forward).parameters: |
| - self.skipTest("Model does not support position_ids") |
| - |
| - with tempfile.TemporaryDirectory() as tmpdirname: |
| - model.save_pretrained(tmpdirname) |
| - |
| - # ensure left padding, to adapt for some models |
| - if 0 in inputs_dict["attention_mask"][:, -1]: |
| - inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) |
| - dummy_attention_mask = inputs_dict["attention_mask"] |
| - inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id |
| - |
| - model = ( |
| - model_class.from_pretrained( |
| - tmpdirname, |
| - torch_dtype=torch.float16, |
| - attn_implementation="flash_attention_2", |
| - ) |
| - .to(torch_device) |
| - .eval() |
| - ) |
| - |
| - # flatten |
| - features = [ |
| - {"input_ids": i[a.bool()].tolist()} |
| - for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) |
| - ] |
| - |
| - # add position_ids + fa_kwargs |
| - data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True) |
| - batch = data_collator(features) |
| - batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()} |
| - |
| - res_padded = model(**inputs_dict) |
| - res_padfree = model(**batch_accelerator) |
| - |
| - logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] |
| - logits_padfree = res_padfree.logits[0] |
| - |
| - torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) |
| - # acceptable numerical instability |
| - tol = torch.finfo(torch.float16).eps |
| - torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) |
| + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| + self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2") |
| |
| @require_flash_attn |
| @require_torch_gpu |
| @mark.flash_attn_test |
| @slow |
| - def test_flash_attn_2_from_config(self): |
| + def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): |
| + self.flash_attention_padding_matches_padding_free_with_position_ids( |
| + attn_implementation="flash_attention_2", fa_kwargs=True |
| + ) |
| + |
| + @require_flash_attn_3 |
| + @require_torch_gpu |
| + @mark.flash_attn_3_test |
| + @slow |
| + def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self): |
| + self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3") |
| + |
| + @require_flash_attn_3 |
| + @require_torch_gpu |
| + @mark.flash_attn_3_test |
| + @slow |
| + def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): |
| + self.flash_attention_padding_matches_padding_free_with_position_ids( |
| + attn_implementation="flash_attention_3", fa_kwargs=True |
| + ) |
| + |
| + def flash_attn_from_config(self, attn_implementation: str): |
| + r""" |
| + Tests if the model can be loaded with `attn_implementation` from the config and if the |
| + weights are not randomly initialized. |
| + """ |
| if not self.has_attentions: |
| self.skipTest(reason="Model architecture does not support attentions") |
| |
| for model_class in self.all_generative_model_classes: |
| - if not model_class._supports_flash_attn_2: |
| - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") |
| + if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( |
| + attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 |
| + ): |
| + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") |
| |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| # TODO: to change it in the future with other relevant auto classes |
| - fa2_model = model_class._from_config( |
| - config, attn_implementation="flash_attention_2", torch_dtype=torch.float16 |
| + fa_model = model_class._from_config( |
| + config, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16 |
| ).to(torch_device) |
| |
| - dummy_input = inputs_dict[fa2_model.main_input_name] |
| - if dummy_input.dtype in [torch.float32, torch.bfloat16]: |
| - dummy_input = dummy_input.to(torch.float16) |
| + dummy_input = inputs_dict[fa_model.main_input_name] |
| + if dummy_input.dtype in [torch.float32, torch.float16]: |
| + dummy_input = dummy_input.to(torch.bfloat16) |
| dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) |
| |
| - if fa2_model.config.is_encoder_decoder: |
| + if fa_model.config.is_encoder_decoder: |
| dummy_decoder_input_ids = inputs_dict["decoder_input_ids"] |
| dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"] |
| - _ = fa2_model( |
| + _ = fa_model( |
| dummy_input, |
| attention_mask=dummy_attention_mask, |
| decoder_input_ids=dummy_decoder_input_ids, |
| decoder_attention_mask=dummy_decoder_attention_mask, |
| ) |
| else: |
| - _ = fa2_model(dummy_input, attention_mask=dummy_attention_mask) |
| + _ = fa_model(dummy_input, attention_mask=dummy_attention_mask) |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| - fa2_model.save_pretrained(tmpdirname) |
| + fa_model.save_pretrained(tmpdirname) |
| model_from_pretrained = model_class.from_pretrained(tmpdirname) |
| - self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2") |
| + self.assertTrue(model_from_pretrained.config._attn_implementation != attn_implementation) |
| + |
| + @require_flash_attn |
| + @require_torch_gpu |
| + @mark.flash_attn_test |
| + @slow |
| + def test_flash_attn_2_from_config(self): |
| + self.flash_attn_from_config(attn_implementation="flash_attention_2") |
| + |
| + @require_flash_attn |
| + @require_torch_gpu |
| + @mark.flash_attn_3_test |
| + @slow |
| + def test_flash_attn_3_from_config(self): |
| + self.flash_attn_from_config(attn_implementation="flash_attention_3") |
| |
| def _get_custom_4d_mask_test_data(self): |
| # Sequence in which all but the last token is the same |
| |
| |
| |
| |
| @@ -77,6 +77,7 @@ |
| ) |
| from transformers.utils.import_utils import ( |
| is_flash_attn_2_available, |
| + is_flash_attn_3_available, |
| is_flax_available, |
| is_tf_available, |
| is_torch_npu_available, |
| @@ -676,6 +677,9 @@ def test_model_from_pretrained_attn_implementation(self): |
| if is_flash_attn_available(): |
| attn_implementation_available.append("flash_attention_2") |
| |
| + if is_flash_attn_3_available(): |
| + attn_implementation_available.append("flash_attention_3") |
| + |
| for requested_attn_implementation in attn_implementation_available: |
| model = AutoModelForCausalLM.from_pretrained( |
| TINY_MISTRAL, attn_implementation=requested_attn_implementation |
| @@ -700,6 +704,9 @@ def test_model_from_config_attn_implementation(self): |
| if is_flash_attn_available(): |
| attn_implementation_available.append("flash_attention_2") |
| |
| + if is_flash_attn_3_available(): |
| + attn_implementation_available.append("flash_attention_3") |
| + |
| for requested_attn_implementation in attn_implementation_available: |
| config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) |
| # Ensure the config was set correctly |
|
|