harness / diffs /38972.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/pyproject.toml b/pyproject.toml
index af22cfe9c623..4e7a0c62d0fc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,6 +52,7 @@ line-ending = "auto"
addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
+ "flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"
diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py
index 16fcc909817a..00df0ef0fd66 100644
--- a/src/transformers/integrations/flash_attention.py
+++ b/src/transformers/integrations/flash_attention.py
@@ -75,6 +75,7 @@ def flash_attention_forward(
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
+ attn_implementation=module.config._attn_implementation,
**kwargs,
)
diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py
index 7f3df3294320..649447ca8f7b 100644
--- a/src/transformers/modeling_flash_attention_utils.py
+++ b/src/transformers/modeling_flash_attention_utils.py
@@ -14,6 +14,7 @@
import inspect
import os
+import warnings
from typing import Optional, TypedDict
import torch
@@ -21,6 +22,7 @@
from .utils import (
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
@@ -32,18 +34,123 @@
flash_attn_func = None
-if is_flash_attn_2_available():
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.layers.rotary import apply_rotary_emb # noqa
+def _index_first_axis(tensor, indices):
+ """
+ A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
+ after flattening the first two dimensions of the tensor. This is functionally equivalent to
+ FA2's `index_first_axis` and replaces the need to import it.
+ """
+ # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
+ # two dimensions to get (total_tokens, ...) before indexing.
+ reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
+ return reshaped_tensor[indices]
+
+
+def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
+ """
+ FA3-compatible unpad_input function.
+ Arguments:
+ hidden_states: (batch, seqlen, ...)
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
+ Return:
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
+ max_seqlen_in_batch: int
+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
+ """
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+
+ return (
+ _index_first_axis(hidden_states, indices),
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ used_seqlens_in_batch,
+ )
+
+
+def _fa3_pad_input(hidden_states, indices, batch, seqlen):
+ """
+ FA3-compatible pad_input function.
+
+ Arguments:
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
+ batch: int, batch size for the padded sequence.
+ seqlen: int, maximum sequence length for the padded sequence.
+ Return:
+ hidden_states: (batch, seqlen, ...)
+ """
+ dim = hidden_states.shape[1:]
+ output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
+ output[indices] = hidden_states
+ return output.view(batch, seqlen, *dim)
+
+
+FA_VERSION = None
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func as flash_attn_2_func
+ from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
+ from flash_attn.bert_padding import pad_input as pad_input_fa2
+ from flash_attn.bert_padding import unpad_input as unpad_input_fa2
+ from flash_attn.layers.rotary import apply_rotary_emb
+
+ HAS_FA2 = True
+ FA_VERSION = 2
+else:
+ flash_attn_2_func = None
+ flash_attn_2_varlen_func = None
+ pad_input_fa2 = None
+ unpad_input_fa2 = None
+ apply_rotary_emb = None
+ HAS_FA2 = False
+
+if is_flash_attn_3_available():
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+
+ pad_input_fa3 = _fa3_pad_input
+ unpad_input_fa3 = _fa3_unpad_input
+ HAS_FA3 = True
+ FA_VERSION = 3
+else:
+ flash_attn_3_func = None
+ flash_attn_3_varlen_func = None
+ pad_input_fa3 = None
+ unpad_input_fa3 = None
+ HAS_FA3 = False
+
+
+# Current Flash Attention implementations
+if FA_VERSION:
+ flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
+ flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
+ unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
+ pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available():
- from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
- from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
- from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
- from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
+ from .integrations.npu_flash_attention import (
+ npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
+ )
+ from .integrations.npu_flash_attention import (
+ npu_flash_attn_func as flash_attn_func,
+ )
+ from .integrations.npu_flash_attention import (
+ npu_flash_attn_varlen_func as flash_attn_varlen_func,
+ )
+ from .integrations.npu_flash_attention import (
+ pad_input,
+ unpad_input,
+ )
_flash_supports_window_size = False
@@ -56,6 +163,9 @@
def is_flash_attn_available():
"""Determine whether flash-attention can be used or not."""
+ if is_flash_attn_3_available():
+ return True
+
# if package `flash-attn` is available, flash-attention can be used natively.
if is_flash_attn_2_available():
return True
@@ -70,6 +180,9 @@ def is_flash_attn_available():
def flash_attn_supports_top_left_mask():
"""Determine whether flash-attention uses top-left or down-right mask"""
+ if is_flash_attn_3_available():
+ return False
+
if is_flash_attn_2_available():
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
return not is_flash_attn_greater_or_equal_2_10()
@@ -116,6 +229,7 @@ def _upad_input(
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
+ unpad_input_func,
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
@@ -134,6 +248,8 @@ def _upad_input(
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Target length.
+ unpad_input_func:
+ The function to use for unpadding the input tensors.
Return:
query_layer (`torch.Tensor`):
@@ -158,12 +274,10 @@ def _upad_input(
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
- value_layer = index_first_axis(
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
- )
+ key_layer = _index_first_axis(key_layer, indices_k)
+ value_layer = _index_first_axis(value_layer, indices_k)
if query_length == kv_seq_len:
- query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
+ query_layer = _index_first_axis(query_layer, indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
@@ -177,7 +291,7 @@ def _upad_input(
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask)
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
return (
query_layer,
@@ -189,7 +303,7 @@ def _upad_input(
)
-def prepare_fa2_from_position_ids(query, key, value, position_ids):
+def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
@@ -239,6 +353,14 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
+def prepare_fa2_from_position_ids(*args, **kwargs):
+ warnings.warn(
+ "The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
+ FutureWarning,
+ )
+ return _prepare_flash_attention_from_position_ids(*args, **kwargs)
+
+
def fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
@@ -303,6 +425,7 @@ def _flash_attention_forward(
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None,
+ attn_implementation: Optional[str] = None,
**kwargs,
):
"""
@@ -329,7 +452,28 @@ def _flash_attention_forward(
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
+ attn_implementation (`str`, *optional*):
+ The attention implementation to use. If None, will default to the one based on the environment.
"""
+ if attn_implementation is None:
+ _flash_attn_varlen_func = flash_attn_varlen_func
+ _flash_attn_func = flash_attn_func
+ _pad_input = pad_input
+ _unpad_input = unpad_input
+ _is_fa3 = HAS_FA3
+ elif attn_implementation == "flash_attention_3":
+ _flash_attn_varlen_func = flash_attn_3_varlen_func
+ _flash_attn_func = flash_attn_3_func
+ _pad_input = pad_input_fa3
+ _unpad_input = unpad_input_fa3
+ _is_fa3 = True
+ elif attn_implementation == "flash_attention_2":
+ _flash_attn_varlen_func = flash_attn_2_varlen_func
+ _flash_attn_func = flash_attn_2_func
+ _pad_input = pad_input_fa2
+ _unpad_input = unpad_input_fa2
+ _is_fa3 = False
+
if not use_top_left_mask:
causal = is_causal
else:
@@ -342,6 +486,12 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
+ if _is_fa3:
+ if dropout > 0.0:
+ logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
+ else:
+ flash_kwargs["dropout_p"] = dropout
+
if flash_241:
if deterministic is None:
global deterministic_g
@@ -362,12 +512,12 @@ def _flash_attention_forward(
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
- query_states, key_states, value_states, attention_mask, query_length
+ query_states, key_states, value_states, attention_mask, query_length, _unpad_input
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
- attn_output_unpad = flash_attn_varlen_func(
+ attn_output_unpad = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
@@ -375,12 +525,11 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
@@ -394,7 +543,7 @@ def _flash_attention_forward(
if cu_seq_lens_q is None or cu_seq_lens_k is None:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
- prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
+ _prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
@@ -405,7 +554,7 @@ def _flash_attention_forward(
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
- attn_output = flash_attn_varlen_func(
+ attn_output = _flash_attn_varlen_func(
query_states,
key_states,
value_states,
@@ -413,7 +562,6 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
- dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
@@ -422,10 +570,12 @@ def _flash_attention_forward(
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
+ attn_output = _flash_attn_func(
+ query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)
+ if isinstance(attn_output, tuple):
+ return attn_output[0]
return attn_output
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 4f6095a3eddc..a5d1be345d10 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -105,6 +105,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_kernels_available,
is_offline_mode,
is_optimum_available,
@@ -1957,6 +1958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Flash Attention 2 support
_supports_flash_attn_2 = False
+ # Flash Attention 3 support
+ _supports_flash_attn_3 = False
+
# SDPA support
_supports_sdpa = False
@@ -2247,6 +2251,8 @@ def _autoset_attn_implementation(
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
+ if cls._supports_flash_attn_3:
+ message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
@@ -2282,7 +2288,15 @@ def _autoset_attn_implementation(
):
sub_config._attn_implementation_internal = curr_attn_implementation
- if config._attn_implementation == "flash_attention_2":
+ if config._attn_implementation == "flash_attention_3":
+ cls._check_and_enable_flash_attn_3(
+ config,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ hard_check_only=False,
+ check_device_map=check_device_map,
+ )
+ elif config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
@@ -2498,6 +2512,94 @@ def _check_and_enable_flash_attn_2(
config._attn_implementation = "flash_attention_2"
return config
+ @classmethod
+ def _check_and_enable_flash_attn_3(
+ cls,
+ config,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, dict[str, int]]] = None,
+ check_device_map: bool = True,
+ hard_check_only: bool = False,
+ ) -> PretrainedConfig:
+ """
+ Checks the availability of Flash Attention 3 and compatibility with the current model.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
+ """
+ if not cls._supports_flash_attn_3:
+ raise ValueError(
+ f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
+ f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
+ " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+ )
+
+ if not is_flash_attn_3_available():
+ preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
+
+ if importlib.util.find_spec("flash_attn_3") is None:
+ raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.")
+
+ if torch.cuda.is_available():
+ major, _ = torch.cuda.get_device_capability()
+ if major < 9:
+ raise ValueError(
+ f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0."
+ )
+ else:
+ raise ImportError(f"{preface} Flash Attention 3 is not available.")
+ else:
+ raise ValueError(
+ f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
+ )
+
+ if torch_dtype is None:
+ logger.warning_once(
+ "You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour"
+ )
+ elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
+ logger.warning_once(
+ "Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
+ f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
+ ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
+ )
+
+ if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
+ raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
+
+ # Check for attention dropout, which is incompatible with FA3
+ if hasattr(config, "attention_dropout") and config.attention_dropout > 0:
+ raise ValueError(
+ f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3."
+ )
+
+ # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
+ # or the model may be initialized under the context manager `with torch.device("cuda"):`.
+ if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
+ if torch.cuda.is_available():
+ logger.warning_once(
+ "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
+ " after initializing it on CPU with `model.to('cuda')`."
+ )
+ else:
+ raise ValueError(
+ "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
+ "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
+ "or initialising the model on CPU and then moving it to GPU."
+ )
+ elif (
+ check_device_map
+ and device_map is not None
+ and isinstance(device_map, dict)
+ and ("cpu" in device_map.values() or "disk" in device_map.values())
+ ):
+ raise ValueError(
+ "You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
+ "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
+ )
+ if not hard_check_only:
+ config._attn_implementation = "flash_attention_3"
+ return config
+
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
@@ -4134,7 +4236,7 @@ def from_pretrained(
</Tip>
attn_implementation (`str`, *optional*):
- The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
+ The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
> Parameters for big model inference
@@ -5770,6 +5872,7 @@ class AttentionInterface(GeneralInterface):
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
+ "flash_attention_3": flash_attention_forward,
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"paged_attention": paged_attention_forward,
diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py
index dc8b7880c418..c224c4300eb4 100644
--- a/src/transformers/models/arcee/modeling_arcee.py
+++ b/src/transformers/models/arcee/modeling_arcee.py
@@ -321,6 +321,7 @@ class ArceePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["ArceeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py
index f62069a09f4c..87f11d192693 100644
--- a/src/transformers/models/aria/modeling_aria.py
+++ b/src/transformers/models/aria/modeling_aria.py
@@ -667,6 +667,7 @@ class AriaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["AriaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py
index f526802bfca9..afafd3f91188 100644
--- a/src/transformers/models/bitnet/modeling_bitnet.py
+++ b/src/transformers/models/bitnet/modeling_bitnet.py
@@ -318,6 +318,7 @@ class BitNetPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["BitNetDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index 88ca4e31de10..ad1604bed4a3 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -355,6 +355,7 @@ class CoherePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["CohereDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py
index 6999f1632f95..3fec29e97609 100644
--- a/src/transformers/models/cohere2/modeling_cohere2.py
+++ b/src/transformers/models/cohere2/modeling_cohere2.py
@@ -334,6 +334,7 @@ class Cohere2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Cohere2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
index 6eb506218918..541ae6669e92 100644
--- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
+++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
@@ -504,6 +504,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["DeepseekV3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py
index fae9f2dbb95c..383c329c9909 100644
--- a/src/transformers/models/diffllama/modeling_diffllama.py
+++ b/src/transformers/models/diffllama/modeling_diffllama.py
@@ -556,6 +556,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["DiffLlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = False
diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py
index b10fae6dbc8d..58b805cca613 100644
--- a/src/transformers/models/dots1/modeling_dots1.py
+++ b/src/transformers/models/dots1/modeling_dots1.py
@@ -424,6 +424,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Dots1DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index 1f8da9ed0ece..04b438c5ab4f 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -318,6 +318,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index 7008538c7ab0..bfd3317946be 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -339,6 +339,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py
index db15678c25c9..084ef0893a7c 100644
--- a/src/transformers/models/gemma3/modeling_gemma3.py
+++ b/src/transformers/models/gemma3/modeling_gemma3.py
@@ -422,6 +422,7 @@ class Gemma3PreTrainedModel(PreTrainedModel):
"SiglipMultiheadAttentionPoolingHead",
]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py
index 2ee6273c00d4..86538fc25e58 100644
--- a/src/transformers/models/glm/modeling_glm.py
+++ b/src/transformers/models/glm/modeling_glm.py
@@ -335,6 +335,7 @@ class GlmPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GlmDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py
index 75487c5fccff..55cc8869d952 100644
--- a/src/transformers/models/glm4/modeling_glm4.py
+++ b/src/transformers/models/glm4/modeling_glm4.py
@@ -343,6 +343,7 @@ class Glm4PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Glm4DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index d3c5141371b2..2e563e401f23 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -292,6 +292,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py
index d1d69f9579c0..b65530c40613 100644
--- a/src/transformers/models/granite/modeling_granite.py
+++ b/src/transformers/models/granite/modeling_granite.py
@@ -305,6 +305,7 @@ class GranitePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GraniteDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py
index 31d9f963049f..3a48d931ca1e 100644
--- a/src/transformers/models/helium/modeling_helium.py
+++ b/src/transformers/models/helium/modeling_helium.py
@@ -320,6 +320,7 @@ class HeliumPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["HeliumDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 3a200ad988b8..e79a76976028 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -320,6 +320,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py
index 0709d31f558b..66ed4adcea4c 100644
--- a/src/transformers/models/minimax/modeling_minimax.py
+++ b/src/transformers/models/minimax/modeling_minimax.py
@@ -590,6 +590,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MiniMaxDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index 2576c85a785a..4b222eabe237 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -262,6 +262,7 @@ class MistralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
index ae0fd74e5665..526bf2bbd756 100644
--- a/src/transformers/models/mixtral/modeling_mixtral.py
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -417,6 +417,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MixtralDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py
index c35988e2b8d2..fc6a7188623a 100644
--- a/src/transformers/models/olmo/modeling_olmo.py
+++ b/src/transformers/models/olmo/modeling_olmo.py
@@ -301,6 +301,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["OlmoDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py
index 8e69f43d3ebc..84f5e5ad4e8a 100644
--- a/src/transformers/models/olmo2/modeling_olmo2.py
+++ b/src/transformers/models/olmo2/modeling_olmo2.py
@@ -305,6 +305,7 @@ class Olmo2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Olmo2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py
index 95164a5f5dbd..1c5136044065 100644
--- a/src/transformers/models/phi/modeling_phi.py
+++ b/src/transformers/models/phi/modeling_phi.py
@@ -295,6 +295,7 @@ class PhiPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py
index 79703927021f..54fd3d1caf73 100644
--- a/src/transformers/models/phi3/modeling_phi3.py
+++ b/src/transformers/models/phi3/modeling_phi3.py
@@ -316,6 +316,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Phi3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
index a9a902598c1f..27c199bf50a0 100644
--- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
+++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
@@ -1622,6 +1622,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Phi4MultimodalDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
index aaebc3c82bd1..4ba0b43e134f 100644
--- a/src/transformers/models/qwen2/modeling_qwen2.py
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
@@ -266,6 +266,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py
index 6da044857043..e64f96675977 100644
--- a/src/transformers/models/qwen3/modeling_qwen3.py
+++ b/src/transformers/models/qwen3/modeling_qwen3.py
@@ -292,6 +292,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
index 329da67a1e64..47ec0d10ab12 100644
--- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
+++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
@@ -424,6 +424,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py
index b0179a518bbf..1e1d9c643632 100644
--- a/src/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/src/transformers/models/starcoder2/modeling_starcoder2.py
@@ -299,6 +299,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Starcoder2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py
index 7f3ce0927a50..a6cec1c09972 100644
--- a/src/transformers/models/t5gemma/modeling_t5gemma.py
+++ b/src/transformers/models/t5gemma/modeling_t5gemma.py
@@ -561,6 +561,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["T5GemmaBlock"]
_skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 1a4232adc8c5..2ddbd51d4140 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -86,6 +86,7 @@
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flax_available,
is_flute_available,
is_fsdp_available,
@@ -571,6 +572,15 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
+def require_flash_attn_3(test_case):
+ """
+ Decorator marking a test that requires Flash Attention 3.
+
+ These tests are skipped when Flash Attention 3 isn't installed.
+ """
+ return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
+
+
def require_torch_sdpa(test_case):
"""
Decorator marking a test that requires PyTorch's SDPA.
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 6d73b8d0325b..7ca4c3552808 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -153,6 +153,7 @@
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
diff --git a/src/transformers/utils/args_doc.py b/src/transformers/utils/args_doc.py
index 00cf4009fa55..61f947516ff7 100644
--- a/src/transformers/utils/args_doc.py
+++ b/src/transformers/utils/args_doc.py
@@ -926,6 +926,9 @@ class ClassAttrs:
_skip_keys_device_placement = r"""
A list of keys to ignore when moving inputs or outputs between devices when using the `accelerate` library.
"""
+ _supports_flash_attn_3 = r"""
+ Whether the model's attention implementation supports FlashAttention 3.0.
+ """
_supports_flash_attn_2 = r"""
Whether the model's attention implementation supports FlashAttention 2.0.
"""
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 7956f1b22d41..014366cc9778 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -1120,6 +1120,25 @@ def is_flash_attn_2_available():
return False
+@lru_cache()
+def is_flash_attn_3_available():
+ if not is_torch_available():
+ return False
+
+ if not _is_package_available("flash_attn_3"):
+ return False
+
+ import torch
+
+ if not torch.cuda.is_available():
+ return False
+
+ # TODO: Check for a minimum version when FA3 is stable
+ # return version.parse(importlib.metadata.version("flash_attn_3")) >= version.parse("3.0.0")
+
+ return True
+
+
@lru_cache
def is_flash_attn_greater_or_equal_2_10():
if not _is_package_available("flash_attn"):
diff --git a/tests/generation/test_flash_attention_parity.py b/tests/generation/test_flash_attention_parity.py
new file mode 100644
index 000000000000..187bdfe24cd9
--- /dev/null
+++ b/tests/generation/test_flash_attention_parity.py
@@ -0,0 +1,144 @@
+# Copyright 2025 Eduard Durech and SGLang team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Usage:
+# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
+
+import unittest
+
+import pytest
+import torch
+
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow
+
+
+class FlashAttentionParityTest(unittest.TestCase):
+ # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
+ def _lcs(self, X, Y):
+ m = len(X)
+ n = len(Y)
+ L = [[0] * (n + 1) for _ in range(m + 1)]
+
+ for i in range(m + 1):
+ for j in range(n + 1):
+ if i == 0 or j == 0:
+ L[i][j] = 0
+ elif X[i - 1] == Y[j - 1]:
+ L[i][j] = L[i - 1][j - 1] + 1
+ else:
+ L[i][j] = max(L[i - 1][j], L[i][j - 1])
+
+ return L[m][n]
+
+ # From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
+ def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
+ rouge_l_scores = []
+
+ for s1, s2 in zip(output_strs_list1, output_strs_list2):
+ lcs_len = self._lcs(s1, s2)
+ precision = lcs_len / len(s1) if len(s1) > 0 else 0
+ recall = lcs_len / len(s2) if len(s2) > 0 else 0
+ if precision + recall > 0:
+ fmeasure = (2 * precision * recall) / (precision + recall)
+ else:
+ fmeasure = 0.0
+ rouge_l_scores.append(fmeasure)
+
+ return rouge_l_scores
+
+ def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
+ for _ in range(n_warmup):
+ model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ torch.cuda.synchronize()
+
+ start_time = torch.cuda.Event(enable_timing=True)
+ end_time = torch.cuda.Event(enable_timing=True)
+
+ start_time.record()
+ for _ in range(n_runs):
+ model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ end_time.record()
+ torch.cuda.synchronize()
+
+ return start_time.elapsed_time(end_time) / n_runs
+
+ @pytest.mark.flash_attn_3_test
+ @require_torch_gpu
+ @require_flash_attn
+ @require_flash_attn_3
+ @slow
+ def test_flash_attention_2_3_parity(self):
+ model_id = "meta-llama/Llama-3.2-1B-Instruct"
+ prompt = "The ETH AI Center is"
+
+ # 1. Load FA2 model and tokenizer
+ model_2 = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ ).to("cuda")
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+ # 2. Load FA3 model
+ try:
+ model_3 = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+ ).to("cuda")
+ except (ValueError, ImportError) as e:
+ pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
+
+ # 3. Generate with both models
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
+
+ with torch.no_grad():
+ output_2 = model_2.generate(
+ **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
+ )
+ output_3 = model_3.generate(
+ **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
+ )
+
+ # 4. Correctness check
+ # 4a. Logits
+ logits_2 = torch.stack(output_2.scores)
+ logits_3 = torch.stack(output_3.scores)
+ torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
+ logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1)
+ logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
+ max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
+
+ # 4b. Generated text
+ text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
+ text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
+ rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
+ assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
+
+ # 5. Performance check
+ with torch.no_grad():
+ time_2 = self._benchmark_generation(model_2, inputs)
+ time_3 = self._benchmark_generation(model_3, inputs)
+
+ print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
+ print(f"Prompt: '{prompt}'")
+ print(f"Generated text with Flash Attention 2: {text_2}")
+ print(f"Generated text with Flash Attention 3: {text_3}")
+ print(f"ROUGE-L: {rouge_score}")
+ print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
+ print(f"Flash Attention 2 latency: {time_2:.2f} ms")
+ print(f"Flash Attention 3 latency: {time_3:.2f} ms")
+ print(f"Speed-up: {time_2 / time_3:.2f}x")
+ print("---")
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index e92d1e1ec77a..840d2e66e753 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -34,6 +34,7 @@
is_flaky,
require_accelerate,
require_flash_attn,
+ require_flash_attn_3,
require_optimum_quanto,
require_read_token,
require_torch,
@@ -2292,6 +2293,7 @@ def _test_attention_implementation(self, attn_implementation):
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
+ "flash_attention_3": "_supports_flash_attn_3",
}
for model_class in self.all_generative_model_classes:
@@ -2369,6 +2371,14 @@ def test_eager_matches_fa2_generate(self):
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
self._test_attention_implementation("flash_attention_2")
+ @pytest.mark.flash_attn_3_test
+ @require_flash_attn_3
+ @require_torch_gpu
+ @slow
+ def test_eager_matches_fa3_generate(self):
+ """Tests that generate has equivalent outputs with FA3 and eager attention implementations."""
+ self._test_attention_implementation("flash_attention_3")
+
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = (
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index f7183089044e..a5d9c9006809 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -84,6 +84,7 @@
require_bitsandbytes,
require_deepspeed,
require_flash_attn,
+ require_flash_attn_3,
require_non_hpu,
require_safetensors,
require_torch,
@@ -3129,18 +3130,19 @@ def test_model_is_small(self):
f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
)
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- @is_flaky()
- def test_flash_attn_2_inference_equivalence(self):
+ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
+ r"""
+ Tests the equivalence between the eager and flash attention implementations.
+ This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
+ """
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
+ attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
+ ):
+ self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3148,7 +3150,7 @@ def test_flash_attn_2_inference_equivalence(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
)
model_fa.to(torch_device)
@@ -3163,9 +3165,12 @@ def test_flash_attn_2_inference_equivalence(self):
if dummy_attention_mask is not None:
dummy_attention_mask = dummy_attention_mask[:1]
- dummy_attention_mask[:, 1:] = 1
- dummy_attention_mask[:, :1] = 0
-
+ if padding_side == "left":
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+ else:
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
if model.config.is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
@@ -3220,104 +3225,46 @@ def test_flash_attn_2_inference_equivalence(self):
else outputs_fa.decoder_hidden_states[-1]
)
- assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+ if padding_side == "left":
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
- # check with inference + dropout
- model.train()
- _ = model_fa(dummy_input, **other_inputs)
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+ else:
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
- def test_flash_attn_2_inference_equivalence_right_padding(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
- )
- model_fa.to(torch_device)
-
- model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
- model.to(torch_device)
-
- dummy_input = inputs_dict[model.main_input_name][:1]
- if dummy_input.dtype in [torch.float32, torch.float16]:
- dummy_input = dummy_input.to(torch.bfloat16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", None)
-
- if dummy_attention_mask is not None:
- dummy_attention_mask = dummy_attention_mask[:1]
- dummy_attention_mask[:, :-1] = 1
- dummy_attention_mask[:, -1:] = 0
-
- if model.config.is_encoder_decoder:
- decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
-
- outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
- outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
- else:
- outputs = model(dummy_input, output_hidden_states=True)
- outputs_fa = model_fa(dummy_input, output_hidden_states=True)
-
- logits = (
- outputs.hidden_states[-1]
- if not model.config.is_encoder_decoder
- else outputs.decoder_hidden_states[-1]
- )
- logits_fa = (
- outputs_fa.hidden_states[-1]
- if not model.config.is_encoder_decoder
- else outputs_fa.decoder_hidden_states[-1]
- )
-
- assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
-
- if model.config.is_encoder_decoder:
- other_inputs = {
- "decoder_input_ids": decoder_input_ids,
- "decoder_attention_mask": dummy_attention_mask,
- "output_hidden_states": True,
- }
- if dummy_attention_mask is not None:
- other_inputs["attention_mask"] = dummy_attention_mask
-
- outputs = model(dummy_input, **other_inputs)
- outputs_fa = model_fa(dummy_input, **other_inputs)
- else:
- other_inputs = {
- "output_hidden_states": True,
- }
- if dummy_attention_mask is not None:
- other_inputs["attention_mask"] = dummy_attention_mask
+ def test_flash_attn_2_inference_equivalence(self):
+ self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="left")
- outputs = model(dummy_input, **other_inputs)
- outputs_fa = model_fa(dummy_input, **other_inputs)
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_2_inference_equivalence_right_padding(self):
+ self.flash_attn_inference_equivalence(attn_implementation="flash_attention_2", padding_side="right")
- logits = (
- outputs.hidden_states[-1]
- if not model.config.is_encoder_decoder
- else outputs.decoder_hidden_states[-1]
- )
- logits_fa = (
- outputs_fa.hidden_states[-1]
- if not model.config.is_encoder_decoder
- else outputs_fa.decoder_hidden_states[-1]
- )
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_inference_equivalence(self):
+ self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="left")
- assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="right")
def test_attn_implementation_composite_models(self):
"""
@@ -3959,24 +3906,21 @@ def test_sdpa_matches_eager_sliding_window(self):
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
)
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- def test_flash_attn_2_can_dispatch_composite_models(self):
+ def flash_attn_can_dispatch_composite_models(self, attn_implementation: str):
"""
- Tests if composite models can dispatch on FA2 if the sub-models support FA2.
+ Tests if composite models can dispatch on flash attention if the sub-models support it.
The tests is needed as we handle differently composite models and we cannot check them
- with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching
+ with above tests. If any of the sub-models does not support flash attention, we'll raise an error when dispatching
that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific
backbone models (LM/vision/audio/etc)
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
- if not is_torch_fp16_available_on_device(torch_device):
- self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
+ if not is_torch_bf16_available_on_device(torch_device):
+ self.skipTest(f"bfloat16 not supported on {torch_device} (on the specific device currently used)")
- torch_dtype = torch.float16
+ torch_dtype = torch.bfloat16
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3987,44 +3931,64 @@ def test_flash_attn_2_can_dispatch_composite_models(self):
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
- sub_models_supporting_fa2 = [
- module._supports_flash_attn_2
+ sub_models_supporting_fa = [
+ (
+ module._supports_flash_attn_3
+ if attn_implementation == "flash_attention_3"
+ else module._supports_flash_attn_2
+ )
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
- supports_fa2_all_modules = (
- all(sub_models_supporting_fa2)
- if len(sub_models_supporting_fa2) > 0
- else model._supports_flash_attn_2
+ supports_fa_all_modules = (
+ all(sub_models_supporting_fa)
+ if len(sub_models_supporting_fa) > 0
+ else (
+ model._supports_flash_attn_3
+ if attn_implementation == "flash_attention_3"
+ else model._supports_flash_attn_2
+ )
)
- if not supports_fa2_all_modules:
+ if not supports_fa_all_modules:
with self.assertRaises(ValueError):
- model_fa2 = model_class.from_pretrained(
+ model_fa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
- attn_implementation="flash_attention_2",
+ attn_implementation=attn_implementation,
)
else:
- model_fa2 = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch_dtype, attn_implementation=attn_implementation
)
- for key in model_fa2.config:
- if isinstance(getattr(model_fa2.config, key), PretrainedConfig):
- sub_config = getattr(model_fa2.config, key)
- self.assertTrue(sub_config._attn_implementation == "flash_attention_2")
+ for key in model_fa.config:
+ if isinstance(getattr(model_fa.config, key), PretrainedConfig):
+ sub_config = getattr(model_fa.config, key)
+ self.assertTrue(sub_config._attn_implementation == attn_implementation)
- has_fa2 = False
- for name, submodule in model_fa2.named_modules():
+ has_fa = False
+ for name, submodule in model_fa.named_modules():
class_name = submodule.__class__.__name__
if (
"Attention" in class_name
and getattr(submodule, "config", None)
- and submodule.config._attn_implementation == "flash_attention_2"
+ and submodule.config._attn_implementation == attn_implementation
):
- has_fa2 = True
+ has_fa = True
break
- if not has_fa2:
- raise ValueError("The FA2 model should have FA2 layers")
+ if not has_fa:
+ raise ValueError(f"The {attn_implementation} model should have {attn_implementation} layers")
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ def test_flash_attn_2_can_dispatch_composite_models(self):
+ self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_2")
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_flash_attn_3_can_dispatch_composite_models(self):
+ self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_3")
@require_flash_attn
@require_torch_gpu
@@ -4121,27 +4085,29 @@ def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(s
assert not loss.isnan().any()
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
+ def flash_attention_padding_matches_padding_free_with_position_ids(
+ self, attn_implementation: str, fa_kwargs: bool = False
+ ):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ if not (
+ model_class._supports_flash_attn_2
+ if attn_implementation == "flash_attention_2"
+ else model_class._supports_flash_attn_3
+ ):
+ self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
@@ -4151,7 +4117,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
- if "position_ids" not in inspect.signature(model.forward).parameters:
+ if (not fa_kwargs) and "position_ids" not in inspect.signature(model.forward).parameters:
continue # this model doesn't accept position ids as input
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -4166,26 +4132,40 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
model = (
model_class.from_pretrained(
tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+ attn_implementation=attn_implementation,
)
.to(torch_device)
.eval()
)
- # flatten
- padfree_inputs_dict = {
- k: v[dummy_attention_mask.bool()].unsqueeze(0)
- for k, v in inputs_dict.items()
- if not k == "attention_mask"
- }
- # add position_ids
- padfree_inputs_dict["position_ids"] = (
- torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
- .long()
- .unsqueeze(0)
- .to(torch_device)
- )
+ if fa_kwargs:
+ # flatten
+ features = [
+ {"input_ids": i[a.bool()].tolist()}
+ for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
+ ]
+
+ # add position_ids + fa_kwargs
+ data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
+ batch = data_collator(features)
+ padfree_inputs_dict = {
+ k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()
+ }
+ else:
+ # flatten
+ padfree_inputs_dict = {
+ k: v[dummy_attention_mask.bool()].unsqueeze(0)
+ for k, v in inputs_dict.items()
+ if not k == "attention_mask"
+ }
+ # add position_ids
+ padfree_inputs_dict["position_ids"] = (
+ torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
+ .long()
+ .unsqueeze(0)
+ .to(torch_device)
+ )
res_padded = model(**inputs_dict)
res_padfree = model(**padfree_inputs_dict)
@@ -4195,119 +4175,96 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
- tol = torch.finfo(torch.float16).eps
+ tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
- def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- max_new_tokens = 30
-
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
- self.skipTest("Model dummy inputs should contain padding in their attention mask")
-
- dummy_input = inputs_dict[model_class.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
-
- # make sure that all models have enough positions for generation
- if hasattr(config, "max_position_embeddings"):
- config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
-
- model = model_class(config)
- if "position_ids" not in inspect.signature(model.forward).parameters:
- self.skipTest("Model does not support position_ids")
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- # ensure left padding, to adapt for some models
- if 0 in inputs_dict["attention_mask"][:, -1]:
- inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
- dummy_attention_mask = inputs_dict["attention_mask"]
- inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
-
- model = (
- model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
- )
- .to(torch_device)
- .eval()
- )
-
- # flatten
- features = [
- {"input_ids": i[a.bool()].tolist()}
- for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
- ]
-
- # add position_ids + fa_kwargs
- data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
- batch = data_collator(features)
- batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}
-
- res_padded = model(**inputs_dict)
- res_padfree = model(**batch_accelerator)
-
- logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
- logits_padfree = res_padfree.logits[0]
-
- torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
- # acceptable numerical instability
- tol = torch.finfo(torch.float16).eps
- torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
+ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
+ self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_2")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
- def test_flash_attn_2_from_config(self):
+ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
+ self.flash_attention_padding_matches_padding_free_with_position_ids(
+ attn_implementation="flash_attention_2", fa_kwargs=True
+ )
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
+ self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_3")
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
+ self.flash_attention_padding_matches_padding_free_with_position_ids(
+ attn_implementation="flash_attention_3", fa_kwargs=True
+ )
+
+ def flash_attn_from_config(self, attn_implementation: str):
+ r"""
+ Tests if the model can be loaded with `attn_implementation` from the config and if the
+ weights are not randomly initialized.
+ """
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_generative_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
+ attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
+ ):
+ self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes
- fa2_model = model_class._from_config(
- config, attn_implementation="flash_attention_2", torch_dtype=torch.float16
+ fa_model = model_class._from_config(
+ config, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16
).to(torch_device)
- dummy_input = inputs_dict[fa2_model.main_input_name]
- if dummy_input.dtype in [torch.float32, torch.bfloat16]:
- dummy_input = dummy_input.to(torch.float16)
+ dummy_input = inputs_dict[fa_model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
- if fa2_model.config.is_encoder_decoder:
+ if fa_model.config.is_encoder_decoder:
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
- _ = fa2_model(
+ _ = fa_model(
dummy_input,
attention_mask=dummy_attention_mask,
decoder_input_ids=dummy_decoder_input_ids,
decoder_attention_mask=dummy_decoder_attention_mask,
)
else:
- _ = fa2_model(dummy_input, attention_mask=dummy_attention_mask)
+ _ = fa_model(dummy_input, attention_mask=dummy_attention_mask)
with tempfile.TemporaryDirectory() as tmpdirname:
- fa2_model.save_pretrained(tmpdirname)
+ fa_model.save_pretrained(tmpdirname)
model_from_pretrained = model_class.from_pretrained(tmpdirname)
- self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
+ self.assertTrue(model_from_pretrained.config._attn_implementation != attn_implementation)
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_from_config(self):
+ self.flash_attn_from_config(attn_implementation="flash_attention_2")
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_from_config(self):
+ self.flash_attn_from_config(attn_implementation="flash_attention_3")
def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index 903283dd4a9d..7df23e029591 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -77,6 +77,7 @@
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flax_available,
is_tf_available,
is_torch_npu_available,
@@ -676,6 +677,9 @@ def test_model_from_pretrained_attn_implementation(self):
if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2")
+ if is_flash_attn_3_available():
+ attn_implementation_available.append("flash_attention_3")
+
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation
@@ -700,6 +704,9 @@ def test_model_from_config_attn_implementation(self):
if is_flash_attn_available():
attn_implementation_available.append("flash_attention_2")
+ if is_flash_attn_3_available():
+ attn_implementation_available.append("flash_attention_3")
+
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly