Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import math | |
| import random | |
| from functools import partial | |
| from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar, Union | |
| import pytest | |
| import torch | |
| import torch.nn.functional as F | |
| from scipy.stats import binomtest | |
| from torch.utils.checkpoint import checkpoint | |
| import xformers.ops | |
| from xformers.attn_bias_utils import create_attn_bias | |
| from xformers.ops import fmha | |
| from xformers.ops.fmha import ALL_BW_OPS, ALL_FW_OPS | |
| from xformers.ops.fmha.common import AttentionFwOpBase, AttentionOpBase | |
| from xformers.ops.fmha.dispatch import _dispatch_fw_priority_list | |
| from .utils import ( | |
| assert_allclose, | |
| cuda_only, | |
| disable_on_rocm, | |
| disable_tf32, | |
| pack_kv_cache, | |
| ref_attention_bmhk_for_test, | |
| ref_attention_for_test, | |
| rocm_only, | |
| ) | |
| compute_capability = (0, 0) | |
| if torch.cuda.is_available(): | |
| compute_capability = torch.cuda.get_device_capability("cuda") | |
| sm70_or_better_only = pytest.mark.skipif( | |
| compute_capability < (7, 0), reason="requires sm70+" | |
| ) | |
| sm75_or_better_only = pytest.mark.skipif( | |
| compute_capability < (7, 5), reason="requires sm75+" | |
| ) | |
| sm80_or_better_only = pytest.mark.skipif( | |
| compute_capability < (8, 0), reason="requires sm80+" | |
| ) | |
| skip_if_rocm = pytest.mark.skipif( | |
| torch.version.hip is not None, reason="not supported on ROCm" | |
| ) | |
| skip_if_pt_cutlass = pytest.mark.skipif( | |
| fmha.cutlass.USE_TORCH_CUTLASS, reason="using PT cutlass" | |
| ) | |
| _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] | |
| T = TypeVar( | |
| "T", Type[fmha.common.AttentionFwOpBase], Type[fmha.common.AttentionBwOpBase] | |
| ) | |
| logger = logging.getLogger("xformers") | |
| def _filter_unsupported_ops(ops: Sequence[T]) -> List[T]: | |
| return [ | |
| op | |
| for op in ops | |
| if ( | |
| "cpu" in op.SUPPORTED_DEVICES | |
| or op.CUDA_MINIMUM_COMPUTE_CAPABILITY <= compute_capability | |
| ) | |
| and op.is_available() | |
| ] | |
| ALL_FW_OPS_NO_UNPADDED_LSE = _filter_unsupported_ops(ALL_FW_OPS) | |
| ALL_FW_OPS = _filter_unsupported_ops( | |
| ALL_FW_OPS | |
| + ( | |
| [fmha.flash.FlashFwUnpaddedLSE] | |
| if fmha.flash.FLASH_SUPPORTS_UNPADDED_LSE | |
| else [] | |
| ) | |
| ) | |
| ALL_BW_OPS = _filter_unsupported_ops(ALL_BW_OPS) | |
| def sample_random_supported_fw( | |
| inp: fmha.Inputs, seed: int | |
| ) -> Type[fmha.common.AttentionFwOpBase]: | |
| r = random.Random(seed) | |
| fw_ops = list(ALL_FW_OPS) | |
| r.shuffle(fw_ops) | |
| for op in fw_ops: | |
| if op.supports(inp): | |
| return op | |
| raise NotImplementedError(f"Could not find a FW operator for: {inp}") | |
| def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): | |
| shapes = [] | |
| for B in op._TEST_BATCH_SIZES: | |
| for Mq in [32, 256]: | |
| for Mkv in [32, 64, 256, 1024]: | |
| for K in op._TEST_K: | |
| shapes.append((B, Mq, Mkv, 1, K, K)) | |
| Mq = 256 | |
| Mkv = 128 | |
| K = 32 | |
| H = 1 | |
| # Weird values of parameters | |
| for M in [2, 3, 15, 31, 32, 34, 68, 72, 90, 132, 136]: | |
| shapes.append((B, M, Mkv, H, K, K)) | |
| shapes.append((B, Mq, M, H, K, K)) | |
| for _K in [1, 2, 3, 31, 34, 36, 38, 40, 64, 80, 160, 256 + 2, 256 + 8, 512]: | |
| if _K <= op.SUPPORTED_MAX_K: | |
| shapes.append((B, Mq, Mkv, H, _K, _K)) | |
| # Different value for K / Kv | |
| if op.SUPPORTS_DIFFERENT_VALUE_EMBED: | |
| for _K in [32, 36, 64, 256 + 8]: | |
| shapes.append((B, Mq, Mkv, H, K, _K)) | |
| shapes.append((B, Mq, Mkv, H, _K, K)) | |
| # Exotic sizes | |
| for K in op._TEST_K: | |
| shapes.append((B, 16, 1024, H, K, K)) | |
| shapes.append((B, 1024, 16, H, K, K)) | |
| # Some number of heads | |
| for H in [3, 5, 12]: | |
| shapes.append((max(1, B // H), Mq, Mkv, H, K, K)) | |
| # Filter-out not supported shapes | |
| shapes = [ | |
| shape | |
| for shape in shapes | |
| if len( | |
| op.shape_not_supported_reasons( | |
| Mq=shape[1], Mkv=shape[2], K=shape[4], Kv=shape[5] | |
| ) | |
| ) | |
| == 0 | |
| ] | |
| # Add some random shapes | |
| if op in [ | |
| fmha.cutlass.FwOp, | |
| fmha.cutlass.BwOp, | |
| fmha.flash.BwOp, | |
| fmha.ck.FwOp, | |
| ]: | |
| K_CHOICES = [8 * i for i in range(1, 256 // 8)] | |
| r = random.Random(0) | |
| found_count = 0 | |
| while found_count < 200: | |
| B = r.randint(1, 400) | |
| Mq = r.randint(1, 500) | |
| Mkv = r.randint(1, 500) | |
| H = r.randint(2, 11) | |
| B = max(B // H, 1) | |
| K = r.choice(K_CHOICES) | |
| Kv = r.choice(K_CHOICES) | |
| if not op.SUPPORTS_DIFFERENT_VALUE_EMBED: | |
| Kv = K | |
| if len(op.shape_not_supported_reasons(Mq, Mkv, K, Kv)): | |
| continue | |
| found_count += 1 | |
| shapes.append((B, Mq, Mkv, H, K, Kv)) | |
| return shapes | |
| def make_id(op, device, dtype, bias_type, *shape): | |
| return ( | |
| f"{op.NAME}-{device}-{str(dtype)}-{bias_type.__name__}" | |
| f"-{'-'.join([str(s) for s in shape])}" | |
| ) | |
| def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv( | |
| ops_list: Sequence[Type[fmha.AttentionOpBase]], max_shapes_per_op: int = 65000 | |
| ): | |
| r = random.Random(0) | |
| combination = [] | |
| ids = [] | |
| for op in ops_list: | |
| op_count = 0 | |
| # Sort list of masks, so it's deterministic across runs | |
| LIST_MASKS = list(sorted(op.SUPPORTED_ATTN_BIAS_TYPES, key=lambda x: str(x))) | |
| for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): | |
| has_one = False | |
| for device in _devices: | |
| if device not in op.SUPPORTED_DEVICES: | |
| continue | |
| for dtype in op.SUPPORTED_DTYPES: | |
| bias_type = r.choice(LIST_MASKS) | |
| # Avoid using too much memory | |
| if bias_type not in [ | |
| type(None), | |
| fmha.attn_bias.LowerTriangularMask, | |
| ]: | |
| B, Mq, Mkv, H, K, Kv = shape | |
| B = min(B, 12) | |
| if bias_type in { | |
| fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask, | |
| fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, | |
| }: | |
| Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) + 2 | |
| elif bias_type in { | |
| fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask, | |
| fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, | |
| fmha.attn_bias.BlockDiagonalPaddedKeysMask, | |
| fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, | |
| fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
| }: | |
| Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) | |
| shape = (B, Mq, Mkv, H, K, Kv) | |
| combination.append((op, device, dtype, bias_type, *shape)) | |
| ids.append( | |
| f"{op.NAME}-{device}-{str(dtype)}-{bias_type.__name__}" | |
| f"-{'-'.join([str(s) for s in shape])}" | |
| ) | |
| has_one = True | |
| if has_one: | |
| op_count += 1 | |
| if op_count > max_shapes_per_op: | |
| break | |
| # Some specific shapes for which we want to run without any mask | |
| bias_type = type(None) | |
| for shape in ( | |
| # Some strides/dims don't fit on an uint16 | |
| (1, 128, 128, 300, 128, 128), | |
| (13, 1, 67, 200, 8, 8), | |
| (1, 1 + 2**16, 4, 1, 8, 8), | |
| (1, 4, 1 + 2**16, 1, 8, 8), | |
| # TODO: Some strides don't fit on an uint32 | |
| # Crashes on Flash, Errors on Cutlass | |
| # (1, 1, 64000, 300, 128, 128) | |
| ): | |
| for device in _devices: | |
| if device not in op.SUPPORTED_DEVICES: | |
| continue | |
| for dtype in op.SUPPORTED_DTYPES: | |
| combination.append((op, device, dtype, bias_type, *shape)) | |
| return { | |
| "argvalues": combination, | |
| "ids": [make_id(*c) for c in combination], | |
| } | |
| parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize( | |
| "opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
| **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS), | |
| ) | |
| parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_NO_UNPADDED_LSE = ( | |
| pytest.mark.parametrize( | |
| "opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
| **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS_NO_UNPADDED_LSE), | |
| ) | |
| ) | |
| parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize( | |
| "opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
| **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS, max_shapes_per_op=1), | |
| ) | |
| parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize( | |
| "opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
| **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS), | |
| ) | |
| parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize( | |
| "opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
| **_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS, max_shapes_per_op=1), | |
| ) | |
| def _rand_partition(r: random.Random, total: int, n: int) -> List[int]: | |
| # returns list of n nonnegative integers summing to total | |
| idx = {0, total} | |
| while len(idx) < n + 1: | |
| idx.add(r.randint(1, total - 1)) | |
| s = sorted(idx) | |
| return [e - b for b, e in zip(s[:-1], s[1:])] | |
| def get_bias_grad(attn_bias, clear: bool = False) -> Optional[torch.Tensor]: | |
| tensor_with_grad: Optional[torch.Tensor] = None | |
| if isinstance(attn_bias, torch.Tensor): | |
| tensor_with_grad = attn_bias | |
| if tensor_with_grad is not None: | |
| grad = tensor_with_grad.grad | |
| if clear: | |
| tensor_with_grad.grad = None | |
| return grad | |
| return None | |
| def create_tensors( | |
| op: Optional[Type[AttentionOpBase]], | |
| device, | |
| dtype, | |
| attn_bias_type, | |
| B, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| *, | |
| attn_bias_requires_grad: bool = False, | |
| fmt: str = "BMK", | |
| g: int = 1, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]: | |
| torch.manual_seed(B * q_len + kv_len * k + kv) | |
| mask_is_bottom_right = attn_bias_type is not None and issubclass( | |
| attn_bias_type, | |
| ( | |
| fmha.attn_bias.LowerTriangularFromBottomRightMask, | |
| fmha.attn_bias.LowerTriangularFromBottomRightLocalAttentionMask, | |
| fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask, | |
| fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, | |
| fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask, | |
| fmha.attn_bias.LocalAttentionFromBottomRightMask, | |
| ), | |
| ) | |
| if mask_is_bottom_right and q_len > kv_len: | |
| # Bottom-right attention and local-attention masks require q_len <= kv_len | |
| kv_len = q_len | |
| if attn_bias_type is not None and issubclass( | |
| attn_bias_type, | |
| ( | |
| fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
| fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
| ), | |
| ): | |
| page_size_choices = [256, 512] | |
| if op is not None and issubclass(op, fmha.triton_splitk.FwOp): | |
| # TODO: enable small pages for flash attention when that's implemented | |
| page_size_choices.extend([64, 128]) | |
| page_size = random.choice(page_size_choices) | |
| kv_len_paged = (kv_len + page_size - 1) // page_size * page_size | |
| else: | |
| kv_len_paged = kv_len | |
| page_size = None | |
| scale = 3 | |
| if fmt == "BMK": | |
| query = torch.randn((B * h, q_len, k), device=device, dtype=dtype) | |
| key = torch.randn((B * h, kv_len_paged, k), device=device, dtype=dtype) | |
| value = torch.randn((B * h, kv_len_paged, kv), device=device, dtype=dtype) | |
| elif fmt == "BMHK": | |
| query = torch.randn((B, q_len, h, k), device=device, dtype=dtype) | |
| key = torch.randn((B, kv_len_paged, h, k), device=device, dtype=dtype) | |
| value = torch.randn((B, kv_len_paged, h, kv), device=device, dtype=dtype) | |
| else: | |
| assert fmt == "BMGHK" | |
| query = torch.randn((B, q_len, g, h, k), device=device, dtype=dtype) | |
| key = torch.randn((B, kv_len_paged, g, 1, k), device=device, dtype=dtype) | |
| value = torch.randn((B, kv_len_paged, g, 1, kv), device=device, dtype=dtype) | |
| for x in [query, key, value]: | |
| x.mul_(scale) | |
| if fmt == "BMGHK": | |
| # Expand - after the in-place mul | |
| key = key.expand((B, kv_len_paged, g, h, k)) | |
| value = value.expand((B, kv_len_paged, g, h, k)) | |
| if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(attn_bias_type): | |
| attn_bias_type = None | |
| attn_bias = None | |
| if attn_bias_type is not None: | |
| attn_bias = create_attn_bias( | |
| attn_bias_type, | |
| batch_size=B, | |
| num_heads=h, | |
| num_heads_groups=g, | |
| q_len=q_len, | |
| kv_len=kv_len, | |
| dtype=dtype, | |
| device=device, | |
| requires_grad=attn_bias_requires_grad, | |
| fmt=fmt, | |
| op=op, | |
| page_size=page_size, | |
| ) | |
| if isinstance( | |
| attn_bias, | |
| ( | |
| fmha.attn_bias.BlockDiagonalMask, | |
| fmha.attn_bias.BlockDiagonalGappyKeysMask, | |
| fmha.attn_bias.BlockDiagonalPaddedKeysMask, | |
| fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
| fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
| ), | |
| ): | |
| query, key, value = [ | |
| x.reshape([1, -1, *x.shape[2:]]) for x in [query, key, value] | |
| ] | |
| inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias) | |
| if op is not None: | |
| reasons = op.not_supported_reasons(inputs) | |
| if reasons: | |
| err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})" | |
| # Ensure we free memory to avoid OOMs | |
| del query, key, value, attn_bias, inputs | |
| pytest.skip(err_msg) | |
| return query, key, value, attn_bias | |
| def bmhk2bmk(tensor) -> torch.Tensor: | |
| return ( | |
| tensor.permute((0, 2, 1, 3)) | |
| .contiguous() | |
| .view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]]) | |
| ) | |
| def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: | |
| return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute( | |
| (0, 2, 1, 3) | |
| ) | |
| def nanify_oob_seqlen(x: torch.Tensor) -> torch.Tensor: | |
| align_to = 256 | |
| if x.shape[1] % align_to == 0: | |
| return x | |
| pad = [0, 0] * x.ndim | |
| pad[-3] = align_to - (x.shape[1] % align_to) | |
| x_pad = torch.nn.functional.pad(x, pad, value=math.nan) | |
| return x_pad[:, : x.shape[1]] | |
| def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs): | |
| ( | |
| op, | |
| device, | |
| dtype, | |
| bias_type, | |
| batch_size, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
| if packed and issubclass( | |
| bias_type, | |
| ( | |
| fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
| fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
| ), | |
| ): | |
| pytest.skip( | |
| "packed doesn't make sense with paged attention, since q has different shape than k/v" | |
| ) | |
| if packed and not (k == kv and q_len == kv_len): | |
| pytest.skip( | |
| f"packed incompatible with `k ({k}) != kv ({kv})` or `q_len ({q_len}) != kv_len ({kv_len})`" | |
| ) | |
| if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(bias_type): | |
| pytest.skip("BMK incompatible with this bias") | |
| query, key, value, attn_bias = create_tensors( | |
| *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| fmt="BMHK" if packed else fmt, | |
| **kwargs, | |
| ) | |
| if packed: | |
| c = torch.stack([query, key, value], 2) | |
| if fmt == "BMK": | |
| # bm3hk -> 3bhmk -> 3Bmk | |
| c = c.permute(2, 0, 3, 1, 4).view([3, -1, q_len, k]) | |
| query, key, value = c[0], c[1], c[2] | |
| # Re-create bias in the right format | |
| attn_bias = create_attn_bias( | |
| bias_type=bias_type, | |
| batch_size=batch_size, | |
| num_heads=h, | |
| num_heads_groups=1, | |
| q_len=q_len, | |
| kv_len=kv_len, | |
| device=device, | |
| dtype=dtype, | |
| requires_grad=False, | |
| fmt=fmt, | |
| op=op, | |
| ) | |
| elif fmt == "BMHK": | |
| # bm3hk -> 3 x bmhk | |
| query, key, value = xformers.ops.unbind(c, 2) | |
| else: | |
| assert False, f"Unsupport fmt {fmt} with packing" | |
| assert not query.is_contiguous() | |
| out = xformers.ops.memory_efficient_attention_forward( | |
| query, key, value, attn_bias, op=op | |
| ) | |
| assert not out.isnan().any(), ("Output has NaNs", attn_bias) | |
| out2 = xformers.ops.memory_efficient_attention_forward( | |
| nanify_oob_seqlen(query), | |
| nanify_oob_seqlen(key), | |
| nanify_oob_seqlen(value), | |
| attn_bias, | |
| op=op, | |
| ) | |
| assert not out2.isnan().any(), "Output has NaNs - most likely reading out-of-bounds" | |
| assert torch.allclose(out, out2, atol=0.0, rtol=0.0), ( | |
| "Non-deterministic behavior", | |
| attn_bias, | |
| ) | |
| ref = ref_attention_for_test(query, key, value, attn_bias) | |
| assert out.shape == ref.shape, out.shape | |
| assert_allclose( | |
| out.float(), | |
| ref, | |
| atol=op.ERROR_ATOL[dtype], | |
| rtol=op.ERROR_RTOL.get(dtype, 1e-5), | |
| ) | |
| def test_key_query_all_ones(q_len, kv_len, batch_size, k_len): | |
| device = "cuda" | |
| scale = 3 | |
| # composable kernel doesn't support fp32 | |
| dtype = torch.float16 if torch.version.hip else torch.float32 | |
| query = torch.ones((batch_size, q_len, k_len), device=device, dtype=dtype) | |
| key = torch.ones((batch_size, kv_len, k_len), device=device, dtype=dtype) | |
| value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale | |
| out = xformers.ops.memory_efficient_attention(query, key, value) | |
| # this should be equivalent to the average over value | |
| ref = value.mean(1, keepdim=True).expand_as(query) | |
| assert_allclose(out, ref, atol=1e-5) | |
| def _block_diag_reshape_lse( | |
| lse: torch.Tensor, q_seqinfo: fmha.attn_bias._SeqLenInfo | |
| ) -> torch.Tensor: | |
| """LSE can be padded, let's remove the padding""" | |
| parts = [] | |
| for slice, (start, end) in zip(lse.unbind(0), q_seqinfo.intervals()): | |
| parts.append(slice[:, : end - start]) | |
| return torch.cat(parts, dim=1).unsqueeze(0) | |
| def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): | |
| ( | |
| op, | |
| device, | |
| dtype, | |
| bias_type, | |
| batch_size, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
| if op is fmha.ck.FwOp: | |
| pytest.skip("logsumexp is not yet supported by ck-tiled fmha!") | |
| query, key, value, attn_bias = create_tensors( | |
| *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| fmt="BMHK", | |
| ) | |
| _out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( | |
| query, | |
| key, | |
| value, | |
| op=op, | |
| attn_bias=attn_bias, | |
| ) | |
| query = query.transpose(1, 2) | |
| key = key.transpose(1, 2) | |
| attn = (query.float() / k**0.5) @ key.float().transpose(-2, -1) | |
| if attn_bias is not None: | |
| if isinstance( | |
| attn_bias, | |
| (fmha.attn_bias.AttentionBias, fmha.attn_bias.AttentionBiasSubTensor), | |
| ): | |
| bias_shape = (1, 1, query.shape[2], key.shape[2]) | |
| tensor_bias = attn_bias.materialize( | |
| bias_shape, | |
| device=query.device, | |
| dtype=torch.float32, | |
| ) | |
| else: | |
| assert type(attn_bias) is torch.Tensor | |
| tensor_bias = attn_bias | |
| attn = attn + tensor_bias.float() | |
| ref_lse = attn.logsumexp(-1) | |
| if isinstance( | |
| attn_bias, | |
| ( | |
| fmha.attn_bias.BlockDiagonalMask, | |
| fmha.attn_bias.BlockDiagonalGappyKeysMask, | |
| fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
| fmha.attn_bias.BlockDiagonalPaddedKeysMask, | |
| ), | |
| ) and issubclass(op, (fmha.flash.FwOp, fmha.cutlass.FwOp)): | |
| # Sometimes LSE is returned in padded format, i.e. (B, H, MAX_LEN) instead of (H, TOTAL_LEN). | |
| # Unpad to compare with the reference. | |
| # This is the case for Flash Attention when UNPADDED_LSE=False and for CUTLASS. | |
| if op.UNPADDED_LSE: | |
| lse = lse.unsqueeze(0) | |
| else: | |
| lse = _block_diag_reshape_lse(lse, attn_bias.q_seqinfo) | |
| if op is fmha.cutlass.FwOp: | |
| # CUTLASS kernel pads the last dimention of LSE to 32 | |
| lse = lse[:, :, : ref_lse.shape[2]] | |
| assert_allclose(lse, ref_lse, atol=2e-4) | |
| def test_logsumexp_mqa(op): | |
| if not op.is_available(): | |
| pytest.skip("not available") | |
| if op.CUDA_MINIMUM_COMPUTE_CAPABILITY > compute_capability: | |
| skip_reason = ( | |
| f"requires device with capability >= {op.CUDA_MINIMUM_COMPUTE_CAPABILITY} " | |
| f"but your GPU has capability {compute_capability} (too old)" | |
| ) | |
| pytest.skip(skip_reason) | |
| dtype = torch.float16 | |
| s = 3 | |
| query = torch.randn([1, 1, 32, 128], dtype=dtype, device="cuda") * s | |
| key = (torch.randn([1, 16, 1, 128], dtype=dtype, device="cuda") * s).expand( | |
| -1, -1, 32, -1 | |
| ) | |
| value = (torch.randn([1, 16, 1, 128], dtype=dtype, device="cuda") * s).expand( | |
| -1, -1, 32, -1 | |
| ) | |
| assert key.stride(2) == 0 | |
| _, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( | |
| query, | |
| key, | |
| value, | |
| op=op, | |
| ) | |
| query, key, value = [x[0].transpose(0, 1) for x in [query, key, value]] | |
| attn = (query.float() / query.shape[-1] ** 0.5) @ key.float().transpose(-2, -1) | |
| ref_lse = attn.logsumexp(-1) | |
| assert_allclose(lse[0, :, 0], ref_lse[:, 0], atol=2e-4) | |
| def test_backward( | |
| opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| grad_out_contiguous, | |
| fmt, | |
| ): | |
| ( | |
| op_bw, | |
| device, | |
| dtype, | |
| bias_type, | |
| batch_size, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| ) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
| attn_bias_requires_grad = ( | |
| random.Random(q_len + kv_len * batch_size).randint(0, 1) > 0 | |
| ) | |
| query, key, value, attn_bias = create_tensors( | |
| *opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| attn_bias_requires_grad=attn_bias_requires_grad, | |
| fmt=fmt, | |
| ) | |
| # To understand why we do this, check the comment on the | |
| # `AttentionBwOpBase` class | |
| scale = None | |
| if op_bw.SUPPORTS_CUSTOM_SCALE and query.shape[-1] < 32: | |
| scale = (1 / 32) ** 0.5 | |
| op_fw = ( | |
| sample_random_supported_fw( | |
| fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias), | |
| seed=q_len * kv + kv_len * k, | |
| ) | |
| if op_bw != fmha.cutlass.BwOp | |
| else fmha.cutlass.FwOp | |
| ) | |
| if op_bw == fmha.ck.BwOp: | |
| op_fw = fmha.ck.FwOp | |
| if dtype == torch.bfloat16: | |
| pytest.skip( | |
| "CK Fmha backward for bfloat16 currently is not very accurate for some cases!" | |
| ) | |
| if grad_out_contiguous is False: | |
| pytest.skip("CK Fmha does not support contiguous layout for grad_out!") | |
| if k % 2 != 0: | |
| pytest.skip( | |
| "CK Fmha currently requires the headdim size of query input be an even value!" | |
| ) | |
| qkv = None | |
| if ( | |
| fmt == "BMHK" | |
| and query.shape[3] == value.shape[3] | |
| and query.shape[1] == value.shape[1] | |
| ): | |
| qkv = torch.stack([query, key, value], 2) | |
| qkv.requires_grad_(True) | |
| # bm3hk -> 3 x bmhk | |
| query, key, value = xformers.ops.unbind(qkv, 2) | |
| assert not query.is_contiguous() | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)): | |
| pytest.skip("inputs not supported") | |
| out = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias, scale=scale, op=(op_fw, op_bw) | |
| ) | |
| grad_out = torch.randn_like(out) | |
| if grad_out_contiguous is False: | |
| grad_out = torch.tensor([1.0], dtype=query.dtype, device=device)[ | |
| None, None, : | |
| ].expand_as(out) | |
| out.backward(grad_out) | |
| if qkv is None and op_bw == fmha.cutlass.BwOp: | |
| assert query.stride() == query.grad.stride() | |
| grads = [] | |
| if qkv is None: | |
| grads = [query.grad, key.grad, value.grad] | |
| query.grad = None | |
| key.grad = None | |
| value.grad = None | |
| else: | |
| grads = [qkv.grad] | |
| qkv.grad = None | |
| if attn_bias_requires_grad: | |
| attn_bias_grad = get_bias_grad(attn_bias, clear=True) | |
| if attn_bias_grad is not None: | |
| grads.append(attn_bias_grad) | |
| ref = ref_attention_for_test(query, key, value, attn_bias, scale=scale) | |
| ref.backward(grad_out) | |
| assert_allclose( | |
| out.float(), | |
| ref.float(), | |
| "fw pass", | |
| atol=op_fw.ERROR_ATOL[dtype], | |
| rtol=op_fw.ERROR_RTOL[dtype], | |
| ) | |
| del out | |
| del grad_out | |
| del ref | |
| atol = op_bw.ERROR_ATOL[dtype] | |
| rtol = op_bw.ERROR_RTOL[dtype] | |
| grads_ref = [] | |
| grads_name = [] | |
| if qkv is None: | |
| assert isinstance(query.grad, torch.Tensor) | |
| assert isinstance(key.grad, torch.Tensor) | |
| assert isinstance(value.grad, torch.Tensor) | |
| grads_ref = [query.grad, key.grad, value.grad] | |
| grads_name = ["query", "key", "value"] | |
| else: | |
| assert isinstance(qkv.grad, torch.Tensor) | |
| grads_ref = [qkv.grad] | |
| grads_name = ["qkv"] | |
| if attn_bias_requires_grad: | |
| attn_bias_grad = get_bias_grad(attn_bias) | |
| if attn_bias_grad is not None: | |
| grads_ref.append(attn_bias.grad) | |
| grads_name.append("bias") | |
| del query | |
| del key | |
| del value | |
| del qkv | |
| assert len(grads_ref) == len( | |
| grads | |
| ), "Wrong number of gradients (maybe bias grad didn't backprop?)" | |
| for name, calc_grad, ref_grad in zip(grads_name, grads, grads_ref): | |
| assert_allclose( | |
| calc_grad, | |
| ref_grad, | |
| msg=f"{op_fw.NAME}+{op_bw.NAME}:{name}", | |
| atol=atol, | |
| rtol=rtol, | |
| ) | |
| def _vec_binom_test(x, n, p): | |
| """ | |
| vectorized implementation of scipy.stats.binom_test | |
| this makes our tests much faster | |
| reference: https://github.com/scipy/scipy/blob/v1.8.0/scipy/stats/_morestats.py#L2609-L2702 | |
| """ | |
| import numpy as np | |
| from scipy.stats import distributions | |
| x = np.atleast_1d(x) | |
| d = distributions.binom.pmf(x, n, p)[:, None] | |
| rerr = 1 + 1e-7 | |
| # x < p * n case | |
| i = np.arange(np.ceil(p * n), n + 1) | |
| y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1) | |
| pval1 = distributions.binom.cdf(x, n, p) + distributions.binom.sf(n - y, n, p) | |
| # other case | |
| i = np.arange(np.floor(p * n) + 1) | |
| y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1) | |
| pval2 = distributions.binom.cdf(y - 1, n, p) + distributions.binom.sf(x - 1, n, p) | |
| pval = np.where(x < p * n, pval1, pval2) | |
| pval = np.minimum(1.0, pval) | |
| return pval | |
| def _get_drop_mask(op, batch_size, q_len, kv_len, p, device): | |
| if op == fmha.cutlass.FwOp: | |
| mask = torch.empty((batch_size, 1, q_len, kv_len), device=device) | |
| rand_uniform = torch.ops.xformers._cutlass_rand_uniform(p, mask) | |
| mask = (rand_uniform > p).to(torch.float32) | |
| mask = mask.reshape(batch_size, q_len, kv_len) | |
| elif op == fmha.ck.FwOp: | |
| mask = torch.empty((batch_size, 1, q_len, kv_len), device=device) | |
| # rand_uniform is an int8_t tensor | |
| rand_uniform = torch.ops.xformers._ck_rand_uniform(p, mask) | |
| mask = (rand_uniform <= int((1.0 - p) * 255.0)).to(torch.float32) | |
| mask = mask.reshape(batch_size, q_len, kv_len) | |
| else: | |
| mask = torch.empty((batch_size, q_len, kv_len), device=device) | |
| mask = torch.ops.xformers._temp_dropout(mask, p) | |
| return mask | |
| def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed, attn_bias): | |
| device = "cuda" | |
| scale = 3 | |
| dtype = torch.float | |
| if torch.version.hip and op == fmha.ck.FwOp: | |
| dtype = torch.float16 | |
| query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale | |
| key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale | |
| value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale | |
| inputs_for_support_check = fmha.Inputs(query, key, value, attn_bias, p, None) | |
| if not op.supports(inputs_for_support_check): | |
| del query, key, value, attn_bias | |
| pytest.skip(f"{op.NAME}: unsupported input") | |
| torch.manual_seed(seed) | |
| out = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias, p, op=(op, None) | |
| ) | |
| torch.manual_seed(seed) | |
| out2 = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias, p, op=(op, None) | |
| ) | |
| assert_allclose(out, out2, "dropout reproducibility") | |
| torch.manual_seed(seed) | |
| mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) | |
| ref = ref_attention_for_test(query, key, value, attn_bias, mask, p) | |
| if dtype is torch.float: | |
| assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}" | |
| else: | |
| assert_allclose(out.float(), ref, atol=2.8e-2), f"{(out - ref).abs().max()}" | |
| num_trials = 1000 | |
| p_val_tol = 1e-6 | |
| keep_prob = 1 - p | |
| masks = [] | |
| for i in range(num_trials): | |
| mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) | |
| masks.append(mask.clone().cpu()) | |
| masks = torch.stack(masks, dim=0) | |
| p_value = binomtest(int(masks.sum()), masks.numel(), p=keep_prob).pvalue | |
| assert p_value > p_val_tol, p_value | |
| masks = masks.sum(0).flatten() | |
| p_values = _vec_binom_test(masks, num_trials, p=keep_prob) | |
| assert all(p_values > p_val_tol) | |
| def _test_dropout_backward(q_len, kv_len, batch_size, k, p, op, dtype): | |
| if dtype is torch.bfloat16 and compute_capability < (8, 0): | |
| pytest.skip("bf16 requires Sm80") | |
| if not op.is_available(): | |
| pytest.skip() | |
| scale = 3 | |
| device = "cuda" | |
| query = torch.randn((batch_size, q_len, k), device=device, dtype=dtype) * scale | |
| key = torch.randn((batch_size, kv_len, k), device=device, dtype=dtype) * scale | |
| value = torch.randn((batch_size, kv_len, k), device=device, dtype=dtype) * scale | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| grad_out = torch.ones_like(query) | |
| assert op.supports(fmha.Inputs(query=query, key=key, value=value, p=p)) | |
| seed = 42 | |
| torch.manual_seed(seed) | |
| out = xformers.ops.memory_efficient_attention(query, key, value, p=p, op=(op, None)) | |
| out.backward(grad_out) | |
| grad_q = query.grad | |
| grad_k = key.grad | |
| grad_v = value.grad | |
| query.grad = None | |
| key.grad = None | |
| value.grad = None | |
| torch.manual_seed(seed) | |
| mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) | |
| ref = ref_attention_for_test(query, key, value, None, mask, p) | |
| ref.backward(grad_out) | |
| atol, rtol = ( | |
| fmha.AttentionBwOpBase.ERROR_ATOL[dtype], | |
| fmha.AttentionBwOpBase.ERROR_RTOL[dtype], | |
| ) | |
| assert_allclose( | |
| grad_v, | |
| value.grad, | |
| "grad_v", | |
| atol=atol, | |
| rtol=rtol, | |
| ) | |
| # TODO: Investigate why precision is worse | |
| if dtype in [torch.float16, torch.bfloat16]: | |
| atol = atol * 2 + 0.15 | |
| rtol = rtol * 2 | |
| assert_allclose( | |
| grad_q, | |
| query.grad, | |
| "grad_q", | |
| atol=atol, | |
| rtol=rtol, | |
| ) | |
| assert_allclose( | |
| grad_k, | |
| key.grad, | |
| "grad_k", | |
| atol=atol, | |
| rtol=rtol, | |
| ) | |
| def test_dropout_backward_small_k(q_len, kv_len, batch_size, k, p): | |
| _test_dropout_backward( | |
| q_len, kv_len, batch_size, k, p, op=fmha.small_k.FwOp, dtype=torch.float32 | |
| ) | |
| def test_dropout_backward_cutlass(dt, q_len, kv_len, batch_size, k, p): | |
| _test_dropout_backward( | |
| q_len, | |
| kv_len, | |
| batch_size, | |
| k, | |
| p, | |
| op=fmha.cutlass.FwOp, | |
| dtype={"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dt], | |
| ) | |
| def test_dropout_backward_ck(dt, q_len, kv_len, batch_size, k, p): | |
| _test_dropout_backward( | |
| q_len, | |
| kv_len, | |
| batch_size, | |
| k, | |
| p, | |
| op=fmha.ck.FwOp, | |
| dtype={"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dt], | |
| ) | |
| def test_memory_efficient_attention_full_block_masked(q_len, kv_len, batch_size, k_len): | |
| device = "cuda" | |
| op_fw = fmha.small_k.FwOp | |
| op_bw = fmha.small_k.BwOp | |
| scale = 3 | |
| query = torch.randn((batch_size, q_len, k_len), device=device) * scale | |
| key = torch.randn((batch_size, kv_len, k_len), device=device) * scale | |
| value = torch.randn((batch_size, kv_len, k_len), device=device) * scale | |
| # in this case, most of the blocks in a row get masked | |
| attn_bias = torch.full((3, 32), float("-inf"), device=device) | |
| attn_bias[:2, :4] = 0 | |
| attn_bias = attn_bias.flatten()[None, None, :].expand(1, q_len, -1) | |
| out = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias, op=(op_fw, op_bw) | |
| ) | |
| ref = ref_attention_for_test(query, key, value, attn_bias) | |
| assert_allclose( | |
| out, ref, atol=op_fw.ERROR_ATOL[query.dtype], rtol=op_fw.ERROR_RTOL[query.dtype] | |
| ) | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| grad_out = torch.ones_like(query) | |
| out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias) | |
| out.backward(grad_out) | |
| grad_q = query.grad | |
| grad_k = key.grad | |
| grad_v = value.grad | |
| query.grad = None | |
| key.grad = None | |
| value.grad = None | |
| ref = ref_attention_for_test(query, key, value, attn_bias) | |
| ref.backward(grad_out) | |
| atol = op_bw.ERROR_ATOL[query.dtype] | |
| rtol = op_bw.ERROR_RTOL[query.dtype] | |
| assert_allclose(grad_q, query.grad, "grad_q", atol=atol, rtol=rtol) | |
| assert_allclose(grad_k, key.grad, "grad_k", atol=atol, rtol=rtol) | |
| assert_allclose(grad_v, value.grad, "grad_v", atol=atol, rtol=rtol) | |
| def test_lowlevel_api_shapes(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt): | |
| query, key, value, attn_bias = create_tensors( | |
| *opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt=fmt | |
| ) | |
| grad_out = torch.ones_like(query) | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( | |
| query, key, value, attn_bias | |
| ) | |
| assert out.ndim == query.ndim | |
| dq, dk, dv = xformers.ops.memory_efficient_attention_backward( | |
| grad_out, out, lse, query, key, value, attn_bias | |
| ) | |
| assert dq.shape == query.shape | |
| assert dk.shape == key.shape | |
| assert dv.shape == value.shape | |
| def test_cuda_streams( | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| ): | |
| ( | |
| op, | |
| device, | |
| dtype, | |
| bias_type, | |
| batch_size, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
| if device != "cuda": | |
| pytest.skip("Not CUDA") | |
| bias_type = None | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = [ | |
| op, | |
| device, | |
| dtype, | |
| bias_type, | |
| batch_size, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| ] | |
| s_hipri = torch.cuda.Stream(priority=-1) | |
| s_lopri = torch.cuda.Stream(priority=0) | |
| query, key, value, attn_bias = create_tensors( | |
| *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMHK" | |
| ) | |
| torch.cuda.synchronize() | |
| with torch.cuda.stream(s_lopri): | |
| torch.cuda._sleep(100_000_000) # wait 100m cycles | |
| query *= 2 | |
| s_hipri.wait_stream(s_lopri) | |
| with torch.cuda.stream(s_hipri): | |
| # If the kernel is scheduled in the main stream | |
| # `query * 2` has not been executed yet | |
| out = xformers.ops.memory_efficient_attention(query, key, value, op=(op, None)) | |
| # Test that `s_lopri` is still sleeping | |
| # and that `query *= 2` has not been executed yet | |
| query2_main_stream = query * 2 | |
| torch.cuda.synchronize() | |
| # TODO: Figure out why this is failing sometimes | |
| # The sleep timer seems to be high enough already ... | |
| # assert torch.allclose(query2_main_stream, query), "Need to increase sleep time" | |
| del query2_main_stream | |
| ref = ref_attention_for_test(query, key, value) | |
| assert out.shape == ref.shape, out.shape | |
| assert_allclose( | |
| out.float(), | |
| ref.float(), | |
| atol=op.ERROR_ATOL[dtype], | |
| rtol=op.ERROR_RTOL.get(dtype, 1e-5), | |
| ) | |
| def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): | |
| p = 0.0 | |
| scale = 0.1 | |
| ( | |
| op_bw, | |
| device, | |
| dtype, | |
| _, | |
| B, | |
| q_len, | |
| kv_len, | |
| H, | |
| k, | |
| Kv, | |
| ) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
| torch.manual_seed(q_len + kv_len + k) | |
| if device != "cuda": | |
| pytest.skip("Not CUDA") | |
| query, key, value, attn_bias = create_tensors( | |
| *opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMK" | |
| ) | |
| inputs = fmha.Inputs( | |
| query=query, key=key, value=value, attn_bias=attn_bias, scale=scale | |
| ) | |
| op_fw = sample_random_supported_fw(inputs, seed=q_len * k + kv_len * k) | |
| grad_out = query.new_ones(B * H, q_len, Kv) | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| reasons = op_fw.not_supported_reasons(inputs) | |
| if reasons: | |
| pytest.skip(f"{op_fw.NAME}: unsupported ({'/'.join(reasons)})") | |
| reasons = op_bw.not_supported_reasons(inputs) | |
| if reasons: | |
| pytest.skip(f"{op_bw.NAME}: unsupported ({'/'.join(reasons)})") | |
| # NOTE: we still need to scale the inputs to not blowup | |
| # the pre-softmax values (numerical stability) | |
| s = k**-0.5 | |
| out = xformers.ops.memory_efficient_attention( | |
| query * s, key, value, attn_bias, p, scale, op=(op_fw, op_bw) | |
| ) | |
| out.backward(grad_out) | |
| grad_q, grad_k, grad_v = query.grad, key.grad, value.grad | |
| query.grad = key.grad = value.grad = None | |
| ref = ref_attention_for_test(query * s, key, value, attn_bias, None, p, scale) | |
| ref.backward(grad_out) | |
| ref_grad_q, ref_grad_k, ref_grad_v = query.grad, key.grad, value.grad | |
| query.grad = key.grad = value.grad = None | |
| atol = op_fw.ERROR_ATOL[dtype] | |
| rtol = op_fw.ERROR_RTOL[dtype] | |
| assert_allclose(out.float(), ref.float(), "out", atol=atol, rtol=rtol) | |
| atol = op_bw.ERROR_ATOL[dtype] | |
| rtol = op_bw.ERROR_RTOL[dtype] | |
| assert_allclose(grad_q, ref_grad_q, "grad_q", atol=atol, rtol=rtol) | |
| assert_allclose(grad_k, ref_grad_k, "grad_k", atol=atol, rtol=rtol) | |
| assert_allclose(grad_v, ref_grad_v, "grad_v", atol=atol, rtol=rtol) | |
| def apply_attention(query, key, value, attn_bias, op_fw, proj): | |
| x = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias=attn_bias, op=(op_fw, None) | |
| ) | |
| x = proj(x) | |
| return x | |
| def test_grad_checkpointing( | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| use_reentrant, | |
| ): | |
| fmt = "BMHK" | |
| ( | |
| op, | |
| device, | |
| dtype, | |
| bias_type, | |
| batch_size, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| ) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
| if op is fmha.triton_splitk.FwOp: | |
| pytest.skip("Triton Flash Decoding doesn't support backward pass yet") | |
| if op is fmha.ck.FwOp: | |
| pytest.skip("ck-tiled FMHA doesn't supported backward pass yet") | |
| bias_type = None | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = ( | |
| op, | |
| device, | |
| dtype, | |
| bias_type, | |
| batch_size, | |
| q_len, | |
| kv_len, | |
| h, | |
| k, | |
| kv, | |
| ) | |
| query, key, value, attn_bias = create_tensors( | |
| *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| fmt=fmt, | |
| ) | |
| qkv = None | |
| if ( | |
| fmt == "BMHK" | |
| and query.shape[3] == value.shape[3] | |
| and query.shape[1] == value.shape[1] | |
| ): | |
| qkv = torch.stack([query, key, value], 2) | |
| qkv.requires_grad_(True) | |
| # bm3hk -> 3 x bmhk | |
| query, key, value = xformers.ops.unbind(qkv, 2) | |
| assert not query.is_contiguous() | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| proj = torch.nn.Linear(kv, k, device=device, dtype=dtype) | |
| x = query | |
| for _ in range(5): | |
| x = checkpoint( | |
| apply_attention, | |
| x, | |
| key, | |
| value, | |
| attn_bias, | |
| op, | |
| proj, | |
| use_reentrant=use_reentrant, | |
| ) | |
| x.mean().backward() | |
| ALL_FW_OPS_NO_SMALLK = [op for op in ALL_FW_OPS if op is not fmha.small_k.FwOp] | |
| def test_unsupported_cpu(op: Type[fmha.AttentionFwOpBase]): | |
| q = torch.empty([1, 1, 1, 32]) | |
| with pytest.raises(ValueError): | |
| fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
| def test_unsupported_stride_lastdim(op: Type[fmha.AttentionFwOpBase]): | |
| q = torch.empty([1, 1, 32, 4], device="cuda", dtype=torch.float16).permute( | |
| 0, 3, 1, 2 | |
| ) | |
| try: | |
| fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
| except ValueError as e: | |
| if "Only work on pre-MLIR triton for now" in str(e): | |
| pytest.skip("Only work on pre-MLIR triton for now") | |
| q = q.contiguous() | |
| fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
| def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]): | |
| q = torch.empty([1, 2, 1, 33], device="cuda", dtype=torch.float16)[:, :, :, :32] | |
| try: | |
| fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
| except ValueError as e: | |
| if "Only work on pre-MLIR triton for now" in str(e): | |
| pytest.skip("Only work on pre-MLIR triton for now") | |
| q = q.contiguous() | |
| fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
| def test_unsupported_dropout_combine_flash_cutlass() -> None: | |
| q = torch.empty( | |
| [1, 4, 1, 16], device="cuda", dtype=torch.float16, requires_grad=True | |
| ) | |
| with pytest.raises(ValueError): | |
| out = fmha.memory_efficient_attention( | |
| q, q, q, p=0.1, op=(fmha.cutlass.FwOp, fmha.flash.BwOp) | |
| ) | |
| out.backward(out) | |
| with pytest.raises(ValueError): | |
| out = fmha.memory_efficient_attention( | |
| q, q, q, p=0.1, op=(fmha.flash.FwOp, fmha.cutlass.BwOp) | |
| ) | |
| out.backward(out) | |
| def test_attn_bias_causal() -> None: | |
| m = -math.inf | |
| causal_mask = torch.tensor([[0, m], [0, 0], [0, 0]]) | |
| tensor_bias = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) | |
| attn_bias = fmha.attn_bias.LowerTriangularMask() | |
| assert_allclose(attn_bias.materialize(causal_mask.shape), causal_mask, "causal") | |
| attn_bias = attn_bias.add_bias(tensor_bias) | |
| assert_allclose( | |
| attn_bias.materialize(causal_mask.shape), | |
| tensor_bias + causal_mask, | |
| "causal+tensor_bias", | |
| ) | |
| def test_attn_bias_torch_tensor() -> None: | |
| tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]) | |
| attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias) | |
| m = -math.inf | |
| causal_bias = torch.tensor([[0, m, m], [0, 0, m]]) | |
| assert_allclose( | |
| attn_bias.materialize((2, 3)), causal_bias + tensor_bias, "tensor_bias+causal" | |
| ) | |
| def test_attn_bias_blockdiag() -> None: | |
| queries = [ | |
| torch.randn([1, 3, 1, 8]), | |
| torch.randn([1, 2, 1, 8]), | |
| torch.randn([1, 5, 1, 8]), | |
| ] | |
| attn_bias, q = fmha.BlockDiagonalMask.from_tensor_list(queries) | |
| # Verify mask | |
| as_tensor = attn_bias.materialize((10, 10)) | |
| assert int((as_tensor != -math.inf).sum().item()) == 3 * 3 + 2 * 2 + 5 * 5 | |
| assert_allclose(as_tensor[0:3, 0:3], torch.zeros([3, 3]), "batch0") | |
| assert_allclose(as_tensor[3:5, 3:5], torch.zeros([2, 2]), "batch1") | |
| assert_allclose(as_tensor[5:, 5:], torch.zeros([5, 5]), "batch2") | |
| # Verify we can split it back | |
| queries2 = attn_bias.split(q) | |
| assert len(queries) == len(queries2) | |
| for q1, q2 in zip(queries, queries2): | |
| assert_allclose(q1, q2) | |
| def test_attn_bias_blockdiag_batched() -> None: | |
| queries = [ | |
| torch.randn([1, 3, 1, 8]), | |
| torch.randn([3, 2, 1, 8]), | |
| torch.randn([1, 5, 1, 8]), | |
| ] | |
| attn_bias, q = fmha.BlockDiagonalMask.from_tensor_list(queries) | |
| # Verify mask | |
| as_tensor = attn_bias.materialize((14, 14)) | |
| assert int((as_tensor != -math.inf).sum().item()) == 3 * 3 + 3 * 2 * 2 + 5 * 5 | |
| assert_allclose(as_tensor[0:3, 0:3], torch.zeros([3, 3]), "batch0") | |
| assert_allclose(as_tensor[3:5, 3:5], torch.zeros([2, 2]), "batch1.0") | |
| assert_allclose(as_tensor[5:7, 5:7], torch.zeros([2, 2]), "batch1.1") | |
| assert_allclose(as_tensor[7:9, 7:9], torch.zeros([2, 2]), "batch1.2") | |
| assert_allclose(as_tensor[9:, 9:], torch.zeros([5, 5]), "batch2") | |
| # Verify we can split it back | |
| queries2 = attn_bias.split(q) | |
| assert len(queries) == len(queries2) | |
| for q1, q2 in zip(queries, queries2): | |
| assert_allclose(q1, q2) | |
| def test_attn_bias_blockdiag_crossattn_causal() -> None: | |
| # Q / KV have different seqlen | |
| list_q = [ | |
| torch.randn([1, 3, 1, 8]), | |
| torch.randn([2, 1, 1, 8]), | |
| ] | |
| list_k = [ | |
| torch.randn([1, 2, 1, 8]), | |
| torch.randn([2, 3, 1, 8]), | |
| ] | |
| attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv( | |
| list_q, list_k | |
| ) | |
| # Verify mask | |
| as_tensor = attn_bias.materialize((q.shape[1], k.shape[1])) | |
| assert int((as_tensor != -math.inf).sum().item()) == 3 * 2 + 2 * 3 * 1 | |
| assert_allclose(as_tensor[0:3, 0:2], torch.zeros([3, 2]), "batch0") | |
| assert_allclose(as_tensor[3:4, 2:5], torch.zeros([1, 3]), "batch1.0") | |
| assert_allclose(as_tensor[4:, 5:], torch.zeros([1, 3]), "batch1.1") | |
| # Also test causal version | |
| as_tensor = attn_bias.make_causal().materialize((q.shape[1], k.shape[1])) | |
| assert_allclose( | |
| as_tensor[3:4, 2:5], | |
| fmha.attn_bias.LowerTriangularMask().materialize((1, 3)), | |
| "batch1.0[causal]", | |
| ) | |
| # Verify we can split it back | |
| list_q2 = attn_bias.split_queries(q) | |
| assert len(list_q) == len(list_q2) | |
| for q1, q2 in zip(list_q, list_q2): | |
| assert_allclose(q1, q2) | |
| with pytest.raises(ValueError): | |
| attn_bias.split_queries(k) | |
| list_k2 = attn_bias.split_kv(k) | |
| assert len(list_k) == len(list_k2) | |
| for k1, k2 in zip(list_k, list_k2): | |
| assert_allclose(k1, k2) | |
| def test_attn_bias_blockdiag_crossattn_causal_with_prefix_qk_cond() -> None: | |
| list_q = [ | |
| torch.randn([1, 3, 1, 8]), | |
| ] | |
| list_k = [ | |
| torch.randn([1, 2, 1, 8]), | |
| ] | |
| attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv( | |
| list_q, list_k | |
| ) | |
| with pytest.raises(ValueError): | |
| attn_bias.make_causal_from_bottomright() | |
| def test_attn_bias_blockdiag_crossattn_causal_with_prefix() -> None: | |
| # Q / KV have different seqlen | |
| list_q = [ | |
| torch.randn([1, 2, 1, 8]), | |
| torch.randn([2, 2, 1, 8]), | |
| ] | |
| list_k = [ | |
| torch.randn([1, 2, 1, 8]), | |
| torch.randn([2, 5, 1, 8]), | |
| ] | |
| attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv( | |
| list_q, list_k | |
| ) | |
| as_tensor = attn_bias.make_causal_from_bottomright().materialize( | |
| (q.shape[1], k.shape[1]) | |
| ) | |
| m = -math.inf | |
| assert_allclose( | |
| as_tensor[0:2, 0:2], | |
| torch.tensor([[0, m], [0, 0]], dtype=torch.float32), | |
| "batch1.1[causal_with_prefix]", | |
| ) | |
| assert_allclose( | |
| as_tensor[2:4, 2:7], | |
| torch.tensor([[0, 0, 0, 0, m], [0, 0, 0, 0, 0]], dtype=torch.float32), | |
| "batch2.1[causal_with_prefix]", | |
| ) | |
| assert_allclose( | |
| as_tensor[4:6, 7:12], | |
| torch.tensor([[0, 0, 0, 0, m], [0, 0, 0, 0, 0]], dtype=torch.float32), | |
| "batch2.2[causal_with_prefix]", | |
| ) | |
| def test_attn_bias_padded() -> None: | |
| bsize, n_heads, d, padding = 8, 3, 8, 32 | |
| torch.manual_seed(0) | |
| # Q / KV have different seqlen | |
| k = torch.randn((bsize, padding, n_heads, d), device="cuda", dtype=torch.float16) | |
| k_seqlen = [5, 8, 7, 1, 9, 3, 12, 32] | |
| other = bsize - 1 | |
| v = torch.randn((bsize, padding, n_heads, d), device="cuda", dtype=torch.float16) | |
| n_q_first = 4 | |
| q = [ | |
| torch.randn((1, n_q_first, n_heads, d), device="cuda", dtype=torch.float16), | |
| torch.randn((1, other, n_heads, d), device="cuda", dtype=torch.float16), | |
| ] | |
| q_cat = torch.cat([x.view(1, -1, n_heads, d) for x in q], dim=1) | |
| q_seqlen = [n_q_first] + [1] * other | |
| attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
| q_seqlen=q_seqlen, | |
| kv_seqlen=k_seqlen, | |
| kv_padding=padding, | |
| ) | |
| v = v.view(1, -1, n_heads, d) | |
| k = k.view(1, -1, n_heads, d) | |
| scores = (q_cat.transpose(1, 2) @ k.transpose(1, 2).transpose(2, 3)).float() | |
| assert not scores.isnan().any() | |
| mask = torch.full_like(scores, -float("inf")) | |
| for i, (slen, qlen) in enumerate(zip(k_seqlen, q_seqlen)): | |
| kseq_start = i * padding | |
| qstart = sum(q_seqlen[:i]) | |
| mask[:, :, qstart : qstart + qlen, kseq_start : kseq_start + slen] = torch.triu( | |
| mask[:, :, qstart : qstart + qlen, kseq_start : kseq_start + slen].float(), | |
| diagonal=1 + slen - qlen, | |
| ).float() | |
| scores += mask | |
| assert not scores.isnan().any() | |
| # 1,3,10,8 @ 1,3,8,256 -> 1,3,10,256 | |
| scores = torch.nn.functional.softmax(scores, -1).half() | |
| # torch.Size([1, 3, 3, 32]) @ torch.Size([1, 3, 32, 8]) | |
| output = scores @ v.transpose(1, 2) # 1,3,10,256 @ 1,3,256, 8 -> 1,3,10,8 | |
| output = output.transpose(1, 2).contiguous() | |
| fmha_output = fmha.memory_efficient_attention_forward( | |
| q_cat, k, v, attn_bias, scale=1.0 | |
| ) | |
| # assert torch.allclose(output, fmha_output) | |
| assert_allclose( | |
| output, | |
| fmha_output, | |
| atol=fmha.cutlass.FwOp.ERROR_ATOL[torch.float16], | |
| rtol=fmha.cutlass.FwOp.ERROR_RTOL[torch.float16], | |
| ) | |
| def test_attn_bias_to_copy() -> None: | |
| def _test_to_copy(attn_bias: torch.Tensor) -> None: | |
| assert attn_bias.device.type == "cpu", f"{attn_bias.device}" | |
| attn_bias_cuda = attn_bias.cuda() | |
| assert attn_bias_cuda.device.type == "cuda", f"{attn_bias_cuda.device}" | |
| attn_bias_fp16 = attn_bias.to(torch.float16) | |
| assert attn_bias_fp16.device.type == "cpu", f"{attn_bias_fp16.device}" | |
| assert attn_bias_fp16.dtype == torch.float16, f"{attn_bias_fp16.dtype}" | |
| attn_bias = fmha.attn_bias.LowerTriangularMask().to("cpu") | |
| _test_to_copy(attn_bias) | |
| tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]) | |
| attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias).to("cpu") | |
| _test_to_copy(attn_bias) | |
| def _kv_heads_label(kv_heads: Optional[int]) -> str: | |
| if kv_heads is None: | |
| return "" | |
| if kv_heads == 1: | |
| return "mq" | |
| return f"gqa{kv_heads}" | |
| def test_decoder( | |
| op, | |
| n_heads: int, | |
| kv_heads: Optional[int], | |
| padding: int, | |
| bsz: int, | |
| dtype: str, | |
| dequant: bool = False, | |
| num_queries: int = 1, | |
| d: int = 128, | |
| ) -> None: | |
| # kv_heads = 1: multiquery | |
| # kv_heads = None: neither MQA nor GQA | |
| # kv_heads > 1: BMGHK | |
| if dtype == "bf16" and compute_capability < (8, 0): | |
| raise pytest.skip("BF16 is only supported on SM80+") | |
| import triton | |
| if dequant and triton.__version__[:4] < "3.0.": | |
| raise pytest.skip("dequant needs triton updates") | |
| dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype] | |
| torch.manual_seed(1) | |
| if kv_heads is not None and kv_heads > 1: | |
| k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d) | |
| q_shape: Tuple[int, ...] = ( | |
| 1, | |
| bsz * num_queries, | |
| kv_heads, | |
| n_heads, | |
| d, | |
| ) | |
| else: | |
| k_shape = (1, bsz * padding, n_heads, d) | |
| q_shape = (1, bsz * num_queries, n_heads, d) | |
| # TODO: support 2 kv heads etc. | |
| k = torch.randn(k_shape, dtype=dtype_, device="cuda") | |
| k_seqlen = torch.randint(num_queries, padding + 1, (bsz,)).tolist() | |
| v = torch.randn(k_shape, dtype=dtype_, device="cuda") | |
| q = torch.randn(q_shape, dtype=dtype_, device="cuda") | |
| if dequant: | |
| k_shape = k_shape[:-1] + (d // 8 + op.NUM_GROUPS,) | |
| k = torch.zeros(k_shape, dtype=torch.int32, device="cuda") | |
| k.random_() | |
| k[..., : op.NUM_GROUPS].view(torch.float16).fill_(1.0) | |
| v = torch.zeros(k_shape, dtype=torch.int32, device="cuda") | |
| v.random_() | |
| v[..., : op.NUM_GROUPS].view(torch.float16).fill_(1.0) | |
| if kv_heads is not None: | |
| k = k[..., :1, :].expand(k_shape) | |
| v = v[..., :1, :].expand(k_shape) | |
| if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)): | |
| pytest.skip("; ".join(skip_reasons)) | |
| attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
| q_seqlen=[num_queries] * bsz, | |
| kv_seqlen=k_seqlen, | |
| kv_padding=padding, | |
| ) | |
| decoder_output = fmha.memory_efficient_attention_forward( | |
| q, | |
| k, | |
| v, | |
| attn_bias, | |
| op=op, | |
| ) | |
| def dequant_cache(x): | |
| x = x[..., op.NUM_GROUPS :, None].expand(k_shape[:-1] + (d // 8, 8)) | |
| x = x // (2 ** (4 * torch.arange(8, device="cuda"))) | |
| x = (x % 16).flatten(start_dim=-2) | |
| return x.to(dtype_) + 1.0 | |
| if dequant: | |
| k = dequant_cache(k) | |
| v = dequant_cache(v) | |
| ref_output = ref_attention_for_test(q, k, v, attn_bias) | |
| assert_allclose( | |
| decoder_output.to(ref_output.dtype), | |
| ref_output, | |
| atol=op.ERROR_ATOL[dtype_] * 4, | |
| rtol=op.ERROR_RTOL[dtype_], | |
| ) | |
| def test_triton_splitk_decoder( | |
| op, | |
| dequant: bool, | |
| kv_heads: Optional[int], | |
| n_heads: int, | |
| padding: int, | |
| bsz: int, | |
| dtype: str, | |
| ) -> None: | |
| # We omit dequant with f16: it needs a very high tol | |
| test_decoder( | |
| op, | |
| kv_heads=kv_heads, | |
| n_heads=n_heads, | |
| padding=padding, | |
| bsz=bsz, | |
| dtype=dtype, | |
| dequant=dequant, | |
| ) | |
| def test_ck_splitk_decoder( | |
| op, | |
| kv_heads: Optional[int], | |
| n_heads: int, | |
| padding: int, | |
| bsz: int, | |
| dtype: str, | |
| d: int, | |
| ) -> None: | |
| # no quantized impl compared to cuda | |
| test_decoder( | |
| op, | |
| kv_heads=kv_heads, | |
| n_heads=n_heads, | |
| padding=padding, | |
| bsz=bsz, | |
| dtype=dtype, | |
| d=d, | |
| ) | |
| # n_heads=1 => it's ambiguous whether can count as multiquery | |
| def test_triton_splitk_decoder_manyqueries( | |
| op, | |
| multiquery: bool, | |
| n_heads: int, | |
| padding: int, | |
| bsz: int, | |
| dtype: str, | |
| num_queries: int, | |
| ) -> None: | |
| kv_heads = 1 if multiquery else None | |
| test_decoder( | |
| op, | |
| kv_heads=kv_heads, | |
| n_heads=n_heads, | |
| padding=padding, | |
| bsz=bsz, | |
| dtype=dtype, | |
| num_queries=num_queries, | |
| dequant=False, | |
| ) | |
| def test_attn_bias_from_seqlens() -> None: | |
| bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens([3, 5, 1]) | |
| out = bias.split(torch.randn([1, 3 + 5 + 1, 16])) | |
| assert len(out) == 3 | |
| assert tuple(out[0].shape) == (1, 3, 16) | |
| def test_attn_bias_blockdiag_doc() -> None: | |
| """IMPORTANT: | |
| This is the example in the doc for `BlockDiagonalMask`. | |
| If this example needs to be updated, please also update the doc | |
| """ | |
| import torch | |
| from xformers.ops import fmha | |
| if torch.version.hip: | |
| pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
| K = 16 | |
| dtype = torch.float16 | |
| device = "cuda" | |
| list_x = [ | |
| torch.randn([1, 3, 1, K], dtype=dtype, device=device), | |
| torch.randn([1, 6, 1, K], dtype=dtype, device=device), | |
| torch.randn([1, 2, 1, K], dtype=dtype, device=device), | |
| ] | |
| attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) | |
| linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) # type: ignore | |
| q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) | |
| out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
| list_out = attn_bias.split(out) | |
| assert tuple(list_out[0].shape) == (1, 3, 1, K) | |
| class TestAttnBias: | |
| def create_tensors( | |
| dtype, | |
| B: int = 2, | |
| Mq: int = 32, | |
| Mkv: int = 32, | |
| H: int = 3, | |
| K: int = 16, | |
| Kv: int = 16, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| return ( | |
| torch.randn([B, Mq, H, K], device="cuda", dtype=dtype) * 3, | |
| torch.randn([B, Mkv, H, K], device="cuda", dtype=dtype) * 3, | |
| torch.randn([B, Mkv, H, Kv], device="cuda", dtype=dtype) * 3, | |
| torch.randn([B, H, Mq, Mkv], device="cuda", dtype=dtype) * 3, | |
| ) | |
| def pad_bias(bias: torch.Tensor) -> torch.Tensor: | |
| align_to = 16 | |
| if (bias.shape[-1] % align_to) == 0: | |
| return bias | |
| pad_count = align_to - (bias.shape[-1] % align_to) | |
| return torch.nn.functional.pad(bias, [0, pad_count])[:, :, :, : bias.shape[-1]] | |
| def test_f16_biasf32(self) -> None: | |
| q, k, v, bias = self.create_tensors(torch.float16) | |
| fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
| bias = bias.to(torch.float32) | |
| with pytest.raises((ValueError, RuntimeError)): | |
| fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
| def test_f32_biasf16(self) -> None: | |
| q, k, v, bias = self.create_tensors(torch.float32) | |
| fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
| bias = bias.to(torch.float16) | |
| with pytest.raises((ValueError, RuntimeError)): | |
| fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
| def test_wrong_alignment(self, dtype) -> None: | |
| op = fmha.cutlass.FwOp if torch.version.cuda else fmha.ck.FwOp | |
| if dtype not in op.SUPPORTED_DTYPES: | |
| pytest.skip( | |
| f"{dtype=} is not supported by {op.__module__}.{op.__qualname__}" | |
| ) | |
| q, k, v, bias = self.create_tensors(dtype, Mq=7, Mkv=5) | |
| try: | |
| fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=(op, None)) | |
| return | |
| except (ValueError, RuntimeError): | |
| pass | |
| # This case is not supported, likely due to padding issues | |
| # Let's make sure it works with padding | |
| assert bias.ndim == 4, bias.shape | |
| bias_padded = self.pad_bias(bias) | |
| out = fmha.memory_efficient_attention( | |
| q, k, v, attn_bias=bias_padded, op=(op, None) | |
| ).float() | |
| ref_out = ref_attention_bmhk_for_test(q, k, v, bias) | |
| assert_allclose( | |
| out, ref_out, atol=op.ERROR_ATOL[dtype], rtol=op.ERROR_RTOL[dtype] | |
| ) | |
| def test_permuted_attn_bias(self) -> None: | |
| op = fmha.cutlass.FwOp | |
| dtype = torch.float16 | |
| q, k, v, bias = self.create_tensors(dtype, Mq=7, Mkv=7) | |
| bias = bias.transpose(-1, -2) # now `stride(-1) != 1` | |
| # Either it works, or it raises an exception | |
| # but we should never get a CUDA error | |
| try: | |
| out = fmha.memory_efficient_attention( | |
| q, k, v, attn_bias=bias, op=(op, None) | |
| ).float() | |
| ref_out = ref_attention_bmhk_for_test(q, k, v, bias) | |
| assert_allclose( | |
| out, ref_out, atol=op.ERROR_ATOL[dtype], rtol=op.ERROR_RTOL[dtype] | |
| ) | |
| except (ValueError, RuntimeError): | |
| pass | |
| SM_AND_SHMEM_KBYTES = [ | |
| # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability | |
| (50, 64), | |
| (60, 64), | |
| (70, 96), | |
| (75, 64), | |
| (80, 163), | |
| (86, 99), | |
| (89, 99), | |
| # (90, 227), | |
| ] | |
| def test_has_kernel_for(sm_shmem: Tuple[int, int], dtype_str: str) -> None: | |
| dtype = {"f32": torch.float, "f16": torch.half, "bf16": torch.bfloat16}[dtype_str] | |
| sm, shmem_kbytes = sm_shmem | |
| if sm < 80 and dtype_str == "bf16": | |
| return | |
| for k in [16, 32, 64, 128, 256]: | |
| assert torch.ops.xformers._has_cutlassF_kernel_for( | |
| dtype, sm, shmem_kbytes * 1024, k | |
| ), f"k={k}" | |
| assert torch.ops.xformers._has_cutlassB_kernel_for( | |
| dtype, sm, shmem_kbytes * 1024, k | |
| ), f"k={k}" | |
| def test_window_size_materialize() -> None: | |
| seqlens = [4, 6] | |
| attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens( | |
| q_seqlen=seqlens, | |
| kv_seqlen=seqlens, | |
| ).make_local_attention(2) | |
| mask = attn_bias.materialize( | |
| (1, 1, sum(seqlens), sum(seqlens)), | |
| device="cpu", | |
| dtype=torch.float32, | |
| ) | |
| true_mask = torch.log( | |
| torch.Tensor( | |
| [ | |
| [ | |
| [ | |
| [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], | |
| [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], | |
| ] | |
| ] | |
| ] | |
| ) | |
| ) | |
| assert torch.all(mask == true_mask) | |
| def test_forward_gqa(opFW_biasT, Mq: int): | |
| opFW, biasT = opFW_biasT | |
| if Mq < 512 and ( | |
| issubclass(biasT, fmha.attn_bias.LowerTriangularMask) | |
| or issubclass(biasT, fmha.attn_bias.BlockDiagonalCausalMask) | |
| ): | |
| pytest.skip("undefined upper left") | |
| B_Mq_Mkv_H_K_Kv = (3, Mq, 512, 16, 128, 128) | |
| test_forward( | |
| ( | |
| opFW, | |
| "cuda", | |
| torch.float16, | |
| biasT, | |
| *B_Mq_Mkv_H_K_Kv, | |
| ), | |
| packed=False, | |
| fmt="BMGHK", | |
| g=2, | |
| ) | |
| def test_backward_gqa(opBW): | |
| H = 8 | |
| B_Mq_Mkv_H_K_Kv = (3, 512, 512, H, 128, 128) | |
| dtype = torch.float16 | |
| query, key, value, attn_bias = create_tensors( | |
| *(opBW, "cuda", dtype, type(None), *B_Mq_Mkv_H_K_Kv), | |
| attn_bias_requires_grad=False, | |
| fmt="BMHK", | |
| ) | |
| op = (fmha.cutlass.FwOp, opBW) | |
| key = key[:, :, :1].expand(-1, -1, H, -1) | |
| value = value[:, :, :1].expand(-1, -1, H, -1) | |
| key.requires_grad_(True) | |
| out = fmha.memory_efficient_attention(query, key, value, attn_bias=attn_bias) | |
| out_ref = ref_attention_bmhk_for_test(query, key, value, attn_bias=attn_bias) | |
| assert_allclose( | |
| out.float(), | |
| out_ref.float(), | |
| atol=op[0].ERROR_ATOL[dtype], | |
| rtol=op[0].ERROR_RTOL[dtype], | |
| ) | |
| out.backward(query) | |
| dk = key.grad | |
| key.grad = None | |
| out_ref.backward(query) | |
| assert_allclose( | |
| dk.float(), | |
| key.grad.float(), | |
| atol=op[1].ERROR_ATOL[dtype], | |
| rtol=op[1].ERROR_RTOL[dtype], | |
| ) | |
| def test_forward_gqa_one_group(opFW): | |
| dtype = torch.float16 | |
| B, Mq, Mkv, H, K = 3, 13, 16, 5, 128 | |
| q = torch.randn([B, Mq, 1, H, K], dtype=dtype, device="cuda") * 3 | |
| k = torch.randn([B, Mkv, 1, H, K], dtype=dtype, device="cuda") * 3 | |
| v = torch.randn([B, Mkv, 1, H, K], dtype=dtype, device="cuda") * 3 | |
| supported = opFW.supports(fmha.Inputs(q, k, v)) | |
| if not supported: | |
| supported_bmhk = opFW.supports(fmha.Inputs(q[:, :, 0], k[:, :, 0], v[:, :, 0])) | |
| assert supported == supported_bmhk | |
| pytest.skip("not supported") | |
| out = fmha.memory_efficient_attention_forward(q, k, v, op=opFW) | |
| ref = ref_attention_for_test(q, k, v) | |
| assert_allclose( | |
| out.float(), | |
| ref, | |
| atol=opFW.ERROR_ATOL[dtype], | |
| rtol=opFW.ERROR_RTOL.get(dtype, 1e-5), | |
| ) | |
| def test_flash_gqa_wrong_strides() -> None: | |
| op = (fmha.flash.FwOp, None) | |
| device = "cuda" | |
| B, Mq, Mkv, G, H, K = 3, 1, 512, 2, 8, 128 | |
| q = torch.empty((B, Mq, G, H, K), dtype=torch.float16, device=device) | |
| kv = torch.empty((B, Mkv, G, H, K), dtype=torch.float16, device=device) | |
| fmha.memory_efficient_attention(q, kv, kv, op=op) | |
| kv = torch.empty((B, Mkv, H, G, K), dtype=torch.float16, device=device).permute( | |
| 0, 1, 3, 2, 4 | |
| ) | |
| with pytest.raises(ValueError): | |
| fmha.memory_efficient_attention(q, kv, kv, op=op) | |
| kv = torch.empty((B, Mkv, G, 1, K), dtype=torch.float16, device=device) | |
| with pytest.raises(ValueError): | |
| fmha.memory_efficient_attention(q, kv, kv, op=op) | |
| kv = kv.expand(-1, -1, -1, H, K) | |
| fmha.memory_efficient_attention(q, kv, kv, op=op) | |
| kv = torch.empty((B, Mkv, G, H, 2 * K), dtype=torch.float16, device=device)[ | |
| :, :, :, :, :K | |
| ] | |
| fmha.memory_efficient_attention(q, kv, kv, op=op) | |
| def _dispatches_to_splitK(q, kv): | |
| return ( | |
| _dispatch_fw_priority_list(fmha.Inputs(q, kv, kv), False)[0] | |
| is fmha.triton_splitk.FwOp | |
| ) | |
| def _dispatches_to_flash_decoding(q, kv): | |
| return ( | |
| _dispatch_fw_priority_list(fmha.Inputs(q, kv, kv), False)[0] is fmha.flash.FwOp | |
| ) | |
| def test_dispatch_decoding_bmhk() -> None: | |
| assert not _dispatches_to_splitK( | |
| torch.empty([1, 8, 1, 128]), torch.empty([1, 2048, 1, 128]) | |
| ), "Should not use SplitK with 1 head (no tensorcores)" | |
| assert _dispatches_to_flash_decoding( | |
| torch.empty([1, 8, 32, 128]), | |
| torch.empty([1, 2048, 1, 128]).expand(-1, -1, 32, -1), | |
| ), "Should use Flash-Decoding with BMHK MQA" | |
| assert not _dispatches_to_splitK( | |
| torch.empty([1, 8, 32, 128]), | |
| torch.empty([1, 2048, 32, 128]), | |
| ), "Should not use SplitK when no TensorCores" | |
| assert not _dispatches_to_splitK( | |
| torch.empty([1, 128, 32, 128]), | |
| torch.empty([1, 2048, 1, 128]).expand(-1, -1, 32, -1), | |
| ), "Should not use SplitK if q seqlen is long" | |
| assert not _dispatches_to_splitK( | |
| torch.empty([128, 8, 32, 128]), | |
| torch.empty([128, 2048, 1, 128]).expand(-1, -1, 32, -1), | |
| ), "Should not use SplitK if B is big" | |
| def test_dispatch_decoding_bmghk() -> None: | |
| assert not _dispatches_to_splitK( | |
| torch.empty([1, 8, 1, 1, 128]), torch.empty([1, 2048, 1, 1, 128]) | |
| ), "Should not use SplitK with 1 head (no tensorcores)" | |
| assert _dispatches_to_flash_decoding( | |
| torch.empty([1, 8, 1, 32, 128]), | |
| torch.empty([1, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1), | |
| ), "Should use Flash-Decoding with MQA" | |
| assert _dispatches_to_flash_decoding( | |
| torch.empty([1, 8, 4, 32, 128]), | |
| torch.empty([1, 2048, 4, 1, 128]).expand(-1, -1, -1, 32, -1), | |
| ), "Should use Flash-Decoding with GQA" | |
| assert not _dispatches_to_splitK( | |
| torch.empty([1, 8, 1, 32, 128]), | |
| torch.empty([1, 2048, 1, 32, 128]), | |
| ), "Should not use SplitK when no TensorCores" | |
| assert not _dispatches_to_splitK( | |
| torch.empty([1, 128, 1, 32, 128]), | |
| torch.empty([1, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1), | |
| ), "Should not use SplitK if q seqlen is long" | |
| assert not _dispatches_to_splitK( | |
| torch.empty([128, 8, 1, 32, 128]), | |
| torch.empty([128, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1), | |
| ), "Should not use SplitK if B is big" | |
| shapes_triton_splitk = [ | |
| (1, 8, 2**16, 1, 128, 128), | |
| (1, 4, 2**16, 1, 128, 128), | |
| (1, 16, 2**16, 1, 128, 128), | |
| (1, 16, 2**16, 1, 32, 32), | |
| (1, 8, 1025, 1, 128, 128), | |
| (2, 8, 4096, 1, 128, 128), | |
| (10, 8, 2**16, 1, 128, 128), | |
| (10, 15, 2**16, 1, 128, 128), | |
| (1, 3, 2**16, 1, 128, 128), | |
| (1, 3, 2**16 - 10, 1, 128, 128), | |
| (2, 3, 73, 1, 128, 128), | |
| (2, 7, 7328, 1, 128, 128), | |
| (2, 7, 7328, 1, 120, 120), | |
| (2, 7, 63, 1, 120, 120), | |
| ] | |
| op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_splitk = [ | |
| (fmha.triton_splitk.FwOp, "cuda", torch.float16, type(None), *s) | |
| for s in shapes_triton_splitk | |
| ] + [ | |
| (fmha.triton_splitk.FwOp, "cuda", torch.bfloat16, type(None), *s) | |
| for s in shapes_triton_splitk | |
| ] | |
| def test_forward_splitk( | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| packed=False, | |
| fmt="BMHK", | |
| ): | |
| test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed=packed, fmt=fmt) | |
| def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K): | |
| B, Mkv, H, K = B_Mkv_H_K | |
| q = torch.randn([B, 1, H, K], dtype=dtype, device="cuda") * 3 | |
| k = torch.randn([B, Mkv, 1, K], dtype=dtype, device="cuda") * 3 | |
| v = torch.randn([B, Mkv, 1, K], dtype=dtype, device="cuda") * 3 | |
| k = k.expand(-1, -1, H, -1) | |
| v = v.expand(-1, -1, H, -1) | |
| if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)): | |
| pytest.skip("; ".join(skip_reasons)) | |
| out = fmha.memory_efficient_attention_forward(q, k, v, op=op) | |
| ref = ref_attention_for_test(q, k, v) | |
| assert_allclose( | |
| out.float(), | |
| ref, | |
| atol=op.ERROR_ATOL[dtype], | |
| rtol=op.ERROR_RTOL.get(dtype, 1e-5), | |
| ) | |
| def test_empty_tensors_empty_query( | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| ): | |
| query, key, value, attn_bias = create_tensors( | |
| *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| fmt="BMHK", | |
| ) | |
| opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] | |
| if torch.version.hip: | |
| pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
| query = query[:, :0] | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) | |
| assert out.shape[1] == 0 | |
| out.backward(out) | |
| # dK/dV should be all zeros | |
| assert_allclose(key.grad, torch.zeros_like(key.grad), "key.grad") | |
| assert_allclose(value.grad, torch.zeros_like(value.grad), "value.grad") | |
| def test_empty_tensors_empty_kv( | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| ): | |
| query, key, value, attn_bias = create_tensors( | |
| *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| fmt="BMHK", | |
| ) | |
| opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] | |
| if opFW == fmha.triton_splitk.FwOp: | |
| pytest.skip("triton_splitk doesn't support empty kv") | |
| if torch.version.hip: | |
| pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
| key = key[:, :0] | |
| value = value[:, :0] | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) | |
| assert_allclose(out, torch.zeros_like(out), "out") | |
| out.backward(out) | |
| # dQ should be all zeros | |
| assert_allclose(query.grad, torch.zeros_like(query.grad), "query.grad") | |
| def test_empty_tensors_empty_b( | |
| opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| ): | |
| query, key, value, attn_bias = create_tensors( | |
| *opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
| fmt="BMHK", | |
| ) | |
| opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] | |
| if torch.version.hip: | |
| pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
| query, key, value = query[:0], key[:0], value[:0] | |
| query.requires_grad_(True) | |
| key.requires_grad_(True) | |
| value.requires_grad_(True) | |
| out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) | |
| out.backward(out) | |
| def test_local_attn_bias() -> None: | |
| mask = ( | |
| fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2) | |
| .materialize(shape=(4, 4)) | |
| .exp() | |
| ) | |
| expected = torch.tensor( | |
| [[1, 1, 1, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1]], dtype=torch.float32 | |
| ) | |
| assert (mask == expected).all().item() | |
| def test_cutlassB_iter_order( | |
| dtype, | |
| cc: int, | |
| maxK: int, | |
| num_queries: int, | |
| num_keys: int, | |
| custom_mask_type, | |
| window_size, | |
| ) -> None: | |
| """ | |
| This tests some internals of the cutlassB kernel | |
| We test the iteration across blocks of [queries, keys] to ensure | |
| that we correctly: | |
| * Iterate over all the blocks that should be iterated | |
| * Do *not* iterate over blocks that are completely masked out | |
| * Correctly compute the number of parallel blocks that will compute | |
| the same block of dQ | |
| .. and we test this across variable causal masks+local attention combinations | |
| """ | |
| if ( | |
| window_size > 0 | |
| and custom_mask_type == fmha.cutlass._CustomMaskType.NoCustomMask | |
| ): | |
| pytest.skip("LocalAttention is only supported for causal") | |
| get_iteration_data = partial( | |
| torch.ops.xformers._cutlassB_iteration_data, | |
| dtype=dtype, | |
| cc=cc, | |
| maxK=maxK, | |
| num_queries=num_queries, | |
| num_keys=num_keys, | |
| custom_mask_type=custom_mask_type, | |
| window_size=window_size, | |
| ) | |
| bias = torch.zeros([num_queries, num_keys], dtype=torch.float32) | |
| if custom_mask_type != fmha.cutlass._CustomMaskType.NoCustomMask: | |
| bias = fmha.attn_bias._materialize_causal_mask( | |
| (num_queries, num_keys), | |
| dtype=torch.float32, | |
| device="cpu", | |
| window_size=None if window_size == 0 else window_size, | |
| from_bottomright=( | |
| custom_mask_type == fmha.cutlass._CustomMaskType.CausalFromBottomRight | |
| ), | |
| ) | |
| block_queries, block_keys = get_iteration_data()[:2] | |
| mask_pooled = ( | |
| F.max_pool2d(bias.unsqueeze(0), (block_queries, block_keys), ceil_mode=True) | |
| == 0 | |
| ).int()[0] | |
| attn_computed = torch.zeros_like(mask_pooled) | |
| for key_start in range(0, num_keys, block_keys): | |
| it = 0 | |
| new_key_start = key_start | |
| new_query_start = get_iteration_data(key_start=key_start)[2] | |
| try: | |
| expected_first_query = ( | |
| mask_pooled[:, key_start // block_keys].tolist().index(1) | |
| * block_queries | |
| ) | |
| assert ( | |
| new_query_start == expected_first_query | |
| ), f"Wrong first query for K={key_start}: {new_query_start} (expected {expected_first_query})" | |
| except ValueError: # Nothing to compute in this column | |
| pass | |
| while new_key_start == key_start and new_query_start < num_queries: | |
| query_start = new_query_start | |
| attn_computed[query_start // block_queries, key_start // block_keys] += 1 | |
| # print(f"Compute [{query_start}, {key_start}]") | |
| # Is there something to compute here? | |
| assert mask_pooled[ | |
| query_start // block_queries, key_start // block_keys | |
| ].item(), "Computing a block that is not needed!" | |
| new_query_start, new_key_start = get_iteration_data( | |
| key_start=key_start, query_start=query_start | |
| )[3:5] | |
| it += 1 | |
| assert it < num_queries, "" | |
| assert (attn_computed == mask_pooled)[ | |
| :, key_start // block_keys | |
| ].all(), "some blocks were not computed!" | |
| # Now check that the number returned by `getNumParallelBlocksForQuery` is correct | |
| for query_start in range(0, num_queries, block_queries): | |
| num_parallel_blocks = get_iteration_data( | |
| query_start=query_start, num_splits_key=num_keys | |
| )[5] | |
| num_actual = mask_pooled[query_start // block_queries].sum().item() | |
| assert num_parallel_blocks == num_actual | |
| def test_paged_attention( | |
| B, | |
| MAX_T: int, | |
| num_quant_groups: int, | |
| page_size: int, | |
| op: Type[AttentionFwOpBase], | |
| gappy: bool, | |
| ): | |
| paged_attention_run_inner( | |
| B, MAX_T, num_quant_groups, page_size, op, bench=False, gappy=gappy | |
| ) | |
| def test_paged_attention_flash(B, MAX_T: int, page_size: int): | |
| # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged | |
| op = fmha.flash.FwOp | |
| if ( | |
| fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask | |
| not in op.SUPPORTED_ATTN_BIAS_TYPES | |
| ): | |
| pytest.skip("Not supported bias") | |
| num_quant_groups = 0 | |
| paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False) | |
| def paged_attention_run_inner( | |
| B: int, | |
| MAX_T: int, | |
| num_quant_groups: int, | |
| page_size: int, | |
| op: Type[AttentionFwOpBase], | |
| bench: bool, | |
| gappy: bool = False, | |
| ) -> None: | |
| import triton | |
| torch.manual_seed(10) | |
| TEST_WARMUP_MS = 500 | |
| TEST_RUN_MS = 5000 | |
| N_H_L = 8 | |
| N_KVH_L = 1 | |
| D_H = 128 | |
| D_H_KV = D_H // 8 + num_quant_groups if num_quant_groups else D_H | |
| kv_seqlens = torch.randint(low=1, high=MAX_T + 1, size=(B,)).tolist() | |
| # Paged attention requires k.shape[1] and v.shape[1] to be divisible by page_size, so pad | |
| padded_per_row_len = ((MAX_T + page_size - 1) // page_size) * page_size | |
| if gappy: | |
| make_paged_kwargs = { | |
| "paged_type": fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
| "notional_padding": MAX_T, | |
| } | |
| attn_bias = fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens( | |
| q_seqlen=[1] * B, | |
| kv_seqstarts=list(range(0, MAX_T * (B + 1), MAX_T)), | |
| kv_seqlen=kv_seqlens, | |
| ) | |
| else: | |
| make_paged_kwargs = { | |
| "paged_type": fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, | |
| } | |
| block_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask | |
| attn_bias = block_type.from_seqlens( # type: ignore | |
| q_seqlen=[1] * B, | |
| kv_padding=MAX_T, | |
| kv_seqlen=kv_seqlens, | |
| ) | |
| q = torch.randn((B, 1, N_H_L, D_H), dtype=torch.bfloat16, device="cuda") | |
| if num_quant_groups: | |
| if triton.__version__[:4] < "3.0.": | |
| raise pytest.skip("dequant needs triton updates") | |
| # Using high=64 below, because with 256 both paged and non-paged paths | |
| # will produce NaNs - probably some quantization coeffitions are NaNs | |
| # after the bitwise cast. | |
| cache_k = torch.randint( | |
| 0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda" | |
| ) | |
| cache_k = cache_k.view(dtype=torch.int32) | |
| cache_v = torch.randint( | |
| 0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda" | |
| ) | |
| cache_v = cache_v.view(dtype=torch.int32) | |
| op = type( | |
| f"{op.__name__}_{num_quant_groups}", | |
| (op,), | |
| {"NUM_GROUPS": num_quant_groups}, | |
| ) | |
| else: | |
| cache_k = torch.randn( | |
| (B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" | |
| ) | |
| cache_v = torch.randn_like(cache_k) | |
| axq = q.view(1, B * 1, N_H_L, D_H) | |
| axk = cache_k.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand( | |
| 1, B * MAX_T, N_H_L, D_H_KV | |
| ) | |
| axv = cache_v.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand( | |
| 1, B * MAX_T, N_H_L, D_H_KV | |
| ) | |
| k_cache_size_usual = axk.numel() | |
| # First, create "wasteful" K/V cache, where every block in logical cache | |
| # has a physical representation, even if there's nothing stored there | |
| block_tables = torch.arange( | |
| B * padded_per_row_len // page_size, device="cuda", dtype=torch.int32 | |
| ).reshape(B, -1) | |
| shape_padded = (B, padded_per_row_len, N_KVH_L, D_H_KV) | |
| axk_padded = torch.empty(shape_padded, device=axk.device, dtype=axk.dtype) | |
| axv_padded = torch.empty(shape_padded, device=axv.device, dtype=axv.dtype) | |
| axk_padded[:, :MAX_T] = axk.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :] | |
| axv_padded[:, :MAX_T] = axv.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :] | |
| axk_padded = axk_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV) | |
| axv_padded = axv_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV) | |
| axk_padded = axk_padded.expand(-1, -1, N_H_L, -1) | |
| axv_padded = axv_padded.expand(-1, -1, N_H_L, -1) | |
| attn_bias_paged = attn_bias.make_paged( | |
| block_tables=block_tables, page_size=page_size, **make_paged_kwargs # type: ignore | |
| ) | |
| y_usual = fmha.memory_efficient_attention_forward( | |
| axq, | |
| axk, | |
| axv, | |
| attn_bias, | |
| op=op, | |
| ) | |
| if bench: | |
| g = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(g): | |
| y_usual = fmha.memory_efficient_attention_forward( | |
| axq, | |
| axk, | |
| axv, | |
| attn_bias, | |
| op=op, | |
| ) | |
| t_ms = triton.testing.do_bench( | |
| lambda g=g: g.replay(), | |
| warmup=TEST_WARMUP_MS, | |
| rep=TEST_RUN_MS, | |
| ) | |
| logger.info(f"Non-paged attention took {t_ms * 1e3:.2f}us") | |
| y_wasteful = fmha.memory_efficient_attention_forward( | |
| axq, | |
| axk_padded, | |
| axv_padded, | |
| attn_bias_paged, | |
| op=op, | |
| ) | |
| if bench: | |
| g = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(g): | |
| y_wasteful = fmha.memory_efficient_attention_forward( | |
| axq, | |
| axk_padded, | |
| axv_padded, | |
| attn_bias_paged, | |
| op=op, | |
| ) | |
| t_ms = triton.testing.do_bench( | |
| lambda g=g: g.replay(), | |
| warmup=TEST_WARMUP_MS, | |
| rep=TEST_RUN_MS, | |
| ) | |
| logger.info(f"Paged attention with wasteful K/V-cache took {t_ms * 1e3:.2f}us") | |
| torch.testing.assert_close( | |
| y_wasteful, | |
| y_usual, | |
| atol=1.0e-2, | |
| rtol=1.0e-2, | |
| ) | |
| # Now let's create a "packed" K/V cache, where only meaniningful logical blocks are mapped to physical blocks | |
| (block_tables, packed_cache_k, packed_cache_v) = pack_kv_cache( | |
| cache_k, | |
| cache_v, | |
| kv_seqlens, | |
| page_size, | |
| ) | |
| attn_bias_paged = attn_bias.make_paged( | |
| block_tables=block_tables, page_size=page_size, **make_paged_kwargs # type: ignore | |
| ) | |
| axk = packed_cache_k.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV) | |
| axv = packed_cache_v.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV) | |
| k_cache_size_packed = axk.numel() | |
| y_packed = fmha.memory_efficient_attention_forward( | |
| axq, | |
| axk, | |
| axv, | |
| attn_bias_paged, | |
| op=op, | |
| ) | |
| logger.info( | |
| f"KV-cache size reduced by {(100 * (1 - k_cache_size_packed/k_cache_size_usual)):.2f}%" | |
| ) | |
| torch.testing.assert_close(y_wasteful, y_packed) | |
| # Let's swap two blocks, and adjust two corresponding entries in the block table. The result shouldn't change | |
| i, j = 0, axk.shape[1] // page_size - 1 | |
| axk = axk[:, :, :1, :] | |
| axv = axv[:, :, :1, :] | |
| vals_i = axk[:, i * page_size : (i + 1) * page_size, :, :].clone() | |
| vals_j = axk[:, j * page_size : (j + 1) * page_size, :, :].clone() | |
| axk[:, i * page_size : (i + 1) * page_size, :, :] = vals_j | |
| axk[:, j * page_size : (j + 1) * page_size, :, :] = vals_i | |
| vals_i = axv[:, i * page_size : (i + 1) * page_size, :, :].clone() | |
| vals_j = axv[:, j * page_size : (j + 1) * page_size, :, :].clone() | |
| axv[:, i * page_size : (i + 1) * page_size, :, :] = vals_j | |
| axv[:, j * page_size : (j + 1) * page_size, :, :] = vals_i | |
| axk = axk.expand(-1, -1, N_H_L, -1) | |
| axv = axv.expand(-1, -1, N_H_L, -1) | |
| where_i = block_tables == i | |
| where_j = block_tables == j | |
| block_tables.masked_fill_(where_i, j) | |
| block_tables.masked_fill_(where_j, i) | |
| y_swapped = fmha.memory_efficient_attention_forward( | |
| axq, | |
| axk, | |
| axv, | |
| attn_bias_paged, | |
| op=op, | |
| ) | |
| if bench: | |
| g = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(g): | |
| y_swapped = fmha.memory_efficient_attention_forward( | |
| axq, | |
| axk, | |
| axv, | |
| attn_bias_paged, | |
| op=op, | |
| ) | |
| t_ms = triton.testing.do_bench( | |
| lambda g=g: g.replay(), | |
| warmup=TEST_WARMUP_MS, | |
| rep=TEST_RUN_MS, | |
| ) | |
| logger.info(f"Paged attention with packed K/V-cache took {t_ms * 1e3:.2f}us") | |
| torch.testing.assert_close(y_swapped, y_packed) | |
| def test_merge_attentions_nobias( | |
| write_lse: bool, | |
| stack_inputs: bool, | |
| op: Type[AttentionFwOpBase], | |
| G: Optional[int], | |
| H: int, | |
| ): | |
| """ | |
| Merging the same attention twice shouldn't change anything. | |
| This also tests the shape of the lse output of each permitted op. | |
| """ | |
| B, M, Mq, K = 13, 5, 3, 128 | |
| if op is None or torch.bfloat16 in op.SUPPORTED_DTYPES: | |
| dtype = torch.bfloat16 | |
| else: | |
| dtype = next(iter(op.SUPPORTED_DTYPES)) | |
| if G is None: | |
| q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda") | |
| k = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K) | |
| v = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K) | |
| else: | |
| q = 3 * torch.rand(B, Mq, G, H, K, dtype=dtype, device="cuda") | |
| k = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand( | |
| B, M, G, H, K | |
| ) | |
| v = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand( | |
| B, M, G, H, K | |
| ) | |
| out1, lse1 = fmha.memory_efficient_attention_partial(q, k, v, op=op) | |
| assert out1.shape == q.shape | |
| M_ceil = lse1.shape[-1] | |
| assert M_ceil >= Mq | |
| assert lse1.shape == (B, H, M_ceil) if G is None else (B, G, H, M_ceil) | |
| lse1 = lse1[..., :Mq] | |
| attn_chunks = [out1, out1] | |
| lse_chunks = [lse1, lse1] | |
| attn_chunks_ = torch.stack(attn_chunks) if stack_inputs else attn_chunks | |
| lse_chunks_ = torch.stack(lse_chunks) if stack_inputs else lse_chunks | |
| out, lse = fmha.merge_attentions(attn_chunks_, lse_chunks_, write_lse=write_lse) # type: ignore | |
| assert out.shape == out1.shape | |
| assert_allclose(out1, out, rtol=1e-3, atol=1e-3, msg="out") | |
| if write_lse: | |
| assert lse is not None | |
| assert lse.shape[:-1] == lse1.shape[:-1] | |
| assert_allclose( | |
| lse1[..., :Mq] + math.log(2), lse[..., :Mq], rtol=1e-3, atol=1e-3, msg="lse" | |
| ) | |
| else: | |
| assert lse is None | |
| def test_merge_attentions_nobias_bwd( | |
| op: Union[Type[AttentionFwOpBase], fmha.AttentionOp] | |
| ): | |
| B, M, Mq, H, K = 13, 5, 5, 4, 128 | |
| dtype = torch.bfloat16 | |
| nparts = 3 | |
| torch.manual_seed(1) | |
| q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda") | |
| kv = [ | |
| [3 * (torch.rand(B, M, H, K, dtype=dtype, device="cuda")) for _ in range(2)] | |
| for _ in range(nparts) | |
| ] | |
| q = q.requires_grad_(True) | |
| kv = [[j.requires_grad_(True) for j in i] for i in kv] | |
| out_parts = [fmha.memory_efficient_attention_partial(q, k, v, op=op) for k, v in kv] | |
| attn_split, lse_split = [list(x) for x in zip(*out_parts)] | |
| out_merged = fmha.merge_attentions(attn_split, lse_split, write_lse=True)[0] | |
| grad_out = torch.rand_like(q) | |
| out_merged.backward(grad_out) | |
| grad_q_out = q.grad | |
| assert q.grad is not None | |
| grad_kv_out = [[j.grad for j in i] for i in kv] | |
| q = q.detach().requires_grad_(True) | |
| kv = [[j.detach().requires_grad_(True) for j in i] for i in kv] | |
| k2, v2 = [torch.cat([i[j] for i in kv], dim=1) for j in range(2)] | |
| if op is None or isinstance(op, tuple): | |
| full_op = op | |
| else: | |
| full_op = (op, None) | |
| out_full = fmha.memory_efficient_attention(q, k2, v2, op=full_op) # type: ignore | |
| out_full.backward(grad_out) | |
| assert_allclose( | |
| out_merged, out_full.to(out_merged.dtype), rtol=1e-2, atol=2e-2, msg="out" | |
| ) | |
| atol = fmha.AttentionBwOpBase.ERROR_ATOL[dtype] * 1.5 | |
| rtol = fmha.AttentionBwOpBase.ERROR_RTOL[dtype] | |
| assert_allclose(grad_q_out, q.grad, rtol=rtol, atol=atol, msg="qgrad") | |
| for i in range(nparts): | |
| for j in range(2): | |
| assert_allclose( | |
| grad_kv_out[i][j], | |
| kv[i][j].grad, | |
| rtol=rtol, | |
| atol=atol, | |
| msg=f"kvgrad {i} {j}", | |
| ) | |
| def test_partial_paged( | |
| dtype: torch.dtype, op: Type[AttentionFwOpBase], num_queries: int, bmghk: bool | |
| ): | |
| B = 128 | |
| N_H_L = 8 | |
| D_H = 128 | |
| page_size = 256 | |
| G = 2 if bmghk else 1 | |
| block_tables = torch.zeros((B, 1), dtype=torch.int32, device="cuda") | |
| torch.manual_seed(1) | |
| output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None | |
| B_T = num_queries * B | |
| q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda") | |
| k = torch.randn((1, page_size, G, 1, D_H), dtype=dtype, device="cuda") | |
| v = torch.randn_like(k) | |
| k = k.expand(1, page_size, G, N_H_L, D_H) | |
| v = v.expand(1, page_size, G, N_H_L, D_H) | |
| if not bmghk: | |
| q = q[:, :, 0] | |
| k = k[:, :, 0] | |
| v = v[:, :, 0] | |
| attn_bias = ( | |
| fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
| q_seqlen=[num_queries] * B, | |
| kv_seqlen=[1] + ([100] * (B - 1)), | |
| page_size=page_size, | |
| block_tables=block_tables, | |
| ) | |
| ) | |
| if attn_bias not in op.SUPPORTED_ATTN_BIAS_TYPES: | |
| pytest.skip("Not supported bias") | |
| attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial( | |
| q, | |
| k, | |
| v, | |
| attn_bias, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| if bmghk: | |
| assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H) | |
| assert lse_chunk.shape == ( | |
| 1, | |
| G, | |
| N_H_L, | |
| B_T, | |
| ), f"{lse_chunk.shape=}, {(1, G, N_H_L, B_T)=}" | |
| else: | |
| assert attn_chunk.shape == (1, B_T, N_H_L, D_H) | |
| assert lse_chunk.shape == ( | |
| 1, | |
| N_H_L, | |
| B_T, | |
| ), f"{lse_chunk.shape=}, {(1, N_H_L, B_T)=}" | |
| def test_merge_attentions_decoding( | |
| dtype: torch.dtype, | |
| op: Type[AttentionFwOpBase], | |
| num_queries: int, | |
| bmghk: bool, | |
| stack_inputs: bool, | |
| ): | |
| """ | |
| Compute decoding attention on chunks of K/V and merge them together. | |
| Compare with computing attention on the whole K/V. | |
| """ | |
| MAX_T = 8192 | |
| B = 128 | |
| N_H_L = 8 | |
| D_H = 128 | |
| G = 2 if bmghk else 1 | |
| torch.manual_seed(1) | |
| output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None | |
| num_chunks = 10 | |
| chunk_starts = sorted( | |
| torch.randint(low=1, high=MAX_T // 2, size=(num_chunks,)).tolist() | |
| ) | |
| chunk_starts[0] = 0 | |
| chunk_starts.append(MAX_T) | |
| # We construct sequences so that even the last chunk has a non-empty part of every sequence | |
| # as long as the number of queries. | |
| # Otherwise the corresponding LSE will be -inf and that'll propagate to the whole sum. | |
| # It is possible to teach the kernel to ignore infinite LSEs, but in practical use cases | |
| # of merging attention, e.g. a batch of sequences with a common prefix, this condition should be satisfied. | |
| k_lens = torch.randint( | |
| low=chunk_starts[-2] + num_queries, high=MAX_T, size=(B,) | |
| ).tolist() | |
| q_lens = [num_queries] * B | |
| B_T = num_queries * B | |
| q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda") | |
| k = torch.randn((B, MAX_T, G, 1, D_H), dtype=dtype, device="cuda") | |
| v = torch.randn_like(k) | |
| if not bmghk: | |
| q = q[:, :, 0] | |
| # Compute per-chunk attention | |
| chunks_output = [] | |
| for i in range(num_chunks): | |
| chunk_start, chunk_end = chunk_starts[i], chunk_starts[i + 1] | |
| k_chunk = k[:, chunk_start:chunk_end, ...] | |
| v_chunk = v[:, chunk_start:chunk_end, ...] | |
| axk = k_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
| axv = v_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
| if not bmghk: | |
| axk = axk[:, :, 0] | |
| axv = axv[:, :, 0] | |
| bias_type = fmha.attn_bias.BlockDiagonalPaddedKeysMask | |
| if i + 1 == num_chunks: | |
| bias_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask | |
| attn_bias = bias_type.from_seqlens( | |
| q_seqlen=q_lens, | |
| kv_padding=chunk_end - chunk_start, | |
| kv_seqlen=[max(min(x, chunk_end) - chunk_start, 0) for x in k_lens], | |
| ) | |
| attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial( | |
| q, | |
| axk, | |
| axv, | |
| attn_bias, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| if bmghk: | |
| assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H) | |
| assert lse_chunk.shape == (1, G, N_H_L, B_T) | |
| else: | |
| assert attn_chunk.shape == (1, B_T, N_H_L, D_H) | |
| assert lse_chunk.shape == (1, N_H_L, B_T) | |
| chunks_output.append((attn_chunk, lse_chunk)) | |
| # Merge attention from all chunks | |
| attn_split = [attn_chunk for attn_chunk, _ in chunks_output] | |
| lse_split = [lse_chunk for _, lse_chunk in chunks_output] | |
| attn_split_ = torch.stack(attn_split) if stack_inputs else attn_split | |
| lse_split_ = torch.stack(lse_split) if stack_inputs else lse_split | |
| attn_out, lse_out = fmha.merge_attentions( | |
| attn_split_, lse_split_, output_dtype=dtype # type: ignore | |
| ) | |
| assert lse_out is not None | |
| # Compute attention on the full K/V | |
| attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
| q_seqlen=q_lens, | |
| kv_padding=MAX_T, | |
| kv_seqlen=k_lens, | |
| ) | |
| axk = k.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
| axv = v.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
| if not bmghk: | |
| axk = axk[:, :, 0] | |
| axv = axv[:, :, 0] | |
| attn_full, lse_full = fmha.memory_efficient_attention_partial( | |
| q, | |
| axk, | |
| axv, | |
| attn_bias, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| assert_allclose( | |
| lse_out.to(lse_full.dtype), lse_full, rtol=1e-3, atol=1e-3, msg="lse" | |
| ) | |
| assert_allclose( | |
| attn_out.to(attn_full.dtype), attn_full, rtol=1e-3, atol=1e-3, msg="out" | |
| ) | |
| attn_full2 = fmha.memory_efficient_attention_forward( | |
| q, | |
| axk, | |
| axv, | |
| attn_bias, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| assert_allclose(attn_full2, attn_full, rtol=1e-3, atol=1e-3, msg="out2") | |
| def test_merge_attentions_sharedinput( | |
| dtype: torch.dtype, | |
| op: Type[AttentionFwOpBase], | |
| gqa: bool, | |
| ): | |
| """ | |
| Compute decoding attention on chunks of K/V and merge them together. | |
| Compare with computing attention on the whole K/V. | |
| """ | |
| MAX_T = 8192 | |
| N_H_L = 16 | |
| D_H = 128 | |
| G = 2 | |
| torch.manual_seed(1) | |
| output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None | |
| shared_length = 20 | |
| full_lengths = [30, 35, 40] | |
| attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
| q_seqlen=[1, 1, 1], | |
| kv_padding=MAX_T, | |
| kv_seqlen=full_lengths, | |
| ) | |
| attn_bias1 = fmha.attn_bias.BlockDiagonalPaddedKeysMask.from_seqlens( | |
| q_seqlen=[2, 1], | |
| kv_padding=MAX_T, | |
| kv_seqlen=[shared_length, 0], | |
| ) | |
| attn_bias2 = fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens( | |
| q_seqlen=[1, 1, 1], | |
| kv_seqstarts=[shared_length, MAX_T + shared_length, 2 * MAX_T, 3 * MAX_T], | |
| kv_seqlen=[ | |
| full_lengths[0] - shared_length, | |
| full_lengths[1] - shared_length, | |
| full_lengths[2], | |
| ], | |
| ) | |
| q = torch.randn((1, 3, G, N_H_L, D_H), dtype=dtype, device="cuda") | |
| k = torch.randn((3, MAX_T, G, 1 if gqa else N_H_L, D_H), dtype=dtype, device="cuda") | |
| v = torch.randn_like(k) | |
| k[1, :shared_length] = k[0, :shared_length] | |
| v[1, :shared_length] = v[0, :shared_length] | |
| k = k.flatten(end_dim=1)[None] | |
| v = v.flatten(end_dim=1)[None] | |
| k = k.expand((1, 3 * MAX_T, G, N_H_L, D_H)) | |
| v = v.expand((1, 3 * MAX_T, G, N_H_L, D_H)) | |
| attn_chunk1, lse_chunk1 = fmha.memory_efficient_attention_partial( | |
| q, | |
| k, | |
| v, | |
| attn_bias1, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| assert attn_chunk1.shape == (1, 3, G, N_H_L, D_H) | |
| assert lse_chunk1.shape == (1, G, N_H_L, 3) | |
| if gqa: | |
| attn_chunk1a, lse_chunk1a = fmha.memory_efficient_attention_partial( | |
| q, | |
| k.contiguous(), | |
| v, | |
| attn_bias1, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| assert attn_chunk1a.shape == (1, 3, G, N_H_L, D_H) | |
| assert lse_chunk1a.shape == (1, G, N_H_L, 3) | |
| assert_allclose( | |
| attn_chunk1a.nan_to_num(0, 0, 0), attn_chunk1.nan_to_num(0, 0, 0) | |
| ) | |
| assert_allclose(lse_chunk1a.nan_to_num(0, 0, 0), lse_chunk1.nan_to_num(0, 0, 0)) | |
| attn_chunk2, lse_chunk2 = fmha.memory_efficient_attention_partial( | |
| q, | |
| k, | |
| v, | |
| attn_bias2, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| assert attn_chunk2.shape == (1, 3, G, N_H_L, D_H) | |
| assert lse_chunk2.shape == (1, G, N_H_L, 3) | |
| # Merge attention from all chunks | |
| attn_out, lse_out = fmha.merge_attentions( | |
| [attn_chunk1, attn_chunk2], [lse_chunk1, lse_chunk2], output_dtype=dtype # type: ignore | |
| ) | |
| assert lse_out is not None | |
| # Compute attention on the full K/V | |
| attn_full, lse_full = fmha.memory_efficient_attention_partial( | |
| q, | |
| k, | |
| v, | |
| attn_bias, | |
| op=op, | |
| output_dtype=output_dtype, | |
| ) | |
| assert_allclose( | |
| attn_out.to(attn_full.dtype), attn_full, rtol=1e-2, atol=2e-3, msg="out" | |
| ) | |
| assert_allclose( | |
| lse_out.to(lse_full.dtype), lse_full, rtol=1e-3, atol=1e-3, msg="lse" | |
| ) | |
| # Gradient with respect to attention, LSE, or neither | |
| def test_merge_attentions_against_ref( | |
| bmghk: bool, stack_inputs: bool, grad_var: Optional[str] | |
| ): | |
| split_k = 16 | |
| B = 12 | |
| M = 137 | |
| G = 2 if bmghk else 1 | |
| N_H_L = 8 | |
| D_H = 128 | |
| dtype = torch.float32 | |
| attn_split = torch.randn([split_k, B, M, G, N_H_L, D_H], dtype=dtype, device="cuda") | |
| lse_split = torch.randn([split_k, B, G, N_H_L, M], dtype=dtype, device="cuda") | |
| if not bmghk: | |
| attn_split = attn_split[:, :, :, 0] | |
| lse_split = lse_split[:, :, 0] | |
| if grad_var is not None: | |
| attn_split.requires_grad_(True) | |
| lse_split.requires_grad_(True) | |
| attn_out_ref, lse_out_ref = _merge_attentions_ref(attn_split, lse_split) | |
| if grad_var is not None: | |
| if grad_var == "attn": | |
| out_grad = torch.randn_like(attn_out_ref) | |
| attn_out_ref.backward(out_grad) | |
| else: | |
| out_grad = torch.randn_like(lse_out_ref) | |
| lse_out_ref.backward(out_grad) | |
| attn_grad_ref, lse_grad_ref = attn_split.grad, lse_split.grad | |
| attn_split = attn_split.detach().unbind(0) # type: ignore | |
| lse_split = lse_split.detach().unbind(0) # type: ignore | |
| for x in attn_split + lse_split: | |
| x.requires_grad_(True) | |
| x.retain_grad() | |
| attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split) | |
| torch.testing.assert_close(lse_out, lse_out_ref, rtol=1e-4, atol=1e-4) | |
| torch.testing.assert_close(attn_out, attn_out_ref, rtol=1e-4, atol=1e-4) | |
| if grad_var is not None: | |
| if grad_var == "attn": | |
| attn_out.backward(out_grad) | |
| else: | |
| assert lse_out is not None | |
| lse_out.backward(out_grad) | |
| attn_grads = [x.grad for x in attn_split] | |
| lse_grads = [x.grad for x in lse_split] | |
| attn_grad_concat = torch.stack(attn_grads, dim=0) | |
| lse_grad_concat = torch.stack(lse_grads, dim=0) | |
| if grad_var == "lse": | |
| # LSE doesn't depend on attn_split, so when only gradient with respect to LSE is provided as input, | |
| # the output gradient with respect to attn_split is zero. | |
| # The reference implementation produced None instead of zero in this case | |
| attn_grad_ref = torch.zeros_like(attn_grad_concat) | |
| torch.testing.assert_close(lse_grad_concat, lse_grad_ref, rtol=1e-4, atol=1e-4) | |
| torch.testing.assert_close( | |
| attn_grad_concat, attn_grad_ref, rtol=1e-4, atol=1e-4 | |
| ) | |
| def _merge_attentions_ref(attn_split, lse_split): | |
| """ | |
| attn_split: [split_k, B, M, (G,) H, Kq] | |
| lse_split: [split_k, B, (G,) H, M] | |
| """ | |
| is_bmghk = len(attn_split.shape) == 6 | |
| if not is_bmghk: | |
| attn_split = attn_split.unsqueeze(3) | |
| lse_split = lse_split.unsqueeze(2) | |
| lse_split = lse_split[..., None].moveaxis(4, 2) # [split_k, B, M, G, H, 1] | |
| lse_max, _ = torch.max(lse_split, dim=0) # [B, M, G, H, 1] | |
| sumexp_normalized = torch.exp(lse_split - lse_max) # [split_k, B, M, G, H, 1] | |
| denominator = sumexp_normalized.sum(dim=0) # [B, M, G, H, 1] | |
| numerator = (sumexp_normalized * attn_split).sum(dim=0) # [B, M, G, H, K] | |
| attn_out = numerator / denominator # [B, M_ceil, G, H, Kq] | |
| lse_out = lse_max + torch.log(denominator) | |
| lse_out = lse_out.squeeze(4).permute(0, 2, 3, 1) # [B, G, H, M] | |
| if not is_bmghk: | |
| attn_out = attn_out.squeeze(2) | |
| lse_out = lse_out.squeeze(1) | |
| return attn_out, lse_out | |
| # rocm doesn't support backward yet | |
| def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -> None: | |
| torch.manual_seed(0) | |
| torch._dynamo.reset_code_caches() # avoids hitting recompilation limit | |
| B, M, H, K = 1, 256, 2, 64 | |
| q, k, v, bias = create_tensors( | |
| op if op is None else op[0], | |
| "cuda", | |
| torch.float16, | |
| bias_t, | |
| B, | |
| M, | |
| M, | |
| H, | |
| K, | |
| K, | |
| fmt="BMHK", | |
| ) | |
| grad = torch.randn_like(q) | |
| if create_bias_inside_compiled: | |
| bias = None | |
| if bias_t not in [None, fmha.attn_bias.LowerTriangularMask]: | |
| pytest.skip("Can't create this mask inside compile") | |
| if bias is not None: | |
| bias.to(q.device) | |
| q.requires_grad_(True) | |
| k.requires_grad_(True) | |
| v.requires_grad_(True) | |
| def fmha_fn(q, k, v, bias): | |
| if create_bias_inside_compiled and bias_t is not None: | |
| bias = bias_t() | |
| return fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=op) | |
| # Eager reference | |
| out_ref = fmha_fn(q, k, v, bias) | |
| out_ref.backward(grad) | |
| dq_ref, dk_ref, dv_ref = q.grad, k.grad, v.grad | |
| q.grad, k.grad, v.grad = None, None, None | |
| # Compiled version | |
| fmha_c = torch.compile(fmha_fn, fullgraph=True, dynamic=False) | |
| out = fmha_c(q, k, v, bias) | |
| out.backward(grad) | |
| assert_allclose( | |
| out, | |
| out_ref, | |
| "out", | |
| atol=fmha.flash.FwOp.ERROR_ATOL[q.dtype], | |
| rtol=fmha.flash.FwOp.ERROR_RTOL[q.dtype], | |
| ) | |
| atol, rtol = ( | |
| fmha.flash.BwOp.ERROR_ATOL[q.dtype], | |
| fmha.flash.BwOp.ERROR_RTOL[q.dtype], | |
| ) | |
| assert_allclose(q.grad, dq_ref, "dq", atol=atol, rtol=rtol) | |
| assert_allclose(k.grad, dk_ref, "dk", atol=atol, rtol=rtol) | |
| assert_allclose(v.grad, dv_ref, "dv", atol=atol, rtol=rtol) | |
| def test_bias_lower_triangular() -> None: | |
| mask = fmha.attn_bias.LowerTriangularMask() | |
| mask.detach() | |
| def test_bias_lower_triangular_with_bias() -> None: | |
| dense_bias = torch.randn([128, 128], dtype=torch.float16, requires_grad=True) | |
| grad = torch.randn_like(dense_bias) | |
| mask = fmha.attn_bias.LowerTriangularMask() | |
| mask_biased = mask.add_bias(dense_bias) | |
| mask_biased2 = mask_biased.detach() | |
| mask_biased.backward(grad) | |
| assert dense_bias.grad is not None | |
| assert mask_biased2.grad is None | |
| assert_allclose(dense_bias.grad, grad, "dense.grad") | |
| # end of file | |