diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a082c929a3e699c8c078daad9a819175ac81c4e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/_cpp_lib.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/_cpp_lib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2605b4aa8ed9767bde63614ed51365daa9774189 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/_cpp_lib.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/_deprecation_warning.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/_deprecation_warning.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d67392847052cfd9306dcabdbafdcee57c7dcde Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/_deprecation_warning.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/attn_bias_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/attn_bias_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cee4763790d8f9432f2541f943b3d3e47e80295 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/attn_bias_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/checkpoint.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/checkpoint.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27108d38ec9c30aff8d3490bf1f1ca99db9c76b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/checkpoint.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/info.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/info.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cae094c2ff1ba15a170a15cbe6acd1ec3fd86a0f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/info.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/test.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/test.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b515a926b804eaabe23d0540a78ec6a4df80d5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/test.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..515005abbbdf30bd657cc11afba24ec39787c3f3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/__pycache__/version.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3978240860ca0ec57fbeca17ad0f0a0aa96c0528 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/__pycache__/version.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_mem_eff_attention.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_mem_eff_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7418e9b7de71a63afb6ea1a8b2ead75df31d4e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -0,0 +1,373 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import random +from functools import partial + +import torch +from torch.utils import benchmark + +import xformers.ops +import xformers.ops.fmha as fmha +from xformers.attn_bias_utils import create_attn_bias, ref_attention +from xformers.benchmarks.utils import benchmark_main_helper, create_argparser + +torch.backends.cuda.matmul.allow_tf32 = False + +min_run_time = 0.5 +device = torch.device("cuda") + +NUM_THREADS = [1] if device.type == "cuda" else [1, 40] +VISION_SHAPES = [ + # ViT + (384, 197, 1, 88), + (384, 197, 1, 80), + (384, 197, 1, 64), + (1024, 197, 1, 88), + (1024, 197, 1, 80), + (1024, 197, 1, 64), + # ViT-Huge + (32 * 16, 197, 1, 80), + (32, 197, 16, 80), + (32, 197, 16, 64), + (32, 197, 16, 128), + # ViT-Giant + (16 * 16, 197, 1, 88), + (16, 197, 16, 88), + (16, 197, 16, 64), + (16, 197, 16, 128), + # FB models + (1024, 82, 8, 64), + (150, 256, 16, 64), + (64, 256, 12, 64), + # Stable diffusion (https://github.com/huggingface/diffusers/pull/532) + (1, 4096, 16, 40), # 512x512 + (1, 16384, 16, 40), # 1024x1024 + (1, 4096, 16, 80), + (1, 16384, 16, 80), + # + bs4 + (4, 4096, 16, 40), + (4, 16384, 16, 40), + (4, 4096, 16, 80), + (4, 16384, 16, 80), + # ParlAI model + (256, 4096, 16, 64), + # Zetta B M H K + (8, 2048, 20, 128), +] + +LLM_SHAPES = [ + # LLaMa 70b - mp=8/16 + *sorted(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128])), + *sorted( + itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256]) + ), +] + + +OPS = [ + (xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp), + (xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp), + (xformers.ops.fmha.flash3.FwOp, xformers.ops.fmha.flash3.BwOp), + (xformers.ops.fmha.ck.FwOp, xformers.ops.fmha.ck.BwOp), +] + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +VISION_CASES, LLM_CASES = [ + list( + product_dict( + shape_q=SHAPES, + num_threads=NUM_THREADS, + dropout_p=[0.0], + attn_bias_cfg=[(type(None), False)], + dtype=[torch.half], + ) + ) + for SHAPES in (VISION_SHAPES, LLM_SHAPES) +] + +# Add more cases with some variations +for c in VISION_CASES.copy(): + c = c.copy() + c.update( + random.Random(str(c["shape_q"])).choice( + [ + {"dropout_p": 0.3}, + {"attn_bias_cfg": (torch.Tensor, False)}, + {"attn_bias_cfg": (torch.Tensor, True)}, + {"dtype": torch.bfloat16}, + {"dtype": torch.float}, + ] + ) + ) + VISION_CASES.append(c) + + +LLM_CASE_UPDATES = [ + {"attn_bias_cfg": (torch.Tensor, True)}, + {"attn_bias_cfg": (xformers.ops.LowerTriangularMask, False)}, + *[ + { + "attn_bias_cfg": ( + xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + False, + ), + "Hkv": Hkv, + "dtype": torch.bfloat16, + } + for Hkv in [1, 2] + ], +] + +for c in LLM_CASES.copy(): + for update in LLM_CASE_UPDATES: + c = c.copy() + c.update(update) + LLM_CASES.append(c) + +CASES = VISION_CASES + LLM_CASES + + +def create_tensors(shape_q, Hkv, dtype, requires_grad=False, packed=True): + stacked_shape = list(shape_q) # B, M, H, K + Hq = shape_q[2] + stacked_dim = 2 if packed else 0 + stacked_shape.insert(stacked_dim, 3) + qkv = torch.rand( + stacked_shape, device=device, dtype=dtype, requires_grad=requires_grad + ) + q = torch.rand(shape_q, device=device, dtype=dtype, requires_grad=requires_grad) + shape_kv = (shape_q[0], shape_q[1], Hkv, shape_q[3]) + k = ( + torch.rand(shape_kv, device=device, dtype=dtype, requires_grad=requires_grad) + .reshape(shape_q[0], shape_q[1], 1, Hkv, shape_q[3]) + .expand(shape_q[0], shape_q[1], Hq // Hkv, Hkv, shape_q[3]) + .reshape(shape_q) + ) + v = ( + torch.rand(shape_kv, device=device, dtype=dtype, requires_grad=requires_grad) + .reshape(shape_q[0], shape_q[1], 1, Hkv, shape_q[3]) + .expand(shape_q[0], shape_q[1], Hq // Hkv, Hkv, shape_q[3]) + .reshape(shape_q) + ) + + return qkv, q, k, v + + +def mem_eff_attention_fw( + shape_q, + num_threads: int, + attn_bias_cfg, + dropout_p, + dtype, + packed=True, + Hkv=None, +): + B, M, Hq, K = shape_q + Hkv = Hkv or Hq + _, q, k, v = create_tensors( + shape_q, + Hkv, + dtype, + requires_grad=False, + packed=packed, + ) + attn_bias_type, attn_bias_requires_grad = attn_bias_cfg + if attn_bias_requires_grad: + return + + dtype_str = { + torch.bfloat16: "b16", + torch.half: "f16", + torch.float: "f32", + }[dtype] + sub_label = ( + f"{dtype_str} {B}-{M}-{Hq}-{Hkv}-{K}, p={dropout_p}, " + f"BiasT={attn_bias_type.__name__}" + ) + + has_run = False + for fw_op, bw_op in OPS: + bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=Hq, + num_heads_groups=Hq // Hkv, + q_len=M, + kv_len=M, + dtype=dtype, + device=device, + requires_grad=attn_bias_requires_grad, + fmt="BMHK", + op=fw_op, + ) + inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p) + if isinstance( + bias, + ( + fmha.attn_bias.BlockDiagonalMask, + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + ), + ): + q, k, v = [x.reshape([1, -1, *x.shape[2:]]) for x in [q, k, v]] + if not fw_op.supports(inp): + continue + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias, p)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": inp.attn_bias, + "p": dropout_p, + "fn": partial( + xformers.ops.memory_efficient_attention, op=(fw_op, bw_op) + ), + }, + label=f"attention (attn_bias={attn_bias_type})", + description=fw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + has_run = True + + if not has_run: + return + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias, p)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": inp.attn_bias, + "p": dropout_p, + "fn": ref_attention, + }, + label=f"attention (attn_bias={attn_bias_type})", + description="eager", + sub_label=sub_label, + num_threads=num_threads, + ) + + +def mem_eff_attention_bw( + shape_q, num_threads: int, attn_bias_cfg, dropout_p, dtype, Hkv=None +): + B, M, Hq, K = shape_q + Hkv = Hkv or Hq + _, q, k, v = create_tensors( + shape_q, + Hkv, + dtype, + requires_grad=True, + ) + + attn_bias_type, attn_bias_requires_grad = attn_bias_cfg + + dtype_str = { + torch.bfloat16: "b16", + torch.half: "f16", + torch.float: "f32", + }[dtype] + sub_label = ( + f"{dtype_str} {B}-{M}-{Hq}-{Hkv}-{K}, p={dropout_p}, " + f"BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}" + ) + + has_run = False + for fw_op, bw_op in OPS: + bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=Hq, + num_heads_groups=Hq // Hkv, + q_len=M, + kv_len=M, + dtype=dtype, + device=device, + requires_grad=attn_bias_requires_grad, + fmt="BMHK", + op=bw_op, + ) + inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p) + + if not fw_op.supports(inp) or not bw_op.supports(inp): + continue + has_run = True + out = xformers.ops.memory_efficient_attention( + inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=(fw_op, bw_op) + ) + grad_benchmark = torch.ones_like(q) + + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad_benchmark, + }, + label=f"attention backward (attn_bias={attn_bias_type})", + description=bw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + del out + + if not has_run: + return + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": ref_attention(q, k, v, inp.attn_bias, dropout_p), + "grad": grad_benchmark, + }, + label=f"attention backward (attn_bias={attn_bias_type})", + description="vanilla", + sub_label=sub_label, + num_threads=num_threads, + ) + + +def main(): + arg_parser = create_argparser() + arg_parser.add_argument( + "--omit-forward", + action="store_true", + help="Do not run forward benchmarks", + ) + arg_parser.add_argument( + "--omit-backward", + action="store_true", + help="Do not run backward benchmarks", + ) + args = arg_parser.parse_args() + if not args.omit_forward: + benchmark_main_helper( + mem_eff_attention_fw, + CASES, + arg_parser=arg_parser, + min_run_time=min_run_time, + ) + if not args.omit_backward: + benchmark_main_helper( + mem_eff_attention_bw, + CASES, + arg_parser=arg_parser, + min_run_time=min_run_time, + ) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sp24.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sp24.py new file mode 100644 index 0000000000000000000000000000000000000000..4e891161696ecf8c63441dac2b39a5371188634c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sp24.py @@ -0,0 +1,178 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from utils import DTYPE2STR, benchmark_main_helper2, product_dict + +import xformers.ops as xops + +min_run_time = 0.5 +device = torch.device("cuda") + +CASES = list( + product_dict( + B_in_hidden_out_ft=[ + (2048 * 8, 2048, 2048 * 3, 2048), + (2048, 5120, 5120 * 3, 5120), # 13b + (1024, 8192, 8192 * 3, 8192), # 30b + (2048, 8192, 8192 * 3, 8192), # 30b + (2048 * 2, 8192, 8192 * 3, 8192), # 30b + # DINO ViT-L: lg + sm crops (patch16) + (64 * 2 * (14 * 14 + 1) + 64 * 8 * (6 * 6 + 1), 1024, 1024 * 4, 1024), + # DINO ViT-g: lg + sm crops (patch16) + ( + 12 * 2 * (16 * 16 + 1 + 11) + 12 * 8 * (7 * 7 + 1 + 11), + 1536, + 1536 * 4, + 1536, + ), + ], + dtype=[torch.half], + bias=[False], + ) +) + + +class Mlp(nn.Module): + LINEAR_CLS = nn.Linear + + def __init__( + self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool + ) -> None: + B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft + super().__init__() + self.label = "mlp" + self.sub_label = ( + f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}" + ) + self.fc1 = self.LINEAR_CLS(in_ft, hid_ft, bias=bias) + self.act = nn.GELU() + self.fc2 = self.LINEAR_CLS(hid_ft, out_ft, bias=bias) + self.grad = torch.randn([B, out_ft], device="cuda", dtype=dtype) + self.input = torch.randn( + [B, in_ft], device="cuda", dtype=dtype, requires_grad=True + ) + self.out = self.input + self.to("cuda").to(dtype) + + def fw(self): + x = self.input + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + self.out = x + + def bw(self): + self.out.backward(self.grad, retain_graph=True) + + +class MlpDenseMask(Mlp): + def fw(self): + x = self.input + x = self.fc1(x) + + mask = torch.ops.xformers.sparse24_largest_mask_2d(x) + x = mask * x + + x = self.act(x) + x = self.fc2(x) + self.out = x + + +class MlpAct24(Mlp): + def fw(self): + x = self.input + x = self.fc1(x) + + x = xops.sparsify24(x) + + x = self.act(x) + x = self.fc2(x) + self.out = x + + +class LinearW24(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + w_sparse = xops.sparsify24( + self.weight, + gradient="24dense", + backend="cusparselt", + ) + return F.linear(input, w_sparse, self.bias) + + +class MlpW24(Mlp): + LINEAR_CLS = LinearW24 + + +class MicrobenchmarkBase: + def __init__( + self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool + ) -> None: + B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft + super().__init__() + self.label = "mlp" + self.sub_label = ( + f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}" + ) + self.input = torch.randn( + [B, in_ft], device="cuda", dtype=dtype, requires_grad=True + ) + self.input_colMajor = self.input.t().contiguous().t() + self.input_sp = xops.sparsify24(self.input) + + def bw(self) -> None: + return None + + +class MicrobenchmarkSparsify24(MicrobenchmarkBase): + def fw(self) -> torch.Tensor: + xops.sparsify24(self.input) + return self.input + + +class MicrobenchmarkSp24ApplyDense(MicrobenchmarkBase): + def fw(self) -> torch.Tensor: + xops.sparsify24_like(self.input, pattern=self.input_sp, out_dense=True) + return self.input + + +class MicrobenchmarkSp24ApplyDenseT(MicrobenchmarkBase): + def fw(self) -> torch.Tensor: + xops.sparsify24_like(self.input_colMajor, pattern=self.input_sp, out_dense=True) + return self.input + + +class MicrobenchmarkInputClone(MicrobenchmarkBase): + def fw(self) -> torch.Tensor: + self.input.clone() + return self.input + + +functions = { + "act24": MlpAct24, + "dense": Mlp, + "w24": MlpW24, + "s24_inp_sparsify24": MicrobenchmarkSparsify24, + "s24_inp_apply_dense": MicrobenchmarkSp24ApplyDense, + "s24_inp_apply_dense_t": MicrobenchmarkSp24ApplyDenseT, + "s24_inp_clone": MicrobenchmarkInputClone, +} +benchmark_main_helper2( + "sp24_fw", fw=True, cases=CASES, functions=functions, min_run_time=min_run_time +) +benchmark_main_helper2( + "sp24_fwbw", + fw=True, + bw=True, + cases=CASES, + functions=functions, + min_run_time=min_run_time, +) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__init__.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b19d99efb5e86088690c5785e0858ecda29c57b8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/__init__.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path +from typing import Any, Callable, Dict, Set, Union + +import torch + +from xformers.utils import ( + generate_matching_config, + get_registry_decorator, + import_all_modules, +) + +from ._sputnik_sparse import SparseCS +from .attention_mask import AttentionMask +from .base import Attention, AttentionConfig # noqa + +logger = logging.getLogger("xformers") + + +# CREDITS: Classy Vision registry mechanism + +ATTENTION_REGISTRY: Dict[str, Any] = {} +ATTENTION_CLASS_NAMES: Set[str] = set() + +# Arbitrary threshold for now, +# in between dense and sparse matrix algorithms for the attention mechanism +_DENSITY_THRESHOLD = 0.30 # noqa # from the sputnik paper, vs. +_USE_SPUTNIK = True + + +def build_attention(config: Union[Dict[str, Any], AttentionConfig]): + """Builds an attention from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_attention", + "foo": "bar"}` will find a class that was registered as "my_attention" + (see :func:`register_attention`) and call .from_config on it.""" + + if not isinstance(config, AttentionConfig): + try: + config_instance = generate_matching_config( + config, ATTENTION_REGISTRY[config["name"]].config + ) + except KeyError as e: + name = config["name"] + logger.warning(f"{name} not available among {ATTENTION_REGISTRY.keys()}") + raise e + else: + config_instance = config + + return ATTENTION_REGISTRY[config_instance.name].constructor.from_config( + config_instance + ) + + +"""Registers an Attention subclass. + + This decorator allows xFormers to instantiate a subclass of Attention + from a configuration file, even if the class itself is not part of the + xFormers library. To use it, apply this decorator to an Attention + subclass, like this: + + .. code-block:: python + + @dataclass + class MyConfig: + ... + + @register_attention('my_attention', MyConfig) + class MyAttention(Attention): + ... + + To instantiate an attention from a configuration file, see :func:`build_attention`.""" +register_attention: Callable[[str, Any], Callable[[Any], Any]] = get_registry_decorator( + ATTENTION_REGISTRY, ATTENTION_CLASS_NAMES, Attention, AttentionConfig +) + + +def maybe_sparsify(matrix) -> Any: + # Sparsify if that makes sense + if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD: + # If not sparse, then AttentionMask is the reference type + return AttentionMask.from_bool(matrix) + + return sparsify(matrix) + + +def sparsify(matrix): + if _USE_SPUTNIK: + return SparseCS(matrix) + return matrix.to_sparse() + + +from .favor import FavorAttention # noqa +from .global_tokens import GlobalAttention # noqa +from .linformer import LinformerAttention # noqa +from .local import LocalAttention # noqa +from .nystrom import NystromAttention # noqa +from .ortho import OrthoFormerAttention # noqa +from .random import RandomAttention # noqa +from .scaled_dot_product import ScaledDotProduct # noqa + +__all__ = [ + "ScaledDotProduct", + "LocalAttention", + "LinformerAttention", + "NystromAttention", + "RandomAttention", + "OrthoFormerAttention", + "GlobalAttention", + "FavorAttention", + "Attention", + "AttentionMask", + "build_attention", + "register_attention", +] + +# automatically import any Python files in the directory +import_all_modules(str(Path(__file__).parent), "xformers.components.attention") diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72fff0b3bd4cf30992046ba417a136183667a1a6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02d7cb8ee2f0117d959e4ea719cf6790d971d407 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_mask.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_mask.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ca42e520f84028596d942135987401be9d58c95 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_mask.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_patterns.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_patterns.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb490af28888551b7f0f8d55f967ba7772b61102 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_patterns.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc34956ce6fd4163310c519e3fe23379e2ab5145 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/compositional.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/compositional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d188532cf96f9e3782ef9c9555238bedd6f3c459 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/compositional.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a8d48648133c17c5acde9dc98ab3e30aa143ed4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/favor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/favor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da37e29a64b1b47341ee3430ac3f8debf1cfe8d8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/favor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/fourier_mix.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/fourier_mix.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7360e6ca61250dbcce60dd180d2edf150bb1be1d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/fourier_mix.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/global_tokens.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/global_tokens.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dcb07ff4c1fb2d53bf245d22c3eac12c2c943b7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/global_tokens.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/lambda_layer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/lambda_layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..279db90fe9446051ed1178c4ac5839a034707ed8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/lambda_layer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/linformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/linformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddf275384a51d9d62f0e41db3fc8e2128efa1b58 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/linformer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/local.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/local.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..652e89f54a92c7600641a7ab7ff954f4e97860f4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/local.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/nystrom.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/nystrom.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a19c3b678a057f41b35b20aa116f27d91d30e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/nystrom.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/ortho.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/ortho.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71d437681e57baa913b21698106abee938bfdb2b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/ortho.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/pooling.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/pooling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0800eb7b97e1d16e1c5bd1c9b3eb0bdd86b1906c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/pooling.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/random.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaedf31d8f56108bcf9946092fcefb712cf44ac3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/random.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/scaled_dot_product.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/scaled_dot_product.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9df16d355b9af8c8b49bcfa08137511062956e3b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/scaled_dot_product.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/sparsity_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/sparsity_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d712a503fb0f78fe01c11d086edfccaa567135ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/sparsity_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735d817f2f12b894ab998666df9639c7875b4054 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/visual.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/visual.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..166fddd9925c06605fcbbe84d57fb7317985c615 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/visual.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/_sputnik_sparse.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/_sputnik_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b92d75cd81511256377dc493215b66acb25364 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/_sputnik_sparse.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.ops import masked_matmul +from xformers.sparse import SparseCSRTensor + +# TODO: this is here for BC +from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401 + + +class SparseCS: + def __init__(self, matrix, device=None): + if device is None: + device = torch.device("cpu") + if matrix.ndim == 2: + matrix = matrix[None] + assert matrix.ndim == 3 + self._mat = SparseCSRTensor.from_dense(matrix).to(device) + + @property + def device(self): + return self._mat.device + + @property + def ndim(self): + return self._mat.ndim + + @property + def dtype(self): + return self._mat.dtype + + @property + def is_sparse(self): + return True + + @property + def shape(self): + return self._mat.shape[1:] + + @property + def values(self): + return self._mat.values() + + @property + def row_indices(self): + return self._mat._csr_row_indices + + @property + def column_indices(self): + return self._mat._csr_column_indices + + @property + def row_offsets(self): + return self._mat._csr_row_offsets + + @property + def _transp_info(self): + return self._mat._csr_transp_info + + @classmethod + def wrap( + cls, shape, values, row_indices, row_offsets, column_indices, _transp_info + ): + matrix = cls.__new__(cls) + _shape = (values.shape[0],) + shape + csr_matrix = SparseCSRTensor._wrap( + _shape, values, row_indices, row_offsets, column_indices, _transp_info + ) + matrix._mat = csr_matrix + return matrix + + @classmethod + def _wrap(cls, csr_matrix): + assert isinstance(csr_matrix, SparseCSRTensor) + matrix = cls.__new__(cls) + matrix._mat = csr_matrix + return matrix + + def __mul__(self, other): + assert isinstance(other, (int, float)) + return type(self)._wrap(self._mat * other) + + def __add__(self, other): + assert isinstance(other, type(self)) + return type(self)._wrap(self._mat + other._mat) + + def matmul_with_mask(self, a, b): + return type(self)._wrap(masked_matmul(a, b, self._mat)) + + def softmax(self): + out = torch.nn.functional.softmax(self._mat, -1) + return type(self)._wrap(out) + + def spmm(self, b): + out = torch.bmm(self._mat, b) + return out + + def transpose(self): + out = torch.transpose(self._mat, -2, -1) + return type(self)._wrap(out) + + def to(self, device): + assert isinstance(device, torch.device) + out = self._mat.to(device) + return type(self)._wrap(out) + + def to_dense(self): + return self._mat.to_dense() + + def logical_and(self, other: torch.Tensor): + assert not isinstance(other, SparseCS) + out = torch.logical_and(self._mat, other) + return type(self)._wrap(out) + + def __and__(self, other): + return self.logical_and(other) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/attention_mask.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/attention_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..7006e97aa9eb710f4029d709f87c1f83a316254b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/attention_mask.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Type, TypeVar + +import torch + +Self = TypeVar("Self", bound="AttentionMask") + + +class AttentionMask: + """ + Holds an attention mask, along with a couple of helpers and attributes. + + .. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value, + and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose + is to bias the attention computation for instance + + .. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`, + `[to_sequence, from_sequence]`, or anything broadcastable in between + """ + + def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False): + assert additive_mask.is_floating_point(), additive_mask.dtype + assert not additive_mask.requires_grad + + if additive_mask.ndim == 2: + additive_mask = additive_mask.unsqueeze(0) + + self.values = additive_mask + self.is_causal = is_causal + self.seq_len = additive_mask.shape[1] + self.to_seq_len = additive_mask.shape[0] + + def to_bool(self) -> torch.Tensor: + """ + .. warning: we assume here that True implies that the value should be computed + """ + return self.values != float("-inf") + + @classmethod + def from_bool(cls: Type[Self], x: torch.Tensor) -> Self: + """ + Create an AttentionMask given a boolean pattern. + .. warning: we assume here that True implies that the value should be computed + """ + assert x.dtype == torch.bool + + additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) + additive_mask.masked_fill_(x, 0.0) + additive_mask.masked_fill_(~x, float("-inf")) + + return cls(additive_mask) + + @classmethod + def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self: + """ + Create an AttentionMask given a multiplicative attention mask. + """ + assert not x.dtype == torch.bool + + additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) + x = x.bool() + + additive_mask.masked_fill_(x, 0.0) + additive_mask.masked_fill_(~x, float("-inf")) + + return cls(additive_mask) + + @classmethod + def make_causal( + cls: Type[Self], + seq_len: int, + to_seq_len: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not to_seq_len: + to_seq_len = seq_len + + additive_mask = torch.triu( + torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"), + diagonal=1, + ) + return cls(additive_mask=additive_mask, is_causal=True) + + def make_crop( + self, seq_len: int, to_seq_len: Optional[int] = None + ) -> "AttentionMask": + """ + Return a cropped attention mask, whose underlying tensor is a view of this one + """ + + if not to_seq_len: + to_seq_len = seq_len + + return AttentionMask( + self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal + ) + + def __repr__(self): + return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values) + + @property + def device(self): + return self.values.device + + @property + def is_sparse(self): + return False + + @property + def ndim(self): + return len(self.values.shape) + + @property + def dtype(self): + return self.values.dtype + + @property + def shape(self): + return self.values.shape + + def __add__(self, other): + return AttentionMask(self.values + other.values, is_causal=False) + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ) -> "AttentionMask": + assert device is None or isinstance(device, torch.device) + assert dtype is None or isinstance(dtype, torch.dtype) + assert device is not None or dtype is not None + + # Noop if we don't need to create another instance + if ((device and device == self.device) or not device) and ( + (dtype and dtype == self.dtype) or not dtype + ): + return self + + return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/base.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/base.py new file mode 100644 index 0000000000000000000000000000000000000000..848deb8a050cc7564dc617da5c0a16ec2520e6ee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/base.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABCMeta, abstractmethod +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch +import torch.nn as nn + +from xformers._deprecation_warning import deprecated_function +from xformers.components.attention import AttentionMask + + +@dataclass +class AttentionConfig: + """Parameters required for all Attentions. + Can accept and store extra parameters. + """ + + name: str # the registered name for this attention mechanism + dropout: float # dropout probability + + +Self = TypeVar("Self", bound="Attention") + + +# Define the common interface, every attention block needs to derive from it +class Attention(nn.Module, metaclass=ABCMeta): + r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention""" + + _causal_mask: Optional[AttentionMask] = None + + @abstractmethod + def __init__(self, dropout: Optional[float] = None, *args, **kwargs): + super().__init__() + deprecated_function(self) + + # Requires the inputs to be projected + self.requires_input_projection = True + + # Whether the head dimension needs to be present (if not it can be folded into the batch dimension) + self.requires_head_dimension = False + + # key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask + self.requires_separate_masks = False + + # Requires that K and Q have the same sequence length + self.requires_same_k_q_dimensions = False + + # Whether the attention owns the single head/multihead mechanism + # so that the MHA wrapper should skip it + self.requires_skip_multi_head = False + + # This attention requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = False + + # Whether this attention mechanism supports attention masks + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + @classmethod + def from_config(cls: Type[Self], config: AttentionConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) + + @abstractmethod + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor): + """ + If the sequence is shorter than the mask, return a padded view + """ + if x.shape[-2] != mask.shape[-1]: + assert x.shape[-2] < mask.shape[-1], ( + "Sequence is bigger than the provided mask, cannot infer what to do with it." + " Please update your attention mask" + ) + + pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0) + return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0) + + return x diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/compositional.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/compositional.py new file mode 100644 index 0000000000000000000000000000000000000000..a06053c27da8538b1fe738ca61cd2796a0da2eca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/compositional.py @@ -0,0 +1,341 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# Credits: this is heavily inspired by the official implementation, present in +# https://github.com/sarthmit/Compositional-Attention +# Original author: Sarthak Mittal + +# This is a simplified version, for the sake of clarity, and because some features could be exposed later +# via the library directly. +# In particular, code paths for TPUs, quantization and gumbel softmax have been removed +# We're also following the same dimension ordering as in the rest of the xformers library +# which is to say [Batch, Sequence, Embedding] wherever possible + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) +from xformers.components.attention.core import _softmax +from xformers.components.input_projection import InputProjection, InputProjectionConfig + + +def _either_or(a: Optional[int], b: int) -> int: + return a if a is not None else b + + +@dataclass +class CompositionalAttentionConfig(AttentionConfig): + dim_model: int + num_heads: int + dim_attn: Optional[int] = None + num_rules: Optional[int] = None + dim_key: Optional[int] = None + dim_value: Optional[int] = None + dim_selection: Optional[int] = None + dropout: float + qk_rule: bool = False + nonlinear: bool = False + q_compose: bool = False + bias: bool = True + causal: Optional[bool] = False + in_proj_container: Optional[InputProjection] = None + use_separate_proj_weight: Optional[bool] = False + + +@register_attention("compositional", CompositionalAttentionConfig) +class CompositionalAttention(Attention): + """Compositional Attention, as proposed in + "Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al. + + A key insight from this proposal is that the attention mechanism can be conceived as two steps: + a search and a retrieval operation. When queried, the model can search for the most relevant information + (Softmax(QKt)), then retrieve information given the Value. + + Contrary to the original attention proposal, which does not consider interactions in between heads, + the compositional attention will consider all possible interactions and softmax over that dimension, + so that the information retrieved covers the most relevant dimensions. The number of heads and rules to + use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads + may not fit in memory. + + Args: + dim_model: dimension of the incoming latent space + num_heads: number of heads *for the search operation* + dim_attn: dimension (embedding) of the attention + num_rules: number of rules to consider *for the retrieval operation* + dim_selection: dimension of the scoring/selection space for the retrievals + dim_key, dim_value: dimensions of K and V, if different from Q + dropout: attention dropout probability + qk_rule: QK product will drive the retrieval process + nonlinear: use a non linear method to score the retrievals + bias: use bias in the initial projection step + causal: causal computations (attend to the past only) + + _"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf + """ + + def __init__( + self, + dim_model: int, + num_heads: int, + dim_attn: Optional[int] = None, + num_rules: Optional[int] = None, + dim_selection: Optional[int] = None, + dim_key: Optional[int] = None, + dim_value: Optional[int] = None, + dropout=0.0, + qk_rule=False, + nonlinear=False, + q_compose=False, + in_proj_container: Optional[InputProjection] = None, + use_separate_proj_weight: Optional[bool] = False, + bias=True, + causal=False, + *_, + **__, + ): + super().__init__() + + # Define the inherited flags + self.requires_skip_multi_head = ( + True # This attention owns the multi-head mechanism + ) + + # Handle defaults / undefined values + self.dim_model = dim_model + num_rules = _either_or(num_rules, num_heads) + dim_selection = _either_or(dim_selection, dim_model // num_heads) + + # All the initial definition plumbing + dim_attn = _either_or(dim_attn, dim_model) + dim_key = _either_or(dim_key, dim_model) + dim_value = _either_or(dim_value, dim_model) + + self.in_proj_container = ( + in_proj_container + if in_proj_container is not None + else InputProjection( + query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias), + key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias) + if use_separate_proj_weight + else None, + value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias) + if use_separate_proj_weight + else None, + ) + ) + + self.num_heads = num_heads + self.num_rules = num_rules + self.qk_rule = qk_rule + self.dim_selection = dim_selection + self.nonlinear = nonlinear + self.q_compose = q_compose + + self.dropout_module = nn.Dropout(dropout) + self.dim_head = dim_model // num_heads + self.value_dim = dim_attn // num_rules + + assert ( + self.value_dim * num_rules == dim_attn + ), "value_dim must be divisible by num_rules" + + self.scaling = self.dim_head**-0.5 + self.scaling_values = self.dim_selection**-0.5 + + self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias) + + if self.qk_rule: + self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias) + if self.q_compose: + self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) + else: + self.value_q = nn.Linear( + dim_model, self.dim_selection * self.num_heads, bias=bias + ) + else: + if self.q_compose: + self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) + else: + self.value_q = nn.Linear( + dim_model, self.dim_selection * self.num_heads, bias=bias + ) + if self.nonlinear: + self.score_network: nn.Module = nn.Sequential( + nn.Linear( + self.dim_selection + self.value_dim, + self.dim_selection, + bias=bias, + ), + nn.ReLU(), + nn.Linear(self.dim_selection, 1, bias=bias), + ) + else: + self.score_network = nn.Linear( + self.dim_selection + self.value_dim, 1, bias=bias + ) + + self.causal = causal + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + self._reset_parameters() + + def _reset_parameters(self): + # NOTE: in_proj_container is already initialized + + if self.qk_rule: + nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.value_q.weight) + if self.nonlinear: + nn.init.xavier_uniform_(self.score_network[0].weight) + nn.init.xavier_uniform_(self.score_network[2].weight) + else: + nn.init.xavier_uniform_(self.score_network.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + att_mask: Optional[Tensor] = None, + *args, + **kwargs, + ) -> Tensor: + """ + Input shape: Time x Batch x Channel + + Args: + att_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + """ + + B, Sq, E = q.shape + _, Sk, _ = k.shape + + assert E == self.dim_model + + # First define projected query/key/values + # We keep the projected and original tensors in flight, + # depending on the options the original values could be reused + q_unprojected = q + q, k, v = self.in_proj_container(query=q, key=k, value=v) + q *= self.scaling + + # Init causal mask if needed, now that we know the context length + if self.causal and ( + self._causal_mask is None or self._causal_mask.shape[0] != Sk + ): + self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device) + + # Convenience, create an attention mask if a tensor was passed + # This sanitizes different mask types being passed, from now on it's additive + if isinstance(att_mask, torch.Tensor): + # By default we don't know of the causality, and a check would be expensive + att_mask_additive: Optional[AttentionMask] = ( + AttentionMask.from_bool(att_mask) + if att_mask.dtype == torch.bool + else AttentionMask(att_mask, is_causal=False) + ) + else: + att_mask_additive = None + + # Handle the attention and key padding masks + if self._causal_mask is not None: + # Optionally add the causal mask + if att_mask_additive is not None: + att_mask_additive += self._causal_mask + else: + att_mask_additive = self._causal_mask + + # Flatten the heads or the rules + q = ( + q.view(B, Sq, self.num_heads, self.dim_head) + .movedim(2, 1) + .flatten(0, 1) # [B * num_heads, Sq, dim_head] + ) + k = ( + k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1) + ) # [B * num_heads, Sk, dim_head] + v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1) + + # Compute the search: Softmax(QKt) + attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk] + + if att_mask_additive is not None: + attn_weights += att_mask_additive.values + + attn_weights = _softmax(attn_weights, causal=self.causal) + + attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) + attn_probs = self.dropout_module(attn_weights) + + # Now compute the information retrieval + # keep all the heads in flight, we'll score the different possibilities + # - compute all the possible retrievals + v = v.view(B, 1, self.num_rules, Sk, self.value_dim) + attn_probs = attn_probs.unsqueeze(2) + attn = torch.matmul(attn_probs, v).view( + B, self.num_heads, self.num_rules, Sq, self.value_dim + ) + + attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values] + + # - search the most appropriate retrieval among all the values + if self.q_compose: + v_q = self.value_q(q.transpose(0, 1)).view( + B, Sq, self.num_heads, 1, self.dim_selection + ) + else: + v_q = self.value_q(q_unprojected).view( + B, Sq, self.num_heads, 1, self.dim_selection + ) + + if self.qk_rule: + v_q *= self.scaling_values + v_k = ( + self.value_k(attn) + .view(B, Sq, self.num_heads, self.num_rules, self.dim_selection) + .transpose(4, 3) + .contiguous() + ) + v_score = torch.matmul(v_q, v_k).view( + B, Sq, self.num_heads, self.num_rules, 1 + ) + else: + v_q = v_q.expand(-1, -1, -1, self.num_rules, -1) + v_in = torch.cat([attn, v_q], dim=-1) + v_score = self.score_network(v_in).view( + B, Sq, self.num_heads, self.num_rules, 1 + ) + + v_score = F.softmax(v_score, dim=3) + + # - extracted values are the original attention (inc. all the values) weighted by value score + attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim) + + # Final attention projection, same as other mechanisms + attn = self.out_proj(attn) + + return attn diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__init__.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed308d17a8bcb41403062133092c7d44dbafd264 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from enum import Enum + +from .base import FeatureMap, FeatureMapConfig +from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg + + +class FeatureMapType(str, Enum): + SMOrf = "sm_orf" + SMHyp = "sm_hyp" + SMReg = "sm_reg" # regularized softmax kernel + + +__all__ = [ + "SMOrf", + "SMReg", + "SMHyperbolic", + "NormDistribution", + "FeatureMapConfig", + "FeatureMap", +] diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0de9bdf137932dc0e8b6825e4e27f782fa42319 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a85127213f4c8642f54310bbe4b5628815833a5c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb4b83c716b85b05106555f04d0db41943f1d9a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/base.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8d41de827a03abe9ed9977f1ca69e364352400cc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/base.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import abstractmethod +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch + +""" +Feature maps allow for a given query or key to be encoded in a different space. +""" + +Self = TypeVar("Self", bound="FeatureMap") + + +@dataclass +class FeatureMapConfig: + name: str + dim_features: int + iter_before_redraw: Optional[int] + normalize_inputs: Optional[bool] + epsilon: Optional[float] + + +class FeatureMap(torch.nn.Module): + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int] = None, + normalize_inputs: bool = False, + epsilon: float = 1e-6, + ): + super().__init__() + + self.dim_features = dim_features + self.dim_feature_map = dim_features + + self.iter_before_redraw = iter_before_redraw + self.features: Optional[torch.Tensor] = None + self.epsilon = epsilon + self.normalize_inputs = normalize_inputs + + self._iter_counter = 0 + + @abstractmethod + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + raise NotImplementedError() + + @classmethod + def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/softmax.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..d0dd1df7343a7abdf48ee6ca09f5345f3cdfdfb3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/softmax.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from enum import Enum, auto +from typing import Optional + +import torch +from torch.autograd.profiler import record_function + +from .base import FeatureMap + +""" +A set of feature maps which approximate the softmax kernel, as per the Performers_ paper. + +_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf +""" + + +class NormDistribution(Enum): + Xi = auto() + Uniform = auto() + + +class SoftMaxPositiveEstimators(FeatureMap): + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int], + normalize_inputs: bool = False, + epsilon: float = 1e-6, + softmax_temp: float = -1, + ): + super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon) + self.softmax_temp = softmax_temp + + # Handle the scaling from all kernels by √m. + # This normalizes for all the feature maps involved + self.h_scale = math.log(math.sqrt(self.dim_features)) + + def pre_scale(self, x: torch.Tensor) -> torch.Tensor: + with record_function("feature_map::pre_scale"): + # Re-draw counting logic + if ( + ( + self.iter_before_redraw is not None + and self._iter_counter > self.iter_before_redraw + ) + or self.features is None + or self.features.device != x.device + ): + # The feature map is actually using half the dimension, we'll concatenate + and - features + self._iter_counter = 1 + self.features = self._get_feature_map( + x.shape[-1], self.dim_feature_map, x.device + ) + + features = self.features + assert features is not None + + if features.dtype != x.dtype: + self.features = features.to(x.dtype) + + self._iter_counter += 1 + + # Normalization / softmax + if self.softmax_temp < 0: + # A = exp(QK.t/√d), so each input will be scaled by √√d + self.softmax_temp = x.shape[-1] ** -0.25 + + x_scaled = x * self.softmax_temp + + # Compute the scaling factors in logspace, applied from within the exponential + # - dimnish possible exponential overflow + # - remove a multiply across the batch, replace by an addition + norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1) + self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon + + if self.normalize_inputs: + # L0 normalize the exponential term, can be useful for numerical stability + # This ensures that features +- offset is below 1 + self.offset -= norm_x_2.max(1, keepdim=True)[0] + + # Return the scaled inputs, the rest depends on the kernel being used + return x_scaled + + @staticmethod + @torch.no_grad() + def _get_random_ortho_matrix( + blocks: int, + dim: int, + device: torch.device, + norm_distribution: NormDistribution = NormDistribution.Uniform, + ) -> torch.Tensor: + r""" + Generate a random matrix whose rows are exactly orthonormal + + "How to generate random matrices from the classical compact groups", Mezzadri, 2007 + https://arxiv.org/pdf/math-ph/0609050v2.pdf + + .. note: the typical qr decomposition does not give uniform results, qr decomposition is not + unique and the qr decomposition routines are biased towards numerical stability. See the above + paper for more information. + + .. note: this does not follow the original implementation from the Performers authors. + see docs/assets/kde plots to visualize the impact of using the R signs to correct Q + """ + + H = torch.randn((blocks, dim, dim), device=device, requires_grad=False) + + # Randomly scale the norms of the features, Xi distributed + if norm_distribution == NormDistribution.Xi: + # NOTE: This averages to sqrt(d) + norms = torch.sqrt(torch.einsum("...d,...d->...", H, H)) + + Q, R = torch.linalg.qr(H) + Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q + + # Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal + if norm_distribution == NormDistribution.Xi: + return torch.diag_embed(norms) @ Q + + return Q + + +class SMOrf(SoftMaxPositiveEstimators): + """ + "Positive random orthogonal features" softmax estimator, + SM_ort^m+, as proposed in the Performers_ paper, Lemma 1. + + _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf + """ + + @torch.no_grad() + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + """ + Generate the projection matrix onto the random features + + .. note: The heads dimension needs to be taken into account, hence the per-block random matrix + and not uniformally random. + """ + + # Get per block random unitary matrices. + # We need enough of them to project the whole input dimension, regardless of the + # requested dimension of the features + features = self._get_random_ortho_matrix( + math.ceil(dim_input / dim_features), + dim_features, + norm_distribution=NormDistribution.Xi, + device=device, + ) + + return features.flatten(0, 1)[:dim_input] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Softmax-dimension related scaling, shared for all kernels + x_scaled = super().pre_scale(x) + assert self.features is not None + + # Project onto the random feature map. + x_scaled = x_scaled @ self.features + return torch.exp(x_scaled + self.offset) + + +class SMHyperbolic(SoftMaxPositiveEstimators): + """ + "Positive random features hyperbolic" estimator, SMHyp+, + as proposed in the Performers_ paper, Lemma 1. + + _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf + """ + + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int], + normalize_inputs: bool = False, + epsilon: float = 1e-6, + softmax_temp: float = -1, + ): + super().__init__( + dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp + ) + + assert ( + dim_features % 2 == 0 + ), "The feature dimension needs to be even with this kernel" + self.dim_feature_map = self.dim_features // 2 + + @torch.no_grad() + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + """ + Generate the projection matrix onto the random features + + .. note: The heads dimension needs to be taken into account, hence the per-block random matrix + and not uniformally random. + """ + + # Get per block random unitary matrices. + # We need enough of them to project the whole input dimension, regardless of the + # requested dimension of the features + features = self._get_random_ortho_matrix( + math.ceil(dim_input / dim_features), + dim_features, + norm_distribution=NormDistribution.Xi, + device=device, + ) + + return features.flatten(0, 1)[:dim_input] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Softmax-dimension related scaling, shared for all kernels + x_scaled = super().pre_scale(x) + + # Project onto the random feature map, concatenate both + and - results + # This follows Lemma 1 in the original Performers Paper to best approximate a + # softmax kernel (cosh representation) + x_scaled = x_scaled @ self.features + return torch.cat( + [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)], + dim=-1, + ) + + +class SMReg(SoftMaxPositiveEstimators): + """ + "Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper. + + _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf + """ + + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int], + normalize_inputs: bool = False, + epsilon: float = 1e-6, + softmax_temp: float = -1, + ): + super().__init__( + dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp + ) + + assert ( + dim_features % 2 == 0 + ), "The feature dimension needs to be even with this kernel" + self.dim_feature_map = self.dim_features // 2 + + @torch.no_grad() + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + """ + Generate the projection matrix onto the random features + + .. note: The heads dimension needs to be taken into account, hence the per-block random matrix + and not uniformally random. + """ + + # Get per block random unitary matrices. + # We need enough of them to project the whole input dimension, regardless of the + # requested dimension of the features + features = self._get_random_ortho_matrix( + math.ceil(dim_input / dim_features), + dim_features, + norm_distribution=NormDistribution.Uniform, + device=device, + ).flatten(0, 1) + norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device) + return (torch.diag(norms) @ features)[:dim_input] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Softmax-dimension related scaling, shared for all kernels + x_scaled = super().pre_scale(x) + + # Project onto the random feature map, concatenate both + and - results + # This follows Lemma 1 in the original Performers Paper to best approximate a + # softmax kernel (cosh representation + sample regularization) + x_scaled = x_scaled @ self.features + return torch.cat( + [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)], + dim=-1, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/global_tokens.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/global_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a5284a2e8ae57f1b3fe9432c933746be0aaa05 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/global_tokens.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + maybe_sparsify, + register_attention, + sparsify, +) +from xformers.components.attention.attention_patterns import ( + causal_1d_pattern, + global_token_pattern, +) +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class GlobalAttentionConfig(AttentionConfig): + attention_query_mask: torch.Tensor # Mark the queries which have global attention + causal: Optional[bool] + force_sparsity: Optional[bool] + + +@register_attention("global", GlobalAttentionConfig) +class GlobalAttention(Attention): + def __init__( + self, + dropout: float, + attention_query_mask: torch.Tensor, + causal: bool = False, + force_sparsity: bool = False, + *_, + **__, + ): + r""" + Global attention, as proposed for instance in BigBird_ or Longformer_. + + Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend + to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to + any other query. + + This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory. + + Args: + dropout (float): probability of an element to be zeroed + attention_query_mask (torch.Tensor): if true, this query can attend to all the others + + """ + super().__init__() + + assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected" + assert ( + attention_query_mask.shape[1] == 1 + and attention_query_mask.shape[0] > attention_query_mask.shape[1] + ), "A N x 1 query mask is expected" + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.attention_mask = global_token_pattern(attention_query_mask[:, 0]) + self.force_sparsity = force_sparsity + + if causal: + self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1]) + + self.attention_mask = ( + sparsify(self.attention_mask) + if self.force_sparsity + else maybe_sparsify(self.attention_mask) + ) + + # Properties specific to this attention mechanism + self.requires_same_k_q_dimensions = True + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + *_, + **__, + ): + # Make sure that the mask is on the right device + if self.attention_mask.device != q.device: + self.attention_mask = self.attention_mask.to(q.device) + + # Mask-aware attention + if att_mask is not None: + if att_mask.dtype == torch.bool and isinstance( + self.attention_mask, AttentionMask + ): + if not isinstance(att_mask, AttentionMask): + att_mask = AttentionMask.from_bool(att_mask) + mask = self.attention_mask + att_mask + else: + mask = self.attention_mask & att_mask + else: + mask = self.attention_mask + + # Handle q/k/v which would not fit the mask + seq_len = q.shape[-2] + q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v)) + + # Normal attention with the global tokens mask + att = scaled_dot_product_attention( + q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop + ) + + # Take into account an hypothetical padding + return att[:, :seq_len, :] diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/linformer.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/linformer.py new file mode 100644 index 0000000000000000000000000000000000000000..af6f20b59994f171802459bbe628323d0a53f034 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/linformer.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class LinformerSelfAttentionConfig(AttentionConfig): + seq_len: int # dimension of the input sequence + k: Optional[int] # dimension of the internal space + + +@register_attention("linformer", LinformerSelfAttentionConfig) +class LinformerAttention(Attention): + def __init__( + self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs + ): + """ + Linformer attention mechanism, + from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020). + The original notation is kept as is. + + .. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2 + """ + super().__init__() + + if k is None: + k = seq_len // 4 + + self.k = k + self.E = nn.Linear(seq_len, k, bias=False) + self.F = nn.Linear(seq_len, k, bias=False) + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.seq_len = seq_len + + # MHA related flags: + # kq need to have the same dimension + self.requires_same_k_q_dimensions = True + + # This attention does not support attention masks + self.supports_attention_mask = False + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs + ): + # Handle a smaller dimension than expected + padding = 0 + if q.shape[1] < self.seq_len: + padding = self.seq_len - q.shape[1] + pad_dims = (0, 0, 0, padding) + q = torch.nn.functional.pad(q, pad_dims) + k = torch.nn.functional.pad(k, pad_dims) + v = torch.nn.functional.pad(v, pad_dims) + + k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1) + v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1) + + y = scaled_dot_product_attention( + q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop + ) + + y = self.attn_drop(y) + + return y[:, :-padding, :] if padding > 0 else y diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/ortho.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/ortho.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6de43a3a1e181887c7e085077647f3eac1fa1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/ortho.py @@ -0,0 +1,324 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import torch +import torch.autograd.profiler as profiler +import torch.nn as nn +import torch.nn.functional as Fn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) +from xformers.components.attention.core import ( + scaled_dot_product_attention, + scaled_query_key_softmax, +) + +logger = logging.getLogger("xformers") + + +class LandmarkSelection(str, Enum): + Orthogonal = "orthogonal" + KMeans = "kmeans" + KMeans_Spherical = "kmeans_spherical" + Random = "random" + + +@dataclass +class OrthoformerAttentionConfig(AttentionConfig): + """ + num_landmarks Number of landmarks to use for softmax approximation. + subsample_fraction Percentage of q_samples matrix to sample per iteration + landmark_selection Landmark selection strategy + """ + + num_landmarks: Optional[int] + subsample_fraction: Optional[float] + landmark_selection: Optional[LandmarkSelection] + + +@register_attention("orthoformer", OrthoformerAttentionConfig) +class OrthoFormerAttention(Attention): + def __init__( + self, + dropout: float, + num_landmarks: int = 32, + subsample_fraction: float = 1.0, + landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal, + *args, + **kwargs, + ): + """ + Orthoformer_ attention mechanism. + :: + + "Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers" + Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer, + C., Vedaldi, A., Henriques, J. (2021) + + Reference codebase: https://github.com/facebookresearch/Motionformer + + .. _Orthoformer: https://arxiv.org/abs/2106.05392 + + """ + super().__init__() + + self.num_landmarks = num_landmarks + self.attn_drop = nn.Dropout(dropout) + self.subsample_fraction = subsample_fraction + self.landmark_selection = landmark_selection + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, + *args, + **kwargs, + ): + N = k.shape[1] + + if self.num_landmarks == N: + # Default attention + x = scaled_dot_product_attention(q, k, v, att_mask) + else: + with torch.no_grad(), profiler.record_function("select landmarks"): + if self.landmark_selection == LandmarkSelection.Orthogonal: + landmarks = self._compute_orthogonal_landmarks(q) + elif self.landmark_selection == LandmarkSelection.Random: + half_L = self.num_landmarks // 2 + landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :] + landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :] + landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2) + elif self.landmark_selection == LandmarkSelection.KMeans: + landmarks = self._cluster_landmarks(q) + elif self.landmark_selection == LandmarkSelection.KMeans_Spherical: + landmarks = self._cluster_landmarks(q, spherical=True) + + if att_mask is not None: + logger.warning( + "Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \ + The two are typically not compatible" + ) + # FIXME: Should we still accept a mask in that case ? + att_mask = None + + # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems + # like it could be uninitialized. + kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask) + # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems + # like it could be uninitialized. + kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask) + x = torch.matmul(kernel_1, torch.matmul(kernel_2, v)) + x = self.attn_drop(x) + return x + + def _cluster_landmarks( + self, + q: torch.Tensor, + spherical: bool = False, + num_iters: int = 6, + ) -> torch.Tensor: + """ + Construct set of landmarks by recursively selecting new landmarks + that are maximally orthogonal to the existing set. + Returns near orthogonal landmarks with shape (B, M, D). + """ + + num_landmarks = min(self.num_landmarks, q.shape[1]) + + if self.subsample_fraction < 1.0: + num_samples = max( + int(self.subsample_fraction * q.size(-2)), num_landmarks + ) # Need at least M/2 samples of queries and keys + q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D) + else: + q_samples = q # (B, N, D) + + if spherical: + q_samples_normalized = Fn.normalize( + q_samples, p=2, dim=-1 + ) # may need to change default eps to eps=1e-8 for mixed precision compatibility + landmarks = self._kmeans_spherical( + q_samples_normalized, num_landmarks, num_iters + ) + else: + landmarks = self._kmeans(q_samples, num_landmarks, num_iters) + return landmarks # (B, M, D) + + def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10): + """ + Arguments: + x: (B, N, D) + K: number of clusters + num_iters: the number of kmeans updates + """ + + B, N, D = x.size() + assert K <= N, f"{K} > {N}" + + c = x[ + :, torch.randperm(N, device=x.device)[:K], : + ].clone() # initialisation for the centroids + + with profiler.record_function("kmeans"): + x_i = x.view(B, N, 1, D) + c_j = c.view(B, 1, K, D) + counts = c.new_zeros(B, K) + ones = x.new_ones((B, N)) + + for _ in range(num_iters): + # E step: assign points to the nearest cluster + D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances + cl = D_ij.argmin( + dim=-1, keepdim=True + ).long() # (B, N, 1) index of point to nearest cluster + + # M step: update the centroids + c.zero_() + c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster + counts.fill_(1e-6) # avoid div0 + counts.scatter_add_( + -1, cl.squeeze(-1), ones + ) # number of points per cluster + c.divide_(counts.unsqueeze(-1)) # compute the average + + return c + + def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10): + """ + Arguments: + x: (B, N, D) + """ + B, N, D = x.size() + assert K <= N, f"{K} > {N}" + + # initialisation for the centroids + c = x[:, torch.randperm(N, device=x.device)[:K], :].clone() + + with profiler.record_function("kmeans_spherical"): + counts = c.new_zeros(B, K) + ones = x.new_ones((B, N)) + + for _ in range(num_iters): + # E step: assign points to the nearest cluster + D_ij = torch.matmul( + x, c.transpose(-2, -1) + ) # (B, N, K) cosine similarity + cl = D_ij.argmax( + dim=-1, keepdim=True + ).long() # (B, N, 1) index of point to nearest cluster + + # M step: update the centroids + c.zero_() + c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster + counts.fill_(1e-6) # avoid div0 + counts.scatter_add_( + -1, cl.squeeze(-1), ones + ) # number of points per cluster + c.divide_(counts.unsqueeze(-1)) # compute the average + c = Fn.normalize(c, p=2, dim=-1) # renormalise + return c + + def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor: + """ + Construct set of landmarks by recursively selecting new landmarks + that are maximally orthogonal to the existing set. + Returns near orthogonal landmarks with shape (B, M, D). + """ + + if self.subsample_fraction < 1.0: + # Need at least M samples of queries + num_samples = max( + int(self.subsample_fraction * q.size(-2)), self.num_landmarks + ) + q_samples = q[ + :, torch.randint(q.size(-2), (num_samples,), device=q.device), : + ] + else: + # (B, N, D) + q_samples = q + + # may need to change default eps to eps=1e-8 for mixed precision compatibility + q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1) + B, N, D = q_samples_normalized.shape + + selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device) + landmark_mask = torch.ones( + (B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device + ) + + #  Get initial random landmark + random_idx = torch.randint( + q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device + ) + selected_mask.scatter_(-2, random_idx, landmark_mask) + + #  Selected landmarks + selected_landmarks = torch.empty( + (B, self.num_landmarks, D), + device=q_samples_normalized.device, + dtype=q_samples_normalized.dtype, + ) + selected_landmarks[:, 0, :] = q_samples_normalized[ + torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), : + ].view(B, D) + + # Store computed cosine similarities + cos_sims = torch.empty( + (B, N, self.num_landmarks), + device=q_samples_normalized.device, + dtype=q_samples_normalized.dtype, + ) + + for M in range(1, self.num_landmarks): + with profiler.record_function("find new landmark"): + #  Calculate absolute cosine similarity between selected and unselected landmarks + # (B, N, D) * (B, D) -> (B, N) + cos_sims[:, :, M - 1] = torch.einsum( + "b n d, b d -> b n", + q_samples_normalized, + selected_landmarks[:, M - 1, :], + ).abs() + + # (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys + cos_sim_set = cos_sims[:, :, :M] + + #  Get orthogonal landmark: landmark with smallest absolute cosine similarity: + # set cosine similarity for already selected landmarks to > 1 + cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10 + + # (B,) - want max for non + selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1) + + #  Add most orthogonal landmark to selected landmarks: + selected_landmarks[:, M, :] = q_samples_normalized[ + torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, : + ].view(B, D) + + #  Removed selected indices from non-selected mask: + selected_mask.scatter_( + -2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask + ) + + # (B, M, D) + landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape( + B, -1, D + ) + return landmarks # (B, M, D) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/pooling.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..6c93193e75ae91ad598dd35517c737033c6a6fea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/pooling.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +@dataclass +class PoolingAttentionConfig(AttentionConfig): + pool_size: int # dimension of the input sequence + stride: Optional[int] # dimension of the internal space + padding: Optional[int] + + +@register_attention("pooling", PoolingAttentionConfig) +class Pooling(Attention): + def __init__( + self, + pool_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + *_, + **__, + ): + """ + Pooling token mixing mechanism, as proposed in + `Metaformer is actually what you need for vision`_, Yu et al (2021). + + The original notation is kept as is. + + .. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf + """ + super().__init__() + + padding = padding if padding is not None else pool_size // 2 + self.pool = nn.AvgPool2d( + pool_size, + stride=stride, + padding=pool_size // 2, + count_include_pad=False, + ) + + # MHA related flags: + # kq need to have the same dimension + self.requires_same_k_q_dimensions = False + + # This attention does not support attention masks + self.supports_attention_mask = False + + # This "attention" (token mixing) skips the multihead attention altogether + self.requires_skip_multi_head = True + self.requires_input_projection = False + + # This operator does not really handle q,k,v + self.requires_same_k_q_dimensions = True + + # This attention requires the 2d structure out of the context, + # implictly assumed to be a squared length + self.requires_squared_context = True + + def forward(self, q: torch.Tensor, *_, **__): + # Expose the 2D token structure + B, HW, C = q.shape + H = int(math.sqrt(HW)) + assert H * H == HW + + q = q.transpose(-2, -1).reshape(B, C, H, H) + + # 2D pool + x_pool = self.pool(q) - q # compensate for the residual path + + # Get back to B HW C + return x_pool.flatten(2, 3).transpose(-2, -1) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/sparsity_config.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/sparsity_config.py new file mode 100644 index 0000000000000000000000000000000000000000..727a7ff70a4fbcd1f70206719f6c14bd15b84c06 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/sparsity_config.py @@ -0,0 +1,812 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +""" +The code has been adopted from DeepSpeed +(https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py) +""" + +import random + +import torch + + +class SparsityConfig: + """Abstract Configuration class to store `sparsity configuration of a self attention layer`. + It contains shared property of different block-sparse sparsity patterns. However, each class + needs to extend it based on required property and functionality. + """ + + def __init__(self, num_heads, block_size=16, different_layout_per_head=False): + """Initialize the Sparsity Pattern Config. + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be + assigned a different sparsity layout; default is false and this will be satisfied + based on availability. + """ + + self.num_heads = num_heads + self.block_size = block_size + self.different_layout_per_head = different_layout_per_head + self.num_layout_heads = num_heads if different_layout_per_head else 1 + + def setup_layout(self, seq_len): + """Create layout tensor for the given sequence length + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout + of all head; initialized with zero + """ + + if seq_len % self.block_size != 0: + raise ValueError( + f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!" + ) + num_blocks = seq_len // self.block_size + # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout. + layout = torch.zeros( + (self.num_heads, num_blocks, num_blocks), dtype=torch.int64 + ) + return layout + + def check_and_propagate_first_head_layout(self, layout): + """If all heads require same sparsity layout, it propagate first head layout to all heads + Arguments: + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head + """ + + if not self.different_layout_per_head: + layout[1 : self.num_heads, :, :] = layout[0, :, :] + return layout + + +class DenseSparsityConfig(SparsityConfig): + """Configuration class to store `Dense` configuration. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and + comprehension. + """ + + def __init__(self, num_heads, block_size=16, different_layout_per_head=False): + """Initialize the Dense Sparsity Pattern Config. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison + and comprehension. + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: this is just for the sake of consistency with + other sparsity formats; can ignore it for DenseSparsityConfig + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + def make_layout(self, seq_len): + """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse. + Arguments: + seq_len: required: an integer determining the underling sequence length; + must be <= max sequence length + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head; for dense everything is 1 + """ + + layout = self.setup_layout(seq_len) + layout[:, :, :] = 1 + return layout + + +class FixedSparsityConfig(SparsityConfig): + """Configuration class to store `Fixed` sparsity configuration. + For more details about this sparsity config, please see `Generative Modeling with + Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_local_blocks=4, + num_global_blocks=1, + attention="bidirectional", + horizontal_global_attention=False, + num_different_global_patterns=1, + ): + """Initialize `Fixed` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be + assigned a different sparsity layout; default is false and this will be satisfied + based on availability. + num_local_blocks: optional: an integer determining the number of blocks in local attention + window. + num_global_blocks: optional: an integer determining how many consecutive blocks in a local + window is used as the representative of the window for global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global + representative of a local window, also attend to all other blocks. This is valid only if + attention type is `bidirectional`. Looking at the attention matrix, that means global + attention not only includes the vertical blocks, but also horizontal blocks. + num_different_global_patterns: optional: an integer determining number of different global + attentions layouts. While global attention can be fixed by which block/s are representative + of any local window, since there are multi-heads, each head can use a different global representative. + For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different + versions in which the first, Second, third, or forth block of each local window can be global + representative of that window. This parameter determines how many of such patterns we want. + Of course, there is a limitation based on num_local_blocks and num_global_blocks. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_local_blocks = num_local_blocks + + if num_local_blocks % num_global_blocks != 0: + raise ValueError( + f"""Number of blocks in a local window, {num_local_blocks}, + must be dividable by number of global blocks, {num_global_blocks}!""" + ) + self.num_global_blocks = num_global_blocks + + if attention != "unidirectional" and attention != "bidirectional": + raise NotImplementedError( + 'only "uni/bi-directional" attentions are supported for now!' + ) + self.attention = attention + + if attention != "bidirectional" and horizontal_global_attention: + raise ValueError( + 'only "bi-directional" attentions can support horizontal global attention!' + ) + self.horizontal_global_attention = horizontal_global_attention + + if num_different_global_patterns > 1 and not different_layout_per_head: + raise ValueError( + """Number of different layouts cannot be more than one when you have set a single layout + for all heads! Set different_layout_per_head to True.""" + ) + if num_different_global_patterns > (num_local_blocks // num_global_blocks): + raise ValueError( + f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, + cannot be larger than number of local window blocks divided by number of global blocks, + {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!""" + ) + self.num_different_global_patterns = num_different_global_patterns + + def set_local_layout(self, h, layout): + """Sets local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local layout is set + """ + + num_blocks = layout.shape[1] + for i in range(0, num_blocks, self.num_local_blocks): + end = min(i + self.num_local_blocks, num_blocks) + for row in range(i, end): + for col in range( + i, (row + 1 if self.attention == "unidirectional" else end) + ): + layout[h, row, col] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Currently we set global blocks starting from the last block of a local window to the first one. + That means if a local window consists of 4 blocks and global attention size is one block, we use + block #4 in each local window as global. If we have different layout per head, then other heads + will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global + attentions, multiple head may have same global attentions. + Note) if horizontal_global_attention is set, global blocks will be set both horizontally and + vertically. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + first_global_block_idx = ( + self.num_local_blocks + - (1 + h % self.num_different_global_patterns) * self.num_global_blocks + ) + + # set all global blocks except the last one if (in last local window) + end = num_blocks - (num_blocks % self.num_local_blocks) + for i in range(first_global_block_idx, end, self.num_local_blocks): + + # vertical global attention + first_row = 0 if self.attention == "bidirectional" else i + # (((i // self.num_local_blocks) + 1) * self.num_local_blocks) + # if (first_row < num_blocks): + layout[h, first_row:, i : i + self.num_global_blocks] = 1 + + # horizontal global attention; only in bidirectional attention + if self.horizontal_global_attention: + layout[h, i : i + self.num_global_blocks, :] = 1 + + # set last global blocks; handle possible short last local window + if end < num_blocks: + start = min( + end + first_global_block_idx, num_blocks - self.num_global_blocks + ) + end = start + self.num_global_blocks + + # vertical global attention + first_row = 0 if self.attention == "bidirectional" else start + # (((start // self.num_local_blocks) + 1) * self.num_local_blocks) + # if (first_row < num_blocks): + layout[h, first_row:, start:end] = 1 + + # horizontal global attention + if self.horizontal_global_attention: + layout[h, start:end, :] = 1 + return layout + + def make_layout(self, seq_len): + """Generates `Fixed` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_local_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class VariableSparsityConfig(SparsityConfig): + """Configuration class to store `Variable` sparsity configuration. + This layout is an extension of FixedSparsityConfig in which: + - user can set random layout; default value is zero means no random block + - user can provide a list of local block sizes + - user can provide a list of global block indices. + For more details about `Fixed` sparsity config, please see `Generative Modeling with + Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_random_blocks=0, + local_window_blocks=[4], + global_block_indices=[0], + global_block_end_indices=None, + attention="bidirectional", + horizontal_global_attention=False, + ): + """Initialize `Variable` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of sparse + self-attention is based on blocked sparse matrices. In which this parameter defines + size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a + different sparsity layout; default is false and this will be satisfied based on + availability. Currently this sparsity config can only assign single layout to all heads; + needs to be extended for different layout per head. + num_random_blocks: optional: an integer determining the number of random blocks in each block row. + local_window_blocks: optional: a list of integers determining the number of blocks in each + local attention window. It assumes first number determines # of blocks in the first local + window, second the second window, ..., and the last number determines the number of blocks + in the remaining local windows. + global_block_indices: optional: a list of integers determining which blocks are considered + as global attention. Given indices, determine the blocks that all other token blocks + attend to and they attend to all other token blocks. Default value is only index 0. + Notice that if global_block_end_indices parameter is set, this parameter is used as + starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global + window blocks. By default this is not used. But if it is set, it must have the same size + of global_block_indices parameter, and combining this two parameters, for each index i, + blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are + considered as global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global + representative of a local window, also attend to all other blocks. This is valid only if + attention type is `bidirectional`. Looking at the attention matrix, that means global + attention not only includes the vertical blocks, but also horizontal blocks. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_random_blocks = num_random_blocks + self.local_window_blocks = local_window_blocks + self.global_block_indices = global_block_indices + + if global_block_end_indices is not None: + if len(global_block_indices) != len(global_block_end_indices): + raise ValueError( + f"""Global block start indices length, {len(global_block_indices)}, must be same as + global block end indices length, {len(global_block_end_indices)}!""" + ) + for _, (start_idx, end_idx) in enumerate( + zip(global_block_indices, global_block_end_indices) + ): + if start_idx >= end_idx: + raise ValueError( + f"""Global block start index, {start_idx}, must be smaller than global block end + index, {end_idx}!""" + ) + self.global_block_end_indices = global_block_end_indices + + if attention != "unidirectional" and attention != "bidirectional": + raise NotImplementedError( + 'only "uni/bi-directional" attentions are supported for now!' + ) + self.attention = attention + + if attention != "bidirectional" and horizontal_global_attention: + raise ValueError( + 'only "bi-directional" attentions can support horizontal global attention!' + ) + self.horizontal_global_attention = horizontal_global_attention + + def set_random_layout(self, h, layout): + """Sets random attention layout used by the given head in the sparse attention. + Note) By default, it assumes there will be a unique random block layout for all heads; unless + `different_layout_per_head` parameter is set in which each head can have a different random + layout. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which random layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_random_blocks: + raise ValueError( + f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" + ) + for row in range(0, num_blocks): + rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks) + layout[h, row, rnd_cols] = 1 + return layout + + def set_local_layout(self, h, layout): + """Sets local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local layout is set + """ + + num_blocks = layout.shape[1] + start_block_idx = 0 + end_block_idx = 0 + for block_size in self.local_window_blocks: + end_block_idx += block_size + end_block_idx = min(end_block_idx, num_blocks) + for row in range(start_block_idx, end_block_idx): + for col in range( + start_block_idx, + (row + 1 if self.attention == "unidirectional" else end_block_idx), + ): + layout[h, row, col] = 1 + start_block_idx += block_size + + # if there is any remaining not attended part, use the lats local window block size as local + # window for the remaining applicable local windows + for i in range(start_block_idx, num_blocks, block_size): + end_block_idx = min(i + block_size, num_blocks) + for row in range(i, end_block_idx): + for col in range( + i, + (row + 1 if self.attention == "unidirectional" else end_block_idx), + ): + layout[h, row, col] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if self.global_block_end_indices is None: + for idx in self.global_block_indices: + # if global block idx is in the range of the sequence blocks + if idx < num_blocks: + # global rows + if self.horizontal_global_attention: + layout[h, idx, :] = 1 + + # global columns + first_row = 0 if self.attention == "bidirectional" else idx + layout[h, first_row:, idx] = 1 + else: + for _, (start_idx, end_idx) in enumerate( + zip(self.global_block_indices, self.global_block_end_indices) + ): + # if global block idx is in the range of the sequence blocks + if start_idx < num_blocks: + end_idx = min(end_idx, num_blocks) + # global rows + if self.horizontal_global_attention: + layout[h, start_idx:end_idx, :] = 1 + + # global columns + first_row = 0 if self.attention == "bidirectional" else start_idx + layout[h, first_row:, start_idx:end_idx] = 1 + return layout + + def make_layout(self, seq_len): + """Generates `Variable` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_random_layout(h, layout) + layout = self.set_local_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class BigBirdSparsityConfig(SparsityConfig): + """Configuration class to store `BigBird` sparsity configuration. + For more details about this sparsity config, please see `Big Bird: Transformers for + Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf + This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_random_blocks=1, + num_sliding_window_blocks=3, + num_global_blocks=1, + attention="bidirectional", + ): + """Initialize the BigBird Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned + a different sparsity layout; default is false and this will be satisfied based on + availability. + num_random_blocks: optional: an integer determining the number of random blocks in each + block row. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding + local attention window. + num_global_blocks: optional: an integer determining how many consecutive blocks, starting + from index 0, are considered as global attention. Global block tokens will be attended + by all other block tokens and will attend to all other block tokens as well. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_random_blocks = num_random_blocks + self.num_sliding_window_blocks = num_sliding_window_blocks + self.num_global_blocks = num_global_blocks + + if attention != "unidirectional" and attention != "bidirectional": + raise NotImplementedError( + 'only "uni/bi-directional" attentions are supported for now!' + ) + self.attention = attention + + def set_random_layout(self, h, layout): + """Sets random attention layout used by the given head in the sparse attention. + Note) By default, it assumes there will be a unique random block layout for all heads; unless + `different_layout_per_head` parameter is set in which each head can have a different random layout. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which random layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_random_blocks: + raise ValueError( + f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" + ) + + for row in range(0, num_blocks): + sample_range = ( + range(0, num_blocks) + if self.attention == "bidirectional" + else range(0, row + 1) + ) + rnd_cols = random.sample(sample_range, self.num_random_blocks) + layout[h, row, rnd_cols] = 1 + return layout + + def set_sliding_window_layout(self, h, layout): + """Sets sliding local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local sliding window layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_sliding_window_blocks: + raise ValueError( + f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than + overall number of blocks in a row, {num_blocks}!""" + ) + + w = self.num_sliding_window_blocks // 2 + for row in range(0, num_blocks): + start = max(0, row - w) + end = min(row + w + 1, num_blocks) + layout[h, row, start:end] = 1 + return layout + + def set_global_layout_itc(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout + of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_global_blocks: + raise ValueError( + f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" + ) + + # global rows + layout[h, 0 : self.num_global_blocks, :] = 1 + + # global columns + layout[h, :, 0 : self.num_global_blocks] = 1 + + if self.attention == "unidirectional": + # zero out anything attending to the future + layout = torch.tril(layout) + + return layout + + def make_layout(self, seq_len): + """Generates `BigBird` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_random_layout(h, layout) + layout = self.set_sliding_window_layout(h, layout) + layout = self.set_global_layout_itc(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class BSLongformerSparsityConfig(SparsityConfig): + """Configuration class to store edited `Longformer` sparsity configuration. + Note) this is a block-sparse version of the Longformer which is slightly different than original + Longformer; which is element-wise sparsity. + For more details about this sparsity config, please see `Longformer: + The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf + This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_sliding_window_blocks=3, + global_block_indices=[0], + global_block_end_indices=None, + attention="bidirectional", + ): + """Initialize the edited `Longformer` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of sparse + self-attention is based on blocked sparse matrices. In which this parameter defines size + of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a + different sparsity layout; default is false and this will be satisfied based on + availability. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding + local attention window. + global_block_indices: optional: a list of integers determining which blocks are considered + as global attention. Given indices, determine the blocks that all other token blocks + attend to and they attend to all other token blocks. Default value is only index 0. + Notice that if global_block_end_indices parameter is set, this parameter is used as + starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global + window blocks. By default this is not used. But if it is set, it must have the same size + of global_block_indices parameter, and combining this two parameters, for each index i, + blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are + considered as global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_sliding_window_blocks = num_sliding_window_blocks + self.global_block_indices = global_block_indices + self.attention = attention + + if global_block_end_indices is not None: + if len(global_block_indices) != len(global_block_end_indices): + raise ValueError( + f"""Global block start indices length, {len(global_block_indices)}, must be same as + global block end indices length, {len(global_block_end_indices)}!""" + ) + for _, (start_idx, end_idx) in enumerate( + zip(global_block_indices, global_block_end_indices) + ): + if start_idx >= end_idx: + raise ValueError( + f"""Global block start index, {start_idx}, must be smaller than global block end + index, {end_idx}!""" + ) + self.global_block_end_indices = global_block_end_indices + + def set_sliding_window_layout(self, h, layout): + """Sets sliding local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout + of all head in which local sliding window layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_sliding_window_blocks: + raise ValueError( + f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller + than overall number of blocks in a row, {num_blocks}!""" + ) + + w = self.num_sliding_window_blocks // 2 + for row in range(0, num_blocks): + start = max(0, row - w) + end = min(row + w + 1, num_blocks) + layout[h, row, start:end] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if self.global_block_end_indices is None: + for idx in self.global_block_indices: + # if global block idx is in the range of the sequence blocks + if idx < num_blocks: + # global rows + layout[h, idx, :] = 1 + + # global columns + layout[h, :, idx] = 1 + else: + for _, (start_idx, end_idx) in enumerate( + zip(self.global_block_indices, self.global_block_end_indices) + ): + # if global block idx is in the range of the sequence blocks + if start_idx < num_blocks: + end_idx = min(end_idx, num_blocks) + # global rows + layout[h, start_idx:end_idx, :] = 1 + + # global columns + layout[h, :, start_idx:end_idx] = 1 + if self.attention == "unidirectional": + layout = torch.tril(layout) + return layout + + def make_layout(self, seq_len): + """Generates edited `Longformer` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_sliding_window_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout diff --git a/.venv/lib/python3.11/site-packages/xformers/components/attention/utils.py b/.venv/lib/python3.11/site-packages/xformers/components/attention/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d6bb06a1acf4f1884442e6dd5a6c238e0a6f641e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/attention/utils.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + + +# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len) +def reshape_key_padding_mask( + key_padding_mask: torch.Tensor, batched_dim: int +) -> torch.Tensor: + assert key_padding_mask.ndim == 2 + batch_size, src_len = key_padding_mask.size() + num_heads = batched_dim // batch_size + return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads) + + +def _reshape_key_padding_mask( + key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int +) -> torch.Tensor: + assert key_padding_mask.shape == (batch_size, src_len) + key_padding_mask = ( + key_padding_mask.view(batch_size, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(batch_size * num_heads, 1, src_len) + ) + return key_padding_mask + + +# Combine the attention mask and key padding mask into a single mask +# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py +# Additive masking not yet supported +def maybe_merge_masks( + att_mask: Optional[torch.Tensor], + key_padding_mask: Optional[torch.Tensor], + batch_size: int, + src_len: int, + num_heads: int, + tgt_len: Optional[int] = None, +) -> Optional[torch.Tensor]: + if tgt_len is None: + tgt_len = src_len + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, src_len) + key_padding_mask = _reshape_key_padding_mask( + key_padding_mask, batch_size, src_len, num_heads + ) + if att_mask is None: + # make sure dimensions of key padding mask are the same as those expected for att_mask + att_mask = key_padding_mask.expand(-1, tgt_len, -1) + # Assumption is that False means to mask. + elif att_mask.dtype == torch.bool: + att_mask = att_mask.logical_and(key_padding_mask) + else: + att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf")) + + return att_mask + + +# Assumes that matrix passed in has had softmax applied to it. +def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False): + """ + Computing the Moore-Penrose inverse. + Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient + matrix-matrix multiplications. + """ + + i = torch.eye( + softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype + ) + k = softmax_mat + + # The entries of K are positive and ||K||_{\infty} = 1 due to softmax + if pinverse_original_init: + # This original implementation is more conservative to compute coefficient of Z_0. + v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2) + else: + # This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster + # convergence. + v = ( + 1 + / torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None] + * k.transpose(-1, -2) + ) + + for _ in range(n_iter): + kv = torch.matmul(k, v) + v = torch.matmul( + 0.25 * v, + 13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)), + ) + return v + + +def bool_mask_to_additive( + mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32 +) -> torch.Tensor: + assert ( + mask.dtype == torch.bool + ), "This util is meant to convert in between bool masks and additive ones" + + mask_ = torch.zeros_like(mask, dtype=dtype) + mask_[~mask] = float("-inf") + return mask_ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__init__.py b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4df0a5ce9153d7f451280336e217ea3d4cf9a1ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from pathlib import Path +from typing import Any, Callable, Dict, Set, Union + +from xformers.utils import ( + generate_matching_config, + get_registry_decorator, + import_all_modules, +) + +from .base import Feedforward, FeedforwardConfig # noqa + +# CREDITS: Classy Vision registry mechanism + +FEEDFORWARD_REGISTRY: Dict[str, Any] = {} +FEEDFORWARD_CLASS_NAMES: Set[str] = set() + + +def build_feedforward(config: Union[Dict[str, Any], FeedforwardConfig]): + """Builds a feedforward from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_feedforward", + "foo": "bar"}` will find a class that was registered as "my_feedforward" + (see :func:`register_feedforward`) and call .from_config on it.""" + + if not isinstance(config, FeedforwardConfig): + config_instance = generate_matching_config( + config, FEEDFORWARD_REGISTRY[config["name"]].config + ) + else: + config_instance = config + + return FEEDFORWARD_REGISTRY[config_instance.name].constructor.from_config( + config_instance + ) + + +"""Registers a Feedforward subclass. + + This decorator allows xFormers to instantiate a subclass of Feedforward + from a configuration file, even if the class itself is not part of the + xFormers framework. To use it, apply this decorator to a Feedforward + subclass, like this: + + .. code-block:: python + + @dataclass + class MyConfig: + ... + + @register_feedforward('my_ff', MyConfig) + class MyFeedforward(Feedforward): + ... + + To instantiate a feedforward from a configuration file, see :func:`build_feedforward`.""" +register_feedforward: Callable[ + [str, Any], Callable[[Any], Any] +] = get_registry_decorator( + FEEDFORWARD_REGISTRY, FEEDFORWARD_CLASS_NAMES, Feedforward, FeedforwardConfig +) + +from .mlp import MLP # noqa + +__all__ = [ + "MLP", + "Feedforward", + "build_feedforward", + "register_feedforward", +] + +# automatically import any Python files in the directory +import_all_modules(str(Path(__file__).parent), "xformers.components.feedforward") diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd556bbf19fab0d2be22f2049b00c9ee6701b1a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07559e682c98a01a39d7c43a0cdc17a9b13f1159 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/conv_mlp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/conv_mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a20e88c77c371620a14632a56f944eb9e4e5af0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/conv_mlp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/mixture_of_experts.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/mixture_of_experts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..490a6c7b4d7aa1c87e498f61beae0115608b0b49 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/mixture_of_experts.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/mlp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16572d42a432e23a3f215605193178d5a378708a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/__pycache__/mlp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/base.py b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/base.py new file mode 100644 index 0000000000000000000000000000000000000000..76a357cfb7469dee12a1dd70363ea3a90e28412e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/base.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABCMeta, abstractmethod +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch.nn as nn + +from xformers._deprecation_warning import deprecated_function +from xformers.components import Activation + +Self = TypeVar("Self", bound="Feedforward") + + +@dataclass +class FeedforwardConfig: + name: str + dim_model: int + dropout: float + activation: Activation + + +# Define the common interface, every feedforward block needs to derive from it +class Feedforward(nn.Module, metaclass=ABCMeta): + @abstractmethod + def __init__( + self, + dim_model: Optional[int] = None, + dropout: Optional[float] = None, + activation: Optional[Activation] = None, + *args, + **kwargs, + ): + super().__init__() + deprecated_function(self) + + # This feedforward requires a CUDA accelerator + self.requires_cuda = False + + # This feedforward requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = False + + @classmethod + def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/conv_mlp.py b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/conv_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..895211977d92fcd55c582667475c67f6b3c0b4a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/conv_mlp.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: Largely reusing the code from the reference VAN implementation +# see https://github.com/Visual-Attention-Network + +import math +from dataclasses import dataclass +from typing import Optional + +import torch.nn as nn + +from xformers.components import Activation, build_activation +from xformers.components.feedforward import Feedforward, FeedforwardConfig + +from . import register_feedforward + + +@dataclass +class ConvMlpConfig(FeedforwardConfig): + hidden_layer_multiplier: int + dim_model: int + dim_model_out: Optional[int] + act_layer: Activation + dropout: float + + +@register_feedforward("Conv2DFeedforward", ConvMlpConfig) +class Conv2DFeedforward(Feedforward): + """ + A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.) + + .. _VAN: https://arxiv.org/pdf/2202.09741.pdf + """ + + def __init__( + self, + dim_model: int, + hidden_layer_multiplier: int = 1, + dim_model_out: Optional[int] = None, + activation: Activation = Activation.GeLU, + dropout=0.0, + *args, + **kwargs, + ): + super().__init__() + out_features = dim_model_out or dim_model + hidden_features = hidden_layer_multiplier * dim_model + + self.conv_mlp = nn.Sequential( + nn.Conv2d(dim_model, hidden_features, 1), + nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features, + ), + build_activation(activation), + nn.Conv2d(hidden_features, out_features, 1), + nn.Dropout(dropout), + ) + + # This feedforward requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = True + + def init_weights(self, **kwargs): + # Follow the original init, but also make it possible to initialize from the outside + def init_module(m: nn.Module): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + self.apply(init_module) + + def forward(self, x): + # The conv layers expect NCHW, we have NLC by default + B, L, C = x.shape + HW = int(math.sqrt(x.shape[-2])) + assert HW**2 == L, "Conv2DFeedforward requires squared context lengths" + + x = x.reshape((B, HW, HW, C)).swapdims(1, -1) + + # The actual FW, including the 2d convolutions + x = self.conv_mlp(x) + + # back to NLC + x = x.transpose(1, -1) + return x.flatten(1, 2) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/mixture_of_experts.py b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/mixture_of_experts.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ab1841f4a69f8500c4e0ac5c55f661e5dee88b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/mixture_of_experts.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional, Union + +import torch + +from xformers.components import Activation +from xformers.components.feedforward import ( + Feedforward, + FeedforwardConfig, + register_feedforward, +) + +logger = logging.getLogger("xformers") + + +_is_fairscale_available = True + +try: + import torch.distributed as dist + from fairscale.nn import MOELayer, Top2Gate # type: ignore + + from xformers.components.feedforward import MLP + +except ImportError: + logger.warning( + "Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed." + " Please install them if you would like to use MoE" + ) + _is_fairscale_available = False + + +if _is_fairscale_available: + + # Credits: initially implemented in FairScale for sanity checking + class RoundRobinGate(torch.nn.Module): + def __init__(self, model_dim, num_experts): + super().__init__() + self.model_dim = model_dim + self.num_experts = num_experts + + def forward(self, input): + s = input.shape[0] + assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0" + capacity = 2 * s // self.num_experts + output = torch.zeros( + s, self.num_experts, capacity, dtype=input.dtype, device=input.device + ) + for i in range(s): + output[i, i % self.num_experts, i // self.num_experts] = 1.0 + return 0.0, output, output.bool() + + class GateConfig(str, Enum): + RoundRobin = "round_robin" + Top2 = "top_2" + # Other gating techniques could be exposed here + + @dataclass + class MoEConfig(FeedforwardConfig): + number_of_experts: int + gate: GateConfig + number_of_local_experts: Optional[int] = None + expert_constructor: Optional[Any] = None + hidden_layer_multiplier: Optional[int] = None + group: Optional[Any] = None + + @register_feedforward("MixtureOfExperts", MoEConfig) + class MixtureOfExperts(Feedforward): + """ + A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_. + xFormers uses the FairScale_ implementation under the hood. + + .. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt + + .. _Gshard: https://arxiv.org/pdf/2006.16668.pdf + .. _FairScale: https://github.com/facebookresearch/fairscale/ + """ + + def __init__( + self, + dim_model: int, + dropout: float, + activation: Activation, + number_of_experts: int, + gate: Union[GateConfig, torch.nn.Module], + number_of_local_experts: Optional[int] = None, + expert_constructor: Optional[Callable[[], torch.nn.Module]] = None, + hidden_layer_multiplier: Optional[int] = None, + group: Optional[Any] = None, + *_, + **__, + ): + super().__init__() + + # Handle a possibly uninitialized process group + assert ( + dist.is_initialized() + ), "Mixture of Experts require torch distributed to be initialized" + + if number_of_local_experts is not None: + assert number_of_experts >= number_of_local_experts + else: + if dist.get_world_size() == 1: + logger.warning("Local experts no specified but world size of 1") + logger.warning("Assuming that all experts are local") + number_of_local_experts = number_of_experts + else: + number_of_local_experts = 1 + + # Programatically handle the gating technique + if not isinstance(gate, torch.nn.Module): + gate_constructor = { + GateConfig.RoundRobin: RoundRobinGate, + GateConfig.Top2: Top2Gate, + }[gate] + + self.gate = gate_constructor(dim_model, number_of_experts) + else: + self.gate = gate + + # Programatically handle the experts + if expert_constructor is None: + + multiplier = ( + hidden_layer_multiplier + if hidden_layer_multiplier is not None + else 4 + ) + + def expert_constructor() -> torch.nn.Module: + return MLP(dim_model, dropout, activation, multiplier) + + assert expert_constructor is not None + + local_experts = torch.nn.ModuleList( + [expert_constructor() for _ in range(number_of_local_experts)] + ) + + self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group) + + self.requires_cuda = True + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + # FairScale MoE assumes that the dimensions are [S, B, E] + # xFormers assumes [B, S, E] + return self.moe(inputs.movedim(0, 1)).movedim(0, 1) diff --git a/.venv/lib/python3.11/site-packages/xformers/components/feedforward/mlp.py b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..fefb328682919ddff5bfd3293a411769763e2d6b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/components/feedforward/mlp.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from xformers.components import Activation, build_activation +from xformers.components.feedforward import Feedforward, FeedforwardConfig + +from . import register_feedforward + + +@dataclass +class MlpConfig(FeedforwardConfig): + hidden_layer_multiplier: int + bias: bool + + +@register_feedforward("MLP", MlpConfig) +class MLP(Feedforward): + def __init__( + self, + dim_model: int, + dropout: float, + activation: Activation, + hidden_layer_multiplier: int, + bias: bool = True, + *args, + **kwargs, + ): + super().__init__() + dim_mlp = hidden_layer_multiplier * dim_model + self.mlp = nn.Sequential( + nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias), + build_activation(activation), + nn.Dropout(dropout), + nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias), + nn.Dropout(dropout), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.mlp(inputs) diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__init__.py b/.venv/lib/python3.11/site-packages/xformers/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9d0a492624968258c8d5b9180658e3b2fb113b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from .api import profile, step +from .profiler import MemSnapshotsProfiler, NsightProfiler, PyTorchProfiler + +__all__ = [ + "profile", + "step", + "MemSnapshotsProfiler", + "PyTorchProfiler", + "NsightProfiler", +] diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a045cb79d72960d4f6a1194be1f8dca6195da6c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4db034b3422e1d44f5d007b2d655bc3718bcc083 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/device_limits.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/device_limits.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4befcb07fe669e2770f3fe24769d485e73c47080 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/device_limits.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/find_slowest.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/find_slowest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2c2602e0caf3cf4b14a8cf41083bc25b4e36723 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/find_slowest.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profile_analyzer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profile_analyzer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d014481b5ecf05f6c44b4de7f5c15e5986482fb9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profile_analyzer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be74f7a24718667a009baef9ffab6d083071366c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler_dcgm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler_dcgm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca62f158a35acd0c6ce16914591c10bf16fa1eda Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler_dcgm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler_dcgm_impl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler_dcgm_impl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e84c96edf2c1b1002a60efbf25d7a07b1dd4f773 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/profiler/__pycache__/profiler_dcgm_impl.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/api.py b/.venv/lib/python3.11/site-packages/xformers/profiler/api.py new file mode 100644 index 0000000000000000000000000000000000000000..02722dc4ac6f6d1db1121e0334a664881b9c70a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/api.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Sequence, Tuple + +import torch.nn as nn + +from .profiler import ( + MemSnapshotsProfiler, + NsightProfiler, + PyTorchProfiler, + PyTorchProfiler_CUDAOnly, + _Profiler, +) +from .profiler_dcgm import DCGMProfiler # noqa: F401 + +DEFAULT_SCHEDULE = ( + (MemSnapshotsProfiler, 0, 2), + (NsightProfiler, 4, 6), + (PyTorchProfiler, 6, 7), + (PyTorchProfiler_CUDAOnly, 7, 8), + # TODO: Found issues where this can take minutes to + # start, as it flushes previous values + # (DCGMProfiler, 9, 11), +) + + +def profile( + output_dir: str, + module: Optional[nn.Module] = None, + schedule: Sequence[Tuple[Any, int, int]] = DEFAULT_SCHEDULE, +): + """ + A pre-configured profiler that will run on the first ~20 steps of the training + It will provide multiple traces that can be exploited later. + Use it in a context manager around your training loop, and call `xformers.profiler.step` + before starting the next iteration. + + :Examples: + + .. code-block:: python + + import torch + import timm.models + import xformers.profiler + + dtype = torch.bfloat16 + device = "cuda" + model = timm.models.vit_large_patch16_224().to(device).to(dtype) + inp = torch.zeros([64, 3, 224, 224], device=device, dtype=dtype) + optim = torch.optim.Adam(model.parameters()) + + with xformers.profiler.profile( + output_dir="profile_data", + module=model, + schedule=[ + (MemSnapshotsProfiler, 0, 2), + (NsightProfiler, 4, 6), + (PyTorchProfiler, 6, 20), + ] + ): + for i in range(20): + model(inp).sum().backward() + optim.step() + optim.zero_grad() + xformers.profiler.step() + + # alternatively, use the profiler without context and with ``.start()`` / `.stop()` + # calls. + + xprofiler = xformers.profiler.profile(...) + xprofiler.start() + + for i in range(20): + model(inp).sum().backward() + optim.step() + optim.zero_grad() + xprofiler.step() + + xprofiler.stop() + """ + return _Profiler(output_dir=output_dir, schedule=schedule, module=module) + + +def step() -> None: + """See `xformers.profiler.profile`""" + # Silently return if no profiler is enabled + if _Profiler._CURRENT_PROFILER is None: + return + _Profiler._CURRENT_PROFILER.step() diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/device_limits.py b/.venv/lib/python3.11/site-packages/xformers/profiler/device_limits.py new file mode 100644 index 0000000000000000000000000000000000000000..7d35022303733889fb4280cfae1f2f493559d5a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/device_limits.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Mapping, Optional, Tuple + +import torch + + +@dataclass +class DeviceLimit: + name: str = "default" # pattern to match from `torch.cuda.get_device_name()` + source: str = "" + sm: Tuple[int, int] = (0, 0) + # bytes/s + gmem_bandwidth: float = math.inf + # dtype -> TFlop/s + gemm_tflops: Mapping[torch.dtype, float] = field(default_factory=dict) + + +# For f32, we assume we can use tf32 +DEVICE_LIMITS: Tuple[DeviceLimit, ...] = ( + DeviceLimit( + "H100", + "https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet", # noqa: E501 + sm=(9, 0), + gmem_bandwidth=3.35 * (1024**4), # NOTE: PCIe is 2 TB/s + gemm_tflops={ + torch.float64: 67, + # NOTE: NVIDIA gives all numbers "with 2:4 sparsity" + # but we want the full GEMM numbers + torch.float32: 989 // 2, + torch.float16: 1979 // 2, + torch.bfloat16: 1979 // 2, + torch.int8: 3958 // 2, + }, + ), + DeviceLimit( + "A100", + "https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf", # noqa: E501 + sm=(8, 0), + gmem_bandwidth=2 * (1024**4), # NOTE: PCIe is 1.5 TB/s + gemm_tflops={ + torch.float64: 19.5, + torch.float32: 156, + torch.float16: 312, + torch.bfloat16: 312, + torch.int8: 624, + }, + ), + DeviceLimit( + "A30", + "https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf", + sm=(8, 0), + gmem_bandwidth=933 * (1024**3), + gemm_tflops={ + torch.float64: 10.3, + torch.float32: 82, + torch.float16: 165, + torch.bfloat16: 165, + torch.int8: 330, + }, + ), + DeviceLimit( + "T4", + "https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf", + sm=(7, 5), + gmem_bandwidth=300 * (1024**3), + gemm_tflops={ + torch.float32: 8.1, + torch.float16: 65, + torch.int8: 130, + }, + ), + # Assuming SXM2 + DeviceLimit( + "V100", + "https://images.nvidia.com/content/technologies/volta/pdf/tesla-volta-v100-datasheet-letter-fnl-web.pdf", + sm=(7, 0), + gmem_bandwidth=900 * (1024**3), + gemm_tflops={ + torch.float64: 7.8, + torch.float32: 15.7, + torch.float16: 125, + }, + ), + DeviceLimit( + "P100", + "https://images.nvidia.com/content/tesla/pdf/nvidia-tesla-p100-datasheet.pdf", + sm=(6, 0), + gmem_bandwidth=732 * (1024**3), + gemm_tflops={ + torch.float64: 5.3, + torch.float32: 10.6, + torch.float16: 21.2, + }, + ), +) + + +def get_device_limits(device) -> Optional[DeviceLimit]: + """Currently only implemented for GPUs""" + if device is not None and device.type == "cuda": + device_sm = torch.cuda.get_device_capability(device) + device_name = torch.cuda.get_device_name(device) + for lim in DEVICE_LIMITS: + if lim.sm == device_sm: + if lim.name in device_name: + return lim + return None diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/find_slowest.py b/.venv/lib/python3.11/site-packages/xformers/profiler/find_slowest.py new file mode 100644 index 0000000000000000000000000000000000000000..e1eee735f4a8d8e56cac92d50c942f43e0b99873 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/find_slowest.py @@ -0,0 +1,180 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import glob +import gzip +import json +import os +import sys +from collections import defaultdict +from typing import Dict, List + +import numpy as np + + +def read_gzipped_json(file_path): + with gzip.open(file_path, "rt") as f: + return json.load(f) + + +# Extract detailed event durations +def extract_detailed_info(log_data): + events = [] + rank = log_data["distributedInfo"]["rank"] + + for event in log_data["traceEvents"]: + if "name" in event and "dur" in event: + events.append( + { + "name": event["name"], + "start_time": event["ts"], + "duration_ms": (event["dur"] / 1000), # convert to milliseconds + "rank": f"GPU {rank}", + "trace_name": log_data["traceName"], + "cat": event["cat"], + "log_name": log_data["log_name"], + } + ) + return events + + +def print_json_as_dataframe(json_list): + if not json_list: + print("Empty list") + return + + # Extract the headers from the keys of the first dictionary + headers = list(json_list[0].keys()) + + # Determine the width of each column + col_widths = {header: max(len(header), 10) for header in headers} + for row in json_list: + for header in headers: + col_widths[header] = max(col_widths[header], len(str(row[header]))) + + # Create the header row + header_row = " ".join(f"{header:<{col_widths[header]}}" for header in headers) + print(header_row) + print("-" * len(header_row)) + + # Create each data row + for row in json_list: + data_row = " ".join( + f"{str(row[header]):<{col_widths[header]}}" for header in headers + ) + print(data_row) + + +def compute_std_dev_of_event_durations_over_ranks(events, top=5): + # Step 1: Group by 'rank' and 'kernel' and sum the 'duration_ms' + grouped_data: defaultdict[str, defaultdict[str, float]] = defaultdict( + lambda: defaultdict(float) + ) + for event in events: + grouped_data[event["name"]][event["rank"]] += event["duration_ms"] + + # Step 2: Calculate the standard deviation across ranks for each kernel + std_devs = [] + for name, ranks in grouped_data.items(): + durations = np.array(list(ranks.values())) + std_dev = np.std(durations, ddof=1) + std_devs.append({"name": name, "std_dev": std_dev}) + + # Step 3: Sort by standard deviation in descending order + std_devs.sort(key=lambda x: x["std_dev"], reverse=True) + for r in std_devs: + r["std_dev"] = f"{r['std_dev']:.2f} ms" + + return std_devs[:top] + + +def sort_nccl_events( + nccl_events, top_k: int = 3, last_k: int = 3 +) -> List[Dict[str, str]]: + # Step 1: Group by 'log_name' and sum the 'duration_ms' + grouped_data: Dict[str, float] = defaultdict(float) + for event in nccl_events: + key = event["log_name"] + grouped_data[key] += event["duration_ms"] + + # Step 2: Create a sorted list of tuples by 'duration_ms' in descending order + sorted_list = sorted(grouped_data.items(), key=lambda x: x[1], reverse=True) + + # Step 3: Format the sorted list + formatted_list: List[Dict[str, str]] = [ + {"log_name": log_name, "nccl_ms": f"{duration:.2f} ms"} + for log_name, duration in sorted_list + ] + + # Step 4: Get top_k and last_k items + top_k_list = formatted_list[:top_k] + last_k_list = formatted_list[-last_k:] + + return top_k_list + last_k_list + + +def print_profiling_info(cuda_profile_dir: str): + has_json_gz_files = None + + cuda_profile_path_name = f"{cuda_profile_dir}/*trace.json.gz" + + profile_files = glob.glob(cuda_profile_path_name) + + if len(profile_files) == 0: + cuda_profile_path_name = f"{cuda_profile_dir}/*.json" + + profile_files = glob.glob(cuda_profile_path_name) + has_json_gz_files = False + else: + has_json_gz_files = True + + if len(profile_files) == 0: + raise Exception( + f"Couldnt find any profiling trace in the specified directory: {cuda_profile_dir}" + ) + + # Extract detailed NCCL event durations for all logs + events_details = [] + total_files = len(profile_files) + for index, profile_trace_path in enumerate(profile_files): + print(f"Processing file {index + 1}/{total_files}", end="\r") + sys.stdout.flush() + + if has_json_gz_files: + log_data = read_gzipped_json(profile_trace_path) + else: + with open(profile_trace_path, "r") as f: + log_data = json.loads(f.read()) + + log_data["log_name"] = os.path.basename(profile_trace_path) + events_details.extend(extract_detailed_info(log_data)) + print() + + kernel_events = [e for e in events_details if e["cat"] == "kernel"] + communication_kernels = [e for e in kernel_events if "nccl" in e["name"]] + computation_kernels = [e for e in kernel_events if "nccl" not in e["name"]] + + print("The longest and shortest communication_kernels:") + print_json_as_dataframe(sort_nccl_events(communication_kernels)) + print("\n\n") + + std_df = compute_std_dev_of_event_durations_over_ranks(communication_kernels) + print("The standard deviation of nccl kernels durations across ranks:") + print_json_as_dataframe(std_df) + print("\n\n") + + std_df = compute_std_dev_of_event_durations_over_ranks(computation_kernels) + print("The standard deviation of computation kernels durations across ranks:") + print_json_as_dataframe(std_df) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process CUDA profile directory.") + parser.add_argument("cuda_profile_dir", type=str, help="The CUDA profile directory") + + args = parser.parse_args() + + print_profiling_info(args.cuda_profile_dir) diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/profile_analyzer.py b/.venv/lib/python3.11/site-packages/xformers/profiler/profile_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..baf4e5f4580402cedec70118f6988c7ac169f6d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/profile_analyzer.py @@ -0,0 +1,235 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, cast + +import torch + + +class FakeKinetoEvent: + def __init__(self, e: torch._C._autograd._KinetoEvent) -> None: + for attr in dir(e): + if attr.startswith("_"): + continue + setattr(self, attr, getattr(e, attr)) + self._kineto_event = e + + +def _attention_flops(queries, values, causal: bool, fmt: str = "BHMK") -> int: + assert isinstance(causal, bool) + assert fmt in ["BMHK", "BHMK"] + if fmt == "BMHK": + queries, values = [[x[0], x[2], x[1], x[3]] for x in [queries, values]] + *B, N, K = queries + *B, Nv, Kv = values + if causal: # NOTE: Causal from bottom right + # non-causal part + flops = 2 * N * max(Nv - N, 0) * K + 2 * max(Nv - N, 0) * max(Nv - N, 0) * Kv + # causal part + flops += ( + 2 * min(N, Nv) * min(N, Nv) * K + 2 * min(N, Nv) * min(N, Nv) * Kv + ) // 2 + else: + flops = 2 * N * Nv * K + 2 * N * Nv * Kv + for b in B: + flops *= b + return int(flops) + + +def _get_arg_idx(op, *arg_names: str) -> int: + for i, arg in enumerate(op.default._schema.arguments): + if arg.name in arg_names: + return i + raise ValueError(f"No such argument {arg_names} found in {op.default._schema}") + + +def _replace_if_needed( + e: torch._C._autograd._KinetoEvent, +) -> torch._C._autograd._KinetoEvent: + """ + Adds a flops amount for operators that don't have this information in Kineto already + This mostly applies for the attention for now, as GEMMs are already calculated by Kineto + and other operations are negligible. + """ + if e.device_type().name != "CPU": + return e + op_name = e.name() + flops = None + + FMT_BMHK = dict(fmt="BMHK") + ATTN_OPS = { + getattr(lib, op).default.name(): (getattr(lib, op), is_bwd, kwargs) + for lib, op, is_bwd, kwargs in [ + (torch.ops.aten, "scaled_dot_product_attention", False, {}), + (torch.ops.xformers_flash, "flash_fwd", False, FMT_BMHK), + ( + torch.ops.xformers, + "efficient_attention_forward_cutlass", + False, + FMT_BMHK, + ), + (torch.ops.aten, "_efficient_attention_forward", False, FMT_BMHK), + (torch.ops.aten, "_scaled_dot_product_flash_attention_backward", True, {}), + ( + torch.ops.aten, + "_scaled_dot_product_efficient_attention_backward", + True, + {}, + ), + (torch.ops.xformers_flash, "flash_bwd", True, FMT_BMHK), + ( + torch.ops.xformers, + "efficient_attention_backward_cutlass", + True, + FMT_BMHK, + ), + (torch.ops.aten, "_efficient_attention_backward", True, FMT_BMHK), + (torch.ops.aten, "_scaled_dot_product_cudnn_attention_backward", True, {}), + ] + if hasattr(lib, op) + } + if op_name in ATTN_OPS.keys(): + op, is_bwd, kwargs = ATTN_OPS[op_name] + shapes = e.shapes() + concrete_inputs = e.concrete_inputs() + try: + is_causal = concrete_inputs[_get_arg_idx(op, "causal", "is_causal")] + except ValueError: + is_causal = concrete_inputs[_get_arg_idx(op, "custom_mask_type")] != 0 + flops = _attention_flops( + shapes[_get_arg_idx(op, "query")], + shapes[_get_arg_idx(op, "value")], + is_causal, + **kwargs, + ) + if is_bwd: + flops = flops * 5 // 2 + if flops is not None: + new_e = FakeKinetoEvent(e) + new_e.flops = lambda: flops # type: ignore + e = cast(torch._C._autograd._KinetoEvent, new_e) + return e + + +@dataclass +class AnalyzedTrace: + operations_per_dtype_fw: Dict[torch.dtype, float] + operations_per_dtype_bw: Dict[torch.dtype, float] + total_time_s: float + + def compute_num_ops( + self, dtype: torch.dtype, fw: bool = True, bw: bool = True + ) -> float: + ops = 0.0 + if fw: + ops += self.operations_per_dtype_fw.get(dtype, 0.0) + if bw: + ops += self.operations_per_dtype_bw.get(dtype, 0.0) + return ops + + def compute_hfu(self, hardware_flops: Dict[torch.dtype, float]) -> float: + hfu_seconds = 0.0 + for dtype, hw_flops in hardware_flops.items(): + hfu_seconds += self.compute_num_ops(dtype) / hw_flops + return hfu_seconds / self.total_time_s + + def compute_mfu(self, hardware_flops: Dict[torch.dtype, float]) -> float: + # Estimated by considering the bw flops should be exactly 2x the fw flops + # The reason MFU!=HFU is because of recomputation in the BW pass + hfu_seconds = 0.0 + for dtype, hw_flops in hardware_flops.items(): + hfu_seconds += ( + min( + 3 * self.compute_num_ops(dtype, bw=False), + self.compute_num_ops(dtype), + ) + / hw_flops + ) + return hfu_seconds / self.total_time_s + + @staticmethod + def _find_all_root_events_with_flops( + all_events: Sequence[torch._C._autograd._KinetoEvent], + ) -> Sequence[torch._C._autograd._KinetoEvent]: + # Filters-out non-dispatch ops + # Or operations without flop counted + all_ops_with_flops = [ + e + for e in all_events + if ( + e.device_type().name == "CPU" + and (e.dtypes() or e.shapes()) + and e.flops() > 0 + ) + ] + events_per_group: Dict[ + Any, List[torch._C._autograd._KinetoEvent] + ] = defaultdict(list) + for e in all_ops_with_flops: + events_per_group[(e.start_thread_id(), e.device_type())].append(e) + root_events: List[torch._C._autograd._KinetoEvent] = [] + for events in events_per_group.values(): + # We assume that 2 events are either non-overlapping, + # or one is contained entirely within the other + events.sort(key=lambda e: (e.start_ns(), -e.duration_ns())) + current_root: Optional[torch._C._autograd._KinetoEvent] = None + for e in events: + if ( + current_root is None + or e.start_ns() + > current_root.start_ns() + current_root.duration_ns() + ): + current_root = e + root_events.append(e) + return root_events + + @staticmethod + def from_profile( + events: Sequence[torch._C._autograd._KinetoEvent], + ) -> "AnalyzedTrace": + events = [_replace_if_needed(e) for e in events] + root_ops = AnalyzedTrace._find_all_root_events_with_flops(events) + + operations_per_dtype_fw: Dict[torch.dtype, float] = defaultdict(float) + operations_per_dtype_bw: Dict[torch.dtype, float] = defaultdict(float) + # We detect BW pass ops based on their thread id + all_bw_threads = {e.start_thread_id() for e in events if e.fwd_thread_id() > 0} + # Find total dt + ATEN_DTYPES = [ + # NOTE: A single torch.dtype per number of bits + # (eg so we map bf16 --> b16) + ("double", torch.float64), + ("float", torch.float), + ("c10::Half", torch.float16), + ("c10::BFloat16", torch.float16), + ("c10::Int8", torch.int8), + ] + begin_ns, end_ns = math.inf, 0 + for op in root_ops: + dtype = None + for aten_dtype, torch_dtype in ATEN_DTYPES: + if aten_dtype in op.dtypes(): + dtype = torch_dtype + break + if dtype is None: # ??? + continue + if op.start_thread_id() in all_bw_threads: + operations_per_dtype_bw[dtype] += op.flops() + else: + operations_per_dtype_fw[dtype] += op.flops() + for op in events: + if op.device_type().name != "CUDA": + continue + begin_ns = min(begin_ns, op.start_ns()) + end_ns = max(end_ns, op.start_ns() + op.duration_ns()) + + return AnalyzedTrace( + operations_per_dtype_fw=operations_per_dtype_fw, + operations_per_dtype_bw=operations_per_dtype_bw, + total_time_s=(end_ns - begin_ns) / (10**9), + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/profiler.py b/.venv/lib/python3.11/site-packages/xformers/profiler/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..e50ab9284da1e7678109ac087e12e26118e40e20 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/profiler.py @@ -0,0 +1,350 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import queue +import socket +import time +import weakref +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch.cuda.memory +import torch.cuda.nvtx +import torch.nn as nn +import torch.profiler + +from .device_limits import get_device_limits +from .profile_analyzer import AnalyzedTrace + +logger = logging.getLogger(__name__) + + +class NsightProfiler: + """Profiler that triggers start of NSight profiler. + + NOTE: you need to ensure that the script running this code actually is running with + ``nsys profile`` and also has a flag ``--capture-range=cudaProfilerApi`` so the + capturing is performed by this profiler during certain steps. + """ + + def __init__(self, main_profiler: "_Profiler") -> None: + self.main_profiler = main_profiler + # TODO figure out if there is a way to know if nsys is launched at this point + + def __enter__(self): + torch.cuda.profiler.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.cuda.profiler.stop() + + def step(self) -> None: + pass + + +class PyTorchProfiler: + """Profiler which relies on native Pytorch profiling. Current setting of the profiler + captures traces, memory footprint and other info that could be read via TensorBoard. + """ + + ACTIVITIES = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + + def __init__(self, main_profiler: "_Profiler") -> None: + self.main_profiler = main_profiler + self.num_steps = 0 + self.pytorch_profiler = torch.profiler.profile( + on_trace_ready=self._on_trace, + profile_memory=True, + record_shapes=True, + with_stack=True, + with_flops=True, + activities=self.ACTIVITIES, + ) + + def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: + activities_str = "_".join(a.name for a in self.ACTIVITIES) + dir_name = str( + self.main_profiler.output_dir + / f"profile_{activities_str}_{self.main_profiler.done_steps:06}" + ) + worker_name = self.main_profiler.worker_name + if worker_name == "": + worker_name = f"{socket.gethostname()}_{os.getpid()}" + os.makedirs(dir_name, exist_ok=True) + file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json.gz" + prof.export_chrome_trace(os.path.join(dir_name, file_name)) + try: + self._analyze_trace(prof) + except Exception as exc: + self.main_profiler.summary.append(("TraceAnalysis", "Error")) + logger.warn("Exception analyzing kineto trace", exc_info=exc) + + def _analyze_trace(self, prof: torch.profiler.profiler.profile) -> None: + if prof.profiler is None or prof.profiler.kineto_results is None: + return + results = AnalyzedTrace.from_profile(prof.profiler.kineto_results.events()) + limits = get_device_limits(torch.device("cuda")) + hw_flops: Dict[torch.dtype, float] = {} + if limits is not None: + for dtype, tflops in limits.gemm_tflops.items(): + hw_flops[dtype] = tflops * (1000**4) + total_hfu = results.compute_hfu(hw_flops) + total_mfu = results.compute_mfu(hw_flops) + total_flop = sum( + results.compute_num_ops(dtype) + for dtype in results.operations_per_dtype_fw.keys() + ) + s = self.main_profiler.summary + s.append( + ("Step time (ms)", f"{int(results.total_time_s * 1000 / self.num_steps)}") + ) + s.append(("TFlop/step", f"{total_flop / (self.num_steps * 1000**4):0.1f}")) + s.append(("TFlops", f"{total_flop / (results.total_time_s * 1000**4):0.1f}")) + s.append(("HFU", f"{total_hfu:0.3f}")) + s.append(("MFU", f"{total_mfu:0.3f}")) + + def __enter__(self): + torch.cuda.synchronize() + self.pytorch_profiler.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.cuda.synchronize() + self.pytorch_profiler.__exit__(exc_type, exc_val, exc_tb) + + def step(self) -> None: + self.pytorch_profiler.step() + self.num_steps += 1 + + +class PyTorchProfiler_CUDAOnly(PyTorchProfiler): + # This profiler does not profile the CPU-side of things + # so we expect it to have almost no overhead + ACTIVITIES = [torch.profiler.ProfilerActivity.CUDA] + + def _analyze_trace(self, prof: torch.profiler.profiler.profile) -> None: + # Can't analyze trace without CPU trace for operator shapes etc... + pass + + +class MemSnapshotsProfiler: + """Profiler that captures memory traces for allocation and deallocation of memory for + tensors. + """ + + def __init__(self, main_profiler: "_Profiler") -> None: + self.main_profiler = main_profiler + self.enabled = False + + @property + def _has_trace_plot(self) -> bool: + return hasattr(torch.cuda._memory_viz, "trace_plot") + + def __enter__(self): + if not self._has_trace_plot: + return + self.enabled = True + # TODO: This does not show the previous memory allocations + # We could at least have a placeholder with how much + # memory was allocated before + torch.cuda.memory._record_memory_history( + True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + # record stack information for the trace events + trace_alloc_record_context=True, + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._has_trace_plot: + self.main_profiler.summary.append( + ("MemTrace", "(not available with your Pytorch version)") + ) + return + assert self.enabled + snapshot = torch.cuda.memory._snapshot() + torch.cuda.memory._record_memory_history(False) + # No data was recorded - avoids a `ValueError` in `trace_plot` + if all(len(t) == 0 for t in snapshot["device_traces"]): + self.main_profiler.summary.append(("MemTrace", "(no allocation recorded)")) + return + # Dump to disk + filename = self.main_profiler._create_output_filename("memory_trace_plot.html") + self.main_profiler.summary.append(("MemTrace", filename)) + with open(filename, "w+") as fd: + fd.write( + torch.cuda._memory_viz.trace_plot( + snapshot, device=None, plot_segments=False + ) + ) + + def step(self) -> None: + pass + + +@dataclass +class _ProfilerState: + cls: Any + iter_begin: int + iter_end: int + object: Any = None + + +class _Profiler: + _CURRENT_PROFILER = None + + def __init__( + self, + output_dir: str, + schedule: Sequence[Tuple[Any, int, int]], + module: Optional[nn.Module], + ) -> None: + self.check_schedule(schedule) + self.schedule = schedule + self.done_steps = 0 + self.output_dir = Path(output_dir).absolute() + self.output_dir.mkdir(exist_ok=True, parents=True) + self.worker_name = "" + if torch.distributed.is_initialized(): + self.worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid())) + + self.module = weakref.ref(module if module is not None else nn.Module()) + self.init_schedule() + + def init_schedule(self, offset: int = 0) -> None: + self.profilers: List[_ProfilerState] = sorted( + [ + _ProfilerState(cls, begin + offset, end + offset) + for cls, begin, end in self.schedule + ], + key=lambda x: x.iter_begin, + ) + self.last_step = self.profilers[-1].iter_end if self.profilers else 0 + self.summary: List[Tuple[str, str]] = [] + + def check_schedule(self, schedule: Sequence[Tuple[Any, int, int]]) -> None: + if len(schedule) == 0: + logger.warning( + "You specified empty schedule for profiling. No data will be captured." + ) + + pq: Any = queue.PriorityQueue() + for cls, begin, end in schedule: + assert ( + begin >= 0 + ), f"Begin step of profiler must be non-negative, found: {begin}" + assert end > 0, f"End step of profiler must be positive, found: {end}" + assert ( + begin < end + ), f"Start must be before the end, found: begin={begin} and end={end}" + + pq.put((begin, end)) + + prev_end = -1 + for begin, end in pq.queue: + assert begin >= prev_end, ( + "There is some overlapping in profiler scheduling. Please do not" + + " overlap profilers by step as they may affect each other. Schedule:" + + f" {schedule}" + ) + prev_end = end + + def update_profilers_on_step(self) -> None: + for p in self.profilers: + if p.iter_begin <= self.done_steps and self.done_steps < p.iter_end: + if p.object is None: + o = p.cls(self) + logging.info(f"Starting {p.cls.__name__} profiler...") + o.__enter__() + p.object = o + else: + p.object.step() + else: + if p.object is not None: + o = p.object + p.object = None + logging.info(f"Shutting down {p.cls.__name__} profiler...") + # Make sure the profiler's `step` function is called + # $N times when we do $N steps with this profiler. + o.step() + o.__exit__(None, None, None) + + def _create_output_filename(self, filename: str) -> Path: + """ + Returns where to write a file with desired filename. + Handles the case where we are in distributed settings, or when + we need to output the same file multiple times (eg if a profiler + runs for several steps) + """ + if self.worker_name != "": + file = Path(filename) + folder = self.output_dir / file.stem + folder.mkdir(parents=True, exist_ok=True) + return folder / f"{self.done_steps:06}_{self.worker_name}{file.suffix}" + return self.output_dir / f"{self.done_steps:06}_{filename}" + + def start(self): + self.__enter__() + + def stop(self, exc_type=None, exc_val=None, exc_tb=None): + self.__exit__(exc_type, exc_val, exc_tb) + + def __enter__(self): + if _Profiler._CURRENT_PROFILER is not None: + raise ValueError("Only one xformers profiler can be active at a time") + _Profiler._CURRENT_PROFILER = self + self.update_profilers_on_step() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _Profiler._CURRENT_PROFILER = None + + for p in self.profilers: + if p.object is not None: + p.object.__exit__(exc_type, exc_val, exc_tb) + + def step(self) -> None: + """Signals the profiler that the next profiling step has started.""" + self.done_steps += 1 + + if self.done_steps <= self.last_step: + self.update_profilers_on_step() + if self.done_steps == self.last_step: + logger.info("xFormers profiler done. %s", self.format_summary()) + + # Check if we triggered a manual profile step + CHECK_TRIGGER_EVERY = 10 + if ( + self.done_steps > self.last_step + and (self.done_steps % CHECK_TRIGGER_EVERY) == 0 + ): + try: + (self.output_dir / "trigger").unlink() + ( + self.output_dir + / f"trigger.{self.done_steps + CHECK_TRIGGER_EVERY:09}" + ).write_text(self.worker_name) + except FileNotFoundError: + pass + step_trigger = self.output_dir / f"trigger.{self.done_steps:09}" + if step_trigger.exists(): + logger.info( + "xFormers profiler manually triggered at step %d", self.done_steps + ) + self.init_schedule(offset=self.done_steps + 1) + + def format_summary(self) -> str: + if len(self.summary) == 0: + return "" + pad_titles = max(len(title) for title, value in self.summary) + return "summary:\n" + "\n".join( + [f" {title.ljust(pad_titles)}: {value}" for title, value in self.summary] + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/profiler_dcgm.py b/.venv/lib/python3.11/site-packages/xformers/profiler/profiler_dcgm.py new file mode 100644 index 0000000000000000000000000000000000000000..389cadd2ca13b8f102b22d35310bdfbebd64dbe5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/profiler_dcgm.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import sys + +from .profiler import _Profiler, logger + +DCGM_PROFILER_AVAILABLE = False +try: + DCGM_PYTHON_PATH: str = "/usr/local/dcgm/bindings/python3" + sys.path.insert(0, DCGM_PYTHON_PATH) + from .profiler_dcgm_impl import DCGMProfiler + + DCGM_PROFILER_AVAILABLE = True +except ModuleNotFoundError: + + class DCGMProfiler: # type: ignore + """The dummy DCGM Profiler.""" + + def __init__( + self, + main_profiler: "_Profiler", + gpus_to_profile=None, + field_ids_to_profile=None, + updateFreq=None, + ) -> None: + pass + + def __enter__(self) -> None: + logger.warning( + f"Unable to find python bindings at {DCGM_PYTHON_PATH}. " + "No data will be captured." + ) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + + def step(self) -> None: + pass + + +del sys.path[0] diff --git a/.venv/lib/python3.11/site-packages/xformers/profiler/profiler_dcgm_impl.py b/.venv/lib/python3.11/site-packages/xformers/profiler/profiler_dcgm_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..41bb32767dedef9975ff5b1826bd25f1aacb708e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/profiler/profiler_dcgm_impl.py @@ -0,0 +1,216 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional, Set, Tuple, Union + +import dcgm_fields +import torch +from dcgm_fields import DcgmFieldGetById +from dcgm_structs import DCGM_GROUP_EMPTY, DCGM_OPERATION_MODE_AUTO +from pydcgm import DcgmFieldGroup, DcgmGroup, DcgmHandle + +from .profiler import _Profiler, logger + + +class DCGMProfiler: + """Profiler that triggers start of DCGM profiler.""" + + def __init__( + self, + main_profiler: "_Profiler", + gpus_to_profile: Optional[Tuple[int, ...]] = None, + field_ids_to_profile=( + dcgm_fields.DCGM_FI_PROF_SM_ACTIVE, + dcgm_fields.DCGM_FI_PROF_SM_OCCUPANCY, + dcgm_fields.DCGM_FI_PROF_PIPE_TENSOR_ACTIVE, + dcgm_fields.DCGM_FI_PROF_DRAM_ACTIVE, + dcgm_fields.DCGM_FI_PROF_PCIE_TX_BYTES, + dcgm_fields.DCGM_FI_PROF_PCIE_RX_BYTES, + dcgm_fields.DCGM_FI_PROF_NVLINK_TX_BYTES, + dcgm_fields.DCGM_FI_PROF_NVLINK_RX_BYTES, + ), + updateFreq: int = 5000, # in microseconds + ) -> None: + """ + Args: + main_profiler: The main profiler object. + gpus_to_profile: A tuple of integers representing the GPUs to profile. If `None`, + then the "default" GPU is used. + field_ids_to_profile: + See https://github.com/NVIDIA/DCGM/blob/master/testing/python3/dcgm_fields.py#L436 + for a full list of available fields. Note that not all fields are profilable. + updateFreq: The interval of two consecutive updates of each field. Defaults to 5000 microseconds. + This is a good tradeoff between performance and accuracy. + An even smaller updateFreq is not supported well by A100. + If the step to profile takes more than 5000 microseconds, then a larger updateFreq could also be used. + """ + self.main_profiler = main_profiler + self.updateFreq = updateFreq + + self.dcgmHandle = DcgmHandle( + ipAddress="127.0.0.1", opMode=DCGM_OPERATION_MODE_AUTO + ) + + if gpus_to_profile is None: + default_gpu: int = torch.empty([], device="cuda").device.index + self.dcgmGroup = self.create_dcgm_group((default_gpu,)) + else: + self.dcgmGroup = self.create_dcgm_group(gpus_to_profile) + + self.dcgmFieldGroup = self.create_profiling_field_group(field_ids_to_profile) + + def create_dcgm_group( + self, gpus_to_profile: Union[Tuple[int], Tuple[int, ...]] + ) -> Optional[DcgmGroup]: + if self.dcgmHandle is None: + return None + + dcgmSystem = self.dcgmHandle.GetSystem() + supportedGPUs = dcgmSystem.discovery.GetAllSupportedGpuIds() + + valid_gpus_to_profile: List[int] = [ + gpu for gpu in gpus_to_profile if gpu in supportedGPUs + ] + if len(valid_gpus_to_profile) < 1: + logger.warning( + f"The provided GPUs are not supported on this system: " + f"provided {gpus_to_profile}, supported {supportedGPUs}. " + f"No data will be captured." + ) + return None + + dcgmGroup = DcgmGroup( + self.dcgmHandle, + groupName="DCGMProfiler", + groupType=DCGM_GROUP_EMPTY, + ) + + for gpu in valid_gpus_to_profile: + dcgmGroup.AddGpu(gpu) + + return dcgmGroup + + def get_profilable_fields(self) -> Set[int]: + assert self.dcgmGroup is not None + + dcgmMetricGroups = self.dcgmGroup.profiling.GetSupportedMetricGroups() + profilableFieldIds = set() + for group_idx in range(dcgmMetricGroups.numMetricGroups): + metric_group = dcgmMetricGroups.metricGroups[group_idx] + for field_id in metric_group.fieldIds[: metric_group.numFieldIds]: + profilableFieldIds.add(field_id) + return profilableFieldIds + + def create_profiling_field_group( + self, + fieldIdsToProfile: Optional[Tuple[int, ...]], + ) -> Optional[DcgmFieldGroup]: + if self.dcgmGroup is None: + return None + + # Get all field ids that can be profiled. + profilableFieldIds = self.get_profilable_fields() + + # Check which of the provided field ids are valid and invalid. + if fieldIdsToProfile is None: + validFieldIds = list(profilableFieldIds) + invalidFieldIds = [] + else: + validFieldIds = [ + field_id + for field_id in fieldIdsToProfile + if field_id in profilableFieldIds + ] + invalidFieldIds = [ + field_id + for field_id in fieldIdsToProfile + if field_id not in profilableFieldIds + ] + + if not validFieldIds: + logger.warning( + "None of the provided field ids could be profiled.\n" + f" Provided: {fieldIdsToProfile}\n" + f" Supported: {profilableFieldIds}\n" + "No data will be captured." + ) + return None + + if invalidFieldIds: + logger.warning( + f"The following field ids cannot be profiled: {invalidFieldIds}. " + f"Profiling {validFieldIds} only." + ) + dcgmFieldGroup = DcgmFieldGroup( + self.dcgmHandle, name="Profiling", fieldIds=validFieldIds + ) + return dcgmFieldGroup + + def __enter__(self) -> None: + if self.dcgmGroup is not None and self.dcgmFieldGroup is not None: + self.dcgmGroup.samples.WatchFields( + self.dcgmFieldGroup, self.updateFreq, 3600, 0 + ) + + # Start collecting the profiling results run in background. + self.profiling_results = self.dcgmGroup.samples.GetAllSinceLastCall( + None, self.dcgmFieldGroup + ) + + # It is necessary to call GetAllSinceLastCall and EmptyValues twice + # to clear old data from previous profilings + # (otherwise the new profiling data is appended to the old data from previous profiling). + self.dcgmGroup.samples.GetAllSinceLastCall( + self.profiling_results, self.dcgmFieldGroup + ) + self.profiling_results.EmptyValues() + + self.dcgmGroup.samples.GetAllSinceLastCall( + self.profiling_results, self.dcgmFieldGroup + ) + self.profiling_results.EmptyValues() + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self.dcgmGroup is not None and self.dcgmFieldGroup is not None: + self.dcgmGroup.samples.UnwatchFields(self.dcgmFieldGroup) + + # Delete the group. + self.dcgmGroup.Delete() + del self.dcgmGroup + self.dcgmGroup = None + + # Disconnect from the hostengine by deleting the DcgmHandle object. + del self.dcgmHandle + self.dcgmHandle = None + + def step(self) -> None: + if self.dcgmGroup is not None and self.dcgmFieldGroup is not None: + # Collect the profiling results. + self.dcgmGroup.samples.GetAllSinceLastCall( + self.profiling_results, self.dcgmFieldGroup + ) + + # Save profiling results to log. + for gpu_id in self.profiling_results.values.keys(): + for field_id in self.profiling_results.values[gpu_id].keys(): + field_name = DcgmFieldGetById(field_id).tag + + field_avg_val = 0.0 + num_vals = 0 + for gpu_field_time in self.profiling_results.values[gpu_id][ + field_id + ]: + if gpu_field_time.value is not None: + field_avg_val = ( + field_avg_val * num_vals + gpu_field_time.value + ) / (num_vals + 1) + num_vals += 1 + self.main_profiler.summary.append( + (f"GPU {gpu_id}, {field_name}({field_id})", f"{field_avg_val}") + ) + + # Clear the profiling results to get ready for the next collection. + self.profiling_results.EmptyValues()