| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import itertools |
| from typing import Callable, Optional, Union |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from transformers.cache_utils import Cache |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_torchdynamo_compiling |
|
|
|
|
| if is_torch_flex_attn_available(): |
| from torch.nn.attention.flex_attention import BlockMask, create_block_mask |
| else: |
| |
| BlockMask = torch.Tensor |
|
|
| _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5") |
| _is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6") |
|
|
| if _is_torch_greater_or_equal_than_2_6: |
| from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex |
|
|
|
|
| def and_masks(*mask_functions: list[Callable]) -> Callable: |
| """Returns a mask function that is the intersection of provided mask functions""" |
| if not all(callable(arg) for arg in mask_functions): |
| raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}") |
|
|
| def and_mask(batch_idx, head_idx, q_idx, kv_idx): |
| result = q_idx.new_ones((), dtype=torch.bool) |
| for mask in mask_functions: |
| result = result & mask(batch_idx, head_idx, q_idx, kv_idx) |
| return result |
|
|
| return and_mask |
|
|
|
|
| def or_masks(*mask_functions: list[Callable]) -> Callable: |
| """Returns a mask function that is the union of provided mask functions""" |
| if not all(callable(arg) for arg in mask_functions): |
| raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}") |
|
|
| def or_mask(batch_idx, head_idx, q_idx, kv_idx): |
| result = q_idx.new_zeros((), dtype=torch.bool) |
| for mask in mask_functions: |
| result = result | mask(batch_idx, head_idx, q_idx, kv_idx) |
| return result |
|
|
| return or_mask |
|
|
|
|
| def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| """ |
| This creates a basic lower-diagonal causal mask. |
| """ |
| return kv_idx <= q_idx |
|
|
|
|
| def sliding_window_overlay(sliding_window: int) -> Callable: |
| """ |
| This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding |
| window mask. |
| """ |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| return kv_idx > q_idx - sliding_window |
|
|
| return inner_mask |
|
|
|
|
| def chunked_overlay(chunk_size: int) -> Callable: |
| """ |
| This is an overlay depicting a chuned attention pattern. Add it on top of a causal mask for a proper chunked |
| attention mask. |
| """ |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| return kv_idx // chunk_size == q_idx // chunk_size |
|
|
| return inner_mask |
|
|
|
|
| def sliding_window_causal_mask_function(sliding_window: int) -> Callable: |
| """ |
| This return the mask_function function to create a sliding window mask. |
| """ |
| return and_masks(sliding_window_overlay(sliding_window), causal_mask_function) |
|
|
|
|
| def chunked_causal_mask_function(chunk_size: int) -> Callable: |
| """ |
| This return the mask_function function to create a chunked attention mask. |
| """ |
| return and_masks(chunked_overlay(chunk_size), causal_mask_function) |
|
|
|
|
| def padding_mask_function(padding_mask: torch.Tensor) -> Callable: |
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| |
| |
| |
| return padding_mask[batch_idx, kv_idx] |
|
|
| return inner_mask |
|
|
|
|
| def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable: |
| """ |
| This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, |
| not start and end indices. |
| """ |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset) |
|
|
| return inner_mask |
|
|
|
|
| def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: |
| """ |
| Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over |
| the batch and head indices as well if `bh_indices=True`. |
| Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive |
| functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different). |
| |
| Args: |
| mask_function (`Callable`): |
| The mask_function to vmap. |
| bh_indices (`bool`, optional): |
| Whether to vmap over the batch and head indices as well, or only q and kv indices. |
| |
| Returns: |
| Callable: The vmapped function. |
| """ |
| |
| dimensions = [(None, None, None, 0), (None, None, 0, None)] |
| if bh_indices: |
| |
| dimensions.extend([(None, 0, None, None), (0, None, None, None)]) |
|
|
| for dims in dimensions: |
| mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) |
| return mask_function |
|
|
|
|
| def prepare_padding_mask( |
| attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True |
| ) -> Optional[torch.Tensor]: |
| """ |
| From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing |
| according to the `kv_offset` if `_slice` is `True`. |
| """ |
| local_padding_mask = attention_mask |
| if attention_mask is not None: |
| |
| if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0: |
| local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length)) |
| |
| if _slice: |
| |
| |
| mask_indices = torch.arange(kv_length, device=local_padding_mask.device) |
| mask_indices += kv_offset |
| local_padding_mask = local_padding_mask[:, mask_indices] |
| return local_padding_mask |
|
|
|
|
| def _ignore_causal_mask_sdpa( |
| padding_mask: Optional[torch.Tensor], |
| query_length: int, |
| kv_length: int, |
| kv_offset: int, |
| local_attention_size: Optional[int] = None, |
| ) -> bool: |
| """ |
| Detects whether the causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. |
| |
| In case no token is masked in the 2D `padding_mask` argument, if `query_length == 1` or |
| `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, |
| allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is |
| passed). |
| """ |
| is_tracing = torch.jit.is_tracing() or isinstance(padding_mask, torch.fx.Proxy) or is_torchdynamo_compiling() |
| if padding_mask is not None and padding_mask.shape[-1] > kv_length: |
| mask_indices = torch.arange(kv_length, device=padding_mask.device) |
| mask_indices += kv_offset |
| padding_mask = padding_mask[:, mask_indices] |
|
|
| |
| |
| |
| |
| if ( |
| not is_tracing |
| |
| and (query_length == 1 or kv_length == query_length) |
| |
| and (local_attention_size is None or kv_length < local_attention_size) |
| |
| and (padding_mask is None or padding_mask.all()) |
| ): |
| return True |
|
|
| return False |
|
|
|
|
| def sdpa_mask_recent_torch( |
| batch_size: int, |
| cache_position: torch.Tensor, |
| kv_length: int, |
| kv_offset: int = 0, |
| mask_function: Callable = causal_mask_function, |
| attention_mask: Optional[torch.Tensor] = None, |
| local_size: Optional[int] = None, |
| allow_is_causal_skip: bool = True, |
| **kwargs, |
| ) -> Optional[torch.Tensor]: |
| """ |
| Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that |
| the element should take part in the attention computation, and False that it should not. |
| This function can only be used with torch>=2.5, as the context manager is otherwise not available. |
| |
| Args: |
| batch_size (`int`): |
| The batch size of the input sequence. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| kv_length (`int`): |
| The size that the key and value states will have during the attention computation. |
| kv_offset (`int`, optional): |
| An optional offset to indicate at which first position the key and values states will refer to. |
| mask_function (`Callable`): |
| The mask factory function describing the mask pattern. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) |
| local_size (`int`, optional): |
| The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` |
| to try to skip mask creation if possible. |
| allow_is_causal_skip (`bool`, optional): |
| Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in |
| `torch.sdpa` instead. Default to `True`. |
| allow_torch_fix (`bool`, optional): |
| Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older |
| versions. We need an arg to skip it when using eager. By default `True`. |
| |
| |
| ## Creating a simple causal mask: |
| |
| To create the following causal mask: |
| |
| 0 ■ ⬚ ⬚ ⬚ ⬚ |
| 1 ■ ■ ⬚ ⬚ ⬚ |
| 2 ■ ■ ■ ⬚ ⬚ |
| 3 ■ ■ ■ ■ ⬚ |
| 4 ■ ■ ■ ■ ■ |
| |
| You can do |
| |
| ```python |
| >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5) |
| >>> tensor([[[[ True, False, False, False, False], |
| [ True, True, False, False, False], |
| [ True, True, True, False, False], |
| [ True, True, True, True, False], |
| [ True, True, True, True, True]]]]) |
| ``` |
| |
| ## Creating a sliding window mask: |
| |
| To create the following sliding window mask (`sliding_window=3`): |
| |
| 0 ■ ⬚ ⬚ ⬚ ⬚ |
| 1 ■ ■ ⬚ ⬚ ⬚ |
| 2 ■ ■ ■ ⬚ ⬚ |
| 3 ⬚ ■ ■ ■ ⬚ |
| 4 ⬚ ⬚ ■ ■ ■ |
| |
| You can do |
| |
| ```python |
| >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3)) |
| >>> tensor([[[[ True, False, False, False, False], |
| [ True, True, False, False, False], |
| [ True, True, True, False, False], |
| [False, True, True, True, False], |
| [False, False, True, True, True]]]]) |
| ``` |
| |
| ## Creating a chunked attention mask |
| |
| To create the following chunked attention mask (`chunk_size=3`): |
| |
| 0 ■ ⬚ ⬚ ⬚ ⬚ |
| 1 ■ ■ ⬚ ⬚ ⬚ |
| 2 ■ ■ ■ ⬚ ⬚ |
| 3 ⬚ ⬚ ⬚ ■ ⬚ |
| 4 ⬚ ⬚ ⬚ ■ ■ |
| |
| You can do |
| |
| ```python |
| >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3)) |
| >>> tensor([[[[ True, False, False, False, False], |
| [ True, True, False, False, False], |
| [ True, True, True, False, False], |
| [False, False, False, True, False], |
| [False, False, False, True, True]]]]) |
| ``` |
| |
| """ |
| q_length = cache_position.shape[0] |
| |
| padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) |
|
|
| |
| if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): |
| return None |
|
|
| |
| |
| kv_arange = torch.arange(kv_length, device=cache_position.device) |
| kv_arange += kv_offset |
|
|
| |
| if padding_mask is not None: |
| mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) |
|
|
| batch_arange = torch.arange(batch_size, device=cache_position.device) |
| head_arange = torch.arange(1, device=cache_position.device) |
| |
| |
| |
| with TransformGetItemToIndex(): |
| causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) |
|
|
| return causal_mask |
|
|
|
|
| def sdpa_mask_older_torch( |
| batch_size: int, |
| cache_position: torch.Tensor, |
| kv_length: int, |
| kv_offset: int = 0, |
| mask_function: Callable = causal_mask_function, |
| attention_mask: Optional[torch.Tensor] = None, |
| local_size: Optional[int] = None, |
| allow_is_causal_skip: bool = True, |
| allow_torch_fix: bool = True, |
| **kwargs, |
| ) -> Optional[torch.Tensor]: |
| """ |
| NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise. |
| |
| Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that |
| the element should take part in the attention computation, and False that it should not. |
| If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend |
| to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does |
| not change the final result). |
| |
| Args: |
| batch_size (`int`): |
| The batch size of the input sequence. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| kv_length (`int`): |
| The size that the key and value states will have during the attention computation. |
| kv_offset (`int`, optional): |
| An optional offset to indicate at which first position the key and values states will refer to. |
| mask_function (`Callable`): |
| The mask factory function describing the mask pattern. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) |
| local_size (`int`, optional): |
| The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` |
| to try to skip mask creation if possible. |
| allow_is_causal_skip (`bool`, optional): |
| Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in |
| `torch.sdpa` instead. Default to `True`. |
| allow_torch_fix (`bool`, optional): |
| Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older |
| versions. We need an arg to skip it when using eager. By default `True`. |
| """ |
| q_length = cache_position.shape[0] |
| |
| padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) |
|
|
| |
| if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): |
| return None |
|
|
| |
| |
| kv_arange = torch.arange(kv_length, device=cache_position.device) |
| kv_arange += kv_offset |
|
|
| |
| |
| |
| |
| causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) |
| if padding_mask is not None: |
| causal_mask = causal_mask * padding_mask[:, None, None, :] |
|
|
| |
| |
| if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: |
| causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) |
| return causal_mask |
|
|
|
|
| |
| |
| sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch |
|
|
|
|
| def eager_mask( |
| batch_size: int, |
| cache_position: torch.Tensor, |
| kv_length: int, |
| kv_offset: int = 0, |
| mask_function: Callable = causal_mask_function, |
| attention_mask: Optional[torch.Tensor] = None, |
| dtype: torch.dtype = torch.float32, |
| **kwargs, |
| ) -> torch.Tensor: |
| """ |
| Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that |
| the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that |
| it should not. |
| |
| Args: |
| batch_size (`int`): |
| The batch size of the input sequence. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| kv_length (`int`): |
| The size that the key and value states will have during the attention computation. |
| kv_offset (`int`, optional): |
| An optional offset to indicate at which first position the key and values states will refer to. |
| mask_function (`Callable`): |
| The mask factory function describing the mask pattern. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) |
| dtype (`torch.dtype`, optional): |
| The dtype to use for the mask. By default, `torch.float32`. |
| """ |
| |
| _ = kwargs.pop("allow_is_causal_skip", None) |
| mask = sdpa_mask( |
| batch_size=batch_size, |
| cache_position=cache_position, |
| kv_length=kv_length, |
| kv_offset=kv_offset, |
| mask_function=mask_function, |
| attention_mask=attention_mask, |
| allow_is_causal_skip=False, |
| allow_torch_fix=False, |
| **kwargs, |
| ) |
| min_dtype = torch.finfo(dtype).min |
| |
| mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) |
| return mask |
|
|
|
|
| def flash_attention_mask( |
| batch_size: int, |
| cache_position: torch.Tensor, |
| kv_length: int, |
| kv_offset: int = 0, |
| mask_function: Callable = causal_mask_function, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ): |
| """ |
| Create the attention mask necesary to use FA2. Since FA2 is un-padded by definition, here we simply return |
| `None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens. |
| We just slice it in case of sliding window. |
| |
| Args: |
| batch_size (`int`): |
| The batch size of the input sequence. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| kv_length (`int`): |
| The size that the key and value states will have during the attention computation. |
| kv_offset (`int`, optional): |
| An optional offset to indicate at which first position the key and values states will refer to. |
| mask_function (`Callable`): |
| The mask factory function describing the mask pattern. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) |
| """ |
| if attention_mask is not None: |
| |
| attention_mask = attention_mask[:, -kv_length:] |
| |
| |
| if attention_mask.all(): |
| attention_mask = None |
|
|
| return attention_mask |
|
|
|
|
| def flex_attention_mask( |
| batch_size: int, |
| cache_position: torch.Tensor, |
| kv_length: int, |
| kv_offset: int = 0, |
| mask_function: Callable = causal_mask_function, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> BlockMask: |
| """ |
| Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential |
| for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/ |
| |
| Args: |
| batch_size (`int`): |
| The batch size of the input sequence. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| kv_length (`int`): |
| The size that the key and value states will have during the attention computation. |
| kv_offset (`int`, optional): |
| An optional offset to indicate at which first position the key and values states will refer to. |
| mask_function (`Callable`): |
| The mask factory function describing the mask pattern. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) |
| """ |
| q_length, q_offset = cache_position.shape[0], cache_position[0] |
|
|
| |
| if attention_mask is not None: |
| padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) |
| mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) |
|
|
| |
| mask_function = add_offsets_to_mask_function(mask_function, q_offset, kv_offset) |
|
|
| |
| block_mask = create_block_mask( |
| mask_mod=mask_function, |
| B=batch_size, |
| H=None, |
| Q_LEN=q_length, |
| KV_LEN=kv_length, |
| device=cache_position.device, |
| _compile=True, |
| ) |
| return block_mask |
|
|
|
|
| class AttentionMaskInterface(): |
| |
| |
| _global_mapping = { |
| "sdpa": sdpa_mask, |
| "eager": eager_mask, |
| "flash_attention_2": flash_attention_mask, |
| "flex_attention": flex_attention_mask, |
| } |
|
|
|
|
| |
| ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface() |
|
|
|
|
| def _preprocess_mask_arguments( |
| config: PretrainedConfig, |
| input_embeds: torch.Tensor, |
| attention_mask: Optional[Union[torch.Tensor, BlockMask]], |
| cache_position: torch.Tensor, |
| past_key_values: Optional[Cache], |
| layer_idx: Optional[int], |
| ) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]: |
| """ |
| Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the |
| key-value length and offsets, and if we should early exit or not. |
| |
| Args: |
| config (`PretrainedConfig`): |
| The model config. |
| input_embeds (`torch.Tensor`): |
| The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the |
| batch size, query length and dtype. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). |
| It can also be an already prepared 4D mask, in which case it is returned as-is. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| past_key_values (`Cache`, optional): |
| The past key values, if we use a cache. |
| layer_idx (`int`, optional): |
| If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value |
| length and offset. Indeed, for hybrid caches, different layers may return different lengths. |
| |
| Returns: |
| early_exit (`bool`): |
| Whether we should early exit mask creation, and return the mask as-is. |
| attention_mask (`torch.Tensor` or `BlockMask` or `None`): |
| The attention mask to either return immediately, or to use in downstream mask creation. |
| kv_length (`int`): |
| The size that the key and value states will have during the attention computation. |
| kv_offset (`int`): |
| An offset to indicate at which first position the key and values states will refer to. |
| """ |
| |
| if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4: |
| return True, attention_mask, None, None |
|
|
| |
| |
| |
| |
| |
| if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: |
| return True, None, None, None |
|
|
| |
| if attention_mask is not None and attention_mask.ndim == 2: |
| attention_mask = attention_mask.to(device=cache_position.device, dtype=torch.bool) |
|
|
| |
| if past_key_values is not None: |
| kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx) |
| |
| else: |
| kv_length, kv_offset = input_embeds.shape[1], 0 |
|
|
| return False, attention_mask, kv_length, kv_offset |
|
|
|
|
| def create_causal_mask( |
| config: PretrainedConfig, |
| input_embeds: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| cache_position: torch.Tensor, |
| past_key_values: Optional[Cache], |
| or_mask_function: Optional[Callable] = None, |
| and_mask_function: Optional[Callable] = None, |
| ) -> Optional[Union[torch.Tensor, BlockMask]]: |
| """ |
| Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values` |
| has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align |
| to what is needed in the `modeling_xxx.py` files). |
| |
| Args: |
| config (`PretrainedConfig`): |
| The model config. |
| input_embeds (`torch.Tensor`): |
| The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the |
| batch size, query length and dtype. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). |
| It can also be an already prepared 4D mask, in which case it is returned as-is. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| past_key_values (`Cache`, optional): |
| The past key values, if we use a cache. |
| or_mask_function (`Callable`, optional): |
| An optional mask function to combine with the causal mask function (by doing the union of both). This is |
| useful to easily overlay another mask on top of the causal one, for example for image tokens handling. |
| and_mask_function (`Callable`, optional): |
| An optional mask function to combine with the causal mask function (by doing the intersection of both). This is |
| useful to easily overlay another mask on top of the causal one, for example for image tokens handling. |
| """ |
| |
| if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding: |
| layer_idx = past_key_values.is_sliding.index(False) |
| else: |
| layer_idx = 0 |
|
|
| early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( |
| config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx |
| ) |
| if early_exit: |
| return attention_mask |
|
|
| batch_size, dtype = input_embeds.shape[0], input_embeds.dtype |
| mask_factory_function = causal_mask_function |
| mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] |
|
|
| |
| |
| allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True |
|
|
| |
| if or_mask_function is not None: |
| if not _is_torch_greater_or_equal_than_2_6: |
| raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") |
| mask_factory_function = or_masks(mask_factory_function, or_mask_function) |
| allow_is_causal_skip = False |
| if and_mask_function is not None: |
| if not _is_torch_greater_or_equal_than_2_6: |
| raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") |
| mask_factory_function = and_masks(mask_factory_function, and_mask_function) |
| allow_is_causal_skip = False |
|
|
| |
| causal_mask = mask_interface( |
| batch_size=batch_size, |
| cache_position=cache_position, |
| kv_length=kv_length, |
| kv_offset=kv_offset, |
| mask_function=mask_factory_function, |
| attention_mask=attention_mask, |
| allow_is_causal_skip=allow_is_causal_skip, |
| dtype=dtype, |
| config=config, |
| ) |
| return causal_mask |
|
|
|
|
| def create_sliding_window_causal_mask( |
| config: PretrainedConfig, |
| input_embeds: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| cache_position: torch.Tensor, |
| past_key_values: Optional[Cache], |
| or_mask_function: Optional[Callable] = None, |
| and_mask_function: Optional[Callable] = None, |
| ) -> Optional[Union[torch.Tensor, BlockMask]]: |
| """ |
| Create a sliding window causal mask based on the attention implementation used (stored in the config). This type |
| of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this |
| function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the |
| `modeling_xxx.py` files). |
| |
| Args: |
| config (`PretrainedConfig`): |
| The model config. |
| input_embeds (`torch.Tensor`): |
| The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the |
| batch size, query length and dtype. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). |
| It can also be an already prepared 4D mask, in which case it is returned as-is. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| past_key_values (`Cache`, optional): |
| The past key values, if we use a cache. |
| or_mask_function (`Callable`, optional): |
| An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is |
| useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. |
| and_mask_function (`Callable`, optional): |
| An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is |
| useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. |
| """ |
| |
| if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: |
| layer_idx = past_key_values.is_sliding.index(True) |
| else: |
| layer_idx = 0 |
|
|
| early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( |
| config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx |
| ) |
| if early_exit: |
| return attention_mask |
|
|
| sliding_window = getattr(config, "sliding_window", None) |
| if sliding_window is None: |
| raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set") |
|
|
| batch_size, dtype = input_embeds.shape[0], input_embeds.dtype |
| mask_factory_function = sliding_window_causal_mask_function(sliding_window) |
| mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] |
|
|
| |
| |
| allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True |
|
|
| |
| if or_mask_function is not None: |
| if not _is_torch_greater_or_equal_than_2_6: |
| raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") |
| mask_factory_function = or_masks(mask_factory_function, or_mask_function) |
| allow_is_causal_skip = False |
| if and_mask_function is not None: |
| if not _is_torch_greater_or_equal_than_2_6: |
| raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") |
| mask_factory_function = and_masks(mask_factory_function, and_mask_function) |
| allow_is_causal_skip = False |
|
|
| |
| causal_mask = mask_interface( |
| batch_size=batch_size, |
| cache_position=cache_position, |
| kv_length=kv_length, |
| kv_offset=kv_offset, |
| mask_function=mask_factory_function, |
| attention_mask=attention_mask, |
| allow_is_causal_skip=allow_is_causal_skip, |
| local_size=sliding_window, |
| dtype=dtype, |
| config=config, |
| ) |
| return causal_mask |
|
|
|
|
| def create_chunked_causal_mask( |
| config: PretrainedConfig, |
| input_embeds: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| cache_position: torch.Tensor, |
| past_key_values: Optional[Cache], |
| or_mask_function: Optional[Callable] = None, |
| and_mask_function: Optional[Callable] = None, |
| ) -> Optional[Union[torch.Tensor, BlockMask]]: |
| """ |
| Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type |
| of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this |
| function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the |
| `modeling_xxx.py` files). |
| |
| Args: |
| config (`PretrainedConfig`): |
| The model config. |
| input_embeds (`torch.Tensor`): |
| The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the |
| batch size, query length and dtype. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). |
| It can also be an already prepared 4D mask, in which case it is returned as-is. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| past_key_values (`Cache`, optional): |
| The past key values, if we use a cache. |
| or_mask_function (`Callable`, optional): |
| An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is |
| useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. |
| and_mask_function (`Callable`, optional): |
| An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is |
| useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. |
| """ |
| |
| if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding: |
| layer_idx = past_key_values.is_sliding.index(True) |
| else: |
| layer_idx = 0 |
|
|
| early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( |
| config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx |
| ) |
| if early_exit: |
| return attention_mask |
|
|
| chunk_size = getattr(config, "attention_chunk_size", None) |
| if chunk_size is None: |
| raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set") |
|
|
| |
| if config._attn_implementation == "flash_attention_2" and kv_length + kv_offset > chunk_size: |
| raise ValueError( |
| "Flash attention 2 cannot handle chunked attention, and the key-value length is larger than the chunk size so the " |
| "chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model" |
| ) |
|
|
| batch_size, dtype = input_embeds.shape[0], input_embeds.dtype |
| mask_factory_function = chunked_causal_mask_function(chunk_size) |
| mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] |
|
|
| |
| |
| allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True |
|
|
| |
| if or_mask_function is not None: |
| if not _is_torch_greater_or_equal_than_2_6: |
| raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") |
| mask_factory_function = or_masks(mask_factory_function, or_mask_function) |
| allow_is_causal_skip = False |
| if and_mask_function is not None: |
| if not _is_torch_greater_or_equal_than_2_6: |
| raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") |
| mask_factory_function = and_masks(mask_factory_function, and_mask_function) |
| allow_is_causal_skip = False |
|
|
| |
| causal_mask = mask_interface( |
| batch_size=batch_size, |
| cache_position=cache_position, |
| kv_length=kv_length, |
| kv_offset=kv_offset, |
| mask_function=mask_factory_function, |
| attention_mask=attention_mask, |
| allow_is_causal_skip=allow_is_causal_skip, |
| local_size=chunk_size, |
| dtype=dtype, |
| config=config, |
| ) |
| return causal_mask |
|
|
|
|
| LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = { |
| "full_attention": create_causal_mask, |
| "sliding_attention": create_sliding_window_causal_mask, |
| "chunked_attention": create_chunked_causal_mask, |
| } |
|
|
|
|
| def create_masks_for_generate( |
| config: PretrainedConfig, |
| input_embeds: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| cache_position: torch.Tensor, |
| past_key_values: Optional[Cache], |
| or_mask_function: Optional[Callable] = None, |
| and_mask_function: Optional[Callable] = None, |
| **kwargs, |
| ): |
| """ |
| This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in `generate` in order |
| to easily create the masks in advance, when we compile the forwards with Static caches. |
| |
| Args: |
| config (`PretrainedConfig`): |
| The model config. |
| input_embeds (`torch.Tensor`): |
| The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the |
| batch size, query length and dtype. |
| attention_mask (`torch.Tensor`, optional): |
| The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). |
| It can also be an already prepared 4D mask, in which case it is returned as-is. |
| cache_position (`torch.Tensor`): |
| A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
| past_key_values (`Cache`, optional): |
| The past key values, if we use a cache. |
| or_mask_function (`Callable`, optional): |
| An optional mask function to combine with the other mask function (by doing the union of both). This is |
| useful to easily overlay another mask on top of the causal one, for example for image tokens handling. |
| and_mask_function (`Callable`, optional): |
| An optional mask function to combine with the other mask function (by doing the intersection of both). This is |
| useful to easily overlay another mask on top of the causal one, for example for image tokens handling. |
| """ |
| |
| effective_config = config.get_text_config() |
| |
| mask_kwargs = { |
| "config": effective_config, |
| "input_embeds": input_embeds, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "or_mask_function": or_mask_function, |
| "and_mask_function": and_mask_function, |
| } |
|
|
| |
| if hasattr(effective_config, "layer_types"): |
| causal_masks = {} |
| for layer_pattern in set(effective_config.layer_types): |
| causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs) |
| return causal_masks |
| |
| elif getattr(effective_config, "sliding_window", None) is not None: |
| return create_sliding_window_causal_mask(**mask_kwargs) |
| |
| elif getattr(effective_config, "attention_chunk_size", None) is not None: |
| return create_chunked_causal_mask(**mask_kwargs) |
| |
| return create_causal_mask(**mask_kwargs) |
|
|
|
|
| |
| |
| GREEN = "\033[92m" |
| YELLOW = "\033[93m" |
| RESET = "\033[0m" |
| BLACK_SQUARE = "■" |
| WHITE_SQUARE = "⬚" |
| GREY_SQUARE = "∙" |
| LOW_TRIANGLE = "⬕" |
| UPPER_TRIANGLE = "⬔" |
|
|
|
|
| def get_style(style): |
| if style == "majong": |
| BLACK_SQUARE = "🀞" |
| BLACK_SQUARE = "🀙" |
| WHITE_SQUARE = "🀆" |
| LOW_TRIANGLE = "🀛" |
| UPPER_TRIANGLE = "🀛" |
| else: |
| BLACK_SQUARE = "█" |
| WHITE_SQUARE = "░" |
| LOW_TRIANGLE = "▙" |
| UPPER_TRIANGLE = "▜" |
|
|
| return BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE |
|
|
|
|
| |
|
|
| YELLOW_SQUARE = f"{YELLOW}{BLACK_SQUARE}{RESET}" |
| GREEN_SQUARE = f"{GREEN}{BLACK_SQUARE}{RESET}" |
|
|
|
|
| def tensor_to_mask_visual(original_tensor: torch.Tensor, grid_size=(20, 40), style="majong") -> str: |
| BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE = get_style(style) |
| h, w = original_tensor.shape |
| max_h, max_w = grid_size |
| if not (h < max_h and w < max_w): |
| |
| aspect_ratio = 2 * w / h |
| if aspect_ratio > 1: |
| w = max_w |
| h = min(max_h, max(1, round(max_w / aspect_ratio))) |
| else: |
| h = max_h |
| w = max(1, round(max_h * aspect_ratio)) |
|
|
| |
| tensor = original_tensor.unsqueeze(0).unsqueeze(0) |
| tensor = F.adaptive_avg_pool2d(tensor, output_size=(h, w))[0, 0] |
| else: |
| tensor = original_tensor |
|
|
| |
| result = [] |
| for i in range(h): |
| row = "" |
| for j in range(w): |
| if tensor[i, j] == 1: |
| row += BLACK_SQUARE |
| elif tensor[i, j] == 0: |
| row += WHITE_SQUARE |
| else: |
| if j > 0: |
| if tensor[i, j - 1] == 1: |
| row += LOW_TRIANGLE |
| elif tensor[i, j - 1] == 0: |
| row += UPPER_TRIANGLE |
| else: |
| row += BLACK_SQUARE if tensor[i, j] == 1 else WHITE_SQUARE |
| else: |
| row += ( |
| BLACK_SQUARE |
| if tensor[i, j] == 1 |
| else ( |
| WHITE_SQUARE |
| if tensor[i, j] == 0 |
| else (UPPER_TRIANGLE if tensor[i, j + 1] == 1 else LOW_TRIANGLE) |
| ) |
| ) |
| result.append(row) |
|
|
| return "\n".join(result) |
|
|
|
|
| class AttentionMask(torch.Tensor): |
| def __new__(cls, data, style=None): |
| |
| cls.style = style |
| return torch.Tensor._make_subclass(cls, data, require_grad=False) |
|
|
| def __init__(self, data): |
| |
| pass |
|
|
| def to_string(self, grid_size=(20, 40), limit=4): |
| """Returns a string representation of the block mask.""" |
| dense_mask = self |
| *batch_dims, num_rows, num_cols = dense_mask.shape |
| total_vis = [] |
|
|
| for idx, batch_idx in enumerate(itertools.product(*[range(i) for i in batch_dims])): |
| if idx == limit: |
| total_vis.append("...") |
| total_vis.append("To print out more, set AttentionMask.to_string(limit=N)") |
| total_vis.append("You can also index (AttentionMask[batch, head]) to choose a specific batch or head") |
| break |
| block_vis = tensor_to_mask_visual(dense_mask[batch_idx], grid_size=grid_size, style=self.style) |
| total_vis.append(block_vis) |
|
|
| total_vis.append(f"torch.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})") |
| return "\n".join(total_vis) |
|
|
| def __repr__(self): |
| return self.to_string() |
|
|
| def __str__(self): |
| return self.to_string() |
|
|
| @classmethod |
| def from_tensor(cls, tensor: torch.Tensor, style: Optional[str] = None) -> "AttentionMask": |
| res = cls(tensor) |
| res.style = style |
| return res |