Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Optional, Tuple | |
| import pytest | |
| import torch | |
| import xformers.ops | |
| from xformers.ops import fmha | |
| from .utils import assert_allclose, disable_tf32, ref_attention_for_test | |
| def ref_attention_splitk_bmhk( | |
| q, k, v, attn_bias, scale=None, split_k=None, dtype=None | |
| ) -> torch.Tensor: | |
| assert q.ndim == 4 | |
| def T(t): | |
| return t.permute((0, 2, 1, 3)).reshape( | |
| [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] | |
| ) | |
| if isinstance(attn_bias, xformers.ops.AttentionBias): | |
| attn_bias = attn_bias.materialize( | |
| (q.shape[0], q.shape[2], q.shape[1], k.shape[1]), | |
| device=q.device, | |
| dtype=torch.float32, | |
| ).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]]) | |
| out = ref_attention_splitk( | |
| T(q), T(k), T(v), attn_bias, scale=scale, split_k=split_k, dtype=dtype | |
| ) | |
| out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) | |
| return out.permute((0, 2, 1, 3)) | |
| def ref_attention_splitk( | |
| q, k, v, attn_bias, scale=None, split_k=2, dtype=None | |
| ) -> torch.Tensor: | |
| if q.ndim == 5: | |
| def attn_bias_group(group: int): | |
| if getattr(attn_bias, "HOLDS_DENSE_TENSOR", True): | |
| return attn_bias[:, group] | |
| return attn_bias | |
| return torch.stack( | |
| [ | |
| ref_attention_splitk_bmhk( | |
| q[:, :, g], | |
| k[:, :, g], | |
| v[:, :, g], | |
| attn_bias=attn_bias_group(g), | |
| split_k=split_k, | |
| dtype=dtype, | |
| ) | |
| for g in range(q.shape[2]) | |
| ], | |
| dim=2, | |
| ) | |
| if q.ndim == 4: | |
| return ref_attention_splitk_bmhk( | |
| q, k, v, attn_bias=attn_bias, split_k=split_k, dtype=dtype | |
| ) | |
| assert q.ndim == 3 | |
| if dtype is None: | |
| dtype = torch.float32 | |
| q = q.to(dtype=dtype) | |
| k = k.to(dtype=dtype) | |
| v = v.to(dtype=dtype) | |
| if scale is None: | |
| scale = q.shape[-1] ** -0.5 | |
| assert not q.isnan().any() | |
| q = q * scale | |
| assert not q.isnan().any() | |
| if attn_bias is not None: | |
| if isinstance(attn_bias, xformers.ops.AttentionBias): | |
| # Always create in B,H,Mq,Mk format | |
| attn_bias_tensor = attn_bias.materialize( | |
| (q.shape[0], 1, q.shape[1], k.shape[1]), | |
| device=q.device, | |
| dtype=torch.float32, | |
| ) | |
| else: | |
| attn_bias_tensor = attn_bias | |
| if attn_bias_tensor.ndim == 4: | |
| assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1] | |
| attn_bias_tensor = attn_bias_tensor.reshape( | |
| [-1, *attn_bias_tensor.shape[2:]] | |
| ) | |
| split_size = k.size(-2) // split_k | |
| split_config = {"dim": -2, "split_size_or_sections": split_size} | |
| k_split = torch.split(k, **split_config) | |
| v_split = torch.split(v, **split_config) | |
| attn_bias_split = torch.split( | |
| attn_bias_tensor, dim=-1, split_size_or_sections=split_size | |
| ) | |
| def compute_attention_split(q_whole, k_slice, v_slice, attn_bias_slice): | |
| p_slice = q_whole @ k_slice.transpose(-2, -1) | |
| p_slice += attn_bias_slice | |
| row_max = torch.max(p_slice, dim=-1, keepdim=True).values | |
| p_slice_scaled = p_slice - row_max | |
| p_slice_scaled[p_slice_scaled.isnan()] = float("-inf") | |
| s = torch.exp(p_slice_scaled) | |
| row_sumexp = torch.sum(s, dim=-1, keepdim=True) | |
| attn_slice = s @ v_slice | |
| return { | |
| "attn_slice": attn_slice, | |
| "row_max": row_max, | |
| "row_sumexp": row_sumexp, | |
| } | |
| splits = list(zip(k_split, v_split, attn_bias_split)) | |
| slices = list(map(lambda s: compute_attention_split(q, s[0], s[1], s[2]), splits)) | |
| out = torch.zeros_like(q) | |
| # reduce out over split-k slices | |
| global_max = torch.zeros_like(slices[0]["row_max"]).fill_(float("-inf")) | |
| global_sumexp = torch.zeros_like(slices[0]["row_sumexp"]) | |
| for s in slices: | |
| local_out = s["attn_slice"] | |
| local_max = s["row_max"] | |
| local_sumexp = s["row_sumexp"] | |
| log_alpha = -torch.abs(local_max - global_max) | |
| alpha = torch.exp(log_alpha) | |
| alpha.nan_to_num_(1.0) | |
| pick_new = local_max < global_max | |
| new_coef = torch.where(pick_new, alpha, 1.0) | |
| curr_coef = torch.where(pick_new, 1.0, alpha) | |
| out = out * curr_coef + local_out * new_coef | |
| global_sumexp = global_sumexp * curr_coef + local_sumexp * new_coef | |
| global_max = torch.max(local_max, global_max) | |
| out /= global_sumexp | |
| return out | |
| def _kv_heads_label(kv_heads: Optional[int]) -> str: | |
| if kv_heads is None: | |
| return "" | |
| if kv_heads == 1: | |
| return "mq" | |
| return f"gqa{kv_heads}" | |
| def test_splitk_reference( | |
| kv_heads: int, | |
| n_heads: int, | |
| padding: int, | |
| bsz: int, | |
| dtype: str, | |
| device: str, | |
| split_k: int, | |
| ): | |
| dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype] | |
| torch.manual_seed(1) | |
| d = 256 | |
| num_queries = 1 | |
| if kv_heads is not None and kv_heads > 1: | |
| k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d) | |
| q_shape: Tuple[int, ...] = ( | |
| 1, | |
| bsz * num_queries, | |
| kv_heads, | |
| n_heads, | |
| d, | |
| ) | |
| else: | |
| k_shape = (1, bsz * padding, n_heads, d) | |
| q_shape = (1, bsz * num_queries, n_heads, d) | |
| k = torch.rand(k_shape, dtype=dtype_, device=device) | |
| k_seqlen = torch.randint(1, padding + 1, (bsz,)).tolist() | |
| v = torch.rand_like(k) | |
| q = torch.rand(q_shape, dtype=dtype_, device=device) | |
| causal_diagonal = torch.tensor( # TODO: make unnecessary | |
| [i - 1 for i in k_seqlen], dtype=torch.int32, device=device | |
| ) | |
| if kv_heads is not None: | |
| k = k[..., :1, :].expand(k_shape) | |
| v = v[..., :1, :].expand(k_shape) | |
| attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
| q_seqlen=[1] * bsz, | |
| kv_seqlen=k_seqlen, | |
| causal_diagonal=causal_diagonal, | |
| kv_padding=padding, | |
| ) | |
| ref_out = ref_attention_for_test(q, k, v, attn_bias) | |
| splitk_out = ref_attention_splitk(q, k, v, attn_bias, None, split_k=split_k) | |
| assert_allclose( | |
| ref_out, | |
| splitk_out, | |
| atol=fmha.ck.FwOp.ERROR_ATOL[dtype_], | |
| rtol=fmha.ck.FwOp.ERROR_RTOL[dtype_], | |
| ) | |