| | import math |
| |
|
| | import pytest |
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| | from flash_attn import ( |
| | flash_attn_func, |
| | flash_attn_kvpacked_func, |
| | flash_attn_qkvpacked_func, |
| | flash_attn_varlen_func, |
| | flash_attn_varlen_kvpacked_func, |
| | flash_attn_varlen_qkvpacked_func, |
| | flash_attn_with_kvcache, |
| | ) |
| | from flash_attn.bert_padding import pad_input, unpad_input |
| | from flash_attn.flash_attn_interface import _get_block_size_n |
| | from flash_attn.layers.rotary import apply_rotary_emb |
| |
|
| | MAX_HEADDIM_SM8x = 192 |
| |
|
| |
|
| | is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5) |
| | is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8 |
| | is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) |
| | is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) |
| |
|
| |
|
| | def attn_bias_from_alibi_slopes( |
| | slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None |
| | ): |
| | batch, nheads = slopes.shape |
| | device = slopes.device |
| | slopes = rearrange(slopes, "b h -> b h 1 1") |
| | if causal: |
| | return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes |
| | else: |
| | row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") |
| | col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) |
| | if key_leftpad is not None: |
| | key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") |
| | col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) |
| | col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) |
| | sk = ( |
| | seqlen_k |
| | if key_padding_mask is None |
| | else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") |
| | ) |
| | sq = ( |
| | seqlen_q |
| | if query_padding_mask is None |
| | else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") |
| | ) |
| | relative_pos = torch.abs(row_idx + sk - sq - col_idx) |
| | return -slopes * relative_pos.to(dtype=slopes.dtype) |
| |
|
| |
|
| | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): |
| | assert mode in ["full", "random", "third"] |
| | if mode == "full": |
| | lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) |
| | elif mode == "random": |
| | lengths = torch.randint( |
| | max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device |
| | ) |
| | elif mode == "third": |
| | lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) |
| | padding_mask = ( |
| | repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths |
| | ) |
| | return padding_mask |
| |
|
| |
|
| | def generate_qkv( |
| | q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False |
| | ): |
| | """ |
| | Arguments: |
| | q: (batch_size, seqlen_q, nheads, d) |
| | k: (batch_size, seqlen_k, nheads_k, d) |
| | v: (batch_size, seqlen_k, nheads_k, d) |
| | query_padding_mask: (batch_size, seqlen), bool |
| | key_padding_mask: (batch_size, seqlen), bool |
| | """ |
| | assert not (kvpacked and qkvpacked) |
| | batch_size, seqlen_q, nheads, d = q.shape |
| | _, seqlen_k, nheads_k, _ = k.shape |
| | assert k.shape == (batch_size, seqlen_k, nheads_k, d) |
| | assert v.shape == (batch_size, seqlen_k, nheads_k, d) |
| |
|
| | if query_padding_mask is not None: |
| | q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) |
| | output_pad_fn = lambda output_unpad: pad_input( |
| | output_unpad, indices_q, batch_size, seqlen_q |
| | ) |
| | else: |
| | q_unpad = rearrange(q, "b s h d -> (b s) h d") |
| | cu_seqlens_q = torch.arange( |
| | 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device |
| | ) |
| | max_seqlen_q = seqlen_q |
| | output_pad_fn = lambda output_unpad: rearrange( |
| | output_unpad, "(b s) h d -> b s h d", b=batch_size |
| | ) |
| |
|
| | if key_padding_mask is not None: |
| | k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) |
| | v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) |
| | else: |
| | k_unpad = rearrange(k, "b s h d -> (b s) h d") |
| | v_unpad = rearrange(v, "b s h d -> (b s) h d") |
| | cu_seqlens_k = torch.arange( |
| | 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device |
| | ) |
| | max_seqlen_k = seqlen_k |
| |
|
| | if qkvpacked: |
| | assert (query_padding_mask == key_padding_mask).all() |
| | assert nheads == nheads_k |
| | qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) |
| | qkv = torch.stack([q, k, v], dim=2) |
| | if query_padding_mask is not None: |
| | dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) |
| | else: |
| | dqkv_pad_fn = lambda dqkv_unpad: rearrange( |
| | dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size |
| | ) |
| | return ( |
| | qkv_unpad.detach().requires_grad_(), |
| | cu_seqlens_q, |
| | max_seqlen_q, |
| | qkv.detach().requires_grad_(), |
| | output_pad_fn, |
| | dqkv_pad_fn, |
| | ) |
| | elif kvpacked: |
| | kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) |
| | kv = torch.stack([k, v], dim=2) |
| | dq_pad_fn = output_pad_fn |
| | if key_padding_mask is not None: |
| | dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) |
| | else: |
| | dkv_pad_fn = lambda dkv_unpad: rearrange( |
| | dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size |
| | ) |
| | return ( |
| | q_unpad.detach().requires_grad_(), |
| | kv_unpad.detach().requires_grad_(), |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | q.detach().requires_grad_(), |
| | kv.detach().requires_grad_(), |
| | output_pad_fn, |
| | dq_pad_fn, |
| | dkv_pad_fn, |
| | ) |
| | else: |
| | dq_pad_fn = output_pad_fn |
| | if key_padding_mask is not None: |
| | dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) |
| | else: |
| | dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) |
| | return ( |
| | q_unpad.detach().requires_grad_(), |
| | k_unpad.detach().requires_grad_(), |
| | v_unpad.detach().requires_grad_(), |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | q.detach().requires_grad_(), |
| | k.detach().requires_grad_(), |
| | v.detach().requires_grad_(), |
| | output_pad_fn, |
| | dq_pad_fn, |
| | dk_pad_fn, |
| | ) |
| |
|
| |
|
| | def construct_local_mask( |
| | seqlen_q, |
| | seqlen_k, |
| | window_size=(-1, -1), |
| | query_padding_mask=None, |
| | key_padding_mask=None, |
| | device=None, |
| | key_leftpad=None, |
| | ): |
| | row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") |
| | col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) |
| | if key_leftpad is not None: |
| | key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") |
| | col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) |
| | col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) |
| | sk = ( |
| | seqlen_k |
| | if key_padding_mask is None |
| | else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") |
| | ) |
| | sq = ( |
| | seqlen_q |
| | if query_padding_mask is None |
| | else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") |
| | ) |
| | if window_size[0] < 0: |
| | return col_idx > row_idx + sk - sq + window_size[1] |
| | else: |
| | sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk |
| | return torch.logical_or( |
| | col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), |
| | col_idx < row_idx + sk - sq - window_size[0], |
| | ) |
| |
|
| |
|
| | def attention_ref( |
| | q, |
| | k, |
| | v, |
| | query_padding_mask=None, |
| | key_padding_mask=None, |
| | attn_bias=None, |
| | dropout_p=0.0, |
| | dropout_mask=None, |
| | causal=False, |
| | window_size=(-1, -1), |
| | softcap=0.0, |
| | upcast=True, |
| | reorder_ops=False, |
| | key_leftpad=None, |
| | ): |
| | """ |
| | Arguments: |
| | q: (batch_size, seqlen_q, nheads, head_dim) |
| | k: (batch_size, seqlen_k, nheads_k, head_dim) |
| | v: (batch_size, seqlen_k, nheads_k, head_dim) |
| | query_padding_mask: (batch_size, seqlen_q) |
| | key_padding_mask: (batch_size, seqlen_k) |
| | attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) |
| | dropout_p: float |
| | dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) |
| | causal: whether to apply causal masking |
| | window_size: (int, int), left and right window size |
| | upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast |
| | output back to fp16/bf16. |
| | reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) |
| | without changing the math. This is to estimate the numerical error from operation |
| | reordering. |
| | Output: |
| | output: (batch_size, seqlen_q, nheads, head_dim) |
| | attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout |
| | """ |
| | if causal: |
| | window_size = (window_size[0], 0) |
| | dtype_og = q.dtype |
| | if upcast: |
| | q, k, v = q.float(), k.float(), v.float() |
| | seqlen_q, seqlen_k = q.shape[1], k.shape[1] |
| | k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) |
| | v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) |
| | d = q.shape[-1] |
| | if not reorder_ops: |
| | scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) |
| | else: |
| | scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) |
| | if softcap > 0: |
| | scores = scores / softcap |
| | scores = scores.tanh() |
| | scores = scores * softcap |
| | if key_padding_mask is not None: |
| | scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) |
| | if window_size[0] >= 0 or window_size[1] >= 0: |
| | local_mask = construct_local_mask( |
| | seqlen_q, |
| | seqlen_k, |
| | window_size, |
| | query_padding_mask, |
| | key_padding_mask, |
| | q.device, |
| | key_leftpad=key_leftpad, |
| | ) |
| | scores.masked_fill_(local_mask, float("-inf")) |
| | if attn_bias is not None: |
| | scores = scores + attn_bias |
| | attention = torch.softmax(scores, dim=-1).to(v.dtype) |
| | |
| | if window_size[0] >= 0 or window_size[1] >= 0: |
| | attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) |
| | |
| | |
| | if query_padding_mask is not None: |
| | attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) |
| | dropout_scaling = 1.0 / (1 - dropout_p) |
| | |
| | |
| | if dropout_mask is not None: |
| | attention_drop = attention.masked_fill(~dropout_mask, 0.0) |
| | else: |
| | attention_drop = attention |
| | output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) |
| | if query_padding_mask is not None: |
| | output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) |
| | return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) |
| |
|
| |
|
| | def attention_kvpacked_ref( |
| | q, |
| | kv, |
| | query_padding_mask=None, |
| | key_padding_mask=None, |
| | attn_bias=None, |
| | dropout_p=0.0, |
| | dropout_mask=None, |
| | causal=False, |
| | window_size=(-1, -1), |
| | softcap=0.0, |
| | upcast=True, |
| | reorder_ops=False, |
| | key_leftpad=None, |
| | ): |
| | return attention_ref( |
| | q, |
| | kv[:, :, 0], |
| | kv[:, :, 1], |
| | query_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | upcast=upcast, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | reorder_ops=reorder_ops, |
| | key_leftpad=key_leftpad, |
| | ) |
| |
|
| |
|
| | def attention_qkvpacked_ref( |
| | qkv, |
| | key_padding_mask=None, |
| | attn_bias=None, |
| | dropout_p=0.0, |
| | dropout_mask=None, |
| | causal=False, |
| | window_size=(-1, -1), |
| | softcap=0.0, |
| | upcast=True, |
| | reorder_ops=False, |
| | ): |
| | return attention_ref( |
| | qkv[:, :, 0], |
| | qkv[:, :, 1], |
| | qkv[:, :, 2], |
| | key_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | upcast=upcast, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | reorder_ops=reorder_ops, |
| | ) |
| |
|
| |
|
| | def generate_sparsity_mask(seqlen, sparsity=0.3): |
| | repeats = seqlen // 16 // 2 |
| | |
| | |
| | |
| | |
| | |
| | |
| | nrow, ncol = seqlen // 16, seqlen // 256 |
| | mask = torch.rand(nrow, ncol, device="cuda") < sparsity |
| | return mask |
| |
|
| |
|
| | def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): |
| | """ |
| | Arguments: |
| | qkv: (batch_size, seqlen, 3, nheads, head_dim) |
| | blockmask: (seqlen / 16, seqlen / 256) |
| | attn_mask: (batch_size, seqlen) |
| | dropout_p: float |
| | dropout_mask: (batch_size, nheads, seqlen, seqlen) |
| | Output: |
| | output: (batch_size, seqlen, nheads, head_dim) |
| | attention: softmax after dropout |
| | """ |
| | q, k, v = qkv.float().unbind(dim=2) |
| | d = qkv.shape[-1] |
| | seqlen = qkv.shape[1] |
| | scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) |
| | scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) |
| | blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)") |
| | blockmask = blockmask[:seqlen, :seqlen] |
| | scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf")) |
| | attention = torch.softmax(scores, dim=-1) |
| | attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0) |
| | attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0) |
| | attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) |
| | output = torch.einsum("bhts,bshd->bthd", attention_drop, v) |
| | output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0) |
| | return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) |
| |
|
| |
|
| | def convert_flash_attn_S_to_softmax( |
| | S, |
| | seqlen_q, |
| | seqlen_k, |
| | query_padding_mask, |
| | key_padding_mask, |
| | head_dim, |
| | is_dropout, |
| | causal=False, |
| | window_size=(-1, -1), |
| | ): |
| | """FlashAttention stores the S matrix in a different way. |
| | Arguments: |
| | S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) |
| | query_padding_mask: (batch_size, seqlen_q_rounded) |
| | key_padding_mask: (batch_size, seqlen_k_rounded) |
| | """ |
| | if causal: |
| | window_size = (window_size[0], 0) |
| | seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] |
| | S_converted = S |
| | if window_size[0] >= 0 or window_size[1] >= 0: |
| | local_mask = construct_local_mask( |
| | seqlen_q, |
| | seqlen_k, |
| | window_size, |
| | query_padding_mask, |
| | key_padding_mask, |
| | S.device, |
| | ) |
| | local_mask = F.pad( |
| | local_mask, |
| | (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), |
| | value=True, |
| | ) |
| | S_converted = S_converted.masked_fill(local_mask, 0.0) |
| |
|
| | |
| | |
| | seqlen_q_og = ( |
| | query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded |
| | ) |
| | if query_padding_mask is not None: |
| | query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) |
| | S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) |
| | seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k |
| | if key_padding_mask is not None: |
| | key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) |
| | S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) |
| | S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) |
| | S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) |
| | return S_converted[:, :, :seqlen_q, :seqlen_k] |
| |
|
| |
|
| | def normalize_flash_attn_S( |
| | attn_unnorm, |
| | q, |
| | k, |
| | v, |
| | query_padding_mask=None, |
| | key_padding_mask=None, |
| | attn_bias=None, |
| | is_dropout=False, |
| | causal=False, |
| | window_size=(-1, -1), |
| | ): |
| | """ |
| | Arguments: |
| | q: (batch_size, seqlen_q, nheads, head_dim) |
| | k, v: (batch_size, seqlen_k, nheads, head_dim) |
| | key_padding_mask: (batch_size, seqlen_q) |
| | attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) |
| | Output: |
| | softmax_lse: (batch_size, nheads, seqlen_q) |
| | softmax_max: (batch_size, nheads, seqlen_q) |
| | """ |
| | if causal: |
| | window_size = (window_size[0], 0) |
| | q, k, v = q.float(), k.float(), v.float() |
| | _, seqlen_q, _, head_dim = q.shape |
| | seqlen_k = k.shape[1] |
| | scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) |
| | if key_padding_mask is not None: |
| | scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) |
| | if window_size[0] >= 0 or window_size[1] >= 0: |
| | local_mask = construct_local_mask( |
| | seqlen_q, |
| | seqlen_k, |
| | window_size, |
| | query_padding_mask, |
| | key_padding_mask, |
| | q.device, |
| | ) |
| | scores.masked_fill_(local_mask, float("-inf")) |
| | if attn_bias is not None: |
| | scores = scores + attn_bias.to(dtype=scores.dtype) |
| | block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) |
| | scores_block = scores.split(block_size_n, dim=-1) |
| | lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) |
| | lse = torch.logsumexp(lse_block, dim=-1) |
| | |
| | |
| | lse[lse == float("-inf")] = float("inf") |
| | scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) |
| | cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) |
| | attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) |
| | attn_norm = torch.cat( |
| | [ |
| | a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") |
| | for a, m in zip(attn_unnorm_block, cummax_block) |
| | ], |
| | dim=-1, |
| | ) |
| | if query_padding_mask is not None: |
| | attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) |
| | return attn_norm.to(dtype=attn_unnorm.dtype) |
| |
|
| |
|
| | def get_dropout_fraction( |
| | dropout_mask, |
| | query_padding_mask=None, |
| | key_padding_mask=None, |
| | causal=False, |
| | window_size=(-1, -1), |
| | ): |
| | """ |
| | dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. |
| | query_padding_mask: (batch_size, seqlen_q) |
| | key_padding_mask: (batch_size, seqlen_k) |
| | """ |
| | if causal: |
| | window_size = (window_size[0], 0) |
| | batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape |
| | dropped = ~dropout_mask |
| | valid = torch.ones_like(dropout_mask) |
| | if query_padding_mask is not None: |
| | dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) |
| | valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) |
| | if key_padding_mask is not None: |
| | dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) |
| | valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) |
| | if window_size[0] >= 0 or window_size[1] >= 0: |
| | local_mask = construct_local_mask( |
| | seqlen_q, |
| | seqlen_k, |
| | window_size, |
| | query_padding_mask, |
| | key_padding_mask, |
| | dropout_mask.device, |
| | ) |
| | dropped.masked_fill_(local_mask, False) |
| | valid.masked_fill_(local_mask, False) |
| | dropped_total = dropped.sum() |
| | return dropped.sum() / valid.sum() |
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("deterministic", [False, True]) |
| | |
| | @pytest.mark.parametrize("alibi", [False, True]) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) |
| | |
| | @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
| | |
| | def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): |
| | if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: |
| | pytest.skip() |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 4 |
| | nheads = 9 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) |
| | qkv = torch.randn( |
| | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | if alibi: |
| | alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
| | attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) |
| | else: |
| | alibi_slopes, attn_bias = None, None |
| | out, lse, S_dmask = flash_attn_qkvpacked_func( |
| | qkv, |
| | dropout_p, |
| | causal=causal, |
| | window_size=window_size, |
| | alibi_slopes=alibi_slopes, |
| | deterministic=deterministic, |
| | return_attn_probs=True, |
| | ) |
| | if dropout_p > 0.0: |
| | S_dmask_converted = convert_flash_attn_S_to_softmax( |
| | S_dmask, |
| | seqlen, |
| | seqlen, |
| | None, |
| | None, |
| | d, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_mask = S_dmask_converted >= 0 |
| | attn_unnorm = S_dmask_converted.abs() |
| | attn = normalize_flash_attn_S( |
| | attn_unnorm, |
| | qkv[:, :, 0], |
| | qkv[:, :, 1], |
| | qkv[:, :, 2], |
| | None, |
| | None, |
| | attn_bias, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_fraction = get_dropout_fraction( |
| | dropout_mask, None, None, causal=causal, window_size=window_size |
| | ).item() |
| | print(f"Actual dropout fraction: {dropout_fraction}") |
| | else: |
| | dropout_mask = None |
| |
|
| | out_ref, attn_ref = attention_qkvpacked_ref( |
| | qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size |
| | ) |
| | out_pt, attn_pt = attention_qkvpacked_ref( |
| | qkv, |
| | None, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| | if dropout_p > 0.0: |
| | print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
| | print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
| |
|
| | g = torch.randn_like(out) |
| | |
| | |
| | |
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | (dqkv,) = torch.autograd.grad(out, qkv, g) |
| | (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) |
| | (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) |
| | print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
| | print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
| | print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
| | print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") |
| | print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
| | print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
| |
|
| | if dropout_p > 0.0: |
| | assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
| | |
| | if not alibi: |
| | assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) |
| |
|
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("deterministic", [False, True]) |
| | |
| | @pytest.mark.parametrize("alibi", [False, True]) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) |
| | |
| | |
| | @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) |
| | |
| | @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
| | |
| | def test_flash_attn_varlen_qkvpacked( |
| | seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype |
| | ): |
| | if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: |
| | pytest.skip() |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 5 |
| | nheads = 6 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) |
| | qkv = torch.randn( |
| | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| |
|
| | key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") |
| | |
| | if alibi: |
| | alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
| | attn_bias = attn_bias_from_alibi_slopes( |
| | alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal |
| | ) |
| | else: |
| | alibi_slopes, attn_bias = None, None |
| |
|
| | qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( |
| | *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True |
| | ) |
| |
|
| | out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( |
| | qkv_unpad, |
| | cu_seqlens, |
| | max_seqlen, |
| | dropout_p, |
| | causal=causal, |
| | window_size=window_size, |
| | alibi_slopes=alibi_slopes, |
| | deterministic=deterministic, |
| | return_attn_probs=True, |
| | ) |
| | out = output_pad_fn(out_unpad) |
| | if dropout_p > 0.0: |
| | S_dmask_converted = convert_flash_attn_S_to_softmax( |
| | S_dmask, |
| | seqlen, |
| | seqlen, |
| | key_padding_mask, |
| | key_padding_mask, |
| | d, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_mask = S_dmask_converted >= 0 |
| | attn_unnorm = S_dmask_converted.abs() |
| | attn = normalize_flash_attn_S( |
| | attn_unnorm, |
| | qkv[:, :, 0], |
| | qkv[:, :, 1], |
| | qkv[:, :, 2], |
| | key_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_fraction = get_dropout_fraction( |
| | dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size |
| | ).item() |
| | print(f"Actual dropout fraction: {dropout_fraction}") |
| | else: |
| | dropout_mask = None |
| |
|
| | out_ref, attn_ref = attention_qkvpacked_ref( |
| | qkv, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | out_pt, attn_pt = attention_qkvpacked_ref( |
| | qkv, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| | if dropout_p > 0.0: |
| | print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
| | print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
| |
|
| | g = torch.randn_like(out) |
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) |
| | dqkv = dqkv_pad_fn(dqkv_unpad) |
| | (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) |
| | (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) |
| | print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
| | print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
| | print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
| | print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") |
| | print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
| | print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
| |
|
| | if dropout_p > 0.0: |
| | assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
| | |
| | if not alibi: |
| | assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) |
| |
|
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() |
| |
|
| |
|
| | @pytest.mark.parametrize("kvpacked", [True, False]) |
| | |
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
| | |
| | @pytest.mark.parametrize("deterministic", [False, True]) |
| | |
| | @pytest.mark.parametrize("alibi", [False, True]) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (113, 203), |
| | (128, 217), |
| | (113, 211), |
| | (108, 256), |
| | (256, 512), |
| | (512, 256), |
| | (1024, 1024), |
| | (1023, 1024), |
| | (1024, 1023), |
| | (2048, 2048), |
| | ], |
| | ) |
| | |
| | @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
| | |
| | @pytest.mark.parametrize("softcap", [0.0, 50.0]) |
| | def test_flash_attn_output( |
| | seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap |
| | ): |
| | if ( |
| | max(seqlen_q, seqlen_k) >= 2048 |
| | and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
| | ): |
| | pytest.skip() |
| | if softcap > 0.0 and dropout_p > 0.0: |
| | pytest.skip("Softcap and dropout not supported together") |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 4 |
| | nheads = 6 if softcap == 0.0 else 4 |
| | nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) |
| | assert nheads % nheads_k == 0 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | if softcap > 0: |
| | |
| | q = q * softcap |
| | if kvpacked: |
| | kv = torch.randn( |
| | batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | else: |
| | k = torch.randn( |
| | batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | v = torch.randn( |
| | batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | if alibi: |
| | alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
| | attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) |
| | else: |
| | alibi_slopes, attn_bias = None, None |
| |
|
| | if kvpacked: |
| | out, lse, S_dmask = flash_attn_kvpacked_func( |
| | q, |
| | kv, |
| | dropout_p, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | alibi_slopes=alibi_slopes, |
| | deterministic=deterministic, |
| | return_attn_probs=True, |
| | ) |
| | else: |
| | out, lse, S_dmask = flash_attn_func( |
| | q, |
| | k, |
| | v, |
| | dropout_p, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | alibi_slopes=alibi_slopes, |
| | deterministic=deterministic, |
| | return_attn_probs=True, |
| | ) |
| | if dropout_p > 0.0: |
| | S_dmask_converted = convert_flash_attn_S_to_softmax( |
| | S_dmask, |
| | seqlen_q, |
| | seqlen_k, |
| | None, |
| | None, |
| | d, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_mask = S_dmask_converted >= 0 |
| | attn_unnorm = S_dmask_converted.abs() |
| | if kvpacked: |
| | kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) |
| | k_rep, v_rep = kv_rep.unbind(dim=2) |
| | else: |
| | k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| | v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| | attn = normalize_flash_attn_S( |
| | attn_unnorm, |
| | q, |
| | k_rep, |
| | v_rep, |
| | None, |
| | None, |
| | attn_bias, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_fraction = get_dropout_fraction( |
| | dropout_mask, None, None, causal=causal, window_size=window_size |
| | ).item() |
| | print(f"Actual dropout fraction: {dropout_fraction}") |
| | else: |
| | dropout_mask = None |
| |
|
| | if kvpacked: |
| | out_ref, attn_ref = attention_kvpacked_ref( |
| | q, |
| | kv, |
| | None, |
| | None, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | ) |
| | out_pt, attn_pt = attention_kvpacked_ref( |
| | q, |
| | kv, |
| | None, |
| | None, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| | else: |
| | out_ref, attn_ref = attention_ref( |
| | q, |
| | k, |
| | v, |
| | None, |
| | None, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | ) |
| | out_pt, attn_pt = attention_ref( |
| | q, |
| | k, |
| | v, |
| | None, |
| | None, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| |
|
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| | if dropout_p > 0.0: |
| | print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
| | print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
| |
|
| | g = torch.randn_like(out) |
| | do_o = (g.float() * out.float()).sum(-1) |
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | if kvpacked: |
| | ( |
| | dq, |
| | dkv, |
| | ) = torch.autograd.grad(out, (q, kv), g) |
| | dk, dv = dkv.unbind(2) |
| | ( |
| | dq_ref, |
| | dkv_ref, |
| | ) = torch.autograd.grad(out_ref, (q, kv), g) |
| | dk_ref, dv_ref = dkv_ref.unbind(2) |
| | ( |
| | dq_pt, |
| | dkv_pt, |
| | ) = torch.autograd.grad(out_pt, (q, kv), g) |
| | dk_pt, dv_pt = dkv_pt.unbind(2) |
| | else: |
| | ( |
| | dq, |
| | dk, |
| | dv, |
| | ) = torch.autograd.grad(out, (q, k, v), g) |
| | ( |
| | dq_ref, |
| | dk_ref, |
| | dv_ref, |
| | ) = torch.autograd.grad(out_ref, (q, k, v), g) |
| | ( |
| | dq_pt, |
| | dk_pt, |
| | dv_pt, |
| | ) = torch.autograd.grad(out_pt, (q, k, v), g) |
| | print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
| | print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
| | print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
| | print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
| | print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
| | print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
| | print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
| | print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
| | print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
| | print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
| |
|
| | if dropout_p > 0.0: |
| | assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
| | |
| | if not alibi: |
| | assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) |
| |
|
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() |
| | assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() |
| | assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() |
| |
|
| |
|
| | @pytest.mark.parametrize("kvpacked", [True, False]) |
| | |
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
| | |
| | @pytest.mark.parametrize("deterministic", [False, True]) |
| | |
| | @pytest.mark.parametrize("alibi", [False, True]) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (1, 147), |
| | (113, 203), |
| | (128, 217), |
| | (113, 211), |
| | (108, 256), |
| | (256, 512), |
| | (512, 256), |
| | (1024, 1024), |
| | (1023, 1024), |
| | (1024, 1023), |
| | (2048, 2048), |
| | ], |
| | ) |
| | |
| | @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
| | @pytest.mark.parametrize("softcap", [0.0, 50.0]) |
| | |
| | def test_flash_attn_varlen_output( |
| | seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap |
| | ): |
| | if ( |
| | max(seqlen_q, seqlen_k) >= 2048 |
| | and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
| | ): |
| | pytest.skip() |
| | if softcap > 0.0 and dropout_p > 0.0: |
| | pytest.skip("Softcap and dropout not supported together") |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 4 |
| | nheads = 6 if softcap == 0.0 else 4 |
| | nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) |
| | assert nheads % nheads_k == 0 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | if softcap > 0: |
| | |
| | q = q * softcap |
| |
|
| | if kvpacked: |
| | kv = torch.randn( |
| | batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | else: |
| | k = torch.randn( |
| | batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | v = torch.randn( |
| | batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| |
|
| | query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") |
| | key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") |
| | |
| | if alibi: |
| | alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
| | attn_bias = attn_bias_from_alibi_slopes( |
| | alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal |
| | ) |
| | else: |
| | alibi_slopes, attn_bias = None, None |
| |
|
| | if kvpacked: |
| | ( |
| | q_unpad, |
| | kv_unpad, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | q, |
| | kv, |
| | output_pad_fn, |
| | dq_pad_fn, |
| | dkv_pad_fn, |
| | ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True) |
| | out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( |
| | q_unpad, |
| | kv_unpad, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | dropout_p, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | alibi_slopes=alibi_slopes, |
| | deterministic=deterministic, |
| | return_attn_probs=True, |
| | ) |
| | else: |
| | ( |
| | q_unpad, |
| | k_unpad, |
| | v_unpad, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | q, |
| | k, |
| | v, |
| | output_pad_fn, |
| | dq_pad_fn, |
| | dk_pad_fn, |
| | ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) |
| | out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( |
| | q_unpad, |
| | k_unpad, |
| | v_unpad, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | dropout_p, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | alibi_slopes=alibi_slopes, |
| | deterministic=deterministic, |
| | return_attn_probs=True, |
| | ) |
| | out = output_pad_fn(out_unpad) |
| | if dropout_p > 0.0: |
| | S_dmask_converted = convert_flash_attn_S_to_softmax( |
| | S_dmask, |
| | seqlen_q, |
| | seqlen_k, |
| | query_padding_mask, |
| | key_padding_mask, |
| | d, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_mask = S_dmask_converted >= 0 |
| | attn_unnorm = S_dmask_converted.abs() |
| | if kvpacked: |
| | kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) |
| | k_rep, v_rep = kv_rep.unbind(dim=2) |
| | else: |
| | k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| | v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| | attn = normalize_flash_attn_S( |
| | attn_unnorm, |
| | q, |
| | k_rep, |
| | v_rep, |
| | query_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p > 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | dropout_fraction = get_dropout_fraction( |
| | dropout_mask, |
| | query_padding_mask, |
| | key_padding_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | ).item() |
| | print(f"Actual dropout fraction: {dropout_fraction}") |
| | else: |
| | dropout_mask = None |
| |
|
| | if kvpacked: |
| | out_ref, attn_ref = attention_kvpacked_ref( |
| | q, |
| | kv, |
| | query_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | ) |
| | out_pt, attn_pt = attention_kvpacked_ref( |
| | q, |
| | kv, |
| | query_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| | else: |
| | out_ref, attn_ref = attention_ref( |
| | q, |
| | k, |
| | v, |
| | query_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | ) |
| | out_pt, attn_pt = attention_ref( |
| | q, |
| | k, |
| | v, |
| | query_padding_mask, |
| | key_padding_mask, |
| | attn_bias, |
| | dropout_p, |
| | dropout_mask, |
| | causal=causal, |
| | window_size=window_size, |
| | softcap=softcap, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| |
|
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| | if dropout_p > 0.0: |
| | print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
| | print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
| |
|
| | g = torch.randn_like(out) |
| | if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)): |
| | if kvpacked: |
| | ( |
| | dq_unpad, |
| | dkv_unpad, |
| | ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) |
| | dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) |
| | ( |
| | dq_ref, |
| | dkv_ref, |
| | ) = torch.autograd.grad(out_ref, (q, kv), g) |
| | dk_ref, dv_ref = dkv_ref.unbind(2) |
| | ( |
| | dq_pt, |
| | dkv_pt, |
| | ) = torch.autograd.grad(out_pt, (q, kv), g) |
| | dk_pt, dv_pt = dkv_pt.unbind(2) |
| | else: |
| | ( |
| | dq_unpad, |
| | dk_unpad, |
| | dv_unpad, |
| | ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) |
| | dk = dk_pad_fn(dk_unpad) |
| | dv = dk_pad_fn(dv_unpad) |
| | ( |
| | dq_ref, |
| | dk_ref, |
| | dv_ref, |
| | ) = torch.autograd.grad(out_ref, (q, k, v), g) |
| | ( |
| | dq_pt, |
| | dk_pt, |
| | dv_pt, |
| | ) = torch.autograd.grad(out_pt, (q, k, v), g) |
| | dq = dq_pad_fn(dq_unpad) |
| | print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
| | print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
| | print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
| | print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
| | print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
| | print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
| | print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
| | print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
| | print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
| | print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
| |
|
| | if dropout_p > 0.0: |
| | assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
| | |
| | if not alibi: |
| | assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) |
| |
|
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() |
| | assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() |
| | assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize("swap_sq_sk", [False, True]) |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (1, 239), |
| | (3, 799), |
| | (127, 512), |
| | (127, 513), |
| | (113, 203), |
| | (128, 217), |
| | (113, 211), |
| | (108, 256), |
| | (256, 512), |
| | (1023, 1024), |
| | ], |
| | ) |
| | |
| | def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): |
| | if ( |
| | max(seqlen_q, seqlen_k) >= 2048 |
| | and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
| | ): |
| | pytest.skip() |
| | if swap_sq_sk: |
| | seqlen_q, seqlen_k = seqlen_k, seqlen_q |
| | device = "cuda" |
| | causal = True |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 8 |
| | nheads = 9 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) |
| | out_ref, attn_ref = attention_ref( |
| | q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size |
| | ) |
| | out_pt, attn_pt = attention_ref( |
| | q, |
| | k, |
| | v, |
| | None, |
| | None, |
| | None, |
| | 0.0, |
| | None, |
| | causal=causal, |
| | window_size=window_size, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| |
|
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| |
|
| | g = torch.randn_like(out) |
| | do_o = (g.float() * out.float()).sum(-1) |
| | ( |
| | dq, |
| | dk, |
| | dv, |
| | ) = torch.autograd.grad(out, (q, k, v), g) |
| | ( |
| | dq_ref, |
| | dk_ref, |
| | dv_ref, |
| | ) = torch.autograd.grad(out_ref, (q, k, v), g) |
| | ( |
| | dq_pt, |
| | dk_pt, |
| | dv_pt, |
| | ) = torch.autograd.grad(out_pt, (q, k, v), g) |
| | print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
| | print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
| | print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
| | print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
| | print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
| | print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
| | print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
| | print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
| | print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
| | print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 |
| |
|
| | assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 |
| | assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 |
| | assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize("swap_sq_sk", [False, True]) |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (1, 239), |
| | (3, 799), |
| | (127, 512), |
| | (127, 513), |
| | (113, 203), |
| | (128, 217), |
| | (113, 211), |
| | (108, 256), |
| | (256, 512), |
| | (1023, 1024), |
| | ], |
| | ) |
| | |
| | @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) |
| | |
| | def test_flash_attn_varlen_causal( |
| | seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype |
| | ): |
| | if ( |
| | max(seqlen_q, seqlen_k) >= 2048 |
| | and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
| | ): |
| | pytest.skip() |
| | if swap_sq_sk: |
| | seqlen_q, seqlen_k = seqlen_k, seqlen_q |
| | device = "cuda" |
| | causal = True |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 8 |
| | nheads = 9 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| |
|
| | if paged_kv_block_size is None: |
| | k = torch.randn( |
| | batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | v = torch.randn( |
| | batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True |
| | ) |
| | block_table = None |
| | else: |
| | k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( |
| | seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype |
| | ) |
| | query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") |
| | key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") |
| | ( |
| | q_unpad, |
| | k_unpad, |
| | v_unpad, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | q, |
| | k, |
| | v, |
| | output_pad_fn, |
| | dq_pad_fn, |
| | dk_pad_fn, |
| | ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) |
| | out_unpad = flash_attn_varlen_func( |
| | q_unpad, |
| | k_unpad if paged_kv_block_size is None else k_cache_paged, |
| | v_unpad if paged_kv_block_size is None else v_cache_paged, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | block_table=block_table, |
| | ) |
| | out = output_pad_fn(out_unpad) |
| | out_ref, attn_ref = attention_ref( |
| | q, |
| | k, |
| | v, |
| | query_padding_mask, |
| | key_padding_mask, |
| | None, |
| | 0.0, |
| | None, |
| | causal=causal, |
| | window_size=window_size, |
| | ) |
| | out_pt, attn_pt = attention_ref( |
| | q, |
| | k, |
| | v, |
| | query_padding_mask, |
| | key_padding_mask, |
| | None, |
| | 0.0, |
| | None, |
| | causal=causal, |
| | window_size=window_size, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| |
|
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| |
|
| | g = torch.randn_like(out) |
| | do_o = (g.float() * out.float()).sum(-1) |
| | test_backward = block_table is None |
| | if test_backward: |
| | ( |
| | dq_unpad, |
| | dk_unpad, |
| | dv_unpad, |
| | ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) |
| | dq = dq_pad_fn(dq_unpad) |
| | dk = dk_pad_fn(dk_unpad) |
| | dv = dk_pad_fn(dv_unpad) |
| | ( |
| | dq_ref, |
| | dk_ref, |
| | dv_ref, |
| | ) = torch.autograd.grad(out_ref, (q, k, v), g) |
| | ( |
| | dq_pt, |
| | dk_pt, |
| | dv_pt, |
| | ) = torch.autograd.grad(out_pt, (q, k, v), g) |
| | print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
| | print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
| | print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
| | print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
| | print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
| | print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
| | print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
| | print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
| | print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
| | print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 |
| |
|
| | if test_backward: |
| | assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 |
| | assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 |
| | assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("deterministic", [False, True]) |
| | |
| | @pytest.mark.parametrize("alibi", [False, True]) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize("swap_sq_sk", [False, True]) |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (3, 1024), |
| | (1, 339), |
| | (64, 800), |
| | (3, 799), |
| | (64, 2048), |
| | (16, 20000), |
| | (16, 100000), |
| | (128, 128), |
| | (256, 256), |
| | ], |
| | ) |
| | |
| | def test_flash_attn_splitkv( |
| | seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype |
| | ): |
| | if swap_sq_sk: |
| | seqlen_q, seqlen_k = seqlen_k, seqlen_q |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 1 |
| | nheads = 12 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | if alibi: |
| | alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
| | attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) |
| | else: |
| | alibi_slopes, attn_bias = None, None |
| | out, lse, _ = flash_attn_func( |
| | q, |
| | k, |
| | v, |
| | 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | alibi_slopes=alibi_slopes, |
| | deterministic=deterministic, |
| | return_attn_probs=True, |
| | ) |
| | out_ref, attn_ref = attention_ref( |
| | q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size |
| | ) |
| | out_pt, attn_pt = attention_ref( |
| | q, |
| | k, |
| | v, |
| | None, |
| | None, |
| | attn_bias, |
| | 0.0, |
| | None, |
| | causal=causal, |
| | window_size=window_size, |
| | upcast=False, |
| | reorder_ops=True, |
| | ) |
| |
|
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| |
|
| | g = torch.randn_like(out) |
| | do_o = (g.float() * out.float()).sum(-1) |
| | ( |
| | dq, |
| | dk, |
| | dv, |
| | ) = torch.autograd.grad(out, (q, k, v), g) |
| | ( |
| | dq_ref, |
| | dk_ref, |
| | dv_ref, |
| | ) = torch.autograd.grad(out_ref, (q, k, v), g) |
| | ( |
| | dq_pt, |
| | dk_pt, |
| | dv_pt, |
| | ) = torch.autograd.grad(out_pt, (q, k, v), g) |
| | print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
| | print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
| | print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
| | print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
| | print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
| | print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
| | print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
| | print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
| | print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
| | print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 |
| |
|
| | mult = 2 if not alibi else 8 |
| | assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 |
| | assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 |
| | assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 |
| |
|
| |
|
| | |
| | @pytest.mark.parametrize("dtype", [torch.float16]) |
| | @pytest.mark.parametrize("num_splits", [1, 0]) |
| | |
| | @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
| | |
| | @pytest.mark.parametrize("new_kv", [False, True]) |
| | |
| | @pytest.mark.parametrize("alibi", [False, True]) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) |
| | |
| | @pytest.mark.parametrize("rotary_interleaved", [False, True]) |
| | |
| | @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) |
| | |
| | @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) |
| | |
| | |
| | @pytest.mark.parametrize("has_leftpad", [False, True]) |
| | |
| | |
| | @pytest.mark.parametrize("has_batch_idx", [False]) |
| | @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (1, 128), |
| | (1, 339), |
| | (3, 1024), |
| | (64, 800), |
| | (64, 256), |
| | (3, 799), |
| | (64, 2048), |
| | (16, 20000), |
| | (1, 128 * 1024), |
| | (16, 128 * 1024), |
| | (128, 128), |
| | ], |
| | ) |
| | |
| | def test_flash_attn_kvcache( |
| | seqlen_q, |
| | seqlen_k, |
| | d, |
| | has_batch_idx, |
| | has_leftpad, |
| | paged_kv_block_size, |
| | rotary_fraction, |
| | rotary_interleaved, |
| | seqlen_new_eq_seqlen_q, |
| | causal, |
| | local, |
| | alibi, |
| | new_kv, |
| | mha_type, |
| | num_splits, |
| | dtype, |
| | ): |
| | if seqlen_q > seqlen_k and new_kv: |
| | pytest.skip() |
| | if not new_kv and rotary_fraction > 0.0: |
| | pytest.skip() |
| | if has_batch_idx and paged_kv_block_size is not None: |
| | pytest.skip() |
| | if has_leftpad and paged_kv_block_size is not None: |
| | pytest.skip() |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 2 |
| | batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 |
| | nheads = 6 |
| | |
| | rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 |
| | nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) |
| | assert nheads % nheads_k == 0 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) |
| | seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() |
| | if new_kv: |
| | k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) |
| | v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) |
| | else: |
| | k, v = None, None |
| | if paged_kv_block_size is None: |
| | k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) |
| | v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) |
| | block_table = None |
| | else: |
| | ( |
| | k_cache, |
| | v_cache, |
| | block_table, |
| | k_cache_paged, |
| | v_cache_paged, |
| | num_blocks, |
| | ) = _generate_block_kvcache( |
| | seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype |
| | ) |
| | cache_seqlens = torch.randint( |
| | 0 if new_kv else 1, |
| | |
| | ( |
| | (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) |
| | if new_kv |
| | else (seqlen_k + 1) |
| | ), |
| | (batch_size,), |
| | dtype=torch.int32, |
| | device=device, |
| | ) |
| | if has_leftpad: |
| | cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) |
| | if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) |
| | for i in range(batch_size)]) |
| | else: |
| | cache_leftpad = None |
| | arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") |
| | cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") |
| | key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) |
| | if has_leftpad: |
| | key_padding_mask = torch.logical_and( |
| | key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) |
| | ) |
| | if has_batch_idx: |
| | cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ |
| | :batch_size |
| | ] |
| | else: |
| | cache_batch_idx = None |
| | if alibi: |
| | alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
| | attn_bias = attn_bias_from_alibi_slopes( |
| | alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad |
| | ) |
| | else: |
| | alibi_slopes, attn_bias = None, None |
| | |
| | if rotary_dim > 0: |
| | angle = ( |
| | torch.rand( |
| | seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, |
| | rotary_dim // 2, |
| | device=device, |
| | ) |
| | * 2 |
| | * math.pi |
| | ) |
| | cos = torch.cos(angle).to(dtype=dtype) |
| | sin = torch.sin(angle).to(dtype=dtype) |
| | if causal or local: |
| | q_ro = apply_rotary_emb( |
| | q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved |
| | ) |
| | else: |
| | q_ro = rearrange( |
| | apply_rotary_emb( |
| | rearrange(q, "b s h d -> b 1 (s h) d"), |
| | cos, |
| | sin, |
| | seqlen_offsets=cache_seqlens, |
| | interleaved=rotary_interleaved, |
| | ), |
| | "b 1 (s h) d -> b s h d", |
| | s=seqlen_q, |
| | ) |
| | |
| | k_ro = apply_rotary_emb( |
| | k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved |
| | ) |
| | else: |
| | cos, sin = None, None |
| | q_ro, k_ro = q, k |
| | |
| | k_cache_ref = ( |
| | k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] |
| | ).clone() |
| | v_cache_ref = ( |
| | v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] |
| | ).clone() |
| | if new_kv: |
| | update_mask = torch.logical_and( |
| | cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new |
| | ) |
| | k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") |
| | v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") |
| | k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| | v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
| | out = flash_attn_with_kvcache( |
| | q, |
| | k_cache if paged_kv_block_size is None else k_cache_paged, |
| | v_cache if paged_kv_block_size is None else v_cache_paged, |
| | k, |
| | v, |
| | rotary_cos=cos, |
| | rotary_sin=sin, |
| | cache_seqlens=cache_seqlens, |
| | cache_batch_idx=cache_batch_idx, |
| | cache_leftpad=cache_leftpad, |
| | block_table=block_table, |
| | causal=causal, |
| | window_size=window_size, |
| | rotary_interleaved=rotary_interleaved, |
| | alibi_slopes=alibi_slopes, |
| | num_splits=num_splits, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | out_ref, _ = attention_ref( |
| | q_ro, |
| | k_cache_rep, |
| | v_cache_rep, |
| | None, |
| | key_padding_mask, |
| | attn_bias, |
| | 0.0, |
| | None, |
| | causal=causal, |
| | window_size=window_size, |
| | key_leftpad=cache_leftpad, |
| | ) |
| | out_pt, _ = attention_ref( |
| | q_ro, |
| | k_cache_rep, |
| | v_cache_rep, |
| | None, |
| | key_padding_mask, |
| | attn_bias, |
| | 0.0, |
| | None, |
| | causal=causal, |
| | window_size=window_size, |
| | upcast=False, |
| | reorder_ops=True, |
| | key_leftpad=cache_leftpad, |
| | ) |
| | print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
| | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
| | print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
| | print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
| |
|
| | |
| | |
| | if new_kv: |
| | if paged_kv_block_size is None: |
| | k_cache_select = ( |
| | k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] |
| | ) |
| | v_cache_select = ( |
| | v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] |
| | ) |
| | else: |
| | k_cache_select = rearrange( |
| | k_cache_paged[block_table.to(dtype=torch.long).flatten()], |
| | "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| | b=batch_size, |
| | )[:, :seqlen_k] |
| | v_cache_select = rearrange( |
| | v_cache_paged[block_table.to(dtype=torch.long).flatten()], |
| | "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| | b=batch_size, |
| | )[:, :seqlen_k] |
| | assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) |
| | assert torch.equal(v_cache_select, v_cache_ref) |
| | mult = 3 if not alibi else 5 |
| | assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 |
| |
|
| |
|
| | def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): |
| | num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 |
| | k_cache_paged = torch.randn( |
| | num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype |
| | ) |
| | v_cache_paged = torch.randn( |
| | num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype |
| | ) |
| | block_table = rearrange( |
| | torch.randperm(num_blocks, dtype=torch.int32, device=device), |
| | "(b nblocks) -> b nblocks", |
| | b=batch_size, |
| | ) |
| | k_cache = rearrange( |
| | |
| | k_cache_paged[block_table.to(dtype=torch.long).flatten()], |
| | "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| | b=batch_size, |
| | )[:, :seqlen_k] |
| | v_cache = rearrange( |
| | v_cache_paged[block_table.to(dtype=torch.long).flatten()], |
| | "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
| | b=batch_size, |
| | )[:, :seqlen_k] |
| | return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks |
| |
|
| |
|
| | |
| | @pytest.mark.parametrize("dtype", [torch.float16]) |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (1, 239), |
| | (239, 1), |
| | (3, 799), |
| | (799, 3), |
| | (1024, 128), |
| | (97, 97), |
| | (128, 128), |
| | (200, 200), |
| | (256, 256), |
| | (257, 257), |
| | (384, 384), |
| | (512, 512), |
| | (768, 768), |
| | (1024, 1024), |
| | ], |
| | ) |
| | @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
| | |
| | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 60 |
| | nheads = 4 |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | torch.random.manual_seed(42) |
| | out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) |
| | g = torch.randn_like(out0) |
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | ( |
| | dq0, |
| | dk0, |
| | dv0, |
| | ) = torch.autograd.grad(out0, (q, k, v), g) |
| | |
| | dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() |
| |
|
| | for i in range(250): |
| | torch.random.manual_seed(42) |
| | out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) |
| | assert torch.equal(out, out0) |
| | assert torch.equal(lse, lse0) |
| |
|
| | if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
| | ( |
| | dq, |
| | dk, |
| | dv, |
| | ) = torch.autograd.grad(out, (q, k, v), g) |
| | dq_equal = torch.allclose(dq, dq0, atol=dq_atol) |
| | if not dq_equal: |
| | print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") |
| | assert torch.equal(dv, dv0) |
| | assert torch.equal(dk, dk0) |
| | assert dq_equal |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", [torch.float16]) |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [16, 32, 64]) |
| | |
| | @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) |
| | |
| | def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): |
| | """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, |
| | in the case where seqlen % 128 != 0. |
| | """ |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 2 |
| | nheads = 5 |
| | q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 |
| | k, v = [ |
| | torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 |
| | for _ in range(2) |
| | ] |
| | q.requires_grad_(True) |
| | k.requires_grad_(True) |
| | v.requires_grad_(True) |
| | out = flash_attn_func(q, k, v, causal=causal) |
| | g = torch.randn_like(out) |
| | out.backward(g) |
| | q_pt = q.detach().clone().requires_grad_(True) |
| | k_pt = k.detach().clone().requires_grad_(True) |
| | v_pt = v.detach().clone().requires_grad_(True) |
| | out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) |
| | out_pt.backward(g) |
| | q_ref = q.detach().clone().requires_grad_(True) |
| | k_ref = k.detach().clone().requires_grad_(True) |
| | v_ref = v.detach().clone().requires_grad_(True) |
| | out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) |
| | out_ref.backward(g) |
| | print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") |
| | print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") |
| | print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") |
| | print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
| | assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( |
| | q_pt.grad - q_ref.grad |
| | ).abs().max().item() + 1e-3 |
| | assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( |
| | k_pt.grad - k_ref.grad |
| | ).abs().max().item() + 1e-3 |
| | assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( |
| | v_pt.grad - v_ref.grad |
| | ).abs().max().item() + 1e-3 |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [64, 128]) |
| | |
| | @pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) |
| | |
| | def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): |
| | """We previously had a bug where we were using the wrong strides of dout, which shows up |
| | when dout is not contiguous. |
| | """ |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 5 |
| | nheads = 2 |
| | q, k, v = [ |
| | torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) |
| | for _ in range(3) |
| | ] |
| | out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") |
| | |
| | g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] |
| | out.backward(g) |
| | q_pt = q.detach().clone().requires_grad_(True) |
| | k_pt = k.detach().clone().requires_grad_(True) |
| | v_pt = v.detach().clone().requires_grad_(True) |
| | out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) |
| | out_pt = rearrange(out_pt, "b s ... -> s b ...") |
| | out_pt.backward(g) |
| | q_ref = q.detach().clone().requires_grad_(True) |
| | k_ref = k.detach().clone().requires_grad_(True) |
| | v_ref = v.detach().clone().requires_grad_(True) |
| | out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) |
| | out_ref = rearrange(out_ref, "b s ... -> s b ...") |
| | out_ref.backward(g) |
| | print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") |
| | print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") |
| | print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") |
| | print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") |
| | print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") |
| | print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") |
| | assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
| | assert (q.grad - q_ref.grad).abs().max().item() <= 2 * ( |
| | q_pt.grad - q_ref.grad |
| | ).abs().max().item() |
| | assert (k.grad - k_ref.grad).abs().max().item() <= 2 * ( |
| | k_pt.grad - k_ref.grad |
| | ).abs().max().item() |
| | assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( |
| | v_pt.grad - v_ref.grad |
| | ).abs().max().item() |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", [torch.float16]) |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [16, 32, 64]) |
| | |
| | def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): |
| | """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, |
| | in the case where seqlen % 128 != 0 or varlen. |
| | """ |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | nheads = 5 |
| | q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32) |
| | k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32) |
| | Mq = 256 |
| | Mk = 3 |
| |
|
| | q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3 |
| | k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)] |
| | q.requires_grad_(True) |
| | k.requires_grad_(True) |
| | v.requires_grad_(True) |
| |
|
| | out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) |
| | g = torch.randn_like(out) |
| | out.backward(g) |
| |
|
| | assert not q.grad.isnan().any() |
| | assert not k.grad.isnan().any() |
| | assert not v.grad.isnan().any() |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize("swap_sq_sk", [False, True]) |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (1, 239), |
| | (3, 799), |
| | (127, 512), |
| | (127, 513), |
| | (113, 203), |
| | (128, 217), |
| | (113, 211), |
| | (108, 256), |
| | (256, 512), |
| | (1023, 1024), |
| | ], |
| | ) |
| | |
| | def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): |
| | if ( |
| | max(seqlen_q, seqlen_k) >= 2048 |
| | and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
| | ): |
| | pytest.skip() |
| | if swap_sq_sk: |
| | seqlen_q, seqlen_k = seqlen_k, seqlen_q |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 4 |
| | nheads = 9 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) |
| |
|
| | g = torch.randn_like(out) |
| | dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) |
| | for _ in range(50): |
| | dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) |
| | assert torch.equal(dv, dv0) |
| | assert torch.equal(dk, dk0) |
| | assert torch.equal(dq, dq0) |
| |
|
| |
|
| | @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
| | |
| | @pytest.mark.parametrize("local", [False, True]) |
| | |
| | @pytest.mark.parametrize("causal", [False, True]) |
| | |
| | @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
| | |
| | |
| | |
| | |
| | |
| | @pytest.mark.parametrize("swap_sq_sk", [False, True]) |
| | |
| | @pytest.mark.parametrize( |
| | "seqlen_q,seqlen_k", |
| | [ |
| | (1, 239), |
| | (3, 799), |
| | (127, 512), |
| | (127, 513), |
| | (113, 203), |
| | (128, 217), |
| | (113, 211), |
| | (108, 256), |
| | (256, 512), |
| | (1023, 1024), |
| | ], |
| | ) |
| | |
| | def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): |
| | if ( |
| | max(seqlen_q, seqlen_k) >= 2048 |
| | and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
| | ): |
| | pytest.skip() |
| | if swap_sq_sk: |
| | seqlen_q, seqlen_k = seqlen_k, seqlen_q |
| | device = "cuda" |
| | |
| | torch.random.manual_seed(0) |
| | batch_size = 2 |
| | nheads = 9 |
| | window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
| | q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
| | query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") |
| | key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") |
| | ( |
| | q_unpad, |
| | k_unpad, |
| | v_unpad, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | q, |
| | k, |
| | v, |
| | output_pad_fn, |
| | dq_pad_fn, |
| | dk_pad_fn, |
| | ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) |
| | out = flash_attn_varlen_func( |
| | q_unpad, |
| | k_unpad, |
| | v_unpad, |
| | cu_seqlens_q, |
| | cu_seqlens_k, |
| | max_seqlen_q, |
| | max_seqlen_k, |
| | 0.0, |
| | causal=causal, |
| | window_size=window_size, |
| | deterministic=True, |
| | ) |
| |
|
| | g = torch.randn_like(out) |
| | dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) |
| | for _ in range(50): |
| | dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) |
| | assert torch.equal(dv, dv0) |
| | assert torch.equal(dk, dk0) |
| | assert torch.equal(dq, dq0) |
| |
|