| |
| |
| |
| |
| @@ -678,9 +678,10 @@ def prepare_inputs_for_generation( |
| if encoder_attention_mask is not None: |
| model_inputs["attention_mask"] = encoder_attention_mask |
| |
| + # 7. Prepare kwargs for flash attention to avoid recomputations |
| if "flash" in self.config._attn_implementation and self._supports_attention_backend: |
| - cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids( |
| - position_ids, is_packed_sequence=False |
| + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids( |
| + model_inputs["position_ids"], is_packed_sequence=False |
| ) |
| model_inputs.update( |
| cu_seq_lens_q=cu_seq_lens_q.to(self.device), |
| @@ -689,12 +690,12 @@ def prepare_inputs_for_generation( |
| max_length_k=max_length_k, |
| ) |
| |
| - # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). |
| + # 8. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). |
| for key, value in kwargs.items(): |
| if key not in model_inputs: |
| model_inputs[key] = value |
| |
| - # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) |
| + # 9. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) |
| model_inputs.pop("labels", None) |
| return model_inputs |
| |
| |
| |
| |
| |
| @@ -10,20 +10,16 @@ |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| +import math |
| import os |
| |
| import torch |
| -import torch.nn.functional as F |
| |
| from ..utils.import_utils import is_torch_npu_available |
| |
| |
| if is_torch_npu_available(): |
| - import math |
| - |
| - import torch_npu |
| - from einops import rearrange, repeat |
| - from torch_npu import npu_rotary_mul |
| + from torch_npu import npu_fusion_attention, npu_rotary_mul |
| |
| |
| # FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. |
| @@ -52,117 +48,6 @@ def is_npu_fa2_top_left_aligned_causal_mask(): |
| return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False |
| |
| |
| -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py |
| -class IndexFirstAxis(torch.autograd.Function): |
| - @staticmethod |
| - def forward(ctx, input, indices): |
| - ctx.save_for_backward(indices) |
| - assert input.ndim >= 2 |
| - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] |
| - second_dim = other_shape.numel() |
| - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. |
| - # return input[indices] |
| - return torch.gather( |
| - rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) |
| - ).reshape(-1, *other_shape) |
| - |
| - @staticmethod |
| - def backward(ctx, grad_output): |
| - (indices,) = ctx.saved_tensors |
| - assert grad_output.ndim >= 2 |
| - other_shape = grad_output.shape[1:] |
| - grad_output = rearrange(grad_output, "b ... -> b (...)") |
| - grad_input = torch.zeros( |
| - [ctx.first_axis_dim, grad_output.shape[1]], |
| - device=grad_output.device, |
| - dtype=grad_output.dtype, |
| - ) |
| - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. |
| - # grad_input[indices] = grad_output |
| - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) |
| - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None |
| - |
| - |
| -index_first_axis = IndexFirstAxis.apply |
| - |
| - |
| -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py |
| -class IndexPutFirstAxis(torch.autograd.Function): |
| - @staticmethod |
| - def forward(ctx, values, indices, first_axis_dim): |
| - ctx.save_for_backward(indices) |
| - assert indices.ndim == 1 |
| - assert values.ndim >= 2 |
| - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) |
| - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. |
| - output[indices] = values |
| - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) |
| - return output |
| - |
| - @staticmethod |
| - def backward(ctx, grad_output): |
| - (indices,) = ctx.saved_tensors |
| - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. |
| - grad_values = grad_output[indices] |
| - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) |
| - return grad_values, None, None |
| - |
| - |
| -index_put_first_axis = IndexPutFirstAxis.apply |
| - |
| - |
| -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py |
| -def pad_input(hidden_states, indices, batch, seqlen): |
| - """ |
| - Arguments: |
| - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| - batch: int, batch size for the padded sequence. |
| - seqlen: int, maximum sequence length for the padded sequence. |
| - Return: |
| - hidden_states: (batch, seqlen, ...) |
| - """ |
| - # dim = hidden_states.shape[-1] |
| - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) |
| - # output[indices] = hidden_states |
| - output = index_put_first_axis(hidden_states, indices, batch * seqlen) |
| - return rearrange(output, "(b s) ... -> b s ...", b=batch) |
| - |
| - |
| -# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py |
| -def unpad_input(hidden_states, attention_mask, unused_mask=None): |
| - """ |
| - Arguments: |
| - hidden_states: (batch, seqlen, ...) |
| - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. |
| - Return: |
| - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. |
| - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. |
| - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| - max_seqlen_in_batch: int |
| - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. |
| - """ |
| - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask |
| - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) |
| - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() |
| - max_seqlen_in_batch = seqlens_in_batch.max().item() |
| - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the |
| - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim |
| - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to |
| - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, |
| - # so we write custom forward and backward to make it a bit faster. |
| - return ( |
| - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), |
| - indices, |
| - cu_seqlens, |
| - max_seqlen_in_batch, |
| - used_seqlens_in_batch, |
| - ) |
| - |
| - |
| def npu_flash_attn_func( |
| q, |
| k, |
| @@ -179,11 +64,11 @@ def npu_flash_attn_func( |
| |
| if not causal: |
| head_num = q.shape[2] |
| - output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] |
| + output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] |
| else: |
| attn_mask_npu = get_attn_mask_npu(q.device) |
| head_num = q.shape[2] |
| - output = torch_npu.npu_fusion_attention( |
| + output = npu_fusion_attention( |
| q, |
| k, |
| v, |
| @@ -218,7 +103,7 @@ def npu_flash_attn_varlen_func( |
| |
| if not causal: |
| head_num = q.shape[1] |
| - output = torch_npu.npu_fusion_attention( |
| + output = npu_fusion_attention( |
| q, |
| k, |
| v, |
| @@ -234,7 +119,7 @@ def npu_flash_attn_varlen_func( |
| else: |
| attn_mask_npu = get_attn_mask_npu(q.device) |
| head_num = q.shape[1] |
| - output = torch_npu.npu_fusion_attention( |
| + output = npu_fusion_attention( |
| q, |
| k, |
| v, |
| @@ -267,8 +152,3 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs): |
| sin = sin.unsqueeze(0).unsqueeze(2) |
| |
| return npu_rotary_mul(x, cos, sin) |
| - |
| - |
| -def get_npu_flash_attn_funcs(): |
| - # return flash attention related functions used for Ascend NPU in order |
| - return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False |
| |
| |
| |
| |
| @@ -1,4 +1,4 @@ |
| -# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. |
| +# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| @@ -14,17 +14,15 @@ |
| import inspect |
| import os |
| import warnings |
| +from functools import partial |
| from typing import Optional, TypedDict |
| |
| import torch |
| import torch.nn.functional as F |
| |
| -from transformers.utils.import_utils import is_kernels_available |
| - |
| from .utils import ( |
| is_flash_attn_2_available, |
| is_flash_attn_3_available, |
| - is_flash_attn_greater_or_equal, |
| is_flash_attn_greater_or_equal_2_10, |
| is_torch_npu_available, |
| logging, |
| @@ -34,18 +32,135 @@ |
| logger = logging.get_logger(__name__) |
| |
| |
| -def _index_first_axis(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: |
| - reshaped = tensor.contiguous().reshape(-1, *tensor.shape[2:]) |
| - return reshaped[indices] |
| +# TODO Deprecate when all models have the attention interface |
| +def flash_attn_supports_top_left_mask(): |
| + if is_flash_attn_3_available(): |
| + return False |
| + if is_flash_attn_2_available(): |
| + return not is_flash_attn_greater_or_equal_2_10() |
| + |
| + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask |
| + |
| + return is_npu_fa2_top_left_aligned_causal_mask() |
| + |
| + |
| +# TODO Deprecate when all models have the attention interface |
| +def is_flash_attn_available(): |
| + return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() |
| + |
| + |
| +# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves |
| +_flash_fn = None |
| +_flash_varlen_fn = None |
| +_pad_fn = None |
| +_unpad_fn = None |
| + |
| +# function that processes kwargs, generalized to handle any supported kwarg within the function |
| +_process_flash_kwargs_fn = None |
| +# exceptions where hf API doesn't match the original flash attention API |
| +_hf_api_to_flash_mapping = { |
| + "dropout": "dropout_p", |
| + "sliding_window": "window_size", |
| +} |
| + |
| + |
| +def _lazy_imports(implementation: Optional[str]): |
| + """ |
| + Lazy loads the respective flash attention implementations. |
| + |
| + Return: |
| + flash_attn_func: The base flash attention function. |
| + flash_attn_varlen_func: The flash attention function supporting variable sequence lengths, |
| + e.g. for padding-free training. |
| + pad_input: The function to pad inputs into one sequence and returning the respective kwargs. |
| + unpad_input: The function to unpad outputs based on the kwargs (from pad_input). |
| + """ |
| + is_fa2 = is_flash_attn_2_available() |
| + is_fa3 = is_flash_attn_3_available() |
| + if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3): |
| + from flash_attn import flash_attn_func, flash_attn_varlen_func |
| + from flash_attn.bert_padding import pad_input, unpad_input |
| + else: |
| + pad_input, unpad_input = _pad_input, _unpad_input |
| + if implementation == "flash_attention_3" or (implementation is None and is_fa3): |
| + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func |
| + elif is_torch_npu_available(): |
| + from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func |
| + from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func |
| + # Kernels fallback |
| + else: |
| + flash_attn_func = getattr(implementation, "flash_attn_func", None) |
| + flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None) |
| + if flash_attn_varlen_func is None or flash_attn_func is None: |
| + raise ValueError( |
| + f"Could not find the currently requested flash attention implementation at `{implementation}`." |
| + f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`." |
| + ) |
| + |
| + return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input |
| + |
| + |
| +def _lazy_define_process_function(flash_function): |
| + """ |
| + Depending on the version and kernel some features are not supported. Due to limitations in |
| + `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported |
| + within `_process_flash_attention_kwargs`. |
| + |
| + NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`. |
| + This might be confusing for kwargs that we use in any case, e.g. `is_causal`. |
| + """ |
| + global _process_flash_kwargs_fn, _hf_api_to_flash_mapping |
| + |
| + flash_parameters = inspect.signature(flash_function).parameters |
| + process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters |
| + |
| + supports_mapping = {} |
| + for param in process_parameters: |
| + fa_param = _hf_api_to_flash_mapping.get(param, param) |
| + supports_mapping[fa_param] = fa_param in flash_parameters |
| + |
| + return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping) |
| + |
| + |
| +def lazy_import_flash_attention(implementation: Optional[str]): |
| + """ |
| + Lazy loading flash attention and returning the respective functions + flags back |
| + |
| + NOTE: For fullgraph, this needs to be called before compile while no fullgraph can |
| + can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`. |
| + """ |
| + global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn |
| + if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]): |
| + _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation) |
| + |
| + global _process_flash_kwargs_fn |
| + if _process_flash_kwargs_fn is None: |
| + _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) |
| |
| + return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn |
| |
| -def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): |
| + |
| +def _index_first_axis(tensor, indices): |
| + """ |
| + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, |
| + after flattening the first two dimensions of the tensor. This is functionally equivalent to |
| + FA2's `index_first_axis` and replaces the need to import it. |
| """ |
| - FA3-compatible unpad_input function. |
| + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first |
| + # two dimensions to get (total_tokens, ...) before indexing. |
| + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) |
| + return reshaped_tensor[indices] |
| + |
| + |
| +def _unpad_input(hidden_states, attention_mask, unused_mask=None): |
| + """ |
| + unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. |
| + |
| Arguments: |
| hidden_states: (batch, seqlen, ...) |
| attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. |
| + |
| Return: |
| hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. |
| indices: (total_nnz), the indices of masked tokens from the flattened input sequence. |
| @@ -69,14 +184,16 @@ def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None): |
| ) |
| |
| |
| -def _fa3_pad_input(hidden_states, indices, batch, seqlen): |
| +def _pad_input(hidden_states, indices, batch, seqlen): |
| """ |
| - FA3-compatible pad_input function. |
| + pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. |
| + |
| Arguments: |
| hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| batch: int, batch size for the padded sequence. |
| seqlen: int, maximum sequence length for the padded sequence. |
| + |
| Return: |
| hidden_states: (batch, seqlen, ...) |
| """ |
| @@ -89,9 +206,11 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen): |
| def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: |
| """ |
| Retrieves indexing data required to repad unpadded (ragged) tensors. |
| + |
| Arguments: |
| attention_mask (`torch.Tensor`): |
| Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| + |
| Return: |
| indices (`torch.Tensor`): |
| The indices of non-masked tokens from the flattened input sequence. |
| @@ -125,6 +244,7 @@ def _upad_input( |
| Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. |
| This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary |
| tensors for query, key, value tensors. |
| + |
| Arguments: |
| query_layer (`torch.Tensor`): |
| Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). |
| @@ -138,6 +258,7 @@ def _upad_input( |
| Target length. |
| unpad_input_func: |
| The function to use for unpadding the input tensors. |
| + |
| Return: |
| query_layer (`torch.Tensor`): |
| Query state without padding. Shape: (total_target_length, num_heads, head_dim). |
| @@ -193,13 +314,15 @@ def _upad_input( |
| def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True): |
| """ |
| This function returns all the necessary kwargs to call `flash_attn_varlen_func` |
| - extracted from position_ids.The `position_ids` can be either packed sequence or |
| - the usual padded position ids, for example in inference time.. |
| + extracted from position_ids. The `position_ids` can be either packed sequence or |
| + the usual padded position ids, for example in inference time. |
| + |
| Arguments: |
| position_ids (`torch.Tensor`): |
| Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| is_packed_sequence (`bool`, *optional*, defaults to `True`): |
| Whether the input position ids are a packed sequence or not. |
| + |
| Return: |
| (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): |
| The cumulative sequence lengths for the target (query) and source (key, value), used to index into |
| @@ -212,19 +335,21 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = |
| # In that case the position ids will not always start with `0` and we need a better way to infer |
| # cumulative seq lengths. |
| if not is_packed_sequence: |
| - tensor_kws = {"dtype": torch.int32, "device": position_ids.device} |
| - last_position_ids = position_ids[:, -1] |
| + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} |
| |
| + last_position_ids = position_ids[:, -1] |
| + q_len = ( |
| + torch.ones(position_ids.size(0), **tensor_kwargs) |
| + if position_ids.shape[-1] == 1 |
| + else last_position_ids.add(1) |
| + ) |
| + cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kwargs), q_len.cumsum(0).to(torch.int32)], 0) |
| cu_seq_lens_k = torch.cat( |
| - [torch.zeros(1, **tensor_kws), last_position_ids.cumsum(0).add(1).to(torch.int32)], 0 |
| + [torch.zeros(1, **tensor_kwargs), last_position_ids.add(1).cumsum(0).to(torch.int32)], 0 |
| ) |
| - max_length_k = int(last_position_ids.max()) + 1 |
| |
| - q_len = ( |
| - torch.ones(position_ids.size(0), **tensor_kws) if position_ids.shape[-1] == 1 else last_position_ids.add(1) |
| - ) |
| - cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0).to(torch.int32)], 0) |
| max_length_q = int(q_len.max()) |
| + max_length_k = int(last_position_ids.max()) + 1 |
| else: |
| position_ids = position_ids.flatten() |
| indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) |
| @@ -237,16 +362,18 @@ def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = |
| ) |
| cu_seq_lens_k = cu_seq_lens_q |
| |
| + # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 |
| + # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing |
| + # for some models (e.g. qwen2-vl). |
| + max_length_q = cu_seq_lens_q.diff().max() |
| # NOTE: With torch compile, this will cause a graph break if you don't set |
| # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call |
| # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. |
| # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` |
| # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. |
| - # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 |
| - # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing |
| - # for some models (e.g. qwen2-vl). |
| - max_length_q = cu_seq_lens_q.diff().max().item() |
| + max_length_q = max_length_q.item() |
| max_length_k = max_length_q |
| + |
| return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) |
| |
| |
| @@ -256,6 +383,7 @@ def _prepare_from_posids(query, key, value, position_ids, query_length): |
| All three query, key, value states will be flattened. |
| Cumulative lengths of each examples in the batch will be extracted from position_ids. |
| NOTE: ideally cumulative lengths should be prepared at the data collator stage |
| + |
| Arguments: |
| query (`torch.Tensor`): |
| Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). |
| @@ -267,6 +395,7 @@ def _prepare_from_posids(query, key, value, position_ids, query_length): |
| Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. |
| query_length (`int`): |
| Sequence length of the input queries. |
| + |
| Return: |
| query (`torch.Tensor`): |
| Query state without padding. Shape: (total_target_length, num_heads, head_dim). |
| @@ -275,121 +404,156 @@ def _prepare_from_posids(query, key, value, position_ids, query_length): |
| value (`torch.Tensor`): |
| Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). |
| (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): |
| - The cumulative sequence lengths for the target (query) and source (key, value), used to index into |
| - ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). |
| + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): |
| - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, |
| - `max_seqlen_in_batch_k` for the source sequence i.e. key/value). |
| + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). |
| """ |
| kv_length = key.shape[1] |
| + is_packed_sequence = query_length == kv_length |
| + |
| query = query.contiguous().view(-1, query.size(-2), query.size(-1)) |
| key = key.contiguous().view(-1, key.size(-2), key.size(-1)) |
| value = value.contiguous().view(-1, value.size(-2), value.size(-1)) |
| - is_packed_sequence = query_length == kv_length |
| |
| - cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids( |
| + (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids( |
| position_ids, is_packed_sequence=is_packed_sequence |
| ) |
| + |
| return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)) |
| |
| |
| def _prepare_flash_attention_from_position_ids(query, key, value, position_ids): |
| warnings.warn( |
| - "prepare_fa2_from_position_ids is deprecated, use _prepare_from_posids", |
| + "The function `_prepare_flash_attention_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_from_posids` instead.", |
| FutureWarning, |
| ) |
| return _prepare_from_posids(query, key, value, position_ids) |
| |
| |
| -def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = None): |
| +def _is_packed_sequence(position_ids, batch_size): |
| + """ |
| + Check the position ids whether packed sequences are indicated or not |
| + 1. Position ids exist |
| + 2. Flattened sequences only are supported |
| + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences |
| + """ |
| + if position_ids is None: |
| + return False |
| + |
| + increasing_position_sequences = ( |
| + torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() |
| + ) |
| + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() |
| + |
| + |
| +def fa_peft_integration_check( |
| + q: torch.Tensor, |
| + k: torch.Tensor, |
| + v: torch.Tensor, |
| + target_dtype: Optional[torch.dtype] = None, |
| +): |
| + """ |
| + PEFT usually casts the layer norms in float32 for training stability reasons |
| + therefore the input hidden states gets silently casted in float32. Hence, we need |
| + cast them back in float16 / bfloat16 just to be sure everything works as expected. |
| + This might slowdown training & inference so it is recommended to not cast the LayerNorms! |
| + """ |
| if target_dtype and q.dtype == torch.float32: |
| logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.") |
| q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) |
| return q, k, v |
| |
| |
| -def _lazy_imports(impl: Optional[str]): |
| - # returns funcs and pad/unpad based on impl |
| - is_fa2 = is_flash_attn_2_available() |
| - is_fa3 = is_flash_attn_3_available() |
| - if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3): |
| - try: |
| - from flash_attn import flash_attn_func, flash_attn_varlen_func |
| - from flash_attn.bert_padding import pad_input, unpad_input |
| - |
| - return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, False |
| - |
| - except ImportError as e: |
| - if not globals().get("use_remote_fa2", None): |
| - use_remote_fa2 = ( |
| - input( |
| - "Unable to import the official flash attention, do you want to try to use `kernels-community/flash-attn` (trust remote code) Yes or No? " |
| - ) |
| - .strip() |
| - .lower() |
| - ) |
| - globals()["use_remote_fa2"] = use_remote_fa2 in {"yes", "y", "1"} |
| - if globals()["use_remote_fa2"]: |
| - if not is_kernels_available(): |
| - raise ImportError("You need to install kernels: `pip install kernels`") |
| - from kernels import get_kernel |
| - |
| - impl = get_kernel("kernels-community/flash-attn") |
| - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input |
| - return ( |
| - getattr(impl, "flash_attn_func", None), |
| - getattr(impl, "flash_attn_varlen_func"), |
| - pad_input, |
| - unpad_input, |
| - True, |
| - ) |
| - |
| - else: |
| - raise ImportError( |
| - "Failed to import flash attention 2, please install it or use another implementation." |
| - ) from e |
| - elif is_torch_npu_available(): |
| - # get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU |
| - from .integrations.npu_flash_attention import get_npu_flash_attn_funcs |
| - |
| - return get_npu_flash_attn_funcs() |
| - elif impl == "flash_attention_3" or (impl is None and is_fa3): |
| - from flash_attn_interface import flash_attn_func, flash_attn_varlen_func |
| - |
| - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input |
| - return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input, True |
| - else: |
| - pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input |
| - return ( |
| - getattr(impl, "flash_attn_func", None), |
| - getattr(impl, "flash_attn_varlen_func"), |
| - pad_input, |
| - unpad_input, |
| - True, |
| - ) |
| +class FlashAttentionKwargs(TypedDict, total=False): |
| + """ |
| + Keyword arguments for Flash Attention with Compile. |
| + |
| + Attributes: |
| + cumulative_seqlens_q (`torch.LongTensor`, *optional*) |
| + Gets cumulative sequence length for query state. |
| + cumulative_seqlens_k (`torch.LongTensor`, *optional*) |
| + Gets cumulative sequence length for key state. |
| + max_length_q (`int`, *optional*): |
| + Maximum sequence length for query state. |
| + max_length_k (`int`, *optional*): |
| + Maximum sequence length for key state. |
| + """ |
| |
| + cumulative_seqlens_q: Optional[torch.LongTensor] |
| + cumulative_seqlens_k: Optional[torch.LongTensor] |
| + max_length_q: Optional[int] |
| + max_length_k: Optional[int] |
| |
| -_flash_supports_window = None |
| |
| +def _process_flash_attention_kwargs( |
| + query_length: int, |
| + key_length: int, |
| + is_causal: bool, |
| + dropout: float = 0.0, |
| + softmax_scale: Optional[float] = None, |
| + sliding_window: Optional[int] = None, |
| + use_top_left_mask: bool = False, |
| + softcap: Optional[float] = None, |
| + deterministic: Optional[bool] = None, |
| + s_aux: Optional[torch.Tensor] = None, |
| + supports_mapping: Optional[dict[str, bool]] = None, |
| + **kwargs, |
| +): |
| + """ |
| + Returns a set of kwargs that are passed down to the according flash attention function based on |
| + requested features and whether it is supported - depends on the version and kernel implementation |
| + which is dynamically configued at `lazy_import_flash_attention`. The (un)supported features can be |
| + inspected in `supports_mapping`, see `_lazy_define_process_function` for more details. |
| |
| -def is_flash_attn_available(): |
| - return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() |
| + Args: |
| + query_length (`int`): |
| + Length of the query states |
| + key_length (`int`): |
| + Length of the key states |
| + is_causal (`bool`): |
| + Whether we perform causal (decoder) attention or full attention. |
| + dropout (`float`): |
| + Attention dropout. |
| + softmax_scale (`float`, *optional*): |
| + The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`. |
| + sliding_window (`int`, *optional*): |
| + The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back. |
| + use_top_left_mask (`bool`): |
| + Deprecated behavior of older versions of flash attention requiring different masking. |
| + softcap (`float`, *optional*): |
| + Softcap for the attention logits, used e.g. in gemma2. |
| + deterministic (`bool`, *optional*): |
| + Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. |
| + s_aux (`torch.Tensor`, *optional*): |
| + Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. |
| + Return: |
| + flash_kwargs (`dict`): |
| + A dict of kwargs that are requested and supported. |
| + """ |
| + flash_kwargs = { |
| + "causal": is_causal and not (use_top_left_mask and query_length == 1), |
| + "softmax_scale": softmax_scale, |
| + } |
| |
| + if supports_mapping["dropout_p"]: |
| + flash_kwargs["dropout_p"] = dropout |
| |
| -def flash_attn_supports_top_left_mask(): |
| - if is_flash_attn_3_available(): |
| - return False |
| - if is_flash_attn_2_available(): |
| - return not is_flash_attn_greater_or_equal_2_10() |
| + if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window: |
| + flash_kwargs["window_size"] = (sliding_window, sliding_window) |
| |
| - from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask |
| + if supports_mapping["deterministic"]: |
| + flash_kwargs["deterministic"] = ( |
| + deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" |
| + ) |
| |
| - return is_npu_fa2_top_left_aligned_causal_mask() |
| + if supports_mapping["softcap"] and softcap is not None: |
| + flash_kwargs["softcap"] = softcap |
| |
| + # Only within kernel implementation atm |
| + if supports_mapping["s_aux"] and s_aux is not None: |
| + flash_kwargs["s_aux"] = s_aux |
| |
| -class FlashAttentionKwargs(TypedDict, total=False): |
| - cumulative_seqlens_q: Optional[torch.LongTensor] |
| - cumulative_seqlens_k: Optional[torch.LongTensor] |
| + return flash_kwargs |
| |
| |
| def _flash_attention_forward( |
| @@ -414,100 +578,121 @@ def _flash_attention_forward( |
| implementation: Optional[str] = None, |
| **kwargs, |
| ): |
| - if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")): |
| - flash_fn, flash_varlen_fn, pad_fn, unpad_fn, is_fa3 = _lazy_imports(implementation) |
| - globals()["_flash_fn"] = flash_fn |
| - globals()["_flash_varlen_fn"] = flash_varlen_fn |
| - globals()["_pad_fn"] = pad_fn |
| - globals()["_unpad_fn"] = unpad_fn |
| - globals()["_is_fa3"] = is_fa3 |
| - flash_supports_window = "window_size" in inspect.signature(flash_varlen_fn).parameters |
| - globals()["_flash_supports_window"] = flash_supports_window |
| - else: |
| - flash_fn = globals()["_flash_fn"] |
| - flash_varlen_fn = globals()["_flash_varlen_fn"] |
| - pad_fn = globals()["_pad_fn"] |
| - unpad_fn = globals()["_unpad_fn"] |
| - is_fa3 = globals()["_is_fa3"] |
| - flash_supports_window = globals()["_flash_supports_window"] |
| - |
| - causal = is_causal and not (use_top_left_mask and query_length == 1) |
| - use_sw = ( |
| - (_flash_supports_window or flash_supports_window) and sliding_window and key_states.shape[1] > sliding_window |
| + """ |
| + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| + first unpad the input, then computes the attention scores and pad the final attention scores. |
| + |
| + (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`. |
| + |
| + Args: |
| + query_states (`torch.Tensor`): |
| + Input query states to be passed to Flash Attention API |
| + key_states (`torch.Tensor`): |
| + Input key states to be passed to Flash Attention API |
| + value_states (`torch.Tensor`): |
| + Input value states to be passed to Flash Attention API |
| + attention_mask (`torch.Tensor`, *optional*): |
| + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| + position of padding tokens and 1 for the position of non-padding tokens. |
| + implementation (`str`, *optional*): |
| + The attention implementation to use. If None, will default to the one based on the environment. |
| + """ |
| + (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention( |
| + implementation |
| ) |
| - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sw else {} |
| - if not is_fa3: |
| - flash_kwargs["dropout_p"] = dropout |
| - if is_flash_attn_greater_or_equal("2.4.1"): |
| - det = deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" |
| - flash_kwargs["deterministic"] = det |
| - if softcap is not None: |
| - flash_kwargs["softcap"] = softcap |
| - if "s_aux" in kwargs: |
| - flash_kwargs["s_aux"] = kwargs.get("s_aux") |
| + |
| + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op |
| query_states, key_states, value_states = fa_peft_integration_check( |
| query_states, key_states, value_states, target_dtype |
| ) |
| - use_mask = position_ids is not None or all( |
| - k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k] |
| + |
| + # Extract the flash attention kwargs that have been requested (and are supported by the implementation) |
| + flash_kwargs = process_flash_kwargs_fn( |
| + query_length=query_length, |
| + key_length=key_states.size(1), |
| + is_causal=is_causal, |
| + dropout=dropout, |
| + softmax_scale=softmax_scale, |
| + sliding_window=sliding_window, |
| + use_top_left_mask=use_top_left_mask, |
| + softcap=softcap, |
| + deterministic=deterministic, |
| + **kwargs, |
| + ) |
| + |
| + # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: |
| + # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. |
| + # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to |
| + # use `flash_varlen_fn` knowing we already have all necessary the kwargs. |
| + # |
| + # NOTE: it is user's responsibility to take care of flattenning `position_ids` if that's needed by the model. |
| + # See #39121 for more information. |
| + is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0)) |
| + is_fa_with_varlen_kwargs = all( |
| + kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) |
| ) |
| + |
| + # Contains at least one padding token in the sequence |
| if attention_mask is not None: |
| - q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input( |
| + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( |
| query_states, key_states, value_states, attention_mask, query_length, unpad_fn |
| ) |
| - # TODO for now this is required to work with https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.p |
| + |
| + # TODO for now this is required to work with |
| + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py |
| if "mps" in str(q.device): |
| - cu_k = cu_k.clone() |
| + cu_seq_lens_k = cu_seq_lens_k.clone() |
| + |
| out_unpad = flash_varlen_fn( |
| q, |
| k, |
| v, |
| - cu_seqlens_q=cu_q.to(torch.int32), |
| - cu_seqlens_k=cu_k.to(torch.int32), |
| - max_seqlen_q=mq, |
| - max_seqlen_k=mk, |
| - softmax_scale=softmax_scale, |
| - causal=causal, |
| + cu_seqlens_q=cu_seq_lens_q, |
| + cu_seqlens_k=cu_seq_lens_k, |
| + max_seqlen_q=max_length_q, |
| + max_seqlen_k=max_length_k, |
| **flash_kwargs, |
| ) |
| if isinstance(out_unpad, tuple): |
| out_unpad = out_unpad[0] |
| - out = pad_fn(out_unpad, idx, query_states.shape[0], query_length) |
| - elif use_mask: |
| + |
| + out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length) |
| + |
| + # Padding free, i.e. sequences flattened into one total sequence |
| + elif is_fa_with_varlen_kwargs or is_fa_with_position_ids: |
| if cu_seq_lens_q is None or cu_seq_lens_k is None: |
| - if position_ids is None: |
| - raise ValueError( |
| - "Position ids should be passed if the attention mask is not passed and the cu_seq-lens are not passed." |
| - ) |
| - q, k, v, (cu_q, cu_k), (mq, mk) = _prepare_from_posids( |
| + q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids( |
| query_states, key_states, value_states, position_ids, query_length=query_length |
| ) |
| else: |
| q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) |
| k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) |
| v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) |
| - mq, mk = max_length_q, max_length_k |
| - cu_q, cu_k = cu_seq_lens_q, cu_seq_lens_k |
| + |
| + # TODO for now this is required to work with |
| + # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py |
| if "mps" in str(q.device): |
| - cu_k = cu_k.clone() |
| + cu_seq_lens_k = cu_seq_lens_k.clone() |
| + |
| out = flash_varlen_fn( |
| q, |
| k, |
| v, |
| - cu_seqlens_q=cu_q.to(torch.int32), |
| - cu_seqlens_k=cu_k.to(torch.int32), |
| - max_seqlen_q=mq, |
| - max_seqlen_k=mk, |
| - softmax_scale=softmax_scale, |
| - causal=causal, |
| + cu_seqlens_q=cu_seq_lens_q, |
| + cu_seqlens_k=cu_seq_lens_k, |
| + max_seqlen_q=max_length_q, |
| + max_seqlen_k=max_length_k, |
| **flash_kwargs, |
| ) |
| if isinstance(out, tuple): |
| out = out[0] |
| - out = out.view(query_states.shape[0], -1, out.size(-2), out.size(-1)) |
| + |
| + out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1)) |
| + |
| + # No padding |
| else: |
| - out = flash_fn( |
| - query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs |
| - ) |
| + out = flash_fn(query_states, key_states, value_states, **flash_kwargs) |
| + if isinstance(out, tuple): |
| + out = out[0] |
| |
| - return out[0] if isinstance(out, tuple) else out |
| + return out |
| |
| |
| |
| |
| @@ -74,6 +74,7 @@ |
| ) |
| from .loss.loss_utils import LOSS_MAPPING |
| from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS |
| +from .modeling_flash_attention_utils import lazy_import_flash_attention |
| from .pytorch_utils import ( # noqa: F401 |
| Conv1D, |
| apply_chunking_to_forward, |
| @@ -2126,7 +2127,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH |
| _pp_plan = None |
| |
| # This flag signal that the model can be used as an efficient backend in TGI and vLLM |
| - # In practice, it means that they support attention interface functions, fully pass the kwargs |
| + # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs |
| # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan |
| _supports_attention_backend = False |
| _can_record_outputs = None |
| @@ -2748,6 +2749,7 @@ def _check_and_adjust_attn_implementation( |
| if attention_wrapper is None: |
| attention_wrapper = flash_attention_forward |
| kernel_function = partial(attention_wrapper, implementation=kernel) |
| + lazy_import_flash_attention(kernel) |
| elif kernel_name is not None: |
| kernel_function = getattr(kernel, kernel_name) |
| ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) |
| @@ -2763,7 +2765,13 @@ def _check_and_adjust_attn_implementation( |
| attn_implementation = "sdpa" # Try to fallback to sdpa in this case |
| return attn_implementation |
| else: |
| - return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check) |
| + attn_implementation = self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check) |
| + |
| + # preload flash attention here to allow compile with fullgraph |
| + if applicable_attn_implementation.startswith("flash_attention"): |
| + lazy_import_flash_attention(applicable_attn_implementation) |
| + |
| + return attn_implementation |
| |
| def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str: |
| requested_attention = "sdpa" if _requested_attention is None else _requested_attention |
| |
| |
| |
| |
| @@ -3483,92 +3483,107 @@ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_sid |
| for model_class in self.all_model_classes: |
| if not model_class._supports_flash_attn: |
| self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") |
| + # Custom kernel which needs the mask interface to be properly usable on these models |
| + if not model_class._supports_attention_backend and not attn_implementation.startswith("flash_attention"): |
| + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") |
| |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| - config.head_dim = 64 # fa2 does not always support arbitrary headim |
| - model = model_class(config) |
| - |
| - model.to(torch_device) |
| - model.to(torch.bfloat16) |
| - dummy_input = inputs_dict[model.main_input_name][:1] |
| - if dummy_input.dtype in [torch.float32, torch.float16]: |
| - dummy_input = dummy_input.to(torch.bfloat16) |
| - |
| - dummy_attention_mask = inputs_dict.get("attention_mask", None) |
| |
| - if dummy_attention_mask is not None: |
| - dummy_attention_mask = dummy_attention_mask[:1] |
| - if padding_side == "left": |
| - dummy_attention_mask[:, 1:] = 1 |
| - dummy_attention_mask[:, :1] = 0 |
| - else: |
| - dummy_attention_mask[:, :-1] = 1 |
| - dummy_attention_mask[:, -1:] = 0 |
| - if model.config.is_encoder_decoder: |
| - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] |
| + # flash attention variants does not always support arbitrary headim |
| + config = self._prepare_config_headdim(config, 16) |
| |
| - outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) |
| - model.set_attn_implementation(attn_implementation) |
| - outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) |
| - else: |
| - outputs = model(dummy_input, output_hidden_states=True) |
| - model.set_attn_implementation(attn_implementation) |
| - outputs_fa = model(dummy_input, output_hidden_states=True) |
| + # TODO it is unclear why saving and reloading with dtype works while |
| + # casting with `.to(dtype=..., device=...)` does not. |
| + # Discovered on tests with `Bart` models. |
| + model = model_class(config) |
| + with tempfile.TemporaryDirectory() as tmpdirname: |
| + model.save_pretrained(tmpdirname) |
| + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) |
| + model.to(torch_device) |
| |
| - model.set_attn_implementation("sdpa") |
| - logits = ( |
| - outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] |
| - ) |
| - logits_fa = ( |
| - outputs_fa.hidden_states[-1] |
| - if not model.config.is_encoder_decoder |
| - else outputs_fa.decoder_hidden_states[-1] |
| - ) |
| + dummy_input = inputs_dict[model.main_input_name][:1] |
| + if dummy_input.dtype in [torch.float32, torch.float16]: |
| + dummy_input = dummy_input.to(torch.bfloat16) |
| |
| - assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) |
| + dummy_attention_mask = inputs_dict.get("attention_mask", None) |
| |
| - if model.config.is_encoder_decoder: |
| - other_inputs = { |
| - "decoder_input_ids": decoder_input_ids, |
| - "decoder_attention_mask": dummy_attention_mask, |
| - "output_hidden_states": True, |
| - } |
| if dummy_attention_mask is not None: |
| - other_inputs["attention_mask"] = dummy_attention_mask |
| + dummy_attention_mask = dummy_attention_mask[:1] |
| + if padding_side == "left": |
| + dummy_attention_mask[:, 1:] = 1 |
| + dummy_attention_mask[:, :1] = 0 |
| + else: |
| + dummy_attention_mask[:, :-1] = 1 |
| + dummy_attention_mask[:, -1:] = 0 |
| + if model.config.is_encoder_decoder: |
| + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] |
| |
| - outputs = model(dummy_input, **other_inputs) |
| - model.set_attn_implementation(attn_implementation) |
| - outputs_fa = model(dummy_input, **other_inputs) |
| - else: |
| - other_inputs = { |
| - "output_hidden_states": True, |
| - } |
| - if dummy_attention_mask is not None: |
| - other_inputs["attention_mask"] = dummy_attention_mask |
| + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) |
| + model.set_attn_implementation(attn_implementation) |
| + outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) |
| + else: |
| + outputs = model(dummy_input, output_hidden_states=True) |
| + model.set_attn_implementation(attn_implementation) |
| + outputs_fa = model(dummy_input, output_hidden_states=True) |
| + |
| + model.set_attn_implementation("sdpa") |
| + logits = ( |
| + outputs.hidden_states[-1] |
| + if not model.config.is_encoder_decoder |
| + else outputs.decoder_hidden_states[-1] |
| + ) |
| + logits_fa = ( |
| + outputs_fa.hidden_states[-1] |
| + if not model.config.is_encoder_decoder |
| + else outputs_fa.decoder_hidden_states[-1] |
| + ) |
| |
| - outputs = model(dummy_input, **other_inputs) |
| - model.set_attn_implementation(attn_implementation) |
| - outputs_fa = model(dummy_input, **other_inputs) |
| + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) |
| |
| - model.set_attn_implementation("sdpa") |
| - logits = ( |
| - outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1] |
| - ) |
| - logits_fa = ( |
| - outputs_fa.hidden_states[-1] |
| - if not model.config.is_encoder_decoder |
| - else outputs_fa.decoder_hidden_states[-1] |
| - ) |
| + if model.config.is_encoder_decoder: |
| + other_inputs = { |
| + "decoder_input_ids": decoder_input_ids, |
| + "decoder_attention_mask": dummy_attention_mask, |
| + "output_hidden_states": True, |
| + } |
| + if dummy_attention_mask is not None: |
| + other_inputs["attention_mask"] = dummy_attention_mask |
| |
| - if padding_side == "left": |
| - assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) |
| + outputs = model(dummy_input, **other_inputs) |
| + model.set_attn_implementation(attn_implementation) |
| + outputs_fa = model(dummy_input, **other_inputs) |
| + else: |
| + other_inputs = { |
| + "output_hidden_states": True, |
| + } |
| + if dummy_attention_mask is not None: |
| + other_inputs["attention_mask"] = dummy_attention_mask |
| + |
| + outputs = model(dummy_input, **other_inputs) |
| + model.set_attn_implementation(attn_implementation) |
| + outputs_fa = model(dummy_input, **other_inputs) |
| + |
| + model.set_attn_implementation("sdpa") |
| + logits = ( |
| + outputs.hidden_states[-1] |
| + if not model.config.is_encoder_decoder |
| + else outputs.decoder_hidden_states[-1] |
| + ) |
| + logits_fa = ( |
| + outputs_fa.hidden_states[-1] |
| + if not model.config.is_encoder_decoder |
| + else outputs_fa.decoder_hidden_states[-1] |
| + ) |
| |
| - # check with inference + dropout |
| - model.train() |
| - model.set_attn_implementation(attn_implementation) |
| - _ = model(dummy_input, **other_inputs) |
| - else: |
| - assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) |
| + if padding_side == "left": |
| + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) |
| + |
| + # check with inference + dropout |
| + model.train() |
| + model.set_attn_implementation(attn_implementation) |
| + _ = model(dummy_input, **other_inputs) |
| + else: |
| + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) |
| |
| @require_kernels |
| @require_torch_gpu |
| @@ -4698,6 +4713,70 @@ def recursively_check(eager_outputs, exported_outputs): |
| is_tested = recursively_check(eager_outputs, exported_outputs) |
| self.assertTrue(is_tested, msg=f"No outputs were compared for {model_class.__name__}") |
| |
| + @staticmethod |
| + def _prepare_config_headdim(config, requested_dim): |
| + """ |
| + This method allows to update the head dim for all model types including |
| + composite models and models that do not support head dim by themselves. |
| + |
| + Why? A lot of kernels including flex attention rely on triton for compilation. |
| + However, triton cannot handle hidden dimensions of less than 16 for example. |
| + (There are many more examples especially now that the `kernels` library is |
| + supported) |
| + """ |
| + |
| + def update_config_headdim(config, requested_dim): |
| + # Flex Attention cannot use dropout |
| + if hasattr(config, "attention_dropout"): |
| + config.attention_dropout = 0 |
| + if hasattr(config, "attention_probs_dropout_prob"): |
| + config.attention_probs_dropout_prob = 0 |
| + |
| + # Update the head dim and try to update hidden size as well if present in config |
| + # NOTE: some models may have none if the values in sub-config, thus we check for `Noneness` |
| + head_dim = None |
| + if hasattr(config, "head_dim") and config.head_dim is not None: |
| + head_dim = config.head_dim |
| + config.head_dim = max(requested_dim, config.head_dim) |
| + |
| + cross_head_dim = None |
| + if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None: |
| + cross_head_dim = config.cross_head_dim |
| + config.cross_head_dim = max(requested_dim, config.cross_head_dim) |
| + |
| + if ( |
| + getattr(config, "hidden_size", None) is not None |
| + and getattr(config, "num_attention_heads", None) is not None |
| + ): |
| + head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads |
| + config.hidden_size *= max(requested_dim // head_dim, 1) |
| + |
| + if ( |
| + getattr(config, "decoder_hidden_size", None) is not None |
| + and getattr(config, "decoder_num_attention_heads", None) is not None |
| + ): |
| + decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads |
| + config.decoder_hidden_size *= max(requested_dim // decoder_head_dim, 1) |
| + |
| + if ( |
| + getattr(config, "cross_hidden_size", None) is not None |
| + and getattr(config, "cross_num_attention_heads", None) is not None |
| + ): |
| + cross_head_dim = ( |
| + cross_head_dim |
| + if cross_head_dim is not None |
| + else config.cross_hidden_size // config.cross_num_attention_heads |
| + ) |
| + config.cross_hidden_size *= max(requested_dim // cross_head_dim, 1) |
| + |
| + # Update config values |
| + update_config_headdim(config, requested_dim) |
| + for key in config.sub_configs: |
| + sub_config = getattr(config, key) |
| + update_config_headdim(sub_config, requested_dim) |
| + |
| + return config |
| + |
| @require_torch_gpu |
| def test_flex_attention_with_grads(self): |
| for model_class in self.all_model_classes: |
| @@ -4711,59 +4790,8 @@ def test_flex_attention_with_grads(self): |
| ): |
| self.skipTest(reason="At least some parts of this model do not support flex attention") |
| |
| - def update_config_for_flex(config): |
| - # Flex Attention cannot use dropout |
| - if hasattr(config, "attention_dropout"): |
| - config.attention_dropout = 0 |
| - if hasattr(config, "attention_probs_dropout_prob"): |
| - config.attention_probs_dropout_prob = 0 |
| - |
| - # Flex attention relies on triton on compilation |
| - # However, triton cannot handle hidden dimensions of less than 16 |
| - # --> forcing at least a hidden dim of 16 |
| - |
| - # Update the head dim and try to update hidden size as well if present in config |
| - # NOTE: some models may have none if the values in sub-config, thus we check for `Noneness` |
| - head_dim = None |
| - if hasattr(config, "head_dim") and config.head_dim is not None: |
| - head_dim = config.head_dim |
| - config.head_dim = max(16, config.head_dim) |
| - |
| - cross_head_dim = None |
| - if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None: |
| - cross_head_dim = config.cross_head_dim |
| - config.cross_head_dim = max(16, config.cross_head_dim) |
| - |
| - if ( |
| - getattr(config, "hidden_size", None) is not None |
| - and getattr(config, "num_attention_heads", None) is not None |
| - ): |
| - head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads |
| - config.hidden_size *= max(16 // head_dim, 1) |
| - |
| - if ( |
| - getattr(config, "decoder_hidden_size", None) is not None |
| - and getattr(config, "decoder_num_attention_heads", None) is not None |
| - ): |
| - decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads |
| - config.decoder_hidden_size *= max(16 // decoder_head_dim, 1) |
| - |
| - if ( |
| - getattr(config, "cross_hidden_size", None) is not None |
| - and getattr(config, "cross_num_attention_heads", None) is not None |
| - ): |
| - cross_head_dim = ( |
| - cross_head_dim |
| - if cross_head_dim is not None |
| - else config.cross_hidden_size // config.cross_num_attention_heads |
| - ) |
| - config.cross_hidden_size *= max(16 // cross_head_dim, 1) |
| - |
| # Set default attention to flex and update config values |
| - update_config_for_flex(config) |
| - for key in config.sub_configs: |
| - sub_config = getattr(config, key) |
| - update_config_for_flex(sub_config) |
| + config = self._prepare_config_headdim(config, 16) # specific to triton |
| |
| if model_class._can_set_attn_implementation(): |
| model = model_class(config).to(device=torch_device) |
|
|