| import os |
| import math |
| import itertools |
|
|
| import pytest |
| import torch |
| import torch.nn.functional as F |
| from torch._C import parse_schema |
|
|
| from einops import rearrange, repeat |
| apply_rotary_emb = None |
|
|
| from padding import pad_input, unpad_input |
| from test_util import ( |
| attention_ref, |
| generate_qkv, |
| generate_random_padding_mask, |
| ) |
|
|
| import kernels |
|
|
| flash_attn3 = kernels.get_kernel("kernels-community/flash-attn3") |
| ops = flash_attn3._ops.ops |
| add_op_namespace_prefix = flash_attn3._ops.add_op_namespace_prefix |
|
|
|
|
| DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" |
| DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" |
| DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" |
| DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" |
| DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" |
| DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" |
| DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" |
| DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" |
| DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 |
| DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" |
| DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" |
| DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" |
| DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" |
| DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" |
|
|
| COMPILED_HDIMS = ( |
| [] |
| + ([64] if not DISABLE_HDIM64 else []) |
| + ([96] if not DISABLE_HDIM96 else []) |
| + ([128] if not DISABLE_HDIM128 else []) |
| + ([192] if not DISABLE_HDIM192 else []) |
| + ([256] if not DISABLE_HDIM256 else []) |
| ) |
|
|
|
|
| |
| @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) |
| |
| |
| @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
| |
| |
| @pytest.mark.parametrize("has_qv", [False]) |
| |
| @pytest.mark.parametrize("deterministic", [False]) |
| @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) |
| |
| @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) |
| |
| @pytest.mark.parametrize("causal", [False, True]) |
| |
| |
| @pytest.mark.parametrize("V_colmajor", [False]) |
| |
| |
| |
| |
| |
| |
| |
| @pytest.mark.parametrize("d", COMPILED_HDIMS) |
| |
| @pytest.mark.parametrize( |
| "seqlen_q,seqlen_k", |
| [ |
| (1, 1), |
| (64, 128), |
| (128, 192), |
| (256, 256), |
| (239, 1), |
| (799, 3), |
| (113, 203), |
| (113, 128), |
| (128, 217), |
| (113, 211), |
| (108, 256), |
| (256, 512), |
| (384, 256), |
| (640, 128), |
| (512, 256), |
| (1024, 1024), |
| (1023, 1024), |
| (1024, 1023), |
| (4096, 4096), |
| (4224, 4224), |
| ], |
| ) |
| |
| def test_flash_attn_output( |
| seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype |
| ): |
| if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): |
| pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") |
| device = "cuda" |
| |
| torch.random.manual_seed(0) |
| |
| |
| batch_size = 9 if seqlen_k <= 2048 else 2 |
| |
| nheads = 6 |
| |
| nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) |
| dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype |
| dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) |
| if dtype == torch.float8_e4m3fn: |
| dv_vals = [d] |
| attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] |
| for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): |
| q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) |
| if softcap > 0.0: |
| |
| q_ref = (q_ref * softcap / 4) |
| q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() |
| k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() |
| v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() |
| if has_qv: |
| qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| else: |
| qv_ref = None |
| |
| window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() |
| |
| if dtype == torch.float8_e4m3fn: |
| q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] |
| else: |
| q_descale, k_descale, v_descale = None, None, None |
| q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] |
| qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None |
| if V_colmajor: |
| v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() |
| out_ref, attn_ref = attention_ref( |
| q_ref, |
| k_ref, |
| v_ref, |
| None, |
| None, |
| causal=causal, |
| qv=qv_ref, |
| q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| softcap=softcap |
| ) |
| out_pt, attn_pt = attention_ref( |
| q_ref, |
| k_ref, |
| v_ref, |
| None, |
| None, |
| causal=causal, |
| qv=qv_ref, |
| q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| softcap=softcap, |
| upcast=False, |
| reorder_ops=True, |
| intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() |
| rtol = 2 if softcap == 0.0 else 3 |
|
|
| print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] |
| num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] |
| for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): |
| out, lse = flash_attn3.flash_attn_func( |
| q, |
| k, |
| v, |
| causal=causal, |
| qv=qv, |
| q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| softcap=softcap, |
| pack_gqa=pack_gqa, |
| num_splits=num_splits |
| ) |
| print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| |
| |
| |
|
|
| |
| |
| assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol |
|
|
| if ( |
| not DISABLE_BACKWARD |
| and dtype != torch.float8_e4m3fn |
| and not V_colmajor |
| and not has_qv |
| and not dv > 256 |
| and not attention_chunk != 0 |
| ): |
| g = torch.randn_like(out) |
| do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) |
| dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) |
| print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
| print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
| print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
| print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
| print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
| print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
| print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
| print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
| print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
| print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
| print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
| print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
| |
| dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) |
| assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol |
| dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) |
| assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol |
| dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) |
| assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol |
|
|
|
|
| |
| @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) |
| |
| |
| @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
| |
| |
| @pytest.mark.parametrize("has_qv", [False]) |
| |
| @pytest.mark.parametrize("deterministic", [False]) |
| @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) |
| |
| @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) |
| |
| @pytest.mark.parametrize("causal", [False, True]) |
| |
| @pytest.mark.parametrize("add_unused_qkv", [False, True]) |
| |
| |
| |
| |
| |
| |
| |
| @pytest.mark.parametrize("d", COMPILED_HDIMS) |
| |
| @pytest.mark.parametrize( |
| "seqlen_q,seqlen_k", |
| [ |
| (1, 1), |
| (1, 3), |
| (2, 1), |
| (511, 1), |
| (3, 513), |
| (64, 128), |
| (128, 128), |
| (256, 256), |
| (113, 203), |
| (128, 217), |
| (113, 211), |
| (108, 256), |
| (256, 512), |
| (307, 256), |
| (640, 128), |
| (512, 256), |
| (1024, 1024), |
| (1023, 1024), |
| (1024, 1023), |
| (2048, 2048), |
| ], |
| ) |
| def test_flash_attn_varlen_output( |
| seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype |
| ): |
| device = "cuda" |
| |
| torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) |
| |
| |
| batch_size = 9 if seqlen_q <= 2048 else 2 |
| nheads = 6 |
| |
| |
| nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) |
| dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype |
| dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) |
| if dtype == torch.float8_e4m3fn: |
| dv_vals = [d] |
| attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] |
| for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): |
| q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) |
| if softcap > 0.0: |
| |
| q_ref = (q_ref * softcap / 4).detach().requires_grad_() |
| q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() |
| k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() |
| v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() |
| if has_qv: |
| qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| else: |
| qv_ref = None |
| |
| window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| if dtype == torch.float8_e4m3fn: |
| q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] |
| else: |
| q_descale, k_descale, v_descale = None, None, None |
| q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] |
| qv = qv_ref.detach() if has_qv else None |
| query_padding_mask = generate_random_padding_mask( |
| seqlen_q, batch_size, device, mode="random", zero_lengths=False |
| ) |
| key_padding_mask = generate_random_padding_mask( |
| seqlen_k, batch_size, device, mode="random", zero_lengths=True |
| ) |
|
|
| def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): |
| if add_unused: |
| another_mask = generate_random_padding_mask(max_seq_len, bs, device) |
| attn_mask = torch.logical_and(padding_mask, another_mask) |
| unused_mask = torch.logical_xor( |
| torch.logical_or(padding_mask, another_mask), attn_mask |
| ) |
| else: |
| attn_mask = padding_mask |
| unused_mask = None |
| return attn_mask, unused_mask |
|
|
| query_padding_mask, query_unused_mask = _gen_unused_masks( |
| query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device |
| ) |
| key_padding_mask, key_unused_mask = _gen_unused_masks( |
| key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device |
| ) |
|
|
| ( |
| q_unpad, |
| k_unpad, |
| v_unpad, |
| qv_unpad, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| seqused_q, |
| seqused_k, |
| max_seqlen_q, |
| max_seqlen_k, |
| q, |
| k, |
| v, |
| qv, |
| output_pad_fn, |
| dq_pad_fn, |
| dk_pad_fn, |
| ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, |
| query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) |
| q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] |
| out_ref, attn_ref = attention_ref( |
| q_ref, |
| k_ref, |
| v_ref, |
| query_padding_mask, |
| key_padding_mask, |
| causal=causal, |
| qv=qv_ref, |
| q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| softcap=softcap |
| ) |
| out_pt, attn_pt = attention_ref( |
| q_ref, |
| k_ref, |
| v_ref, |
| query_padding_mask, |
| key_padding_mask, |
| causal=causal, |
| qv=qv_ref, |
| q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| softcap=softcap, |
| upcast=False, |
| reorder_ops=True, |
| intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, |
| ) |
|
|
|
|
| print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
|
| if query_unused_mask is not None: |
| q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") |
|
|
| |
| fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() |
| rtol = 2 if softcap == 0.0 else 3 |
|
|
| pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] |
| num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] |
| for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): |
| out_unpad, lse = flash_attn3.flash_attn_varlen_func( |
| q_unpad, |
| k_unpad, |
| v_unpad, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| max_seqlen_q, |
| max_seqlen_k, |
| seqused_q=seqused_q, |
| seqused_k=seqused_k, |
| causal=causal, |
| qv=qv_unpad, |
| q_descale=q_descale, |
| k_descale=k_descale, v_descale=v_descale, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| softcap=softcap, |
| ) |
| out = output_pad_fn(out_unpad) |
| if query_unused_mask is not None: |
| out.masked_fill_(q_zero_masking, 0.0) |
| print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| |
| |
| |
|
|
| |
| |
| assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol |
|
|
|
|
| if ( |
| not DISABLE_BACKWARD |
| and dtype != torch.float8_e4m3fn |
| and not has_qv |
| and not dv > 256 |
| and not attention_chunk != 0 |
| ): |
| g_unpad = torch.randn_like(out_unpad) |
| do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) |
| dq = dq_pad_fn(dq_unpad) |
| dk = dk_pad_fn(dk_unpad) |
| dv = dk_pad_fn(dv_unpad) |
| if key_unused_mask is not None: |
| k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") |
| dk.masked_fill_(k_zero_masking, 0.0) |
| dv.masked_fill_(k_zero_masking, 0.0) |
| if query_unused_mask is not None: |
| dq.masked_fill_(q_zero_masking, 0.0) |
| |
| |
| |
| g = output_pad_fn(g_unpad) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) |
| dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) |
| print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
| print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
| print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
| print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
| print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
| print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
| print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
| print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
| print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
| print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
| print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
| print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
| |
| dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) |
| assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol |
| dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) |
| assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol |
| dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) |
| assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol |
|
|
|
|
| |
| @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) |
| |
| |
| @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
| |
| @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) |
| |
| @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) |
| |
| |
| @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) |
| |
| @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) |
| |
| @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) |
| |
| @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) |
| |
| @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) |
| |
| @pytest.mark.parametrize("has_leftpad", [False, True]) |
| |
| @pytest.mark.parametrize("has_batch_idx", [False, True]) |
| |
| @pytest.mark.parametrize("varlen_q", [False, True]) |
| |
| |
| |
| |
| |
| @pytest.mark.parametrize("d", [128]) |
| |
| @pytest.mark.parametrize( |
| "seqlen_q,seqlen_k", |
| [ |
| (1, 128), |
| (1, 339), |
| (3, 1024), |
| (64, 800), |
| (64, 256), |
| (3, 799), |
| (64, 2048), |
| (16, 20000), |
| |
| |
| (128, 128), |
| (256, 512), |
| (2048, 3577), |
| ], |
| ) |
| |
| def test_flash_attn_kvcache( |
| seqlen_q, |
| seqlen_k, |
| d, |
| varlen_q, |
| has_batch_idx, |
| has_leftpad, |
| page_size, |
| rotary_fraction, |
| rotary_interleaved, |
| has_rotary_seqlens, |
| seqlen_new_eq_seqlen_q, |
| causal, |
| local, |
| new_kv, |
| mha_type, |
| dtype, |
| ): |
| if page_size is not None and seqlen_k % page_size != 0: |
| pytest.skip() |
| if seqlen_q > seqlen_k and new_kv: |
| pytest.skip() |
| if not new_kv and rotary_fraction > 0.0: |
| pytest.skip() |
| if rotary_fraction == 0.0 and has_rotary_seqlens: |
| pytest.skip() |
| device = "cuda" |
| |
| torch.random.manual_seed(0) |
| batch_size = 5 |
| |
| batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 |
| nheads = 6 |
| |
| |
| rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 |
| nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) |
| assert nheads % nheads_k == 0 |
| dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype |
| dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) |
| if dtype == torch.float8_e4m3fn: |
| dv_vals = [d] |
| attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] |
| for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): |
| has_qv = d == 64 and dv >= 256 |
| q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| if has_qv: |
| qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| else: |
| qv = None |
| if varlen_q: |
| query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") |
| q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) |
| output_pad_fn = lambda output_unpad: pad_input( |
| output_unpad, indices_q, batch_size, seqlen_q |
| ) |
| qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None |
| else: |
| query_padding_mask = None |
| q_unpad = q |
| qv_unpad = qv |
| cu_seqlens_q, max_seqlen_q = None, None |
| |
| window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
|
| seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() |
| cu_seqlens_k_new = None |
| key_new_padding_mask = None |
| if new_kv: |
| k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| if varlen_q: |
| key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") |
| k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) |
| v_unpad, *rest = unpad_input(v, key_new_padding_mask) |
| else: |
| k_unpad, v_unpad = k, v |
| else: |
| k, v, k_unpad, v_unpad = None, None, None, None |
| if page_size is None: |
| k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) |
| page_table = None |
| else: |
| ( |
| k_cache, |
| v_cache, |
| page_table, |
| k_cache_paged, |
| v_cache_paged, |
| num_blocks, |
| ) = _generate_block_kvcache( |
| seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref |
| ) |
| cache_seqlens = torch.randint( |
| 0 if new_kv else 1, |
| |
| ( |
| (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) |
| if new_kv |
| else (seqlen_k + 1) |
| ), |
| (batch_size,), |
| dtype=torch.int32, |
| device=device, |
| ) |
| if has_leftpad: |
| cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) |
| if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) |
| for i in range(batch_size)]) |
| else: |
| cache_leftpad = None |
| if has_batch_idx: |
| cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ |
| :batch_size |
| ] |
| else: |
| cache_batch_idx = None |
| arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") |
| cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") |
| if not new_kv: |
| key_padding_mask = arange < cache_seqlens_expanded |
| else: |
| k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new |
| key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens |
| if has_leftpad: |
| key_padding_mask = torch.logical_and( |
| key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) |
| ) |
| |
| rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 |
| if rotary_dim > 0: |
| angle = ( |
| torch.rand( |
| seqlen_k if page_size is None else num_blocks * page_size, |
| rotary_dim // 2, |
| device=device, |
| ) |
| * 2 |
| * math.pi |
| ) |
| cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) |
| sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) |
| if causal or local: |
| q_ro = apply_rotary_emb( |
| q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved |
| ) |
| else: |
| q_ro = rearrange( |
| apply_rotary_emb( |
| rearrange(q, "b s h d -> b 1 (s h) d"), |
| cos, |
| sin, |
| seqlen_offsets=rotary_seqlens, |
| interleaved=rotary_interleaved, |
| ), |
| "b 1 (s h) d -> b s h d", |
| s=seqlen_q, |
| ) |
| |
| k_ro = apply_rotary_emb( |
| k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved |
| ) |
| else: |
| cos, sin = None, None |
| q_ro, k_ro = q, k |
| |
| k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() |
| v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() |
| if new_kv: |
| update_mask = torch.logical_and( |
| cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens |
| ) |
| k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") |
| v_to_update = rearrange(v, "b s ... -> (b s) ...") |
| if varlen_q: |
| k_to_update = k_to_update[indices_k] |
| v_to_update = v_to_update[indices_k] |
| k_cache_ref[update_mask] = k_to_update |
| v_cache_ref[update_mask] = v_to_update |
| k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| out_ref, _ = attention_ref( |
| q_ro, |
| k_cache_rep, |
| v_cache_rep, |
| query_padding_mask, |
| key_padding_mask, |
| causal=causal, |
| qv=qv, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| key_leftpad=cache_leftpad, |
| ) |
| out_pt, _ = attention_ref( |
| q_ro, |
| k_cache_rep, |
| v_cache_rep, |
| query_padding_mask, |
| key_padding_mask, |
| causal=causal, |
| qv=qv, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| upcast=False, |
| reorder_ops=True, |
| key_leftpad=cache_leftpad, |
| intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None |
| ) |
| q = q.to(dtype) |
| q_unpad = q_unpad.to(dtype) if varlen_q else None |
| k_cache = k_cache.to(dtype) |
| v_cache = v_cache.to(dtype) |
| k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None |
| v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None |
| k = k.to(dtype) if k is not None else None |
| v = v.to(dtype) if v is not None else None |
| k_unpad = k_unpad.to(dtype) if k_unpad is not None else None |
| v_unpad = v_unpad.to(dtype) if v_unpad is not None else None |
| qv = qv.to(dtype) if qv is not None else None |
| qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None |
| cos = cos.to(dtype) if cos is not None else None |
| sin = sin.to(dtype) if sin is not None else None |
| k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() |
| v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() |
| num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] |
| precompute_metadata_vals = [False, True] |
| for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): |
| if precompute_metadata: |
| scheduler_metadata = flash_attn3.get_scheduler_metadata( |
| batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, |
| cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, |
| max_seqlen_k_new=seqlen_new, page_size=page_size, |
| causal=causal, window_size=window_size, attention_chunk=attention_chunk, |
| num_splits=num_splits |
| ) |
| else: |
| scheduler_metadata = None |
| |
| for _ in range(1 if not precompute_metadata else 2): |
| if page_size is None: |
| k_cache.copy_(k_cache_saved) |
| v_cache.copy_(v_cache_saved) |
| else: |
| k_cache_paged.copy_(k_cache_saved) |
| v_cache_paged.copy_(v_cache_saved) |
| out, lse, *rest = flash_attn3.flash_attn_with_kvcache( |
| q if not varlen_q else q_unpad, |
| k_cache if page_size is None else k_cache_paged, |
| v_cache if page_size is None else v_cache_paged, |
| k if not new_kv or not varlen_q else k_unpad, |
| v if not new_kv or not varlen_q else v_unpad, |
| qv=qv if not varlen_q else qv_unpad, |
| rotary_cos=cos, |
| rotary_sin=sin, |
| cache_seqlens=cache_seqlens, |
| cache_batch_idx=cache_batch_idx, |
| cache_leftpad=cache_leftpad, |
| page_table=page_table, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k_new=cu_seqlens_k_new, |
| max_seqlen_q=max_seqlen_q, |
| rotary_seqlens=rotary_seqlens, |
| causal=causal, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| rotary_interleaved=rotary_interleaved, |
| scheduler_metadata=scheduler_metadata, |
| num_splits=num_splits, |
| return_softmax_lse=True |
| ) |
| if varlen_q: |
| out = output_pad_fn(out) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| |
|
|
| |
| |
| if new_kv: |
| if page_size is None: |
| k_cache_select = ( |
| k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] |
| ) |
| v_cache_select = ( |
| v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] |
| ) |
| else: |
| k_cache_select = rearrange( |
| k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], |
| "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| b=batch_size, |
| )[:, :seqlen_k].to(dtype_ref) |
| v_cache_select = rearrange( |
| v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], |
| "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| b=batch_size, |
| )[:, :seqlen_k].to(dtype_ref) |
| k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) |
| v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) |
| if dtype is not torch.float8_e4m3fn: |
| assert torch.equal(v_cache_select, v_cache_ref) |
| else: |
| assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) |
| |
| |
| if rotary_dim == 0: |
| assert torch.equal(k_cache_select, k_cache_ref) |
| else: |
| |
| |
| if dtype is not torch.float8_e4m3fn: |
| assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) |
| else: |
| assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) |
| mult = 4 if dtype == torch.float8_e4m3fn else 2 |
| assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 |
| mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 |
| assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() |
|
|
|
|
| def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): |
| num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 |
| k_cache_paged = torch.randn( |
| num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref |
| ).to(dtype).to(dtype_ref) |
| v_cache_paged = torch.randn( |
| num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref |
| ).to(dtype).to(dtype_ref) |
| page_table = rearrange( |
| torch.randperm(num_blocks, dtype=torch.int32, device=device), |
| "(b nblocks) -> b nblocks", |
| b=batch_size, |
| ) |
| k_cache = rearrange( |
| k_cache_paged[page_table.flatten()], |
| "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| b=batch_size, |
| )[:, :seqlen_k] |
| v_cache = rearrange( |
| v_cache_paged[page_table.flatten()], |
| "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| b=batch_size, |
| )[:, :seqlen_k] |
| return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks |
|
|
|
|
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) |
| @pytest.mark.parametrize("causal", [False, True]) |
| |
| @pytest.mark.parametrize('d', [128]) |
| @pytest.mark.parametrize( |
| "seqlen_q,seqlen_k", |
| [ |
| (64, 8192), |
| ], |
| ) |
| def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): |
| device = "cuda" |
| torch.random.manual_seed(0) |
| batch_size = 2 |
| nheads = 16 |
| nheads_kv = 4 |
| |
| q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) |
| k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) |
| v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) |
| for _ in range(100): |
| flash_attn3.flash_attn_func(q, k, v, causal=causal) |
|
|
|
|
| |
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) |
| @pytest.mark.parametrize("causal", [False, True]) |
| |
| @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| |
| |
| |
| |
| @pytest.mark.parametrize( |
| "seqlen_q,seqlen_k", |
| [ |
| (1, 239), |
| (239, 1), |
| (3, 799), |
| (799, 3), |
| (1024, 128), |
| (97, 97), |
| (128, 128), |
| (200, 200), |
| (256, 256), |
| (257, 257), |
| (384, 384), |
| (512, 512), |
| (768, 768), |
| (1024, 1024), |
| (2048, 2048), |
| ], |
| ) |
| def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): |
| device = "cuda" |
| |
| torch.random.manual_seed(0) |
| |
| dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) |
| batch_size = 60 |
| nheads = 4 |
| q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| torch.random.manual_seed(42) |
| out0, lse0 = flash_attn3.flash_attn_func(q, k, v, causal=causal) |
| g = torch.randn_like(out0) |
| dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) |
| |
| dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() |
|
|
| for i in range(1000): |
| torch.random.manual_seed(42) |
| out, lse = flash_attn3.flash_attn_func(q, k, v, causal=causal) |
| assert torch.equal(out, out0) |
| assert torch.equal(lse, lse0) |
|
|
| dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) |
| dq_equal = torch.allclose(dq, dq0, atol=dq_atol) |
| if not dq_equal: |
| print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") |
| |
| assert torch.equal(dv, dv0) |
| assert torch.equal(dk, dk0) |
| assert dq_equal |
|
|
|
|
| def attention_combine_ref(out_partial, lse_partial): |
| """ |
| out_partial: (num_splits, batch_size, seqlen, nheads, d) |
| lse_partial: (num_splits, batch_size, nheads, seqlen) |
| """ |
| lse = torch.logsumexp(lse_partial, dim=0) |
| scale = torch.exp(lse_partial - lse) |
| scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) |
| out = (scale.unsqueeze(-1) * out_partial).sum(0) |
| return out, lse |
|
|
|
|
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
| |
| |
| @pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) |
| |
| @pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) |
| |
| |
| @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) |
| |
| |
| def test_flash_attn_combine(num_splits, seqlen, d, dtype): |
| if DISABLE_SPLIT: |
| pytest.skip() |
| device = "cuda" |
| |
| torch.random.manual_seed(1) |
| batch_size = 5 |
| nheads = 16 |
| |
| |
| out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] |
| lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] |
| |
| lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") |
| out, lse = flash_attn3.flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) |
| out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) |
| out_pt = out_ref.to(dtype) |
|
|
| print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") |
| print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") |
| print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| |
|
|
| assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) |
| multiple = 2 |
| assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) |
|
|
| |
| |
| |
| |
|
|
| def test_flash3_bw_compatibility() -> None: |
| |
| |
| |
| |
| assert ops.fwd.default._schema.is_backward_compatible_with(parse_schema( |
| add_op_namespace_prefix("fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " |
| "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " |
| "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " |
| "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " |
| "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " |
| "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " |
| "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " |
| "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " |
| "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " |
| "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " |
| "-> (Tensor(out!), Tensor, Tensor, Tensor)" |
| ))) |
| assert ops.bwd.default._schema.is_backward_compatible_with(parse_schema( |
| add_op_namespace_prefix("bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " |
| "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " |
| "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " |
| "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " |
| "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " |
| "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" |
| ))) |
| assert ops.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( |
| add_op_namespace_prefix("fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " |
| "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" |
| ))) |
| assert ops.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( |
| add_op_namespace_prefix("get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " |
| "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " |
| "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " |
| "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " |
| "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " |
| "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " |
| "int sm_margin=0) -> Tensor" |
| ))) |
|
|