# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import contextlib import functools import inspect import math from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, Callable import torch import torch.distributed as dist import torch.nn.functional as F if torch.distributed.is_available(): import torch.distributed._functional_collectives as funcol from ..utils import ( get_logger, is_aiter_available, is_aiter_version, is_flash_attn_3_available, is_flash_attn_available, is_flash_attn_version, is_kernels_available, is_kernels_version, is_sageattention_available, is_sageattention_version, is_torch_npu_available, is_torch_version, is_torch_xla_available, is_torch_xla_version, is_xformers_available, is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS from ..utils.torch_utils import maybe_allow_in_graph from ._modeling_parallel import gather_size_by_comm if TYPE_CHECKING: from ._modeling_parallel import ParallelConfig _REQUIRED_FLASH_VERSION = "2.6.3" _REQUIRED_AITER_VERSION = "0.1.5" _REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_FLEX_VERSION = "2.5.0" _REQUIRED_XLA_VERSION = "2.2" _REQUIRED_XFORMERS_VERSION = "0.0.29" logger = get_logger(__name__) # pylint: disable=invalid-name _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() _CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) _CAN_USE_NPU_ATTN = is_torch_npu_available() _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) if _CAN_USE_FLASH_ATTN: try: from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward except (ImportError, OSError, RuntimeError) as e: # Handle ABI mismatch or other import failures gracefully. # This can happen when flash_attn was compiled against a different PyTorch version. logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.") _CAN_USE_FLASH_ATTN = False flash_attn_func = None flash_attn_varlen_func = None _wrapped_flash_attn_backward = None _wrapped_flash_attn_forward = None else: flash_attn_func = None flash_attn_varlen_func = None _wrapped_flash_attn_backward = None _wrapped_flash_attn_forward = None if _CAN_USE_FLASH_ATTN_3: try: from flash_attn_interface import flash_attn_func as flash_attn_3_func from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func except (ImportError, OSError, RuntimeError) as e: logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.") _CAN_USE_FLASH_ATTN_3 = False flash_attn_3_func = None flash_attn_3_varlen_func = None else: flash_attn_3_func = None flash_attn_3_varlen_func = None if _CAN_USE_AITER_ATTN: try: from aiter import flash_attn_func as aiter_flash_attn_func except (ImportError, OSError, RuntimeError) as e: logger.warning(f"aiter failed to import: {e}. Falling back to native attention.") _CAN_USE_AITER_ATTN = False aiter_flash_attn_func = None else: aiter_flash_attn_func = None if _CAN_USE_SAGE_ATTN: try: from sageattention import ( sageattn, sageattn_qk_int8_pv_fp8_cuda, sageattn_qk_int8_pv_fp8_cuda_sm90, sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_varlen, ) except (ImportError, OSError, RuntimeError) as e: logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.") _CAN_USE_SAGE_ATTN = False sageattn = None sageattn_qk_int8_pv_fp8_cuda = None sageattn_qk_int8_pv_fp8_cuda_sm90 = None sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None sageattn_varlen = None else: sageattn = None sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None sageattn_qk_int8_pv_fp8_cuda = None sageattn_qk_int8_pv_fp8_cuda_sm90 = None sageattn_varlen = None if _CAN_USE_FLEX_ATTN: try: # We cannot import the flex_attention function from the package directly because it is expected (from the # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the # compiled function. import torch.nn.attention.flex_attention as flex_attention except (ImportError, OSError, RuntimeError) as e: logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.") _CAN_USE_FLEX_ATTN = False flex_attention = None else: flex_attention = None if _CAN_USE_NPU_ATTN: try: from torch_npu import npu_fusion_attention except (ImportError, OSError, RuntimeError) as e: logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.") _CAN_USE_NPU_ATTN = False npu_fusion_attention = None else: npu_fusion_attention = None if _CAN_USE_XLA_ATTN: try: from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention except (ImportError, OSError, RuntimeError) as e: logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.") _CAN_USE_XLA_ATTN = False xla_flash_attention = None else: xla_flash_attention = None if _CAN_USE_XFORMERS_ATTN: try: import xformers.ops as xops except (ImportError, OSError, RuntimeError) as e: logger.warning(f"xformers failed to import: {e}. Falling back to native attention.") _CAN_USE_XFORMERS_ATTN = False xops = None else: xops = None # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 if torch.__version__ >= "2.4.0": _custom_op = torch.library.custom_op _register_fake = torch.library.register_fake else: def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): def wrap(func): return func return wrap if fn is None else fn def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): def wrap(func): return func return wrap if fn is None else fn _custom_op = custom_op_no_op _register_fake = register_fake_no_op # TODO(aryan): Add support for the following: # - Sage Attention++ # - block sparse, radial and other attention methods # - CP with sage attention, flex, xformers, other missing backends # - Add support for normal and CP training with backends that don't support it yet class AttentionBackendName(str, Enum): # EAGER = "eager" # `flash-attn` FLASH = "flash" FLASH_HUB = "flash_hub" FLASH_VARLEN = "flash_varlen" FLASH_VARLEN_HUB = "flash_varlen_hub" _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" _FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub" # `aiter` AITER = "aiter" # PyTorch native FLEX = "flex" NATIVE = "native" _NATIVE_CUDNN = "_native_cudnn" _NATIVE_EFFICIENT = "_native_efficient" _NATIVE_FLASH = "_native_flash" _NATIVE_MATH = "_native_math" _NATIVE_NPU = "_native_npu" _NATIVE_XLA = "_native_xla" # `sageattention` SAGE = "sage" SAGE_HUB = "sage_hub" SAGE_VARLEN = "sage_varlen" _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" # TODO: let's not add support for Sparge Attention now because it requires tuning per model # We can look into supporting something "autotune"-ing in the future # SPARGE = "sparge" # `xformers` XFORMERS = "xformers" class _AttentionBackendRegistry: _backends = {} _constraints = {} _supported_arg_names = {} _supports_context_parallel = set() _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS @classmethod def register( cls, backend: AttentionBackendName, constraints: list[Callable] | None = None, supports_context_parallel: bool = False, ): logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") def decorator(func): cls._backends[backend] = func cls._constraints[backend] = constraints or [] cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) if supports_context_parallel: cls._supports_context_parallel.add(backend.value) return func return decorator @classmethod def get_active_backend(cls): return cls._active_backend, cls._backends[cls._active_backend] @classmethod def set_active_backend(cls, backend: str): cls._active_backend = backend @classmethod def list_backends(cls): return list(cls._backends.keys()) @classmethod def _is_context_parallel_available( cls, backend: AttentionBackendName, ) -> bool: supports_context_parallel = backend.value in cls._supports_context_parallel return supports_context_parallel @dataclass class _HubKernelConfig: """Configuration for downloading and using a hub-based attention kernel.""" repo_id: str function_attr: str revision: str | None = None version: int | None = None kernel_fn: Callable | None = None wrapped_forward_attr: str | None = None wrapped_backward_attr: str | None = None wrapped_forward_fn: Callable | None = None wrapped_backward_fn: Callable | None = None # Registry for hub-based attention kernels _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = { AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", wrapped_forward_attr="flash_attn_interface._flash_attn_forward", wrapped_backward_attr="flash_attn_interface._flash_attn_backward", version=1, ), AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", function_attr="flash_attn_varlen_func", version=1, ), AttentionBackendName.FLASH_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", version=1, ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", version=1, ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( repo_id="kernels-community/sage-attention", function_attr="sageattn", version=1, ), } @contextlib.contextmanager def attention_backend(backend: str | AttentionBackendName = AttentionBackendName.NATIVE): """ Context manager to set the active attention backend. """ if backend not in _AttentionBackendRegistry._backends: raise ValueError(f"Backend {backend} is not registered.") backend = AttentionBackendName(backend) _check_attention_backend_requirements(backend) _maybe_download_kernel_for_backend(backend) old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry.set_active_backend(backend) try: yield finally: _AttentionBackendRegistry.set_active_backend(old_backend) def dispatch_attention_fn( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, attention_kwargs: dict[str, Any] | None = None, *, backend: AttentionBackendName | None = None, parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: attention_kwargs = attention_kwargs or {} if backend is None: # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() else: backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) kwargs = { "query": query, "key": key, "value": value, "attn_mask": attn_mask, "dropout_p": dropout_p, "is_causal": is_causal, "scale": scale, **attention_kwargs, "_parallel_config": parallel_config, } if is_torch_version(">=", "2.5.0"): kwargs["enable_gqa"] = enable_gqa if _AttentionBackendRegistry._checks_enabled: removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) if removed_kwargs: logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") for check in _AttentionBackendRegistry._constraints.get(backend_name): check(**kwargs) kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} return backend_fn(**kwargs) # ===== Checks ===== # A list of very simple functions to catch common errors quickly when debugging. def _check_attn_mask_or_causal(attn_mask: torch.Tensor | None, is_causal: bool, **kwargs) -> None: if attn_mask is not None and is_causal: raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: if query.device != key.device or query.device != value.device: raise ValueError("Query, key, and value must be on the same device.") if query.dtype != key.dtype or query.dtype != value.dtype: raise ValueError("Query, key, and value must have the same dtype.") def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: _check_device(query, key, value) if query.device.type != "cuda": raise ValueError("Query, key, and value must be on a CUDA device.") def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: _check_device_cuda(query, key, value) if torch.cuda.get_device_capability(query.device) < (major, minor): raise ValueError( f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." ) return check_device_cuda def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: if query.dtype != key.dtype: raise ValueError("Query and key must have the same dtype.") if query.dtype != value.dtype: raise ValueError("Query and value must have the same dtype.") def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: _check_qkv_dtype_match(query, key, value) if query.dtype not in (torch.bfloat16, torch.float16): raise ValueError("Query, key, and value must be either bfloat16 or float16.") def _check_shape( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, **kwargs, ) -> None: # Expected shapes: # query: (batch_size, seq_len_q, num_heads, head_dim) # key: (batch_size, seq_len_kv, num_heads, head_dim) # value: (batch_size, seq_len_kv, num_heads, head_dim) # attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv) # or (batch_size, num_heads, seq_len_q, seq_len_kv) if query.shape[-1] != key.shape[-1]: raise ValueError("Query and key must have the same head dimension.") if key.shape[-3] != value.shape[-3]: raise ValueError("Key and value must have the same sequence length.") if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]: raise ValueError("Attention mask must match the key's sequence length.") # ===== Helper functions ===== def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: if not _CAN_USE_FLASH_ATTN: raise RuntimeError( f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." ) elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: if not _CAN_USE_FLASH_ATTN_3: raise RuntimeError( f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) elif backend in [ AttentionBackendName.FLASH_HUB, AttentionBackendName.FLASH_VARLEN_HUB, AttentionBackendName._FLASH_3_HUB, AttentionBackendName._FLASH_3_VARLEN_HUB, AttentionBackendName.SAGE_HUB, ]: if not is_kernels_available(): raise RuntimeError( f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) if not is_kernels_version(">=", "0.12"): raise RuntimeError( f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`." ) elif backend == AttentionBackendName.AITER: if not _CAN_USE_AITER_ATTN: raise RuntimeError( f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`." ) elif backend in [ AttentionBackendName.SAGE, AttentionBackendName.SAGE_VARLEN, AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, ]: if not _CAN_USE_SAGE_ATTN: raise RuntimeError( f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." ) elif backend == AttentionBackendName.FLEX: if not _CAN_USE_FLEX_ATTN: raise RuntimeError( f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." ) elif backend == AttentionBackendName._NATIVE_NPU: if not _CAN_USE_NPU_ATTN: raise RuntimeError( f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." ) elif backend == AttentionBackendName._NATIVE_XLA: if not _CAN_USE_XLA_ATTN: raise RuntimeError( f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." ) elif backend == AttentionBackendName.XFORMERS: if not _CAN_USE_XFORMERS_ATTN: raise RuntimeError( f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." ) @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, seq_len_q: int, seq_len_kv: int, device: torch.device | None = None, ): seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) max_seqlen_q = seqlens_q.max().item() max_seqlen_k = seqlens_k.max().item() return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) def _prepare_for_flash_attn_or_sage_varlen_with_mask( batch_size: int, seq_len_q: int, attn_mask: torch.Tensor, device: torch.device | None = None, ): seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) max_seqlen_q = seqlens_q.max().item() max_seqlen_k = seqlens_k.max().item() return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) def _prepare_for_flash_attn_or_sage_varlen( batch_size: int, seq_len_q: int, seq_len_kv: int, attn_mask: torch.Tensor | None = None, device: torch.device | None = None, ) -> None: if attn_mask is None: return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device) return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: """ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in FlashAttention/Sage varlen. Supports 1D to 4D shapes and common broadcasting patterns. """ if attn_mask.dtype != torch.bool: raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") if attn_mask.ndim == 1: # [seq_len_k] -> broadcast across batch attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) elif attn_mask.ndim == 2: # [batch_size, seq_len_k]. Maybe broadcast across batch if attn_mask.size(0) not in [1, batch_size]: raise ValueError( f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." ) attn_mask = attn_mask.expand(batch_size, seq_len_k) elif attn_mask.ndim == 3: # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. if attn_mask.size(0) not in [1, batch_size]: raise ValueError( f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." ) attn_mask = attn_mask.any(dim=1) attn_mask = attn_mask.expand(batch_size, seq_len_k) elif attn_mask.ndim == 4: # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions if attn_mask.size(0) not in [1, batch_size]: raise ValueError( f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." ) attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] else: raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") if attn_mask.shape != (batch_size, seq_len_k): raise ValueError( f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" ) return attn_mask def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx # ===== Helpers for downloading kernels ===== def _resolve_kernel_attr(module, attr_path: str): target = module for attr in attr_path.split("."): if not hasattr(target, attr): raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.") target = getattr(target, attr) return target def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: if backend not in _HUB_KERNELS_REGISTRY: return config = _HUB_KERNELS_REGISTRY[backend] needs_kernel = config.kernel_fn is None needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward): return try: from kernels import get_kernel kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version) if needs_kernel: config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr) if needs_wrapped_forward: config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr) if needs_wrapped_backward: config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr) except Exception as e: logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") raise # ===== torch op registrations ===== # Registrations are required for fullgraph tracing compatibility # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 @_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, qv: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, attention_chunk: int = 0, softcap: float = 0.0, num_splits: int = 1, pack_gqa: bool | None = None, deterministic: bool = False, sm_margin: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: # Hardcoded for now because pytorch does not support tuple/int type hints window_size = (-1, -1) result = flash_attn_3_func( q=q, k=k, v=v, softmax_scale=softmax_scale, causal=causal, qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, deterministic=deterministic, sm_margin=sm_margin, return_attn_probs=True, ) out, lse, *_ = result lse = lse.permute(0, 2, 1) return out, lse @_register_fake("_diffusers_flash_attn_3::_flash_attn_forward") def _( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, qv: torch.Tensor | None = None, q_descale: torch.Tensor | None = None, k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, attention_chunk: int = 0, softcap: float = 0.0, num_splits: int = 1, pack_gqa: bool | None = None, deterministic: bool = False, sm_margin: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: window_size = (-1, -1) # noqa: F841 # A lot of the parameters here are not yet used in any way within diffusers. # We can safely ignore for now and keep the fake op shape propagation simple. batch_size, seq_len, num_heads, head_dim = q.shape lse_shape = (batch_size, seq_len, num_heads) return torch.empty_like(q), q.new_empty(lse_shape) # ===== Helper functions to use attention backends with templated CP autograd functions ===== def _native_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): # Native attention does not return_lse if return_lse: raise ValueError("Native attention does not support return_lse=True") # used for backward pass if _save_ctx: ctx.save_for_backward(query, key, value) ctx.attn_mask = attn_mask ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale ctx.enable_gqa = enable_gqa query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) out = out.permute(0, 2, 1, 3) return out def _native_attention_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, **kwargs, ): query, key, value = ctx.saved_tensors query.requires_grad_(True) key.requires_grad_(True) value.requires_grad_(True) query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( query=query_t, key=key_t, value=value_t, attn_mask=ctx.attn_mask, dropout_p=ctx.dropout_p, is_causal=ctx.is_causal, scale=ctx.scale, enable_gqa=ctx.enable_gqa, ) out = out.permute(0, 2, 1, 3) grad_out_t = grad_out.permute(0, 2, 1, 3) grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False ) grad_query = grad_query_t.permute(0, 2, 1, 3) grad_key = grad_key_t.permute(0, 2, 1, 3) grad_value = grad_value_t.permute(0, 2, 1, 3) return grad_query, grad_key, grad_value # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) def _cudnn_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") tensors_to_save = () # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results # if the input tensors are not contiguous. query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() tensors_to_save += (query, key, value) out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=attn_mask, compute_log_sumexp=return_lse, dropout_p=dropout_p, is_causal=is_causal, return_debug_mask=False, scale=scale, ) ) tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) if _save_ctx: ctx.save_for_backward(*tensors_to_save) ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale ctx.attn_mask = attn_mask ctx.max_q = max_q ctx.max_k = max_k out = out.transpose(1, 2).contiguous() if lse is not None: lse = lse.transpose(1, 2).contiguous() return (out, lse) if return_lse else out # backward declaration: # aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) def _cudnn_attention_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, **kwargs, ): query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors grad_out = grad_out.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() # Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341 grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( grad_out, query, key, value, out, logsumexp=lse, philox_seed=philox_seed, philox_offset=philox_offset, attn_bias=ctx.attn_mask, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=ctx.max_q, max_k=ctx.max_k, dropout_p=ctx.dropout_p, is_causal=ctx.is_causal, scale=ctx.scale, ) grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value)) return grad_query, grad_key, grad_value # https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15135 # forward declaration: # aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) def _native_flash_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for native flash attention.") tensors_to_save = () query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() tensors_to_save += (query, key, value) out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_flash_attention( query=query, key=key, value=value, dropout_p=dropout_p, is_causal=is_causal, return_debug_mask=False, scale=scale, ) ) tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) if _save_ctx: ctx.save_for_backward(*tensors_to_save) ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.scale = scale ctx.max_q = max_q ctx.max_k = max_k out = out.transpose(1, 2).contiguous() if lse is not None: lse = lse.transpose(1, 2).contiguous() return (out, lse) if return_lse else out # https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15153 # backward declaration: # aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) def _native_flash_attention_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, **kwargs, ): query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors grad_out = grad_out.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( grad_out, query, key, value, out, logsumexp=lse, philox_seed=philox_seed, philox_offset=philox_offset, cum_seq_q=cum_seq_q, cum_seq_k=cum_seq_k, max_q=ctx.max_q, max_k=ctx.max_k, dropout_p=ctx.dropout_p, is_causal=ctx.is_causal, scale=ctx.scale, ) grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value)) return grad_query, grad_key, grad_value # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 def _flash_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): if attn_mask is not None: raise ValueError("`attn_mask` is not yet supported for flash-attn 2.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.") # Hardcoded for now window_size = (-1, -1) softcap = 0.0 alibi_slopes = None deterministic = False grad_enabled = any(x.requires_grad for x in (query, key, value)) if scale is None: scale = query.shape[-1] ** (-0.5) # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 with torch.set_grad_enabled(grad_enabled): out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward( query, key, value, dropout_p, scale, is_causal, window_size[0], window_size[1], softcap, alibi_slopes, return_lse, ) lse = lse.permute(0, 2, 1) if _save_ctx: ctx.save_for_backward(query, key, value, out, lse, rng_state) ctx.dropout_p = dropout_p ctx.scale = scale ctx.is_causal = is_causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return (out, lse) if return_lse else out def _flash_attention_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, **kwargs, ): query, key, value, out, lse, rng_state = ctx.saved_tensors grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) lse_d = _wrapped_flash_attn_backward( # noqa: F841 grad_out, query, key, value, out, lse, grad_query, grad_key, grad_value, ctx.dropout_p, ctx.scale, ctx.is_causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state, ) # Head dimension may have been padded grad_query = grad_query[..., : grad_out.shape[-1]] grad_key = grad_key[..., : grad_out.shape[-1]] grad_value = grad_value[..., : grad_out.shape[-1]] return grad_query, grad_key, grad_value def _flash_attention_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): if attn_mask is not None: raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.") config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] wrapped_forward_fn = config.wrapped_forward_fn wrapped_backward_fn = config.wrapped_backward_fn if wrapped_forward_fn is None or wrapped_backward_fn is None: raise RuntimeError( "Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` " "for context parallel execution." ) if scale is None: scale = query.shape[-1] ** (-0.5) window_size = (-1, -1) softcap = 0.0 alibi_slopes = None deterministic = False grad_enabled = any(x.requires_grad for x in (query, key, value)) if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 with torch.set_grad_enabled(grad_enabled): out, lse, S_dmask, rng_state = wrapped_forward_fn( query, key, value, dropout_p, scale, is_causal, window_size[0], window_size[1], softcap, alibi_slopes, return_lse, ) lse = lse.permute(0, 2, 1).contiguous() if _save_ctx: ctx.save_for_backward(query, key, value, out, lse, rng_state) ctx.dropout_p = dropout_p ctx.scale = scale ctx.is_causal = is_causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return (out, lse) if return_lse else out def _flash_attention_hub_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, **kwargs, ): config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] wrapped_backward_fn = config.wrapped_backward_fn if wrapped_backward_fn is None: raise RuntimeError( "Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution." ) query, key, value, out, lse, rng_state = ctx.saved_tensors grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) _ = wrapped_backward_fn( grad_out, query, key, value, out, lse, grad_query, grad_key, grad_value, ctx.dropout_p, ctx.scale, ctx.is_causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state, ) grad_query = grad_query[..., : grad_out.shape[-1]] grad_key = grad_key[..., : grad_out.shape[-1]] grad_value = grad_value[..., : grad_out.shape[-1]] return grad_query, grad_key, grad_value def _flash_attention_3_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, *, window_size: tuple[int, int] = (-1, -1), softcap: float = 0.0, num_splits: int = 1, pack_gqa: bool | None = None, deterministic: bool = False, sm_margin: int = 0, ): if attn_mask is not None: raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.") if dropout_p != 0.0: raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.") config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] wrapped_forward_fn = config.wrapped_forward_fn if wrapped_forward_fn is None: raise RuntimeError( "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` " "for context parallel execution." ) if scale is None: scale = query.shape[-1] ** (-0.5) out, softmax_lse, *_ = wrapped_forward_fn( query, key, value, None, None, # k_new, v_new None, # qv None, # out None, None, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k None, None, # max_seqlen_q/k None, None, None, # page_table, kv_batch_idx, leftpad_k None, None, None, # rotary_cos/sin, seqlens_rotary None, None, None, # q_descale, k_descale, v_descale scale, causal=is_causal, window_size_left=window_size[0], window_size_right=window_size[1], attention_chunk=0, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, ) lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None if _save_ctx: ctx.save_for_backward(query, key, value, out, softmax_lse) ctx.scale = scale ctx.is_causal = is_causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin return (out, lse) if return_lse else out def _flash_attention_3_hub_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, **kwargs, ): config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] wrapped_backward_fn = config.wrapped_backward_fn if wrapped_backward_fn is None: raise RuntimeError( "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` " "for context parallel execution." ) query, key, value, out, softmax_lse = ctx.saved_tensors grad_query = torch.empty_like(query) grad_key = torch.empty_like(key) grad_value = torch.empty_like(value) wrapped_backward_fn( grad_out, query, key, value, out, softmax_lse, None, None, # cu_seqlens_q, cu_seqlens_k None, None, # seqused_q, seqused_k None, None, # max_seqlen_q, max_seqlen_k grad_query, grad_key, grad_value, ctx.scale, ctx.is_causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, ) grad_query = grad_query[..., : grad_out.shape[-1]] grad_key = grad_key[..., : grad_out.shape[-1]] grad_value = grad_value[..., : grad_out.shape[-1]] return grad_query, grad_key, grad_value def _sage_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): if attn_mask is not None: raise ValueError("`attn_mask` is not yet supported for Sage attention.") if dropout_p > 0.0: raise ValueError("`dropout_p` is not yet supported for Sage attention.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for Sage attention.") out = sageattn( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) lse = None if return_lse: out, lse, *_ = out lse = lse.permute(0, 2, 1) return (out, lse) if return_lse else out def _sage_attention_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): if attn_mask is not None: raise ValueError("`attn_mask` is not yet supported for Sage attention.") if dropout_p > 0.0: raise ValueError("`dropout_p` is not yet supported for Sage attention.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for Sage attention.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn out = func( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) lse = None if return_lse: out, lse, *_ = out lse = lse.permute(0, 2, 1).contiguous() return (out, lse) if return_lse else out def _sage_attention_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, ): raise NotImplementedError("Backward pass is not implemented for Sage attention.") def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None): # Skip Attention Mask if all values are 1, `None` mask can speedup the computation if attn_mask is not None and torch.all(attn_mask != 0): attn_mask = None # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md if ( attn_mask is not None and attn_mask.ndim == 2 and attn_mask.shape[0] == query.shape[0] and attn_mask.shape[1] == key.shape[1] ): B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] attn_mask = ~attn_mask.to(torch.bool) attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() return attn_mask def _npu_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query, key, value, query.size(2), # num_heads atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, pre_tockens=65536, next_tockens=65536, keep_prob=1.0 - dropout_p, sync=False, inner_precise=0, )[0] return out # Not implemented yet. def _npu_attention_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, **kwargs, ): raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") # ===== Context parallel ===== # Reference: # - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827 # - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246 # For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method): def _wait_tensor(tensor): if isinstance(tensor, funcol.AsyncCollectiveTensor): tensor = tensor.wait() return tensor def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: shape = x.shape # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization # to benchmark triton codegen fails somewhere: # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3') # ValueError: Tensors must be contiguous x = x.flatten() x = funcol.all_to_all_single(x, None, None, group) x = x.reshape(shape) x = _wait_tensor(x) return x def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: """ Perform dimension sharding / reassembly across processes using _all_to_all_single. This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or head dimension flexibly by accepting scatter_idx and gather_idx. Args: x (torch.Tensor): Input tensor. Expected shapes: - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) scatter_idx (int) : Dimension along which the tensor is partitioned before all-to-all. gather_idx (int): Dimension along which the output is reassembled after all-to-all. group : Distributed process group for the Ulysses group. Returns: torch.Tensor: Tensor with globally exchanged dimensions. - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) """ group_world_size = torch.distributed.get_world_size(group) if scatter_idx == 2 and gather_idx == 1: # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence # dimension and scatters head dimension batch_size, seq_len_local, num_heads, head_dim = x.shape seq_len = seq_len_local * group_world_size num_heads_local = num_heads // group_world_size # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D x_temp = ( x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) .transpose(0, 2) .contiguous() ) if group_world_size > 1: out = _all_to_all_single(x_temp, group=group) else: out = x_temp # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) return out elif scatter_idx == 1 and gather_idx == 2: # Used after ulysses sequence parallel in unified SP. gathers the head dimension # scatters back the sequence dimension. batch_size, seq_len, num_heads_local, head_dim = x.shape num_heads = num_heads_local * group_world_size seq_len_local = seq_len // group_world_size # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D x_temp = ( x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) .permute(1, 3, 2, 0, 4) .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) ) if group_world_size > 1: output = _all_to_all_single(x_temp, group) else: output = x_temp output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) return output else: raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") class SeqAllToAllDim(torch.autograd.Function): """ all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange for more info. """ @staticmethod def forward(ctx, group, input, scatter_id=2, gather_id=1): ctx.group = group ctx.scatter_id = scatter_id ctx.gather_id = gather_id return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) @staticmethod def backward(ctx, grad_outputs): grad_input = SeqAllToAllDim.apply( ctx.group, grad_outputs, ctx.gather_id, # reversed ctx.scatter_id, # reversed ) return (None, grad_input, None, None) # Below are helper functions to handle abritrary head num and abritrary sequence length for Ulysses Anything Attention. def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> tuple[torch.Tensor, int]: r"""Maybe pad the head dimension to be divisible by world_size. x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: tuple[torch.Tensor, int], padded tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD """ world_size = dist.get_world_size(group=group) H_PAD = 0 if H % world_size != 0: H_PAD = world_size - (H % world_size) NEW_H_LOCAL = (H + H_PAD) // world_size # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2. # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14. assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() return x, H_PAD def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: r"""Maybe unpad the head dimension. x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, unpadded tensor (B, S_GLOBAL, H_LOCAL, D) """ rank = dist.get_rank(group=group) world_size = dist.get_world_size(group=group) # Only the last rank may have padding if H_PAD > 0 and rank == world_size - 1: x = x[:, :, :-H_PAD, :] return x.contiguous() def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> tuple[torch.Tensor, int]: r"""Maybe pad the head dimension to be divisible by world_size. x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: tuple[torch.Tensor, int], padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD """ if H is None: return x, 0 rank = dist.get_rank(group=group) world_size = dist.get_world_size(group=group) H_PAD = 0 # Only the last rank may need padding if H % world_size != 0: # We need to broadcast H_PAD to all ranks to keep consistency # in unpadding step later for all ranks. H_PAD = world_size - (H % world_size) NEW_H_LOCAL = (H + H_PAD) // world_size assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" if rank == world_size - 1: x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() return x, H_PAD def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: r"""Maybe unpad the head dimension. x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, unpadded tensor (B, S_LOCAL, H_GLOBAL, D) """ if H_PAD > 0: x = x[:, :, :-H_PAD, :] return x.contiguous() def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: # query: (B, S_LOCAL, H_GLOBAL, D) assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" extra_kwargs = {} extra_kwargs["NUM_QO_HEAD"] = query.shape[2] extra_kwargs["Q_S_LOCAL"] = query.shape[1] # Add other kwargs if needed in future return extra_kwargs @maybe_allow_in_graph def all_to_all_single_any_qkv_async( x: torch.Tensor, group: dist.ProcessGroup, **kwargs ) -> Callable[..., torch.Tensor]: r""" x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) """ world_size = dist.get_world_size(group=group) B, S_LOCAL, H, D = x.shape x, H_PAD = _maybe_pad_qkv_head(x, H, group) H_LOCAL = (H + H_PAD) // world_size # (world_size, S_LOCAL, B, H_LOCAL, D) x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() input_split_sizes = [S_LOCAL] * world_size # S_LOCAL maybe not equal for all ranks in dynamic shape case, # since we don't know the actual shape before this timing, thus, # we have to use all gather to collect the S_LOCAL first. output_split_sizes = gather_size_by_comm(S_LOCAL, group) x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D) x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) def wait() -> torch.Tensor: nonlocal x, H_PAD x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) # (S_GLOBAL, B, H_LOCAL, D) # -> (B, S_GLOBAL, H_LOCAL, D) x = x.permute(1, 0, 2, 3).contiguous() x = _maybe_unpad_qkv_head(x, H_PAD, group) return x return wait @maybe_allow_in_graph def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: r""" x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) """ # Assume H is provided in kwargs, since we can't infer H from x's shape. # The padding logic needs H to determine if padding is necessary. H = kwargs.get("NUM_QO_HEAD", None) world_size = dist.get_world_size(group=group) x, H_PAD = _maybe_pad_o_head(x, H, group) shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) (B, S_GLOBAL, H_LOCAL, D) = shape # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then, # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] + # b.tensor_split(4)[0].shape[1]) S_LOCAL = kwargs.get("Q_S_LOCAL") input_split_sizes = gather_size_by_comm(S_LOCAL, group) x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) output_split_sizes = [S_LOCAL] * world_size x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) def wait() -> torch.Tensor: nonlocal x, H_PAD x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) x = x.permute(2, 1, 0, 3, 4).contiguous() x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) x = _maybe_unpad_o_head(x, H_PAD, group) return x return wait class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None, dropout_p: float, is_causal: bool, scale: float | None, enable_gqa: bool, return_lse: bool, forward_op, backward_op, _parallel_config: "ParallelConfig" | None = None, ): ring_mesh = _parallel_config.context_parallel_config._ring_mesh rank = _parallel_config.context_parallel_config._ring_local_rank world_size = _parallel_config.context_parallel_config.ring_degree next_rank = (rank + 1) % world_size prev_out = prev_lse = None ctx.forward_op = forward_op ctx.backward_op = backward_op ctx.q_shape = query.shape ctx.kv_shape = key.shape ctx._parallel_config = _parallel_config kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) kv_buffer = kv_buffer.chunk(world_size) for i in range(world_size): if i > 0: kv = kv_buffer[next_rank] key_numel = key.numel() key = kv[:key_numel].reshape_as(key) value = kv[key_numel:].reshape_as(value) next_rank = (next_rank + 1) % world_size out, lse = forward_op( ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True, _save_ctx=i == 0, _parallel_config=_parallel_config, ) if _parallel_config.context_parallel_config.convert_to_fp32: out = out.to(torch.float32) lse = lse.to(torch.float32) # Refer to: # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 if is_torch_version("<", "2.9.0"): lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) prev_out = out prev_lse = lse out = out.to(query.dtype) lse = lse.squeeze(-1) return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, ): ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh rank = ctx._parallel_config.context_parallel_config._ring_local_rank world_size = ctx._parallel_config.context_parallel_config.ring_degree next_rank = (rank + 1) % world_size next_ranks = list(range(1, world_size)) + [0] accum_dtype = torch.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device) grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device) grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device) next_grad_kv = None query, key, value, *_ = ctx.saved_tensors kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) kv_buffer = kv_buffer.chunk(world_size) for i in range(world_size): if i > 0: kv = kv_buffer[next_rank] key_numel = key.numel() key = kv[:key_numel].reshape_as(key) value = kv[key_numel:].reshape_as(value) next_rank = (next_rank + 1) % world_size grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) if i > 0: grad_kv_buffer = _wait_tensor(next_grad_kv) grad_key_numel = grad_key.numel() grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key) grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value) grad_query += grad_query_op grad_key += grad_key_op grad_value += grad_value_op if i < world_size - 1: grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous() next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group()) grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None, dropout_p: float, is_causal: bool, scale: float | None, enable_gqa: bool, return_lse: bool, forward_op, backward_op, _parallel_config: "ParallelConfig" | None = None, ): ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh world_size = _parallel_config.context_parallel_config.ulysses_degree group = ulysses_mesh.get_group() ctx.forward_op = forward_op ctx.backward_op = backward_op ctx._parallel_config = _parallel_config B, S_Q_LOCAL, H, D = query.shape _, S_KV_LOCAL, _, _ = key.shape H_LOCAL = H // world_size query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) out = forward_op( ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, _save_ctx=True, _parallel_config=_parallel_config, ) if return_lse: out, lse, *_ = out out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() out = _all_to_all_single(out, group) out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() if return_lse: lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() lse = _all_to_all_single(lse, group) lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() else: lse = None return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, ): ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh world_size = ctx._parallel_config.context_parallel_config.ulysses_degree group = ulysses_mesh.get_group() B, S_LOCAL, H, D = grad_out.shape H_LOCAL = H // world_size grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() grad_out = _all_to_all_single(grad_out, group) grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) grad_query, grad_key, grad_value = ( x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() for x in (grad_query_op, grad_key_op, grad_value_op) ) grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value)) grad_query, grad_key, grad_value = ( x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) ) return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAnythingAttention(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor, dropout_p: float, is_causal: bool, scale: float, enable_gqa: bool, return_lse: bool, forward_op, backward_op, _parallel_config: "ParallelConfig" | None = None, **kwargs, ): ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh group = ulysses_mesh.get_group() ctx.forward_op = forward_op ctx.backward_op = backward_op ctx._parallel_config = _parallel_config metadata = ulysses_anything_metadata(query) query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) value_wait = all_to_all_single_any_qkv_async(value, group, **metadata) query = query_wait() # type: torch.Tensor key = key_wait() # type: torch.Tensor value = value_wait() # type: torch.Tensor out = forward_op( ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, _save_ctx=False, # ulysses anything only support forward pass now. _parallel_config=_parallel_config, ) if return_lse: out, lse, *_ = out # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) out_wait = all_to_all_single_any_o_async(out, group, **metadata) if return_lse: # lse: (B, S_Q_GLOBAL, H_LOCAL) lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) lse_wait = all_to_all_single_any_o_async(lse, group, **metadata) out = out_wait() # type: torch.Tensor lse = lse_wait() # type: torch.Tensor lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) else: out = out_wait() # type: torch.Tensor lse = None return (out, lse) if return_lse else out @staticmethod def backward( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, ): raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") def _templated_unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor, dropout_p: float, is_causal: bool, scale: float, enable_gqa: bool, return_lse: bool, forward_op, backward_op, _parallel_config: "ParallelConfig" | None = None, scatter_idx: int = 2, gather_idx: int = 1, ): """ Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 """ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh ulysses_group = ulysses_mesh.get_group() query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) out = TemplatedRingAttention.apply( query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op, backward_op, _parallel_config, ) if return_lse: context_layer, lse, *_ = out else: context_layer = out # context_layer is of shape (B, S, H_LOCAL, D) output = SeqAllToAllDim.apply( ulysses_group, context_layer, gather_idx, scatter_idx, ) if return_lse: # lse is of shape (B, S, H_LOCAL, 1) # Refer to: # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 if is_torch_version("<", "2.9.0"): lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) lse = lse.squeeze(-1) return (output, lse) return output def _templated_context_parallel_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, *, forward_op, backward_op, _parallel_config: "ParallelConfig" | None = None, ): if is_causal: raise ValueError("Causal attention is not yet supported for templated attention.") if enable_gqa: raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 if ( _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1 ): return _templated_unified_attention( query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op, backward_op, _parallel_config, ) elif _parallel_config.context_parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op, backward_op, _parallel_config, ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: if _parallel_config.context_parallel_config.ulysses_anything: # For Any sequence lengths and Any head num support return TemplatedUlyssesAnythingAttention.apply( query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op, backward_op, _parallel_config, ) else: return TemplatedUlyssesAttention.apply( query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op, backward_op, _parallel_config, ) else: raise ValueError("Reaching this branch of code is unexpected. Please report a bug.") # ===== Attention backends ===== @_AttentionBackendRegistry.register( AttentionBackendName.FLASH, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: lse = None if attn_mask is not None: raise ValueError("`attn_mask` is not supported for flash-attn 2.") if _parallel_config is None: out = flash_attn_func( q=query, k=key, v=value, dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, return_attn_probs=return_lse, ) if return_lse: out, lse, *_ = out else: out = _templated_context_parallel_attention( query, key, value, None, dropout_p, is_causal, scale, False, return_lse, forward_op=_flash_attention_forward_op, backward_op=_flash_attention_backward_op, _parallel_config=_parallel_config, ) if return_lse: out, lse = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _flash_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: lse = None if attn_mask is not None: raise ValueError("`attn_mask` is not supported for flash-attn 2.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn if _parallel_config is None: out = func( q=query, k=key, v=value, dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, return_attn_probs=return_lse, ) if return_lse: out, lse, *_ = out else: out = _templated_context_parallel_attention( query, key, value, None, dropout_p, is_causal, scale, False, return_lse, forward_op=_flash_attention_hub_forward_op, backward_op=_flash_attention_hub_backward_op, _parallel_config=_parallel_config, ) if return_lse: out, lse = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=False, ) def _flash_varlen_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, scale: float | None = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) key_valid, value_valid = [], [] for b in range(batch_size): valid_len = seqlens_k[b] key_valid.append(key[b, :valid_len]) value_valid.append(value[b, :valid_len]) query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn out = func( q=query_packed, k=key_packed, v=value_packed, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, return_attn_probs=return_lse, ) out = out.unflatten(0, (batch_size, -1)) return out @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _flash_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, scale: float | None = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) key_valid, value_valid = [], [] for b in range(batch_size): valid_len = seqlens_k[b] key_valid.append(key[b, :valid_len]) value_valid.append(value[b, :valid_len]) query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) out = flash_attn_varlen_func( q=query_packed, k=key_packed, v=value_packed, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, return_attn_probs=return_lse, ) out = out.unflatten(0, (batch_size, -1)) return out @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _flash_attention_3( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, scale: float | None = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for flash-attn 3.") out, lse = _wrapped_flash_attn_3( q=query, k=key, v=value, softmax_scale=scale, causal=is_causal, ) return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_3_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _flash_attention_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, scale: float | None = None, is_causal: bool = False, window_size: tuple[int, int] = (-1, -1), softcap: float = 0.0, deterministic: bool = False, return_attn_probs: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for flash-attn 3.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn if _parallel_config is None: out = func( q=query, k=key, v=value, softmax_scale=scale, causal=is_causal, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=window_size, softcap=softcap, num_splits=1, pack_gqa=None, deterministic=deterministic, sm_margin=0, return_attn_probs=return_attn_probs, ) return (out[0], out[1]) if return_attn_probs else out forward_op = functools.partial( _flash_attention_3_hub_forward_op, window_size=window_size, softcap=softcap, num_splits=1, pack_gqa=None, deterministic=deterministic, sm_margin=0, ) backward_op = functools.partial( _flash_attention_3_hub_backward_op, window_size=window_size, softcap=softcap, num_splits=1, pack_gqa=None, deterministic=deterministic, sm_margin=0, ) out = _templated_context_parallel_attention( query, key, value, None, 0.0, is_causal, scale, False, return_attn_probs, forward_op=forward_op, backward_op=backward_op, _parallel_config=_parallel_config, ) if return_attn_probs: out, lse = out return out, lse return out @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_3_VARLEN_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=False, ) def _flash_attention_3_varlen_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, scale: float | None = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) key_valid, value_valid = [], [] for b in range(batch_size): valid_len = seqlens_k[b] key_valid.append(key[b, :valid_len]) value_valid.append(value[b, :valid_len]) query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn out, lse, *_ = func( q=query_packed, k=key_packed, v=value_packed, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, softmax_scale=scale, causal=is_causal, ) out = out.unflatten(0, (batch_size, -1)) return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _flash_varlen_attention_3( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, scale: float | None = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) key_valid, value_valid = [], [] for b in range(batch_size): valid_len = seqlens_k[b] key_valid.append(key[b, :valid_len]) value_valid.append(value[b, :valid_len]) query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) result = flash_attn_3_varlen_func( q=query_packed, k=key_packed, v=value_packed, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, softmax_scale=scale, causal=is_causal, return_attn_probs=return_lse, ) if isinstance(result, tuple): out, lse, *_ = result else: out = result lse = None out = out.unflatten(0, (batch_size, -1)) return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.AITER, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _aiter_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for aiter attention") if not return_lse and torch.is_grad_enabled(): # aiter requires return_lse=True by assertion when gradients are enabled. out, lse, *_ = aiter_flash_attn_func( q=query, k=key, v=value, dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, return_lse=True, ) else: out = aiter_flash_attn_func( q=query, k=key, v=value, dropout_p=dropout_p, softmax_scale=scale, causal=is_causal, return_lse=return_lse, ) if return_lse: out, lse, *_ = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.FLEX, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], ) def _native_flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | "flex_attention.BlockMask" | None = None, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: # TODO: should we LRU cache the block mask creation? score_mod = None block_mask = None batch_size, seq_len_q, num_heads, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): block_mask = attn_mask elif is_causal: block_mask = flex_attention.create_block_mask( _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device ) elif torch.is_tensor(attn_mask): if attn_mask.ndim == 2: attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) if attn_mask.dtype == torch.bool: # TODO: this probably does not work but verify! def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return attn_mask[batch_idx, head_idx, q_idx, kv_idx] block_mask = flex_attention.create_block_mask( mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device ) else: def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] else: raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = flex_attention.flex_attention( query=query, key=key, value=value, score_mod=score_mod, block_mask=block_mask, scale=scale, enable_gqa=enable_gqa, return_lse=return_lse, ) out = out.permute(0, 2, 1, 3) return out def _prepare_additive_attn_mask( attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True ) -> torch.Tensor: """ Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA. This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks. Args: attn_mask: 2D tensor [batch_size, seq_len_k] - Boolean: True means attend, False means mask out - Additive: 0.0 means attend, -inf means mask out target_dtype: The dtype to convert the mask to (usually query.dtype) reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting Returns: Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. """ # Check if the mask is boolean or already additive if attn_mask.dtype == torch.bool: # Convert boolean to additive: True -> 0.0, False -> -inf attn_mask = torch.where(attn_mask, 0.0, float("-inf")) # Convert to target dtype attn_mask = attn_mask.to(dtype=target_dtype) else: # Already additive mask - just ensure correct dtype attn_mask = attn_mask.to(dtype=target_dtype) # Optionally reshape to 4D for broadcasting in attention mechanisms if reshape_4d: batch_size, seq_len_k = attn_mask.shape attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) return attn_mask @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], supports_context_parallel=True, ) def _native_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") # Reshape 2D mask to 4D for SDPA # SDPA accepts both boolean masks (torch.bool) and additive masks (float) if ( attn_mask is not None and attn_mask.ndim == 2 and attn_mask.shape[0] == query.shape[0] and attn_mask.shape[1] == key.shape[1] ): # Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k] # SDPA handles both boolean and additive masks correctly attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) out = out.permute(0, 2, 1, 3) else: out = _templated_context_parallel_attention( query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op=_native_attention_forward_op, backward_op=_native_attention_backward_op, _parallel_config=_parallel_config, ) return out @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_CUDNN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _native_cudnn_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: lse = None if _parallel_config is None and not return_lse: query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) out = out.permute(0, 2, 1, 3) else: out = _templated_context_parallel_attention( query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op=_cudnn_attention_forward_op, backward_op=_cudnn_attention_backward_op, _parallel_config=_parallel_config, ) if return_lse: out, lse = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_EFFICIENT, constraints=[_check_device, _check_shape], ) def _native_efficient_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if return_lse: raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) out = out.permute(0, 2, 1, 3) return out @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_FLASH, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _native_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for aiter attention") lse = None if _parallel_config is None and not return_lse: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): out = torch.nn.functional.scaled_dot_product_attention( query=query, key=key, value=value, attn_mask=None, # not supported dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) out = out.permute(0, 2, 1, 3) else: out = _templated_context_parallel_attention( query, key, value, None, dropout_p, is_causal, scale, enable_gqa, return_lse, forward_op=_native_flash_attention_forward_op, backward_op=_native_flash_attention_backward_op, _parallel_config=_parallel_config, ) if return_lse: out, lse = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_MATH, constraints=[_check_device, _check_shape], ) def _native_math_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if return_lse: raise ValueError("Native math attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): out = torch.nn.functional.scaled_dot_product_attention( query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa, ) out = out.permute(0, 2, 1, 3) return out @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_NPU, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _native_npu_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") if _parallel_config is None: attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) out = npu_fusion_attention( query, key, value, query.size(2), # num_heads atten_mask=attn_mask, input_layout="BSND", pse=None, scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, pre_tockens=65536, next_tockens=65536, keep_prob=1.0 - dropout_p, sync=False, inner_precise=0, )[0] else: out = _templated_context_parallel_attention( query, key, value, attn_mask, dropout_p, None, scale, None, return_lse, forward_op=_npu_attention_forward_op, backward_op=_npu_attention_backward_op, _parallel_config=_parallel_config, ) return out # Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_XLA, constraints=[_check_device, _check_shape], ) def _native_xla_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for XLA attention") if return_lse: raise ValueError("XLA attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) query = query / math.sqrt(query.shape[-1]) out = xla_flash_attention( q=query, k=key, v=value, causal=is_causal, ) out = out.permute(0, 2, 1, 3) return out @_AttentionBackendRegistry.register( AttentionBackendName.SAGE, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _sage_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for sage attention") lse = None if _parallel_config is None: out = sageattn( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) if return_lse: out, lse, *_ = out else: out = _templated_context_parallel_attention( query, key, value, None, 0.0, is_causal, scale, False, return_lse, forward_op=_sage_attention_forward_op, backward_op=_sage_attention_backward_op, _parallel_config=_parallel_config, ) if return_lse: out, lse = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_HUB, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _sage_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for sage attention") lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn if _parallel_config is None: out = func( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) if return_lse: out, lse, *_ = out else: out = _templated_context_parallel_attention( query, key, value, None, 0.0, is_causal, scale, False, return_lse, forward_op=_sage_attention_hub_forward_op, backward_op=_sage_attention_backward_op, _parallel_config=_parallel_config, ) if return_lse: out, lse = out return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], ) def _sage_varlen_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if return_lse: raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape if attn_mask is not None: attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen( batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device ) ) key_valid, value_valid = [], [] for b in range(batch_size): valid_len = seqlens_k[b] key_valid.append(key[b, :valid_len]) value_valid.append(value[b, :valid_len]) query_packed = query.flatten(0, 1) key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) out = sageattn_varlen( q=query_packed, k=key_packed, v=value_packed, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, is_causal=is_causal, sm_scale=scale, ) out = out.unflatten(0, (batch_size, -1)) return out @_AttentionBackendRegistry.register( AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], ) def _sage_qk_int8_pv_fp8_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) @_AttentionBackendRegistry.register( AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], ) def _sage_qk_int8_pv_fp8_cuda_sm90_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) @_AttentionBackendRegistry.register( AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], ) def _sage_qk_int8_pv_fp16_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) @_AttentionBackendRegistry.register( AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], ) def _sage_qk_int8_pv_fp16_triton_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, is_causal: bool = False, scale: float | None = None, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_triton( q=query, k=key, v=value, tensor_layout="NHD", is_causal=is_causal, sm_scale=scale, return_lse=return_lse, ) @_AttentionBackendRegistry.register( AttentionBackendName.XFORMERS, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], ) def _xformers_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if return_lse: raise ValueError("xformers attention backend does not support setting `return_lse=True`.") batch_size, seq_len_q, num_heads_q, _ = query.shape _, seq_len_kv, num_heads_kv, _ = key.shape if is_causal: attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: # Convert 2D mask to 4D for xformers # Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask) # xformers requires 4D additive masks [batch, heads, seq_q, seq_k] # Need memory alignment - create larger tensor and slice for alignment original_seq_len = attn_mask.size(1) aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 # Create aligned 4D tensor and slice to ensure proper memory layout aligned_mask = torch.zeros( (batch_size, num_heads_q, seq_len_q, aligned_seq_len), dtype=query.dtype, device=query.device, ) # Convert to 4D additive mask (handles both boolean and additive inputs) mask_additive = _prepare_additive_attn_mask( attn_mask, target_dtype=query.dtype ) # [batch, 1, 1, seq_len_k] # Broadcast to [batch, heads, seq_q, seq_len_k] aligned_mask[:, :, :, :original_seq_len] = mask_additive # Mask out the padding (already -inf from zeros -> where with default) aligned_mask[:, :, :, original_seq_len:] = float("-inf") # Slice to actual size with proper alignment attn_mask = aligned_mask[:, :, :, :seq_len_kv] elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") elif attn_mask.ndim == 4: attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") num_heads_per_group = num_heads_q // num_heads_kv query = query.unflatten(2, (num_heads_kv, -1)) key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) if enable_gqa: out = out.flatten(2, 3) return out