| from typing import Optional |
|
|
| import torch |
|
|
| from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask |
| from ..utils import logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _use_top_left_mask = flash_attn_supports_top_left_mask() |
|
|
|
|
| def flash_attention_forward( |
| module: torch.nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| dropout: float = 0.0, |
| scaling: Optional[float] = None, |
| sliding_window: Optional[int] = None, |
| softcap: Optional[float] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, None]: |
| if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None: |
| logger.warning_once( |
| "`flash_attention_2` does not support `output_attentions=True` or `head_mask`." |
| " Please set your attention to `eager` if you want any of these features." |
| ) |
|
|
| |
| seq_len = query.shape[2] |
|
|
| if any(dim == 0 for dim in query.shape): |
| raise ValueError( |
| "Tensor query has shape with a zero dimension.\n" |
| "FlashAttention does not support inputs with dim=0.\n" |
| "Please check your input shapes or use SDPA instead." |
| ) |
| |
| query = query.transpose(1, 2) |
| key = key.transpose(1, 2) |
| value = value.transpose(1, 2) |
|
|
| |
| |
| |
| |
| |
| target_dtype = None |
| if query.dtype == torch.float32: |
| if torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| |
| elif hasattr(module.config, "_pre_quantization_dtype"): |
| target_dtype = module.config._pre_quantization_dtype |
| else: |
| target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype |
|
|
| |
| is_causal = kwargs.pop("is_causal", None) |
| if is_causal is None: |
| is_causal = module.is_causal |
|
|
| attn_output = _flash_attention_forward( |
| query, |
| key, |
| value, |
| attention_mask, |
| query_length=seq_len, |
| is_causal=is_causal, |
| dropout=dropout, |
| softmax_scale=scaling, |
| sliding_window=sliding_window, |
| softcap=softcap, |
| use_top_left_mask=_use_top_left_mask, |
| target_dtype=target_dtype, |
| attn_implementation=module.config._attn_implementation, |
| layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, |
| **kwargs, |
| ) |
|
|
| return attn_output, None |
|
|