harness / diffs /40002.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index eabc6f2926d3..b83d5a973398 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -678,9 +678,10 @@ def prepare_inputs_for_generation(
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask
+ # 7. Prepare kwargs for flash attention to avoid recomputations
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
- cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids(
- position_ids, is_packed_sequence=False
+ (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
+ model_inputs["position_ids"], is_packed_sequence=False
)
model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
@@ -689,12 +690,12 @@ def prepare_inputs_for_generation(
max_length_k=max_length_k,
)
- # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
+ # 8. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
- # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
+ # 9. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
model_inputs.pop("labels", None)
return model_inputs
diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py
index ed1b30d9a6b0..716a3481a82a 100644
--- a/src/transformers/integrations/npu_flash_attention.py
+++ b/src/transformers/integrations/npu_flash_attention.py
@@ -10,20 +10,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
import os
import torch
-import torch.nn.functional as F
from ..utils.import_utils import is_torch_npu_available
if is_torch_npu_available():
- import math
-
- import torch_npu
- from einops import rearrange, repeat
- from torch_npu import npu_rotary_mul
+ from torch_npu import npu_fusion_attention, npu_rotary_mul
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask():
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
-# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
-class IndexFirstAxis(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input, indices):
- ctx.save_for_backward(indices)
- assert input.ndim >= 2
- ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
- second_dim = other_shape.numel()
- # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
- # return input[indices]
- return torch.gather(
- rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
- ).reshape(-1, *other_shape)
-
- @staticmethod
- def backward(ctx, grad_output):
- (indices,) = ctx.saved_tensors
- assert grad_output.ndim >= 2
- other_shape = grad_output.shape[1:]
- grad_output = rearrange(grad_output, "b ... -> b (...)")
- grad_input = torch.zeros(
- [ctx.first_axis_dim, grad_output.shape[1]],
- device=grad_output.device,
- dtype=grad_output.dtype,
- )
- # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
- # grad_input[indices] = grad_output
- grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
- return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
-
-
-index_first_axis = IndexFirstAxis.apply
-
-
-# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
-class IndexPutFirstAxis(torch.autograd.Function):
- @staticmethod
- def forward(ctx, values, indices, first_axis_dim):
- ctx.save_for_backward(indices)
- assert indices.ndim == 1
- assert values.ndim >= 2
- output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
- # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
- output[indices] = values
- # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- (indices,) = ctx.saved_tensors
- # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
- grad_values = grad_output[indices]
- # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
- return grad_values, None, None
-
-
-index_put_first_axis = IndexPutFirstAxis.apply
-
-
-# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
-def pad_input(hidden_states, indices, batch, seqlen):
- """
- Arguments:
- hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
- indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
- batch: int, batch size for the padded sequence.
- seqlen: int, maximum sequence length for the padded sequence.
- Return:
- hidden_states: (batch, seqlen, ...)
- """
- # dim = hidden_states.shape[-1]
- # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
- # output[indices] = hidden_states
- output = index_put_first_axis(hidden_states, indices, batch * seqlen)
- return rearrange(output, "(b s) ... -> b s ...", b=batch)
-
-
-# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
-def unpad_input(hidden_states, attention_mask, unused_mask=None):
- """
- Arguments:
- hidden_states: (batch, seqlen, ...)
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
- unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
- Return:
- hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
- indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
- cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
- max_seqlen_in_batch: int
- seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
- """
- all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
- seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
- used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
- # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
- # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
- # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
- # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
- # so we write custom forward and backward to make it a bit faster.
- return (
- index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- used_seqlens_in_batch,
- )
-
-
def npu_flash_attn_func(
q,
k,
@@ -179,11 +64,11 @@ def npu_flash_attn_func(
if not causal:
head_num = q.shape[2]
- output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
+ output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[2]
- output = torch_npu.npu_fusion_attention(
+ output = npu_fusion_attention(
q,
k,
v,
@@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func(
if not causal:
head_num = q.shape[1]
- output = torch_npu.npu_fusion_attention(
+ output = npu_fusion_attention(
q,
k,
v,
@@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func(
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[1]
- output = torch_npu.npu_fusion_attention(
+ output = npu_fusion_attention(
q,
k,
v,
@@ -267,8 +152,3 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
sin = sin.unsqueeze(0).unsqueeze(2)
return npu_rotary_mul(x, cos, sin)
-
-
-def get_npu_flash_attn_funcs():
- # return flash attention related functions used for Ascend NPU in order
- return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False
diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py
index e845e0cbc4a4..0d8906076829 100644
--- a/src/transformers/modeling_flash_attention_utils.py
+++ b/src/transformers/modeling_flash_attention_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,17 +14,15 @@
import inspect
import os
import warnings
+from functools import partial
from typing import Optional, TypedDict
import torch
import torch.nn.functional as F
-from transformers.utils.import_utils import is_kernels_available
-
from .utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
- is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
logging,
@@ -34,18 +32,135 @@
logger = logging.get_logger(__name__)
-def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
- reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:])
- return reshaped[indices]
+# TODO Deprecate when all models have the attention interface
+def flash_attn_supports_top_left_mask():
+ if is_flash_attn_3_available():
+ return False
+ if is_flash_attn_2_available():
+ return not is_flash_attn_greater_or_equal_2_10()
+
+ from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
+
+ return is_npu_fa2_top_left_aligned_causal_mask()
+
+
+# TODO Deprecate when all models have the attention interface
+def is_flash_attn_available():
+ return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
+
+
+# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
+_flash_fn = None
+_flash_varlen_fn = None
+_pad_fn = None
+_unpad_fn = None
+
+# function that processes kwargs, generalized to handle any supported kwarg within the function
+_process_flash_kwargs_fn = None
+# exceptions where hf API doesn't match the original flash attention API
+_hf_api_to_flash_mapping = {
+ "dropout": "dropout_p",
+ "sliding_window": "window_size",
+}
+
+
+def _lazy_imports(implementation: Optional[str]):
+ """
+ Lazy loads the respective flash attention implementations.
+
+ Return:
+ flash_attn_func: The base flash attention function.
+ flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
+ e.g. for padding-free training.
+ pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
+ unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
+ """
+ is_fa2 = is_flash_attn_2_available()
+ is_fa3 = is_flash_attn_3_available()
+ if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import pad_input, unpad_input
+ else:
+ pad_input, unpad_input = _pad_input, _unpad_input
+ if implementation == "flash_attention_3" or (implementation is None and is_fa3):
+ from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+ elif is_torch_npu_available():
+ from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
+ from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
+ # Kernels fallback
+ else:
+ flash_attn_func = getattr(implementation, "flash_attn_func", None)
+ flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None)
+ if flash_attn_varlen_func is None or flash_attn_func is None:
+ raise ValueError(
+ f"Could not find the currently requested flash attention implementation at `{implementation}`."
+ f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
+ )
+
+ return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
+
+
+def _lazy_define_process_function(flash_function):
+ """
+ Depending on the version and kernel some features are not supported. Due to limitations in
+ `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
+ within `_process_flash_attention_kwargs`.
+
+ NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
+ This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
+ """
+ global _process_flash_kwargs_fn, _hf_api_to_flash_mapping
+
+ flash_parameters = inspect.signature(flash_function).parameters
+ process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
+
+ supports_mapping = {}
+ for param in process_parameters:
+ fa_param = _hf_api_to_flash_mapping.get(param, param)
+ supports_mapping[fa_param] = fa_param in flash_parameters
+
+ return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
+
+
+def lazy_import_flash_attention(implementation: Optional[str]):
+ """
+ Lazy loading flash attention and returning the respective functions + flags back
+
+ NOTE: For fullgraph, this needs to be called before compile while no fullgraph can
+ can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`.
+ """
+ global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
+ if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
+ _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
+
+ global _process_flash_kwargs_fn
+ if _process_flash_kwargs_fn is None:
+ _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
+ return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
-def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
+
+def _index_first_axis(tensor, indices):
+ """
+ A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
+ after flattening the first two dimensions of the tensor. This is functionally equivalent to
+ FA2's `index_first_axis` and replaces the need to import it.
"""
- FA3-compatible unpad_input function.
+ # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
+ # two dimensions to get (total_tokens, ...) before indexing.
+ reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
+ return reshaped_tensor[indices]
+
+
+def _unpad_input(hidden_states, attention_mask, unused_mask=None):
+ """
+ unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
+
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
@@ -69,14 +184,16 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
)
-def _fa3_pad_input(hidden_states, indices, batch, seqlen):
+def _pad_input(hidden_states, indices, batch, seqlen):
"""
- FA3-compatible pad_input function.
+ pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
+
Return:
hidden_states: (batch, seqlen, ...)
"""
@@ -89,9 +206,11 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
+
Arguments:
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
Return:
indices (`torch.Tensor`):
The indices of non-masked tokens from the flattened input sequence.
@@ -125,6 +244,7 @@ def _upad_input(
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
tensors for query, key, value tensors.
+
Arguments:
query_layer (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
@@ -138,6 +258,7 @@ def _upad_input(
Target length.
unpad_input_func:
The function to use for unpadding the input tensors.
+
Return:
query_layer (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
@@ -193,13 +314,15 @@ def _upad_input(
def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True):
"""
This function returns all the necessary kwargs to call `flash_attn_varlen_func`
- extracted from position_ids.The `position_ids` can be either packed sequence or
- the usual padded position ids, for example in inference time..
+ extracted from position_ids. The `position_ids` can be either packed sequence or
+ the usual padded position ids, for example in inference time.
+
Arguments:
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
is_packed_sequence (`bool`, *optional*, defaults to `True`):
Whether the input position ids are a packed sequence or not.
+
Return:
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into
@@ -212,19 +335,21 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool =
# In that case the position ids will not always start with `0` and we need a better way to infer
# cumulative seq lengths.
if not is_packed_sequence:
- tensor_kws = {"dtype": torch.int32, "device": position_ids.device}
- last_position_ids = position_ids[:, -1]
+ tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
+ last_position_ids = position_ids[:, -1]
+ q_len = (
+ torch.ones(position_ids.size(0), **tensor_kwargs)
+ if position_ids.shape[-1] == 1
+ else last_position_ids.add(1)
+ )
+ cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kwargs), q_len.cumsum(0).to(torch.int32)], 0)
cu_seq_lens_k = torch.cat(
- [torch.zeros(1, **tensor_kws), last_position_ids.cumsum(0).add(1).to(torch.int32)], 0
+ [torch.zeros(1, **tensor_kwargs), last_position_ids.add(1).cumsum(0).to(torch.int32)], 0
)
- max_length_k = int(last_position_ids.max()) + 1
- q_len = (
- torch.ones(position_ids.size(0), **tensor_kws) if position_ids.shape[-1] == 1 else last_position_ids.add(1)
- )
- cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0).to(torch.int32)], 0)
max_length_q = int(q_len.max())
+ max_length_k = int(last_position_ids.max()) + 1
else:
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
@@ -237,16 +362,18 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool =
)
cu_seq_lens_k = cu_seq_lens_q
+ # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
+ # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
+ # for some models (e.g. qwen2-vl).
+ max_length_q = cu_seq_lens_q.diff().max()
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
- # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
- # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
- # for some models (e.g. qwen2-vl).
- max_length_q = cu_seq_lens_q.diff().max().item()
+ max_length_q = max_length_q.item()
max_length_k = max_length_q
+
return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
@@ -256,6 +383,7 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
+
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
@@ -267,6 +395,7 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Sequence length of the input queries.
+
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
@@ -275,121 +404,156 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
- The cumulative sequence lengths for the target (query) and source (key, value), used to index into
- ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
- Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
- `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
kv_length = key.shape[1]
+ is_packed_sequence = query_length == kv_length
+
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
- is_packed_sequence = query_length == kv_length
- cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids(
+ (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(
position_ids, is_packed_sequence=is_packed_sequence
)
+
return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
warnings.warn(
- "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids",
+ "The function `_prepare_flash_attention_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_from_posids` instead.",
FutureWarning,
)
return _prepare_from_posids(query, key, value, position_ids)
-def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None):
+def _is_packed_sequence(position_ids, batch_size):
+ """
+ Check the position ids whether packed sequences are indicated or not
+ 1. Position ids exist
+ 2. Flattened sequences only are supported
+ 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
+ """
+ if position_ids is None:
+ return False
+
+ increasing_position_sequences = (
+ torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
+ )
+ return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
+
+
+def fa_peft_integration_check(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ target_dtype: Optional[torch.dtype] = None,
+):
+ """
+ PEFT usually casts the layer norms in float32 for training stability reasons
+ therefore the input hidden states gets silently casted in float32. Hence, we need
+ cast them back in float16 / bfloat16 just to be sure everything works as expected.
+ This might slowdown training & inference so it is recommended to not cast the LayerNorms!
+ """
if target_dtype and q.dtype == torch.float32:
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
return q, k, v
-def _lazy_imports(impl: Optional[str]):
- # returns funcs and pad/unpad based on impl
- is_fa2 = is_flash_attn_2_available()
- is_fa3 = is_flash_attn_3_available()
- if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
- try:
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import pad_input, unpad_input
-
- return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False
-
- except ImportError as e:
- if not globals().get("use_remote_fa2", None):
- use_remote_fa2 = (
- input(
- "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? "
- )
- .strip()
- .lower()
- )
- globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"}
- if globals()["use_remote_fa2"]:
- if not is_kernels_available():
- raise ImportError("You need to install kernels: `pip install kernels`")
- from kernels import get_kernel
-
- impl = get_kernel("kernels-community/flash-attn")
- pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
- return (
- getattr(impl, "flash_attn_func", None),
- getattr(impl, "flash_attn_varlen_func"),
- pad_input,
- unpad_input,
- True,
- )
-
- else:
- raise ImportError(
- "Failed to import flash attention 2, please install it or use another implementation."
- ) from e
- elif is_torch_npu_available():
- # get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU
- from .integrations.npu_flash_attention import get_npu_flash_attn_funcs
-
- return get_npu_flash_attn_funcs()
- elif impl == "flash_attention_3" or (impl is None and is_fa3):
- from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
-
- pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
- return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True
- else:
- pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
- return (
- getattr(impl, "flash_attn_func", None),
- getattr(impl, "flash_attn_varlen_func"),
- pad_input,
- unpad_input,
- True,
- )
+class FlashAttentionKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for Flash Attention with Compile.
+
+ Attributes:
+ cumulative_seqlens_q (`torch.LongTensor`, *optional*)
+ Gets cumulative sequence length for query state.
+ cumulative_seqlens_k (`torch.LongTensor`, *optional*)
+ Gets cumulative sequence length for key state.
+ max_length_q (`int`, *optional*):
+ Maximum sequence length for query state.
+ max_length_k (`int`, *optional*):
+ Maximum sequence length for key state.
+ """
+ cumulative_seqlens_q: Optional[torch.LongTensor]
+ cumulative_seqlens_k: Optional[torch.LongTensor]
+ max_length_q: Optional[int]
+ max_length_k: Optional[int]
-_flash_supports_window = None
+def _process_flash_attention_kwargs(
+ query_length: int,
+ key_length: int,
+ is_causal: bool,
+ dropout: float = 0.0,
+ softmax_scale: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ use_top_left_mask: bool = False,
+ softcap: Optional[float] = None,
+ deterministic: Optional[bool] = None,
+ s_aux: Optional[torch.Tensor] = None,
+ supports_mapping: Optional[dict[str, bool]] = None,
+ **kwargs,
+):
+ """
+ Returns a set of kwargs that are passed down to the according flash attention function based on
+ requested features and whether it is supported - depends on the version and kernel implementation
+ which is dynamically configued at `lazy_import_flash_attention`. The (un)supported features can be
+ inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
-def is_flash_attn_available():
- return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
+ Args:
+ query_length (`int`):
+ Length of the query states
+ key_length (`int`):
+ Length of the key states
+ is_causal (`bool`):
+ Whether we perform causal (decoder) attention or full attention.
+ dropout (`float`):
+ Attention dropout.
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
+ sliding_window (`int`, *optional*):
+ The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
+ use_top_left_mask (`bool`):
+ Deprecated behavior of older versions of flash attention requiring different masking.
+ softcap (`float`, *optional*):
+ Softcap for the attention logits, used e.g. in gemma2.
+ deterministic (`bool`, *optional*):
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
+ s_aux (`torch.Tensor`, *optional*):
+ Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
+ Return:
+ flash_kwargs (`dict`):
+ A dict of kwargs that are requested and supported.
+ """
+ flash_kwargs = {
+ "causal": is_causal and not (use_top_left_mask and query_length == 1),
+ "softmax_scale": softmax_scale,
+ }
+ if supports_mapping["dropout_p"]:
+ flash_kwargs["dropout_p"] = dropout
-def flash_attn_supports_top_left_mask():
- if is_flash_attn_3_available():
- return False
- if is_flash_attn_2_available():
- return not is_flash_attn_greater_or_equal_2_10()
+ if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
+ flash_kwargs["window_size"] = (sliding_window, sliding_window)
- from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
+ if supports_mapping["deterministic"]:
+ flash_kwargs["deterministic"] = (
+ deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+ )
- return is_npu_fa2_top_left_aligned_causal_mask()
+ if supports_mapping["softcap"] and softcap is not None:
+ flash_kwargs["softcap"] = softcap
+ # Only within kernel implementation atm
+ if supports_mapping["s_aux"] and s_aux is not None:
+ flash_kwargs["s_aux"] = s_aux
-class FlashAttentionKwargs(TypedDict, total=False):
- cumulative_seqlens_q: Optional[torch.LongTensor]
- cumulative_seqlens_k: Optional[torch.LongTensor]
+ return flash_kwargs
def _flash_attention_forward(
@@ -414,100 +578,121 @@ def _flash_attention_forward(
implementation: Optional[str] = None,
**kwargs,
):
- if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
- flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation)
- globals()["_flash_fn"] = flash_fn
- globals()["_flash_varlen_fn"] = flash_varlen_fn
- globals()["_pad_fn"] = pad_fn
- globals()["_unpad_fn"] = unpad_fn
- globals()["_is_fa3"] = is_fa3
- flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters
- globals()["_flash_supports_window"] = flash_supports_window
- else:
- flash_fn = globals()["_flash_fn"]
- flash_varlen_fn = globals()["_flash_varlen_fn"]
- pad_fn = globals()["_pad_fn"]
- unpad_fn = globals()["_unpad_fn"]
- is_fa3 = globals()["_is_fa3"]
- flash_supports_window = globals()["_flash_supports_window"]
-
- causal = is_causal and not (use_top_left_mask and query_length == 1)
- use_sw = (
- (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`, *optional*):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ implementation (`str`, *optional*):
+ The attention implementation to use. If None, will default to the one based on the environment.
+ """
+ (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
+ implementation
)
- flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {}
- if not is_fa3:
- flash_kwargs["dropout_p"] = dropout
- if is_flash_attn_greater_or_equal("2.4.1"):
- det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
- flash_kwargs["deterministic"] = det
- if softcap is not None:
- flash_kwargs["softcap"] = softcap
- if "s_aux" in kwargs:
- flash_kwargs["s_aux"] = kwargs.get("s_aux")
+
+ # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype
)
- use_mask = position_ids is not None or all(
- k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]
+
+ # Extract the flash attention kwargs that have been requested (and are supported by the implementation)
+ flash_kwargs = process_flash_kwargs_fn(
+ query_length=query_length,
+ key_length=key_states.size(1),
+ is_causal=is_causal,
+ dropout=dropout,
+ softmax_scale=softmax_scale,
+ sliding_window=sliding_window,
+ use_top_left_mask=use_top_left_mask,
+ softcap=softcap,
+ deterministic=deterministic,
+ **kwargs,
+ )
+
+ # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
+ # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
+ # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
+ # use `flash_varlen_fn` knowing we already have all necessary the kwargs.
+ #
+ # NOTE: it is user's responsibility to take care of flattenning `position_ids` if that's needed by the model.
+ # See #39121 for more information.
+ is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
+ is_fa_with_varlen_kwargs = all(
+ kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
)
+
+ # Contains at least one padding token in the sequence
if attention_mask is not None:
- q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
+ q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
query_states, key_states, value_states, attention_mask, query_length, unpad_fn
)
- # TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p
+
+ # TODO for now this is required to work with
+ # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
if "mps" in str(q.device):
- cu_k = cu_k.clone()
+ cu_seq_lens_k = cu_seq_lens_k.clone()
+
out_unpad = flash_varlen_fn(
q,
k,
v,
- cu_seqlens_q=cu_q.to(torch.int32),
- cu_seqlens_k=cu_k.to(torch.int32),
- max_seqlen_q=mq,
- max_seqlen_k=mk,
- softmax_scale=softmax_scale,
- causal=causal,
+ cu_seqlens_q=cu_seq_lens_q,
+ cu_seqlens_k=cu_seq_lens_k,
+ max_seqlen_q=max_length_q,
+ max_seqlen_k=max_length_k,
**flash_kwargs,
)
if isinstance(out_unpad, tuple):
out_unpad = out_unpad[0]
- out = pad_fn(out_unpad, idx, query_states.shape[0], query_length)
- elif use_mask:
+
+ out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
+
+ # Padding free, i.e. sequences flattened into one total sequence
+ elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
if cu_seq_lens_q is None or cu_seq_lens_k is None:
- if position_ids is None:
- raise ValueError(
- "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed."
- )
- q, k, v, (cu_q, cu_k), (mq, mk) = _prepare_from_posids(
+ q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
query_states, key_states, value_states, position_ids, query_length=query_length
)
else:
q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
- mq, mk = max_length_q, max_length_k
- cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k
+
+ # TODO for now this is required to work with
+ # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
if "mps" in str(q.device):
- cu_k = cu_k.clone()
+ cu_seq_lens_k = cu_seq_lens_k.clone()
+
out = flash_varlen_fn(
q,
k,
v,
- cu_seqlens_q=cu_q.to(torch.int32),
- cu_seqlens_k=cu_k.to(torch.int32),
- max_seqlen_q=mq,
- max_seqlen_k=mk,
- softmax_scale=softmax_scale,
- causal=causal,
+ cu_seqlens_q=cu_seq_lens_q,
+ cu_seqlens_k=cu_seq_lens_k,
+ max_seqlen_q=max_length_q,
+ max_seqlen_k=max_length_k,
**flash_kwargs,
)
if isinstance(out, tuple):
out = out[0]
- out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1))
+
+ out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
+
+ # No padding
else:
- out = flash_fn(
- query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
- )
+ out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
+ if isinstance(out, tuple):
+ out = out[0]
- return out[0] if isinstance(out, tuple) else out
+ return out
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index b15183b4821e..b8a7d6a44024 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -74,6 +74,7 @@
)
from .loss.loss_utils import LOSS_MAPPING
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
+from .modeling_flash_attention_utils import lazy_import_flash_attention
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@@ -2126,7 +2127,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
_pp_plan = None
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
- # In practice, it means that they support attention interface functions, fully pass the kwargs
+ # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
_supports_attention_backend = False
_can_record_outputs = None
@@ -2748,6 +2749,7 @@ def _check_and_adjust_attn_implementation(
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
+ lazy_import_flash_attention(kernel)
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
@@ -2763,7 +2765,13 @@ def _check_and_adjust_attn_implementation(
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
return attn_implementation
else:
- return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
+ attn_implementation = self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
+
+ # preload flash attention here to allow compile with fullgraph
+ if applicable_attn_implementation.startswith("flash_attention"):
+ lazy_import_flash_attention(applicable_attn_implementation)
+
+ return attn_implementation
def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str:
requested_attention = "sdpa" if _requested_attention is None else _requested_attention
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 750c5c22324d..b7ca0e2d9b42 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -3483,92 +3483,107 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn:
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
+ # Custom kernel which needs the mask interface to be properly usable on these models
+ if not model_class._supports_attention_backend and not attn_implementation.startswith("flash_attention"):
+ self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.head_dim = 64 # fa2 does not always support arbitrary headim
- model = model_class(config)
-
- model.to(torch_device)
- model.to(torch.bfloat16)
- dummy_input = inputs_dict[model.main_input_name][:1]
- if dummy_input.dtype in [torch.float32, torch.float16]:
- dummy_input = dummy_input.to(torch.bfloat16)
-
- dummy_attention_mask = inputs_dict.get("attention_mask", None)
- if dummy_attention_mask is not None:
- dummy_attention_mask = dummy_attention_mask[:1]
- if padding_side == "left":
- dummy_attention_mask[:, 1:] = 1
- dummy_attention_mask[:, :1] = 0
- else:
- dummy_attention_mask[:, :-1] = 1
- dummy_attention_mask[:, -1:] = 0
- if model.config.is_encoder_decoder:
- decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+ # flash attention variants does not always support arbitrary headim
+ config = self._prepare_config_headdim(config, 16)
- outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
- model.set_attn_implementation(attn_implementation)
- outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
- else:
- outputs = model(dummy_input, output_hidden_states=True)
- model.set_attn_implementation(attn_implementation)
- outputs_fa = model(dummy_input, output_hidden_states=True)
+ # TODO it is unclear why saving and reloading with dtype works while
+ # casting with `.to(dtype=..., device=...)` does not.
+ # Discovered on tests with `Bart` models.
+ model = model_class(config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
- model.set_attn_implementation("sdpa")
- logits = (
- outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
- )
- logits_fa = (
- outputs_fa.hidden_states[-1]
- if not model.config.is_encoder_decoder
- else outputs_fa.decoder_hidden_states[-1]
- )
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
- assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
- if model.config.is_encoder_decoder:
- other_inputs = {
- "decoder_input_ids": decoder_input_ids,
- "decoder_attention_mask": dummy_attention_mask,
- "output_hidden_states": True,
- }
if dummy_attention_mask is not None:
- other_inputs["attention_mask"] = dummy_attention_mask
+ dummy_attention_mask = dummy_attention_mask[:1]
+ if padding_side == "left":
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+ else:
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
- outputs = model(dummy_input, **other_inputs)
- model.set_attn_implementation(attn_implementation)
- outputs_fa = model(dummy_input, **other_inputs)
- else:
- other_inputs = {
- "output_hidden_states": True,
- }
- if dummy_attention_mask is not None:
- other_inputs["attention_mask"] = dummy_attention_mask
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ model.set_attn_implementation(attn_implementation)
+ outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ model.set_attn_implementation(attn_implementation)
+ outputs_fa = model(dummy_input, output_hidden_states=True)
+
+ model.set_attn_implementation("sdpa")
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
- outputs = model(dummy_input, **other_inputs)
- model.set_attn_implementation(attn_implementation)
- outputs_fa = model(dummy_input, **other_inputs)
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
- model.set_attn_implementation("sdpa")
- logits = (
- outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
- )
- logits_fa = (
- outputs_fa.hidden_states[-1]
- if not model.config.is_encoder_decoder
- else outputs_fa.decoder_hidden_states[-1]
- )
+ if model.config.is_encoder_decoder:
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
- if padding_side == "left":
- assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+ outputs = model(dummy_input, **other_inputs)
+ model.set_attn_implementation(attn_implementation)
+ outputs_fa = model(dummy_input, **other_inputs)
+ else:
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ model.set_attn_implementation(attn_implementation)
+ outputs_fa = model(dummy_input, **other_inputs)
+
+ model.set_attn_implementation("sdpa")
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
- # check with inference + dropout
- model.train()
- model.set_attn_implementation(attn_implementation)
- _ = model(dummy_input, **other_inputs)
- else:
- assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ if padding_side == "left":
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ model.set_attn_implementation(attn_implementation)
+ _ = model(dummy_input, **other_inputs)
+ else:
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_kernels
@require_torch_gpu
@@ -4698,6 +4713,70 @@ def recursively_check(eager_outputs, exported_outputs):
is_tested = recursively_check(eager_outputs, exported_outputs)
self.assertTrue(is_tested, msg=f"No outputs were compared for {model_class.__name__}")
+ @staticmethod
+ def _prepare_config_headdim(config, requested_dim):
+ """
+ This method allows to update the head dim for all model types including
+ composite models and models that do not support head dim by themselves.
+
+ Why? A lot of kernels including flex attention rely on triton for compilation.
+ However, triton cannot handle hidden dimensions of less than 16 for example.
+ (There are many more examples especially now that the `kernels` library is
+ supported)
+ """
+
+ def update_config_headdim(config, requested_dim):
+ # Flex Attention cannot use dropout
+ if hasattr(config, "attention_dropout"):
+ config.attention_dropout = 0
+ if hasattr(config, "attention_probs_dropout_prob"):
+ config.attention_probs_dropout_prob = 0
+
+ # Update the head dim and try to update hidden size as well if present in config
+ # NOTE: some models may have none if the values in sub-config, thus we check for `Noneness`
+ head_dim = None
+ if hasattr(config, "head_dim") and config.head_dim is not None:
+ head_dim = config.head_dim
+ config.head_dim = max(requested_dim, config.head_dim)
+
+ cross_head_dim = None
+ if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None:
+ cross_head_dim = config.cross_head_dim
+ config.cross_head_dim = max(requested_dim, config.cross_head_dim)
+
+ if (
+ getattr(config, "hidden_size", None) is not None
+ and getattr(config, "num_attention_heads", None) is not None
+ ):
+ head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
+ config.hidden_size *= max(requested_dim // head_dim, 1)
+
+ if (
+ getattr(config, "decoder_hidden_size", None) is not None
+ and getattr(config, "decoder_num_attention_heads", None) is not None
+ ):
+ decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
+ config.decoder_hidden_size *= max(requested_dim // decoder_head_dim, 1)
+
+ if (
+ getattr(config, "cross_hidden_size", None) is not None
+ and getattr(config, "cross_num_attention_heads", None) is not None
+ ):
+ cross_head_dim = (
+ cross_head_dim
+ if cross_head_dim is not None
+ else config.cross_hidden_size // config.cross_num_attention_heads
+ )
+ config.cross_hidden_size *= max(requested_dim // cross_head_dim, 1)
+
+ # Update config values
+ update_config_headdim(config, requested_dim)
+ for key in config.sub_configs:
+ sub_config = getattr(config, key)
+ update_config_headdim(sub_config, requested_dim)
+
+ return config
+
@require_torch_gpu
def test_flex_attention_with_grads(self):
for model_class in self.all_model_classes:
@@ -4711,59 +4790,8 @@ def test_flex_attention_with_grads(self):
):
self.skipTest(reason="At least some parts of this model do not support flex attention")
- def update_config_for_flex(config):
- # Flex Attention cannot use dropout
- if hasattr(config, "attention_dropout"):
- config.attention_dropout = 0
- if hasattr(config, "attention_probs_dropout_prob"):
- config.attention_probs_dropout_prob = 0
-
- # Flex attention relies on triton on compilation
- # However, triton cannot handle hidden dimensions of less than 16
- # --> forcing at least a hidden dim of 16
-
- # Update the head dim and try to update hidden size as well if present in config
- # NOTE: some models may have none if the values in sub-config, thus we check for `Noneness`
- head_dim = None
- if hasattr(config, "head_dim") and config.head_dim is not None:
- head_dim = config.head_dim
- config.head_dim = max(16, config.head_dim)
-
- cross_head_dim = None
- if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None:
- cross_head_dim = config.cross_head_dim
- config.cross_head_dim = max(16, config.cross_head_dim)
-
- if (
- getattr(config, "hidden_size", None) is not None
- and getattr(config, "num_attention_heads", None) is not None
- ):
- head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
- config.hidden_size *= max(16 // head_dim, 1)
-
- if (
- getattr(config, "decoder_hidden_size", None) is not None
- and getattr(config, "decoder_num_attention_heads", None) is not None
- ):
- decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
- config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
-
- if (
- getattr(config, "cross_hidden_size", None) is not None
- and getattr(config, "cross_num_attention_heads", None) is not None
- ):
- cross_head_dim = (
- cross_head_dim
- if cross_head_dim is not None
- else config.cross_hidden_size // config.cross_num_attention_heads
- )
- config.cross_hidden_size *= max(16 // cross_head_dim, 1)
-
# Set default attention to flex and update config values
- update_config_for_flex(config)
- for key in config.sub_configs:
- sub_config = getattr(config, key)
- update_config_for_flex(sub_config)
+ config = self._prepare_config_headdim(config, 16) # specific to triton
if model_class._can_set_attn_implementation():
model = model_class(config).to(device=torch_device)