diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__init__.py b/.venv/lib/python3.11/site-packages/xformers/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed312692841c5ec4a77717fbd48432c30ae519b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/__init__.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import warnings +from dataclasses import fields +from pathlib import Path +from typing import Any, Dict, Union + +from xformers.utils import import_all_modules + +from .activations import Activation, build_activation # noqa +from .attention import Attention, build_attention # noqa +from .input_projection import InputProjection, InputProjectionConfig # noqa +from .multi_head_dispatch import MultiHeadDispatch # noqa +from .multi_head_dispatch import MultiHeadDispatchConfig +from .patch_embedding import PatchEmbeddingConfig # noqa +from .patch_embedding import build_patch_embedding # noqa +from .residual import NormalizationType # noqa +from .residual import PostNorm # noqa +from .residual import PreNorm # noqa +from .residual import RequiresWrappedInputs # noqa +from .residual import Residual # noqa +from .residual import ResidualNormStyle # noqa + +warnings.warn( + "xformers.components is deprecated and is not maintained anymore. " + "It might be removed in a future version of xFormers ", + FutureWarning, + stacklevel=2, +) + + +# automatically import any Python files in the directory +import_all_modules(str(Path(__file__).parent), "xformers.components") + + +def build_multi_head_attention( + multi_head_config: Union[MultiHeadDispatchConfig, Dict[str, Any]], +): + """Builds a multihead attention from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_attention", + "foo": "bar"}` will find a class that was registered as "my_attention" + (see :func:`register_attention`) and call .from_config on it.""" + + if not isinstance(multi_head_config, MultiHeadDispatchConfig): + # Extract the required fields + field_names = list(map(lambda x: x.name, fields(MultiHeadDispatchConfig))) + + # The missing fields get Noned + for k in field_names: + if k not in multi_head_config.keys(): + multi_head_config[k] = None + + # Could be that the attention needs to be instantiated + if not isinstance(multi_head_config["attention"], Attention): + # Convenience: fill in possible missing fields + if "num_heads" not in multi_head_config["attention"]: + multi_head_config["attention"]["num_heads"] = multi_head_config[ + "num_heads" + ] + + if "dim_model" not in multi_head_config["attention"]: + multi_head_config["attention"]["dim_model"] = multi_head_config[ + "dim_model" + ] + + if ( + "dim_features" not in multi_head_config["attention"] + or multi_head_config["attention"]["dim_features"] is None + ): + multi_head_config["attention"]["dim_features"] = ( + multi_head_config["dim_model"] // multi_head_config["num_heads"] + ) + + multi_head_config["attention"] = build_attention( + multi_head_config["attention"] + ) + + multi_head_config = MultiHeadDispatchConfig(**multi_head_config) + + return MultiHeadDispatch.from_config(multi_head_config) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b93e069ff5d47984e50131dca60548e351b5a0a3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/activations.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/activations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74c4e36013e8df9d3e77da9bfaa393f121457c99 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/activations.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/input_projection.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/input_projection.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a948f250877e0fb44ed5a9920cd15e11df4ddfb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/input_projection.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/multi_head_dispatch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/multi_head_dispatch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61f573c8845e92866f58038ccaaaebc3c5c98aea Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/multi_head_dispatch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/patch_embedding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/patch_embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b69601674014ac3a0abc5afe8e0133c8b66ea70 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/patch_embedding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/residual.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/residual.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eb72e167ef8536e7e74e415dd90edcb0cc4be1c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/residual.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/reversible.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/reversible.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86ad38ed976efa9897a1a042dbc9122bfae8b44c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/reversible.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/simplicial_embedding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/simplicial_embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da0c3bb4461530fd8b3e2100d5a34bc289828d93 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/__pycache__/simplicial_embedding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/activations.py b/.venv/lib/python3.11/site-packages/xformers/components/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..314a7962df220bc3ddb6f5060a0cea630505a806 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/activations.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from enum import Enum +from typing import Optional + +import torch +from torch import nn + +from xformers._deprecation_warning import deprecated_function + + +class Activation(str, Enum): + SquaredReLU = "squared_relu" + GeLU = "gelu" + LeakyReLU = "leaky_relu" + ReLU = "relu" + SmeLU = "smelu" + StarReLU = "star_relu" + + +# For unit testing / parity comparisons, probably not the fastest way +class SquaredReLU(nn.Module): + def __init__(self) -> None: + super().__init__() + deprecated_function(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_ = torch.nn.functional.relu(x) + return x_ * x_ + + +class StarReLU(nn.Module): + def __init__(self) -> None: + super().__init__() + deprecated_function(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_ = torch.nn.functional.relu(x) + return 0.8944 * x_ * x_ - 0.4472 + + +class SmeLU(nn.Module): + def __init__(self, beta: float = 2.0) -> None: + super().__init__() + self.beta = beta + deprecated_function(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + relu = torch.where( + x >= self.beta, + x, + torch.tensor([0.0], device=x.device, dtype=x.dtype), + ) + return torch.where( + torch.abs(x) <= self.beta, + ((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta), + relu, + ) + + +def build_activation(activation: Optional[Activation]): + if not activation: + return nn.Identity() + + return { + Activation.ReLU: nn.ReLU, + Activation.GeLU: nn.GELU, + Activation.LeakyReLU: nn.LeakyReLU, + Activation.SquaredReLU: SquaredReLU, + Activation.StarReLU: StarReLU, + Activation.SmeLU: SmeLU, + }[activation]() diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/attention_patterns.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/attention_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..9c817debb926a27e32aedcaa728f777f812f313c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/attention_patterns.py @@ -0,0 +1,295 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import List + +import numpy as np +import torch + +from xformers.components.attention.sparsity_config import ( + BigBirdSparsityConfig, + BSLongformerSparsityConfig, + FixedSparsityConfig, + VariableSparsityConfig, +) + + +# generic nd cases +def _generate_nd_grid(*sizes): + coords = [torch.arange(s) for s in sizes] + return torch.meshgrid(*coords) + + +def local_nd_distance(*sizes, p=2.0, weights=None): + if weights is None: + weights = (1,) * len(sizes) + assert len(sizes) == len(weights) + grid = _generate_nd_grid(*sizes) + grid = [i.flatten() * w for i, w in zip(grid, weights)] + grid = torch.stack(grid, dim=1).float() + d = torch.cdist(grid, grid, p=p) + return d + + +def local_nd_gaussian_distribution(*sizes, sigma=1): + d = local_nd_distance(*sizes, p=2.0) ** 2 + d = torch.exp(-0.5 * sigma ** (-2.0) * d) + return d + + +def local_nd_pattern(*sizes, distance, p=2.0): + d = local_nd_distance(*sizes, p=p) + return d < distance + + +def axial_nd_pattern(*sizes): + # axial is a special case with p=0 and distance=2 + d = local_nd_distance(*sizes, p=0) + return d < 2 + + +def random_pattern_from_probability_matrix(dist_matrix, nnz): + att = torch.zeros_like(dist_matrix, dtype=torch.bool) + # PyTorch multinomial wrongly doesn't support sampling when number of categories + # is > 2^24, arguing that it's because it's the max representable consecutive element + # in fp32 and that the kernels use float32. This is actually not true, and the kernels + # should work fine if double tensor is passed on CPU. This is a bug that was introduced + # in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66 + # when unifying the checks between CPU and CUDA. For now, just fall-back to numpy + if dist_matrix.numel() > 2**24: + dist_matrix = dist_matrix.double() + dist_matrix /= dist_matrix.sum() + idxs = np.random.choice( + dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False + ) + idxs = torch.as_tensor(idxs) + else: + idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False) + att.view(-1)[idxs] = True + return att + + +def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Tensor: + assert attention_query_mask.ndim == 1 + assert attention_query_mask.dtype == torch.bool + attention_query_mask = attention_query_mask[None, :] + mask = attention_query_mask | attention_query_mask.transpose(1, 0) + return mask + + +def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor: + assert 0 < sparsity < 1 + mask = torch.rand(attn_size, attn_size) > sparsity + return mask + + +# 1d-specific cases +def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor: + assert ( + window_size % 2 == 1 + ), "The window size is assumed to be odd (counts self-attention + 2 wings)" + h_win_size = window_size // 2 + 1 + return local_nd_pattern(attn_size, distance=h_win_size, p=1.0) + + +def causal_1d_pattern(attn_size: int) -> torch.Tensor: + mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool)) + return mask + + +# 2d-specific cases +def horizontal_axial_2d_distance(H, W, p=2.0): + d = local_nd_distance(H, W, p=p, weights=(1, 0)) + return d + + +def vertical_axial_2d_distance(H, W, p=2.0): + d = local_nd_distance(H, W, p=p, weights=(0, 1)) + return d + + +def local_2d_distance(H, W, p=2.0): + return local_nd_distance(H, W, p=p) + + +def local_2d_gausian_distribution(H, W, sigma=1): + return local_nd_gaussian_distribution(H, W, sigma=sigma) + + +def local_2d_pattern(H, W, distance, p=2.0): + return local_nd_pattern(H, W, distance=distance, p=p) + + +def axial_2d_pattern(H, W): + return axial_nd_pattern(H, W) + + +def swin_attention_pattern(H, W, window_size, shift_size=0): + assert H % window_size == 0 + assert W % window_size == 0 + assert 0 <= shift_size < window_size, "shift_size must in 0-window_size" + + # input grid + i, j = _generate_nd_grid(H, W) + i, j = i + 0.5, j + 0.5 + + # anchors grid + # if shift is present, add extra element to the grid + # to account for the uneven partitioning + extra = int(shift_size % window_size != 0) + grid_h = H // window_size + extra + grid_w = W // window_size + extra + + ii, jj = _generate_nd_grid(grid_h, grid_w) + # convert shift to be compatible with the paper representation + s = (-shift_size) % window_size + offset = window_size / 2 - s + ii = ii * window_size + offset + jj = jj * window_size + offset + + input_coords = torch.stack([i.flatten(), j.flatten()], 1).float() + anchors_coords = torch.stack([ii.flatten(), jj.flatten()], 1).float() + + anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1) + mask = anchor_id[:, None] == anchor_id[None, :] + return mask + + +def dilated_2d_pattern(H, W, k=2): + """ + Returns a 2d pattern that samples 1 every k elements in the attention mask. + Can be seen as a form of downsampling, where every pixel attends to a downsampled + version of the input. + """ + d_h = local_nd_distance(H, W, p=1, weights=(1, 0)) + d_w = local_nd_distance(H, W, p=1, weights=(0, 1)) + d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0) + return d + + +# Block sparse utils +def block_sparsify_tensor(x, mask, block_size): + """ + Block sparsify a tensor, given a mask and block size + """ + ret = torch.empty( + (x.size(0), mask.sum(), block_size, block_size), dtype=x.dtype, device=x.device + ) + + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[ + :, + h, + i * block_size : (i + 1) * block_size, + j * block_size : (j + 1) * block_size, + ] + return ret + + +def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor: + r""" + Given a mask pattern and blocksize, return the corresponding layout + which makes sure that all the positives in the mask are covered + """ + assert mask.ndim >= 2, "We're expecting [Heads, Seq, Seq] or [Seq, Seq]" + _should_squeeze = False + + if mask.ndim == 2: + mask = mask.unsqueeze(0) + _should_squeeze = True + + assert ( + mask.shape[1] % block_size == 0 and mask.shape[2] % block_size == 0 + ), "We're only handling masks divisible by block_size" + + # Now mark the mask + layout = torch.nn.functional.max_pool2d( + mask.to(torch.float), kernel_size=block_size, stride=block_size + ) + layout = layout.to(torch.long) + + if _should_squeeze: + layout.squeeze_(0) + + return layout + + +def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Tensor: + r""" + Use the additive bias computation from ALiBi_ to generate a mask. + Note that this mask can in turn be used to generate a blocksparse attention computation layout + + .. note: mask_shape is expected to hold the [heads, seq, seq] dimensions + + .. _ALiBi: https://arxiv.org/pdf/2108.12409.pdf + """ + + # CREDITS: code snippet from Ofir Press, one of the authors + + def get_slopes(n: int): + def get_slopes_power_of_2(n: int) -> List[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + # In the paper, we only train models that have 2^a heads for some a. This function has + # some good properties that only occur when the input is a power of 2. To maintain that even + # when the number of heads is not a power of 2, we use this workaround. + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + maxpos = mask_shape[1] + attn_heads = mask_shape[0] + slopes = torch.Tensor(get_slopes(attn_heads)) + + # In the next line, the part after the * is what constructs the diagonal matrix + # (right matrix in Figure 3 in the paper). + # If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3, + # but one where all rows are identical. + # This works because the softmax operation is invariant to translation, + # and our bias functions are always linear. + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze( + 0 + ).unsqueeze(0).expand(attn_heads, -1, -1) + alibi = alibi.view(attn_heads, 1, maxpos) + + # Now threshold arbitrarily, report the mask + return alibi < threshold + + +def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int): + config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def quick_variable_layout(num_heads: int, block_size: int, seq_len: int): + config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int): + config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int): + config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def layout_to_pattern(layout: torch.Tensor, block_size: int): + r""" + create a pattern of shape [heads, seq, seq] out of a blocksparse + layout of shape [heads, seq/block_size, seq/block_size] + """ + return torch.kron(layout, torch.ones(block_size, block_size)) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/core.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/core.py new file mode 100644 index 0000000000000000000000000000000000000000..3a201fb5124100425bd9a6ab8b9140a996092d76 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/core.py @@ -0,0 +1,248 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import math +from contextlib import nullcontext +from typing import Optional, Union + +import torch + +from xformers import _has_cpp_library +from xformers.components.attention.attention_mask import AttentionMask + +if _has_cpp_library: + from ._sputnik_sparse import SparseCS + +logger = logging.getLogger("xformers") + + +def _create_random_sparsity(matrix, sparsity, divisible_by=4): + assert matrix.ndim == 3 + keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity + nonzero = torch.nonzero(keep) + nnz = nonzero.shape[0] + # NOTE: need to make it a multiple of 4 for sputnik + nonzero = nonzero[: (nnz - nnz % divisible_by)] + i, j = nonzero.unbind(1) + output = torch.zeros_like(matrix) + bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None] + output[bdim, i, j] = matrix[bdim, i, j] + return output + + +def _broadcast_batch(mask, batch_size): + if mask.ndim == 3: + return mask + assert mask.ndim == 2 + + mask = mask.coalesce() + values = mask.values() + indices = mask.indices() + nnz = len(values) + # strategy: repeat the indices and append the extra batch dimension to the indices + indices = indices.repeat(1, batch_size) + # now create the batch indices + batch_indices = torch.arange(batch_size, device=indices.device) + batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten() + + # put them together + indices = torch.cat([batch_indices[None, :], indices], dim=0) + + # now repeat the values + values = values.repeat(batch_size) + + size = (batch_size,) + mask.shape + + return torch.sparse_coo_tensor(indices, values, size) + + +def _matmul_with_mask( + a: torch.Tensor, + b: torch.Tensor, + mask: Optional[Union[torch.Tensor, "SparseCS"]], +) -> torch.Tensor: + if mask is None: + return a @ b + + if _has_cpp_library and mask.dtype == torch.bool: + if isinstance(mask, SparseCS): + return mask.matmul_with_mask(a, b) + if mask.is_sparse: + # perform broadcasting if needed + mask = _broadcast_batch(mask, a.shape[0]) + + # coalesced is not implemented for bool tensors, so need to cast + mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above + + return torch.ops.xformers.matmul_with_mask(a, b, mask) + + # Non optimized codepath + if _has_cpp_library: + assert not isinstance(mask, SparseCS) + + att = a @ b + if mask.dtype == torch.bool: + assert not isinstance(mask, SparseCS) + if mask.ndim == 2: + mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) + # mask is presumed false == ignore + att[~mask] = float("-inf") + else: + # mask is presumed additive + # repeat if batch sizes don't match + if ( + not isinstance(mask, SparseCS) + and mask.ndim == 3 + and mask.shape[0] != att.shape[0] + and (att.shape[0] % mask.shape[0]) == 0 + ): + repeat_factor = att.shape[0] // mask.shape[0] + mask = mask.repeat([repeat_factor, 1, 1]) + logger.info("Mismatched batch dimensions for mask, repeating mask.") + att += mask + return att + + +def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor: + if _has_cpp_library and isinstance(a, SparseCS): + return a.softmax() + + if a.is_sparse: + return torch.sparse.softmax(a, dim=a.ndim - 1) + + return torch.softmax(a, dim=a.ndim - 1) + + +if _has_cpp_library: + + class SparseBMM(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b): + a = a.coalesce() + r = torch.bmm(a, b) + ctx.save_for_backward(a, b) + return r + + @staticmethod + def backward(ctx, grad): + a, b = ctx.saved_tensors + + # gradients w.r.t. a + ga = None + if ctx.needs_input_grad[0]: + ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a) + + # gradients w.r.t. b + gb = None + if ctx.needs_input_grad[1]: + gb = a.transpose(1, 2).bmm(grad) + + return ga, gb + + def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Batch matrix multiply between a sparse matrix and a dense matrix + """ + assert a.ndim == b.ndim == 3 + assert a.shape[0] == b.shape[0] + assert a.shape[2] == b.shape[1] + return SparseBMM.apply(a, b) + + +def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + if _has_cpp_library: + if isinstance(a, SparseCS): + return a.spmm(b) + if a.is_sparse: + return _sparse_bmm(a, b) + return a @ b + + +def _apply_dropout(att, dropout): + if dropout is None: + return att + + # Dropout chokes on sparse tensors + if _has_cpp_library: + if isinstance(att, SparseCS): + values = att.values.clone() + values = dropout(values) + att = SparseCS.wrap( + att.shape, + values, + att.row_indices, + att.row_offsets, + att.column_indices, + att._transp_info, + ) + elif att.is_sparse: + att = att.coalesce() + values = att.values().clone() # protect against in-place dropout + values = dropout(values) + att = torch.sparse_coo_tensor(att.indices(), values, att.shape) + else: + # Simple dense case + att = dropout(att) + + return att + + # Non optimized vanilla dropout + att = dropout(att) + return att + + +def scaled_query_key_softmax( + q: torch.Tensor, + k: torch.Tensor, + att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]], +) -> torch.Tensor: + # TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh + # this is needed due to limitations in sparse_bmm for now + + # Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S) + q = q / math.sqrt(k.size(-1)) + + # Matmul with mask + if att_mask is not None and isinstance(att_mask, AttentionMask): + # Additive mask + mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values + else: + mask = att_mask + + att = _matmul_with_mask(q, k.transpose(-2, -1), mask) + + # Softmax to get the attention probabilities + is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal + att = _softmax(att, causal=is_causal) + return att + + +def scaled_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]], + dropout: Optional[torch.nn.Module] = None, +) -> torch.Tensor: + autocast_disabled = ( + _has_cpp_library + and isinstance(att_mask, SparseCS) + or (att_mask is not None and att_mask.is_sparse) + ) + with torch.amp.autocast("cuda", enabled=False) if autocast_disabled else nullcontext(): # type: ignore + if autocast_disabled: + q, k, v = q.float(), k.float(), v.float() + + att = scaled_query_key_softmax(q, k, att_mask=att_mask) + + # Optional dropout, could be part of the masking in the future + att = _apply_dropout(att, dropout) + + # Get to the predicted values, for all heads + # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs) + y = bmm(att, v) + return y diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/favor.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/favor.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dfbc53ab0314b7a0aae93c632f8eacd82b9213 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/favor.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.amp import autocast + +from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.feature_maps import ( + FeatureMap, + FeatureMapType, + SMHyperbolic, + SMOrf, + SMReg, +) + +logger = logging.getLogger("xformers") + + +@dataclass +class FavorAttentionConfig(AttentionConfig): + causal: Optional[bool] + dim_features: Optional[int] = None # The dimensions of the random features + dim_head: Optional[ + int + ] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate + iter_before_redraw: Optional[ + int + ] = None # The number of iterations before the random features are re-drawn from scratch + feature_map: Optional[FeatureMapType] = None + + +@register_attention("favor", FavorAttentionConfig) +class FavorAttention(Attention): + def __init__( + self, + causal: bool = False, + dropout: float = 0.0, + dim_features: Optional[int] = None, + dim_head: Optional[int] = None, + iter_before_redraw: Optional[int] = None, + feature_map_type: FeatureMapType = FeatureMapType.SMReg, + normalize_inputs: bool = False, + *_, + **__, + ): + r""" + Kernelized attention, as proposed in Performers_ + ("Rethinking attention with performers." K. Choromanski et al. (2020).). + + FAVOR stands for "Fast Attention Via positive Orthogonal Random features" + + Args: + dropout (float): the probability of an output to be randomly dropped at training time + dim_features (int): the dimension of the random features space + iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features + feature_map_type (FeatureMapType): the type of feature map being used, + for instance orthogonal random features. + + .. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf + """ + super().__init__() + + self.causal = causal + self.iter_before_redraw = ( + (2 * iter_before_redraw) + if iter_before_redraw is not None + else iter_before_redraw + ) # This will be used for both key and query + self.normalize_inputs = normalize_inputs + self.feature_map_type = feature_map_type + self.attn_drop = nn.Dropout(dropout, inplace=True) + + # Setup dimension-dependent variables + # Reasonable dimension default + if dim_features is None: + assert dim_head is not None, "dim_features or dim_head needs to be passed" + self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head))) + self.dim_features = 2 * ( + self.dim_features // 2 + ) # needs to be even for some variants + logger.info( + f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}" + ) + else: + self.dim_features = dim_features + + feature_map_constructor = { + FeatureMapType.SMHyp: SMHyperbolic, + FeatureMapType.SMReg: SMReg, + FeatureMapType.SMOrf: SMOrf, + }[self.feature_map_type] + + feature_settings = { + "dim_features": self.dim_features, + "iter_before_redraw": self.iter_before_redraw, + "normalize_inputs": self.normalize_inputs, + } + + self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore + + # Properties specific to this attention mechanism + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + @staticmethod + def _maybe_promote(x: torch.Tensor) -> torch.Tensor: + # Only promote fp16 buffers, bfloat16 would be fine for instance + return x.float() if x.dtype == torch.float16 else x + + @staticmethod + def _causal_attention( + k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Algorithm 1 in the paper + ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB + Gps = k_prime.unsqueeze(3) * v.unsqueeze(2) + Grenorm = k_prime.unsqueeze(3) * ref_v + + # Consolidate against the feature dimension + att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime) + att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime) + + # Cumulative sum over the sequence + att_raw = att_raw.cumsum(2) + att_norm = att_norm.cumsum(2) + + return att_raw, att_norm + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *_, + **__, + ): + + # Project key and queries onto the feature map space + k_prime = self.feature_map(k) + q_prime = self.feature_map(q) + + with autocast("cuda", enabled=False): + # The softmax kernel approximation for Favor will easily overflow + # Force the computations here to stay in fp32 for numerical stability + # Note that the dimensions are vastly reduced when compared to scaled_dot_product + k_prime = self._maybe_promote(k_prime) + q_prime = self._maybe_promote(q_prime) + v = self._maybe_promote(v) + + if not self.causal: + att_normalization = q_prime @ ( + k_prime.transpose(-2, -1) @ torch.ones_like(v) + ) + att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v) + else: + # Actually compute attention + att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v) + + # Normalize + att = att_raw / att_normalization + + if self.attn_drop is not None: + att = self.attn_drop(att) + + return att diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/fourier_mix.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/fourier_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..ad32591f2b794c48ee7a0908685b309bf6cc977d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/fourier_mix.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.amp import autocast + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +@register_attention("fourier_mix", AttentionConfig) +class FourierMix(Attention): + def __init__(self, dropout: float, *_, **__): + """ + FFT-based pseudo-attention mechanism, from + " + "FNet: Mixing Tokens with Fourier Transforms" + Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf + """ + super().__init__() + self.attn_drop = torch.nn.Dropout(dropout, inplace=False) + + # Properties specific to this attention mechanism + self.supports_attention_mask = False + self.requires_input_projection = False + + def forward(self, q: torch.Tensor, *_, **__): + # Guard against autocast / fp16, not supported by torch.fft.fft2 + with autocast("cuda", enabled=False): + att = torch.fft.fft2(q).real + + att = self.attn_drop(att) + + return att diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/lambda_layer.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/lambda_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..0002a20cbcb8e5c2e90470f9afe5b9655855e48d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/lambda_layer.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +import torch + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +def calc_rel_pos(n: int): + # Adapted from LucidRains + # https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py + rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n] + rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2] + return rel_pos + + +@dataclass +class LambdaLayerConfig(AttentionConfig): + seq_len: int # dimension of the input sequence + dim_head: int + + +@register_attention("lambda", LambdaLayerConfig) +class LambdaLayer(Attention): + def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__): + """ + Attention approximation using Lambda layers, from + "Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021). + """ + super().__init__() + + # Possible extensions: + # - support different dimensions for key and queries + # - support varying dimensions in between inputs and outputs + # - support u hyperparam + + self.rel_pos_emb = torch.nn.Parameter( + torch.randn(2 * seq_len - 1, int(dim_head)) + ) + self.rel_pos = calc_rel_pos(seq_len) + self.attn_drop = torch.nn.Dropout(dropout, inplace=True) + + # Properties specific to this attention mechanism + self.requires_same_k_q_dimensions = True + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs + ): + """..NOTE: We're reusing the einsum notation suggested by the paper, changed in that + heads are folded in the batch dimension""" + + content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v) + content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda) + + rel_pos_emb = self.rel_pos_emb[self.rel_pos] + + # Handle real sequence length being possibly smaller + seq_len = q.shape[1] + rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :] + + # Compute the position lambda for every possible combination in one go, then compute the + # position related contribution + position_lambdas = torch.einsum( + "mnk,bnv->bnkv", rel_pos_emb, v + ) # one lambda per position + position_output = (q.unsqueeze(2) @ position_lambdas).squeeze() + att = content_output + position_output + + att = self.attn_drop(att) + + return att diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/local.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/local.py new file mode 100644 index 0000000000000000000000000000000000000000..3220a8d401df65d28aaa93f57d9c917066464f17 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/local.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + maybe_sparsify, + register_attention, + sparsify, +) +from xformers.components.attention.attention_patterns import ( + causal_1d_pattern, + local_1d_pattern, +) +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class LocalAttentionConfig(AttentionConfig): + causal: Optional[bool] = None + window_size: Optional[int] = None + force_sparsity: Optional[bool] = None + + +@register_attention("local", LocalAttentionConfig) +class LocalAttention(Attention): + def __init__( + self, + dropout: float = 0.0, + causal: bool = False, + window_size: int = 5, + force_sparsity: bool = False, + *args, + **kwargs, + ): + + r""" + An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_ + + + Args: + dropout (float): the probability of an output to be randomly dropped at training time + causal (bool): apply a causal mask, in that the attention cannot be applied to the future + window_size (int): the overall window size for local attention. + Odd number is expected if the mask is not causal, as the window size will be evenly + distributed on both sides of each query + + + .. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf + + .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf + + .. _Longformer: https://arxiv.org/pdf/2004.05150.pdf + + """ + super().__init__() + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.causal = causal + self.force_sparsity = force_sparsity + + if not self.causal: + assert ( + window_size % 2 == 1 + ), "The window size is assumed to be odd (counts self-attention + 2 wings)" + + self.window_size = window_size + self.attention_mask: Optional[torch.Tensor] = None + self.requires_same_k_q_dimensions = True + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + def _get_local_mask(self, shape: torch.Size) -> torch.Tensor: + window_size = self.window_size * 2 + 1 if self.causal else self.window_size + mask = local_1d_pattern(shape[1], window_size) + + if self.causal: + mask &= causal_1d_pattern(shape[1]) + + mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask) + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + *args, + **kwargs, + ): + # Local window attention masking + if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]: + self.attention_mask = self._get_local_mask(q.shape).to(q.device) + + # Take into account the optional user mask + if att_mask is None: + mask = self.attention_mask + else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() + mask = self.attention_mask & att_mask + + return scaled_dot_product_attention( + q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/nystrom.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/nystrom.py new file mode 100644 index 0000000000000000000000000000000000000000..93e40b74de3bae5171555b2ec217a81cdd99bb01 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/nystrom.py @@ -0,0 +1,295 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.core import ( + scaled_dot_product_attention, + scaled_query_key_softmax, +) +from xformers.components.attention.utils import ( + bool_mask_to_additive, + iterative_pinv, + reshape_key_padding_mask, +) + +logger = logging.getLogger("xformers") + + +@dataclass +class NystromSelfAttentionConfig(AttentionConfig): + """ + num_heads Number of heads. + num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good + approximation according to https://arxiv.org/pdf/2102.03902.pdf. + causal Apply a causal mask, in that the attention cannot be applied to the future. + use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose + inverse, otherwise use standard torch inverse. + pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using + method from (Razavi et al. 2014). + False if using exact coefficient computation (leads to faster convergence). + inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse. + v_skip_connection A module that will take V as input and will be added as a skip connection to the + softmax approximation. A skip connection is added in the paper to help with training. + conv_kernel_size Kernel size for convolution optionally added to help in training. + If v_skip_connection is not specified, this will be used to define the default + depth wise convolution used as a skip connection. + If both conv_kernel_size and v_skip_connection are None, no skip connection will + be added. + landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d. + """ + + num_heads: int + num_landmarks: Optional[int] + landmark_pooling: Optional[nn.Module] + causal: Optional[bool] + pinverse_original_init: Optional[bool] + inv_iterations: Optional[int] + v_skip_connection: Optional[nn.Module] + conv_kernel_size: Optional[int] + use_razavi_pinverse: Optional[bool] + + +class AvgPool(nn.Module): + def __init__(self, n: int): + super().__init__() + self.n = n + + def forward(self, x: torch.Tensor): + # Average independently for every segment in the sequence dimension + seq_len = x.shape[1] + head_dim = x.shape[2] + segments = seq_len // self.n + assert segments > 0, "num_landmarks should be smaller than the sequence length" + + # Dimensions are a match + if seq_len % self.n == 0: + return x.reshape( + -1, + self.n, + segments, + head_dim, + ).mean(dim=-2) + + # Handle the last segment boundary being off + n_round = self.n - seq_len % self.n + + x_avg_round = ( + x[:, : n_round * segments, :] + .reshape(-1, n_round, segments, head_dim) + .mean(dim=-2) + ) + x_avg_off = ( + x[:, n_round * segments :, :] + .reshape(-1, self.n - n_round, segments + 1, head_dim) + .mean(dim=-2) + ) + return torch.cat((x_avg_round, x_avg_off), dim=-2) + + +@register_attention("nystrom", NystromSelfAttentionConfig) +class NystromAttention(Attention): + # TODO: update defaults for use_razavi_pinverse and inv_iterations + def __init__( + self, + dropout: float, + num_heads: int, + num_landmarks: int = 64, + landmark_pooling: Optional[nn.Module] = None, + causal: bool = False, + use_razavi_pinverse: bool = True, + pinverse_original_init: bool = False, + inv_iterations: int = 6, # recommended default in paper was 6. + v_skip_connection: Optional[nn.Module] = None, + conv_kernel_size: Optional[int] = None, + *args, + **kwargs, + ): + """ + Nystrom attention mechanism, from Nystromformer_. + :: + + "A Nystrom-based Algorithm for Approximating Self-Attention." + Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021) + + Reference codebase: https://github.com/mlpen/Nystromformer + + .. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf + + """ + super().__init__() + # merged key padding mask and attention mask is not accepted + self.requires_separate_masks = True + self.num_landmarks = num_landmarks + # TODO: should be able to not have to pass in num_heads + self.num_heads = num_heads + self.use_razavi_pinverse = use_razavi_pinverse + self.pinverse_original_init = pinverse_original_init + self.inv_iterations = inv_iterations + self.attn_drop = nn.Dropout(dropout) + self.skip_connection = v_skip_connection + self.causal = causal + + if self.skip_connection is None and conv_kernel_size is not None: + self.skip_connection = nn.Conv2d( + in_channels=self.num_heads, + out_channels=self.num_heads, + kernel_size=(conv_kernel_size, 1), + padding=(conv_kernel_size // 2, 0), + bias=False, + groups=self.num_heads, + ) + + if landmark_pooling is not None: + self.landmark_pooling = landmark_pooling + else: + self.landmark_pooling = AvgPool(n=self.num_landmarks) + + # Optional lower triangular masks for causal attention + self.causal_mask_1: Optional[torch.Tensor] = None + self.causal_mask_2: Optional[torch.Tensor] = None + self.causal_mask_3: Optional[torch.Tensor] = None + + # This attention does not support attention masks + self.supports_attention_mask = False + self.supports_key_padding_mask = True + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + r""" + key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or + (batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will + be ignored. An additive mask is expected, meaning float values using "-inf" to mask values + """ + + batched_dim = k.size(0) + seq_len = k.size(-2) + tt = {"dtype": q.dtype, "device": q.device} + + if key_padding_mask is not None: + if key_padding_mask.dtype == torch.bool: + logger.warning( + "Bool mask found, but an additive mask is expected. Converting but this is slow" + ) + + key_padding_mask = bool_mask_to_additive(key_padding_mask) + + if key_padding_mask.ndim == 2: + key_padding_mask = reshape_key_padding_mask( + key_padding_mask, batched_dim + ) + + zeros = torch.zeros_like(key_padding_mask) + ones = torch.ones_like(key_padding_mask) + is_masked = torch.isinf(-key_padding_mask) + + # _mask takes 1 if the token is not padded, otherwise 0. + _mask = torch.where(is_masked, zeros, ones) + _mask = _mask.transpose(2, 1) + assert _mask.shape == (batched_dim, q.shape[1], 1) + + # Mask q and k before pooling + # https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31 + q = q * _mask + k = k * _mask + + assert key_padding_mask.size() == (batched_dim, 1, seq_len), ( + f"key_padding_mask has invalid dimensions {key_padding_mask.size()}." + f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})." + ) + + if self.num_landmarks >= seq_len: + mask: Optional[torch.Tensor] = None + + if self.causal: + mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt) + + if key_padding_mask is not None: + mask = key_padding_mask if mask is None else mask + key_padding_mask + + x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask) + + else: + q_landmarks = self.landmark_pooling(q) + k_landmarks = self.landmark_pooling(k) + + if self.causal and ( + self.causal_mask_1 is None + or (batched_dim, seq_len, self.num_landmarks) + != self.causal_mask_1.size() + ): + self.causal_mask_1 = self._triu_mask( + batched_dim, seq_len, self.num_landmarks, **tt + ) + self.causal_mask_2 = self._triu_mask( + batched_dim, self.num_landmarks, self.num_landmarks, **tt + ) + self.causal_mask_3 = self._triu_mask( + batched_dim, self.num_landmarks, seq_len, **tt + ) + + mask_3: Optional[torch.Tensor] = self.causal_mask_3 + if key_padding_mask is not None: + mask_3 = ( + key_padding_mask if mask_3 is None else mask_3 + key_padding_mask + ) + + kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None) + kernel_2 = scaled_query_key_softmax( + q=q_landmarks, k=k_landmarks, att_mask=None + ) + kernel_3 = scaled_dot_product_attention( + q=q_landmarks, k=k, v=v, att_mask=mask_3 + ) + + kernel_2_inv = ( + iterative_pinv( + kernel_2, self.inv_iterations, self.pinverse_original_init + ) + if self.use_razavi_pinverse + else torch.linalg.pinv(kernel_2) + ) + + x = torch.matmul( + torch.matmul( + kernel_1, + kernel_2_inv, + ), + kernel_3, + ) + + if self.skip_connection: + # Assumption here is that v is 3D. + v_conv = self.skip_connection( + v.reshape(-1, self.num_heads, v.size(-2), v.size(-1)) + ) + x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1)) + x = self.attn_drop(x) + return x + + def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor: + device = kwargs["device"] + dtype = kwargs["dtype"] + + return torch.triu( + torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"), + diagonal=1, + ).expand( + dim_1, -1, -1 + ) # micro optim, save memory on the batch dimension diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/random.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/random.py new file mode 100644 index 0000000000000000000000000000000000000000..e07e6c8679ba9f50a6e39b6feaaefd1de72f671d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/random.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + maybe_sparsify, + register_attention, + sparsify, +) +from xformers.components.attention.attention_patterns import ( + causal_1d_pattern, + random_pattern, +) +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class RandomAttentionConfig(AttentionConfig): + r: Optional[ + float + ] # the ratio of keys that the query can attend to. 1.0 means dense attention + constant_masking: Optional[ + bool + ] # whether the randomness is per query or defined at construction time + force_sparsity: Optional[bool] # use sparsity in any case (potentially slower) + + +@register_attention("random", RandomAttentionConfig) +class RandomAttention(Attention): + def __init__( + self, + dropout: float, + causal: bool = False, + r: float = 0.01, + constant_masking: bool = True, + force_sparsity: bool = False, + *args, + **kwargs, + ): + """ + "Random" attention, as proposed for instance in BigBird_. + Random means in that case that each query can attend to a random set of keys. + This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory. + + Args: + r (float): the ratio in [0,1] of keys that the query can attend to + constant_masking (bool): if true, keep the same random set for all queries. + + .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf + + """ + super().__init__() + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.causal = causal + self.r = r + self.rand_attention_mask: Optional[torch.Tensor] = None + self.constant_masking = constant_masking + self.force_sparsity = force_sparsity + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + self.requires_same_k_q_dimensions = True + + def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor: + sparsity = 1 - self.r + mask = random_pattern(shape[1], sparsity=sparsity) + + if self.causal: + mask &= causal_1d_pattern(shape[1]) + + mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask) + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + *args, + **kwargs, + ): + # Rand masking + if not self.constant_masking or self.rand_attention_mask is None: + self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device) + + # Mask-aware attention + if att_mask is not None: + if att_mask.dtype == torch.bool and isinstance( + self.rand_attention_mask, AttentionMask + ): + mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask) + else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() + mask = self.rand_attention_mask & att_mask + else: + mask = self.rand_attention_mask + + # Handle q/k/v which would not fit the mask + seq_len = q.shape[-2] + q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v)) + + # Normal attention with the random mask + att = scaled_dot_product_attention( + q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop + ) + + # Take into account an hypothetical padding + return att[:, :seq_len, :] diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/scaled_dot_product.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/scaled_dot_product.py new file mode 100644 index 0000000000000000000000000000000000000000..16fd32ab7f236673846dc2fafad9643265c6986d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/scaled_dot_product.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) +from xformers.components.attention.core import scaled_dot_product_attention + +logger = logging.getLogger("xformers") + + +@dataclass +class ScaledDotProductConfig(AttentionConfig): + causal: Optional[bool] + seq_len: Optional[int] + to_seq_len: Optional[int] + + +@register_attention("scaled_dot_product", ScaledDotProductConfig) +class ScaledDotProduct(Attention): + r""" + Implementing the Scaled Dot-Product attention proposed in + `Attention is all you need`_, Vaswani et al. + + .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5 + """ + + mask: Optional[AttentionMask] + + def __init__( + self, + dropout: float = 0.0, + causal: bool = False, + seq_len: Optional[int] = None, + to_seq_len: Optional[int] = None, + *args, + **kwargs, + ): + super().__init__() + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.causal = causal + self.seq_len = seq_len + + if causal and seq_len is not None: + self.mask = AttentionMask.make_causal(seq_len, to_seq_len) + else: + self.mask = None + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, + *args, + **kwargs, + ) -> torch.Tensor: + r""" + att_mask A 2D or 3D mask which ignores attention at certain positions. + + - If the mask is boolean, a value of True will keep the value, + while a value of False will mask the value. + + Key padding masks (dimension: batch x sequence length) and attention masks + (dimension: sequence length x sequence length OR batch x sequence length x sequence length) + can be combined and passed in here. Method maybe_merge_masks provided in the utils can be + used for that merging. + + - If the mask has the float type, then an additive mask is expected (masked values are -inf) + + """ + + # Convenience, create an attention mask if a tensor was passed + if att_mask is not None and isinstance(att_mask, torch.Tensor): + # By default we don't know of the causality, and a check would be expensive + att_mask = ( + AttentionMask.from_bool(att_mask) + if att_mask.dtype == torch.bool + else AttentionMask(att_mask, is_causal=False) + ) + + # Handle a possibly deferred causal mask handling + mask = self.mask + if self.causal and self.mask is None: + mask = AttentionMask.make_causal( + seq_len=q.shape[-2], + to_seq_len=q.shape[-2], + device=q.device, + dtype=q.dtype, + ) + + # Merge the optional causal mask and the user-provided mask + if mask is not None: + mask = mask.to(dtype=q.dtype, device=q.device) + + att_mask = att_mask + mask if att_mask is not None else mask + + # Try to handle a case where the sequence is smaller than the mask + if ( + att_mask is not None + and q.shape[-2] == k.shape[-2] + and q.shape[-2] < att_mask.shape[1] + ): + if isinstance(att_mask, AttentionMask): + att_mask = att_mask.make_crop(seq_len=q.shape[-2]) + else: + logger.error( + "Mismatching sparse attention mask and sequence length." + + " Please pad the inputs or adjust the attention mask" + ) + raise NotImplementedError + + # Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S) + y = scaled_dot_product_attention( + q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop + ) + return y diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/visual.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea81f41c2334388861eb865daf7b118d42fecd9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/visual.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +@dataclass +class VisualAttentionConfig(AttentionConfig): + dim_model: int # dimension of the input sequence + + +class LKA(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + self.conv_spatial = nn.Conv2d( + dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3 + ) + self.conv1 = nn.Conv2d(dim, dim, 1) + + def forward(self, x: torch.Tensor): + u = x.clone() + attn = self.conv0(x) + attn = self.conv_spatial(attn) + attn = self.conv1(attn) + + return u * attn + + +@register_attention("visual", VisualAttentionConfig) +class Visual(Attention): + def __init__( + self, + dim_model: int, + *_, + **__, + ): + """ + Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022). + The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network + for the reference implementation + + .. Note: compared to the paper, this block contains the LKA (Large Kernel Attention) + and the prior and posterior transformations (Conv2d and activation) + + .. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf + """ + super().__init__() + + self.block = nn.Sequential( + nn.Conv2d(dim_model, dim_model, 1), + nn.GELU(), + LKA(dim_model), + nn.Conv2d(dim_model, dim_model, 1), + ) + + # MHA related flags: + self.requires_same_k_q_dimensions = ( + True # This mechanism only really supports self attention + ) + self.supports_attention_mask = False + self.requires_skip_multi_head = ( + True # This mechanism skips the multihead attention altogether + ) + self.requires_squared_context = ( + True # Recovering the 2D structure from context assumes squared content + ) + + self.requires_input_projection = ( + False # This mechanism does not require that the MHA projects inputs + ) + + def forward(self, q: torch.Tensor, *_, **__): + # Expose the 2D token structure + B, HW, C = q.shape + H = int(math.sqrt(HW)) + assert H * H == HW + + x = q.transpose(-2, -1).reshape(B, C, H, H) + + # Large kernel attention + residual = x.clone() + x = self.block(x) + x = x + residual + + # Get back to B HW C + return x.flatten(2, 3).transpose(-2, -1) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/input_projection.py b/.venv/lib/python3.11/site-packages/xformers/components/input_projection.py new file mode 100644 index 0000000000000000000000000000000000000000..a77d1d4eaf16f7cd42cfdd579f5b1837ae6dc425 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/input_projection.py @@ -0,0 +1,102 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py +# and the MultiHeadAttention implementation from PyTorch + + +import logging +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +from xformers._deprecation_warning import deprecated_function + +logger = logging.getLogger("xformers") + + +@dataclass +class InputProjectionConfig: + in_features: int + out_features: int + bias: bool + + +class InputProjection(nn.Module): + """ + Handle all the input projections in one go, opportunistically fuse some operations. + """ + + def __init__( + self, + query_proj_params: InputProjectionConfig, + key_proj_params: Optional[InputProjectionConfig], + value_proj_params: Optional[InputProjectionConfig], + use_separate_proj_weight: bool = True, + ): + + super().__init__() + deprecated_function(self) + + self.out_features = query_proj_params.out_features + + # Each input gets a separate projection + self.q_proj = nn.Linear( + query_proj_params.in_features, + query_proj_params.out_features, + query_proj_params.bias, + ) + + if key_proj_params is not None: + self.k_proj = nn.Linear( + key_proj_params.in_features, + key_proj_params.out_features, + key_proj_params.bias, + ) + else: + logger.info( + "No Key projection parameters were passed, assuming that the weights" + + " are shared with the query projection" + ) + self.k_proj = self.q_proj + + if value_proj_params is not None: + self.v_proj = nn.Linear( + value_proj_params.in_features, + value_proj_params.out_features, + value_proj_params.bias, + ) + else: + logger.info( + "No Value projection parameters were passed, assuming that the weights" + + " are shared with the query projection" + ) + self.v_proj = self.q_proj + + if not use_separate_proj_weight: + # Compute optimization used at times, share the parameters in between Q/K/V + with torch.no_grad(): + self.k_proj.weight = self.q_proj.weight + self.v_proj.weight = self.q_proj.weight + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # One projection per input tensor + + # NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ? + + q, k, v = map( + lambda fn, x: fn(x), + [self.q_proj, self.k_proj, self.v_proj], + [query, key, value], + ) + + return q, k, v diff --git a/.venv/lib/python3.11/site-packages/xformers/components/multi_head_dispatch.py b/.venv/lib/python3.11/site-packages/xformers/components/multi_head_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f75b2645bbdc27ba59623a3fc67327517c0619 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/multi_head_dispatch.py @@ -0,0 +1,271 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn.init import constant_ + +from xformers._deprecation_warning import deprecated_function +from xformers.components.attention import Attention +from xformers.components.input_projection import InputProjection, InputProjectionConfig +from xformers.components.positional_embedding import RotaryEmbedding + +logger = logging.getLogger("xformers") + + +@dataclass +class MultiHeadDispatchConfig: + dim_model: int + num_heads: int + attention: Attention + bias: bool + residual_dropout: float + dim_key: Optional[int] + dim_value: Optional[int] + in_proj_container: Optional[InputProjection] + use_separate_proj_weight: Optional[bool] + use_rotary_embeddings: Optional[bool] + out_proj: Optional[nn.Module] + + def __getitem__(self, item): + return getattr(self, item) + + +# Move head forward and fold into batch dim. dimensions become (B * nh, S, hs) +def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): + return t.view(B, S, H, Hs).transpose(1, 2).flatten(start_dim=0, end_dim=1) + + +# Move head forward and fold into batch dim. dimensions become (B, nh, S, hs) +def _split_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): + return t.view(B, S, H, Hs).transpose(1, 2) + + +class MultiHeadDispatch(nn.Module): + """ + A multi-head masked self-attention dispatch mechanism, with a projection at the end, + following the architecture proposed in `Attention is all you need`_, Vaswani et al. + + The actual attention mechanism can vary, as well as the projections. + This can be used to wrap the proposed attention mechanisms and make them multi-head aware, + but it is optional. + + Args: + dim_model: The model/embedding dimension + num_heads: The number of heads being used + attention: The attention mechanism (needs to be registered to the xformers library) + bias: Whether to use bias for the projections : (Q, K, V, Output) + residual_dropout: Amount of dropout on the residual path + use_separate_proj_weight: Use different weights for the Q, K, V projections + dim_key: Optionally use a different dimension for the key + dim_value: Optionally use a different dimension for the value + in_proj_container: Optionally provide the input projection module + use_rotary_embeddings: Use rotary embeddings + out_proj: Optionally provide the output projection module + + + .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5 + """ + + def __init__( + self, + dim_model: int, + num_heads: int, + attention: Attention, + bias: Tuple[bool, bool, bool, bool] = (True, True, True, True), + residual_dropout: float = 0.0, + use_separate_proj_weight: bool = True, + dim_key: Optional[int] = None, + dim_value: Optional[int] = None, + in_proj_container: Optional[InputProjection] = None, + use_rotary_embeddings: Optional[bool] = False, + out_proj: Optional[nn.Module] = None, + *args, + **kwargs, + ): + super().__init__() + deprecated_function(self) + + if isinstance(bias, bool): + logger.warning( + "Single bias value provided for the MHA projections." + + f" Assuming the same parameter ({bias}) is to be used everywhere" + ) + bias = (bias, bias, bias, bias) + + assert ( + dim_model % num_heads == 0 + ) # static preset for now, each head works on 1/d the embeddings, could be relaxed + assert num_heads > 0 + + # Popular default is that all latent dimensions are the same + dim_key, dim_value = map(lambda x: x if x else dim_model, (dim_key, dim_value)) + + self.num_heads = num_heads + self.dim_key_head = dim_key // num_heads + self.dim_value_head = dim_value // num_heads + self.dim_model = dim_model + self.attention = attention + + # key, query, value projections for all heads + # critical options are + # - are we sharing weights ? + # - are we adding biases ? + if attention.requires_input_projection: + self.in_proj_container = ( + in_proj_container + if in_proj_container is not None + else InputProjection( + query_proj_params=InputProjectionConfig( + dim_model, dim_key, bias=bias[0] + ), + key_proj_params=InputProjectionConfig( + dim_model, dim_key, bias=bias[1] + ), + value_proj_params=InputProjectionConfig( + dim_model, dim_value, bias=bias[2] + ), + use_separate_proj_weight=use_separate_proj_weight, + ) + ) + + # Optional rotary embeddings + self.rotary_embeddings = ( + RotaryEmbedding(self.dim_key_head) if use_rotary_embeddings else None + ) + + # Regularization + self.resid_drop = nn.Dropout(residual_dropout, inplace=False) + + # Output projection + self.proj = ( + out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias[3]) + ) + if isinstance(self.proj, nn.Linear) and self.proj.bias is not None: + constant_(self.proj.bias, 0.0) + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + att_mask: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Expected input dimensions are [batch size, sequence length, embed dim] + Output dimensions are [batch size, sequence length, embed dim] + """ + + if key is None: + key = query + if value is None: + value = query + + if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]: + max_batch = max((query.shape[0], key.shape[0], value.shape[0])) + query, key, value = map( + lambda x: x.expand(max_batch, -1, -1), [query, key, value] + ) + + B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent) + _, S_K, _ = key.size() # K, Q's sequence length could differ + + # Catch different query and key length but a causal attention + if S_Q != S_K: + assert ( + not self.attention.requires_same_k_q_dimensions + ), "This attention mechanism requires query and key to have the same sequence (context) lengths" + + if hasattr(self.attention, "causal"): + assert not self.attention.causal, ( + "Causal attention is not supported when key and query have different sequence lengths.\n" + + "In that case causality is ill-determined. Please pad your sequences accordingly" + ) + + kw_mask_args = {} + if att_mask is not None: + assert ( + self.attention.supports_attention_mask + ), "This attention does not support attention masks" + kw_mask_args["att_mask"] = att_mask + + if key_padding_mask is not None: + assert ( + self.attention.supports_key_padding_mask + ), "This attention does not support key padding masks" + kw_mask_args["key_padding_mask"] = key_padding_mask + + if self.attention.requires_skip_multi_head: + return self.attention(query, key, value, **kw_mask_args) + + # Calculate query, key, values for all heads in batch + if self.attention.requires_input_projection: + q, k, v = self.in_proj_container(query=query, key=key, value=value) + else: + k, q, v = key, query, value + + # Check the dimensions properly + def check(t, name): + assert ( + t.shape[2] % self.num_heads == 0 + ), f"the {name} embeddings need to be divisible by the number of heads" + + check(q, "projected query") + check(v, "projected value") + check(k, "projected key") + + # Optional: rotary embedding, add relative positioning information + if self.rotary_embeddings: + # rotary requires the head dimension + q = _split_heads(q, B, S_Q, self.num_heads, self.dim_key_head) + k = _split_heads(k, B, S_K, self.num_heads, self.dim_key_head) + v = _split_heads(v, B, S_K, self.num_heads, self.dim_value_head) + + q, k = self.rotary_embeddings(q=q, k=k) + + if not self.attention.requires_head_dimension: + q, k, v = q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1) + + else: + # Reshape k/q/v to either expose the heads, or fold the head dimension into the batch + reshape_fn = ( + _split_heads if self.attention.requires_head_dimension else _fold_heads + ) + + q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_key_head) + k = reshape_fn(k, B, S_K, self.num_heads, self.dim_key_head) + v = reshape_fn(v, B, S_K, self.num_heads, self.dim_value_head) + + # Self-attend + y = self.attention(q, k, v, **kw_mask_args) + + # Re-assemble all head outputs side by side + y = ( + y.view(B, self.num_heads, S_Q, self.dim_value_head) + .transpose(1, 2) + .flatten(start_dim=2, end_dim=3) + ) + + # Output projection, dropout and good to go + y = self.resid_drop(self.proj(y)) + + # Return the same sequence size as the input + return y + + @classmethod + def from_config(cls, config: MultiHeadDispatchConfig): + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/patch_embedding.py b/.venv/lib/python3.11/site-packages/xformers/components/patch_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3afb8d2ebe913b9b93721fa85ddd2b9dab258f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/patch_embedding.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from enum import Enum + +import torch + +from xformers._deprecation_warning import deprecated_function + + +class PoolType(str, Enum): + Conv2D = "CONV_2D" + # ... + # TODO: Support more cases ? + + +@dataclass +class PatchEmbeddingConfig: + """ + The configuration for the patch embedding layer, which takes the raw token passed in + and returns a pooled representation along a given embedding dimension. + + This typically trades the spatial (context length) representation with the embedding size + + This is canonicaly used by ViT, but other papers (like MetaFormer or other hierarchical transformers) + propose a more general use case for this + """ + + in_channels: int + out_channels: int + kernel_size: int + stride: int + padding: int = 0 + pool_type: PoolType = PoolType.Conv2D + + +class ConditionalReshape(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + deprecated_function(self) + + def forward(self, x): + if x.ndim == 3: + B, HW, C = x.shape + # NOTE: We're assuming a square sample here + H = int(math.sqrt(HW)) + assert H * H == HW, f"{H, HW}" + x = x.transpose(1, 2).reshape(B, C, H, H) + + return x + + +class PatchToSequence(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + deprecated_function(self) + + def forward(self, x): + return x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C + + +def build_patch_embedding(config: PatchEmbeddingConfig): + if not isinstance(config, PatchEmbeddingConfig): + config = PatchEmbeddingConfig(**config) + + if config.pool_type == PoolType.Conv2D: + pool = torch.nn.Conv2d( + config.in_channels, + config.out_channels, + kernel_size=config.kernel_size, + stride=config.stride, + padding=config.padding, + ) + else: + raise NotImplementedError + + # The patch embedding supposes that the input really is 2D in essence + # If this block is in the middle of a stack, we need to reshape + return torch.nn.Sequential(ConditionalReshape(), pool, PatchToSequence()) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__init__.py b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f7f02c2d157889dd8a0ed7ff937756e136fd277 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__init__.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from pathlib import Path +from typing import Any, Callable, Dict, Set, Union + +from xformers.utils import ( + generate_matching_config, + get_registry_decorator, + import_all_modules, +) + +from .base import PositionEmbedding, PositionEmbeddingConfig # noqa + +# CREDITS: Classy Vision registry mechanism + +POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {} +POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set() + + +def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]): + """Builds a position encoding from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_position_encoding", + "foo": "bar"}` will find a class that was registered as "my_position_encoding" + (see :func:`register_positional_embedding`) and call .from_config on it.""" + + if not isinstance(config, PositionEmbeddingConfig): + config_instance = generate_matching_config( + config, POSITION_EMBEDDING_REGISTRY[config["name"]].config + ) + else: + config_instance = config + + return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config( + config_instance + ) + + +"""Registers a PositionEncoding subclass. + + This decorator allows xFormers to instantiate a subclass of PositionEncoding + from a configuration file, even if the class itself is not part of the + xFormers framework. To use it, apply this decorator to a `PositionEncoding` + subclass, like this: + + .. code-block:: python + + @dataclass + class MyConfig: + ... + + @register_positional_embedding('my_encoding', MyConfig) + class MyEncoding(PositionEncoding): + ... + + To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`.""" +register_positional_embedding: Callable[ + [str, Any], Callable[[Any], Any] +] = get_registry_decorator( + POSITION_EMBEDDING_REGISTRY, + POSITION_EMBEDDING_CLASS_NAMES, + PositionEmbedding, + PositionEmbeddingConfig, +) + + +from .rotary import RotaryEmbedding # noqa +from .sine import SinePositionalEmbedding # type: ignore # noqa +from .vocab import VocabEmbedding # noqa + +__all__ = [ + "RotaryEmbedding", + "SinePositionalEmbedding", + "VocabEmbedding", + "build_positional_embedding", + "register_positional_embedding", +] + +# automatically import any Python files in the directory +import_all_modules( + str(Path(__file__).parent), "xformers.components.positional_embedding" +) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05f177d4bf7eb6b87859769b131fc6335cc7315 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d3a81bdc9fbda067ba68f75ade3dfab95158a77 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/param.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/param.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74aa66a361eaba4de3664f45f2596ebb712a6f9c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/param.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/rotary.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a7a163067163ca1179544dd345508d695b332a4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/rotary.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/sine.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/sine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a563ca770cd82c2be6f69cff91bed6d5afef57 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/sine.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/vocab.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/vocab.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45b5c29c8160d1392877d78f4559d646fcdc832e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/vocab.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/base.py b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c998487660d257e522e21ac3133fca5e9d8db104 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/base.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABCMeta, abstractmethod +from dataclasses import asdict, dataclass +from typing import Type, TypeVar + +import torch.nn as nn + +from xformers._deprecation_warning import deprecated_function + +Self = TypeVar("Self", bound="PositionEmbedding") + + +@dataclass +class PositionEmbeddingConfig: + name: str + dim_model: int + seq_len: int + + +class PositionEmbedding(nn.Module, metaclass=ABCMeta): + @abstractmethod + def __init__(self, *args, **kwargs) -> None: + super().__init__() + deprecated_function(self) + + @classmethod + def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + return cls(**fields) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/param.py b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/param.py new file mode 100644 index 0000000000000000000000000000000000000000..bc96cf6787606d9a5e9176eea9c7542ef1627ca7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/param.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +import torch + +from xformers.components.positional_embedding import ( + PositionEmbedding, + PositionEmbeddingConfig, + register_positional_embedding, +) + + +@dataclass +class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig): + name: str + seq_len: int + dim_model: int + add_class_token: bool + + +@register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig) +class LearnablePositionalEmbedding(PositionEmbedding): + def __init__( + self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__ + ): + super().__init__() + + # 0.02 is BERT initialization + self.pos_emb = torch.nn.Parameter( + torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02 + ) + + self.class_token = ( + torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.class_token is not None: + # Prepend class token + clf_token = ( + torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device) + * self.class_token + ) + x = torch.cat([clf_token, x], dim=1) + + if x.ndim == 2: + x = x.unsqueeze(-1) + + return x + self.pos_emb diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/rotary.py b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..551089b3b35860b7840badca6613d530a9d0ce43 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/rotary.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox +# NOTE: Almost the same right now, moving parts to Triton is the next step + +from typing import Tuple + +import torch + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +@torch.jit.script +def apply_rotary_pos_emb(x, cos, sin): + # NOTE: This could probably be moved to Triton + + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim_model: int, *_, **__): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)) + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=1): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + + def forward( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + k, seq_dimension=-2 + ) + + return ( + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/sine.py b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/sine.py new file mode 100644 index 0000000000000000000000000000000000000000..321920c5ac86c99c5c88cd286e7fa0320289a2cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/sine.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# Silence Mypy errors in this file. +# type: ignore + +import math + +import torch + +from xformers.components.positional_embedding import ( + PositionEmbedding, + PositionEmbeddingConfig, + register_positional_embedding, +) + + +@register_positional_embedding("sine", PositionEmbeddingConfig) +class SinePositionalEmbedding(PositionEmbedding): + def __init__(self, dim_model: int, *args, **kwargs): + super().__init__() + self.dim_model = dim_model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + seq_len = x.shape[1] + pos = ( + torch.arange(0, seq_len, device=x.device, dtype=torch.float32) + .unsqueeze(1) + .repeat(1, self.dim_model) + ) + dim = ( + torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32) + .unsqueeze(0) + .repeat(seq_len, 1) + ) + div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model)) + pos *= div + pos[:, 0::2] = torch.sin(pos[:, 0::2]) + pos[:, 1::2] = torch.cos(pos[:, 1::2]) + + output = x.unsqueeze(-1) if x.ndim == 2 else x + + return output + pos.unsqueeze(0) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/vocab.py b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..dd18777eb14624543d56d0d89e294279af6b7b76 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/vocab.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.positional_embedding import ( + PositionEmbedding, + PositionEmbeddingConfig, + register_positional_embedding, +) + + +@dataclass +class VocabEmbeddingConfig(PositionEmbeddingConfig): + vocab_size: int + dropout: float + + +@register_positional_embedding("vocab", VocabEmbeddingConfig) +class VocabEmbedding(PositionEmbedding): + def __init__( + self, + dim_model: int, + seq_len: int, + vocab_size: int, + dropout: float = 0.0, + *args, + **kwargs + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.position_embeddings = nn.Embedding(seq_len, self.dim_model) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + self.position_ids: Optional[torch.Tensor] = None + + self.init_weights() + + def init_weights(self, gain: float = 1.0): + torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain) + torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain) + + def forward(self, x: torch.Tensor): + position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[ + None, : + ].repeat(x.shape[0], 1) + + X_token = self.word_embeddings(x) + X_pos = self.position_embeddings(position_ids) + + X = X_token + X_pos + X = self.dropout(X) + + return X diff --git a/.venv/lib/python3.11/site-packages/xformers/components/residual.py b/.venv/lib/python3.11/site-packages/xformers/components/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..ce635abfa157c0e5b776c2dc6f2569ea0485e57f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/residual.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from collections import namedtuple +from enum import Enum +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn + +from xformers._deprecation_warning import deprecated_function + + +class ResidualNormStyle(str, Enum): + """Support different residual path and norm styles. + See "On Layer Normalization in the Transformer Architecture", + Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf + """ + + Pre = "pre" + Post = "post" + DeepNorm = "deepnorm" + + +class NormalizationType(str, Enum): + LayerNorm = "layernorm" + Skip = "skip" + # TODO: BatchNorm = "batchnorm" + # TODO: GroupNorm = "groupnorm" + + +def get_normalization_layer(normalization_type: NormalizationType): + class Skip(nn.Module): + def __init__(self, *_, **__) -> None: + super().__init__() + deprecated_function(self) + + def forward(self, x: torch.Tensor, **_): + return x + + return { + NormalizationType.LayerNorm: nn.LayerNorm, + NormalizationType.Skip: Skip, + }[normalization_type] + + +class RequiresWrappedInputs: + """Used to mark, through inheritance, + the fact that this class will require inputs to be passed as a single list""" + + pass + + +# CREDITS: the following is inspired by FastAI's Transformer implementation +class Residual(nn.Module, RequiresWrappedInputs): + """ + Object-oriented handling of the residual path + + This supports scaling of the residual path, as proposed by DeepNet_ + .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf + + .. Note: the wrapped layers must accept all the inputs as a single list + """ + + def __init__(self, layer: nn.Module, scale: Optional[float] = None): + super().__init__() + deprecated_function(self) + self.layer = layer + self.scale = scale + + # PreNorm and PostNorm require all the tensors to be passed as a list + self.wrap_inputs = isinstance(layer, RequiresWrappedInputs) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + if self.scale is not None: + residue = inputs[0] * self.scale + else: + residue = inputs[0] + + if self.wrap_inputs: + return residue + self.layer(inputs=inputs, **kwargs) + + else: + return residue + self.layer(*inputs, **kwargs) + + +class PreNorm(nn.Module, RequiresWrappedInputs): + """Adds a normalization before computing attention + + ..Note: If a list of inputs is passed, all of them get normalized""" + + def __init__( + self, + d_norm: int, + sublayer: nn.Module, + normalization: NormalizationType, + use_triton: bool = True, + ): + + super().__init__() + deprecated_function(self) + self.norm = get_normalization_layer(normalization)(d_norm) + + self.sublayer = sublayer + self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + assert len(inputs) > 0 + + # Perf improvement: if the inputs are all the same, only norm once + ids = [id(x) for x in inputs] + if ids.count(ids[0]) == len(ids): + # The same tensor is passed multiple times + x_norm = self.norm(inputs[0]) + inputs_normed = [x_norm for _ in inputs] + else: + # The inputs differ, norm them all + inputs_normed = [self.norm(x_) for x_ in inputs] + + if self.wrap_inputs: + return self.sublayer(inputs=inputs_normed, **kwargs) + else: + return self.sublayer(*inputs_normed, **kwargs) + + +class PostNorm(nn.Module, RequiresWrappedInputs): + """Adds LayerNorm after computing attention""" + + def __init__( + self, + d_norm: int, + sublayer: nn.Module, + normalization: NormalizationType, + use_triton: bool = True, + ): + super().__init__() + deprecated_function(self) + self.norm = get_normalization_layer(normalization)(d_norm) + + self.sublayer = sublayer + self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + if self.wrap_inputs: + x = self.sublayer(inputs=inputs, **kwargs) + else: + x = self.sublayer(*inputs, **kwargs) + return self.norm(x) + + +DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"]) + + +def get_deepnorm_coefficients( + encoder_layers: int, decoder_layers: int +) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]: + """ + See DeepNet_. + + Returns alpha and beta depending on the number of encoder and decoder layers, + first tuple is for the encoder and second for the decoder + + .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf + """ + + N = encoder_layers + M = decoder_layers + + if decoder_layers == 0: + # Encoder only + return ( + DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25), + None, + ) + + elif encoder_layers == 0: + # Decoder only + return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25) + else: + # Encoder/decoder + encoder_coeffs = DeepNormCoefficients( + alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625 + ) + + decoder_coeffs = DeepNormCoefficients( + alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25 + ) + + return (encoder_coeffs, decoder_coeffs) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/reversible.py b/.venv/lib/python3.11/site-packages/xformers/components/reversible.py new file mode 100644 index 0000000000000000000000000000000000000000..b961018b6247e0a209e43526d3a58f9fdf089e1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/reversible.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch +import torch.nn as nn +from torch.autograd.function import Function +from torch.utils.checkpoint import get_device_states, set_device_states + +from xformers._deprecation_warning import deprecated_function +from xformers.components import RequiresWrappedInputs + +# CREDITS: Code adapted from +# https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py +# https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py, +# https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html + + +# pyre-fixme[13]: `cpu_state` is not initialized in the constructor. +class Deterministic(nn.Module): + def __init__(self, net: nn.Module): + super().__init__() + deprecated_function(self) + self.net = net + self.cpu_state: torch.Tensor = torch.get_rng_state() + self.cuda_in_fwd: bool = False + self.gpu_devices: List[int] = [] + self.gpu_states: List[torch.Tensor] = [] + self.wrap_inputs = isinstance(net, RequiresWrappedInputs) + + def record_rng(self, *args): + self.cpu_state = torch.get_rng_state() + if torch.cuda._initialized: + self.cuda_in_fwd = True + self.gpu_devices, self.gpu_states = get_device_states(*args) + + def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwargs): + if record_rng: + self.record_rng(*args) + + if not set_rng: + # Normal FW run + if self.wrap_inputs: + return self.net(inputs=args, **kwargs) + else: + return self.net(*args, **kwargs) + + else: # pragma: no cover # this is called in the backward pass, not picked up + # This is analogous to checkpointing, reset the original random state + rng_devices: List[int] = [] + if self.cuda_in_fwd: + rng_devices = self.gpu_devices + + with torch.random.fork_rng(devices=rng_devices, enabled=True): + torch.set_rng_state(self.cpu_state) + if self.cuda_in_fwd: + set_device_states(self.gpu_devices, self.gpu_states) + + if self.wrap_inputs: + return self.net(inputs=args, **kwargs) + else: + return self.net(*args, **kwargs) + + +class ReversibleBlock(nn.Module): + def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1): + super().__init__() + self.f = Deterministic(f) + self.g = Deterministic(g) + self.split_dim = split_dim + + def forward(self, x: torch.Tensor, f_args={}, g_args={}): + x1, x2 = torch.chunk(x, 2, dim=-1) + y1, y2 = None, None + + with torch.no_grad(): + y1 = x1 + self.f(x2, record_rng=self.training, **f_args) + y2 = x2 + self.g(y1, record_rng=self.training, **g_args) + + return torch.cat([y1, y2], dim=self.split_dim) + + def backward_pass( + self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={} + ): # pragma: no cover # this is covered, but called directly from C++ + y1, y2 = torch.chunk(y, 2, dim=self.split_dim) + del y + + dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim) + del dy + + with torch.enable_grad(): + y1.requires_grad = True + gy1 = self.g(y1, set_rng=True, **g_args) + torch.autograd.backward(gy1, dy2) + + with torch.no_grad(): + x2 = y2 - gy1 + del y2, gy1 + + dx1 = dy1 + y1.grad + del dy1 + y1.grad = None + + with torch.enable_grad(): + x2.requires_grad = True + fx2 = self.f(x2, set_rng=True, **f_args) + torch.autograd.backward(fx2, dx1) + + with torch.no_grad(): + x1 = y1 - fx2 + del y1, fx2 + + dx2 = dy2 + x2.grad + del dy2 + x2.grad = None + + x = torch.cat([x1, x2.detach()], dim=self.split_dim) + dx = torch.cat([dx1, dx2], dim=self.split_dim) + + return x, dx + + +class _ReversibleFunction(Function): + @staticmethod + def forward(ctx, x, blocks, kwargs): + ctx.kwargs = kwargs + for block in blocks: + x = block(x, **kwargs) + ctx.y = x.detach() + ctx.blocks = blocks + return x + + @staticmethod + def backward( + ctx, dy + ): # pragma: no cover # this is covered, but called directly from C++ + y = ctx.y + kwargs = ctx.kwargs + for block in ctx.blocks[::-1]: + y, dy = block.backward_pass(y, dy, **kwargs) + return dy, None, None + + +class ReversibleSequence(nn.Module): + def __init__(self, blocks: nn.ModuleList): + super().__init__() + deprecated_function(self) + + # pyre-fixme[23]: Unable to unpack `torch.nn.Module` into 2 values. + self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks]) + + def forward(self, x, arg_route=(True, False), **kwargs): + f_args, g_args = map(lambda route: kwargs if route else {}, arg_route) + block_kwargs = {"f_args": f_args, "g_args": g_args} + + return _ReversibleFunction.apply(x, self.blocks, block_kwargs) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/simplicial_embedding.py b/.venv/lib/python3.11/site-packages/xformers/components/simplicial_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ccbdac64936951495f7292a40243ce09b7f264 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/simplicial_embedding.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch + +from xformers._deprecation_warning import deprecated_function + +Self = TypeVar("Self", bound="SimplicialEmbedding") + + +@dataclass +class SimplicialEmbeddingConfig: + L: int + temperature: float + + +class SimplicialEmbedding(torch.nn.Module): + """ + An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al + + Arguments: + - L: the number of embedding chunks + - temperature: optional scaling parameter for the softmax operation. + A small (<1.) temperature will lead to a sparse representation (up to one-hot), + while a large (>1.) temperature will make the vector more uniform + + _"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf + """ + + def __init__(self, L: int, temperature: Optional[float] = None) -> None: + super().__init__() + deprecated_function(self) + self.L = L + self.temperature = temperature + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert ( + x.shape[-1] % self.L == 0 + ), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}" + + # Separate the input tensor into V chunks + B, C, E = x.shape + V = E // self.L + + Vs = x.reshape(B, C, self.L, V) + + # Softmax normalize them, with the proposed temperature + # This is done over the last dimension, so only within Vs + if self.temperature is not None: + Vs /= self.temperature + + Vs = torch.nn.functional.softmax(Vs, dim=-1) + + # Concatenate back and return + return Vs.reshape(B, C, E) + + @classmethod + def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + return cls(**fields) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__init__.py b/.venv/lib/python3.11/site-packages/xformers/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4d6292128df7682e0e7e4dfb8549ac263708e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/__init__.py @@ -0,0 +1,130 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .fmha import ( + AttentionBias, + AttentionOp, + AttentionOpBase, + LowerTriangularMask, + MemoryEfficientAttentionCkOp, + MemoryEfficientAttentionCutlassFwdFlashBwOp, + MemoryEfficientAttentionCutlassOp, + MemoryEfficientAttentionFlashAttentionOp, + MemoryEfficientAttentionSplitKCkOp, + memory_efficient_attention, + memory_efficient_attention_backward, + memory_efficient_attention_forward, + memory_efficient_attention_forward_requires_grad, +) +from .indexing import index_select_cat, scaled_index_add +from .ipc import init_ipc +from .modpar_layers import ColumnParallelLinear, RowParallelLinear +from .rmsnorm import RMSNorm +from .rope_padded import rope_padded +from .seqpar import sequence_parallel_leading_matmul, sequence_parallel_trailing_matmul +from .sequence_parallel_fused_ops import ( + fused_allgather_and_anything, + fused_allgather_and_linear, + fused_anything_and_reducescatter, + fused_linear_and_reducescatter, +) +from .sp24 import Sparse24Tensor, sparsify24, sparsify24_like +from .swiglu_op import ( + SwiGLU, + SwiGLUEagerOp, + SwiGLUFusedOp, + SwiGLUOp, + SwiGLUOpDispatch, + SwiGLUPackedFusedOp, + swiglu, +) +from .tiled_matmul import tiled_matmul +from .unbind import get_stack_strides, stack_or_none, unbind + +# BW compatibility +AttentionMask = AttentionBias + + +def masked_matmul(a, b, mask=None): + if torch.overrides.has_torch_function((a, b, mask)): + return torch.overrides.handle_torch_function( + masked_matmul, (a, b, mask), a, b, mask + ) + + att = a @ b + + if mask is None: + return att + + if mask.dtype == torch.bool: + if mask.ndim == 2: + mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) + # mask is presumed false == ignore + att[~mask] = float("-inf") + else: + # mask is presumed additive + att += mask + return att + + +__all__ = [ + # fmha + "AttentionBias", + "AttentionMask", + "AttentionOp", + "AttentionOpBase", + "LowerTriangularMask", + "MemoryEfficientAttentionCutlassFwdFlashBwOp", + "MemoryEfficientAttentionCutlassOp", + "MemoryEfficientAttentionFlashAttentionOp", + "MemoryEfficientAttentionCkOp", + "MemoryEfficientAttentionSplitKCkOp", + "memory_efficient_attention", + "memory_efficient_attention_backward", + "memory_efficient_attention_forward", + "memory_efficient_attention_forward_requires_grad", + # indexing + "index_select_cat", + "scaled_index_add", + # ipc + "init_ipc", + # modpar_layers + "ColumnParallelLinear", + "RowParallelLinear", + # rmsnorm + "RMSNorm", + # rope_padded + "rope_padded", + # seqpar + "sequence_parallel_leading_matmul", + "sequence_parallel_trailing_matmul", + # sequence_parallel_fused_ops + "fused_allgather_and_anything", + "fused_allgather_and_linear", + "fused_anything_and_reducescatter", + "fused_linear_and_reducescatter", + # swiglu_op + "SwiGLU", + "SwiGLUEagerOp", + "SwiGLUFusedOp", + "SwiGLUOp", + "SwiGLUOpDispatch", + "SwiGLUPackedFusedOp", + "swiglu", + # tiled_matmul + "tiled_matmul", + # unbind + "get_stack_strides", + "stack_or_none", + "unbind", + # sp24 + "sparsify24", + "sparsify24_like", + "Sparse24Tensor", + # . + "masked_matmul", +] diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad02afcc685c3c38f54322e2d7837155291bf895 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_index_select_cat.py b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_index_select_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..d34c32c326d98be702ee7ff3109364a3738fa8c8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_index_select_cat.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def index_select_cat_fwd_kernel( + output_ptr, # *Pointer* to output tensor. + source_ptr, # *Pointer* to source tensor. + index_ptr, # *Pointer* to index tensor. + num_indices, + num_cols, + stride0, # Stride information of source tensor. + stride1, + BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process. + BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process. +): + pid0 = tl.program_id(axis=0) # We use 2D launch grid + pid1 = tl.program_id(axis=1) + + indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX) + rows = tl.load(index_ptr + indices, mask=(indices < num_indices)) + cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL) + + source_offsets = source_ptr + rows[:, None] * stride0 + cols[None, :] * stride1 + mask = (indices[:, None] < num_indices) & (cols[None, :] < num_cols) + output = tl.load(source_offsets, mask=mask) + + output_offsets = output_ptr + indices[:, None] * stride0 + cols[None, :] * stride1 + tl.store(output_offsets, output, mask=mask) + + +def index_select_cat_fwd( + output: torch.Tensor, + source: torch.Tensor, + index: torch.Tensor, +): + if not (source.is_cuda and index.is_cuda): + raise ValueError("The index tensor and the source tensor must be of type CUDA!") + + if not source.ndim == 2: + raise ValueError(f"Expected 2-dimensional tensor, got {source.ndim}.") + if not index.ndim == 1: + raise ValueError(f"Expected 1-dimensional tensor, got {index.ndim}.") + + num_rows, num_cols = source.shape + num_indices = index.shape[0] + + if not num_indices < num_rows: + raise ValueError( + "The number of indices cannot exceed the number of rows in the source matrix." + ) + + stride0, stride1 = source.stride(0), source.stride(1) + + def grid(meta): + return ( + triton.cdiv(num_indices, meta["BLOCK_SIZE_INDEX"]), + triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]), + ) + + index_select_cat_fwd_kernel[grid]( + output, + source, + index, + num_indices, + num_cols, + stride0, + stride1, + BLOCK_SIZE_INDEX=1, + BLOCK_SIZE_COL=512, + ) + + return output + + +@triton.jit +def index_select_cat_bwd_kernel( + grad_source_ptr, # *Pointer* to grad_source tensor. + index_ptr, # *Pointer* to index tensor. + grad_output_ptr, # *Pointer* to grad_output tensor. + num_rows, + num_indices, + num_cols, + stride0, # Stride information of input and source tensor. + stride1, + BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process. + BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process. +): + pid0 = tl.program_id(axis=0) # We use 3D launch grid + pid1 = tl.program_id(axis=1) + + cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL) + + # load grad_output + grad_output_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX) + grad_output_offsets = ( + grad_output_ptr + + grad_output_indices[:, None] * stride0 + + cols[None, :] * stride1 + ) + grad_output_mask = (grad_output_indices[:, None] < num_indices) & ( + cols[None, :] < num_cols + ) + grad_output = tl.load(grad_output_offsets, mask=grad_output_mask).to(tl.float32) + + # select indices from grad_source + grad_source_indices = tl.load( + index_ptr + grad_output_indices, mask=(grad_output_indices < num_indices) + ) + grad_source_offsets = ( + grad_source_ptr + + grad_source_indices[:, None] * stride0 + + cols[None, :] * stride1 + ) + + # compute scaled index add and save + tl.store(grad_source_offsets, grad_output, mask=grad_output_mask) + + +def index_select_cat_bwd( + grad_source: torch.Tensor, + index: torch.Tensor, + grad_output: torch.Tensor, +): + if not (grad_source.is_cuda and grad_output.is_cuda): + raise ValueError("The grad_source and grad_output tensor must be of type CUDA!") + + if not (grad_source.ndim == 2 and grad_output.ndim == 2): + raise ValueError( + f"The grad_source and grad_output must be three-dimensional " + f"(got {grad_source.ndim} and {grad_output.ndim})!" + ) + if not grad_source.shape[1] == grad_output.shape[1]: + raise ValueError( + f"The number of elements along dimension 1 of grad_source and grad_output must be the same " + f"(got {grad_source.shape[1]} and {grad_output.shape[1]})" + ) + + num_rows, num_cols = grad_source.shape + num_indices, num_cols = grad_output.shape + if not num_rows >= num_indices: + raise ValueError( + f"The number of elements along dimension 0 of grad_source must be larger than that of grad_output " + f"(got {num_rows} and {num_indices})!" + ) + if not index.shape[0] == num_indices: + raise ValueError( + f"The number of indices and the number of elements along dimension 0 of grad_output must match " + f"(got {index.shape[0]} and {num_indices})!" + ) + + stride0, stride1 = grad_source.stride(0), grad_source.stride(1) + if not (grad_output.stride(0) == stride0 and grad_output.stride(1) == stride1): + raise ValueError( + f"The strides of the grad_source and grad_output tensors must match " + f"(got {stride0} vs. {grad_output.stride(0)}, {stride1} vs. {grad_output.stride(1)})!" + ) + + def grid(meta): + return ( + triton.cdiv(num_indices, meta["BLOCK_SIZE_INDEX"]), + triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]), + ) + + index_select_cat_bwd_kernel[grid]( + grad_source, + index, + grad_output, + num_rows, + num_indices, + num_cols, + grad_source.stride(0), + grad_source.stride(1), + BLOCK_SIZE_INDEX=1, + BLOCK_SIZE_COL=512, + ) + + return diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_scaled_index_add.py b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_scaled_index_add.py new file mode 100644 index 0000000000000000000000000000000000000000..ab1fd96cf1fe1cdc11bb3526e3a7e511071cd966 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_scaled_index_add.py @@ -0,0 +1,365 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit +def scaled_index_add_fwd_kernel( + input_ptr, # *Pointer* to input tensor. + index_ptr, # *Pointer* to index tensor. + source_ptr, # *Pointer* to source tensor. + scaling_ptr, # *Pointer* to the scaling tensor. + alpha, + num_inp_indices, + num_src_indices, + num_rows, + num_cols, + stride0, # Stride information of input and source tensor. + stride1, + stride2, + BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process. + BLOCK_SIZE_ROW: tl.constexpr, # Number of rows each program should process. + BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process. + HAS_SCALING: tl.constexpr, # Boolean indicating if the scaling factor is present. +): + pid0 = tl.program_id(axis=0) # We use 3D launch grid + pid1 = tl.program_id(axis=1) + pid2 = tl.program_id(axis=2) + + rows = pid1 * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW) + cols = pid2 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL) + + # load source + source_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX) + source_offsets = ( + source_ptr + + source_indices[:, None, None] * stride0 + + rows[None, :, None] * stride1 + + cols[None, None, :] * stride2 + ) + source_mask = ( + (source_indices[:, None, None] < num_src_indices) + & (rows[None, :, None] < num_rows) + & (cols[None, None, :] < num_cols) + ) + source = tl.load(source_offsets, mask=source_mask).to(tl.float32) + + # load input + input_indices = tl.load( + index_ptr + source_indices, mask=(source_indices < num_src_indices) + ) + input_offsets = ( + input_ptr + + input_indices[:, None, None] * stride0 + + rows[None, :, None] * stride1 + + cols[None, None, :] * stride2 + ) + x = tl.load(input_offsets, mask=source_mask).to(tl.float32) + + # compute scaled index add and save + if HAS_SCALING: + scaling = tl.load( + scaling_ptr + cols[None, None, :] * stride2, + mask=(cols[None, None, :] < num_cols), + ).to(tl.float32) + tl.store(input_offsets, x + alpha * scaling * source, mask=source_mask) + else: + tl.store(input_offsets, x + alpha * source, mask=source_mask) + + +def scaled_index_add_fwd( + x: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + scaling: Optional[torch.Tensor], + alpha: float, +): + if not (x.is_cuda and index.is_cuda and source.is_cuda): + raise ValueError( + "The input tensor, the index tensor and the source tensor must be of type CUDA!" + ) + + if not (x.ndim == 3 and source.ndim == 3): + raise ValueError( + f"The input and source must be three-dimensional (got {x.ndim} and {source.ndim})!" + ) + if not x.shape[1] == source.shape[1]: + raise ValueError( + f"The number of elements along dimension 1 of the input and source must be the same " + f"(got {x.shape[1], } and {source.shape[1], })!" + ) + if not x.shape[2] == source.shape[2]: + raise ValueError( + f"The number of elements along dimension 2 of the input and source must be the same " + f"(got {x.shape[2], } and {source.shape[2], })!" + ) + + num_inp_indices, num_rows, num_cols = x.shape + num_src_indices, num_rows, num_cols = source.shape + if not num_inp_indices >= num_src_indices: + raise ValueError( + f"The number of elements along dimension 0 of the input must be larger than that of source " + f"(got {num_inp_indices} and {num_src_indices})!" + ) + if not index.shape[0] == num_src_indices: + raise ValueError( + f"The number of indices and source tensors must match (got {len(index)} and {len(source)})!" + ) + + stride0, stride1, stride2 = x.stride(0), x.stride(1), x.stride(2) + if not ( + source.stride(0) == stride0 + and source.stride(1) == stride1 + and source.stride(2) == stride2 + ): + raise ValueError( + f"The strides of the source and input tensors must match (got {source.stride(0)} vs. {stride0}, " + f"{source.stride(1)} vs. {stride1}, {source.stride(2)} vs. {stride2})!" + ) + + if scaling is None: + HAS_SCALING = False + else: + HAS_SCALING = True + if not scaling.is_cuda: + raise ValueError("The scaling tensor must be of type CUDA!") + if not (scaling.ndim == 1 and scaling.numel() == num_cols): + raise ValueError( + f"The scaling tensor must be a 1-dimensional tensor (got {scaling.ndim}) and its size " + f"must be equal to the size of dimension 2 of source (got {scaling.numel()} vs. {num_cols})." + ) + if not scaling.stride(0) == stride2: + raise ValueError( + f"The stride of scaling must match the stride2 of input (got {scaling.stride(0)} vs. {stride2})" + ) + + if not index.ndim == 1: + raise ValueError(f"The index must be one-dimensional (got {index.ndim})!") + + def grid(meta): + return ( + triton.cdiv(num_src_indices, meta["BLOCK_SIZE_INDEX"]), + triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]), + triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]), + ) + + scaled_index_add_fwd_kernel[grid]( + x, + index, + source, + scaling, + alpha, + num_inp_indices, + num_src_indices, + num_rows, + num_cols, + x.stride(0), + x.stride(1), + x.stride(2), + BLOCK_SIZE_INDEX=1, + BLOCK_SIZE_ROW=1, + BLOCK_SIZE_COL=512, + HAS_SCALING=HAS_SCALING, + ) + + return + + +@triton.jit +def scaled_index_add_bwd_kernel( + grad_output_ptr, # *Pointer* to input tensor. + grad_source_ptr, # *Pointer* to index tensor. + grad_scaling_ptr, # *Pointer* to source tensor. + source_ptr, # *Pointer* to the source tensor. + scaling_ptr, # *Pointer* to the scaling tensor. + index_ptr, + alpha, + num_inp_indices, + num_src_indices, + num_rows, + num_cols, + stride0, # Stride information of input and source tensor. + stride1, + stride2, + BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process. + BLOCK_SIZE_ROW: tl.constexpr, # Number of rows each program should process. + BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process. + HAS_SCALING: tl.constexpr, # Boolean indicating if the scaling factor is present. +): + pid0 = tl.program_id(axis=0) # We use 3D launch grid + pid1 = tl.program_id(axis=1) + pid2 = tl.program_id(axis=2) + + rows = pid1 * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW) + cols = pid2 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL) + + # load source + source_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX) + source_offsets = ( + source_ptr + + source_indices[:, None, None] * stride0 + + rows[None, :, None] * stride1 + + cols[None, None, :] * stride2 + ) + source_mask = ( + (source_indices[:, None, None] < num_src_indices) + & (rows[None, :, None] < num_rows) + & (cols[None, None, :] < num_cols) + ) + source = tl.load(source_offsets, mask=source_mask).to(tl.float32) + + # load grad_output + grad_output_indices = tl.load( + index_ptr + source_indices, mask=(source_indices < num_src_indices) + ) + grad_output_offsets = ( + grad_output_ptr + + grad_output_indices * stride0 + + rows[None, :, None] * stride1 + + cols[None, None, :] * stride2 + ) + grad_output = tl.load(grad_output_offsets, mask=source_mask).to(tl.float32) + + # compute gradient + grad_source_offsets = ( + grad_source_ptr + + source_indices[:, None, None] * stride0 + + rows[None, :, None] * stride1 + + cols[None, None, :] * stride2 + ) + if HAS_SCALING: + scaling = tl.load( + scaling_ptr + cols[None, None, :] * stride2, + mask=(cols[None, None, :] < num_cols), + ).to(tl.float32) + + tl.store(grad_source_offsets, alpha * grad_output * scaling, mask=source_mask) + + grad_scaling_offsets = ( + grad_scaling_ptr + + source_indices[:, None, None] * stride0 + + rows[None, :, None] * stride1 + + cols[None, None, :] * stride2 + ) + tl.store(grad_scaling_offsets, alpha * grad_output * source, mask=source_mask) + else: + tl.store(grad_source_offsets, alpha * grad_output, mask=source_mask) + + +def scaled_index_add_bwd( + grad_output: torch.Tensor, + grad_source: torch.Tensor, + grad_scaling: Optional[torch.Tensor], + source: torch.Tensor, + scaling: Optional[torch.Tensor], + index: torch.Tensor, + alpha: float, +): + if not (grad_output.is_cuda and grad_source.is_cuda): + raise ValueError( + "The grad_output tensor and grad_source tensor must be of type CUDA!" + ) + + if not (grad_output.ndim == 3 and source.ndim == 3): + raise ValueError( + f"The input and source must be three-dimensional (got {grad_output.ndim} and {source.ndim})!" + ) + + if not grad_output.shape[1] == source.shape[1]: + raise ValueError( + f"The number of elements along dimension 1 of the input and source must be the same " + f"(got {grad_output.shape[1], } and {source.shape[1], })!" + ) + if not grad_output.shape[2] == source.shape[2]: + raise ValueError( + f"The number of elements along dimension 2 of the input and source must be the same " + f"(got {grad_output.shape[2], } and {source.shape[2], })!" + ) + + num_inp_indices, num_rows, num_cols = grad_output.shape + num_src_indices, num_rows, num_cols = source.shape + if not num_inp_indices >= num_src_indices: + raise ValueError( + f"The number of elements along dimension 0 of the input must be larger than that of source " + f"(got {num_inp_indices} and {num_src_indices})!" + ) + + stride0, stride1, stride2 = source.stride(0), source.stride(1), source.stride(2) + if not ( + grad_output.stride(0) == stride0 + and grad_output.stride(1) == stride1 + and grad_output.stride(2) == stride2 + ): + raise ValueError( + f"The strides of grad_output and source must match " + f"(got {grad_output.stride(0)} vs {stride0}, {grad_output.stride(1)} vs {stride1}, " + f"{grad_output.stride(2)} vs {stride2})!" + ) + if not ( + grad_source.stride(0) == stride0 + and grad_source.stride(1) == stride1 + and grad_source.stride(2) == stride2 + ): + raise ValueError( + f"The strides of grad_source and source must match " + f"(got {grad_source.stride(0)} vs {stride0}, {grad_source.stride(1)} vs {stride1}, " + f"{grad_source.stride(2)} vs {stride2})!" + ) + + if scaling is not None and grad_scaling is not None: + HAS_SCALING = True + if not grad_scaling.is_cuda: + raise ValueError("The scaling tensor must be of type CUDA!") + if not ( + grad_scaling.stride(0) == stride0 + and grad_scaling.stride(1) == stride1 + and grad_scaling.stride(2) == stride2 + ): + raise ValueError( + f"The strides of grad_scaling and source must match " + f"(got {grad_scaling.stride(0)} vs {stride0}, {grad_scaling.stride(1)} vs {stride1}, " + f"{grad_scaling.stride(2)} vs {stride2})!" + ) + if not scaling.stride(0) == stride2: + raise ValueError( + f"The stride of scaling must match stride2 of source (got {scaling.stride(0)} vs. {stride2})!" + ) + else: + HAS_SCALING = False + + def grid(meta): + return ( + triton.cdiv(num_src_indices, meta["BLOCK_SIZE_INDEX"]), + triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]), + triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]), + ) + + scaled_index_add_bwd_kernel[grid]( + grad_output, + grad_source, + grad_scaling, + source, + scaling, + index, + alpha, + num_inp_indices, + num_src_indices, + num_rows, + num_cols, + stride0, + stride1, + stride2, + BLOCK_SIZE_INDEX=1, + BLOCK_SIZE_ROW=1, + BLOCK_SIZE_COL=512, + HAS_SCALING=HAS_SCALING, + ) + + return diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/rmsnorm_kernels.py b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/rmsnorm_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..4abc34ff5f83477fcd67ee0bcc7e730e5c1020bf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/rmsnorm_kernels.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch +import triton +import triton.language as tl + +try: + from triton.language.extra.cuda.libdevice import rsqrt +except ImportError: + try: + from triton.language.math import rsqrt + except ImportError: + from triton.language.libdevice import rsqrt + + +@triton.jit +def _rms_norm_kernel( + x_ptr, + h1_ptr, + w_ptr, + eps, + stride, + N_COLS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + INCLUDE_WEIGHT: tl.constexpr, +): + row = tl.program_id(0).to(tl.int64) + x_ptr += row * stride + h1_ptr += row * stride + + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + a = tl.load( + x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last" + ).to(tl.float32) + _mean += a * a + rstd = rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + mask = cols < N_COLS + a = tl.load( + x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first" + ).to(tl.float32) + if INCLUDE_WEIGHT: + w = tl.load(w_ptr + cols, mask=mask) + tl.store(h1_ptr + cols, a * rstd * w, mask=mask) + else: + tl.store(h1_ptr + cols, a * rstd, mask=mask) + + +@triton.jit +def _rms_norm_add_kernel( + x_ptr, + y_ptr, + h1_ptr, + w_ptr, + eps, + stride, + N_COLS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + INCLUDE_WEIGHT: tl.constexpr, +): + row = tl.program_id(0) + x_ptr += row * stride + y_ptr += row * stride + h1_ptr += row * stride + + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + mask = cols < N_COLS + ax = tl.load( + x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last" + ).to(tl.float32) + ay = tl.load( + y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first" + ).to(tl.float32) + a = ax + ay + tl.store(x_ptr + cols, a, mask=mask) + _mean += a * a + rstd = rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + mask = cols < N_COLS + a = tl.load( + x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first" + ).to(tl.float32) + if INCLUDE_WEIGHT: + w = tl.load(w_ptr + cols, mask=mask) + tl.store(h1_ptr + cols, a * rstd * w, mask=mask) + else: + tl.store(h1_ptr + cols, a * rstd, mask=mask) + + +def _rms_norm_forward(x, attn_norm_weights, eps): + if not x.is_contiguous(): + raise ValueError("data must be contiguous") + if attn_norm_weights is not None: + if not attn_norm_weights.is_contiguous(): + raise ValueError("weights must be contiguous") + out = torch.empty_like(x) + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 8192) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + with torch.cuda.device(x.device): + _rms_norm_kernel[(M,)]( + x_arg, + out, + attn_norm_weights, + eps, + x_arg.stride(0), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + INCLUDE_WEIGHT=attn_norm_weights is not None, + ) + return out + + +def _rms_norm_add_forward(x, y, attn_norm_weights, eps): + # x, y contiguous of same shape [..., n] + # output of same shape, normed over the last dim. + if not x.is_contiguous(): + raise ValueError("x must be contiguous") + if not y.is_contiguous(): + raise ValueError("y must be contiguous") + if attn_norm_weights is not None: + if not attn_norm_weights.is_contiguous(): + raise ValueError("weights must be contiguous") + out = torch.empty_like(x) + x_arg = x.reshape(-1, x.shape[-1]) + y_arg = y.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 8192) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + with torch.cuda.device(x.device): + _rms_norm_add_kernel[(M,)]( + x_arg, + y_arg, + out, + attn_norm_weights, + eps, + x_arg.stride(0), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + INCLUDE_WEIGHT=attn_norm_weights is not None, + ) + return out diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/rope_padded_kernels.py b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/rope_padded_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..80781b30c9f88703aa9c107c668e398af5fb9ee9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/rope_padded_kernels.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import triton # type: ignore +import triton.language as tl # type: ignore + +try: + from triton.language.extra.cuda.libdevice import pow +except ImportError: + try: + from triton.language.math import pow + except ImportError: + from triton.language.libdevice import pow + + +@triton.jit +def _rope_padded_kernel( + xq, + xk, + xv, + out_q, + cache_k, + cache_v, + seqstartq, + seqstartk, + seqlenk, + theta, + linear_scale, + use_dynamic_scaling: tl.constexpr, + dynamic_old_context_len: tl.constexpr, + dynamic_scale_factor: tl.constexpr, + dynamic_low_freq_factor: tl.constexpr, + dynamic_high_freq_factor: tl.constexpr, + first_seqpos, + seqpos, + k_start: tl.constexpr, + v_start: tl.constexpr, + n_groups, + dim: tl.constexpr, # dimension of each head + stride_xqM, + stride_xqG, + stride_xqH, + stride_xkM, + stride_xkG, + stride_xkH, + stride_xvM, + stride_xvG, + stride_xvH, + stride_cachekM, + stride_cachekG, + stride_cachekH, + stride_cachevM, + stride_cachevG, + stride_cachevH, + stride_seqstartq, + stride_seqstartk, + stride_seqlenk, + stride_outqM, + stride_outqG, + stride_outqH, + stride_seqpos, + internal_dtype: tl.constexpr, + # If True, seqstartq and seqstartk are not used but rather we + # assume that every batch element has the same number of + # queries (i.e. num_queries := tl.num_programs(1) ) + # and the same cache space cache_padding_length. + # Always False when called below. + const_batch_strides: tl.constexpr, + # If const_batch_strides==True, the common cache length for each batch element. + # (Only the first seqlenk[i] elements are actually in use, and only the last + # num_queries of those are actually written to.) + cache_padding_length, + # offset added to all values in seqlenk before using them. + # Always 0 when called below. + seqlenk_shift: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + adjacents: tl.constexpr, +): + """ + Each letter in this diagram is a whole row of length dim. + + INPUT xq xk xv + + head_dim ─► + + batch qqqqqq kk vv + │ qqqqqq kk vv + ▼ qqqqqq kk vv + + head_idx: (goes across all heads of all 3 inputs) + ▲ ▲ ▲ ▲ ▲ ▲ + │ │ │ │ │ │ + │ │ + 0 k_start │v_start │n_total_heads + │ │ + │ │ + k_start v_start + + Output is to out_q (same shape as xq), an xk-shaped part + of cache_k and an xv-shaped part of cache_v + """ + query_pos_in_batch_elt = tl.program_id(0) + batch_elt = tl.program_id(1) + group_head_idx = tl.program_id(2) + group_idx = group_head_idx % n_groups + head_idx = group_head_idx // n_groups + + if internal_dtype == "f32": + theta = theta.to(tl.float32) + elif internal_dtype == "f64": + theta = theta.to(tl.float64) + + if const_batch_strides: + query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt + end_query_pos = tl.num_programs(1) * (batch_elt + 1) + else: + query_pos = query_pos_in_batch_elt + tl.load( + seqstartq + batch_elt * stride_seqstartq + ) + end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq) + if query_pos >= end_query_pos: + return + + is_q = head_idx < k_start + is_v = head_idx >= v_start + + xq += query_pos * stride_xqM + head_idx * stride_xqH + group_idx * stride_xqG + out_q += ( + query_pos * stride_outqM + head_idx * stride_outqH + group_idx * stride_outqG + ) + + if const_batch_strides: + cache_start = cache_padding_length * batch_elt + else: + cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk) + end_of_batch_elt_cache = ( + cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift + ) + + cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos) + if seqpos is not None: + seq_pos = tl.load(seqpos + query_pos * stride_seqpos) + else: + seq_pos = cache_pos - cache_start + if first_seqpos is not None: + seq_pos += tl.load(first_seqpos + batch_elt * stride_seqpos) + cache_k += ( + (head_idx - k_start) * stride_cachekH + + cache_pos * stride_cachekM + + group_idx * stride_cachekG + ) + xk += ( + query_pos * stride_xkM + + (head_idx - k_start) * stride_xkH + + group_idx * stride_xkG + ) + in_qk = tl.where(is_q, xq, xk) + out_qk = tl.where(is_q, out_q, cache_k) + + cache_v += ( + (head_idx - v_start) * stride_cachevH + + cache_pos * stride_cachevM + + group_idx * stride_cachevG + ) + xv += ( + query_pos * stride_xvM + + (head_idx - v_start) * stride_xvH + + group_idx * stride_xvG + ) + + out = tl.where(is_v, cache_v, out_qk) + x_in = tl.where(is_v, xv, in_qk) + + for offset in range(0, dim // 2, BLOCK_SIZE // 2): + c = tl.arange(0, BLOCK_SIZE // 2) + powers = (offset + c) * 2.0 + if adjacents: + cols_re = (offset + c) * 2 + cols_im = cols_re + 1 + else: + cols_re = offset + c + cols_im = cols_re + dim // 2 + + mask = cols_im < dim + + re_x = tl.load(x_in + cols_re, mask=mask) + im_x = tl.load(x_in + cols_im, mask=mask) + # freqs = seq_pos / (theta ** (powers / dim)) + freqs = pow(theta, powers / (-dim)) + + if use_dynamic_scaling: + lo_freq_wavelen = dynamic_old_context_len / dynamic_low_freq_factor + hi_freq_wavelen = dynamic_old_context_len / dynamic_high_freq_factor + + wavelens = 6.28318530718 / freqs # 2*pi + is_low_freq = wavelens > lo_freq_wavelen + freqs = tl.where(is_low_freq, freqs / dynamic_scale_factor, freqs) + + is_mid_freq = hi_freq_wavelen <= wavelens and wavelens <= lo_freq_wavelen + + smooth = (dynamic_old_context_len / wavelens - dynamic_low_freq_factor) / ( + dynamic_high_freq_factor - dynamic_low_freq_factor + ) + freqs = tl.where( + is_mid_freq, + (1 - smooth) * freqs / dynamic_scale_factor + smooth * freqs, + freqs, + ) + + freqs = seq_pos * freqs / linear_scale + sines = tl.sin(freqs) + cosines = tl.cos(freqs) + re_out = re_x * cosines - im_x * sines + im_out = im_x * cosines + re_x * sines + + re_out_ = tl.where(is_v, re_x, re_out) + im_out_ = tl.where(is_v, im_x, im_out) + if internal_dtype == "f64": + if re_x.dtype == tl.bfloat16: + # triton 2.0.0 crashes if you try to convert + # float64 directly to bfloat16, so make an intermediate step. + re_out_ = re_out_.to(tl.float32) + im_out_ = im_out_.to(tl.float32) + tl.store(out + cols_re, re_out_, mask=mask) + tl.store(out + cols_im, im_out_, mask=mask) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/_triton/tiled_matmul_kernels.py b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/tiled_matmul_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..2453f141215afae6299e2bf75b811785775d1178 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/_triton/tiled_matmul_kernels.py @@ -0,0 +1,430 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from typing import List, Tuple + +import torch +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(*names): + def result(nargs): + for name in names: + nargs[name].zero_() + + return result + + +def gen_config( + block_m: int, + block_n: int, + block_k: int, + stages: int, + warps: int, + split_k: int = 1, + group_m: int = 8, +) -> triton.Config: + """A more compact way to define a triton.Config, so it fits on one line""" + + return triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + "GROUP_M": group_m, + }, + num_stages=stages, + num_warps=warps, + pre_hook=init_to_zero(*[f"C{i+1}{j+1}" for i in range(3) for j in range(3)]) + if split_k > 1 + else init_to_zero(), + ) + + +BASIC_MATMUL_CONFIGS = [ + gen_config(block_m=128, block_n=256, block_k=32, stages=3, warps=8), + gen_config(block_m=256, block_n=128, block_k=32, stages=3, warps=8), + gen_config(block_m=256, block_n=64, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=256, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=128, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=64, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=128, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=32, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=32, block_k=32, stages=5, warps=2), +] + + +INT8_MATMUL_CONFIGS = [ + gen_config(block_m=128, block_n=256, block_k=128, stages=3, warps=8), + gen_config(block_m=256, block_n=128, block_k=128, stages=3, warps=8), + gen_config(block_m=256, block_n=64, block_k=128, stages=4, warps=4), + gen_config(block_m=64, block_n=256, block_k=128, stages=4, warps=4), + gen_config(block_m=128, block_n=128, block_k=128, stages=4, warps=4), + gen_config(block_m=128, block_n=64, block_k=64, stages=4, warps=4), + gen_config(block_m=64, block_n=128, block_k=64, stages=4, warps=4), + gen_config(block_m=128, block_n=32, block_k=64, stages=4, warps=4), + gen_config(block_m=64, block_n=32, block_k=64, stages=5, warps=2), +] + + +IO_BOUND_MATMUL_CONFIGS_STAGES = [2, 3, 4, 5, 6] +IO_BOUND_MATMUL_CONFIGS_BLOCK_M = [16, 32] +IO_BOUND_MATMUL_CONFIGS_BLOCK_K = [32, 64] +IO_BOUND_MATMUL_CONFIGS_BLOCK_N = [32, 64, 128, 256] +IO_BOUND_MATMUL_CONFIGS_SPLIT_K = [1, 2, 4, 8, 16] + + +IO_BOUND_MATMUL_CONFIGS = [ + gen_config( + block_m=block_m, + block_n=block_n, + block_k=block_k, + stages=stages, + warps=2 if block_n <= 64 else 4, + split_k=split_k, + ) + for stages, block_m, block_k, block_n, split_k in itertools.product( + IO_BOUND_MATMUL_CONFIGS_STAGES, + IO_BOUND_MATMUL_CONFIGS_BLOCK_M, + IO_BOUND_MATMUL_CONFIGS_BLOCK_K, + IO_BOUND_MATMUL_CONFIGS_BLOCK_N, + IO_BOUND_MATMUL_CONFIGS_SPLIT_K, + ) +] + + +TRITON_CONFIGS = BASIC_MATMUL_CONFIGS + INT8_MATMUL_CONFIGS + IO_BOUND_MATMUL_CONFIGS + + +def our_estimate_matmul_time( + A11, B11, C11, M1, M2, M3, N1, N2, N3, K1, K2, K3, **kwargs +): + """Call into Triton's upstream cost model, with the right args + + The upstream function expects arguments to have certain names. Since we + renamed a few of them in our implementation, we rename them back. + + At the time of writing (July 2023) the arguments that Triton expects are: + M, N, K, A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages. + + """ + return estimate_matmul_time( + M=M1 + M2 + M3, N=N1 + N2 + N3, K=K1 + K2 + K3, A=A11, B=B11, C=C11, **kwargs + ) + + +def our_early_config_prune(config, named_args, **kwargs): + new_named_args = named_args.copy() + new_named_args["M"] = named_args["M1"] + named_args["M2"] + named_args["M3"] + new_named_args["N"] = named_args["N1"] + named_args["N2"] + named_args["N3"] + new_named_args["K"] = named_args["K1"] + named_args["K2"] + named_args["K3"] + new_named_args["A"] = named_args["A11"] + new_named_args["B"] = named_args["B11"] + new_named_args["C"] = named_args["C11"] + return early_config_prune(config, new_named_args, **kwargs) + + +@triton.autotune( + configs=TRITON_CONFIGS, + key=["M1", "M2", "M3", "N1", "N2", "N3", "K1", "K2", "K3"], + prune_configs_by={ + "early_config_prune": our_early_config_prune, + "perf_model": our_estimate_matmul_time, + "top_k": 10, + }, +) +@triton.heuristics( + { + "EVEN_K": lambda args: all( + k % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 + for k in [args["K1"], args["K2"], args["K3"]] + ), + } +) +@triton.jit() +def _xformers_tiled_matmul_kernel( + A11, + A12, + A13, + A21, + A22, + A23, + A31, + A32, + A33, + B11, + B12, + B13, + B21, + B22, + B23, + B31, + B32, + B33, + C11, + C12, + C13, + C21, + C22, + C23, + C31, + C32, + C33, + M1, + M2, + M3, + N1, + N2, + N3, + K1, + K2, + K3, + stride_am1, + stride_am2, + stride_am3, + stride_ak1, + stride_ak2, + stride_ak3, + stride_bk1, + stride_bk2, + stride_bk3, + stride_bn1, + stride_bn2, + stride_bn3, + stride_cm1, + stride_cm2, + stride_cm3, + stride_cn1, + stride_cn2, + stride_cn3, + BLOCK_M: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + BLOCK_N: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + BLOCK_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + # matrix multiplication + pid = tl.program_id(0) + pid_k = tl.program_id(1) + grid_m1 = tl.cdiv(M1, BLOCK_M) + grid_m2 = tl.cdiv(M2, BLOCK_M) + grid_m3 = tl.cdiv(M3, BLOCK_M) + grid_n1 = tl.cdiv(N1, BLOCK_N) + grid_n2 = tl.cdiv(N2, BLOCK_N) + grid_n3 = tl.cdiv(N3, BLOCK_N) + grid_m = grid_m1 + grid_m2 + grid_m3 + grid_n = grid_n1 + grid_n2 + grid_n3 + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + # We use tl.where to circumvent a regression in alignment auto-detection: + # https://github.com/openai/triton/issues/1784 + + A1 = tl.where(pid_m < grid_m1, A11, tl.where(pid_m < grid_m1 + grid_m2, A21, A31)) + A2 = tl.where(pid_m < grid_m1, A12, tl.where(pid_m < grid_m1 + grid_m2, A22, A32)) + A3 = tl.where(pid_m < grid_m1, A13, tl.where(pid_m < grid_m1 + grid_m2, A23, A33)) + B1 = tl.where(pid_n < grid_n1, B11, tl.where(pid_n < grid_n1 + grid_n2, B12, B13)) + B2 = tl.where(pid_n < grid_n1, B21, tl.where(pid_n < grid_n1 + grid_n2, B22, B23)) + B3 = tl.where(pid_n < grid_n1, B31, tl.where(pid_n < grid_n1 + grid_n2, B32, B33)) + C = tl.where( + pid_m < grid_m1, + tl.where(pid_n < grid_n1, C11, tl.where(pid_n < grid_n1 + grid_n2, C12, C13)), + tl.where( + pid_m < grid_m1 + grid_m2, + tl.where( + pid_n < grid_n1, C21, tl.where(pid_n < grid_n1 + grid_n2, C22, C23) + ), + tl.where( + pid_n < grid_n1, C31, tl.where(pid_n < grid_n1 + grid_n2, C32, C33) + ), + ), + ) + M = tl.where(pid_m < grid_m1, M1, tl.where(pid_m < grid_m1 + grid_m2, M2, M3)) + N = tl.where(pid_n < grid_n1, N1, tl.where(pid_n < grid_n1 + grid_n2, N2, N3)) + stride_ak = tl.where( + pid_m < grid_m1, + stride_ak1, + tl.where(pid_m < grid_m1 + grid_m2, stride_ak2, stride_ak3), + ) + stride_bk = tl.where( + pid_n < grid_n1, + stride_bk1, + tl.where(pid_n < grid_n1 + grid_n2, stride_bk2, stride_bk3), + ) + stride_cn = tl.where( + pid_m < grid_m1, + stride_cn1, + tl.where(pid_m < grid_m1 + grid_m2, stride_cn2, stride_cn3), + ) + stride_cm = tl.where( + pid_n < grid_n1, + stride_cm1, + tl.where(pid_n < grid_n1 + grid_n2, stride_cm2, stride_cm3), + ) + pid_m = tl.where( + pid_m < grid_m1, + pid_m, + tl.where(pid_m < grid_m1 + grid_m2, pid_m - grid_m1, pid_m - grid_m1 - grid_m2), + ) + pid_n = tl.where( + pid_n < grid_n1, + pid_n, + tl.where(pid_n < grid_n1 + grid_n2, pid_n - grid_n1, pid_n - grid_n1 - grid_n2), + ) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + # pointers + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + grid_k1 = tl.cdiv(K1, BLOCK_K) + grid_k2 = tl.cdiv(K2, BLOCK_K) + grid_k3 = tl.cdiv(K3, BLOCK_K) + for tile in range(pid_k, grid_k1 + grid_k2 + grid_k3, SPLIT_K): + A = tl.where(tile < grid_k1, A1, tl.where(tile < grid_k1 + grid_k2, A2, A3)) + B = tl.where(tile < grid_k1, B1, tl.where(tile < grid_k1 + grid_k2, B2, B3)) + K = tl.where(tile < grid_k1, K1, tl.where(tile < grid_k1 + grid_k2, K2, K3)) + stride_am = tl.where( + tile < grid_k1, + stride_am1, + tl.where(tile < grid_k1 + grid_k2, stride_am2, stride_am3), + ) + stride_bn = tl.where( + tile < grid_k1, + stride_bn1, + tl.where(tile < grid_k1 + grid_k2, stride_bn2, stride_bn3), + ) + my_tile = tl.where( + tile < grid_k1, + tile, + tl.where( + tile < grid_k1 + grid_k2, tile - grid_k1, tile - grid_k1 - grid_k2 + ), + ) + rk = my_tile * BLOCK_K + tl.arange(0, BLOCK_K) + Ain = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + Bin = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + if EVEN_K: + a = tl.load(Ain) + b = tl.load(Bin) + else: + a = tl.load(Ain, mask=rk[None, :] < K, other=0.0) + b = tl.load(Bin, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b, allow_tf32=False) + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def _check_row_or_column(row_or_col_type, row_or_col_idx, tensor_name, dim_name, vals): + assert len(vals) > 0 + for pos, val in enumerate(vals[1:]): + assert val == vals[0], ( + f"the tensors on {row_or_col_type} {row_or_col_idx} of the {tensor_name} " + f"must all have the same stride along the {dim_name} dimension, got " + f"{vals[0]} at position 0 and {val} at position {pos + 1}" + ) + return vals[0] + + +def _get_strides( + ts: List[List[torch.Tensor]], tensor_name, dim_0_name, dim_1_name +) -> Tuple[List[int], List[int]]: + strides_0 = [ + _check_row_or_column( + "column", idx, tensor_name, dim_0_name, [y.stride(0) for y in x] + ) + for idx, x in enumerate(zip(*ts)) + ] + strides_1 = [ + _check_row_or_column( + "row", idx, tensor_name, dim_1_name, [y.stride(1) for y in x] + ) + for idx, x in enumerate(ts) + ] + assert all(s == 1 for s in strides_0) or all(s == 1 for s in strides_1) + while len(strides_0) < 3: + strides_0.append(1 if strides_0[0] == 1 else 0) + while len(strides_1) < 3: + strides_1.append(1 if strides_1[0] == 1 else 0) + return strides_0, strides_1 + + +def _launch_triton_matmul( + a: List[List[torch.Tensor]], + b: List[List[torch.Tensor]], + c: List[List[torch.Tensor]], + ms: List[int], + ns: List[int], + ks: List[int], +) -> None: + strides_am, strides_ak = _get_strides(a, "first operand", "m", "k") + strides_bk, strides_bn = _get_strides(b, "second operand", "k", "n") + strides_cm, strides_cn = _get_strides(c, "output", "m", "n") + + # accumulator types + ACC_TYPE = ( + tl.float32 + if c[0][0].dtype in [torch.float16, torch.bfloat16, torch.float32] + else tl.int32 + ) + + # launch kernel + def grid(META): + return ( + sum(triton.cdiv(m, META["BLOCK_M"]) for m in ms) + * sum(triton.cdiv(n, META["BLOCK_N"]) for n in ns), + META["SPLIT_K"], + ) + + _xformers_tiled_matmul_kernel[grid]( + *[ + a[min(i, len(a) - 1)][min(j, len(a[0]) - 1)] + for i in range(3) + for j in range(3) + ], + *[ + b[min(i, len(b) - 1)][min(j, len(b[0]) - 1)] + for i in range(3) + for j in range(3) + ], + *[ + c[min(i, len(c) - 1)][min(j, len(c[0]) - 1)] + for i in range(3) + for j in range(3) + ], + *[ms[i] if len(ms) > i else 0 for i in range(3)], + *[ns[i] if len(ns) > i else 0 for i in range(3)], + *[ks[i] if len(ks) > i else 0 for i in range(3)], + *strides_am, + *strides_ak, + *strides_bk, + *strides_bn, + *strides_cm, + *strides_cn, + ACC_TYPE=ACC_TYPE, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..630e7b9cad4a8ad46c944ef35450f20a1bf4a460 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py @@ -0,0 +1,893 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional, Sequence, Tuple, Type, Union, cast + +import torch + +from . import ( + attn_bias, + ck, + ck_decoder, + ck_splitk, + cutlass, + flash, + flash3, + triton_splitk, +) +from .attn_bias import ( + VARLEN_BIASES, + AttentionBias, + BlockDiagonalMask, + LowerTriangularMask, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + AttentionOp, + AttentionOpBase, + Context, + Gradients, + Inputs, + bmk2bmhk, +) +from .dispatch import ( + _dispatch_bw, + _dispatch_fw, + _ensure_op_supports_or_raise, + _get_use_fa3, + _set_use_fa3, +) + +MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp) +MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp) +MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp) +MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp) +MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp) +MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp) + + +def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]) -> Any: + if attn_bias_tensor is None: + return attn_bias_ctx + return attn_bias_tensor + + +# Note: `torch.compile` only allows custom autograd functions +# to accept a subset of types. Therefore we serialize `op` objects +# to `str` before entering the function, and unserialize them inside. +# See also: https://github.com/pytorch/pytorch/issues/118395 +_OPS_LOOKUP = { + flash.FwOp.NAME: flash.FwOp, + flash.BwOp.NAME: flash.BwOp, +} + + +def _serialize_op(op): + if op is not None and op.NAME in _OPS_LOOKUP: + return op.NAME + return op + + +def _unserialize_op(op): + if isinstance(op, str): + return _OPS_LOOKUP[op] + return op + + +class _fMHA(torch.autograd.Function): + @staticmethod + # type: ignore + def forward(ctx, op_fw, op_bw, *args: Any) -> Any: + inp = Inputs(*args) + + op_fw = _unserialize_op(op_fw) + op_bw = _unserialize_op(op_bw) + + out, op_ctx = _memory_efficient_attention_forward_requires_grad( + inp=inp, op=op_fw + ) + + # Saving attn_bias is a bit complicated, as the + # torch part should go in `save_for_backward` + if isinstance(inp.attn_bias, torch.Tensor): + attn_bias_tensor = inp.attn_bias + attn_bias_ctx = None + else: + attn_bias_tensor = None + attn_bias_ctx = inp.attn_bias + + ctx.save_for_backward( + inp.query, + inp.key, + inp.value, + op_ctx.out, + op_ctx.lse, + ) + ctx.rng_state = op_ctx.rng_state + ctx.attn_bias_tensor = attn_bias_tensor + if op_ctx.op_bw is not None: + if op_bw is not None and op_bw is not op_ctx.op_bw: + raise ValueError( + f"Specified op_bw={op_bw.NAME}, but forward op " + f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None." + ) + op_bw = op_ctx.op_bw + if ( + op_bw is not None + and isinstance(inp.attn_bias, VARLEN_BIASES) + and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2 + and op_bw.VARLEN_LSE_PACKED != op_fw.VARLEN_LSE_PACKED + ): + raise ValueError( + f"Specified op_bw={op_bw.NAME} is not compatible with the " + f"op_fw={op_fw.NAME}, because they use different format of logsumexp. " + f"NOTE: This is new with xFormers 0.0.28" + ) + if op_bw is None and ( + inp.query.requires_grad or inp.key.requires_grad or inp.value.requires_grad + ): + varlen_lse_packed = _detect_lse_packed_or_raise(op_ctx.lse, inp) + if varlen_lse_packed is not None and op_fw is not None: + assert ( + op_fw.VARLEN_LSE_PACKED == varlen_lse_packed + ), f"{op_fw.NAME}: wrong value for `VARLEN_LSE_PACKED` ?" + # NOTE: We need to check tensor strides to decide which operator we run in the BW pass. + # Unfortunately, PyTorch only allows to call this function during the FW pass, so + # we decide the operator to use now. + op_bw = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed) + ctx.op_fw = op_fw + ctx.op_bw = op_bw + ctx.p = inp.p + # This allows to create gradients from a single storage, + # to avoid a "cat" in the BW pass. + # The heuristic is approximative, but: + # (1) It's not a big issue to create a shared storage + # (2) The heuristic needs to pass `torch.compile` + # (this is also why we run it in the FW pass, the BW pass is stricter) + ctx.qkv_share_storage = ( + inp.query.shape[0] == inp.key.shape[0] + and inp.query.shape[-1] == inp.value.shape[-1] + and inp.query.stride(-2) + == (inp.key.shape[-1] + inp.query.shape[-1] + inp.value.shape[-1]) + ) + + ctx.scale = inp.scale + ctx.attn_bias_ctx = attn_bias_ctx + ctx.n_args = len(args) + return out, op_ctx.lse + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad, grad_lse): + # Re-create context + query, key, value, out, lse = ctx.saved_tensors + attn_bias_tensor = ctx.attn_bias_tensor + rng_state = ctx.rng_state + inp = Inputs( + query=query, + key=key, + value=value, + attn_bias=_deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor), + p=ctx.p, + scale=ctx.scale, + ) + op_ctx = Context( + lse=lse, + out=out, + rng_state=rng_state, + ) + grads = _memory_efficient_attention_backward( + ctx=op_ctx, + inp=inp, + grad=grad, + op=ctx.op_bw, + _skip_op_checks=True, + ) + return (None, None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * ( + ctx.n_args - 2 + ) + + +def memory_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op: Optional[AttentionOp] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Implements the memory-efficient attention mechanism following + `"Self-Attention Does Not Need O(n^2) Memory" `_. + + :Inputs shape: + + - Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \ + the sequence length, H the number of heads, and K the embeding size per head + + - If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1`` + + - Inputs can also be of dimension 5 with GQA - see note below + + - Inputs can be non-contiguous - we only require the last dimension's stride to be 1 + + + :Equivalent pytorch code: + + .. code-block:: python + + scale = 1.0 / query.shape[-1] ** 0.5 + query = query * scale + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attn = query @ key.transpose(-2, -1) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + attn = F.dropout(attn, p) + attn = attn @ value + return attn.transpose(1, 2) + + :Examples: + + .. code-block:: python + + import xformers.ops as xops + + # Compute regular attention + y = xops.memory_efficient_attention(q, k, v) + + # With a dropout of 0.2 + y = xops.memory_efficient_attention(q, k, v, p=0.2) + + # Causal attention + y = xops.memory_efficient_attention( + q, k, v, + attn_bias=xops.LowerTriangularMask() + ) + + :Supported hardware: + + NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``. + + :EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA): + + MQA/GQA is an experimental feature supported only for the forward pass. + If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors + in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and + ``H`` is the number of heads per group (8 in the example). + + Please note that xFormers will not automatically broadcast the inputs, so you will need + to broadcast it manually before calling `memory_efficient_attention`. + + :GQA/MQA example: + + .. code-block:: python + + import torch + import xformers.ops as xops + + B, M, K = 3, 32, 128 + kwargs = dict(device="cuda", dtype=torch.float16) + q = torch.randn([B, M, 8, K], **kwargs) + k = torch.randn([B, M, 2, K], **kwargs) + v = torch.randn([B, M, 2, K], **kwargs) + out_gqa = xops.memory_efficient_attention( + q.reshape([B, M, 2, 4, K]), + k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]), + v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]), + ) + + Raises: + NotImplementedError: if there is no operator available to compute the MHA + ValueError: if inputs are invalid + + :parameter query: Tensor of shape ``[B, Mq, H, K]`` + :parameter key: Tensor of shape ``[B, Mkv, H, K]`` + :parameter value: Tensor of shape ``[B, Mkv, H, Kv]`` + :parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \ + For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \ + This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower). + :parameter p: Dropout probability. Disabled if set to ``0.0`` + :parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \ + scale (q.shape[-1]**-0.5) will be used. + :parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \ + If set to ``None`` (recommended), xFormers \ + will dispatch to the best available operator, depending on the inputs \ + and options. + :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]`` + """ + return _memory_efficient_attention( + Inputs( + query=query, + key=key, + value=value, + p=p, + attn_bias=attn_bias, + scale=scale, + output_dtype=output_dtype, + ), + op=op, + ) + + +torch.library.define( + "xformer::memory_efficient_attention_forward", + "(Tensor q, Tensor k, Tensor v, Tensor? b = None, float? p = 0.0, float? scale = None) -> Tensor", +) + + +@torch.library.impl("xformer::memory_efficient_attention_forward", "Meta") +def memory_efficient_attention_forward_meta(q, k, v): + return q.new_empty(q.shape) + + +# torch.compile has issue when tracing through op dispatch and ensure_op_support +# so provide a wrapper to register it as a custom torch library op. +@torch.library.impl("xformer::memory_efficient_attention_forward", "CUDA") +def memory_efficient_attention_forward_torch_wrapper( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + """ + This provides a torch-compilable wrapper op to + memory_efficient_attention_forward in certain special cases. + + Note that the following are not supported + - `op` input (?) + - certain attn_bias types (?) + - output_dtype + - K != Kv + """ + return memory_efficient_attention_forward( + query, + key, + value, + attn_bias, + p, + scale, + ) + + +def memory_efficient_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op: Optional[Type[AttentionFwOpBase]] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`. + """ + return _memory_efficient_attention_forward( + Inputs( + query=query, + key=key, + value=value, + p=p, + attn_bias=attn_bias, + scale=scale, + output_dtype=output_dtype, + ), + op=op, + ) + + +def memory_efficient_attention_forward_requires_grad( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op: Optional[Type[AttentionFwOpBase]] = None, + output_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later. + See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments + See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass + """ + if p != 0.0: + raise NotImplementedError( + "dropout is not supported on the non-autograd API." + " If you want to use dropout, please call `memory_efficient_attention` directly" + ) + out, ctx = _memory_efficient_attention_forward_requires_grad( + Inputs( + query=query, + key=key, + value=value, + p=p, + attn_bias=attn_bias, + scale=scale, + output_dtype=output_dtype, + ), + op=op, + ) + return out, ctx.lse + + +def memory_efficient_attention_backward( + grad: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op: Optional[Type[AttentionBwOpBase]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the gradient of the attention. + Returns a tuple (dq, dk, dv) + See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments. + `lse` is the tensor returned by + :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad` + """ + if p != 0.0: + raise NotImplementedError( + "dropout is not supported on the non-autograd API." + " If you want to use dropout, please call `memory_efficient_attention` directly" + ) + gradients = _memory_efficient_attention_backward( + Context(out=output, lse=lse), + Inputs( + query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale + ), + grad, + op=op, + ) + return (gradients.dq, gradients.dk, gradients.dv) + + +def _memory_efficient_attention( + inp: Inputs, op: Optional[AttentionOp] = None +) -> torch.Tensor: + # fast-path that doesn't require computing the logsumexp for backward computation + if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]): + return _memory_efficient_attention_forward( + inp, op=op[0] if op is not None else None + ) + + output_shape = inp.normalize_bmhk() + + op_fw = _serialize_op(op[0] if op is not None else None) + op_bw = _serialize_op(op[1] if op is not None else None) + return _fMHA.apply( + op_fw, op_bw, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale + )[0].reshape(output_shape) + + +def _memory_efficient_attention_forward( + inp: Inputs, op: Optional[Type[AttentionFwOpBase]] +) -> torch.Tensor: + inp.validate_inputs() + output_shape = inp.normalize_bmhk() + if op is None: + op = _dispatch_fw(inp, False) + else: + _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp) + + out, *_ = op.apply(inp, needs_gradient=False) + return out.reshape(output_shape) + + +def _memory_efficient_attention_forward_requires_grad( + inp: Inputs, op: Optional[Type[AttentionFwOpBase]] +) -> Tuple[torch.Tensor, Context]: + inp.validate_inputs() + output_shape = inp.normalize_bmhk() + if op is None: + op = _dispatch_fw(inp, True) + else: + _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp) + out = op.apply(inp, needs_gradient=True) + assert out[1] is not None + return (out[0].reshape(output_shape), out[1]) + + +def _detect_lse_packed_or_raise(lse: torch.Tensor, inp: Inputs) -> Optional[bool]: + """ + Detects the LSE format if we're in a varlen case. + Returns `None` if the format is not relevant (eg not varlen) + Raises an exception if the `lse` has the wrong shape + """ + shape_mismatch_err = ( + "Input tensors have incompatible shapes.\n" + f" lse.shape : {lse.shape}\n" + f" query.shape : {inp.query.shape}\n" + f" attn_bias : {type(inp.attn_bias)}" + ) + # 1. Check ndim & head dimensions + # In any case, LSE should be [*, *GH] + if lse.ndim != (inp.query.ndim - 1) or lse.shape[1:-1] != inp.query.shape[2:-1]: + raise ValueError(shape_mismatch_err) + lse_bm = [lse.shape[0], lse.shape[-1]] + lse_packed_shape = [inp.query.shape[0], inp.query.shape[1]] + lse_packed = lse_bm[0] == lse_packed_shape[0] and lse_bm >= lse_packed_shape + # 2. Check correctness for varlen biases with query.shape = [1, M, *GH, K] + # Either [1, *GH, M] (packed) + # Or [num_seq, *GH, Mq] .. with `Mq >= max_q` (padded) + if isinstance(inp.attn_bias, VARLEN_BIASES): + si = inp.attn_bias.q_seqinfo + lse_padded_shape = [si.seqstart.shape[0] - 1, si.max_seqlen] + lse_padded = lse_bm[0] == lse_padded_shape[0] and lse_bm >= lse_padded_shape + if lse_packed and lse_padded: + return None + elif lse_packed: + return True + elif lse_padded: + return False + raise ValueError(shape_mismatch_err) + # 3. For non-varlen, shape must be [B, *GH] with query.shape=[B, M, *GH, K] + if not lse_packed: + raise ValueError(shape_mismatch_err) + return None + + +def _memory_efficient_attention_backward( + ctx: Context, + inp: Inputs, + grad: torch.Tensor, + op: Optional[Type[AttentionBwOpBase]], + *, + _skip_op_checks: bool = False, +) -> Gradients: + """Warning: grad/ctx.out is potentially in BMK format""" + inp.validate_inputs() + if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim: + raise ValueError( + "All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n" + f"grad.shape : {grad.shape} \n" + f"out.shape : {ctx.out.shape} \n" + f"query.shape: {inp.query.shape}" + ) + shape_dq, shape_dk, shape_dv = tuple( + x.shape for x in (inp.query, inp.key, inp.value) + ) + inp.normalize_bmhk() + varlen_lse_packed = _detect_lse_packed_or_raise(ctx.lse, inp) + grad = bmk2bmhk(grad, 1) + ctx.out = bmk2bmhk(ctx.out, 1) + + if op is None: + op = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed) + elif not _skip_op_checks: + _ensure_op_supports_or_raise( + ValueError, "memory_efficient_attention_backward", op, inp + ) + if varlen_lse_packed is not None and varlen_lse_packed != op.VARLEN_LSE_PACKED: + raise ValueError( + f"Wrong LSE format for {op.NAME} in variable seqlen case. " + f"Double-check that the BW operator {op.NAME} is compatible " + f"with the operator used in the FW pass." + ) + + grads = op.apply(ctx, inp, grad) + grads.dq = grads.dq.reshape(shape_dq) + grads.dk = grads.dk.reshape(shape_dk) + grads.dv = grads.dv.reshape(shape_dv) + return grads + + +def memory_efficient_attention_partial( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op: Optional[Union[AttentionOp, Type[AttentionFwOpBase]]] = None, + output_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns a tuple (output, lse), where `output` is the attention in the style of + memory_efficient_attention, and `lse` is extra data, a log-sum-exp. + The outputs of calls to this with the same query and separate keys and values + can be merged with merge_attentions to obtain the attention of the queries + against the disjoint union of the keys and values. + + Warning: The backward pass of this function is quite restricted. In particular + we assume that in the forward pass the outputs were only used in merge_attention + calculations, and that LSEs weren't used anywhere except in merge attentions. + """ + if p != 0.0: + raise NotImplementedError("dropout is not supported.") + fwop: Optional[Type[AttentionFwOpBase]] = op[0] if isinstance(op, tuple) else op + inp = Inputs( + query=query, + key=key, + value=value, + p=p, + attn_bias=attn_bias, + scale=scale, + output_dtype=output_dtype, + is_partial=True, + ) + + is_grad = torch.is_grad_enabled() and any( + x.requires_grad for x in [query, key, value] + ) + + if not is_grad: + out, ctx = _memory_efficient_attention_forward_requires_grad( + inp, + op=fwop, + ) + return out, ctx.lse + + if query.ndim == 5: + raise ValueError("gradients not supported for 5D tensors") + if isinstance(op, tuple): + op_fw = _serialize_op(op[0]) + op_bw = _serialize_op(op[1]) + elif op is None: + op_fw = op_bw = None + else: + op_fw = _serialize_op(op) + op_bw = None + return _fMHA.apply( + op_fw, + op_bw, + inp.query, + inp.key, + inp.value, + inp.attn_bias, + inp.p, + inp.scale, + inp.output_dtype, + inp.is_partial, + ) + + +def merge_attentions( + attn_split: Union[torch.Tensor, Sequence[torch.Tensor]], + lse_split: Union[torch.Tensor, Sequence[torch.Tensor]], + write_lse: bool = True, + output_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Combine attention output computed on different parts of K/V for the same + query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099 + The result is equal to + Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...) + LSE_full = log(exp(LSE1) + exp(LSE2) + ...) + + Args: + attn_split: attention outputs for chunks, + either as a list of tensors of shapes [B, M, G, H, Kq] or [B, M, H, Kq] + or as a single tensor of shape [num_chunks, B, M, G, H, Kq] + or [num_chunks, B, M, H, Kq] + lse_split: LSE for chunks, + either as a list of tensors of shapes [B, G, H, M] or [B, H, M] + or as a single tensor of shape [num_chunks, B, G, H, M] or [num_chunks, B, H, M] + write_lse: whether to output LSE + output_dtype: dtype of attn_out + + Returns: + attn_out: [B, M, G, H, Kq] or [B, M, H, Kq] + lse_out: [B, G, H, M] or [B, H, M] if write_lse + or None otherwise + """ + + attn_is_concat = isinstance(attn_split, torch.Tensor) + lse_is_concat = isinstance(lse_split, torch.Tensor) + + attn_requires_grad = ( + attn_split.requires_grad # type: ignore + if attn_is_concat + else any(x.requires_grad for x in attn_split) + ) + lse_requires_grad = ( + lse_split.requires_grad # type: ignore + if lse_is_concat + else any(x.requires_grad for x in lse_split) + ) + requires_grad = torch.is_grad_enabled() and ( + attn_requires_grad or lse_requires_grad + ) + if requires_grad and not write_lse: + raise ValueError("write_lse should be true if inputs require gradients.") + + concat_path = attn_is_concat and lse_is_concat and not requires_grad + if concat_path: + attn_split = cast(torch.Tensor, attn_split) + lse_split = cast(torch.Tensor, lse_split) + if attn_split.ndim != lse_split.ndim + 1: + raise ValueError( + f"Incompatible input shapes: {attn_split.shape=}, {lse_split.shape=}" + ) + + is_bmhk = attn_split.ndim == 5 + if is_bmhk: + attn_split = attn_split.unsqueeze(3) + lse_split = lse_split.unsqueeze(2) + + num_chunks, B, M, G, H, Kq = attn_split.shape + num_chunks1, B1, G1, H1, M1 = lse_split.shape + if B != B1 or G != G1 or H != H1 or num_chunks != num_chunks1 or M != M: + raise ValueError( + f"Incompatible input shapes: {attn_split.shape=} {lse_split.shape=} " + f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {num_chunks}/{num_chunks1}, {M}/{M}" + ) + + attn_split = attn_split.permute(1, 3, 4, 0, 2, 5) + lse_split = lse_split.permute(1, 2, 3, 0, 4) + + device = attn_split.device + attn_dtype = attn_split.dtype + lse_dtype = lse_split.dtype + else: + if attn_is_concat: + attn_split = attn_split.unbind(0) # type: ignore + if lse_is_concat: + lse_split = lse_split.unbind(0) # type: ignore + num_chunks = len(attn_split) + if len(lse_split) != num_chunks: + raise ValueError( + f"Incompatible number of LSE and attention chunks: {len(attn_split)=}, {len(lse_split)=}" + ) + + attn_unsqueezed = [] + lse_unsqueezed = [] + is_bmhk = False + for i in range(num_chunks): + if attn_split[i].ndim != lse_split[i].ndim + 1: + raise ValueError( + f"Incompatible input shapes for chunk {i}: {attn_split[i].shape=}, {lse_split[i].shape=}" + ) + + is_bmhk = attn_split[i].ndim == 4 + if is_bmhk: + attn_unsqueezed.append(attn_split[i].unsqueeze(2)) + lse_unsqueezed.append(lse_split[i].unsqueeze(1)) + else: + attn_unsqueezed.append(attn_split[i]) + lse_unsqueezed.append(lse_split[i]) + attn_split, lse_split = attn_unsqueezed, lse_unsqueezed + + B, M, G, H, Kq = attn_split[0].shape + B1, G1, H1, M1 = lse_split[0].shape + if B != B1 or G != G1 or H != H1 or M != M: + raise ValueError( + f"Incompatible input shapes: {attn_split[0].shape=}, {lse_split[0].shape=} " + f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {M}/{M}" + ) + + for i in range(num_chunks): + if attn_split[i].shape != (B, M, G, H, Kq): + raise ValueError( + f"Incompatible input shapes for attention chunk {i}: " + f"{attn_split[i].shape=}, {(B, M, G, H, Kq)=}" + ) + if lse_split[i].shape != (B, G, H, M): + raise ValueError( + f"Incompatible input shapes for LSE chunk {i}: " + f"{lse_split[i].shape=}, {(B, G, H, M)=}" + ) + + attn_split[i] = attn_split[i].permute(0, 2, 3, 1, 4) # to (B, G, H, M, Kq) + + device = attn_split[0].device + attn_dtype = attn_split[0].dtype + lse_dtype = lse_split[0].dtype + + attn_out = torch.empty( + B, + M, + G, + H, + Kq, + device=device, + dtype=output_dtype or attn_dtype, + requires_grad=requires_grad, + ) + if write_lse: + lse_out = torch.empty( + B, G, H, M, device=device, dtype=lse_dtype, requires_grad=requires_grad + ) + else: + lse_out = None + + if concat_path: + triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split) # type: ignore + else: + attn_out, lse_out = _MergeAttentions.apply(attn_out, lse_out, *attn_split, *lse_split) # type: ignore + + if is_bmhk: + attn_out = attn_out[:, :, 0] + if lse_out is not None: + lse_out = lse_out[:, 0] + + return attn_out, lse_out + + +class _MergeAttentions(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx, attn_out: torch.Tensor, lse_out: torch.Tensor, *inputs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_chunks = len(inputs) // 2 + attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:] + + triton_splitk.merge_attentions_varargs(attn_out, lse_out, attn_split, lse_split) + + ctx.save_for_backward( + attn_out, + lse_out, + *inputs, + ) + return attn_out, lse_out + + @staticmethod + # type: ignore + def backward( + ctx, grad_attn: torch.Tensor, grad_lse: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], ...]: + out, lse, *inputs = ctx.saved_tensors + num_chunks = len(inputs) // 2 + attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:] + dattn, dlse = triton_splitk.merge_attentions_varargs_backward( + attn_split, + lse_split, + out, + lse, + grad_attn, + grad_lse, + ) + ret = [None, None] + dattn + dlse + return tuple(ret) + + +ALL_FW_OPS: List[Type[AttentionFwOpBase]] = [ + cutlass.FwOp if torch.version.cuda else ck.FwOp, + flash.FwOp, + flash3.FwOp, + triton_splitk.FwOp, +] + +ALL_BW_OPS: List[Type[AttentionBwOpBase]] = [ + cutlass.BwOp if torch.version.cuda else ck.BwOp, + flash.BwOp, + flash3.BwOp, +] + +__all__ = [ + "AttentionBias", + "AttentionOp", + "AttentionOpBase", + "LowerTriangularMask", + "MemoryEfficientAttentionCutlassFwdFlashBwOp", + "MemoryEfficientAttentionCutlassOp", + "MemoryEfficientAttentionFlashAttentionOp", + "memory_efficient_attention", + "MemoryEfficientAttentionCkOp", + "MemoryEfficientAttentionCkDecoderOp", + "ALL_FW_OPS", + "ALL_BW_OPS", + "attn_bias", + "_get_use_fa3", + "_set_use_fa3", + "BlockDiagonalMask", +] diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2fdcfb4b21b902e65d4159a6654b0857c69ee8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/attn_bias.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/attn_bias.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa078c78be677be4c990efbf27c65c0d946ed64 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/attn_bias.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fabab9f9ccb34ffe8b8870201723eebdcddd72b2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_decoder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4467f9ce0114bbbf15945a5a9c103af041d16c51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_decoder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_splitk.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_splitk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f15d38a9f8d2e3d1970fa019f2590cd1ef8f7273 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_splitk.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0624b214440bdc97576bbb1cdaf357c29a45ab47 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/cutlass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/cutlass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57aa14e5d140f973547004dece435d2053d5b9cf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/cutlass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/dispatch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/dispatch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6282f9034ccf29658ae8ff269eedc02e71bf7f9d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/dispatch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/flash.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/flash.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75f41bde952369ffccb70ae6d9aa61baeb49a7ab Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/flash.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/flash3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/flash3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87fa567d999c968476f417ed9b0566a653cba637 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/flash3.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/torch_attention_compat.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/torch_attention_compat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95ead30a3964261ea5c51a732f0df5df211f4665 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/torch_attention_compat.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/triton_splitk.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/triton_splitk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da39b0d9c671e88c4a9a528043eaa5bb12cba639 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/triton_splitk.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__init__.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6677db08c0dcd87f7570a3249f61cc3eabe0532a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fbcb84ef4198614ab1c71f55e6acdb6cacb1641 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__pycache__/splitk_kernels.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__pycache__/splitk_kernels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b2d55e52fcd3a44004c21dcb8705575b56b0a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/__pycache__/splitk_kernels.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/splitk_kernels.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/splitk_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..6752cd126dcf4f93933ebb157863f8000e81bf91 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/_triton/splitk_kernels.py @@ -0,0 +1,1127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import sys +from typing import Callable, Dict, Tuple, Union + +import torch +import triton +import triton.language as tl + +from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs + +AUTOTUNER_KEY = [ + "Z", + "H", + "G", + "N_CTX_Q", + "N_CTX_K", + "BLOCK_DMODEL", + "PACKED_PER_VAL", + "N_GROUPS", + "BLOCK_N_PER_SPLIT", +] + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + LSE_splitk, # [B, H, split_k, Mq] + block_tables, + Seq_len, + Seq_starts_k, + Seq_starts_q, + Seq_starts_q_multiplier, + additive_bias, + K_fp8_scale_shift, + V_fp8_scale_shift, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_z, + stride_osk_g, + stride_osk_h, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_lsek_z, + stride_lsek_g, + stride_lsek_h, + stride_lsek_s, + stride_lsek_m, + stride_blocktablesz, + stride_blocktablesl, + stride_bias_b, + stride_bias_g, + stride_bias_h, + stride_bias_qm, + stride_bias_km, + stride_k_fp8_scale_shift_z: tl.constexpr, + stride_k_fp8_scale_shift_n: tl.constexpr, + stride_k_fp8_scale_shift_g: tl.constexpr, + stride_k_fp8_scale_shift_h: tl.constexpr, + stride_v_fp8_scale_shift_z: tl.constexpr, + stride_v_fp8_scale_shift_n: tl.constexpr, + stride_v_fp8_scale_shift_g: tl.constexpr, + stride_v_fp8_scale_shift_h: tl.constexpr, + kv_cache_blocks_per_row: tl.constexpr, + Z: tl.constexpr, + N_CTX_Q: tl.constexpr, # The number of queries + N_CTX_K: tl.constexpr, + BLOCK_N_PER_SPLIT: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + N_GROUPS: tl.constexpr, + # It's important that BOUNDS_CHECKS_N, BLOCK_M, BLOCK_N come at the end of + # the argument list, since they are provided by the heuristics/autotune decorator. + # Otherwise Triton throws IndexError + BOUNDS_CHECKS_N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_SPLITK: tl.constexpr, + SPLIT_K_EARLY_EXIT: tl.constexpr, + IS_CAUSAL: tl.constexpr, + NUM_QUERIES_CAUSAL: tl.constexpr, # The N_CTX_Q queries are from this many sequence positions + USE_PAGED_ATTENTION: tl.constexpr, + PAGE_SIZE: tl.constexpr, + WRITE_LSE: tl.constexpr, + HAS_ADDITIVE_BIAS: tl.constexpr, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs + before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists. + See how FwOp.apply does it below. + + Set IS_SPLITK=False to indicate the MHA result should be written directly. + No metadata will be written. + """ + internal_dtype = ( + tl.float64 if Out_splitK.dtype.element_ty is tl.float64 else tl.float32 + ) + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or ( + (PACKED_PER_VAL == 4 or PACKED_PER_VAL == 8) + and tl.constexpr(K.dtype.element_ty == tl.int32) + ), + f"Only int4 and fp8 quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + tl.static_assert( + N_GROUPS == 1 or K_fp8_scale_shift is None, + f"Only row-wise fp8 quantization is supported, but got {N_GROUPS=} > 1.", + ) + FP8_QUANTIZED: tl.constexpr = K_fp8_scale_shift is not None + INT4_QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 and not FP8_QUANTIZED + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) + off_hg = off_zhg % (H * G) + off_h = off_hg // G + off_g = off_hg % G + splitk_idx = tl.program_id(2) + + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + if SPLIT_K_EARLY_EXIT and kv_len == 0: + return + else: + kv_len = N_CTX_K + + if Seq_starts_k is None: + start_kv_idx = 0 + else: + start_kv_idx = tl.load(Seq_starts_k + off_z) + + if Seq_starts_q is None: + q_len = N_CTX_Q + queries_use_batch_dim = 1 + off_m = 0 + else: + queries_use_batch_dim = 0 + off_m = tl.load(Seq_starts_q + off_z) * Seq_starts_q_multiplier + q_len = tl.load(Seq_starts_q + off_z + 1) * Seq_starts_q_multiplier - off_m + if q_len == 0: + return + + k_base = K + off_h * stride_kh + off_g * stride_kg + v_base = V + off_h * stride_vh + off_g * stride_vg + + if FP8_QUANTIZED: + k_fp8_scale_shift_base = ( + K_fp8_scale_shift + + off_h * stride_k_fp8_scale_shift_h + + off_g * stride_k_fp8_scale_shift_g + ) + v_fp8_scale_shift_base = ( + V_fp8_scale_shift + + off_h * stride_v_fp8_scale_shift_h + + off_g * stride_v_fp8_scale_shift_g + ) + else: + k_fp8_scale_shift_base = None + v_fp8_scale_shift_base = None + + # Boundaries of split-k chunk + chunk_hi = (splitk_idx + 1) * BLOCK_N_PER_SPLIT + chunk_lo = splitk_idx * BLOCK_N_PER_SPLIT + ignore_in_first_block = 0 + # For paged attention case K/V_block_ptr are defined inside the loop + # whereas for non-paged case they are defined before the loop. + if PAGE_SIZE > 0: + # Page contains several blocks + BLOCKS_IN_PAGE: tl.constexpr = PAGE_SIZE // BLOCK_N + # Align boundaries of split-k chunk to block boundaries + # In the last chunk, shift hi to the right, in the other chunks, shift it to the left + is_last_chunk = splitk_idx == tl.num_programs(2) - 1 + shift = BLOCK_N - 1 if is_last_chunk else 0 + lo = (tl.maximum(chunk_lo, start_kv_idx) // BLOCK_N) * BLOCK_N + ignore_in_first_block = tl.maximum(0, (start_kv_idx - lo)) + hi = ((chunk_hi + shift) // BLOCK_N) * BLOCK_N + hi = tl.minimum(hi, kv_len + start_kv_idx) + block_table = block_tables + stride_blocktablesz * off_z + # Offset in integer blocks + logical_block_idx = lo // BLOCK_N + else: + lo = chunk_lo + hi = tl.minimum(chunk_hi, kv_len) + if Seq_starts_k is not None: + k_base += start_kv_idx * stride_kn + v_base += start_kv_idx * stride_vn + else: + k_base += off_z * stride_kz + v_base += off_z * stride_vz + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * INT4_QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * INT4_QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if INT4_QUANTIZED: + # Pointers to quantization coefficients. Even those they are 1D, + # we have to use block pointers, since usual pointers + # don't support boundary checks + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + elif FP8_QUANTIZED: + if Seq_starts_k is not None: + k_fp8_scale_shift_base += start_kv_idx * stride_k_fp8_scale_shift_n + v_fp8_scale_shift_base += start_kv_idx * stride_v_fp8_scale_shift_n + else: + k_fp8_scale_shift_base += off_z * stride_k_fp8_scale_shift_z + v_fp8_scale_shift_base += off_z * stride_v_fp8_scale_shift_z + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_fp8_scale_shift_base, + shape=(1, hi), + strides=(1, stride_k_fp8_scale_shift_n), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_fp8_scale_shift_base, + shape=(hi, 1), + strides=(stride_v_fp8_scale_shift_n, 1), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + if HAS_ADDITIVE_BIAS: + additive_bias_block_ptr = tl.make_block_ptr( + base=additive_bias + + off_z * stride_bias_b + + off_g * stride_bias_g + + off_h * stride_bias_h, + shape=(N_CTX_Q, hi), + strides=(stride_bias_qm, stride_bias_km), + offsets=(start_m * BLOCK_M, lo), + block_shape=(BLOCK_M, BLOCK_N), + order=(0, 1), + ) + + if SPLIT_K_EARLY_EXIT and lo >= hi: + return + + Q_block_ptr = tl.make_block_ptr( + base=Q + + off_m * stride_qm + + off_h * stride_qh + + off_z * stride_qz * queries_use_batch_dim + + off_g * stride_qg, + shape=(q_len, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + # Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs. + # That turns tensors annotated as the one below into lists of tensors of length N_GROUPS. + # This is a solution for Triton native lack of support for lists of tensors. + acc: "VAR_ARGS_ARRAY" # noqa: F821 + + for i in range(len(acc)): # noqa: F821 + acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=internal_dtype) # noqa: F821 + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(acc)): # noqa: F821 + q[i] = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,) + ) + + if IS_CAUSAL: + # Why does the masking conditon below work as a causal mask? + # Assuming num_queries <= BLOCK_M: + # kv_pos = kv_start + range(0, BLOCK_N) + # q_offset = start_m * BLOCK_M + range(0, BLOCK_M) + # q_pos = kv_start + kv_len - num_queries + q_offset % num_queries + # mask = q_pos - kv_pos >= 0 + # So the final masking condition is: + # range(0, BLOCK_M) % num_queries - range(0, BLOCK_N) >= num_queries - kv_len + + q_offset = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + diag_idx = (q_offset[:, None] % NUM_QUERIES_CAUSAL) - tl.arange(0, BLOCK_N)[ + None, : + ] + diag_idx_shifted = tl.constexpr(diag_idx - NUM_QUERIES_CAUSAL + kv_len) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + if PAGE_SIZE > 0: + # Offset in integer blocks from the beginning of the page + block_offset_in_page = logical_block_idx % BLOCKS_IN_PAGE + # Offset in integer pages + logical_page_idx = logical_block_idx // BLOCKS_IN_PAGE + physical_page_idx = tl.load( + block_table + stride_blocktablesl * logical_page_idx + ).to(tl.int32) + offset = physical_page_idx * PAGE_SIZE + block_offset_in_page * BLOCK_N + + current_block_size = min(hi - start_n, BLOCK_N) + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * INT4_QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, offset + current_block_size), + strides=(stride_kk, stride_kn), + offsets=(0, offset), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * INT4_QUANTIZED * N_GROUPS, + shape=(offset + current_block_size, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(offset, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + if INT4_QUANTIZED: + # Pointers to quantization coefficients. Even those they are 1D, + # we have to use block pointers, since usual pointers + # don't support boundary checks + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, offset + current_block_size), + strides=(stride_kk, stride_kn), + offsets=(0, offset), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(offset + current_block_size, 1), + strides=(stride_vn, stride_vk), + offsets=(offset, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + elif FP8_QUANTIZED: + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_fp8_scale_shift_base, + shape=(1, offset + current_block_size), + strides=(1, stride_k_fp8_scale_shift_n), + offsets=(0, offset), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_fp8_scale_shift_base, + shape=(offset + current_block_size, 1), + strides=(stride_v_fp8_scale_shift_n, 1), + offsets=(offset, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + logical_block_idx += 1 + + k: "VAR_ARGS_ARRAY" # noqa: F821 + v: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(acc)): # noqa: F821 + k[i], v[i] = load_dequantize_k_v_group( # noqa: F821 + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + FP8_QUANTIZED, + Q.dtype.element_ty, + i, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + for i in range(len(acc)): # noqa: F821 + qk += tl.dot(q[i], k[i]) # noqa: F821 + qk *= qk_scale + + if start_n == lo and ignore_in_first_block > 0: + qk = tl.where( + tl.arange(0, BLOCK_N) < ignore_in_first_block, float("-inf"), qk + ) + + if HAS_ADDITIVE_BIAS: + loaded_bias = tl.load( + additive_bias_block_ptr, + boundary_check=(0, 1) if BOUNDS_CHECKS_N else (0,), + ) + qk += loaded_bias * 1.44269504 + additive_bias_block_ptr = tl.advance(additive_bias_block_ptr, (0, BLOCK_N)) + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + if IS_CAUSAL: + # -- apply the causal mask -- + qk = tl.where(diag_idx_shifted >= start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + if HAS_ADDITIVE_BIAS or IS_CAUSAL: + # NOTE: It's possible that an entire block is masked out. + # if this is the case, `m_i_new=nan` and everything becomes nan + alpha = tl.where(m_i_new == float("-inf"), 0, alpha) + p = tl.where(m_i_new[:, None] == float("-inf"), 0, p) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + for i in range(len(acc)): # noqa: F821 + acc[i] *= alpha[:, None] # noqa: F821 + acc[i] += tl.dot(p, v[i]) # noqa: F821 + + if not PAGE_SIZE: + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + + off_z.to(tl.int64) * stride_osk_z * queries_use_batch_dim + + off_m * stride_osk_m + + off_g * stride_osk_g + + off_h * stride_osk_h + + splitk_idx * stride_osk_s, + shape=(q_len, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + for i in range(len(acc)): # noqa: F821 + # If for the current batch element there are no tokens in the current split-k chunk (because + # seqlen is too short), l_i will be 0, so we need to make sure attention is filled with zeros and not NaNs. + attn_out = tl.where(l_i[:, None] == 0, 0.0, acc[i] / l_i[:, None]) # noqa: F821 + tl.store( + tl.advance(O_block_ptr, (0, i * D_PER_GROUP)), + attn_out.to(Out_splitK.dtype.element_ty), # noqa: F821 + boundary_check=(0,), + ) + if WRITE_LSE: + LSE_splitk_ptr = ( + LSE_splitk + + off_z * stride_lsek_z * queries_use_batch_dim + + off_m * stride_lsek_m + + off_g * stride_lsek_g + + off_h * stride_lsek_h + + splitk_idx * stride_lsek_s + + (start_m * BLOCK_M + tl.arange(0, BLOCK_M)) * stride_lsek_m + ) + mask = start_m * BLOCK_M + tl.arange(0, BLOCK_M) < q_len + # Can be float64 to improve numerics + lse_dtype = LSE_splitk.dtype.element_ty + tl.store( + LSE_splitk_ptr, + (tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / 1.44269504, + mask=mask, + ) + + +def gen_config( + block_m: int, + block_n: int, + stages: int, + warps: int, +) -> triton.Config: + """A more compact way to define a triton.Config, so it fits on one line""" + + return triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + }, + num_stages=stages, + num_warps=warps, + ) + + +def _get_splitk_kernel(num_groups): + """ + Kernel _fwd_kernel_splitK needs to be post-processed by unroll_varargs + to specialize it for a given number of quantization groups N_GROUPS + before we can apply triton.heuristics and triton.autotune, so we + don't do them as decorators. + """ + + _fwd_kernel_splitK_unrolled = unroll_varargs(_fwd_kernel_splitK, N=num_groups) + kernel = triton.heuristics( + { + "BOUNDS_CHECKS_N": lambda args: bool( + (args["BLOCK_N_PER_SPLIT"] % args["BLOCK_N"]) + or ( + args["BLOCK_N_PER_SPLIT"] > 0 + and args["N_CTX_K"] % args["BLOCK_N_PER_SPLIT"] + ) + or args["USE_SEQ_LEN"] + ) + } + )(_fwd_kernel_splitK_unrolled) + return kernel + + +@functools.lru_cache(None) +def autotune_kernel(kernel: Callable): + BLOCK_M_VALUES = [16, 32] + BLOCK_N_VALUES = [32, 64, 128] + # On AMD num_stages has to be 0 or 1, but 0 sometimes produces NaN or incorrect results. + STAGES_VALUES = [1] if torch.version.hip else [1, 2, 3] + WARPS_VALUES = [1, 2, 4] + + TRITON_CONFIGS = [ + gen_config(block_m, block_n, stages, warps) + for block_m in BLOCK_M_VALUES + for block_n in BLOCK_N_VALUES + for stages in STAGES_VALUES + for warps in WARPS_VALUES + ] + + kernel = triton.autotune( + configs=TRITON_CONFIGS, + key=AUTOTUNER_KEY, + use_cuda_graph=True, + )(kernel) + return kernel + + +# This object contains forward kernels wrapped into autotuner for different number +# of quantization groups. +_fwd_kernel_splitK_autotune: Dict[int, triton.runtime.Autotuner] = {} +# The loop below: +# - transforms the jitted kernel with unroll_varargs producing a new kernel of each value of num_groups +# - wraps the kernel into triton.heuristics +# - wraps kernel into Triton autotuner. Autotuning itself happens the first time the kernel is called +if sys.version_info >= (3, 9): + # unroll_varargs requires Python 3.9+ + for num_groups in [1, 2, 4, 8]: + _fwd_kernel_splitK_autotune[num_groups] = autotune_kernel( + _get_splitk_kernel(num_groups) + ) + + def get_autotuner_cache( + num_groups: int, + ) -> Dict[Tuple[Union[int, str]], triton.Config]: + """Returns a triton.runtime.autotuner.AutoTuner.cache object, which + represents mappings from kernel autotune keys (tuples describing kernel inputs) + to triton.Config + """ + return _fwd_kernel_splitK_autotune[num_groups].cache + + def set_autotuner_cache( + cache: Dict[Tuple[Union[int, str]], triton.Config], num_groups: int + ) -> None: + _fwd_kernel_splitK_autotune[num_groups].cache = cache + + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + FP8_QUANTIZED: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + """Load K/V for a given block. In case of int4/fp8-quantized K/V, dequantize them after loading. + If quantization is group-wise, use group_id to advance the pointers to the current group. + """ + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + # If K/V are quantized, load quantization coefficients and dequantize. + if FP8_QUANTIZED: + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + k_t = dequantize( + tl.trans(k), tl.trans(k_scale), tl.trans(k_shift), PACKED_PER_VAL + ).to(dtype) + k = tl.trans(k_t) + elif PACKED_PER_VAL > 1: + # Int4 quantization. + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + """Extract two float16 packed into one int32""" + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr, +): + """PACKED_PER_VAL is the number of values packed into each element x_. + For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8. + """ + # x_ : (BLOCK_N, D // PACKED_PER_VAL) + # scale: (BLOCK_N, 1) + # offsets: (PACKED_PER_VAL,) + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * (32 // PACKED_PER_VAL) + quant_offset = ( + x_[:, :, None, :] >> offsets + ) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL) + + quant_offset = tl.reshape( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + if PACKED_PER_VAL == 4: + # FP8 quantization. + fp8_type = tl.float8e4b8 if torch.version.hip is not None else tl.float8e4nv + dequant = ( + quant_offset.to(tl.uint8).to(fp8_type, bitcast=True).to(scale.dtype) * scale + + shift + ) + else: + # Int4 quantization. + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, G, H, split_k, Mq, K] + LSE_splitK, # [B, G, H, split_k, Mq] + Out, # [B, H, M, K] + LSE, # [B, H, M] + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + stride_osk_z: tl.constexpr, + stride_osk_g: tl.constexpr, + stride_osk_h: tl.constexpr, + stride_osk_s: tl.constexpr, + stride_osk_m: tl.constexpr, + stride_osk_k: tl.constexpr, + stride_lsek_z: tl.constexpr, + stride_lsek_g: tl.constexpr, + stride_lsek_h: tl.constexpr, + stride_lsek_s: tl.constexpr, + stride_lsek_m: tl.constexpr, + stride_oz: tl.constexpr, + stride_og: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_ok: tl.constexpr, + stride_lse_z: tl.constexpr, + stride_lse_g: tl.constexpr, + stride_lse_h: tl.constexpr, + stride_lse_m: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + WRITE_LSE: tl.constexpr, +): + # grid = (M, B * G * H, 1) + off_m = tl.program_id(0).to(tl.int64) + off_zhg = tl.program_id(1).to(tl.int64) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + + Out_splitK_ptr = ( + Out_splitK + + stride_osk_z * off_z + + stride_osk_g * off_g + + stride_osk_h * off_h + + stride_osk_m * off_m + + tl.arange(0, BLOCK_SIZE)[None, :] + + stride_osk_s * tl.arange(0, splitK_pow2)[:, None] + ) + + LSE_splitK_ptr0 = ( + LSE_splitK + + stride_lsek_z * off_z + + stride_lsek_g * off_g + + stride_lsek_h * off_h + + stride_lsek_m * off_m + + stride_lsek_s * tl.arange(0, splitK_pow2) + ) + + if splitK_pow2 > split_k: + mask_1d = tl.arange(0, splitK_pow2) < split_k + mask_2d = mask_1d[:, None] + lse_splitk = tl.load(LSE_splitK_ptr0, mask=mask_1d, other=float("-inf")) + lse_max = tl.max(lse_splitk) + out_splitk = tl.load( + Out_splitK_ptr, mask=mask_2d, other=0 + ) # (split_k, BLOCK_SIZE) + lse_splitk = tl.load( + LSE_splitK_ptr0, mask=mask_1d, other=float("-inf") + ) # (split_k,) + else: + lse_splitk = tl.load(LSE_splitK_ptr0) + lse_max = tl.max(lse_splitk) + out_splitk = tl.load(Out_splitK_ptr) + lse_splitk = tl.load(LSE_splitK_ptr0) + + sumexp_normalized_splitk = tl.math.exp2( + (lse_splitk - lse_max).to(tl.float32) * 1.44269504 + ) # (split_k,) + sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) # scalar + # Compute numerator + numerator_normalized = tl.sum( + out_splitk * sumexp_normalized_splitk[:, None], axis=0 + ) + acc = numerator_normalized / sumexp_normalized + acc = tl.where(lse_max == float("-inf"), 0.0, acc) + + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + tl.arange(0, BLOCK_SIZE) + ) + if acc.dtype is tl.float64 and Out.dtype.element_ty is not tl.float64: + # must avoid direct cast f64->f16 + acc = acc.to(tl.float32) + tl.store(Out_ptr, acc) + + if WRITE_LSE: + l_ptrs = ( + LSE + + off_z * stride_lse_z + + off_g * stride_lse_g + + off_h * stride_lse_h + + off_m * stride_lse_m + ) + to_store = lse_max + tl.math.log2(sumexp_normalized) / 1.44269504 + to_store = tl.where(lse_max == float("-inf"), lse_max, to_store) + tl.store(l_ptrs, to_store) + + +@triton.jit +def _splitK_reduce_varargs( + Out_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq, K]; + LSE_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq] + Out, # [B, G, H, M, K] + LSE, # [B, G, H, M] + stride_osk_z: "VAR_ARGS_ARRAY", + stride_osk_g: "VAR_ARGS_ARRAY", + stride_osk_h: "VAR_ARGS_ARRAY", + stride_osk_m: "VAR_ARGS_ARRAY", + stride_osk_k: "VAR_ARGS_ARRAY", + stride_lsek_z: "VAR_ARGS_ARRAY", + stride_lsek_g: "VAR_ARGS_ARRAY", + stride_lsek_h: "VAR_ARGS_ARRAY", + stride_lsek_m: "VAR_ARGS_ARRAY", + stride_oz, + stride_og, + stride_oh, + stride_om, + stride_ok, + stride_lse_z, + stride_lse_g, + stride_lse_h, + stride_lse_m, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + WRITE_LSE: tl.constexpr, +): + """ + This version of reduce kernel takes attention and LSE of chunks as lists of tensors, + as opposed to _splitK_reduce, which takes each as a stacked tensor. + """ + # grid = (M, B * G * H, 1) + off_m = tl.program_id(0).to(tl.int64) + off_zhg = tl.program_id(1).to(tl.int64) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + + out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(Out_splitK)): + out_splitk_offset[i] = ( # noqa: F821 + stride_osk_z[i] * off_z # type: ignore # noqa: F821 + + stride_osk_g[i] * off_g + + stride_osk_h[i] * off_h + + stride_osk_m[i] * off_m + + tl.arange(0, BLOCK_SIZE) + ) + lse_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(Out_splitK)): + lse_splitk_offset[i] = ( # noqa: F821 + stride_lsek_z[i] * off_z # type: ignore # noqa: F821 + + stride_lsek_g[i] * off_g + + stride_lsek_h[i] * off_h + + stride_lsek_m[i] * off_m + ) + + lse_max = float("-inf") + for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821 + LSE_splitK_ptr = LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821 + lse_splitk = tl.load(LSE_splitK_ptr) + lse_max = tl.maximum(lse_max, lse_splitk) + + sumexp_normalized = 0.0 + numerator_normalized = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + + for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821 + out_splitk = tl.load(Out_splitK[split_k_idx] + out_splitk_offset[split_k_idx]) # type: ignore # noqa: F821 + lse_splitk = tl.load(LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx]) # type: ignore # noqa: F821 + # Compute denominator + sumexp_normalized_splitk = tl.math.exp2( + (lse_splitk - lse_max).to(tl.float32) * 1.44269504 + ) + sumexp_normalized += sumexp_normalized_splitk + + # Compute numerator + numerator_normalized += out_splitk * sumexp_normalized_splitk + + acc = numerator_normalized / sumexp_normalized + acc = tl.where(lse_max == float("-inf"), 0.0, acc) + + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + tl.arange(0, BLOCK_SIZE) + ) + if acc.dtype is tl.float64 and Out.dtype.element_ty is not tl.float64: + # must avoid direct cast f64->f16 + acc = acc.to(tl.float32) + tl.store(Out_ptr, acc) + + if WRITE_LSE: + l_ptrs = ( + LSE + + off_z * stride_lse_z + + off_g * stride_lse_g + + off_h * stride_lse_h + + off_m * stride_lse_m + ) + to_store = lse_max + tl.math.log2(sumexp_normalized) / 1.44269504 + to_store = tl.where(lse_max == float("-inf"), lse_max, to_store) + tl.store(l_ptrs, to_store) + + +@triton.jit +def _splitK_reduce_varargs_backward( + Out_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq, K]; + LSE_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq] + Dout_splitK: "VAR_ARGS_ARRAY", # gradients - same shape as the inputs themselves + DLSE_splitK: "VAR_ARGS_ARRAY", + Out, # [B, G, H, M, K] + LSE, # [B, G, H, M] + DOut, + DLSE, + # strides of chunked inputs: attention and LSE + stride_osk_z: "VAR_ARGS_ARRAY", + stride_osk_g: "VAR_ARGS_ARRAY", + stride_osk_h: "VAR_ARGS_ARRAY", + stride_osk_m: "VAR_ARGS_ARRAY", + stride_osk_k: "VAR_ARGS_ARRAY", + stride_lsek_z: "VAR_ARGS_ARRAY", + stride_lsek_g: "VAR_ARGS_ARRAY", + stride_lsek_h: "VAR_ARGS_ARRAY", + stride_lsek_m: "VAR_ARGS_ARRAY", + # strides of merged outputs: attention and LSE + stride_oz, + stride_og, + stride_oh, + stride_om, + stride_ok, + stride_lse_z, + stride_lse_g, + stride_lse_h, + stride_lse_m, + # strides of gradients + stride_doz, + stride_dog, + stride_doh, + stride_dom, + stride_dok, + stride_dlse_z, + stride_dlse_g, + stride_dlse_h, + stride_dlse_m, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, +): + """ + Backward for _splitK_reduce_varargs. Similar to forward, it takes + attention and LSE of chunks as lists of tensors, + and outputs the corresponding gradients in the same format. + """ + + # grid = (M, B * G * H, 1) + off_m = tl.program_id(0).to(tl.int64) + off_zhg = tl.program_id(1).to(tl.int64) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + + # Compute offsets inside each attention/LSE chunk. + # Note that each chunk can have different strides, so offsets can also be different. + out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(Out_splitK)): + out_splitk_offset[i] = ( # type: ignore # noqa: F821 + stride_osk_z[i] * off_z + + stride_osk_g[i] * off_g + + stride_osk_h[i] * off_h + + stride_osk_m[i] * off_m + + tl.arange(0, BLOCK_SIZE) + ) + lse_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(Out_splitK)): + lse_splitk_offset[i] = ( # type: ignore # noqa: F821 + stride_lsek_z[i] * off_z + + stride_lsek_g[i] * off_g + + stride_lsek_h[i] * off_h + + stride_lsek_m[i] * off_m + ) + + lse_max = float("-inf") + for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821 + LSE_splitK_ptr = LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821 + lse_splitk = tl.load(LSE_splitK_ptr) + lse_max = tl.maximum(lse_max, lse_splitk) + + # Load attention and the corresponding gradient + offset_out = ( + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + tl.arange(0, BLOCK_SIZE) + ) + offset_dout = ( + stride_doz * off_z + + stride_doh * off_h + + stride_dog * off_g + + stride_dom * off_m + + tl.arange(0, BLOCK_SIZE) + ) + out = tl.load(Out + offset_out) + dattn = tl.load(DOut + offset_dout) + + # Load LSE and the corresponding gradient + offset_lse = ( + stride_lse_z * off_z + + stride_lse_h * off_h + + stride_lse_g * off_g + + stride_lse_m * off_m + ) + offset_dlse = ( + stride_dlse_z * off_z + + stride_dlse_h * off_h + + stride_dlse_g * off_g + + stride_dlse_m * off_m + ) + lse = tl.load(LSE + offset_lse) + dlse = tl.load(DLSE + offset_dlse) + + for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821 + # Load attention and LSE of chunks + out_splitk = tl.load(Out_splitK[split_k_idx] + out_splitk_offset[split_k_idx]) # type: ignore # noqa: F821 + lse_splitk = tl.load(LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx]) # type: ignore # noqa: F821 + + # Pointers to save gradients of attention and LSE of chunks + dout_splitk_ptr = Dout_splitK[split_k_idx] + out_splitk_offset[split_k_idx] # type: ignore # noqa: F821 + dlse_splitk_ptr = DLSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821 + + # dX/dattn_i = dX/dattn * dattn/dattn_i + dX/dlse * dlse/dattn_i, and dlse/dattn_i == 0 + dattn_dattn_i = tl.exp(lse_splitk - lse_max) / tl.exp(lse - lse_max) + dX_dattn_i = dattn_dattn_i * dattn + tl.store(dout_splitk_ptr, dX_dattn_i) + + dattn_dlse_i = (out_splitk - out) * dattn_dattn_i + + # dX/dlse_i = dX/dattn * dattn/dlse_i + dX/dlse * dlse/dlse_i + dlse_dlse_i = dattn_dattn_i + dX_dlse_i = dlse_dlse_i * dlse + tl.sum( + dattn_dlse_i * dattn + ) # Sum is over the hidden dimension + tl.store(dlse_splitk_ptr, dX_dlse_i) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/attn_bias.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/attn_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1ff62fd402eb10345c03ff50e4ac9f323965a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/attn_bias.py @@ -0,0 +1,1739 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +""" +This file contains biases that can be used as the `attn_bias` argument in +:attr:`xformers.ops.memory_efficient_attention`. +Essentially, a bias is a Tensor which will be added to the ``Q @ K.t`` before +computing the ``softmax``. + + +The goal of having custom made classes (instead of dense tensors) is that +we want to avoid having to load the biases from memory in the kernel, for +performance reasons. We also want to be able to know before-hand which +parts of the attention matrix we will need to compute (eg causal masks). + + +Some very common biases are LowerTriangularMask and BlockDiagonalMask. +""" + +import math +from dataclasses import dataclass +from typing import ( + Any, + ClassVar, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +import torch + + +class AttentionBias: + """Base class for a custom bias that can be applied \ + as the attn_bias argument in + :attr:`xformers.ops.memory_efficient_attention`. + + That function has the ability to add a tensor, the + attention bias, to the QK^T matrix before it is used + in the softmax part of the attention calculation. + The attention bias tensor with shape + (B or 1, n_queries, number of keys) + can be given as the attn_bias input. + The most common use case is for an attention bias is + to contain only zeros and negative infinities, which forms + a mask so that some queries only attend to some keys. + + Children of this class define alternative things which can + be used as the attn_bias input to define an attention bias which + forms such a mask, for some common cases. + + When using an :attr:`xformers.ops.AttentionBias` + instead of a :attr:`torch.Tensor`, the mask matrix does + not need to be materialized, and can be + hardcoded into some kernels for better performance. + + See: + + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask` + + """ + + HOLDS_DENSE_TENSOR = False + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """ + Materializes the bias as a `torch.Tensor`. This is very slow + and we don't attempt to make it fast. Only use for debugging/testing. + + Shape should be like `[*, q_seqlen, k_seqlen]` + """ + raise NotImplementedError() + + +def _get_default_bias_device(device: Optional[torch.device] = None) -> torch.device: + if device is None: + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + return device + + +def _materialize_causal_mask( + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + *, + window_size: Optional[int] = None, + from_bottomright: bool = False, +) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + shift = 0 + if from_bottomright: + shift = num_keys - num_queries + + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + return mask.to(dtype) + + +@dataclass +class LocalAttentionFromBottomRightMask(AttentionBias): + """ + A local attention mask + + The query at position :math:`q` can attend the key at position :math:`k` if + :math:`q - window\\_left <= k + s <= q + window\\_right` + + With :math:`s = num\\_queries - num\\_keys` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + bias = fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2) + print(bias.materialize(shape=(4, 4)).exp()) + print(bias.materialize(shape=(4, 5)).exp()) + + .. code-block:: text + + # 4x4 + tensor([[1., 1., 1., 0.], + [1., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 0., 1., 1.]]) + + # 4x5 + tensor([[1., 1., 1., 1., 0.], + [0., 1., 1., 1., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1.]]) + + :Illustration: + + .. figure:: /_static/local_attn.png + :width: 240px + + The total window size is :math:`window\\_left + 1 + window\\_right` + """ + + window_left: int + window_right: int + + def __post_init__(self) -> None: + if self.window_left < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_left > 0` but got window_left={self.window_left}" + ) + if self.window_right < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_right > 0` but got window_right={self.window_right}" + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + mask = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + shift = num_keys - num_queries + + mask = torch.triu(mask, diagonal=shift - self.window_left) + mask = torch.tril(mask, diagonal=shift + self.window_right) + mask = torch.log(mask) + return mask.to(dtype) + + +class LowerTriangularFromBottomRightMask(AttentionBias): + """ + A causal masking. + + This mask is exactly the same as :attr:`LowerTriangularMask` when there is + the same number of queries and keys. + When the number of queries is different from the number of keys, + it is a triangular mask shifted so that the last query can attend to + the last key. + In other words, a query Q cannot attend to a key which is nearer the + final key than Q is to the final query. + + + .. figure:: /_static/causal_bottom_right.png + + The difference between :attr:`LowerTriangularMask` (left) and + :attr:`LowerTriangularFromBottomRightMask` (right). They become + equivalent if the number of queries equals the number of keys. + """ + + def to(self, device: torch.device) -> "LowerTriangularFromBottomRightMask": + return self + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, dtype=dtype, device=device, from_bottomright=True + ) + + def make_local_attention( + self, window_size: int + ) -> "LowerTriangularFromBottomRightLocalAttentionMask": + """ + Create a new bias which combines local + causal attention. + + See :attr:`LowerTriangularFromBottomRightLocalAttentionMask` + """ + return LowerTriangularFromBottomRightLocalAttentionMask(window_size) + + +@dataclass +class LowerTriangularFromBottomRightLocalAttentionMask( + LowerTriangularFromBottomRightMask +): + """ + A mask that combines both :attr:`LowerTriangularFromBottomRightMask` and + local attention. + + A query whose distance from the final query is X cannot attend to a key + whose distance to the final key is either of: + + * less than X (i.e. "causal attention", same as :attr:`LowerTriangularFromBottomRightMask`) + * greater than X + window_size (i.e. "local attention") + + + .. figure:: /_static/causal_bottom_right_local.png + + The mask from :attr:`LowerTriangularFromBottomRightLocalAttentionMask`. + The green area is calculated, and the grey area is masked out. + """ + + _window_size: int + + def __post_init__(self) -> None: + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + from_bottomright=True, + ) + + +@dataclass +class _SeqLenInfo: + """ + (Internal) Represents the division of a dimension into blocks. + + For example, to represents a dimension of length 7 divided into + three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 2, 5, 7] + seqstart: torch.IntTensor([0, 2, 5, 7]) + """ + + seqstart: torch.Tensor + max_seqlen: int + min_seqlen: int + seqstart_py: List[int] + + def to(self, device: torch.device) -> "_SeqLenInfo": + if self.seqstart.device == device: + return self + return _SeqLenInfo( + seqstart=self.seqstart.to(device), + max_seqlen=self.max_seqlen, + min_seqlen=self.min_seqlen, + seqstart_py=self.seqstart_py, + ) + + def intervals(self) -> Iterable[Tuple[int, int]]: + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def _get_seqstart( + cls, seqlens: Iterable[int], *, device: torch.device + ) -> Tuple[int, int, List[int], torch.Tensor]: + """ + Given sequence lengths, returns the min/max value and the sequence start + positions (offsets), with first element being 0 (returned in list and Tensor). + """ + + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = [0] + max_seqlen = -1 + min_seqlen = -1 + for seqlen in seqlens: + min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32, device=device) + + return (min_seqlen, max_seqlen, seqstart_py, seqstart) + + @classmethod + def from_seqlens( + cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None + ) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + device = _get_default_bias_device(device) + min_seqlen, max_seqlen, seqstart_py, seqstart = cls._get_seqstart( + seqlens, device=device + ) + + return cls( + max_seqlen=max_seqlen, + min_seqlen=min_seqlen, + seqstart=seqstart, + seqstart_py=seqstart_py, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1: + raise ValueError( + f"Invalid `torch.Tensor` of shape {x.shape}, expected format " + f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n" + f" seqstart: {self.seqstart_py}" + ) + if batch_sizes is None: + batch_sizes = [1] * (len(self.seqstart_py) - 1) + split_chunks = [] + it = 0 + for batch_size in batch_sizes: + split_chunks.append( + self.seqstart_py[it + batch_size] - self.seqstart_py[it] + ) + it += batch_size + return [ + tensor.reshape([bs, -1, *tensor.shape[2:]]) + for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1)) + ] + + +@dataclass +class _PaddedSeqLenInfo(_SeqLenInfo): + """ + (Internal) Represents the division of a dimension into blocks which are + padded out to the same total length. + + For example, to represent a dimension of length 12 with space for + three blocks of length 4, but where the occupied lengths are + 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`. + + The layout along the dimension is + + 0 ─► block 0 + block 0 + + + 4 ─► block 1 + block 1 + block 1 + + 8 ─► block 2 + block 2 + + + 12 ─► + + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 4, 8, 12] + seqstart: torch.IntTensor([0, 4, 8, 12]) + seqlen_py: [2, 3, 2] + seqlen: torch.IntTensor([2, 3, 2]) + padding: 4 + """ + + seqlen: torch.Tensor + seqlen_py: Sequence[int] + padding: int + # From parent: seqstart[i] contains the start position + # of the i-th sequence + # seqstart: torch.Tensor + + def __post_init__(self) -> None: + assert len(self.seqstart_py) == len(self.seqlen_py) + 1 + + def to(self, device: torch.device) -> "_PaddedSeqLenInfo": + if self.seqlen.device == device: + return self + return _PaddedSeqLenInfo( + # _SeqLenInfo + seqstart=self.seqstart.to(device), + max_seqlen=self.max_seqlen, + min_seqlen=self.min_seqlen, + seqstart_py=self.seqstart_py, + # _PaddedSeqLenInfo + seqlen=self.seqlen.to(device), + seqlen_py=self.seqlen_py, + padding=self.padding, + ) + + def intervals(self) -> Iterable[Tuple[int, int]]: + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens( + cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None + ) -> "_SeqLenInfo": + raise RuntimeError( + "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`" + ) + + @classmethod + def from_seqlens_padded( + cls, + seqlens: Sequence[int], + padding: int, + *, + device: Optional[torch.device] = None, + ) -> "_PaddedSeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + seqstart = padding * torch.arange(batch_size) + """ + assert not isinstance(seqlens, torch.Tensor) + assert all( + seqlen <= padding for seqlen in seqlens + ), f"Seqlens {seqlens} Padding {padding}" + device = _get_default_bias_device(device) + seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + seqlen = torch.tensor(seqlens, dtype=torch.int32, device=device) + return cls( + seqlen=seqlen, + seqlen_py=seqlens, + max_seqlen=max(seqlens), + min_seqlen=min(seqlens), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32, device=device), + seqstart_py=seqstart_py, + padding=padding, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + raise NotImplementedError("_PaddedSeqLenInfo.split") + + +@dataclass +class _GappySeqInfo(_SeqLenInfo): + """ + (Internal) Flexible equivalent of _PaddedSeqLenInfo. There are two + distinct semantics. + + (1) For non-paged masks: + Represents the division of a dimension into blocks which are + anywhere. Each just has a start and a length. The final start is the total + length of the dimension. + + For example, to represent a dimension of length 14 like follows with + three occupied lengths of + 6, 3 and 1, use `from_seqlens_padded([0, 7, 12, 14], [6, 3, 1])`. + + The layout along the dimension is + + 0 ─► block 0 + block 0 + block 0 + block 0 + 4 ─► block 0 + block 0 + + block 1 + 8 ─► block 1 + block 1 + + + 12 ─► block 2 + + + The members will be: + max_seqlen: 6 + min_seqlen: 1 + seqstart_py: [0, 7, 12, 14] + seqstart: torch.IntTensor([0, 7, 12, 14]) + seqlen_py: [6, 3 1] + seqlen: torch.IntTensor([6, 3, 1]) + + (2) For paged masks: + The notional space is divided into batch-size-many blocks. + seqstart and seqstart_py is an offset in the block, not in + the whole space, and doesn't have an extra last element. + Otherwise as above. + """ + + seqlen: torch.Tensor + seqlen_py: Sequence[int] + # From parent: seqstart[i] contains the start position + # of the i-th sequence + # seqstart: torch.Tensor + + def to(self, device: torch.device) -> "_GappySeqInfo": + if self.seqlen.device == device: + return self + return _GappySeqInfo( + # _SeqLenInfo + seqstart=self.seqstart.to(device), + max_seqlen=self.max_seqlen, + min_seqlen=self.min_seqlen, + seqstart_py=self.seqstart_py, + # _GappySeqInfo + seqlen=self.seqlen.to(device), + seqlen_py=self.seqlen_py, + ) + + def intervals(self) -> Iterable[Tuple[int, int]]: + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens( + cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None + ) -> "_SeqLenInfo": + raise NotImplementedError() + + @classmethod + def from_seqlens_gappy( + cls, + seqstarts: Sequence[int], + seqlens: Sequence[int], + paged: bool, + *, + device: torch.device, + ) -> "_GappySeqInfo": + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = list(seqstarts) + if len(seqlens) == 0: + raise ValueError("No elements") + if len(seqstarts) - len(seqlens) != (0 if paged else 1): + extra = "" if paged else "1 + " + raise ValueError( + f"len(seqstarts)={seqstarts} should be {extra}len(seqlens)={seqlens}" + ) + seqlen = torch.tensor(seqlens, dtype=torch.int32, device=device) + return cls( + seqlen=seqlen, + seqlen_py=seqlens, + max_seqlen=max(seqlens), + min_seqlen=min(seqlens), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32, device=device), + seqstart_py=seqstart_py, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + raise NotImplementedError("_PaddedSeqLenInfo.split") + + +@dataclass +class BlockDiagonalMask(AttentionBias): + """ + A block-diagonal mask that can be passed as ``attn_bias`` + argument to :attr:`xformers.ops.memory_efficient_attention`. + + Queries and Keys are each divided into the same number of blocks. + Queries in block i only attend to keys in block i. + + .. figure:: /_static/block_diag_bias.png + + This bias can be used to handle a batch of sequences of + different lengths, via :attr:`BlockDiagonalMask.from_tensor_list` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) + + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _SeqLenInfo + _batch_sizes: Optional[Sequence[int]] = None + + def to(self, device) -> "BlockDiagonalMask": + return BlockDiagonalMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + _batch_sizes=self._batch_sizes, + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return torch.zeros( + shape, + dtype=dtype, + device=device, + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + assert shape[-1] == self.k_seqinfo.seqstart_py[-1], ( + shape[-1], + self.k_seqinfo.seqstart_py[-1], + ) + assert shape[-2] == self.q_seqinfo.seqstart_py[-1], ( + shape[-2], + self.q_seqinfo.seqstart_py[-1], + ) + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Optional[Sequence[int]] = None, + *, + device: Optional[torch.device] = None, + ) -> "BlockDiagonalMask": + """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. + + Args: + q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors + kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value. + (Defaults to ``q_seqlen``.) + Returns: + BlockDiagonalMask + """ + device = _get_default_bias_device(device) + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + if kv_seqlen is None or q_seqlen == kv_seqlen: + k_seqinfo = q_seqinfo + else: + k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen, device=device) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + @classmethod + def from_tensor_list( + cls, + tensors: Sequence[torch.Tensor], + ) -> Tuple["BlockDiagonalMask", torch.Tensor]: + """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors + concatenated on the sequence length dimension + + .. figure:: /_static/block_diag_cat_split.png + + See also :attr:`BlockDiagonalMask.split` to split the returned + :attr:`torch.Tensor` back to a list of tensors of varying sequence length + + Args: + tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``. + All tensors should have the same dimension and the same batch size ``B``, but + they can have different sequence length ``M``. + + Returns: + Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention + along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]`` + """ + batch_sizes = [tensor.shape[0] for tensor in tensors] + seqlens = [] + for x in tensors: + for _ in range(x.shape[0]): + seqlens.append(x.shape[1]) + block_diag = cls.from_seqlens(seqlens) + block_diag._batch_sizes = batch_sizes + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors) + concat_tensors = torch.cat(tensors_bs1, dim=1) + return block_diag, concat_tensors + + @classmethod + def from_tensor_lists_qkv( + cls, + tensors_q: Sequence[torch.Tensor], + tensors_k: Sequence[torch.Tensor], + tensors_v: Optional[Sequence[torch.Tensor]] = None, + ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert len(tensors_q) == len(tensors_k) + assert tensors_v is None or len(tensors_v) == len(tensors_q) + batch_sizes = [tensor.shape[0] for tensor in tensors_q] + q_seqlens, kv_seqlens = [], [] + for i, (q, k) in enumerate(zip(tensors_q, tensors_k)): + assert q.shape[0] == k.shape[0] + q_seqlens += [q.shape[1]] * q.shape[0] + kv_seqlens += [k.shape[1]] * k.shape[0] + assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2] + block_diag = cls.from_seqlens(q_seqlens, kv_seqlens) + block_diag._batch_sizes = batch_sizes + return ( + block_diag, + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1) + if tensors_v is not None + else None, + ) + + def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.k_seqinfo.split(tensor, self._batch_sizes) + + def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list` + + Args: + tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]`` + + Returns: + Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths + """ + assert self.q_seqinfo is self.k_seqinfo + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def make_causal(self) -> "BlockDiagonalCausalMask": + """Makes each block causal""" + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask": + """Makes each block causal with a possible non-causal prefix""" + return BlockDiagonalCausalFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_local_attention( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionMask": + """Experimental: Makes each block causal with local attention""" + return BlockDiagonalCausalLocalAttentionMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + def make_local_attention_from_bottomright( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask": + """Experimental: Makes each block causal with local attention, start from bottom right""" + return BlockDiagonalCausalLocalAttentionFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + +@dataclass +class BlockDiagonalCausalMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + + Queries and Keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which is farther from the initial key in block i than Q + is from the initial query in block i. + """ + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularMask().materialize( + shape, + dtype=dtype, + device=device, + ) + + +@dataclass +class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + This mask allows for a non-causal prefix + NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not + defined (softmax of vector of `-inf` in the attention) + + Queries and keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which nearer the final key in block i than Q is to the + final query in block i. + """ + + def __post_init__(self) -> None: + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + num_queries = q_end - q_start + num_keys = k_end - k_start + if num_keys < num_queries: + raise ValueError( + f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}." + " Expected `num_keys >= num_queries`" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularFromBottomRightMask().materialize( + shape=shape, dtype=dtype, device=device + ) + + +@dataclass +class BlockDiagonalPaddedKeysMask(AttentionBias): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, + except we support padding for k/v + + The keys and values are divided into blocks which are padded out to + the same total length. + For example, if there is space for 12 keys, for three blocks of + max length 4, but we only want to use the first 2, 3 and 2 + of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`. + The queries are divided into blocks, without padding, of lengths given by + q_seqlen. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area). + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + + def to(self, device) -> "BlockDiagonalPaddedKeysMask": + return BlockDiagonalPaddedKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return torch.tensor(0.0, device=device, dtype=dtype) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong") + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong") + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + causal_diagonal: Any = None, + *, + device: Optional[torch.device] = None, + ) -> "BlockDiagonalPaddedKeysMask": + """Creates a :attr:`BlockDiagonalPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalPaddedKeysMask + """ + device = _get_default_bias_device(device) + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded( + kv_seqlen, kv_padding, device=device + ) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + def make_paged( + self, + block_tables: torch.Tensor, + page_size: int, + paged_type: Type["PagedBlockDiagonalPaddedKeysMask"], + ) -> AttentionBias: + paged_bias = paged_type( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + block_tables=block_tables, + page_size=page_size, + ) + paged_bias.k_seqinfo.padding = block_tables.shape[1] * page_size + return paged_bias + + +@dataclass +class BlockDiagonalCausalWithOffsetPaddedKeysMask(BlockDiagonalPaddedKeysMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`, + except an offset on causality is allowed for each block and we support padding for k/v + + The keys and values are divided into blocks which are padded out to + the same total length. + For example, if there is space for 12 keys, for three blocks of + max length 4, but we only want to use the first 2, 3 and 2 + of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`. + The queries are divided into blocks, without padding, of lengths given by + q_seqlen. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area), + nor one which is nearer to the final key in block i + than Q is to the final query in block i. + """ + + causal_diagonal: Any = None # unused. Exists for BC only. + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularFromBottomRightMask().materialize( + shape=shape, dtype=dtype, device=device + ) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + causal_diagonal: Any = None, + *, + device: Optional[torch.device] = None, + ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalCausalWithOffsetPaddedKeysMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + device = _get_default_bias_device(device) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded( + kv_seqlen, kv_padding, device=device + ) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + +@dataclass +class BlockDiagonalCausalLocalAttentionPaddedKeysMask(BlockDiagonalPaddedKeysMask): + """ + Like :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask`, + except with a window size. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area), + nor one which is nearer to the final key in block i + than Q is to the final query in block i, nor one that is more than + window_size further from the final key in block i than Q is + to the final query in block i. + """ + + _window_size: int + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape=shape, + dtype=dtype, + device=device, + window_size=self._window_size, + from_bottomright=True, + ) + + @classmethod + def from_seqlens_local( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + window_size: int, + ) -> "BlockDiagonalCausalLocalAttentionPaddedKeysMask": + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo, _window_size=window_size) + + +@dataclass +class PagedBlockDiagonalPaddedKeysMask(AttentionBias): + """ + Same as BlockDiagonalPaddedKeysMask, but for paged attention. + block_tables has shape [batch_size, max_num_pages] and K/V have shape + [1, max_num_pages * page_size, num_heads, head_dim] + or [1, max_num_pages * page_size, num_groups, num_heads, head_dim] + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + block_tables: torch.Tensor + page_size: int + + _UNPAGED_TYPE: ClassVar[ + Type[BlockDiagonalPaddedKeysMask] + ] = BlockDiagonalPaddedKeysMask + + def to(self, device: torch.device) -> "PagedBlockDiagonalPaddedKeysMask": + return PagedBlockDiagonalPaddedKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + block_tables=self.block_tables.to(device), + page_size=self.page_size, + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + # First create a non-paged mask, then cut individual pages and + # copy them to their places in the physical mask, using block tables + + max_row_len = self.block_tables.shape[1] * self.page_size + bias_nonpaged = self._UNPAGED_TYPE( + q_seqinfo=self.q_seqinfo, + k_seqinfo=_PaddedSeqLenInfo.from_seqlens_padded( + self.k_seqinfo.seqlen_py, max_row_len + ), + ) + mask_nonpaged = bias_nonpaged.materialize(shape, dtype, device) + + n_used_blocks = cast(int, self.block_tables.max().item() + 1) + max_physical_len = n_used_blocks * self.page_size + mask_paged = torch.empty( + mask_nonpaged.shape[:-1] + (max_physical_len,), dtype=dtype, device=device + ) + mask_paged.fill_(-math.inf) + for b, (q_start, q_end) in enumerate(self.q_seqinfo.intervals()): + for logical_page_idx in range(self.block_tables.shape[1]): + physical_page_idx = cast( + int, self.block_tables[b][logical_page_idx].item() + ) + k_logical_start = b * max_row_len + logical_page_idx * self.page_size + k_logical_end = k_logical_start + self.page_size + k_physical_start = physical_page_idx * self.page_size + k_physical_end = k_physical_start + self.page_size + mask_paged[ + ..., q_start:q_end, k_physical_start:k_physical_end + ] = mask_nonpaged[..., q_start:q_end, k_logical_start:k_logical_end] + return mask_paged + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Sequence[int], + block_tables: torch.Tensor, + page_size: int, + *, + device: Optional[torch.device] = None, + ) -> "PagedBlockDiagonalPaddedKeysMask": + """Creates a :attr:`PagedBlockDiagonalPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + PagedBlockDiagonalPaddedKeysMask + """ + assert len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + device = _get_default_bias_device(device) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded( + kv_seqlen, padding=block_tables.shape[1] * page_size, device=device + ) + return cls( + q_seqinfo=q_seqinfo, + k_seqinfo=k_seqinfo, + block_tables=block_tables, + page_size=page_size, + ) + + +@dataclass +class PagedBlockDiagonalCausalWithOffsetPaddedKeysMask( + PagedBlockDiagonalPaddedKeysMask +): + """ + Same as BlockDiagonalCausalWithOffsetPaddedKeysMask, but for paged attention. + block_tables has shape [batch_size, max_num_pages] and K/V have shape + [1, max_num_pages * page_size, num_heads, head_dim] + or [1, max_num_pages * page_size, num_groups, num_heads, head_dim] + """ + + _UNPAGED_TYPE = BlockDiagonalCausalWithOffsetPaddedKeysMask + + +@dataclass +class BlockDiagonalGappyKeysMask(AttentionBias): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, + except k/v is gappy. + + A query Q in block i only attends to a key which is in block i. + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _GappySeqInfo + + def to(self, device: torch.device) -> "BlockDiagonalGappyKeysMask": + return BlockDiagonalGappyKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong", (shape, self.k_seqinfo)) + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong", (shape, self.q_seqinfo)) + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = 0 + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqstarts: Sequence[int], + kv_seqlen: Sequence[int], + *, + device: Optional[torch.device] = None, + ) -> "BlockDiagonalGappyKeysMask": + """Creates a :attr:`BlockDiagonalGappyKeysMask` from a list of tensor + lengths for query and key/value. + """ + assert len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + device = _get_default_bias_device(device) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _GappySeqInfo.from_seqlens_gappy( + kv_seqstarts, kv_seqlen, False, device=device + ) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + def make_paged( + self, + block_tables: torch.Tensor, + page_size: int, + notional_padding: int, + paged_type: Type["PagedBlockDiagonalGappyKeysMask"], + ) -> AttentionBias: + """ + Assuming our keys actually live in separate blocks of length + notional_padding, convert to a Paged version. + """ + # Our child class does not yet have a paged version. + assert self.__class__ is BlockDiagonalGappyKeysMask + max_row_len = block_tables.shape[1] * page_size + new_seqstarts = [ + start - i * notional_padding + for i, start in enumerate(self.k_seqinfo.seqstart_py[:-1]) + ] + assert all(0 <= i < max_row_len for i in new_seqstarts) + k_seqinfo = _GappySeqInfo.from_seqlens_gappy( + new_seqstarts, self.k_seqinfo.seqlen_py, True, device=block_tables.device + ) + assert self.k_seqinfo.max_seqlen <= max_row_len + paged_bias = paged_type( + q_seqinfo=self.q_seqinfo, + k_seqinfo=k_seqinfo, + block_tables=block_tables, + page_size=page_size, + ) + return paged_bias + + +@dataclass +class BlockDiagonalCausalWithOffsetGappyKeysMask(BlockDiagonalGappyKeysMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`, + except k/v is gappy. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is nearer to the final key in block i + than Q is to the final query in block i. + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong") + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong") + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[ + q_start:q_end, k_start:k_end + ] = LowerTriangularFromBottomRightMask().materialize( + shape=(q_end - q_start, k_end - k_start), dtype=dtype, device=device + ) + + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + +@dataclass +class PagedBlockDiagonalGappyKeysMask(AttentionBias): + """ + Equivalent BlockDiagonalGappyKeysMask, but for paged attention. + block_tables has shape [batch_size, max_num_pages] and K/V have shape + [1, max_num_pages * page_size, num_heads, head_dim] + or [1, max_num_pages * page_size, num_groups, num_heads, head_dim] + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _GappySeqInfo + block_tables: torch.Tensor + page_size: int + + _UNPAGED_TYPE: ClassVar[ + Type[BlockDiagonalGappyKeysMask] + ] = BlockDiagonalGappyKeysMask + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + # First create a non-paged mask, then cut individual pages and + # copy them to their places in the physical mask, using block tables + + max_row_len = self.block_tables.shape[1] * self.page_size + new_seqstarts = [ + start + i * max_row_len + for i, start in enumerate(self.k_seqinfo.seqstart_py) + ] + [shape[-1]] + bias_nonpaged = self._UNPAGED_TYPE( + q_seqinfo=self.q_seqinfo, + k_seqinfo=_GappySeqInfo.from_seqlens_gappy( + new_seqstarts, + self.k_seqinfo.seqlen_py, + False, + device=torch.device(device), + ), + ) + mask_nonpaged = bias_nonpaged.materialize(shape, dtype, device) + + n_used_blocks = cast(int, self.block_tables.max().item() + 1) + max_physical_len = n_used_blocks * self.page_size + mask_paged = torch.empty( + mask_nonpaged.shape[:-1] + (max_physical_len,), dtype=dtype, device=device + ) + mask_paged.fill_(-math.inf) + for b, (q_start, q_end) in enumerate(self.q_seqinfo.intervals()): + for logical_page_idx in range(self.block_tables.shape[1]): + physical_page_idx = cast( + int, self.block_tables[b][logical_page_idx].item() + ) + k_logical_start = b * max_row_len + logical_page_idx * self.page_size + k_logical_end = k_logical_start + self.page_size + k_physical_start = physical_page_idx * self.page_size + k_physical_end = k_physical_start + self.page_size + mask_paged[ + ..., q_start:q_end, k_physical_start:k_physical_end + ] = mask_nonpaged[..., q_start:q_end, k_logical_start:k_logical_end] + return mask_paged + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqstarts: Sequence[int], + kv_seqlen: Sequence[int], + block_tables: torch.Tensor, + page_size: int, + *, + device: Optional[torch.device] = None, + ) -> "PagedBlockDiagonalGappyKeysMask": + """Creates a :attr:`PagedBlockDiagonalGappyKeysMask` from a list of tensor + lengths for query and key/value. + + Note that unlike :attr:`BlockDiagonalGappyKeysMask`, kv_seqstarts is + addressing in a different space for each batch element. For example + if you were doing a BlockDiagonalPaddedKeysMask with two batch + elements and padding=100, but wanted to change it so that the first + key is ignored, then you would use BlockDiagonalGappyKeysMask with kv_seqstarts + [1, 101, 200]. But if you were using PagedBlockDiagonalPaddedKeysMask + but wanted to ignore the first key, you would provide this function with + kv_seqstarts = [1, 1]. + """ + assert len(q_seqlen) == len(kv_seqlen) == len(kv_seqstarts), ( + q_seqlen, + kv_seqlen, + kv_seqstarts, + ) + device = block_tables.device if device is None else device + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _GappySeqInfo.from_seqlens_gappy( + kv_seqstarts, kv_seqlen, True, device=device + ) + return cls( + q_seqinfo=q_seqinfo, + k_seqinfo=k_seqinfo, + block_tables=block_tables, + page_size=page_size, + ) + + +@dataclass +class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q - self._window_size >= k: + # Each query only attends to keys no further than window_size back. + # When q > k + window_size, there will be a query for which the window doesn't reach any key. + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + ) + + +@dataclass +class BlockDiagonalCausalLocalAttentionFromBottomRightMask( + BlockDiagonalCausalFromBottomRightMask +): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + super().__post_init__() + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + from_bottomright=True, + ) + + +class AttentionBiasSubTensor(torch.Tensor, AttentionBias): + HOLDS_DENSE_TENSOR = False + + _subtensor: torch.Tensor + + @staticmethod + def __new__(cls, *, _subtensor=None): + if _subtensor is None: + _subtensor = torch.empty((0,), device=_get_default_bias_device()) + tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + [], + device=_subtensor.device, + dtype=_subtensor.dtype, + requires_grad=False, + ) + tensor._subtensor = _subtensor + return tensor + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def __repr__(self): + return f"{self.__class__.__name__}" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if func._overloadpacket in [ + torch.ops.aten.clone, + torch.ops.aten.detach, + torch.ops.aten._to_copy, + torch.ops.aten.to, + ]: + return cls(_subtensor=func(args[0]._subtensor, *args[1:], **kwargs)) + return NotImplemented + + def __tensor_flatten__(self): + return ["_subtensor"], None + + @classmethod + def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride): + assert meta is None + return cls(_subtensor=inner_tensors["_subtensor"]) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """ + Materializes the bias as a `torch.Tensor`. This is very slow + and we don't attempt to make it fast. Only use for debugging/testing. + + Shape should be like `[*, q_seqlen, k_seqlen]` + """ + raise NotImplementedError() + + +class _AddDenseBias(torch.autograd.Function): + @staticmethod + def forward(ctx, causal_bias, tensor): + assert type(causal_bias) is LowerTriangularMask + return LowerTriangularMaskWithTensorBias(tensor) + + @staticmethod + def backward(ctx, grad_out): + return None, grad_out + + +class LowerTriangularMask(AttentionBiasSubTensor): + """ + A lower-triangular (aka causal) mask + + A query Q cannot attend to a key which is farther from the + initial key than Q is from the initial query. + + See also :attr:`LowerTriangularFromBottomRightMask` if the number + of queries is not equal to the number of keys/values. + """ + + HOLDS_DENSE_TENSOR = False + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask(shape, dtype=dtype, device=device) + + def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias": + """ + Creates a new causal mask with an arbitrary ``torch.Tensor`` bias + """ + return _AddDenseBias.apply(self, bias) + + +class LowerTriangularMaskWithTensorBias(LowerTriangularMask): + """A lower-triangular (aka causal) mask with an additive bias""" + + HOLDS_DENSE_TENSOR = True + + @staticmethod + def __new__(cls, bias): + tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + bias.shape, + device=bias.device, + dtype=bias.dtype, + requires_grad=bias.requires_grad, + ) + tensor._subtensor = bias + return tensor + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return super().materialize(shape, dtype=dtype, device=device) + self._subtensor + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if func._overloadpacket in [ + torch.ops.aten.unsqueeze, + torch.ops.aten.select, + torch.ops.aten.slice, + torch.ops.aten.clone, + torch.ops.aten.detach, + torch.ops.aten._to_copy, + torch.ops.aten.to, + ]: + output = func( + *[a._subtensor if isinstance(a, cls) else a for a in args], + **kwargs, + ) + return cls(output) + return NotImplemented + + +torch._dynamo.allow_in_graph(LowerTriangularMask) +torch._dynamo.allow_in_graph(LowerTriangularMaskWithTensorBias) + +VARLEN_BIASES = ( + BlockDiagonalMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, +) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck.py new file mode 100644 index 0000000000000000000000000000000000000000..a4defb17c311cd2e07fc39d939ff17366698d811 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import replace +from enum import Enum +from functools import partial +from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union + +import torch + +from ..common import get_operator, register_operator +from . import attn_bias +from .attn_bias import ( + AttentionBias, + AttentionBiasSubTensor, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, + LowerTriangularFromBottomRightLocalAttentionMask, + LowerTriangularFromBottomRightMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + _attn_bias_apply, + check_lastdim_alignment_stride1, +) + + +def _minimum_gemm_alignment(inp: Inputs) -> int: + return 1 + + +def _get_seqlen_info( + inp: Inputs, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]: + attn_bias = inp.attn_bias + if isinstance( + attn_bias, + (BlockDiagonalMask, BlockDiagonalPaddedKeysMask, BlockDiagonalGappyKeysMask), + ): + attn_bias.k_seqinfo.to(inp.query.device) + attn_bias.q_seqinfo.to(inp.query.device) + seqstart_k = attn_bias.k_seqinfo.seqstart + seqstart_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + max_seqlen_k = attn_bias.k_seqinfo.max_seqlen + else: + seqstart_k = None + seqstart_q = None + max_seqlen_q = -1 + max_seqlen_k = -1 + + return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k + + +def _get_tensor_bias( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> Optional[torch.Tensor]: + if isinstance(attn_bias, AttentionBiasSubTensor): + if isinstance(attn_bias, LowerTriangularMaskWithTensorBias): + return attn_bias._subtensor + elif isinstance(attn_bias, torch.Tensor): + return attn_bias + return None + + +def _check_bias_alignment( + reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> None: + attn_bias_tensor = _get_tensor_bias(attn_bias) + if attn_bias_tensor is not None: + alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits + show_padding_hint = False + for d in range(attn_bias_tensor.ndim - 1): + if attn_bias_tensor.stride(d) % alignment != 0: + reasons.append( + f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})" + ) + show_padding_hint = True + if show_padding_hint: + reasons.append( + """\ +HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \ +you need to ensure memory is aligned by slicing a bigger tensor. \ +Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`""" + ) + # We can have stride=0 sometimes if dimension=1 + if attn_bias_tensor.stride(-1) > 1: + reasons.append( + f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - " + "you should call `.contiguous()` on the bias" + ) + + +class _CustomMaskType(int, Enum): + """ + (Matches CustomMaskType in C++.) + """ + + NoCustomMask = 0 + CausalFromTopLeft = 1 + CausalFromBottomRight = 2 + + +def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: + if isinstance( + bias, + ( + LowerTriangularMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + ), + ): + return int(_CustomMaskType.CausalFromTopLeft) + if isinstance( + bias, + ( + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ): + return int(_CustomMaskType.CausalFromBottomRight) + return int(_CustomMaskType.NoCustomMask) + + +@register_operator +class FwOp(AttentionFwOpBase): + """xFormers' MHA kernel based on Composable Kernel.""" + + OPERATOR = get_operator("xformers", "efficient_attention_forward_ck") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 256 + + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + torch.Tensor, + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + attn_bias.BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ) + + SUPPORTS_DROPOUT = True + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = True + SUPPORTS_PARTIAL = True + SUPPORTS_BMGHK = True + NAME = "ckF" + + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 3e-4, + torch.half: 6e-3, + torch.bfloat16: 2.8e-2, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 2e-5, + torch.half: 3e-3, + torch.bfloat16: 2e-2, + } + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 128, # 64x128 kernel + 256, # 64x128 with accumulation in gmem + ] + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + if inp.query.ndim in [3, 4]: + return cls.apply_bmhk(inp, needs_gradient=needs_gradient) + assert inp.query.ndim == 5, f"query has shape {inp.query.shape}" + ctx: Optional[Context] = None + # XXX: Hackfix for BMGHK with H=1 + # In that case we don't want to run G different streams because it adds + # some overhead + if inp.query.ndim == 5 and inp.query.shape[3] == 1: + slice_op = partial(torch.squeeze, dim=3) + inp = replace( + inp, + query=slice_op(inp.query), + key=slice_op(inp.key), + value=slice_op(inp.value), + attn_bias=_attn_bias_apply( + inp.attn_bias, partial(torch.squeeze, dim=2) + ), + ) + out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient) + out = out.unsqueeze(3) + if ctx is not None: + ctx = replace(ctx, lse=ctx.lse.unsqueeze(1), out=out) + return out, ctx + + # Workaround until this is properly implemented in C++ + # run each head group in a different stream + n_groups = inp.key.shape[2] + main_stream = torch.cuda.current_stream() + streams = [main_stream] + [ + torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1) + ] + outs = [] + for group, stream in enumerate(streams): + stream.wait_stream(main_stream) + with torch.cuda.stream(stream): + query = inp.query[:, :, group] + key = inp.key[:, :, group] + value = inp.value[:, :, group] + bias = _attn_bias_apply( + inp.attn_bias, partial(torch.select, dim=1, index=group) + ) + outs.append( + cls.apply_bmhk( + replace(inp, query=query, key=key, value=value, attn_bias=bias), + needs_gradient=needs_gradient, + ) + ) + for s in streams[1:]: + main_stream.wait_stream(s) + out = torch.stack([o[0] for o in outs], dim=2) + if needs_gradient: + ctx = Context( + out=out, + lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore + op_bw=outs[0][1].op_bw, # type: ignore + ) + return out, ctx + + @classmethod + def apply_bmhk( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp) + out, lse, rng_seed, rng_offset = cls.OPERATOR( + query=inp.query, + key=inp.key, + value=inp.value, + attn_bias=_get_tensor_bias(inp.attn_bias), + seqstart_q=seqstart_q, + seqstart_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + dropout_p=inp.p, + compute_logsumexp=needs_gradient, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + seqlen_k=( + inp.attn_bias.k_seqinfo.seqlen + if isinstance( + inp.attn_bias, + ( + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + ), + ) + else None + ), + window_size=( + inp.attn_bias._window_size + if isinstance( + inp.attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + ), + ) + else None + ), + ) + + ctx: Optional[Context] = None + if needs_gradient: + ctx = Context( + out=out, + lse=lse, + # cutlass forward is only compatible with cutlass backward if + # dropout is used (because of the way RNG states are passed and the + # way random numbers are generated during backward) + op_bw=BwOp if inp.p != 0 else None, + ) + if inp.p != 0: + ctx.rng_state = torch.tensor( + [rng_seed, rng_offset], dtype=torch.int64, device="cpu" + ) + return out, ctx + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + return reasons + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_operator("xformers", "efficient_attention_backward_ck") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = 256 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + torch.Tensor, + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + # TODO: Fix handling of gradient through the fMHA autograd function + # LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + attn_bias.BlockDiagonalCausalLocalAttentionMask, + ) + SUPPORTS_ATTN_BIAS_GRAD = True + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + SUPPORTS_UNPADDED_LSE = True + NAME = "ckB" + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 64, + 128, # 64x128/128x128 kernel + 256, + ] + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + attn_bias_tensor = _get_tensor_bias(d.attn_bias) + + # Backprop of gradient through broadcasted bias is not supported + if attn_bias_tensor is not None and attn_bias_tensor.requires_grad: + # Don't forget that inputs are either in BMK or BMHK! + if d.query.ndim == 3 and attn_bias_tensor.ndim == 3: + expected_bias_shape = (*d.query.shape[:2], d.key.shape[1]) + else: + # bias is B H Mq Mk + expected_bias_shape = ( + d.query.shape[0], + d.query.shape[2] if d.query.ndim == 4 else 1, + d.query.shape[1], + d.key.shape[1], + ) + if tuple(attn_bias_tensor.shape) != expected_bias_shape: + reasons.append( + "Broadcasting the `attn_bias` tensor is not supported " + f"(shape: {tuple(attn_bias_tensor.shape)}" + f"/ expected: {expected_bias_shape})" + ) + + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + + seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp) + dtype = inp.query.dtype + + rng_seed = rng_offset = 0 + if inp.p != 0.0: + if ( + ctx.rng_state is None + or ctx.rng_state.dtype != torch.int64 + or ctx.rng_state.device.type != "cpu" + or ctx.rng_state.shape != (2,) + ): + raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}") + rng_seed, rng_offset = ctx.rng_state.tolist() + + (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR( + grad.to(dtype), + inp.query, + inp.key, + inp.value, + attn_bias=_get_tensor_bias(inp.attn_bias), + seqstart_q=seqstart_q, + seqstart_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqlen_k=( + inp.attn_bias.k_seqinfo.seqlen + if isinstance( + inp.attn_bias, + ( + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + ), + ) + else None + ), + logsumexp=ctx.lse, + output=ctx.out.to(dtype), + dropout_p=inp.p, + # if not using dropout, seed and offset are irrelevant but still expected + # in function signature so just pass 0 + # seed and offset could be None if a different FW op other than cutlass + # was used. + rng_seed=rng_seed, + rng_offset=rng_offset, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + window_size=( + inp.attn_bias._window_size + if isinstance( + inp.attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + ), + ) + else None + ), + ) + + # c++/CUDA implementation returns an uninitialized tensor if bias doesn't + # require grad + if not ( + isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad + ): + grad_bias = None + + return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck_decoder.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c820bfc7d7c33df239c51b47f0cc89e1b348db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck_decoder.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Iterable, List, Optional, Set, Tuple + +import torch + +from ..common import get_operator, register_operator +from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from .common import AttentionFwOpBase, Context, Inputs + + +@register_operator +class FwOp(AttentionFwOpBase): + """ + An operator optimized for K=256 (so the contiguous dim fits into registers). + Tested to work on MI250x. + """ + + OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_ck") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16, torch.float} + SUPPORTED_MAX_K: int = 256 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + BlockDiagonalCausalWithOffsetPaddedKeysMask, + ) + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "ck_decoderF" + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + + attn_bias = d.attn_bias + if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + if d.query.shape[0] != 1: + reasons.append( + f"One formal batch element expected; got {d.query.shape[0]}" + ) + + if d.query.shape[-1] > cls.SUPPORTED_MAX_K: + reasons.append( + f"Got head_dim={d.query.shape[-1]}; only head_dim<={cls.SUPPORTED_MAX_K} is supported for now." + ) + + threads_per_warp = 64 # TODO: ideally query the platform here + required_alignment = 0 + head_dim = d.query.shape[-1] + for vec_size in (4, 2, 1): + if head_dim <= vec_size * threads_per_warp: + required_alignment = vec_size + + if not required_alignment: + reasons.append(f"Got head_dim={head_dim} which is too large") + + if head_dim % required_alignment != 0: + reasons.append( + f"Got head_dim={head_dim}; it needs to be divisible by {required_alignment}" + ) + + if d.key.stride(-1) != 1: + reasons.append("expect keys to have last dim contiguous") + + if d.value.stride(-1) != 1: + reasons.append("expect values to have last dim contiguous") + + q_starts = attn_bias.q_seqinfo.seqstart_py + padding = attn_bias.k_seqinfo.padding + bsz = d.key.shape[1] // padding + num_queries = d.query.shape[1] // bsz + + if q_starts != list(range(0, 1 + bsz, num_queries)): + reasons.append("expect to have same num_queries in each batch") + if bsz != len(q_starts) - 1: + reasons.append("empty lanes not supported yet") + + if attn_bias.k_seqinfo.padding > 8192: + reasons.append("key padding exceeds 8192") + + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if needs_gradient: + raise NotImplementedError("backward pass is not supported") + attn_bias = inp.attn_bias + q, k, v = inp.get_qkv_in_bmghk() + if attn_bias is not None: + assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + attn_bias.k_seqinfo.to(k.device) + attn_bias.q_seqinfo.to(q.device) + padding = attn_bias.k_seqinfo.padding + seq_positions_gpu = attn_bias.k_seqinfo.seqlen + else: + padding = k.shape[1] + seq_positions_gpu = None + + if attn_bias is not None: + # key: (1, B * padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (1, B * q_seqlen, G, Hq, D) + multiquery = k.stride(3) == 0 + if multiquery: + key = k[0, :, :, :1].unflatten(0, (-1, padding)) + value = v[0, :, :, :1].unflatten(0, (-1, padding)) + else: + key = k[0].unflatten(0, (-1, padding)) + value = v[0].unflatten(0, (-1, padding)) + query = q[0].unflatten(0, (key.shape[0], -1)) + else: + # key: (B, padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (B, q_seqlen, G, Hq, D) + key = k + query = q + value = v + + if inp.scale is not None: + qk_scale = inp.scale + else: + qk_scale = torch.rsqrt( + torch.tensor(key.shape[-1], dtype=torch.float32) + ).item() + + out = cls.OPERATOR( + query=query, + key=key, + value=value, + seq_positions=seq_positions_gpu, + scale=qk_scale, + ) + return out, None diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck_splitk.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7af07945ec88643c407a52910fec6868c183ee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/ck_splitk.py @@ -0,0 +1,208 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Iterable, List, Optional, Tuple + +import torch + +from xformers.ops.common import get_operator, register_operator +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from xformers.ops.fmha.common import ( + AttentionFwOpBase, + Context, + Inputs, + check_lastdim_alignment_stride1, +) + + +@register_operator +class FwOp(AttentionFwOpBase): + + OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_splitk_ck") + SUPPORTED_DEVICES = {"cuda"} + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + torch.float, + } # Those are dtypes of Q. In the quantized case K/V has dtype int32 + SUPPORTED_MAX_K = 256 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + BlockDiagonalCausalWithOffsetPaddedKeysMask, + ) + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "ck_splitKF" + + SPLIT_K: Optional[int] = None + BLOCK_M = 16 + BLOCK_N = 64 + + NUM_GROUPS = 1 # Default quantization is row-wise + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + # if K not in {16, 32, 64, 128}: + # reasons.append(f"Embed dim {K} not supported") + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + if d.key.dtype != torch.int32: + check_lastdim_alignment_stride1(reasons, "key", d.key, 8) + check_lastdim_alignment_stride1(reasons, "value", d.value, 8) + if cls.OPERATOR is None: + reasons.append("triton is not available") + if d.device.type == "cuda": + # Has only been tested on 8.0 / 9.0. + if torch.cuda.get_device_capability(d.device) < (7, 0): + reasons.append( + "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" + ) + + q_len = d.query.shape[1] + if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + seqinfo = d.attn_bias.q_seqinfo + if q_len != seqinfo.seqstart_py[-1]: + reasons.append( + f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}" + ) + q_len = seqinfo.min_seqlen + if q_len != seqinfo.max_seqlen: + reasons.append( + "Variable query len is not supported in the presence of causal mask." + ) + + if d.key.ndim in [4, 5] and d.key.shape[-2] != 1: + if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1: + reasons.append("multiquery is only supported with query seqlen=1") + + if d.attn_bias is not None and q_len > 1: + reasons.append( + "query with seqlen > 1 is not supported in the presence of causal mask" + ) + return reasons + + @classmethod + def get_split_k(cls, B: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + split_k = min(split_k, 64) + split_k = max(split_k, 1) + return split_k + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + attn_bias = inp.attn_bias + q, k, v = inp.get_qkv_in_bmghk() + + if attn_bias is not None: + assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + attn_bias.k_seqinfo.to(k.device) + attn_bias.q_seqinfo.to(q.device) + padding = attn_bias.k_seqinfo.padding + seq_positions_gpu = attn_bias.k_seqinfo.seqlen + else: + padding = k.shape[1] + seq_positions_gpu = None + + if attn_bias is not None: + # key: (1, B * padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (1, B * q_seqlen, G, Hq, D) + multiquery = k.stride(3) == 0 + if multiquery: + key = k[0, :, :, :1].unflatten(0, (-1, padding)) + value = v[0, :, :, :1].unflatten(0, (-1, padding)) + else: + key = k[0].unflatten(0, (-1, padding)) + value = v[0].unflatten(0, (-1, padding)) + query = q[0].unflatten(0, (key.shape[0], -1)) + else: + # key: (B, padding, G, 1 if multiquery else Hkv, D) + # value: like key + # query: (B, q_seqlen, G, Hq, D) + key = k + query = q + value = v + + B, _, _, H, _ = query.shape + _, Mk, _, _, _ = key.shape + + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = cls.get_split_k(B, H, Mk) + + if inp.scale is not None: + qk_scale = inp.scale + else: + qk_scale = torch.rsqrt( + torch.tensor(k.shape[-1], dtype=torch.float32) + ).item() + + out = cls.OPERATOR( + query=query, + key=key, + value=value, + seq_positions=seq_positions_gpu, + scale=qk_scale, + split_k=split_k, + ) + + return out, None + + +class FwOp_S1(FwOp): + SPLIT_K = 1 + NAME = "ck_splitK1" + + +class FwOp_S2(FwOp): + SPLIT_K = 2 + NAME = "ck_splitK2" + + +class FwOp_S4(FwOp): + SPLIT_K = 4 + NAME = "ck_splitK4" + + +class FwOp_S8(FwOp): + SPLIT_K = 8 + NAME = "ck_splitK8" + + +class FwOp_S16(FwOp): + SPLIT_K = 16 + NAME = "ck_splitK16" + + +class FwOp_S32(FwOp): + SPLIT_K = 32 + NAME = "ck_splitK32" + + +class FwOp_S64(FwOp): + SPLIT_K = 64 + NAME = "ck_splitK64" + + +class FwOp_S128(FwOp): + SPLIT_K = 128 + NAME = "ck_splitK128" diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/common.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c862a50a99eba80963968f2ad2081020168fd71c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/common.py @@ -0,0 +1,495 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from functools import partial +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) + +import torch + +from ..._cpp_lib import _built_with_cuda +from ..common import BaseOperator +from .attn_bias import ( + AttentionBias, + AttentionBiasSubTensor, + BlockDiagonalGappyKeysMask, + BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, + LowerTriangularMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, +) + + +def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool: + # NoneType + if isinstance(None, attn_bias_type): + return True + if attn_bias_type in [LowerTriangularMask, torch.Tensor]: + return True + return False + + +def _attn_bias_apply( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]], + op: Callable[[torch.Tensor], torch.Tensor], +) -> Optional[Union[torch.Tensor, AttentionBias]]: + if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim != 0: + return op(attn_bias) + return attn_bias + + +@dataclass +class Inputs: + """ + Stores inputs to the `memory_efficient_attention` operators + """ + + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None + p: float = 0.0 + scale: Optional[float] = None + output_dtype: Optional[torch.dtype] = None + is_partial: bool = False + + @property + def device(self) -> torch.device: + return self.query.device + + @property + def scale_float(self) -> float: + return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale + + def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.query.ndim == 5: + return self.query, self.key, self.value + if self.query.ndim == 4: + return ( + self.query.unsqueeze(2), + self.key.unsqueeze(2), + self.value.unsqueeze(2), + ) + if self.value.ndim == 3: + return ( + self.query[:, :, None, None], + self.key[:, :, None, None], + self.value[:, :, None, None], + ) + assert False + + def normalize_bmhk(self) -> Tuple[int, ...]: + if self.query.ndim not in [3, 4, 5]: + raise ValueError( + f"Invalid shape for query: {self.query.shape}. " + "Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]" + ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]." + ) + if self.value.dtype == torch.int32: + # Quantized K/V case, in which the last dims of Q and K are different. + # NB we currently don't have any implementations for quantized KV with + # SUPPORTS_DIFFERENT_VALUE_EMBED. + output_shape = tuple(self.query.shape) + else: + output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],) + # Convert from legacy format + if self.query.ndim == 3: + self.query = self.query.unsqueeze(2) + self.key = self.key.unsqueeze(2) + self.value = self.value.unsqueeze(2) + self.attn_bias = _attn_bias_apply( + self.attn_bias, partial(torch.unsqueeze, dim=1) + ) + return output_shape + + def validate_inputs(self) -> None: + qkv = (self.query, self.key, self.value) + if self.query.ndim not in (3, 4, 5) or any( + x.ndim != self.query.ndim for x in qkv + ): + raise ValueError( + f"Query/Key/Value should all have BMGHK, BMHK or BMK shape.\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if any(x.device != self.query.device for x in qkv): + raise ValueError("Query/Key/Value should all be on the same device") + if isinstance( + self.attn_bias, + ( + BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, + BlockDiagonalGappyKeysMask, + PagedBlockDiagonalGappyKeysMask, + ), + ): + bias_device = self.attn_bias.q_seqinfo.seqstart.device + if bias_device != self.query.device: + raise ValueError( + f"Attention bias and Query/Key/Value should be on the same device\n" + f" query.device: {self.query.device}\n" + f" attn_bias : {bias_device}\n" + ) + + quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32 + non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv) + if not (quantized_dtypes or non_quantized_dtypes): + raise ValueError( + "Query/Key/Value should either all have the same dtype, or " + "(in the quantized case) Key/Value should have dtype torch.int32\n" + f" query.dtype: {self.query.dtype}\n" + f" key.dtype : {self.key.dtype}\n" + f" value.dtype: {self.value.dtype}" + ) + # Biases with tensors attached are meant to be in BMHK format + # This would require to permute biases/gradients which can be expensive, + # so let's just forbid it - BMK is a legacy format anyway + if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK( + type(self.attn_bias) + ): + raise ValueError( + f"Please provide inputs in BMHK format rather " + f"than BMK when using bias type `{type(self.attn_bias).__name__}`" + ) + attn_bias_t: Optional[torch.Tensor] = None + if isinstance(self.attn_bias, AttentionBiasSubTensor): + if self.attn_bias.HOLDS_DENSE_TENSOR: + attn_bias_t = self.attn_bias._subtensor + elif isinstance(self.attn_bias, torch.Tensor): + attn_bias_t = self.attn_bias + if self.query.ndim == 4 and attn_bias_t is not None: + expected_shape = ( + self.query.shape[0], + self.query.shape[2], + self.query.shape[1], + self.key.shape[1], + ) + if attn_bias_t.shape != expected_shape: + raise ValueError( + f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if isinstance(self.attn_bias, BlockDiagonalMask): + if any(x.shape[0] != 1 for x in qkv): + raise ValueError( + f"Expected batch_size=1 when using block-diagonal bias\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if self.p < 0.0 or self.p > 1.0: + raise ValueError(f"Invalid dropout probability: p={self.p}") + # Check that shapes match between inputs + B, Mq = self.query.shape[:2] + K = self.query.shape[-1] + B, Mkv = self.key.shape[:2] + Kv = self.value.shape[-1] + quantized_kv_cache = self.value.dtype == torch.int32 + key_embed_dim = Kv if quantized_kv_cache else K + + valid_shapes = True + if self.query.ndim == 3: # BMK + valid_shapes = ( + self.query.shape == (B, Mq, K) + and self.key.shape == (B, Mkv, K) + and self.value.shape == (B, Mkv, Kv) + ) + H = self.query.shape[-2] + if self.query.ndim == 4: # BMHK + valid_shapes = ( + self.query.shape == (B, Mq, H, K) + and self.key.shape == (B, Mkv, H, key_embed_dim) + and self.value.shape == (B, Mkv, H, Kv) + ) + G = self.query.shape[2] + if self.query.ndim == 5: # BMNHK + valid_shapes = ( + self.query.shape == (B, Mq, G, H, K) + and self.key.shape == (B, Mkv, G, H, key_embed_dim) + and self.value.shape == (B, Mkv, G, H, Kv) + ) + if not valid_shapes: + raise ValueError( + f"Incompatible shapes for attention inputs:\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}\n" + "HINT: We don't support broadcasting, please use `expand` " + "yourself before calling `memory_efficient_attention` if you need to" + ) + + def get_output_dtype(self) -> torch.dtype: + if self.output_dtype is None: + if self.is_partial and self.query.dtype is not torch.float64: + return torch.float32 + return self.query.dtype + return self.output_dtype + + @property + def nbytes(self) -> int: + """ + Number of bytes in the input, not counting the attention bias. + """ + return sum( + x.untyped_storage().nbytes() for x in [self.query, self.key, self.value] + ) + + +@dataclass +class Context: + lse: torch.Tensor + out: torch.Tensor + # NOTE: If `rng_state` is set, `op_bw` should be set as well + # as the randomness is backend-dependant + op_bw: Optional[Type["AttentionBwOpBase"]] = None + rng_state: Optional[Any] = None + qkv_share_storage: bool = False + + def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor: + pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to + lse = self.lse + if pad_amount > 0: + if force_pad_inf: + lse = lse[:, :, : self.out.shape[1]] + pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to + lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf) + elif force_pad_inf and self.out.shape[1] != lse.shape[2]: + lse[:, :, self.out.shape[1] :].fill_(math.inf) + return lse + + +@dataclass +class Gradients: + dq: torch.Tensor + dk: torch.Tensor + dv: torch.Tensor + # bias gradient. None if there is no tensor bias or if it doesn't require grad + db: Optional[torch.Tensor] = None + + +class AttentionOpBase(BaseOperator): + """Base class for any attention operator in xFormers + + See: + + - :attr:`xformers.ops.fmha.cutlass.FwOp` + - :attr:`xformers.ops.fmha.cutlass.BwOp` + - :attr:`xformers.ops.fmha.flash.FwOp` + - :attr:`xformers.ops.fmha.flash.BwOp` + - :attr:`xformers.ops.fmha.triton.FwOp` + - :attr:`xformers.ops.fmha.triton.BwOp` + """ + + OPERATOR: Any + SUPPORTED_DEVICES: Set[str] + CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0) + SUPPORTED_DTYPES: Set[torch.dtype] + SUPPORTED_MAX_K: float + SUPPORTED_MIN_K: int = 0 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (type(None),) + SUPPORTS_DROPOUT: bool + SUPPORTS_CUSTOM_SCALE: bool = False + SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False + SUPPORTS_OUTPUT_DTYPE: bool = False + SUPPORTS_PARTIAL: bool = False + IS_DETERMINISTIC: bool = True + SUPPORTS_BMGHK: bool = False + NAME: str + OPERATOR_CATEGORY = "memory_efficient_attention" + # Format for the LSE computed in the FW pass, and accepted in the BW pass, + # for BlockDiagonalMask and children. + # When using a varlen bias, both the FW and BW operators must have the + # same value for `VARLEN_LSE_PACKED` + VARLEN_LSE_PACKED: bool = True + + _TEST_BATCH_SIZES: List[int] = [1, 300] + _TEST_K: List[int] = [32, 128] + + @classmethod + def supports(cls, d: Inputs) -> bool: + return not cls.not_supported_reasons(d) + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = [] + if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv: + reasons.append("query.shape[-1] != value.shape[-1]") + if max(K, Kv) > cls.SUPPORTED_MAX_K: + reasons.append( + f"max(query.shape[-1], value.shape[-1]) > {cls.SUPPORTED_MAX_K}" + ) + if min(K, Kv) < cls.SUPPORTED_MIN_K: + reasons.append( + f"min(query.shape[-1], value.shape[-1]) < {cls.SUPPORTED_MIN_K}" + ) + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + """ + Returns a list of reasons why this is not supported. + The kernel can run these inputs only if the returned list is empty + """ + query_shape = d.query.shape + reasons = cls.shape_not_supported_reasons( + Mq=query_shape[1], + Mkv=d.key.shape[1], + K=query_shape[-1], + Kv=query_shape[-1] if d.value.dtype == torch.int32 else d.value.shape[-1], + ) + device_type = d.query.device.type + dtype = d.query.dtype + if device_type not in cls.SUPPORTED_DEVICES: + reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") + if ( + device_type == "cuda" + and not _built_with_cuda + and (torch.version.hip is None) + ): + reasons.append("xFormers wasn't build with CUDA support") + if device_type == "cuda" and (torch.version.hip is None): + device_capability = torch.cuda.get_device_capability(d.device) + if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: + reasons.append( + f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} " + f"but your GPU has capability {device_capability} (too old)" + ) + if dtype not in cls.SUPPORTED_DTYPES: + reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})") + if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES: + reasons.append(f"attn_bias type is {type(d.attn_bias)}") + if not cls.SUPPORTS_OUTPUT_DTYPE: + if d.output_dtype is not None and d.output_dtype is not dtype: + reasons.append("Custom output dtype not supported") + if d.is_partial and not cls.SUPPORTS_PARTIAL: + reasons.append("Partial attention not supported") + if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT: + reasons.append("dropout > 0.0") + if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE: + reasons.append("has custom scale") + # bfloat16 is only supported on A100+ + # ... although the kernels can still run and give the + # correct result + if dtype is torch.bfloat16 and ( + not device_type.startswith("cuda") + or torch.cuda.get_device_capability(d.query.device)[0] < 8 + ): + reasons.append("bf16 is only supported on A100+ GPUs") + if not cls.is_available(): + reasons.append( + "operator wasn't built - see `python -m xformers.info` for more info" + ) + if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled(): + reasons.append( + "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set" + ) + if not cls.SUPPORTS_BMGHK and d.query.ndim == 5: + reasons.append("operator does not support BMGHK format") + return reasons + + +class AttentionFwOpBase(AttentionOpBase): + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 3e-4, + torch.half: 4e-3, + torch.bfloat16: 2e-2, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 2e-5, + torch.half: 4e-4, + torch.bfloat16: 5e-3, + } + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + raise NotImplementedError() + + +class AttentionBwOpBase(AttentionOpBase): + # NOTE on tolerances: These are tested for `scales => (1/32)**0.5` + # In the BW pass, imprecisions accumulate in the Q@K.T recalculation + # These imprecisions are multiplied by the `scale` and then exponentiated + # So if the scale is too high, we get a lot of errors + + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 9e-4, + torch.half: 0.2, + torch.bfloat16: 0.9, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 1e-4, + torch.half: 2e-2, + torch.bfloat16: 0.1, + } + SUPPORTS_ATTN_BIAS_GRAD = False + SUPPORTS_PARTIAL = True + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d) + if ( + isinstance(d.attn_bias, torch.Tensor) + and d.attn_bias.requires_grad + and not cls.SUPPORTS_ATTN_BIAS_GRAD + ): + reasons.append( + "Computing the bias gradient is not supported (attn_bias.requires_grad = True)" + ) + + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + raise NotImplementedError() + + +AttentionOp = Tuple[ + Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]] +] + + +def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: + if tensor.ndim == 4: + return tensor + return tensor.reshape( + [tensor.shape[0] // num_heads, num_heads, tensor.shape[1], tensor.shape[2]] + ).permute((0, 2, 1, 3)) + + +def check_lastdim_alignment_stride1( + reasons: List[str], name: str, x: torch.Tensor, alignment: int +) -> None: + if x.shape[-1] % alignment != 0: + reasons.append(f"{name}.shape[-1] % {alignment} != 0") + elif x.stride(-2) % alignment != 0: + reasons.append( + f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})" + ) + # We can have stride=0 sometimes if dimension=1 + if x.stride(-1) > 1: + reasons.append( + f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input" + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/cutlass.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..f26252340ad0add0737f67d5b8c22c1bb9fb3759 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/cutlass.py @@ -0,0 +1,470 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import replace +from enum import Enum +from functools import partial +from typing import Any, Iterable, List, Optional, Set, Tuple, Union + +import torch + +from ..common import get_operator, get_xformers_operator, register_operator +from . import attn_bias +from .attn_bias import ( + AttentionBias, + AttentionBiasSubTensor, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalMask, + LowerTriangularFromBottomRightLocalAttentionMask, + LowerTriangularFromBottomRightMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + _attn_bias_apply, + check_lastdim_alignment_stride1, +) +from .torch_attention_compat import is_pt_cutlass_compatible + + +def _uses_tensorcores(sm: int, is_half: bool) -> bool: + if sm >= 80: + return True + if sm >= 70: + return is_half + return False + + +def _minimum_gemm_alignment(inp: Inputs) -> int: + if inp.device.type != "cuda": + return 1 + cap = torch.cuda.get_device_capability(inp.device) + sm = cap[0] * 10 + cap[1] + bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[ + inp.query.dtype + ] + uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16) + matmul_alignment_mn = 1 + if sm >= 80: + matmul_alignment_mn = 4 + if uses_tensorcores: + matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar) + return matmul_alignment_mn + + +def _get_seqlen_info( + inp: Inputs, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]: + attn_bias = inp.attn_bias + if isinstance( + attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask) + ): + assert attn_bias.k_seqinfo.seqstart.device == inp.query.device + seqstart_k = attn_bias.k_seqinfo.seqstart + seqstart_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + max_seqlen_k = attn_bias.k_seqinfo.max_seqlen + else: + seqstart_k = None + seqstart_q = None + max_seqlen_q = -1 + max_seqlen_k = -1 + + return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k + + +def _get_tensor_bias( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> Optional[torch.Tensor]: + if isinstance(attn_bias, AttentionBiasSubTensor): + if isinstance(attn_bias, LowerTriangularMaskWithTensorBias): + return attn_bias._subtensor + elif isinstance(attn_bias, torch.Tensor): + return attn_bias + return None + + +def _check_bias_alignment( + reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> None: + attn_bias_tensor = _get_tensor_bias(attn_bias) + if attn_bias_tensor is not None: + alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits + show_padding_hint = False + for d in range(attn_bias_tensor.ndim - 1): + if attn_bias_tensor.stride(d) % alignment != 0: + reasons.append( + f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})" + ) + show_padding_hint = True + if show_padding_hint: + reasons.append( + """\ +HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \ +you need to ensure memory is aligned by slicing a bigger tensor. \ +Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`""" + ) + # We can have stride=0 sometimes if dimension=1 + if attn_bias_tensor.stride(-1) > 1: + reasons.append( + f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - " + "you should call `.contiguous()` on the bias" + ) + + +class _CustomMaskType(int, Enum): + """ + (Matches CustomMaskType in C++.) + """ + + NoCustomMask = 0 + CausalFromTopLeft = 1 + CausalFromBottomRight = 2 + + +def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: + if isinstance( + bias, + ( + LowerTriangularMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + ), + ): + return int(_CustomMaskType.CausalFromTopLeft) + if isinstance( + bias, + ( + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ): + return int(_CustomMaskType.CausalFromBottomRight) + return int(_CustomMaskType.NoCustomMask) + + +USE_TORCH_CUTLASS = ( + is_pt_cutlass_compatible(force=False) + and hasattr(torch.ops.xformers, "efficient_attention_forward_cutlass") + and not torch._C._dispatch_has_kernel_for_dispatch_key( + "xformers::efficient_attention_forward_cutlass", "CUDA" + ) +) + + +@register_operator +class FwOp(AttentionFwOpBase): + """xFormers' MHA kernel based on CUTLASS. + Supports a large number of settings (including without TensorCores, f32 ...) + and GPUs as old as P100 (Sm60) + """ + + OPERATOR = ( + get_operator("aten", "_efficient_attention_forward") + if USE_TORCH_CUTLASS + else get_xformers_operator("efficient_attention_forward_cutlass") + ) + SUPPORTED_DEVICES: Set[str] = {"cuda"} + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 65536 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + torch.Tensor, + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + attn_bias.BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ) + SUPPORTS_DROPOUT = True + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = True + SUPPORTS_BMGHK = True + VARLEN_LSE_PACKED = False + NAME = "cutlassF-pt" if USE_TORCH_CUTLASS else "cutlassF" + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 128, # 64x128 kernel + 256, # 64x128 with accumulation in gmem + ] + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + if inp.query.ndim in [3, 4]: + return cls.apply_bmhk(inp, needs_gradient=needs_gradient) + assert inp.query.ndim == 5, f"query has shape {inp.query.shape}" + ctx: Optional[Context] = None + # XXX: Hackfix for BMGHK with H=1 + # In that case we don't want to run G different streams because it adds + # some overhead + if inp.query.ndim == 5 and inp.query.shape[3] == 1: + slice_op = partial(torch.squeeze, dim=3) + inp = replace( + inp, + query=slice_op(inp.query), + key=slice_op(inp.key), + value=slice_op(inp.value), + attn_bias=_attn_bias_apply( + inp.attn_bias, partial(torch.squeeze, dim=2) + ), + ) + out, ctx = cls.apply_bmhk(inp, needs_gradient=needs_gradient) + out = out.unsqueeze(3) + if ctx is not None: + ctx = replace(ctx, lse=ctx.lse.unsqueeze(1), out=out) + return out, ctx + + # Workaround until this is properly implemented in C++ + # run each head group in a different stream + n_groups = inp.key.shape[2] + main_stream = torch.cuda.current_stream() + streams = [main_stream] + [ + torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1) + ] + outs = [] + for group, stream in enumerate(streams): + stream.wait_stream(main_stream) + with torch.cuda.stream(stream): + query = inp.query[:, :, group] + key = inp.key[:, :, group] + value = inp.value[:, :, group] + bias = _attn_bias_apply( + inp.attn_bias, partial(torch.select, dim=1, index=group) + ) + outs.append( + cls.apply_bmhk( + replace(inp, query=query, key=key, value=value, attn_bias=bias), + needs_gradient=needs_gradient, + ) + ) + for s in streams[1:]: + main_stream.wait_stream(s) + out = torch.stack([o[0] for o in outs], dim=2) + if needs_gradient: + ctx = Context( + out=out, + lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore + op_bw=outs[0][1].op_bw, # type: ignore + ) + return out, ctx + + @classmethod + def apply_bmhk( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp) + out, lse, rng_seed, rng_offset, _, _ = cls.OPERATOR( + query=inp.query, + key=inp.key, + value=inp.value, + bias=_get_tensor_bias(inp.attn_bias), + cu_seqlens_q=seqstart_q, + cu_seqlens_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=inp.p, + compute_log_sumexp=needs_gradient, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + seqlen_k=( + inp.attn_bias.k_seqinfo.seqlen + if isinstance( + inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask + ) + else None + ), + window_size=( + inp.attn_bias._window_size + if isinstance( + inp.attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + ), + ) + else None + ), + ) + ctx: Optional[Context] = None + if needs_gradient: + ctx = Context(out=out, lse=lse) + if inp.p != 0: + # cutlass forward is only compatible with cutlass backward if + # dropout is used (because of the way RNG states are passed and the + # way random numbers are generated during backward) + ctx.rng_state = (rng_seed, rng_offset) + ctx.op_bw = BwOp + return out, ctx + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + return reasons + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = ( + get_operator("aten", "_efficient_attention_backward") + if USE_TORCH_CUTLASS + else get_xformers_operator("efficient_attention_backward_cutlass") + ) + + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + torch.Tensor, + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + # TODO: Still some infs/nans in the BW pass for + # local + causal + # LowerTriangularFromBottomRightLocalAttentionMask, + # TODO: Fix handling of gradient through the fMHA autograd function + # LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + attn_bias.BlockDiagonalCausalLocalAttentionMask, + ) + SUPPORTS_ATTN_BIAS_GRAD = True + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + VARLEN_LSE_PACKED = False + NAME = "cutlassB-pt" if USE_TORCH_CUTLASS else "cutlassB" + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 128, # 64x128/128x128 kernel + 256, # 64x128 with accumulation in gmem + ] + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + attn_bias_tensor = _get_tensor_bias(d.attn_bias) + + # Backprop of gradient through broadcasted bias is not supported + if attn_bias_tensor is not None and attn_bias_tensor.requires_grad: + # Don't forget that inputs are either in BMK or BMHK! + if d.query.ndim == 3 and attn_bias_tensor.ndim == 3: + expected_bias_shape = (*d.query.shape[:2], d.key.shape[1]) + else: + # bias is B H Mq Mk + expected_bias_shape = ( + d.query.shape[0], + d.query.shape[2] if d.query.ndim == 4 else 1, + d.query.shape[1], + d.key.shape[1], + ) + if tuple(attn_bias_tensor.shape) != expected_bias_shape: + reasons.append( + "Broadcasting the `attn_bias` tensor is not supported " + f"(shape: {tuple(attn_bias_tensor.shape)}" + f"/ expected: {expected_bias_shape})" + ) + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + + seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp) + dtype = inp.query.dtype + + rng_seed = rng_offset = torch.Tensor() + if inp.p != 0.0: + assert ctx.rng_state is not None + rng_seed, rng_offset = ctx.rng_state + tensor_bias = _get_tensor_bias(inp.attn_bias) + + force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5) + (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR( + grad.to(dtype), + inp.query, + inp.key, + inp.value, + bias=tensor_bias, + bias_requires_grad=( + tensor_bias.requires_grad if tensor_bias is not None else False + ), + cu_seqlens_q=seqstart_q, + cu_seqlens_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf), + out=ctx.out.to(dtype), + dropout_p=inp.p, + # if not using dropout, seed and offset are irrelevant but still expected + # in function signature so just pass 0 + # seed and offset could be None if a different FW op other than cutlass + # was used. + philox_seed=rng_seed, + philox_offset=rng_offset, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + num_splits_key=-1, # Let C++ determine it + window_size=( + inp.attn_bias._window_size + if isinstance( + inp.attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + ), + ) + else None + ), + ) + + # c++/CUDA implementation returns an uninitialized tensor if bias doesn't + # require grad + if not ( + isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad + ): + grad_bias = None + + return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/dispatch.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..60252261b5453058162e902d7cb23e06948bc653 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/dispatch.py @@ -0,0 +1,186 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import textwrap +from collections import deque +from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar + +import torch + +from . import attn_bias, ck, cutlass, flash, flash3, triton_splitk +from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs + +T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase]) + + +_USE_FLASH_ATTENTION_3 = False + + +def _set_use_fa3(use_flash_attention3: bool) -> None: + global _USE_FLASH_ATTENTION_3 + _USE_FLASH_ATTENTION_3 = use_flash_attention3 + + +def _get_use_fa3() -> bool: + global _USE_FLASH_ATTENTION_3 + return _USE_FLASH_ATTENTION_3 + + +def _format_inputs_description(inp: Inputs) -> str: + return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype}) +key : shape={tuple(inp.key.shape)} ({inp.key.dtype}) +value : shape={tuple(inp.value.shape)} ({inp.value.dtype}) +attn_bias : {type(inp.attn_bias)} +p : {inp.p}""" + + +def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None: + reasons = op.not_supported_reasons(inp) + if not reasons: + return + raise exc_type( + f"""Operator `{name}` does not support inputs: +{textwrap.indent(_format_inputs_description(inp), ' ')} +{_format_not_supported_reasons(op, reasons)}""" + ) + + +def _format_not_supported_reasons(op, reasons: List[str]) -> str: + return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons) + + +def _run_priority_list( + name: str, + priority_list: Sequence[T], + inp: Inputs, + extra_op_reasons: Optional[List[Tuple[Any, List[str]]]] = None, +) -> T: + not_supported_reasons: List[List[str]] = [] + for op in priority_list: + not_supported = op.not_supported_reasons(inp) + if not not_supported: + return op + not_supported_reasons.append(not_supported) + + # Let's write a nice message explaining what we tried and why it's not supported + msg = f"""No operator found for `{name}` with inputs: +{textwrap.indent(_format_inputs_description(inp), ' ')}""" + for op, not_supported in zip(priority_list, not_supported_reasons): + msg += "\n" + _format_not_supported_reasons(op, not_supported) + if extra_op_reasons is not None: + for op, not_supported in extra_op_reasons: + msg += "\n" + _format_not_supported_reasons(op, not_supported) + raise NotImplementedError(msg) + + +def _dispatch_fw_priority_list( + inp: Inputs, needs_gradient: bool +) -> Sequence[Type[AttentionFwOpBase]]: + if torch.version.cuda: + flash3_op = [flash3.FwOp] if _get_use_fa3() else [] + priority_list_ops = deque( + flash3_op + + [ + flash.FwOp, + cutlass.FwOp, + ] + ) + else: + priority_list_ops = deque( + [ + ck.FwOp, + ] + ) + if not needs_gradient: + mqa_or_gqa = ( + inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1 + ) + # Split-KV is useful with MQA + # for short Q-seqlen / long K-seqlen + if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256: + parallelism_BH = 0 # BMK + if inp.query.ndim == 3: + parallelism_BH = inp.query.shape[0] + elif inp.query.ndim == 4: # BMHK + parallelism_BH = inp.query.shape[0] * inp.query.shape[2] + elif inp.query.ndim == 5: # BMGHK + parallelism_BH = inp.query.shape[0] * inp.query.shape[2] + if parallelism_BH > 0 and parallelism_BH < 64: + # priority_list_ops.appendleft(ck_splitk.FwOp) + priority_list_ops.appendleft(triton_splitk.FwOp) + # Without variable seqlen flash is fastest + if torch.version.cuda and not isinstance( + inp.attn_bias, attn_bias.BlockDiagonalMask + ): + if _get_use_fa3(): + priority_list_ops.remove(flash3.FwOp) + priority_list_ops.remove(flash.FwOp) + priority_list_ops.appendleft(flash.FwOp) + + return priority_list_ops + + +def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]: + """Computes the best operator for forward + + Raises: + NotImplementedError: if not operator was found + + Returns: + AttentionOp: The best operator for the configuration + """ + return _run_priority_list( + "memory_efficient_attention_forward", + _dispatch_fw_priority_list(inp, needs_gradient), + inp, + ) + + +def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool: + return False + + +def _dispatch_bw( + inp: Inputs, varlen_lse_packed: Optional[bool] +) -> Type[AttentionBwOpBase]: + if torch.version.cuda: + priority_list_ops: List[Type[AttentionBwOpBase]] = [ + flash.BwOp, + cutlass.BwOp, + ] + else: + priority_list_ops = [ + ck.BwOp, + ] + + # NOTE: If we have a variable seqlen `attn_bias`, we need to get a BW pass + # that supports the LSE format + # *unless* we are in the case where both formats are the same (bs=1) + extra_op_reasons = [] + if ( + isinstance(inp.attn_bias, attn_bias.VARLEN_BIASES) + and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2 + ): + assert varlen_lse_packed is not None + for op in priority_list_ops: + if op.VARLEN_LSE_PACKED != varlen_lse_packed: + extra_op_reasons.append( + ( + op, + [ + f"LSE is in {'packed' if varlen_lse_packed else 'padded'} format" + ], + ) + ) + priority_list_ops = [ + op for op in priority_list_ops if op.VARLEN_LSE_PACKED == varlen_lse_packed + ] + if torch.version.cuda and _is_cutlassB_faster_than_flash(inp): + priority_list_ops.remove(cutlass.BwOp) + priority_list_ops.insert(0, cutlass.BwOp) + return _run_priority_list( + "memory_efficient_attention_backward", priority_list_ops, inp + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/flash.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/flash.py new file mode 100644 index 0000000000000000000000000000000000000000..f598dbb74d1439495c5cd3c0feef781193ae9767 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/flash.py @@ -0,0 +1,836 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from itertools import zip_longest +from typing import Any, Iterable, List, Optional, Set, Tuple, Union + +import torch + +from ..common import get_operator, register_operator +from .attn_bias import ( + VARLEN_BIASES, + AttentionBias, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionPaddedKeysMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, + LocalAttentionFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + LowerTriangularFromBottomRightMask, + LowerTriangularMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + check_lastdim_alignment_stride1, +) +from .torch_attention_compat import is_pt_flash_compatible + +FLASH_VERSION = "0.0.0" +VARLEN_LSE_PACKED = False +_TRY_PT_FLASH_ATTN = torch.version.hip is None +_USE_PT_FLASH_ATTN = False + +try: + try: + from ... import _C_flashattention # type: ignore[attr-defined] + from ..._cpp_lib import _build_metadata + + if _build_metadata is not None: + FLASH_VERSION = _build_metadata.flash_version + VARLEN_LSE_PACKED = True + except ImportError: + try: + import flash_attn + from flash_attn.flash_attn_interface import ( + flash_attn_cuda as _C_flashattention, + ) + + FLASH_VERSION = flash_attn.__version__ + FLASH_VER_MIN = (2, 6, 3) + FLASH_VER_LAST = (2, 6, 3) # last supported, inclusive + flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) + if ( + flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST + ) and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1": + raise ImportError( + f"Requires Flash-Attention version >={'.'.join([str(i) for i in FLASH_VER_MIN])}," + f"<={'.'.join([str(i) for i in FLASH_VER_LAST])} " + f"but got {FLASH_VERSION}." + ) + VARLEN_LSE_PACKED = True + except ImportError: + if not _TRY_PT_FLASH_ATTN: + raise + assert is_pt_flash_compatible(force=True) + FLASH_VERSION = torch.nn.attention._get_flash_version() # type: ignore + FLASH_VERSION = f"v{FLASH_VERSION}" + VARLEN_LSE_PACKED = False + _USE_PT_FLASH_ATTN = True + + @torch.library.custom_op( + "xformers_flash::flash_fwd", + mutates_args=(), + device_types=["cuda"], + ) + def _flash_fwd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + p: float, + softmax_scale: float, + is_causal: bool, + window_left: int, + window_right: int, + return_softmax: bool, + block_tables: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + softcap = 0.0 + if _USE_PT_FLASH_ATTN: + ( + attention, + logsumexp, + philox_seed, + philox_offset, + _, + ) = torch.ops.aten._flash_attention_forward( + query, + key, + value, + cu_seqlens_q, # cum_seq_q + cu_seqlens_k, # cum_seq_k + max_seqlen_q, # max_q + max_seqlen_k, # max_k + p, # dropout_p + is_causal, + return_debug_mask=False, + scale=softmax_scale, + window_size_left=window_left, + window_size_right=window_right, + seqused_k=seqused_k, + alibi_slopes=None, # alibi_slopes + ) + rng_state = torch.stack([philox_seed, philox_offset]) + return attention, logsumexp, rng_state + else: + if cu_seqlens_q is None: + assert cu_seqlens_k is None + assert seqused_k is None + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + p, + rng_state, + ) = _C_flashattention.fwd( + query, + key, + value, + None, # out + None, # alibi_slopes + p, + softmax_scale, + is_causal, + window_left, # window_size_left + window_right, # window_size_right + softcap, + return_softmax, + None, # rng + ) + else: + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + p, + rng_state, + ) = _C_flashattention.varlen_fwd( + query, + key, + value, + None, # out + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + None, # leftpad_k_ + block_tables, + None, # alibi_slopes + max_seqlen_q, + max_seqlen_k, + p, + softmax_scale, + False, + is_causal, + window_left, + window_right, + softcap, + return_softmax, + None, # gen + ) + return out, softmax_lse, rng_state + + @torch.library.register_fake("xformers_flash::flash_fwd") + def _flash_fwd_abstract( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + p, + softmax_scale, + is_causal, + window_left, + window_right, + return_softmax, + block_tables, + ): + out = torch.empty_like(query) + if cu_seqlens_q is None: + B, M, H, K = query.shape + lse_shape = [B, H, M] + else: + M, H, K = query.shape + B = cu_seqlens_q.shape[0] - 1 + if VARLEN_LSE_PACKED: + lse_shape = [H, M] + else: + lse_shape = [B, H, max_seqlen_q] + softmax_lse = torch.empty(lse_shape, device=query.device, dtype=torch.float32) + rng_state = torch.empty([2], device=query.device, dtype=torch.int64) + return out, softmax_lse, rng_state + + @torch.library.custom_op( + "xformers_flash::flash_bwd", + mutates_args=(), + device_types=["cuda"], + ) + def _flash_bwd( + grads_share_storage: bool, + grad: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + p: float, + softmax_scale: float, + is_causal: bool, + window_left: int, + window_right: int, + rng_state: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + softcap = 0.0 + if _USE_PT_FLASH_ATTN: + assert softcap == 0.0 + if rng_state is not None: + philox_seed = rng_state[0] + philox_offset = rng_state[1] + else: + philox_seed = philox_offset = None + dq, dk, dv = torch.ops.aten._flash_attention_backward( + grad, + query, + key, + value, + out, + lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + p, + is_causal, + philox_seed, + philox_offset, + scale=softmax_scale, + window_size_left=window_left, + window_size_right=window_right, + ) + else: + dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value) + if cu_seqlens_k is None: + assert cu_seqlens_q is None + _C_flashattention.bwd( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + None, # alibi_slopes + p, + softmax_scale, + is_causal, + window_left, + window_right, + softcap, + False, # deterministic + None, + rng_state, + ) + else: + _C_flashattention.varlen_bwd( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + None, # alibi_slopes + max_seqlen_q, + max_seqlen_k, + p, + softmax_scale, + False, # zero_tensors + is_causal, + window_left, + window_right, + softcap, + False, # deterministic + None, + rng_state, + ) + return dq, dk, dv + + @torch.library.register_fake("xformers_flash::flash_bwd") + def _flash_bwd_abstract( + grads_share_storage, + grad, + query, + key, + value, + *args, + **kwargs, + ): + return _create_dq_dk_dv(grads_share_storage, query, key, value) + + def _create_dq_dk_dv( + grads_share_storage: bool, query, key, value + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Create dq,dk,dv + # If Q/K/V come from a single QKV tensor, let's put the gradient in the + # right strides, so we can avoid a `cat` + if grads_share_storage: + chunk = torch.empty( + (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]), + dtype=query.dtype, + device=query.device, + ) + return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2) + return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + +except ImportError: + pass + + +def _convert_input_format( + inp: Inputs, + supports_mqa: bool, +) -> Tuple[ + Inputs, + Optional[torch.Tensor], + int, + Optional[torch.Tensor], + int, + Optional[torch.Tensor], +]: + assert inp.query.ndim in [4, 5] + query, key, value = inp.query, inp.key, inp.value + batch = query.shape[0] + seqlen_q = query.shape[1] + seqlen_kv = key.shape[1] + head_dim_q = query.shape[-1] + head_dim_v = value.shape[-1] + + attn_bias = inp.attn_bias + if isinstance(attn_bias, BlockDiagonalMask): + assert attn_bias.k_seqinfo.seqstart.device == inp.query.device + cu_seqlen_k = attn_bias.k_seqinfo.seqstart + cu_seqlen_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + max_seqlen_k = attn_bias.k_seqinfo.max_seqlen + seqused_k = None + elif isinstance( + attn_bias, + ( + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ), + ): + assert attn_bias.k_seqinfo.seqstart.device == inp.query.device + cu_seqlen_k = attn_bias.k_seqinfo.seqstart + cu_seqlen_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + max_seqlen_k = attn_bias.k_seqinfo.max_seqlen + seqused_k = attn_bias.k_seqinfo.seqlen + else: + cu_seqlen_k = None + cu_seqlen_q = None + seqused_k = None + max_seqlen_q = inp.query.shape[1] + max_seqlen_k = inp.key.shape[1] + + if query.ndim == 5: # GQA + assert supports_mqa + + # Fold the group/head_in_group dimensions together + def fold(x): + # Either the head is replicated + if x.stride(3) == 0: + return x[:, :, :, 0] + # Or we reshape + return x.reshape( + [ + x.shape[0], + x.shape[1], + -1, + x.shape[4], + ] + ) + + query = fold(query) + key = fold(key) + value = fold(value) + # Optimize for MHA + if supports_mqa and key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0: + key = key[:, :, :1] + value = value[:, :, :1] + # Initially we have `query.shape = [batch, seqlen, num_heads, head_dim_q]` + # We want format `[batch * seqlen, num_heads, head_dim_q]` + if cu_seqlen_k is not None: + query = query.reshape([batch * seqlen_q, -1, head_dim_q]) + key = key.reshape([batch * seqlen_kv, -1, head_dim_q]) + value = value.reshape([batch * seqlen_kv, -1, head_dim_v]) + if isinstance(attn_bias, PagedBlockDiagonalPaddedKeysMask): + num_pages = value.shape[0] // attn_bias.page_size + key = key.view(num_pages, attn_bias.page_size, *key.shape[1:]) + value = value.view(num_pages, attn_bias.page_size, *value.shape[1:]) + + new_inp = Inputs( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=inp.p, + scale=inp.scale, + output_dtype=inp.output_dtype, + is_partial=inp.is_partial, + ) + return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k, seqused_k + + +def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool: + return isinstance( + attn_bias, + ( + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionPaddedKeysMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + ), + ) + + +def _is_paged_attention_supported(attn_bias_type) -> bool: + if issubclass(attn_bias_type, PagedBlockDiagonalPaddedKeysMask): + return FLASH_VERSION > "2.5.6" and not _USE_PT_FLASH_ATTN + + return True + + +def _window_size( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> Tuple[int, int]: + win_left = -1 + win_right = -1 + if isinstance( + attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionPaddedKeysMask, + LowerTriangularFromBottomRightLocalAttentionMask, + ), + ): + win_left = attn_bias._window_size - 1 + if isinstance(attn_bias, LocalAttentionFromBottomRightMask): + win_left = attn_bias.window_left + win_right = attn_bias.window_right + return (win_left, win_right) + + +def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None: + # Flash does not support TopLeft, so only allow causal masks with TopLeft + # if each batch element has equal number of queries and keys. + if isinstance(d.attn_bias, BlockDiagonalCausalMask): + # Flash does not support TopLeft, so only allow BlockDiagonalCausalMask + # if each batch element has equal number of queries and keys. + for k_start, q_start in zip_longest( + d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py + ): + if k_start != q_start: + reasons.append( + "Only support BlockDiagonalCausalMask if equal" + " numbers of keys and queries" + ) + break + elif isinstance(d.attn_bias, LowerTriangularMask): + if d.query.shape[1] != d.key.shape[1]: + reasons.append( + "Only support LowerTriangularMask if equal number of" "keys and queries" + ) + + +def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None: + """ + We want to be able to collapse the G/H dimensions together + """ + if x.ndim == 5: + stride_g, stride_h = x.stride(2), x.stride(3) + if x.shape[2] == 1: + return + if x.shape[3] == 1 or stride_h == 0: + return + if stride_g != stride_h * x.shape[-2]: + reasons.append( + f"GQA is only supported when the G/H dimensions are contiguous\n" + f" {name}.stride: {x.stride()}\n" + f" {name}.shape : {list(x.shape)}" + ) + + +def _post_process_lse( + lse: torch.Tensor, + inp: Inputs, + original_query_shape: Tuple[int, ...], + varlen_lse_packed: bool = VARLEN_LSE_PACKED, +) -> torch.Tensor: + # Easy case: no varlen + if not isinstance(inp.attn_bias, VARLEN_BIASES): + if len(original_query_shape) == 5: + # [B, GH, M] => [B, G, H, M] + return lse.unflatten(1, original_query_shape[2:4]) + return lse + + # Already packed: just bring back the batch dimension + if varlen_lse_packed: + if len(original_query_shape) == 5: + # (1, G, H, total_q) + return lse.unflatten(0, original_query_shape[2:4]).unsqueeze(0) + # (1, H, total_q) + return lse.unsqueeze(0) + + if not inp.is_partial: + # (B, H, M) + return lse + + # reshape from (B, G*H, max_seqlen) to (1, G*H, B*max_seqlen) + # Unfortunately this flatten is not just a view. + lse_hkm = lse.permute(1, 0, 2).flatten(start_dim=1)[None] + if len(original_query_shape) == 5: + return lse_hkm.unflatten(1, original_query_shape[2:4]) + return lse_hkm + + +@register_operator +class FwOp(AttentionFwOpBase): + """Operator that computes memory-efficient attention using \ + `Flash-Attention `_ \ + implementation. + """ + + OPERATOR = get_operator("xformers_flash", "flash_fwd") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 256 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionPaddedKeysMask, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + LocalAttentionFromBottomRightMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ) + + SUPPORTED_ATTN_BIAS_TYPES = [ + b for b in SUPPORTED_ATTN_BIAS_TYPES if _is_paged_attention_supported(b) + ] + + SUPPORTS_DROPOUT = True + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = False + SUPPORTS_BMGHK = True + SUPPORTS_PARTIAL = True + VARLEN_LSE_PACKED = VARLEN_LSE_PACKED + NAME = f"fa2F@{FLASH_VERSION}-pt" if _USE_PT_FLASH_ATTN else f"fa2F@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + _check_needs_no_topleft(d, reasons) + _check_strides_for_bmghk(d.query, "query", reasons) + _check_strides_for_bmghk(d.key, "key", reasons) + _check_strides_for_bmghk(d.value, "value", reasons) + + if ( + d.is_partial + and not VARLEN_LSE_PACKED + and isinstance(d.attn_bias, VARLEN_BIASES) + ): + q_seqinfo = d.attn_bias.q_seqinfo + if q_seqinfo.min_seqlen != q_seqinfo.max_seqlen: + # Flash provides padded LSE which we don't handle. + reasons.append("partial attention with heterogeneous queries") + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + return_softmax = False + original_query_shape = inp.query.shape + + out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + # no cumulative seqlen + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + seqused_k, + ) = _convert_input_format(inp, supports_mqa=True) + + if inp.query.numel() > 0 and inp.key.numel() > 0: + win_left, win_right = _window_size(inp.attn_bias) + block_tables = ( + inp.attn_bias.block_tables + if isinstance(inp.attn_bias, PagedBlockDiagonalPaddedKeysMask) + else None + ) + out, softmax_lse, rng_state = cls.OPERATOR( + inp.query, + inp.key, + inp.value, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + inp.scale_float, + _is_causal(inp.attn_bias), + window_left=win_left, + window_right=win_right, + return_softmax=return_softmax, + block_tables=block_tables, + ) + out = out.reshape(out_shape) + else: + out = torch.zeros(out_shape, device=inp.query.device, dtype=inp.query.dtype) + rng_state = None + softmax_lse = torch.empty( + ( + [inp.query.shape[2], inp.query.shape[0] * inp.query.shape[1]] + if VARLEN_LSE_PACKED and isinstance(inp.attn_bias, VARLEN_BIASES) + else [inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]] + ), + device=inp.query.device, + dtype=torch.float32, + ) + if not needs_gradient: + return out, None + ctx = Context( + out=out, + lse=_post_process_lse(softmax_lse, inp, original_query_shape), + ) + if inp.p != 0.0: + ctx.op_bw = BwOp + ctx.rng_state = rng_state + return (out, ctx) + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_operator("xformers_flash", "flash_bwd") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = tuple( + set(FwOp.SUPPORTED_ATTN_BIAS_TYPES).difference( + { + BlockDiagonalCausalLocalAttentionPaddedKeysMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, + } + ) + ) + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + IS_DETERMINISTIC = False + SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this! + VARLEN_LSE_PACKED = VARLEN_LSE_PACKED + NAME = f"fa2B@{FLASH_VERSION}-pt" if _USE_PT_FLASH_ATTN else f"fa2B@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + MAX_HEADDIM_DROPOUT_SM8x = 224 + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + _check_needs_no_topleft(d, reasons) + if d.device.type == "cuda": + # Due to limited shared-memory, some GPUs are limited in head dimension + device_capability = torch.cuda.get_device_capability(d.device) + is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)] + if ( + max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_DROPOUT_SM8x + and not is_sm80_or_sm90 + and d.p != 0.0 + ): + reasons.append( + "requires a GPU with compute capability 8.0 " + f"(A100) or 9.0 (H100) for dropout when 'query.shape[-1] > {cls.MAX_HEADDIM_DROPOUT_SM8x}'" + ) + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + seqused_k, + ) = _convert_input_format(inp, supports_mqa=False) + # assert ctx.lse.is_contiguous() + assert seqused_k is None + ctx_lse = ctx.lse + if isinstance(inp.attn_bias, VARLEN_BIASES) and VARLEN_LSE_PACKED: + assert ctx_lse.shape[0] == 1 + ctx_lse = ctx_lse[0] + else: + # NOTE: cutlass pads the last dimension, we need to slice it + assert ctx_lse.shape[2] >= max_seqlen_q + ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous() + kernel_out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + assert grad.dtype in cls.SUPPORTED_DTYPES + + if inp.query.numel() and inp.key.numel(): + win_left, win_right = _window_size(inp.attn_bias) + grads = Gradients( + *cls.OPERATOR( + ctx.qkv_share_storage, + grad.reshape(kernel_out_shape).contiguous(), + inp.query, + inp.key, + inp.value, + ctx.out.reshape(kernel_out_shape), + ctx_lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + inp.scale_float, + _is_causal(inp.attn_bias), + window_left=win_left, + window_right=win_right, + rng_state=ctx.rng_state if inp.p > 0.0 else None, + ) + ) + else: + grads = Gradients( + dq=torch.zeros_like(inp.query), + dk=torch.zeros_like(inp.key), + dv=torch.zeros_like(inp.value), + ) + if grads.dq.numel() == 0: + grads.dk.zero_() + grads.dv.zero_() + if grads.dv.numel() == 0: + grads.dq.zero_() + grads.dq = grads.dq.reshape(dq_shape) + grads.dk = grads.dk.reshape(dk_shape) + grads.dv = grads.dv.reshape(dv_shape) + return grads diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/flash3.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/flash3.py new file mode 100644 index 0000000000000000000000000000000000000000..88e8e29fce3bdc0acd0996819eff6f61ff82678d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/flash3.py @@ -0,0 +1,421 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple + +import torch + +from ..common import get_operator, register_operator +from .attn_bias import ( + VARLEN_BIASES, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, + LowerTriangularFromBottomRightMask, + LowerTriangularMask, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + check_lastdim_alignment_stride1, +) +from .flash import ( + _check_needs_no_topleft, + _convert_input_format, + _is_causal, + _post_process_lse, +) + +FLASH_VERSION = "0.0.0" +try: + from ... import _C_flashattention3 # type: ignore[attr-defined] + from ..._cpp_lib import _build_metadata + + if _build_metadata is not None: + FLASH_VERSION = _build_metadata.flash_version +except ImportError: + try: + from flash_attn_interface import flashattn_hopper_cuda as _C_flashattention3 + except ImportError: + # We end up here is arch is not 90a + _C_flashattention3 = None + +if _C_flashattention3 is not None: + # returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p + @torch.library.custom_op( + "xformers_flash3::flash_fwd", mutates_args=(), device_types=["cuda"] + ) + def mha_fwd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + p: float, + softmax_scale: float, + is_causal: bool, + ) -> Tuple[torch.Tensor, torch.Tensor,]: + if cu_seqlens_q is None: + assert cu_seqlens_k is None + assert seqused_k is None + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + p, + ) = _C_flashattention3.fwd( + query, key, value, None, softmax_scale, None, None, None, is_causal + ) + else: + out, q, k, v, out_padded, softmax_lse = _C_flashattention3.varlen_fwd( + query, + key, + value, + None, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal, + ) + return out, softmax_lse + + @torch.library.register_fake("xformers_flash3::flash_fwd") + def mha_fwd_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + p: float, + softmax_scale: float, + is_causal: bool, + ) -> Tuple[torch.Tensor, torch.Tensor,]: + query_shape = query.shape + out = query.new_empty(query_shape) + # Query is (B, M, H, K) or (total_M, H, K) + # LSE is (B, H, M) or (H, total_M) + lse_shape = ( + (query_shape[0], query_shape[2], query_shape[1]) + if cu_seqlens_q is None + else (query_shape[1], query_shape[0]) + ) + lse = query.new_empty(lse_shape, dtype=torch.float32) + return out, lse + + def _create_dq_dk_dv( + grads_share_storage: bool, query, key, value + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Create dq,dk,dv + # If Q/K/V come from a single QKV tensor, let's put the gradient in the + # right strides, so we can avoid a `cat` + if grads_share_storage: + chunk = torch.empty( + (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]), + dtype=query.dtype, + device=query.device, + ) + return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2) + return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + + @torch.library.custom_op( + "xformers_flash3::flash_bwd", mutates_args=(), device_types=["cuda"] + ) + def mha_bwd( + grads_share_storage: bool, + dout: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value) + is_deterministic = False + if cu_seqlens_q is None: + assert cu_seqlens_k is None + dq, dk, dv, softmax_d, *rest = _C_flashattention3.bwd( + dout, + query, + key, + value, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + is_causal, + is_deterministic, + ) + else: + dq, dk, dv, softmax_d, *rest = _C_flashattention3.varlen_bwd( + dout, + query, + key, + value, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal, + is_deterministic, + ) + return dq, dk, dv + + @torch.library.register_fake("xformers_flash3::flash_bwd") + def mha_bwd_fake( + grads_share_storage: bool, + dout: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq = torch.empty_like(query) + dk = torch.empty_like(key) + dv = torch.empty_like(value) + return dq, dk, dv + + +@register_operator +class FwOp(AttentionFwOpBase): + """Operator that computes memory-efficient attention using \ + `Flash-Attention `_ \ + implementation. + """ + + OPERATOR = get_operator("xformers_flash3", "flash_fwd") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (9, 0) + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 256 + SUPPORTED_MIN_K = 64 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + ) + + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = False + SUPPORTS_BMGHK = True + SUPPORTS_PARTIAL = True + UNPADDED_LSE = True + NAME = f"fa3F@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + if d.query.shape[-1] not in [64, 128, 256]: + reasons.append("only head-dim 64,128,256 is supported") + + _check_needs_no_topleft(d, reasons) + + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + + original_query_shape = inp.query.shape + out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + seqused_k, + ) = _convert_input_format(inp, supports_mqa=True) + + if inp.query.numel() > 0 and inp.key.numel() > 0: + (out, softmax_lse,) = cls.OPERATOR( + inp.query, + inp.key, + inp.value, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + inp.scale_float, + _is_causal(inp.attn_bias), + ) + out = out.reshape(out_shape) + else: + out = torch.zeros( + inp.query.shape, device=inp.query.device, dtype=inp.query.dtype + ) + softmax_lse = torch.empty( + [inp.query.shape[0], inp.query.shape[2], inp.query.shape[1]], + device=inp.query.device, + dtype=torch.float32, + ) + ctx = Context( + out=out, + lse=softmax_lse, + ) + + if not needs_gradient: + return out, None + ctx = Context( + out=out, + lse=_post_process_lse( + softmax_lse, inp, tuple(original_query_shape), varlen_lse_packed=True + ), + ) + return (out, ctx) + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_operator("xformers_flash3", "flash_bwd") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_MIN_K = FwOp.SUPPORTED_MIN_K + SUPPORTED_ATTN_BIAS_TYPES = ( + # Exclude padded or gappy masks, since seqused_k is not supported by the kernel. + type(None), + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalFromBottomRightMask, + ) + + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + IS_DETERMINISTIC = False + SUPPORTS_BMGHK = False + SUPPORTS_LSE_FORMATS: Sequence[str] = ["", "varlen_flat"] + NAME = f"fa3B@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + _check_needs_no_topleft(d, reasons) + if d.query.shape[-1] not in [64, 128]: + reasons.append("only head-dim 64 or 128 is supported") + + _check_needs_no_topleft(d, reasons) + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + + dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + _, # seqused_k, + ) = _convert_input_format(inp, supports_mqa=False) + ctx_lse = ctx.lse + + if isinstance(inp.attn_bias, VARLEN_BIASES): + assert ctx_lse.shape[0] == 1 + ctx_lse = ctx_lse[0] + else: + # NOTE: cutlass pads the last dimension, we need to slice it + assert ctx_lse.shape[2] >= max_seqlen_q + ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous() + + kernel_out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + assert grad.dtype in cls.SUPPORTED_DTYPES + + if inp.query.numel() and inp.key.numel(): + dq, dk, dv = cls.OPERATOR( + ctx.qkv_share_storage, + grad.reshape(kernel_out_shape).contiguous(), + inp.query, + inp.key, + inp.value, + ctx.out.reshape(kernel_out_shape), + ctx.lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=inp.scale_float, + is_causal=_is_causal(inp.attn_bias), + ) + grads = Gradients(dq, dk, dv) + else: + grads = Gradients( + dq=torch.zeros_like(inp.query), + dk=torch.zeros_like(inp.key), + dv=torch.zeros_like(inp.value), + ) + + grads.dq = grads.dq.reshape(dq_shape) + grads.dk = grads.dk.reshape(dk_shape) + grads.dv = grads.dv.reshape(dv_shape) + return grads diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/torch_attention_compat.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/torch_attention_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e7704afc3bcbf251f1b84d8b2d1406197b2006 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/torch_attention_compat.py @@ -0,0 +1,133 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import torch +from torch._C import parse_schema + +try: + # This function was added in https://github.com/pytorch/pytorch/pull/131894 + # (which hadn't landed yet at the time of writing), thus will only arrive in + # PyTorch 2.5+. In the meantime we need a fallback. + from torch.modules.cuda import is_flash_attention_available +except ImportError: + + def is_flash_attention_available(): + return sys.platform == "linux" + + +def is_pt_cutlass_compatible(force: bool) -> bool: + compatible = True + + fwd_schema_str = ( + "aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, " + "Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, " + "SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, " + "float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> " + "(Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, " + "SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)" + ) + expected_fwd_schema = parse_schema(fwd_schema_str) + + current_schema = torch.ops.aten._efficient_attention_forward.default._schema + if not current_schema.is_backward_compatible_with(expected_fwd_schema): + compatible = False + + if force: + raise ImportError( + f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_forward schema\n" + f"EXPECTED:\n{expected_fwd_schema}\n" + f"but GOT:\n{current_schema}" + ) + + bwd_schema_str = ( + "aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, " + "Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, " + "SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, " + "int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, " + "int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor)" + ) + + expected_bwd_schema = parse_schema(bwd_schema_str) + + current_schema = torch.ops.aten._efficient_attention_backward.default._schema + if not current_schema.is_backward_compatible_with(expected_bwd_schema): + compatible = False + + if force: + raise ImportError( + f"Current Torch CUTLASS doesnt have a compatible aten::_efficient_attention_backward schema\n" + f"EXPECTED:\n{expected_bwd_schema}\n" + f"but GOT:\n{current_schema}" + ) + + return compatible + + +def is_pt_flash_compatible(force: bool) -> bool: + if not is_flash_attention_available(): + if force: + raise ImportError("Flash SDP backend is disabled") + return False + + if not hasattr(torch.nn, "attention") or not hasattr( + torch.nn.attention, "_get_flash_version" + ): + if force: + raise ImportError( + f"Current Torch {torch.__version__} doesnt implement " + "torch.nn.attention._get_flash_version()" + ) + return False + + FLASH_VERSION = torch.nn.attention._get_flash_version() + + compatible = True + + fwd_schema_str = ( + "aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, " + "Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, " + "bool is_causal, bool return_debug_mask, *, float? scale=None, " + "SymInt? window_size_left=None, SymInt? window_size_right=None, " + "Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, " + "Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)" + ) + expected_fwd_schema = parse_schema(fwd_schema_str) + + current_schema = torch.ops.aten._flash_attention_forward.default._schema + if not current_schema.is_backward_compatible_with(expected_fwd_schema): + compatible = False + + if force: + raise ImportError( + f"Current Torch with Flash-Attention {FLASH_VERSION} doesnt have " + "a compatible aten::_flash_attention_forward schema\n" + f"EXPECTED:\n{expected_fwd_schema}\n" + f"but GOT:\n{current_schema}" + ) + + bwd_schema_str = ( + "aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, " + "Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, " + "float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, " + "SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)" + ) + + expected_bwd_schema = parse_schema(bwd_schema_str) + + current_schema = torch.ops.aten._flash_attention_backward.default._schema + if not current_schema.is_backward_compatible_with(expected_bwd_schema): + compatible = False + + if force: + raise ImportError( + f"Current Torch with Flash-Attention {FLASH_VERSION} doesnt have " + "a compatible aten::_flash_attention_backward schema\n" + f"EXPECTED:\n{expected_bwd_schema}\n" + f"but GOT:\n{current_schema}" + ) + + return compatible diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/fmha/triton_splitk.py b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/triton_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..9261b38a05b416775c6b8a61ef8c3982d7b490fa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/fmha/triton_splitk.py @@ -0,0 +1,1015 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import sys +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +import torch + +from ... import _is_triton_available +from ..common import register_operator +from .attn_bias import ( + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, +) +from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1 + + +def _strides(x: Optional[torch.Tensor], *stride_names: str): + if x is None: + return {f"stride_{name}": None for name in stride_names} + assert x.ndim == len(stride_names) + return {f"stride_{name}": s for name, s in zip(stride_names, x.stride())} + + +def _is_supported_causal_bias(attn_bias: Any) -> bool: + return isinstance( + attn_bias, + ( + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + ), + ) + + +def _is_supported_gappy_bias(attn_bias: Any) -> bool: + return isinstance( + attn_bias, + ( + BlockDiagonalGappyKeysMask, + PagedBlockDiagonalGappyKeysMask, + ), + ) + + +def _is_supported_paged_bias(attn_bias: Any) -> bool: + return isinstance( + attn_bias, + ( + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ), + ) + + +@dataclass +class InputsFp8(Inputs): + """ + Each of k/v_fp8_scales is an int32 tensor of shape (1, B * Mkv, Hq), + or (1, page_size * max_pages_per_lane, Hq) in the paged case. + Each int32 element contains two packed fp16 number + - scales and shifts for row-wise FP8 quantization. + """ + + k_fp8_scale_shift: Optional[torch.Tensor] = None + v_fp8_scale_shift: Optional[torch.Tensor] = None + + @property + def nbytes(self) -> int: + """ + Number of bytes in the input, not counting the attention bias. + """ + return ( + super(InputsFp8, self).nbytes + + ( + self.k_fp8_scale_shift.untyped_storage().nbytes() + if self.k_fp8_scale_shift is not None + else 0 + ) + + ( + self.v_fp8_scale_shift.untyped_storage().nbytes() + if self.v_fp8_scale_shift is not None + else 0 + ) + ) + + +if TYPE_CHECKING or _is_triton_available(): + from ._triton.splitk_kernels import _fwd_kernel_splitK, _splitK_reduce +else: + _fwd_kernel_splitK = None + _splitK_reduce = None + + +def _is_cuda() -> bool: + return torch.version.cuda is not None + + +def _is_cuda_at_least_sm80(device: torch.device) -> bool: + return _is_cuda() and torch.cuda.get_device_capability(device) >= ( + 8, + 0, + ) + + +@register_operator +class FwOp(AttentionFwOpBase): + """Flash-Attention with Split-K. Supports fused int4 and fp8 K/V quantization. + Quantized path will be taken if input K/V have type int32. + + Int4 quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along + the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported. + Quantization coefficients (scale and shift) are represented as two + float16 constants per group, packed into int32. Quantization coefficients of + all groups are placed at the beginning of the row. So, if unquantized K/V have head + dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS + and dtype int32. + Pseudocode for dequantizing one row can look like: + group_size = D // 8 + for i in range(NUM_GROUPS): + group_start = NUM_GROUPS + i * group_size + group_quant = K[..., group_start: group_start + group_size] + scale, shift = unpack_int32_into_float16x2(group_quant[0]) + group_dequant = group_quant[..., 1:] * scale + shift + ... + + For fp8 only row-wise quantization is supported. To use it, provide input of type + xformers.ops.fmha.triton_splitk.InputsFp8 (instead of the usual xformers.ops.fmha.Inputs) to + xformers.ops.fmha.triton_splitk.FwOp.apply or xformers.ops.fmha._memory_efficient_attention_forward. + + This op uses Paged Attention when bias is one of the Paged* classes. + In this case bias has additional fields: + - block_tables of shape [batch_size, max_num_pages] + - K/V of shape [1, max_num_pages * page_size, num_heads, head_dim] + or [1, max_num_pages * page_size, num_groups, num_heads, head_dim] + + The shape which the kernel takes the queries and the output + is quite different from the user interface. There are three + types of input (a) no bias / tensor bias, (b) variable q_len + (which is only for non causal) and (c) other bias objects. + From the interface to the kernel the following changes happen. + + (0) In all cases, a group dimension may need to be added. + + (1) For (c), a batch dimension is created, reshaping from (1, B*Mq, G, Hq, K) + to (B, Mq, G, Hq, K) + + (2) For (a) and (c), in the case of multiquery (i.e. the head dimension + of keys and values is expanded), the head-swapping trick + reshaping from (B, Mq, G, Hq, K) to (B, M=Hq*Mq, G, H=1, K) + + (3) For (b), in the case of multiquery, the head-swapping trick + trick, reshaping from (1, Mq, G, Hq, K) to (1, Mq*Hq, G, H=1, K) + Note here that Mq is a single long dimension which spans all the queries + in the batch, unlike in case (C). Also that Hq has to run faster than + Mq in order that the queries in a batch element remain evenly spaced. + + In all cases, the shape as seen by the kernel is called (Bqq, Mqq, G, H, K). + The kernel operates on B batch elements and M queries per batch element. + """ + + OPERATOR = True + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } # Those are dtypes of Q. In the quantized case K/V has dtype int32 + SUPPORTED_MAX_K = 512 + SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( + type(None), + torch.Tensor, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ) + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + SUPPORTS_OUTPUT_DTYPE = True + SUPPORTS_PARTIAL = True + NAME = "triton_splitKF" + + SPLIT_K: Optional[int] = None + MAX_BLOCK_M = 32 + + # Whether blocks attending to no part of a variable sequence length + # should exit early. This requires extra kernels to run beforehand + # to initialise the outputs. + # TODO: avoid these by making the reduce kernel work out it doesn't need + # to look at the irrelevant places. + SPLIT_K_EARLY_EXIT: bool = False + + # Perform kernel-level Triton autotune + AUTOTUNE = False + + NUM_GROUPS = 1 # Default quantization is row-wise + NUM_GROUPS_VALUES = [1, 2, 4, 8] + + # Values below are used when autotune=False. + # Note that under certain conditions different values might be used, see the code just before the kernel launch. + BLOCK_M: int = 16 # When M > 1, different BLOCK_M can be used. + BLOCK_N: int = 64 + # On AMD or for M > 1 different NUM_STAGES and NUM_WARPS can be used. + NUM_STAGES: int = 1 + NUM_WARPS: int = 2 + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + if K not in {16, 32, 64, 128, 256, 512}: + reasons.append(f"Embed dim {K} not supported") + if Mkv == 0: + # Other ops support this; but here, triton compilation + # crashes on A100 + reasons.append("Query length is 0") + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + if (sys.version_info.major, sys.version_info.minor) < (3, 9): + reasons.append("triton_splitk requires python 3.9 or above!") + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + if d.key.dtype != torch.int32: + check_lastdim_alignment_stride1(reasons, "key", d.key, 8) + check_lastdim_alignment_stride1(reasons, "value", d.value, 8) + if cls.OPERATOR is None: + reasons.append("triton is not available") + if d.device.type == "cuda": + # Has only been tested on 8.0 / 9.0. + if _is_cuda() and not _is_cuda_at_least_sm80(d.device): + reasons.append( + "requires NVidia GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" + ) + # TODO: AMD GPU support matrix needs to be figured out. MI300X is tested to work. + + q_len = d.query.shape[1] + is_block_diagonal = isinstance( + d.attn_bias, (BlockDiagonalPaddedKeysMask, BlockDiagonalGappyKeysMask) + ) + is_paged = _is_supported_paged_bias(d.attn_bias) + is_causal = _is_supported_causal_bias(d.attn_bias) + if is_block_diagonal or is_paged: + seqinfo = d.attn_bias.q_seqinfo # type: ignore + if q_len != seqinfo.seqstart_py[-1]: + reasons.append( + f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}" + ) + q_len = seqinfo.max_seqlen + if is_causal and q_len != seqinfo.min_seqlen: + reasons.append("Variable query len is not supported for causal masks.") + if q_len > 16 and is_causal: + # 16 is the minimum BLOCK_M which gets used + # XXX I don't really understand why this is needed. + reasons.append( + "Query length should not be larger than 16 for causal attention biases" + ) + + if is_paged: + page_size = d.attn_bias.page_size # type: ignore + if d.key.shape[1] % page_size: + reasons.append( + "For paged attention, key.shape[1] should be divisible " + "by the page size, " + f"but got {d.key.shape[1]=}, {page_size=}." + ) + if cls.AUTOTUNE: + reasons.append("Paged attention doesn't support autotuning yet.") + if page_size % cls.BLOCK_N: + reasons.append( + "For paged attention, page size should be divisible " + "by the block size, " + f"but got {page_size=}, {cls.BLOCK_N=}." + ) + + if isinstance(d.attn_bias, torch.Tensor): + if d.attn_bias.ndim not in (4, 5): + reasons.append( + "Additive attention bias has to have shape (B, G, H, Mq, Mkv) " + f"or (B, H, Mq, Mkv), but got {d.attn_bias.shape}." + ) + if cls.SPLIT_K is not None and cls.SPLIT_K > 1: + reasons.append( + "Additive attention bias is not supported with split-k > 1." + ) + + return reasons + + @classmethod + def get_split_k(cls, B: int, G: int, H: int, Mk: int, Mq: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + if torch.version.hip: + split_k = max(Mk + bh - 1, 1024) // bh + max_chunk_size = 64 + split_k_stop_val = 1024 / (B * G * H) + while split_k > 1 and Mk / (split_k - 1) < max_chunk_size: + split_k = split_k - 1 + + while split_k > split_k_stop_val: + split_k = split_k // 2 + + split_size = (Mk + split_k - 1) // split_k + + chunk_size = split_size // max_chunk_size * max_chunk_size + if chunk_size < split_size: + split_k += 1 + + split_k_upper_bound = 512 + else: + if Mq > 1 and B * G * H > 64: + return 1 + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128 + split_k_stop_val = Mk / max_chunk_size + split_k_upper_bound = 64 + + while split_k > split_k_stop_val: + split_k = split_k // 2 + + split_k = min(split_k, split_k_upper_bound) + split_k = max(split_k, 1) + + return split_k + + @classmethod + def get_kernel(cls): + from ._triton.splitk_kernels import ( + _fwd_kernel_splitK_autotune, + _get_splitk_kernel, + ) + + if cls.AUTOTUNE: + return _fwd_kernel_splitK_autotune[cls.NUM_GROUPS] + else: + return _get_splitk_kernel(cls.NUM_GROUPS) + + @classmethod + def get_fp8_scale_shift( + cls, inp: Inputs + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if not hasattr(inp, "k_fp8_scale_shift"): + return None, None + inp_ = cast(InputsFp8, inp) + k_fp8_scale_shift = inp_.k_fp8_scale_shift + v_fp8_scale_shift = inp_.v_fp8_scale_shift + assert k_fp8_scale_shift is not None + assert v_fp8_scale_shift is not None + if k_fp8_scale_shift.ndim == 3: + return k_fp8_scale_shift.unsqueeze(2), v_fp8_scale_shift.unsqueeze(2) + if k_fp8_scale_shift.ndim == 4: + return k_fp8_scale_shift, v_fp8_scale_shift + raise ValueError( + "FP8 scales have to be provided in BMH or BMGH format, " + f"but got {k_fp8_scale_shift.shape=}" + ) + + @classmethod + def apply( + cls, + inp: Inputs, + needs_gradient: bool, + ) -> Tuple[torch.Tensor, Optional[Context]]: + """ + Note that inp can be of type InputsFp8, in which case K/V are assumed to be row-wise FP8-quantized. + This is different from int4 quantization, where coefficients are kept together with the quantized + values at the beginning of each row, and inp has type Inputs. + """ + + k_fp8_scale_shift, v_fp8_scale_shift = cls.get_fp8_scale_shift(inp) + + output_dtype = inp.get_output_dtype() + if not isinstance(inp.attn_bias, torch.Tensor): + attn_bias_tensor = None + attn_bias = cast( + Optional[ + Union[ + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ] + ], + inp.attn_bias, + ) + else: + attn_bias_tensor = inp.attn_bias + attn_bias = None + + seq_len = None + seq_starts_k = None + seq_starts_q = None + seq_starts_q_multiplier = None + q, k, v = inp.get_qkv_in_bmghk() + IS_CAUSAL = False + NUM_QUERIES_CAUSAL = 1 + variable_q = False + + is_block_diagonal = isinstance(attn_bias, BlockDiagonalPaddedKeysMask) + is_gappy = _is_supported_gappy_bias(attn_bias) + is_paged = _is_supported_paged_bias(attn_bias) + if attn_bias is not None: + assert is_paged or is_block_diagonal or is_gappy + assert attn_bias.k_seqinfo.seqlen.device == inp.query.device + seq_len = attn_bias.k_seqinfo.seqlen + assert seq_len.stride(0) == 1 + if is_gappy: + seq_starts_k = attn_bias.k_seqinfo.seqstart + assert seq_starts_k.stride(0) == 1 + assert q.shape[0] == 1 + B = len(seq_len) + G, Hq, Kq = q.shape[-3:] + # force a bool because triton cannot take np.bool_ + multiple_q = bool(attn_bias.q_seqinfo.max_seqlen > 1) + IS_CAUSAL = multiple_q and _is_supported_causal_bias(attn_bias) + variable_q = multiple_q and not IS_CAUSAL + Kkv = v.shape[-1] + + if variable_q: + seq_starts_q = attn_bias.q_seqinfo.seqstart + seq_starts_q_multiplier = 1 + assert seq_starts_q.stride(0) == 1 + else: + q = q.view(B, -1, G, Hq, Kq) + + kv_shape = (1 if is_paged or is_gappy else B, -1, G, Hq, Kkv) + k = k.view(kv_shape) + v = v.view(kv_shape) + if k_fp8_scale_shift is not None and v_fp8_scale_shift is not None: + k_fp8_scale_shift = k_fp8_scale_shift.view(kv_shape[:-1]) + v_fp8_scale_shift = v_fp8_scale_shift.view(kv_shape[:-1]) + + Mq = q.shape[1] + NUM_QUERIES_CAUSAL = Mq + else: + B, Mq, G, Hq, Kq = q.shape + + if attn_bias_tensor is not None and attn_bias_tensor.ndim == 4: + # (B, H, Mq, Mkv) -> (B, G, H, Mq, Mkv) + attn_bias_tensor = attn_bias_tensor.unsqueeze(1) + + # In the case of MQA/GQA, we make q have sequence length (H * Mq) and only one "head". + mqa_swap_seqlen_head = False + if ( + k.shape[3] > 1 + and k.stride(3) == 0 + and v.stride(3) == 0 + and attn_bias_tensor is None + ): + mqa_swap_seqlen_head = True + if variable_q: + seq_starts_q_multiplier = Hq + assert q.shape[0] == 1 + # The idea is Hq,Mq are reshaped to (M=Mq*Hq, H=1) + q = q.permute(0, 1, 3, 2, 4).reshape(1, -1, G, 1, Kq) + else: + # This is a copy iff Mq, G and H are all > 1. + # The idea is Hq,Mq are reshaped to (M=Hq*Mq, H=1) + q = q.permute(0, 3, 1, 2, 4).reshape(q.shape[0], -1, G, 1, Kq) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + if k_fp8_scale_shift is not None and v_fp8_scale_shift is not None: + k_fp8_scale_shift = k_fp8_scale_shift[:, :, :, :1] + v_fp8_scale_shift = v_fp8_scale_shift[:, :, :, :1] + + if k.dtype == torch.int32: + if k_fp8_scale_shift is not None: + Lk = k.shape[-1] * 4 + PACKED_PER_VAL = 4 + else: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + assert cls.NUM_GROUPS == 1, f"{cls.NUM_GROUPS=}" + + _, Mk, G, H, Kkv = k.shape + Bqq, Mqq, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + if variable_q: + assert attn_bias is not None + assert seq_starts_q_multiplier is not None + M = attn_bias.q_seqinfo.max_seqlen * seq_starts_q_multiplier + else: + M = Mqq + page_size = inp.attn_bias.page_size if is_paged else 0 # type: ignore + block_tables = None + kv_cache_blocks_per_row = 0 + if is_paged: + block_tables = inp.attn_bias.block_tables # type: ignore + kv_cache_blocks_per_row = block_tables.shape[1] + Mk = block_tables.shape[1] * page_size + elif attn_bias is not None: + Mk = min(Mk, attn_bias.k_seqinfo.max_seqlen) + + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = ( + cls.get_split_k(B, G, H, Mk, Mq) if attn_bias_tensor is None else 1 + ) + + # M_ceil = Mqq rounded up to a multiple of MAX_BLOCK_M + M_ceil = (Mqq + cls.MAX_BLOCK_M - 1) // cls.MAX_BLOCK_M * cls.MAX_BLOCK_M + IS_SPLITK = split_k > 1 # or cls.autotune? + output_shape = (Bqq, Mq, G, Hq, Kq) + if IS_SPLITK: + o_splitk_dtype = ( + torch.float64 if output_dtype == torch.float64 else torch.float32 + ) + if cls.SPLIT_K_EARLY_EXIT: + o_splitk = torch.zeros( + [Bqq, G, H, split_k, M_ceil, Kq], + dtype=o_splitk_dtype, + device=q.device, + ) + else: + o_splitk = torch.empty( + [Bqq, G, H, split_k, M_ceil, Kq], + dtype=o_splitk_dtype, + device=q.device, + ) + else: + o_splitk = torch.empty( + [Bqq, split_k, Mqq, G, H, Kq], + dtype=output_dtype, + device=q.device, + ).permute(0, 3, 4, 1, 2, 5) + lse, lse_splitk = None, None + # LSE may need higher precision than output + output_f64_lse = output_dtype in (torch.float32, torch.float64) + if IS_SPLITK or needs_gradient: + if cls.SPLIT_K_EARLY_EXIT: + lse_splitk = torch.full( + [Bqq, G, H, split_k, Mqq], + -float("inf"), + dtype=torch.float64 + if IS_SPLITK or output_f64_lse + else torch.float32, + device=q.device, + ) + else: + lse_splitk = torch.empty( + [Bqq, G, H, split_k, Mqq], + dtype=torch.float64 + if IS_SPLITK or output_f64_lse + else torch.float32, + device=q.device, + ) + + def grid(META): + import triton + + return triton.cdiv(M, META["BLOCK_M"]), B * G * H, split_k + + split_size = (Mk + split_k - 1) // split_k + use_seq_len = seq_len is not None + + kernel = cls.get_kernel() + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.AUTOTUNE: + extra_args = {} + else: + # TODO: remove this when autotuning on AMD is working + num_warps = cls.NUM_WARPS + num_stages = cls.NUM_STAGES + if torch.version.hip: + if B == 1: + num_warps = 4 + num_stages = 1 # TODO num_stages = 0 gives better perf on AMD, but sometimes produces NaNs + BLOCK_N = 32 + elif B <= 4 and split_k <= 128: + num_warps = 2 + num_stages = 1 + BLOCK_N = 32 + elif B <= 16: + if M < 16: + num_warps = 2 + num_stages = 1 + else: + num_warps = 1 + num_stages = 1 + BLOCK_N = 32 + else: + num_warps = 1 + num_stages = 1 + BLOCK_N = 64 + else: + should_modify_warp_and_block = ( + Kkv == 128 + and Kq == 128 + and torch.cuda.get_device_capability() >= (8, 9) + ) + if should_modify_warp_and_block: + if Mq > 1: + num_warps = 4 + # Choose minimal round block size which covers M. + if M > 16: + BLOCK_M = 32 + if M > 32: + BLOCK_M = 64 + if M > 64: + BLOCK_M = 128 + extra_args = { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "num_warps": num_warps, + "num_stages": num_stages, + } + kernel[grid]( + Q=q, + K=k, + V=v, + sm_scale=inp.scale_float, + Out_splitK=o_splitk, + LSE_splitk=lse_splitk, + block_tables=block_tables, + Seq_len=seq_len, + Seq_starts_k=seq_starts_k, + Seq_starts_q=seq_starts_q, + Seq_starts_q_multiplier=seq_starts_q_multiplier, + additive_bias=attn_bias_tensor, + K_fp8_scale_shift=k_fp8_scale_shift, + V_fp8_scale_shift=v_fp8_scale_shift, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_z", "osk_g", "osk_h", "osk_s", "osk_m", "osk_k"), + **_strides(lse_splitk, "lsek_z", "lsek_g", "lsek_h", "lsek_s", "lsek_m"), + **_strides(block_tables, "blocktablesz", "blocktablesl"), + **_strides( + attn_bias_tensor, "bias_b", "bias_g", "bias_h", "bias_qm", "bias_km" + ), + **_strides( + k_fp8_scale_shift, + "k_fp8_scale_shift_z", + "k_fp8_scale_shift_n", + "k_fp8_scale_shift_g", + "k_fp8_scale_shift_h", + ), + **_strides( + v_fp8_scale_shift, + "v_fp8_scale_shift_z", + "v_fp8_scale_shift_n", + "v_fp8_scale_shift_g", + "v_fp8_scale_shift_h", + ), + kv_cache_blocks_per_row=kv_cache_blocks_per_row, + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_DMODEL=Lk, + USE_SEQ_LEN=use_seq_len, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS, + IS_CAUSAL=IS_CAUSAL, + NUM_QUERIES_CAUSAL=NUM_QUERIES_CAUSAL, + IS_SPLITK=IS_SPLITK, + SPLIT_K_EARLY_EXIT=cls.SPLIT_K_EARLY_EXIT, + USE_PAGED_ATTENTION=is_paged, + PAGE_SIZE=page_size, + WRITE_LSE=IS_SPLITK or needs_gradient, + HAS_ADDITIVE_BIAS=attn_bias_tensor is not None, + **extra_args, + ) + if not IS_SPLITK: + out = o_splitk[:, :, :, 0] # Bqq, G, H, Mqq, Kq + if variable_q and mqa_swap_seqlen_head: + out = out.view(1, G, Mq, Hq, Kq).permute(0, 2, 1, 3, 4).contiguous() + else: + out = out.view(Bqq, G, Hq, Mq, Kq) + # This is a copy iff mqa_swap_seqlen_head and Mq, G and Hq are all > 1. + out = out.permute(0, 3, 1, 2, 4).contiguous() + if needs_gradient: + assert lse_splitk is not None + lse = lse_splitk[:, :, :, 0] # Bqq, G, H, Mqq + if variable_q and mqa_swap_seqlen_head: + lse = lse.view(1, G, Mq, Hq).permute(0, 1, 3, 2) + else: + lse = lse.view(Bqq, G, Hq, Mq) + if attn_bias is not None and not variable_q: + lse = lse.permute(1, 2, 0, 3).reshape(1, G, Hq, B * Mq) + else: + lse = None + + if inp.query.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + if lse is not None: + lse = lse[:, 0] + out = out[:, :, 0] + + if lse is None: + return out, None + return out, Context(out=out, lse=lse) + + out = torch.empty(output_shape, device=q.device, dtype=output_dtype) + + # Merge attention and LSE outputs from different split-k chunks + assert lse_splitk is not None + output_lse = None + if needs_gradient: + lse_dtype = torch.float64 if output_f64_lse else torch.float32 + if attn_bias is None or variable_q: + output_lse = torch.empty( + (Bqq, G, Hq, Mq), device=q.device, dtype=lse_dtype + ) + lse = output_lse + else: + output_lse = torch.empty( + (1, G, Hq, B * Mq), device=q.device, dtype=lse_dtype + ) + lse = output_lse.view(G, Hq, B, Mq).permute(2, 0, 1, 3) + + o_splitk = o_splitk[:, :, :, :, :Mqq] + + if mqa_swap_seqlen_head: + if variable_q: + o_splitk = o_splitk.view(Bqq, G, split_k, Mq, Hq, Kq).permute( + 0, 1, 4, 2, 3, 5 + ) + lse_splitk = lse_splitk.view(Bqq, G, split_k, Mq, Hq).permute( + 0, 1, 4, 2, 3 + ) + else: + o_splitk = o_splitk.view(Bqq, G, split_k, Hq, Mq, Kq).permute( + 0, 1, 3, 2, 4, 5 + ) + lse_splitk = lse_splitk.view(Bqq, G, split_k, Hq, Mq).permute( + 0, 1, 3, 2, 4 + ) + + merge_attentions(out, lse, o_splitk, lse_splitk) + + if inp.query.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + if output_lse is not None: + output_lse = output_lse[:, 0] + if Mk == 0: + out.zero_() + + if attn_bias is not None and not variable_q: + out = out.view(1, B * Mq, G, Hq, Kq) + + if output_lse is None: + return out, None + + return out, Context(out=out, lse=output_lse) + + @classmethod + @functools.lru_cache + def get_operator( + cls, + splitk: int, + *, + block_m: Optional[int] = None, + block_n: Optional[int] = None, + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, + split_k_early_exit: Optional[bool] = None, + ) -> Type[AttentionFwOpBase]: + kwargs = { + "NAME": f"triton_splitK{splitk}", + "SPLIT_K": splitk, + } + if block_m is not None: + kwargs["BLOCK_M"] = block_m + if block_n is not None: + kwargs["BLOCK_N"] = block_n + if num_warps is not None: + kwargs["NUM_WARPS"] = num_warps + if num_stages is not None: + kwargs["NUM_STAGES"] = num_stages + if split_k_early_exit is not None: + kwargs["SPLIT_K_EARLY_EXIT"] = split_k_early_exit + return type( + f"FwOp_S{splitk}", + (cls,), + kwargs, + ) + + +def merge_attentions( + attn_out: torch.Tensor, + lse_out: Optional[torch.Tensor], + attn_split: torch.Tensor, + lse_split: torch.Tensor, +): + import triton + + from ._triton.splitk_kernels import _splitK_reduce + + B, M, G, H, Kq = attn_out.shape + B1, G1, H1, split_k, M1, Kq1 = attn_split.shape + B2, G2, H2, split_k1, M2 = lse_split.shape + + assert ( + B == B1 == B2 + and G == G1 == G2 + and H == H1 == H2 + and M == M1 == M2 + and Kq == Kq1 + ), f"Incompatible shapes: {attn_out.shape=}, {attn_split.shape=}, {lse_split.shape=}" + assert ( + split_k == split_k1 + ), f"Incompatible shapes: {attn_split.shape=}, {lse_split.shape=}" + if lse_out is not None: + B3, G3, H3, M3 = lse_out.shape + assert ( + B == B3 and G == G3 and H == H3 and M == M3 + ), f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}" + + num_warps = 4 if B * G * H < 32 or torch.version.hip else 2 + splitK_pow2 = triton.next_power_of_2(split_k) + grid = (M, B * G * H, 1) + _splitK_reduce[grid]( + attn_split, + lse_split, + attn_out, + lse_out, + split_k=split_k, + splitK_pow2=splitK_pow2, + **_strides(attn_split, "osk_z", "osk_g", "osk_h", "osk_s", "osk_m", "osk_k"), + **_strides(lse_split, "lsek_z", "lsek_g", "lsek_h", "lsek_s", "lsek_m"), + **_strides(attn_out, "oz", "om", "og", "oh", "ok"), + **_strides(lse_out, "lse_z", "lse_g", "lse_h", "lse_m"), + BLOCK_SIZE=attn_out.shape[-1], + G=G, + H=H, + WRITE_LSE=lse_out is not None, + num_warps=num_warps, + ) + + +def merge_attentions_varargs( + attn_out: torch.Tensor, + lse_out: Optional[torch.Tensor], + attn_split: Sequence[torch.Tensor], + lse_split: Sequence[torch.Tensor], +): + from xformers.triton.vararg_kernel import unroll_varargs + + from ._triton.splitk_kernels import _splitK_reduce_varargs + + kernel_args, grid = _prepare_reduce_kernel_params( + attn_out, lse_out, attn_split, lse_split + ) + reduce_kernel = unroll_varargs(_splitK_reduce_varargs, N=len(attn_split)) + reduce_kernel[grid]( + *attn_split, + *lse_split, + Out=attn_out, + LSE=lse_out, + **kernel_args, + BLOCK_SIZE=attn_out.shape[-1], + WRITE_LSE=lse_out is not None, + ) + + +def merge_attentions_varargs_backward( + attn_split: List[torch.Tensor], + lse_split: List[torch.Tensor], + attn_out: torch.Tensor, + lse_out: torch.Tensor, + grad_attn: torch.Tensor, + grad_lse: torch.Tensor, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + from xformers.triton.vararg_kernel import unroll_varargs + + from ._triton.splitk_kernels import _splitK_reduce_varargs_backward + + dattn_splitk = [torch.empty_like(x) for x in attn_split] + dlse_splitk = [torch.empty_like(x) for x in lse_split] + + kernel_args, grid = _prepare_reduce_kernel_params( + attn_out, lse_out, attn_split, lse_split, grad_attn, grad_lse + ) + + reduce_kernel_backward = unroll_varargs( + _splitK_reduce_varargs_backward, N=len(attn_split) + ) + reduce_kernel_backward[grid]( + *attn_split, + *lse_split, + *dattn_splitk, + *dlse_splitk, + Out=attn_out, + LSE=lse_out, + DOut=grad_attn, + DLSE=grad_lse, + **kernel_args, + BLOCK_SIZE=attn_out.shape[-1], + ) + + return dattn_splitk, dlse_splitk + + +def _prepare_reduce_kernel_params( + attn_out: torch.Tensor, + lse_out: Optional[torch.Tensor], + attn_split: Sequence[torch.Tensor], + lse_split: Sequence[torch.Tensor], + grad_attn: Optional[torch.Tensor] = None, + grad_lse: Optional[torch.Tensor] = None, +) -> Tuple[Dict[str, int], Tuple[int, int, int]]: + + B, M, G, H, Kq = attn_out.shape + B1, G1, H1, M1, Kq1 = attn_split[0].shape + B2, G2, H2, M2 = lse_split[0].shape + + assert ( + B == B1 == B2 + and G == G1 == G2 + and H == H1 == H2 + and M == M1 == M2 + and Kq == Kq1 + ), f"Incompatible shapes: {attn_out.shape=}, {attn_split[0].shape=}, {lse_split[0].shape=}" + if lse_out is not None: + B3, G3, H3, M3 = lse_out.shape + assert ( + B == B3 and G == G3 and H == H3 and M == M3 + ), f"Incompatible shapes: {attn_out.shape=}, {lse_out.shape=}" + + attn_split_strides = {} + lse_split_strides = {} + for i in range(len(attn_split)): + attn_split_strides.update( + _strides( + attn_split[i], + "osk_z" + str(i), + "osk_g" + str(i), + "osk_h" + str(i), + "osk_m" + str(i), + "osk_k" + str(i), + ) + ) + lse_split_strides.update( + _strides( + lse_split[i], + "lsek_z" + str(i), + "lsek_g" + str(i), + "lsek_h" + str(i), + "lsek_m" + str(i), + ) + ) + + num_warps = 4 if B * G * H < 32 or torch.version.hip else 2 + grid = (M, B * G * H, 1) + + kernel_args = { + "G": G, + "H": H, + "num_warps": num_warps, + **attn_split_strides, + **lse_split_strides, + } + kernel_args.update(_strides(attn_out, "oz", "om", "og", "oh", "ok")) + kernel_args.update(_strides(lse_out, "lse_z", "lse_g", "lse_h", "lse_m")) + if grad_attn is not None: + kernel_args.update(_strides(grad_attn, "doz", "dom", "dog", "doh", "dok")) + kernel_args.update(_strides(grad_lse, "dlse_z", "dlse_g", "dlse_h", "dlse_m")) + return kernel_args, grid + + +FwOp_Map = { + k: FwOp.get_operator(k) for k in [1, 2, 4, 8, 16, 32, 48, 64, 72, 80, 96, 112, 128] +} +FwOp_S1 = FwOp_Map[1] +FwOp_S2 = FwOp_Map[2] +FwOp_S4 = FwOp_Map[4] +FwOp_S8 = FwOp_Map[8] +FwOp_S16 = FwOp_Map[16] +FwOp_S32 = FwOp_Map[32] +FwOp_S64 = FwOp_Map[64] +FwOp_S128 = FwOp_Map[128] diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/seqpar.py b/.venv/lib/python3.11/site-packages/xformers/ops/seqpar.py new file mode 100644 index 0000000000000000000000000000000000000000..b734911fa0e07e88d55d2c739a2d2a717c1413a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/seqpar.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable, List, Tuple + +import torch +from torch.distributed.distributed_c10d import _resolve_process_group + +from .differentiable_collectives import ( + gather_along_first_dim, + gather_along_first_dim_async, + reduce_scatter_along_first_dim, + reduce_scatter_along_first_dim_async, +) +from .sequence_parallel_fused_ops import ( + fused_allgather_and_anything, + fused_allgather_and_linear, + fused_anything_and_reducescatter, + fused_linear_and_reducescatter, +) +from .tiled_matmul import tiled_matmul, tiled_matmul_out + + +@torch.library.custom_op( + "xformers_python::sequence_parallel_leading_matmul_fwd", + mutates_args=(), + device_types="cuda", +) +def sequence_parallel_leading_matmul_fwd( + scattered_input: torch.Tensor, + weights: List[torch.Tensor], + fuse: bool, + process_group_name: str, +) -> List[torch.Tensor]: + process_group = _resolve_process_group(process_group_name) + + if fuse: + gathered_outputs = fused_allgather_and_linear( + scattered_input, [w.t() for w in weights], group=process_group + ) + else: + gathered_input = gather_along_first_dim( + scattered_input, process_group=process_group + ) + (gathered_outputs,) = tiled_matmul( + [[gathered_input]], + [[w for w in weights]], + ) + return gathered_outputs + + +@torch.library.register_fake("xformers_python::sequence_parallel_leading_matmul_fwd") +def sequence_parallel_leading_matmul_fwd_fake( + scattered_input: torch.Tensor, + weights: List[torch.Tensor], + fuse: bool, + process_group_name: str, +) -> List[torch.Tensor]: + mp_size = _resolve_process_group(process_group_name).size() + return [ + scattered_input.new_empty((scattered_input.shape[0] * mp_size, w.shape[1])) + for w in weights + ] + + +@torch.library.custom_op( + "xformers_python::sequence_parallel_leading_matmul_bwd", + mutates_args=(), + device_types="cuda", +) +def sequence_parallel_leading_matmul_bwd( + scattered_input: torch.Tensor, + weights: List[torch.Tensor], + grad_gathered_outputs: List[torch.Tensor], + fuse: bool, + process_group_name: str, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + process_group = _resolve_process_group(process_group_name) + mp_size = process_group.size() + + # torch.library.opcheck gives us gradients whose strides are zero. + # See https://github.com/pytorch/pytorch/issues/132857. + grad_gathered_outputs = [ + grad_go.clone() if any(s == 0 for s in grad_go.stride()) else grad_go + for grad_go in grad_gathered_outputs + ] + + if fuse: + grad_scattered_input = torch.empty_like(scattered_input) + grad_weights = [torch.zeros_like(w) for w in weights] + + grad_gathered_outputss = [ + grad_go.tensor_split(mp_size, dim=0) for grad_go in grad_gathered_outputs + ] + + def my_si_matmul( + grad_gathered_inputs: List[torch.Tensor], + dst_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + (grad_gi,) = grad_gathered_inputs + with torch.cuda.stream(stream_factory()): + tiled_matmul_out( + [[grad_gos[dst_rank] for grad_gos in grad_gathered_outputss]], + [[w.t()] for w in weights], + out=[[grad_gi]], + ) + + fused_anything_and_reducescatter( + my_si_matmul, + [grad_scattered_input], + group=process_group, + ) + + # Each pair of shards of input and grad_output accumulates into the same + # grad_weight. Thus we need to make sure that the in-place addmms are + # sequenced correctly for each of the grad_weights. + events = [torch.cuda.Event() for _ in weights] + + def my_w_matmul( + gathered_inputs_shard: List[torch.Tensor], + src_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + (gi_shard,) = gathered_inputs_shard + for grad_gos, grad_w, event in zip( + grad_gathered_outputss, grad_weights, events + ): + with torch.cuda.stream(stream_factory()): + event.wait() + grad_w.t().addmm_(grad_gos[src_rank].t(), gi_shard) + event.record() + + fused_allgather_and_anything( + [scattered_input], + my_w_matmul, + group=process_group, + ) + else: + gathered_input, handle = gather_along_first_dim_async( + scattered_input, process_group=process_group + ) + ((grad_gathered_input,),) = tiled_matmul( + [[grad_go for grad_go in grad_gathered_outputs]], + [[w.t()] for w in weights], + ) + if handle is not None: + handle.wait() + + grad_scattered_input, handle = reduce_scatter_along_first_dim_async( + grad_gathered_input, process_group=process_group + ) + + grad_weights_tuples = tiled_matmul( + [[grad_go.t()] for grad_go in grad_gathered_outputs], + [[gathered_input]], + ) + if handle is not None: + handle.wait() + + grad_weights = [grad_w.t() for (grad_w,) in grad_weights_tuples] + + return grad_scattered_input, grad_weights + + +@torch.library.register_fake("xformers_python::sequence_parallel_leading_matmul_bwd") +def sequence_parallel_leading_matmul_bwd_fake( + scattered_input: torch.Tensor, + weights: List[torch.Tensor], + grad_gathered_outputs: List[torch.Tensor], + fuse: bool, + process_group_name: str, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + return (torch.empty_like(scattered_input), [torch.empty_like(w) for w in weights]) + + +def sequence_parallel_leading_matmul_setup_context(ctx, inputs, output): + scattered_input, weights, fuse, process_group_name = inputs + ctx.save_for_backward(scattered_input, *weights) + ctx.fuse = fuse + ctx.process_group_name = process_group_name + + +def sequence_parallel_leading_matmul_bwd_bridge(ctx, grad_gathered_outputs): + scattered_input, *weights = ctx.saved_tensors + (grad_scattered_input, grad_weights,) = sequence_parallel_leading_matmul_bwd( + scattered_input, + list(weights), + list(grad_gathered_outputs), + ctx.fuse, + ctx.process_group_name, + ) + return grad_scattered_input, grad_weights, None, None + + +torch.library.register_autograd( + "xformers_python::sequence_parallel_leading_matmul_fwd", + sequence_parallel_leading_matmul_bwd_bridge, + setup_context=sequence_parallel_leading_matmul_setup_context, +) + + +def sequence_parallel_leading_matmul( + x: torch.Tensor, + ws: List[torch.Tensor], + *, + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> List[torch.Tensor]: + os = sequence_parallel_leading_matmul_fwd( + x.flatten(0, -2), ws, fuse, process_group.group_name + ) + return [o.view(-1, *x.shape[1:-1], w.shape[1]) for o, w in zip(os, ws)] + + +@torch.library.custom_op( + "xformers_python::sequence_parallel_trailing_matmul_fwd", + mutates_args=(), + device_types="cuda", +) +def sequence_parallel_trailing_matmul_fwd( + gathered_input: torch.Tensor, + weight: torch.Tensor, + fuse: bool, + process_group_name: str, +) -> torch.Tensor: + process_group = _resolve_process_group(process_group_name) + + if fuse: + scattered_output = fused_linear_and_reducescatter( + gathered_input, weight.t(), group=process_group + ) + else: + gathered_output = torch.matmul(gathered_input, weight) + scattered_output = reduce_scatter_along_first_dim( + gathered_output, process_group=process_group + ) + return scattered_output + + +@torch.library.register_fake("xformers_python::sequence_parallel_trailing_matmul_fwd") +def sequence_parallel_trailing_matmul_fwd_fake( + gathered_input: torch.Tensor, + weight: torch.Tensor, + fuse: bool, + process_group_name: str, +) -> torch.Tensor: + mp_size = _resolve_process_group(process_group_name).size() + return gathered_input.new_empty( + (gathered_input.shape[0] // mp_size, weight.shape[1]) + ) + + +@torch.library.custom_op( + "xformers_python::sequence_parallel_trailing_matmul_bwd", + mutates_args=(), + device_types="cuda", +) +def sequence_parallel_trailing_matmul_bwd( + gathered_input: torch.Tensor, + weight: torch.Tensor, + grad_scattered_output: torch.Tensor, + fuse: bool, + process_group_name: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + process_group = _resolve_process_group(process_group_name) + mp_size = process_group.size() + + # torch.library.opcheck gives us gradients whose strides are zero. + # See https://github.com/pytorch/pytorch/issues/132857. + if any(s == 0 for s in grad_scattered_output.stride()): + grad_scattered_output = grad_scattered_output.clone() + + if fuse: + grad_gathered_input = torch.empty_like(gathered_input) + grad_weight = torch.zeros_like(weight) + + gathered_inputs = gathered_input.tensor_split(mp_size, dim=0) + grad_gathered_inputs = grad_gathered_input.tensor_split(mp_size, dim=0) + + def my_gi_and_w_matmul( + grad_gathered_outputs_shard: List[torch.Tensor], + src_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + (grad_go_shard,) = grad_gathered_outputs_shard + with torch.cuda.stream(stream_factory()): + torch.matmul( + grad_go_shard, weight.t(), out=grad_gathered_inputs[src_rank] + ) + with torch.cuda.stream(stream_factory()): + grad_weight.t().addmm_(grad_go_shard.t(), gathered_inputs[src_rank]) + + fused_allgather_and_anything( + [grad_scattered_output], + my_gi_and_w_matmul, + group=process_group, + ) + else: + grad_gathered_output = gather_along_first_dim( + grad_scattered_output, process_group=process_group + ) + grad_gathered_input = torch.matmul(grad_gathered_output, weight.t()) + grad_weight = torch.matmul(grad_gathered_output.t(), gathered_input).t() + + return grad_gathered_input, grad_weight + + +@torch.library.register_fake("xformers_python::sequence_parallel_trailing_matmul_bwd") +def sequence_parallel_trailing_matmul_bwd_fake( + gathered_input: torch.Tensor, + weight: torch.Tensor, + grad_scattered_output: torch.Tensor, + fuse: bool, + process_group_name: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + return (torch.empty_like(gathered_input), torch.empty_like(weight)) + + +def sequence_parallel_trailing_matmul_setup_context(ctx, inputs, output): + gathered_input, weight, fuse, process_group_name = inputs + ctx.save_for_backward(gathered_input, weight) + ctx.fuse = fuse + ctx.process_group_name = process_group_name + + +def sequence_parallel_trailing_matmul_bwd_bridge(ctx, grad_scattered_output): + gathered_input, weight = ctx.saved_tensors + (grad_gathered_input, grad_weight,) = sequence_parallel_trailing_matmul_bwd( + gathered_input, + weight, + grad_scattered_output, + ctx.fuse, + ctx.process_group_name, + ) + return grad_gathered_input, grad_weight, None, None + + +torch.library.register_autograd( + "xformers_python::sequence_parallel_trailing_matmul_fwd", + sequence_parallel_trailing_matmul_bwd_bridge, + setup_context=sequence_parallel_trailing_matmul_setup_context, +) + + +def sequence_parallel_trailing_matmul( + x: torch.Tensor, + w: torch.Tensor, + *, + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + o = sequence_parallel_trailing_matmul_fwd( + x.flatten(0, -2), w, fuse, process_group.group_name + ) + return o.view(-1, *x.shape[1:-1], w.shape[1]) diff --git a/.venv/lib/python3.11/site-packages/xformers/ops/sp24.py b/.venv/lib/python3.11/site-packages/xformers/ops/sp24.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c3da7824e53cb0efda2cdc3bba2216272a10a0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/ops/sp24.py @@ -0,0 +1,848 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import ctypes +import glob +import os +import time +import warnings +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast + +import torch + +from .common import BaseOperator, get_operator, get_xformers_operator, register_operator + + +@register_operator +class SparsifyBothWays(BaseOperator): + OPERATOR = get_xformers_operator("sparse24_sparsify_both_ways") + OPERATOR_CATEGORY = "sp24" + NAME = "sparse24_sparsify_both_ways" + + +@register_operator +class SparsifyApply(BaseOperator): + OPERATOR = get_xformers_operator("sparse24_apply") + OPERATOR_CATEGORY = "sp24" + NAME = "sparse24_apply" + + +@register_operator +class SparsifyApplyDenseOutput(BaseOperator): + OPERATOR = get_xformers_operator("sparse24_apply_dense_output") + OPERATOR_CATEGORY = "sp24" + NAME = "sparse24_apply_dense_output" + + +@register_operator +class Sp24Gemm(BaseOperator): + OPERATOR = get_xformers_operator("_sparse24_gemm") + OPERATOR_CATEGORY = "sp24" + NAME = "_sparse24_gemm" + + +def _get_cusparselt_lib() -> Optional[str]: + libs = glob.glob( + str(Path(torch._C.__file__).parent / "lib" / "libcusparseLt*.so.0") + ) + if len(libs) != 1: + return None + return libs[0] + + +def _get_cusparselt_torch_version() -> Tuple[int, int, int]: + """ + Returns the version of the cusparselt.so library that ships with pytorch 2.2+ + """ + lib_path = _get_cusparselt_lib() + if lib_path is None: + return (0, 0, 0) + lib = ctypes.CDLL(lib_path) + + def get_version_part(version_part: int) -> int: + value = ctypes.c_int() + ret = lib.cusparseLtGetProperty(version_part, ctypes.byref(value)) + if ret != 0: + return -1 + return value.value + + return (get_version_part(0), get_version_part(1), get_version_part(2)) + + +_cusplt_version = _get_cusparselt_torch_version() +_cusplt_version_str = ".".join(str(v) for v in _cusplt_version) + + +@register_operator +class Sp24GemmCuspltSearch(BaseOperator): + OPERATOR = get_operator("aten", "_cslt_sparse_mm_search") + OPERATOR_CATEGORY = "sp24" + NAME = f"_cslt_sparse_mm_search@{_cusplt_version_str}" + + +@register_operator +class Sp24GemmCusplt(BaseOperator): + OPERATOR = get_operator("aten", "_cslt_sparse_mm") + OPERATOR_CATEGORY = "sp24" + NAME = f"_cslt_sparse_mm@{_cusplt_version_str}" + + +def _has_cusparseLt() -> bool: + available = _cusplt_version >= (0, 4, 0) + if not available: + return False + if _cusplt_version < (0, 5, 0): + # Version 0.5.0 has much better perf because it can fuse the + # transpose within the GEMM epilogue + warnings.warn( + f"You have cusparseLt version {_cusplt_version_str} " + f"but you get better performance with v0.5.0+ if " + f"you replace the .so file ({_get_cusparselt_lib()})" + ) + + # Sm90 added in 6.0 + compute_capability = (0, 0) + if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability("cuda") + if _cusplt_version < (6, 0, 0): + if compute_capability >= (9, 0): + return False + return available + + +def sparse24_pointwise_op( + func, types, args=(), kwargs=None, allow_sparsify_args_list=() +): + self = None + for tensor in args: + if isinstance(tensor, Sparse24Tensor): + self = tensor + assert self is not None + args_updated = [] + for i, tensor in enumerate(args): + if isinstance(tensor, torch.Tensor): + if not isinstance(tensor, Sparse24Tensor): + if i in allow_sparsify_args_list: + tensor = sparsify24_like(tensor, self) + else: + raise ValueError( + f"Operation {func.__module__}.{func.__name__} on Sparse24Tensor requires all operands to " + f"be Sparse24Tensors, but operand {i} is a {type(tensor)}" + ) + if ( + tensor.threads_masks is None + or self.threads_masks is None + or tensor.threads_masks.data_ptr() != self.threads_masks.data_ptr() + or tensor.threads_masks.stride() != self.threads_masks.stride() + ): + raise ValueError( + f"Operation {func.__module__}.{func.__name__} on Sparse24Tensor requires all operands to be " + "Sparse24Tensors with the same sparsity pattern" + ) + args_updated.append(tensor) + assert isinstance( + self, Sparse24TensorCutlass + ), "Only implemented for CUTLASS tensors" + return Sparse24TensorCutlass( + self.shape, + func( + *[(x.packed if isinstance(x, Sparse24Tensor) else x) for x in args_updated] + ), + self.meta, + func( + *[ + (x.packed_t if isinstance(x, Sparse24Tensor) else x) + for x in args_updated + ] + ), + self.meta_t, + self.threads_masks, + ) + + +def sparse24_mm(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 2 + A, B = args + if A.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + "`Sparse24Tensor` matmul: Broadcasting is not implemented" + ) + if isinstance(A, Sparse24Tensor): + return A._mm(B) + else: + B_t = B.t() + assert isinstance(B_t, Sparse24Tensor) + return B_t._mm(A.t(), prefer_col_major_output=True).t() + + +def sparse24_addmm(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 3 + bias, A, B = args + if A.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + "`Sparse24Tensor` matmul: Broadcasting is not implemented" + ) + if bias.ndim != 1: + raise NotImplementedError( + f"`Sparse24Tensor` matmul: only bias dim=1 supported. Shape={bias.shape}" + ) + if isinstance(A, Sparse24Tensor): + raise NotImplementedError( + "`Sparse24Tensor` matmul: only operand B of `addmm` can be sparse" + ) + B_t = B.t() + assert isinstance(B_t, Sparse24Tensor) + return B_t._mm(A.t(), bias=bias, prefer_col_major_output=True).t() + + +def sparse24_linear(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) in [2, 3] + A, B = args[:2] + bias = args[2] if len(args) == 3 else None + if bias is None: + return A @ B.t() + return sparse24_addmm( + func=None, + types=None, + args=[bias, A, B.t()], + ) + + +def sparse24_t(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 1 + self = args[0] + assert isinstance(self, Sparse24Tensor) + assert len(self.shape) == 2 + return self.__class__( + (self.shape[-1], self.shape[0]), + packed=self.packed_t, + meta=self.meta_t, + packed_t=self.packed, + meta_t=self.meta, + threads_masks=self.threads_masks.transpose(0, 1), + ) + + +def sparse24_view(func, types, args=(), kwargs=None) -> torch.Tensor: + assert len(args) == 2 + self, shape = args + if tuple(shape) != self.shape: + raise NotImplementedError( + f"`view` is not implemented for Sparse24Tensor, except for the dummy case (shape={shape})" + ) + return self + + +def sparse24_detach(func, types, args, kwargs) -> torch.Tensor: + assert len(args) == 1 + self = args[0] + return self.__class__( + shape=self.shape, + packed=self.packed, + meta=self.meta, + packed_t=self.packed_t, + meta_t=self.meta_t, + threads_masks=self.threads_masks, + requires_grad=False, + ) + + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def fallback_dispatcher(func, types, args, kwargs): + with no_dispatch(): + return func(*args) + + +SPARSE24_DISPATCH_CUTLASS = { + torch.ops.aten.is_same_size: fallback_dispatcher, + torch.ops.aten.detach_: fallback_dispatcher, + torch.ops.aten.detach: sparse24_detach, + torch.ops.aten.relu: sparse24_pointwise_op, + torch.ops.aten.gelu: sparse24_pointwise_op, + torch.ops.aten.silu: sparse24_pointwise_op, + torch.ops.aten.mul: partial( + # `mul` BW in swiglu + sparse24_pointwise_op, + allow_sparsify_args_list=( + 0, + 1, + ), + ), + torch.ops.aten.add: sparse24_pointwise_op, + # Note: for these ops, we allow the gradient to come in as a `torch.Tensor` + # and we will run the sparsification right before calling the BW aten func + torch.ops.aten.gelu_backward: partial( + sparse24_pointwise_op, allow_sparsify_args_list=(0,) + ), + torch.ops.aten.silu_backward: partial( + sparse24_pointwise_op, allow_sparsify_args_list=(0, 1) + ), + torch.ops.aten.threshold_backward: partial( # relu BW + sparse24_pointwise_op, + allow_sparsify_args_list=(0,), + ), + torch.ops.aten.mm: sparse24_mm, + torch.ops.aten.matmul: sparse24_mm, + torch.ops.aten.t: sparse24_t, + torch.ops.aten.view: sparse24_view, + torch.ops.aten.linear: sparse24_linear, +} + +SPARSE24_DISPATCH_CUSPARSELT = { + torch.ops.aten.is_same_size: fallback_dispatcher, + torch.ops.aten.detach_: fallback_dispatcher, + torch.ops.aten.detach: sparse24_detach, + torch.ops.aten.t: sparse24_t, + torch.ops.aten.view: sparse24_view, + torch.ops.aten.mm: sparse24_mm, + torch.ops.aten.matmul: sparse24_mm, + torch.ops.aten.addmm: sparse24_addmm, + torch.ops.aten.linear: sparse24_linear, +} + + +class Sparse24Tensor(torch.Tensor): + packed: torch.Tensor + meta: torch.Tensor + packed_t: torch.Tensor + meta_t: torch.Tensor + threads_masks: torch.Tensor + __slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"] + + # We need to update the new method here to tell PyTorch what should be + # the Tensor corresponding to the wrapper object + @staticmethod + def __new__( + cls, + shape, + packed: torch.Tensor, + meta: torch.Tensor, + packed_t: torch.Tensor, + meta_t: torch.Tensor, + threads_masks: torch.Tensor, + *, + requires_grad=False, + ): + assert isinstance(packed, torch.Tensor) + tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + shape, + device=packed.device, + dtype=packed.dtype, + requires_grad=requires_grad, + ) + tensor.packed = packed + tensor.meta = meta + tensor.packed_t = packed_t + tensor.meta_t = meta_t + tensor.threads_masks = threads_masks + return tensor + + def __repr__(self): + return f"{self.__class__.__name__}(shape={self.shape})" + + def _sp24_to_dense(self) -> torch.Tensor: + # Multiply by identity + # WARN: This is not efficient at all + e = torch.eye( + self.shape[1], self.shape[1], device=self.device, dtype=self.dtype + ) + return self @ e + + def _mm( + self, + B: torch.Tensor, + *, + prefer_col_major_output: bool = False, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError() + + __torch_function__ = torch._C._disabled_torch_function_impl + + def __tensor_flatten__(self): + return self.__slots__, (self.shape, self.requires_grad) + + @classmethod + def __tensor_unflatten__( + cls, inner_tensors, flatten_spec, outer_size, outer_stride + ): + shape, requires_grad = flatten_spec + return cls( + shape, + **inner_tensors, + requires_grad=requires_grad, + ) + + +class Sparse24TensorCutlass(Sparse24Tensor): + def _mm( + self, + B: torch.Tensor, + *, + bias: Optional[torch.Tensor] = None, + prefer_col_major_output: bool = False, + ) -> torch.Tensor: + if isinstance(B, Sparse24Tensor): + raise ValueError( + "`Sparse24Tensor @ Sparse24Tensor` is not supported by the hardware" + ) + if bias is not None: + raise NotImplementedError( + f"`Sparse24Tensor` with backend='{BACKEND_CUTLASS}' does not support matmul with bias. " + f"Remove the bias, or use backend='{BACKEND_CUSPARSELT}'" + ) + if self.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented" + ) + if self.shape[1] != B.shape[0]: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: invalid shapes \ + ({self.shape[0]}, {self.shape[1]}) @ ({B.shape[0]}, {B.shape[1]})" + ) + return Sp24Gemm.OPERATOR(self.packed, B, self.meta)[: self.shape[0]] + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if func._overloadpacket not in SPARSE24_DISPATCH_CUTLASS: + raise NotImplementedError( + f"{cls.__name__} only supports a specific set of operations, " + f"can't perform requested op ({func.__name__})" + ) + return SPARSE24_DISPATCH_CUTLASS[func._overloadpacket]( + func, types, args, kwargs + ) + + +_CUSPLT_ALG_CACHE: Dict[Tuple[int, int, int, str, torch.dtype, bool], int] = {} +_CUSPLT_TUNE = os.environ.get("XFORMERS_CUSPARSELT_TUNE", "1") == "1" + + +def _cusplt_find_alg( + shape: List[int], + packed: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor], + transpose_result: bool, +) -> int: + """ + cuSPARSELt has multiple algorithms (that correspond to different kernels) + to run a given GEMM, because the optimal kernel depends on the GEMM dimensions. + This function attempts to find the most efficient one by benchmarking all + of them. + NOTE: cuSPARSELt also provides a function to search the best algorithm + (exposed via `aten:_cslt_sparse_mm_search`) but it often fails to find the best + algorithm, so we need this workaround. + """ + if not _CUSPLT_TUNE: + return 0 + M, K = shape + N = B.shape[1] + fmt = "r" + fmt += "r" if B.stride(-1) <= 1 else "c" + fmt += "c" if transpose_result else "r" + h = (M, N, K, fmt, B.dtype, bias is not None) + if h in _CUSPLT_ALG_CACHE: + return _CUSPLT_ALG_CACHE[h] + + REPEAT = 10 + TIME_ALGO = [] + for algo in range(70): + has_error = False + for i in range(REPEAT): + try: + Sp24GemmCusplt.OPERATOR( + packed, B, bias=bias, transpose_result=transpose_result, alg_id=algo + ) + except RuntimeError: + has_error = True + break + if i == 1: # 1 iteration of warmup + torch.cuda.synchronize() + t = time.monotonic() + if has_error: + break + torch.cuda.synchronize() + dt = time.monotonic() - t + TIME_ALGO.append((dt, algo)) + TIME_ALGO.sort() + _CUSPLT_ALG_CACHE[h] = TIME_ALGO[0][1] + return _CUSPLT_ALG_CACHE[h] + + +@torch.library.custom_op("xformers::_cusplt_mm", mutates_args=(), device_types=["cuda"]) +def _cusplt_mm( + shape: List[int], + packed: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor], + transpose_result: bool, +) -> torch.Tensor: + """ + This operator wraps find_algo + gemm. This is because we don't want find_algo + to be visible by torch compile, otherwise it will remove it from the graph. + """ + alg_id = _cusplt_find_alg( + shape, packed, B, bias=bias, transpose_result=transpose_result + ) + return Sp24GemmCusplt.OPERATOR( + packed, B, bias=bias, transpose_result=transpose_result, alg_id=alg_id + ) + + +@torch.library.register_fake("xformers::_cusplt_mm") +def _cusplt_mm_meta( + shape: List[int], + packed: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor], + transpose_result: bool, +) -> torch.Tensor: + M, K = shape + N = B.shape[1] + if transpose_result: + return torch.empty([N, M], dtype=B.dtype, device=B.device) + return torch.empty([M, N], dtype=B.dtype, device=B.device) + + +class Sparse24TensorCuSparseLt(Sparse24Tensor): + def _mm( + self, + B: torch.Tensor, + *, + prefer_col_major_output: bool = False, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(B, Sparse24Tensor): + raise ValueError( + "`Sparse24Tensor @ Sparse24Tensor` is not supported by the hardware" + ) + if self.ndim != 2 or B.ndim != 2: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented" + ) + if self.shape[1] != B.shape[0]: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: invalid shapes \ + ({self.shape[0]}, {self.shape[1]}) @ ({B.shape[0]}, {B.shape[1]})" + ) + if B.shape[1] % 8 != 0: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`. " + "The dense matrix B should have the second dimension aligned to 8." + ) + if B.dtype != self.dtype: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, " + f"with A.dtype={self.dtype} and B.dtype={B.dtype}. " + "This operation is only supported when A and B have the same data type." + ) + if bias is not None and bias.dtype != self.dtype: + raise NotImplementedError( + f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, " + "with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. " + "This operation is only supported when A, B and C have the same data type." + ) + assert _has_cusparseLt() + out = torch.ops.xformers._cusplt_mm( + self.shape, + self.packed, + B, + bias=bias, + transpose_result=prefer_col_major_output, + ) + if prefer_col_major_output: + out = out.t() + return out[: self.shape[0]] + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if func._overloadpacket not in SPARSE24_DISPATCH_CUSPARSELT: + raise NotImplementedError( + f"{cls.__name__} only supports a specific set of operations, " + f"can't perform requested op ({func.__name__})" + ) + return SPARSE24_DISPATCH_CUSPARSELT[func._overloadpacket]( + func, types, args, kwargs + ) + + +torch._dynamo.allow_in_graph(Sparse24TensorCuSparseLt) +torch._dynamo.allow_in_graph(Sparse24TensorCutlass) + +GRADIENT_SP24 = "24sparse" +GRADIENT_DENSE = "24dense" +GRADIENT_STE = "ste" # Straight-Through Estimator + +BACKEND_CUTLASS = "cutlass" +BACKEND_CUSPARSELT = "cusparselt" +BACKEND_DENSE = "dense" + + +def _sparsify24_forward(x: torch.Tensor, *, algo: str, backend: str) -> Sparse24Tensor: + assert backend in [ + BACKEND_CUTLASS, + BACKEND_CUSPARSELT, + ], f"Invalid backend: {backend}" + if isinstance(x, Sparse24Tensor): + if x.threads_masks is None: + raise ValueError("Input to `sparsify24` is already sparse") + return x + + (packed, meta, packed_t, meta_t, threads_masks) = SparsifyBothWays.OPERATOR( + x, algorithm=algo, backend=backend + ) + cls = ( + Sparse24TensorCutlass + if backend == BACKEND_CUTLASS + else Sparse24TensorCuSparseLt + ) + return cls( + x.shape, + packed=packed, + meta=meta, + packed_t=packed_t, + meta_t=meta_t, + threads_masks=threads_masks, + requires_grad=False, + ) + + +class _Sparsify24Func(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, algo: str, gradient: str, backend: str): # type: ignore[override] + if gradient not in [GRADIENT_SP24, GRADIENT_DENSE, GRADIENT_STE]: + raise ValueError( + f"Invalid gradient type: '{gradient}'. " + f"Expected '{GRADIENT_SP24}' or '{GRADIENT_DENSE}' or '{GRADIENT_STE}" + ) + out = _sparsify24_forward(x, algo=algo, backend=backend) + ctx.threads_masks = out.threads_masks + ctx.meta = out.meta + ctx.meta_t = out.meta_t + ctx.dtype = out.dtype + ctx.gradient = gradient + return out + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] + if isinstance(grad_out, Sparse24Tensor) or ctx.gradient == GRADIENT_STE: + return grad_out, None, None, None + assert not isinstance(grad_out, Sparse24Tensor) + assert grad_out.dtype == ctx.dtype + if ctx.gradient == GRADIENT_SP24: + packed, _, packed_t, _ = SparsifyApply.OPERATOR(grad_out, ctx.threads_masks) + grad_in: torch.Tensor = Sparse24TensorCutlass( + grad_out.shape, + packed, + ctx.meta, + packed_t, + ctx.meta_t, + ctx.threads_masks, + requires_grad=grad_out.requires_grad, + ) + elif ctx.gradient == GRADIENT_DENSE: + assert ctx.threads_masks.is_contiguous() + grad_in = SparsifyApplyDenseOutput.OPERATOR(grad_out, ctx.threads_masks) + else: + assert False, f"Unsupported gradient type: {ctx.gradient}" + return ( + grad_in, + None, + None, + None, + ) + + +class _Sparsify24STEFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + algo: str, + backend: str, + bw_mul0: float, + bw_mul1: float, + ): # type: ignore[override] + out = _sparsify24_forward(x, algo=algo, backend=backend) + ctx.threads_masks = out.threads_masks + ctx.bw_mul0 = bw_mul0 + ctx.bw_mul1 = bw_mul1 + return out + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] + assert not isinstance(grad_out, Sparse24Tensor) + if ctx.bw_mul0 == 1.0 and ctx.bw_mul1 == 1.0: + grad_in = grad_out + else: + grad_in = SparsifyApplyDenseOutput.OPERATOR( + grad_out, ctx.threads_masks, mul0=ctx.bw_mul0, mul1=ctx.bw_mul1 + ) + return ( + grad_in, + None, # algo + None, # backend + None, # bw_mul0 + None, # bw_mul1 + ) + + +class _Sparsify24LikeFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, pattern: Sparse24Tensor, gradient: str, backend: str): # type: ignore[override] + if not isinstance(pattern, Sparse24Tensor): + raise NotImplementedError( + "`sparsify24_like`: `pattern` must be a sparse tensor" + ) + if not pattern.threads_masks.is_contiguous(): + raise NotImplementedError( + "`sparsify24_like` is not implemented when `pattern` is transposed" + ) + if gradient not in [GRADIENT_DENSE, GRADIENT_SP24, GRADIENT_STE]: + raise ValueError(f'`sparsify24_like`: invalid gradient type "{gradient}"') + ctx.threads_masks = pattern.threads_masks + ctx.meta = pattern.meta + ctx.meta_t = pattern.meta_t + ctx.dtype = pattern.dtype + ctx.gradient = gradient + if backend == BACKEND_DENSE: + assert ctx.threads_masks.is_contiguous() + return SparsifyApplyDenseOutput.OPERATOR(x, ctx.threads_masks) + packed, meta, packed_t, meta_t = SparsifyApply.OPERATOR( + x, ctx.threads_masks, backend=backend + ) + if backend == BACKEND_CUTLASS: + return Sparse24TensorCutlass( + x.shape, + packed, + ctx.meta, + packed_t, + ctx.meta_t, + ctx.threads_masks, + requires_grad=x.requires_grad, + ) + assert backend == BACKEND_CUSPARSELT, f"Invalid backend: {backend}" + meta.copy_(pattern.meta) + meta_t.copy_(pattern.meta_t) + return Sparse24TensorCuSparseLt( + x.shape, + packed, + meta, + packed_t, + meta_t, + ctx.threads_masks, + requires_grad=x.requires_grad, + ) + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] + if ctx.gradient == GRADIENT_STE or isinstance(grad_out, Sparse24Tensor): + return grad_out, None, None, None + assert not isinstance(grad_out, Sparse24Tensor) + assert grad_out.dtype == ctx.dtype + + if ctx.gradient == GRADIENT_DENSE: + assert ctx.threads_masks.is_contiguous() + return ( + SparsifyApplyDenseOutput.OPERATOR(grad_out, ctx.threads_masks), + None, + None, + None, + ) + assert ctx.gradient == GRADIENT_SP24 + + packed, _, packed_t, _ = SparsifyApply.OPERATOR( + grad_out, ctx.threads_masks, backend=BACKEND_CUTLASS + ) + return ( + Sparse24TensorCutlass( + grad_out.shape, + packed, + ctx.meta, + packed_t, + ctx.meta_t, + ctx.threads_masks, + requires_grad=grad_out.requires_grad, + ), + None, + None, + None, + ) + + +# We want to use `torch._dynamo.allow_in_graph` as a decorator +# (see https://fburl.com/workplace/uimiz0mf) but it breaks mypy. +# This is a hack to work around this +F = TypeVar("F", bound=Callable[..., Any]) + + +def allow_in_graph(func: F) -> F: + return cast(F, torch._dynamo.allow_in_graph(func)) + + +@allow_in_graph +def sparsify24( + x: torch.Tensor, + algo: str = "", + gradient: str = GRADIENT_SP24, + backend: str = BACKEND_CUTLASS, +) -> Sparse24Tensor: + return _Sparsify24Func.apply(x, algo, gradient, backend) + + +@allow_in_graph +def sparsify24_ste( + x: torch.Tensor, + algo: str = "", + backend: str = BACKEND_CUTLASS, + bw_mul0: float = 1.0, + bw_mul1: float = 1.0, +) -> Sparse24Tensor: + """ + 2:4 sparsification, with Straight Through Estimator for the + backward pass (eg the gradient is *not* sparsified). + Optionally, `bw_mul[0-1]` provide the option to rescale the gradient + differently for pruned (`bw_mul0`) and kept values (`bw_mul1`). + """ + return _Sparsify24STEFunc.apply(x, algo, backend, bw_mul0, bw_mul1) + + +@allow_in_graph +def sparsify24_like( + x: torch.Tensor, + pattern: torch.Tensor, + gradient: str = GRADIENT_SP24, + backend: str = "", + out_dense: Optional[bool] = None, # <-- TODO: Deprecate this in favor of "gradient" +) -> Sparse24Tensor: + if out_dense is not None and out_dense: + backend = BACKEND_DENSE + if backend == "": + backend = ( + BACKEND_CUSPARSELT + if isinstance(pattern, Sparse24TensorCuSparseLt) + else BACKEND_CUTLASS + ) + if not isinstance(pattern, Sparse24Tensor): + raise ValueError( + f"`pattern` must be a `Sparse24Tensor` but got a {type(pattern)}" + ) + # Handle transposed case + if not pattern.threads_masks.is_contiguous(): + return _Sparsify24LikeFunc.apply(x.t(), pattern.t(), gradient, backend).t() + return _Sparsify24LikeFunc.apply(x, pattern, gradient, backend)