koichi12 commited on
Commit
c1012a5
·
verified ·
1 Parent(s): 1e8ae5e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/xformers/components/__init__.py +86 -0
  2. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/activations.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/input_projection.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/multi_head_dispatch.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/patch_embedding.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/residual.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/reversible.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/xformers/components/__pycache__/simplicial_embedding.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/xformers/components/activations.py +76 -0
  11. .venv/lib/python3.11/site-packages/xformers/components/attention/attention_patterns.py +295 -0
  12. .venv/lib/python3.11/site-packages/xformers/components/attention/core.py +248 -0
  13. .venv/lib/python3.11/site-packages/xformers/components/attention/favor.py +173 -0
  14. .venv/lib/python3.11/site-packages/xformers/components/attention/fourier_mix.py +35 -0
  15. .venv/lib/python3.11/site-packages/xformers/components/attention/lambda_layer.py +78 -0
  16. .venv/lib/python3.11/site-packages/xformers/components/attention/local.py +120 -0
  17. .venv/lib/python3.11/site-packages/xformers/components/attention/nystrom.py +295 -0
  18. .venv/lib/python3.11/site-packages/xformers/components/attention/random.py +126 -0
  19. .venv/lib/python3.11/site-packages/xformers/components/attention/scaled_dot_product.py +134 -0
  20. .venv/lib/python3.11/site-packages/xformers/components/attention/visual.py +96 -0
  21. .venv/lib/python3.11/site-packages/xformers/components/input_projection.py +102 -0
  22. .venv/lib/python3.11/site-packages/xformers/components/multi_head_dispatch.py +271 -0
  23. .venv/lib/python3.11/site-packages/xformers/components/patch_embedding.py +83 -0
  24. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__init__.py +87 -0
  25. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/__init__.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/base.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/param.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/rotary.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/sine.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/vocab.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/base.py +38 -0
  32. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/param.py +54 -0
  33. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/rotary.py +91 -0
  34. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/sine.py +46 -0
  35. .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/vocab.py +65 -0
  36. .venv/lib/python3.11/site-packages/xformers/components/residual.py +192 -0
  37. .venv/lib/python3.11/site-packages/xformers/components/reversible.py +160 -0
  38. .venv/lib/python3.11/site-packages/xformers/components/simplicial_embedding.py +67 -0
  39. .venv/lib/python3.11/site-packages/xformers/ops/__init__.py +130 -0
  40. .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/__init__.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/xformers/ops/_triton/k_index_select_cat.py +184 -0
  42. .venv/lib/python3.11/site-packages/xformers/ops/_triton/k_scaled_index_add.py +365 -0
  43. .venv/lib/python3.11/site-packages/xformers/ops/_triton/rmsnorm_kernels.py +163 -0
  44. .venv/lib/python3.11/site-packages/xformers/ops/_triton/rope_padded_kernels.py +226 -0
  45. .venv/lib/python3.11/site-packages/xformers/ops/_triton/tiled_matmul_kernels.py +430 -0
  46. .venv/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py +893 -0
  47. .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/attn_bias.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_decoder.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/xformers/components/__init__.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import warnings
8
+ from dataclasses import fields
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Union
11
+
12
+ from xformers.utils import import_all_modules
13
+
14
+ from .activations import Activation, build_activation # noqa
15
+ from .attention import Attention, build_attention # noqa
16
+ from .input_projection import InputProjection, InputProjectionConfig # noqa
17
+ from .multi_head_dispatch import MultiHeadDispatch # noqa
18
+ from .multi_head_dispatch import MultiHeadDispatchConfig
19
+ from .patch_embedding import PatchEmbeddingConfig # noqa
20
+ from .patch_embedding import build_patch_embedding # noqa
21
+ from .residual import NormalizationType # noqa
22
+ from .residual import PostNorm # noqa
23
+ from .residual import PreNorm # noqa
24
+ from .residual import RequiresWrappedInputs # noqa
25
+ from .residual import Residual # noqa
26
+ from .residual import ResidualNormStyle # noqa
27
+
28
+ warnings.warn(
29
+ "xformers.components is deprecated and is not maintained anymore. "
30
+ "It might be removed in a future version of xFormers ",
31
+ FutureWarning,
32
+ stacklevel=2,
33
+ )
34
+
35
+
36
+ # automatically import any Python files in the directory
37
+ import_all_modules(str(Path(__file__).parent), "xformers.components")
38
+
39
+
40
+ def build_multi_head_attention(
41
+ multi_head_config: Union[MultiHeadDispatchConfig, Dict[str, Any]],
42
+ ):
43
+ """Builds a multihead attention from a config.
44
+
45
+ This assumes a 'name' key in the config which is used to determine what
46
+ attention class to instantiate. For instance, a config `{"name": "my_attention",
47
+ "foo": "bar"}` will find a class that was registered as "my_attention"
48
+ (see :func:`register_attention`) and call .from_config on it."""
49
+
50
+ if not isinstance(multi_head_config, MultiHeadDispatchConfig):
51
+ # Extract the required fields
52
+ field_names = list(map(lambda x: x.name, fields(MultiHeadDispatchConfig)))
53
+
54
+ # The missing fields get Noned
55
+ for k in field_names:
56
+ if k not in multi_head_config.keys():
57
+ multi_head_config[k] = None
58
+
59
+ # Could be that the attention needs to be instantiated
60
+ if not isinstance(multi_head_config["attention"], Attention):
61
+ # Convenience: fill in possible missing fields
62
+ if "num_heads" not in multi_head_config["attention"]:
63
+ multi_head_config["attention"]["num_heads"] = multi_head_config[
64
+ "num_heads"
65
+ ]
66
+
67
+ if "dim_model" not in multi_head_config["attention"]:
68
+ multi_head_config["attention"]["dim_model"] = multi_head_config[
69
+ "dim_model"
70
+ ]
71
+
72
+ if (
73
+ "dim_features" not in multi_head_config["attention"]
74
+ or multi_head_config["attention"]["dim_features"] is None
75
+ ):
76
+ multi_head_config["attention"]["dim_features"] = (
77
+ multi_head_config["dim_model"] // multi_head_config["num_heads"]
78
+ )
79
+
80
+ multi_head_config["attention"] = build_attention(
81
+ multi_head_config["attention"]
82
+ )
83
+
84
+ multi_head_config = MultiHeadDispatchConfig(**multi_head_config)
85
+
86
+ return MultiHeadDispatch.from_config(multi_head_config)
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.57 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/activations.cpython-311.pyc ADDED
Binary file (4.59 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/input_projection.cpython-311.pyc ADDED
Binary file (3.97 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/multi_head_dispatch.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/patch_embedding.cpython-311.pyc ADDED
Binary file (4.53 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/residual.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/reversible.cpython-311.pyc ADDED
Binary file (9.78 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/simplicial_embedding.cpython-311.pyc ADDED
Binary file (3.5 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/activations.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from enum import Enum
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ from xformers._deprecation_warning import deprecated_function
14
+
15
+
16
+ class Activation(str, Enum):
17
+ SquaredReLU = "squared_relu"
18
+ GeLU = "gelu"
19
+ LeakyReLU = "leaky_relu"
20
+ ReLU = "relu"
21
+ SmeLU = "smelu"
22
+ StarReLU = "star_relu"
23
+
24
+
25
+ # For unit testing / parity comparisons, probably not the fastest way
26
+ class SquaredReLU(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ deprecated_function(self)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ x_ = torch.nn.functional.relu(x)
33
+ return x_ * x_
34
+
35
+
36
+ class StarReLU(nn.Module):
37
+ def __init__(self) -> None:
38
+ super().__init__()
39
+ deprecated_function(self)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ x_ = torch.nn.functional.relu(x)
43
+ return 0.8944 * x_ * x_ - 0.4472
44
+
45
+
46
+ class SmeLU(nn.Module):
47
+ def __init__(self, beta: float = 2.0) -> None:
48
+ super().__init__()
49
+ self.beta = beta
50
+ deprecated_function(self)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ relu = torch.where(
54
+ x >= self.beta,
55
+ x,
56
+ torch.tensor([0.0], device=x.device, dtype=x.dtype),
57
+ )
58
+ return torch.where(
59
+ torch.abs(x) <= self.beta,
60
+ ((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta),
61
+ relu,
62
+ )
63
+
64
+
65
+ def build_activation(activation: Optional[Activation]):
66
+ if not activation:
67
+ return nn.Identity()
68
+
69
+ return {
70
+ Activation.ReLU: nn.ReLU,
71
+ Activation.GeLU: nn.GELU,
72
+ Activation.LeakyReLU: nn.LeakyReLU,
73
+ Activation.SquaredReLU: SquaredReLU,
74
+ Activation.StarReLU: StarReLU,
75
+ Activation.SmeLU: SmeLU,
76
+ }[activation]()
.venv/lib/python3.11/site-packages/xformers/components/attention/attention_patterns.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import math
8
+ from typing import List
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from xformers.components.attention.sparsity_config import (
14
+ BigBirdSparsityConfig,
15
+ BSLongformerSparsityConfig,
16
+ FixedSparsityConfig,
17
+ VariableSparsityConfig,
18
+ )
19
+
20
+
21
+ # generic nd cases
22
+ def _generate_nd_grid(*sizes):
23
+ coords = [torch.arange(s) for s in sizes]
24
+ return torch.meshgrid(*coords)
25
+
26
+
27
+ def local_nd_distance(*sizes, p=2.0, weights=None):
28
+ if weights is None:
29
+ weights = (1,) * len(sizes)
30
+ assert len(sizes) == len(weights)
31
+ grid = _generate_nd_grid(*sizes)
32
+ grid = [i.flatten() * w for i, w in zip(grid, weights)]
33
+ grid = torch.stack(grid, dim=1).float()
34
+ d = torch.cdist(grid, grid, p=p)
35
+ return d
36
+
37
+
38
+ def local_nd_gaussian_distribution(*sizes, sigma=1):
39
+ d = local_nd_distance(*sizes, p=2.0) ** 2
40
+ d = torch.exp(-0.5 * sigma ** (-2.0) * d)
41
+ return d
42
+
43
+
44
+ def local_nd_pattern(*sizes, distance, p=2.0):
45
+ d = local_nd_distance(*sizes, p=p)
46
+ return d < distance
47
+
48
+
49
+ def axial_nd_pattern(*sizes):
50
+ # axial is a special case with p=0 and distance=2
51
+ d = local_nd_distance(*sizes, p=0)
52
+ return d < 2
53
+
54
+
55
+ def random_pattern_from_probability_matrix(dist_matrix, nnz):
56
+ att = torch.zeros_like(dist_matrix, dtype=torch.bool)
57
+ # PyTorch multinomial wrongly doesn't support sampling when number of categories
58
+ # is > 2^24, arguing that it's because it's the max representable consecutive element
59
+ # in fp32 and that the kernels use float32. This is actually not true, and the kernels
60
+ # should work fine if double tensor is passed on CPU. This is a bug that was introduced
61
+ # in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66
62
+ # when unifying the checks between CPU and CUDA. For now, just fall-back to numpy
63
+ if dist_matrix.numel() > 2**24:
64
+ dist_matrix = dist_matrix.double()
65
+ dist_matrix /= dist_matrix.sum()
66
+ idxs = np.random.choice(
67
+ dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False
68
+ )
69
+ idxs = torch.as_tensor(idxs)
70
+ else:
71
+ idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
72
+ att.view(-1)[idxs] = True
73
+ return att
74
+
75
+
76
+ def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Tensor:
77
+ assert attention_query_mask.ndim == 1
78
+ assert attention_query_mask.dtype == torch.bool
79
+ attention_query_mask = attention_query_mask[None, :]
80
+ mask = attention_query_mask | attention_query_mask.transpose(1, 0)
81
+ return mask
82
+
83
+
84
+ def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor:
85
+ assert 0 < sparsity < 1
86
+ mask = torch.rand(attn_size, attn_size) > sparsity
87
+ return mask
88
+
89
+
90
+ # 1d-specific cases
91
+ def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor:
92
+ assert (
93
+ window_size % 2 == 1
94
+ ), "The window size is assumed to be odd (counts self-attention + 2 wings)"
95
+ h_win_size = window_size // 2 + 1
96
+ return local_nd_pattern(attn_size, distance=h_win_size, p=1.0)
97
+
98
+
99
+ def causal_1d_pattern(attn_size: int) -> torch.Tensor:
100
+ mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool))
101
+ return mask
102
+
103
+
104
+ # 2d-specific cases
105
+ def horizontal_axial_2d_distance(H, W, p=2.0):
106
+ d = local_nd_distance(H, W, p=p, weights=(1, 0))
107
+ return d
108
+
109
+
110
+ def vertical_axial_2d_distance(H, W, p=2.0):
111
+ d = local_nd_distance(H, W, p=p, weights=(0, 1))
112
+ return d
113
+
114
+
115
+ def local_2d_distance(H, W, p=2.0):
116
+ return local_nd_distance(H, W, p=p)
117
+
118
+
119
+ def local_2d_gausian_distribution(H, W, sigma=1):
120
+ return local_nd_gaussian_distribution(H, W, sigma=sigma)
121
+
122
+
123
+ def local_2d_pattern(H, W, distance, p=2.0):
124
+ return local_nd_pattern(H, W, distance=distance, p=p)
125
+
126
+
127
+ def axial_2d_pattern(H, W):
128
+ return axial_nd_pattern(H, W)
129
+
130
+
131
+ def swin_attention_pattern(H, W, window_size, shift_size=0):
132
+ assert H % window_size == 0
133
+ assert W % window_size == 0
134
+ assert 0 <= shift_size < window_size, "shift_size must in 0-window_size"
135
+
136
+ # input grid
137
+ i, j = _generate_nd_grid(H, W)
138
+ i, j = i + 0.5, j + 0.5
139
+
140
+ # anchors grid
141
+ # if shift is present, add extra element to the grid
142
+ # to account for the uneven partitioning
143
+ extra = int(shift_size % window_size != 0)
144
+ grid_h = H // window_size + extra
145
+ grid_w = W // window_size + extra
146
+
147
+ ii, jj = _generate_nd_grid(grid_h, grid_w)
148
+ # convert shift to be compatible with the paper representation
149
+ s = (-shift_size) % window_size
150
+ offset = window_size / 2 - s
151
+ ii = ii * window_size + offset
152
+ jj = jj * window_size + offset
153
+
154
+ input_coords = torch.stack([i.flatten(), j.flatten()], 1).float()
155
+ anchors_coords = torch.stack([ii.flatten(), jj.flatten()], 1).float()
156
+
157
+ anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1)
158
+ mask = anchor_id[:, None] == anchor_id[None, :]
159
+ return mask
160
+
161
+
162
+ def dilated_2d_pattern(H, W, k=2):
163
+ """
164
+ Returns a 2d pattern that samples 1 every k elements in the attention mask.
165
+ Can be seen as a form of downsampling, where every pixel attends to a downsampled
166
+ version of the input.
167
+ """
168
+ d_h = local_nd_distance(H, W, p=1, weights=(1, 0))
169
+ d_w = local_nd_distance(H, W, p=1, weights=(0, 1))
170
+ d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0)
171
+ return d
172
+
173
+
174
+ # Block sparse utils
175
+ def block_sparsify_tensor(x, mask, block_size):
176
+ """
177
+ Block sparsify a tensor, given a mask and block size
178
+ """
179
+ ret = torch.empty(
180
+ (x.size(0), mask.sum(), block_size, block_size), dtype=x.dtype, device=x.device
181
+ )
182
+
183
+ for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
184
+ ret[:, idx, :, :] = x[
185
+ :,
186
+ h,
187
+ i * block_size : (i + 1) * block_size,
188
+ j * block_size : (j + 1) * block_size,
189
+ ]
190
+ return ret
191
+
192
+
193
+ def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor:
194
+ r"""
195
+ Given a mask pattern and blocksize, return the corresponding layout
196
+ which makes sure that all the positives in the mask are covered
197
+ """
198
+ assert mask.ndim >= 2, "We're expecting [Heads, Seq, Seq] or [Seq, Seq]"
199
+ _should_squeeze = False
200
+
201
+ if mask.ndim == 2:
202
+ mask = mask.unsqueeze(0)
203
+ _should_squeeze = True
204
+
205
+ assert (
206
+ mask.shape[1] % block_size == 0 and mask.shape[2] % block_size == 0
207
+ ), "We're only handling masks divisible by block_size"
208
+
209
+ # Now mark the mask
210
+ layout = torch.nn.functional.max_pool2d(
211
+ mask.to(torch.float), kernel_size=block_size, stride=block_size
212
+ )
213
+ layout = layout.to(torch.long)
214
+
215
+ if _should_squeeze:
216
+ layout.squeeze_(0)
217
+
218
+ return layout
219
+
220
+
221
+ def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Tensor:
222
+ r"""
223
+ Use the additive bias computation from ALiBi_ to generate a mask.
224
+ Note that this mask can in turn be used to generate a blocksparse attention computation layout
225
+
226
+ .. note: mask_shape is expected to hold the [heads, seq, seq] dimensions
227
+
228
+ .. _ALiBi: https://arxiv.org/pdf/2108.12409.pdf
229
+ """
230
+
231
+ # CREDITS: code snippet from Ofir Press, one of the authors
232
+
233
+ def get_slopes(n: int):
234
+ def get_slopes_power_of_2(n: int) -> List[float]:
235
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
236
+ ratio = start
237
+ return [start * ratio**i for i in range(n)]
238
+
239
+ # In the paper, we only train models that have 2^a heads for some a. This function has
240
+ # some good properties that only occur when the input is a power of 2. To maintain that even
241
+ # when the number of heads is not a power of 2, we use this workaround.
242
+ if math.log2(n).is_integer():
243
+ return get_slopes_power_of_2(n)
244
+ else:
245
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
246
+ return (
247
+ get_slopes_power_of_2(closest_power_of_2)
248
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
249
+ )
250
+
251
+ maxpos = mask_shape[1]
252
+ attn_heads = mask_shape[0]
253
+ slopes = torch.Tensor(get_slopes(attn_heads))
254
+
255
+ # In the next line, the part after the * is what constructs the diagonal matrix
256
+ # (right matrix in Figure 3 in the paper).
257
+ # If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3,
258
+ # but one where all rows are identical.
259
+ # This works because the softmax operation is invariant to translation,
260
+ # and our bias functions are always linear.
261
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(
262
+ 0
263
+ ).unsqueeze(0).expand(attn_heads, -1, -1)
264
+ alibi = alibi.view(attn_heads, 1, maxpos)
265
+
266
+ # Now threshold arbitrarily, report the mask
267
+ return alibi < threshold
268
+
269
+
270
+ def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int):
271
+ config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size)
272
+ return config.make_layout(seq_len)
273
+
274
+
275
+ def quick_variable_layout(num_heads: int, block_size: int, seq_len: int):
276
+ config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size)
277
+ return config.make_layout(seq_len)
278
+
279
+
280
+ def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int):
281
+ config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size)
282
+ return config.make_layout(seq_len)
283
+
284
+
285
+ def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int):
286
+ config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size)
287
+ return config.make_layout(seq_len)
288
+
289
+
290
+ def layout_to_pattern(layout: torch.Tensor, block_size: int):
291
+ r"""
292
+ create a pattern of shape [heads, seq, seq] out of a blocksparse
293
+ layout of shape [heads, seq/block_size, seq/block_size]
294
+ """
295
+ return torch.kron(layout, torch.ones(block_size, block_size))
.venv/lib/python3.11/site-packages/xformers/components/attention/core.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+ import math
9
+ from contextlib import nullcontext
10
+ from typing import Optional, Union
11
+
12
+ import torch
13
+
14
+ from xformers import _has_cpp_library
15
+ from xformers.components.attention.attention_mask import AttentionMask
16
+
17
+ if _has_cpp_library:
18
+ from ._sputnik_sparse import SparseCS
19
+
20
+ logger = logging.getLogger("xformers")
21
+
22
+
23
+ def _create_random_sparsity(matrix, sparsity, divisible_by=4):
24
+ assert matrix.ndim == 3
25
+ keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
26
+ nonzero = torch.nonzero(keep)
27
+ nnz = nonzero.shape[0]
28
+ # NOTE: need to make it a multiple of 4 for sputnik
29
+ nonzero = nonzero[: (nnz - nnz % divisible_by)]
30
+ i, j = nonzero.unbind(1)
31
+ output = torch.zeros_like(matrix)
32
+ bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
33
+ output[bdim, i, j] = matrix[bdim, i, j]
34
+ return output
35
+
36
+
37
+ def _broadcast_batch(mask, batch_size):
38
+ if mask.ndim == 3:
39
+ return mask
40
+ assert mask.ndim == 2
41
+
42
+ mask = mask.coalesce()
43
+ values = mask.values()
44
+ indices = mask.indices()
45
+ nnz = len(values)
46
+ # strategy: repeat the indices and append the extra batch dimension to the indices
47
+ indices = indices.repeat(1, batch_size)
48
+ # now create the batch indices
49
+ batch_indices = torch.arange(batch_size, device=indices.device)
50
+ batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten()
51
+
52
+ # put them together
53
+ indices = torch.cat([batch_indices[None, :], indices], dim=0)
54
+
55
+ # now repeat the values
56
+ values = values.repeat(batch_size)
57
+
58
+ size = (batch_size,) + mask.shape
59
+
60
+ return torch.sparse_coo_tensor(indices, values, size)
61
+
62
+
63
+ def _matmul_with_mask(
64
+ a: torch.Tensor,
65
+ b: torch.Tensor,
66
+ mask: Optional[Union[torch.Tensor, "SparseCS"]],
67
+ ) -> torch.Tensor:
68
+ if mask is None:
69
+ return a @ b
70
+
71
+ if _has_cpp_library and mask.dtype == torch.bool:
72
+ if isinstance(mask, SparseCS):
73
+ return mask.matmul_with_mask(a, b)
74
+ if mask.is_sparse:
75
+ # perform broadcasting if needed
76
+ mask = _broadcast_batch(mask, a.shape[0])
77
+
78
+ # coalesced is not implemented for bool tensors, so need to cast
79
+ mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above
80
+
81
+ return torch.ops.xformers.matmul_with_mask(a, b, mask)
82
+
83
+ # Non optimized codepath
84
+ if _has_cpp_library:
85
+ assert not isinstance(mask, SparseCS)
86
+
87
+ att = a @ b
88
+ if mask.dtype == torch.bool:
89
+ assert not isinstance(mask, SparseCS)
90
+ if mask.ndim == 2:
91
+ mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
92
+ # mask is presumed false == ignore
93
+ att[~mask] = float("-inf")
94
+ else:
95
+ # mask is presumed additive
96
+ # repeat if batch sizes don't match
97
+ if (
98
+ not isinstance(mask, SparseCS)
99
+ and mask.ndim == 3
100
+ and mask.shape[0] != att.shape[0]
101
+ and (att.shape[0] % mask.shape[0]) == 0
102
+ ):
103
+ repeat_factor = att.shape[0] // mask.shape[0]
104
+ mask = mask.repeat([repeat_factor, 1, 1])
105
+ logger.info("Mismatched batch dimensions for mask, repeating mask.")
106
+ att += mask
107
+ return att
108
+
109
+
110
+ def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor:
111
+ if _has_cpp_library and isinstance(a, SparseCS):
112
+ return a.softmax()
113
+
114
+ if a.is_sparse:
115
+ return torch.sparse.softmax(a, dim=a.ndim - 1)
116
+
117
+ return torch.softmax(a, dim=a.ndim - 1)
118
+
119
+
120
+ if _has_cpp_library:
121
+
122
+ class SparseBMM(torch.autograd.Function):
123
+ @staticmethod
124
+ def forward(ctx, a, b):
125
+ a = a.coalesce()
126
+ r = torch.bmm(a, b)
127
+ ctx.save_for_backward(a, b)
128
+ return r
129
+
130
+ @staticmethod
131
+ def backward(ctx, grad):
132
+ a, b = ctx.saved_tensors
133
+
134
+ # gradients w.r.t. a
135
+ ga = None
136
+ if ctx.needs_input_grad[0]:
137
+ ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a)
138
+
139
+ # gradients w.r.t. b
140
+ gb = None
141
+ if ctx.needs_input_grad[1]:
142
+ gb = a.transpose(1, 2).bmm(grad)
143
+
144
+ return ga, gb
145
+
146
+ def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
147
+ """
148
+ Batch matrix multiply between a sparse matrix and a dense matrix
149
+ """
150
+ assert a.ndim == b.ndim == 3
151
+ assert a.shape[0] == b.shape[0]
152
+ assert a.shape[2] == b.shape[1]
153
+ return SparseBMM.apply(a, b)
154
+
155
+
156
+ def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
157
+ if _has_cpp_library:
158
+ if isinstance(a, SparseCS):
159
+ return a.spmm(b)
160
+ if a.is_sparse:
161
+ return _sparse_bmm(a, b)
162
+ return a @ b
163
+
164
+
165
+ def _apply_dropout(att, dropout):
166
+ if dropout is None:
167
+ return att
168
+
169
+ # Dropout chokes on sparse tensors
170
+ if _has_cpp_library:
171
+ if isinstance(att, SparseCS):
172
+ values = att.values.clone()
173
+ values = dropout(values)
174
+ att = SparseCS.wrap(
175
+ att.shape,
176
+ values,
177
+ att.row_indices,
178
+ att.row_offsets,
179
+ att.column_indices,
180
+ att._transp_info,
181
+ )
182
+ elif att.is_sparse:
183
+ att = att.coalesce()
184
+ values = att.values().clone() # protect against in-place dropout
185
+ values = dropout(values)
186
+ att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
187
+ else:
188
+ # Simple dense case
189
+ att = dropout(att)
190
+
191
+ return att
192
+
193
+ # Non optimized vanilla dropout
194
+ att = dropout(att)
195
+ return att
196
+
197
+
198
+ def scaled_query_key_softmax(
199
+ q: torch.Tensor,
200
+ k: torch.Tensor,
201
+ att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
202
+ ) -> torch.Tensor:
203
+ # TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh
204
+ # this is needed due to limitations in sparse_bmm for now
205
+
206
+ # Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
207
+ q = q / math.sqrt(k.size(-1))
208
+
209
+ # Matmul with mask
210
+ if att_mask is not None and isinstance(att_mask, AttentionMask):
211
+ # Additive mask
212
+ mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values
213
+ else:
214
+ mask = att_mask
215
+
216
+ att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
217
+
218
+ # Softmax to get the attention probabilities
219
+ is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal
220
+ att = _softmax(att, causal=is_causal)
221
+ return att
222
+
223
+
224
+ def scaled_dot_product_attention(
225
+ q: torch.Tensor,
226
+ k: torch.Tensor,
227
+ v: torch.Tensor,
228
+ att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
229
+ dropout: Optional[torch.nn.Module] = None,
230
+ ) -> torch.Tensor:
231
+ autocast_disabled = (
232
+ _has_cpp_library
233
+ and isinstance(att_mask, SparseCS)
234
+ or (att_mask is not None and att_mask.is_sparse)
235
+ )
236
+ with torch.amp.autocast("cuda", enabled=False) if autocast_disabled else nullcontext(): # type: ignore
237
+ if autocast_disabled:
238
+ q, k, v = q.float(), k.float(), v.float()
239
+
240
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
241
+
242
+ # Optional dropout, could be part of the masking in the future
243
+ att = _apply_dropout(att, dropout)
244
+
245
+ # Get to the predicted values, for all heads
246
+ # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
247
+ y = bmm(att, v)
248
+ return y
.venv/lib/python3.11/site-packages/xformers/components/attention/favor.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.amp import autocast
14
+
15
+ from xformers.components.attention import Attention, AttentionConfig, register_attention
16
+ from xformers.components.attention.feature_maps import (
17
+ FeatureMap,
18
+ FeatureMapType,
19
+ SMHyperbolic,
20
+ SMOrf,
21
+ SMReg,
22
+ )
23
+
24
+ logger = logging.getLogger("xformers")
25
+
26
+
27
+ @dataclass
28
+ class FavorAttentionConfig(AttentionConfig):
29
+ causal: Optional[bool]
30
+ dim_features: Optional[int] = None # The dimensions of the random features
31
+ dim_head: Optional[
32
+ int
33
+ ] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate
34
+ iter_before_redraw: Optional[
35
+ int
36
+ ] = None # The number of iterations before the random features are re-drawn from scratch
37
+ feature_map: Optional[FeatureMapType] = None
38
+
39
+
40
+ @register_attention("favor", FavorAttentionConfig)
41
+ class FavorAttention(Attention):
42
+ def __init__(
43
+ self,
44
+ causal: bool = False,
45
+ dropout: float = 0.0,
46
+ dim_features: Optional[int] = None,
47
+ dim_head: Optional[int] = None,
48
+ iter_before_redraw: Optional[int] = None,
49
+ feature_map_type: FeatureMapType = FeatureMapType.SMReg,
50
+ normalize_inputs: bool = False,
51
+ *_,
52
+ **__,
53
+ ):
54
+ r"""
55
+ Kernelized attention, as proposed in Performers_
56
+ ("Rethinking attention with performers." K. Choromanski et al. (2020).).
57
+
58
+ FAVOR stands for "Fast Attention Via positive Orthogonal Random features"
59
+
60
+ Args:
61
+ dropout (float): the probability of an output to be randomly dropped at training time
62
+ dim_features (int): the dimension of the random features space
63
+ iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
64
+ feature_map_type (FeatureMapType): the type of feature map being used,
65
+ for instance orthogonal random features.
66
+
67
+ .. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf
68
+ """
69
+ super().__init__()
70
+
71
+ self.causal = causal
72
+ self.iter_before_redraw = (
73
+ (2 * iter_before_redraw)
74
+ if iter_before_redraw is not None
75
+ else iter_before_redraw
76
+ ) # This will be used for both key and query
77
+ self.normalize_inputs = normalize_inputs
78
+ self.feature_map_type = feature_map_type
79
+ self.attn_drop = nn.Dropout(dropout, inplace=True)
80
+
81
+ # Setup dimension-dependent variables
82
+ # Reasonable dimension default
83
+ if dim_features is None:
84
+ assert dim_head is not None, "dim_features or dim_head needs to be passed"
85
+ self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head)))
86
+ self.dim_features = 2 * (
87
+ self.dim_features // 2
88
+ ) # needs to be even for some variants
89
+ logger.info(
90
+ f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}"
91
+ )
92
+ else:
93
+ self.dim_features = dim_features
94
+
95
+ feature_map_constructor = {
96
+ FeatureMapType.SMHyp: SMHyperbolic,
97
+ FeatureMapType.SMReg: SMReg,
98
+ FeatureMapType.SMOrf: SMOrf,
99
+ }[self.feature_map_type]
100
+
101
+ feature_settings = {
102
+ "dim_features": self.dim_features,
103
+ "iter_before_redraw": self.iter_before_redraw,
104
+ "normalize_inputs": self.normalize_inputs,
105
+ }
106
+
107
+ self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
108
+
109
+ # Properties specific to this attention mechanism
110
+ self.supports_attention_mask = False
111
+ self.supports_key_padding_mask = False
112
+
113
+ @staticmethod
114
+ def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
115
+ # Only promote fp16 buffers, bfloat16 would be fine for instance
116
+ return x.float() if x.dtype == torch.float16 else x
117
+
118
+ @staticmethod
119
+ def _causal_attention(
120
+ k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor
121
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ # Algorithm 1 in the paper
123
+ ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB
124
+ Gps = k_prime.unsqueeze(3) * v.unsqueeze(2)
125
+ Grenorm = k_prime.unsqueeze(3) * ref_v
126
+
127
+ # Consolidate against the feature dimension
128
+ att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime)
129
+ att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime)
130
+
131
+ # Cumulative sum over the sequence
132
+ att_raw = att_raw.cumsum(2)
133
+ att_norm = att_norm.cumsum(2)
134
+
135
+ return att_raw, att_norm
136
+
137
+ def forward(
138
+ self,
139
+ q: torch.Tensor,
140
+ k: torch.Tensor,
141
+ v: torch.Tensor,
142
+ *_,
143
+ **__,
144
+ ):
145
+
146
+ # Project key and queries onto the feature map space
147
+ k_prime = self.feature_map(k)
148
+ q_prime = self.feature_map(q)
149
+
150
+ with autocast("cuda", enabled=False):
151
+ # The softmax kernel approximation for Favor will easily overflow
152
+ # Force the computations here to stay in fp32 for numerical stability
153
+ # Note that the dimensions are vastly reduced when compared to scaled_dot_product
154
+ k_prime = self._maybe_promote(k_prime)
155
+ q_prime = self._maybe_promote(q_prime)
156
+ v = self._maybe_promote(v)
157
+
158
+ if not self.causal:
159
+ att_normalization = q_prime @ (
160
+ k_prime.transpose(-2, -1) @ torch.ones_like(v)
161
+ )
162
+ att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v)
163
+ else:
164
+ # Actually compute attention
165
+ att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v)
166
+
167
+ # Normalize
168
+ att = att_raw / att_normalization
169
+
170
+ if self.attn_drop is not None:
171
+ att = self.attn_drop(att)
172
+
173
+ return att
.venv/lib/python3.11/site-packages/xformers/components/attention/fourier_mix.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch.amp import autocast
8
+
9
+ from xformers.components.attention import Attention, AttentionConfig, register_attention
10
+
11
+
12
+ @register_attention("fourier_mix", AttentionConfig)
13
+ class FourierMix(Attention):
14
+ def __init__(self, dropout: float, *_, **__):
15
+ """
16
+ FFT-based pseudo-attention mechanism, from
17
+ "
18
+ "FNet: Mixing Tokens with Fourier Transforms"
19
+ Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
20
+ """
21
+ super().__init__()
22
+ self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
23
+
24
+ # Properties specific to this attention mechanism
25
+ self.supports_attention_mask = False
26
+ self.requires_input_projection = False
27
+
28
+ def forward(self, q: torch.Tensor, *_, **__):
29
+ # Guard against autocast / fp16, not supported by torch.fft.fft2
30
+ with autocast("cuda", enabled=False):
31
+ att = torch.fft.fft2(q).real
32
+
33
+ att = self.attn_drop(att)
34
+
35
+ return att
.venv/lib/python3.11/site-packages/xformers/components/attention/lambda_layer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+
11
+ from xformers.components.attention import Attention, AttentionConfig, register_attention
12
+
13
+
14
+ def calc_rel_pos(n: int):
15
+ # Adapted from LucidRains
16
+ # https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py
17
+ rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n]
18
+ rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
19
+ return rel_pos
20
+
21
+
22
+ @dataclass
23
+ class LambdaLayerConfig(AttentionConfig):
24
+ seq_len: int # dimension of the input sequence
25
+ dim_head: int
26
+
27
+
28
+ @register_attention("lambda", LambdaLayerConfig)
29
+ class LambdaLayer(Attention):
30
+ def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
31
+ """
32
+ Attention approximation using Lambda layers, from
33
+ "Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021).
34
+ """
35
+ super().__init__()
36
+
37
+ # Possible extensions:
38
+ # - support different dimensions for key and queries
39
+ # - support varying dimensions in between inputs and outputs
40
+ # - support u hyperparam
41
+
42
+ self.rel_pos_emb = torch.nn.Parameter(
43
+ torch.randn(2 * seq_len - 1, int(dim_head))
44
+ )
45
+ self.rel_pos = calc_rel_pos(seq_len)
46
+ self.attn_drop = torch.nn.Dropout(dropout, inplace=True)
47
+
48
+ # Properties specific to this attention mechanism
49
+ self.requires_same_k_q_dimensions = True
50
+ self.supports_attention_mask = False
51
+ self.supports_key_padding_mask = False
52
+
53
+ def forward(
54
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
55
+ ):
56
+ """..NOTE: We're reusing the einsum notation suggested by the paper, changed in that
57
+ heads are folded in the batch dimension"""
58
+
59
+ content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v)
60
+ content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda)
61
+
62
+ rel_pos_emb = self.rel_pos_emb[self.rel_pos]
63
+
64
+ # Handle real sequence length being possibly smaller
65
+ seq_len = q.shape[1]
66
+ rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :]
67
+
68
+ # Compute the position lambda for every possible combination in one go, then compute the
69
+ # position related contribution
70
+ position_lambdas = torch.einsum(
71
+ "mnk,bnv->bnkv", rel_pos_emb, v
72
+ ) # one lambda per position
73
+ position_output = (q.unsqueeze(2) @ position_lambdas).squeeze()
74
+ att = content_output + position_output
75
+
76
+ att = self.attn_drop(att)
77
+
78
+ return att
.venv/lib/python3.11/site-packages/xformers/components/attention/local.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from xformers.components.attention import (
14
+ Attention,
15
+ AttentionConfig,
16
+ AttentionMask,
17
+ maybe_sparsify,
18
+ register_attention,
19
+ sparsify,
20
+ )
21
+ from xformers.components.attention.attention_patterns import (
22
+ causal_1d_pattern,
23
+ local_1d_pattern,
24
+ )
25
+ from xformers.components.attention.core import scaled_dot_product_attention
26
+
27
+
28
+ @dataclass
29
+ class LocalAttentionConfig(AttentionConfig):
30
+ causal: Optional[bool] = None
31
+ window_size: Optional[int] = None
32
+ force_sparsity: Optional[bool] = None
33
+
34
+
35
+ @register_attention("local", LocalAttentionConfig)
36
+ class LocalAttention(Attention):
37
+ def __init__(
38
+ self,
39
+ dropout: float = 0.0,
40
+ causal: bool = False,
41
+ window_size: int = 5,
42
+ force_sparsity: bool = False,
43
+ *args,
44
+ **kwargs,
45
+ ):
46
+
47
+ r"""
48
+ An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_
49
+
50
+
51
+ Args:
52
+ dropout (float): the probability of an output to be randomly dropped at training time
53
+ causal (bool): apply a causal mask, in that the attention cannot be applied to the future
54
+ window_size (int): the overall window size for local attention.
55
+ Odd number is expected if the mask is not causal, as the window size will be evenly
56
+ distributed on both sides of each query
57
+
58
+
59
+ .. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf
60
+
61
+ .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
62
+
63
+ .. _Longformer: https://arxiv.org/pdf/2004.05150.pdf
64
+
65
+ """
66
+ super().__init__()
67
+
68
+ self.attn_drop = nn.Dropout(dropout, inplace=False)
69
+ self.causal = causal
70
+ self.force_sparsity = force_sparsity
71
+
72
+ if not self.causal:
73
+ assert (
74
+ window_size % 2 == 1
75
+ ), "The window size is assumed to be odd (counts self-attention + 2 wings)"
76
+
77
+ self.window_size = window_size
78
+ self.attention_mask: Optional[torch.Tensor] = None
79
+ self.requires_same_k_q_dimensions = True
80
+
81
+ # Properties specific to this attention mechanism
82
+ self.supports_attention_mask = True
83
+ self.supports_key_padding_mask = False
84
+
85
+ def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
86
+ window_size = self.window_size * 2 + 1 if self.causal else self.window_size
87
+ mask = local_1d_pattern(shape[1], window_size)
88
+
89
+ if self.causal:
90
+ mask &= causal_1d_pattern(shape[1])
91
+
92
+ mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
93
+
94
+ return mask
95
+
96
+ def forward(
97
+ self,
98
+ q: torch.Tensor,
99
+ k: torch.Tensor,
100
+ v: torch.Tensor,
101
+ att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
102
+ *args,
103
+ **kwargs,
104
+ ):
105
+ # Local window attention masking
106
+ if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]:
107
+ self.attention_mask = self._get_local_mask(q.shape).to(q.device)
108
+
109
+ # Take into account the optional user mask
110
+ if att_mask is None:
111
+ mask = self.attention_mask
112
+ else:
113
+ if isinstance(att_mask, AttentionMask):
114
+ # Needed because & op not defined for SparseCS with AttentionMask
115
+ att_mask = att_mask.to_bool()
116
+ mask = self.attention_mask & att_mask
117
+
118
+ return scaled_dot_product_attention(
119
+ q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop
120
+ )
.venv/lib/python3.11/site-packages/xformers/components/attention/nystrom.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from xformers.components.attention import Attention, AttentionConfig, register_attention
15
+ from xformers.components.attention.core import (
16
+ scaled_dot_product_attention,
17
+ scaled_query_key_softmax,
18
+ )
19
+ from xformers.components.attention.utils import (
20
+ bool_mask_to_additive,
21
+ iterative_pinv,
22
+ reshape_key_padding_mask,
23
+ )
24
+
25
+ logger = logging.getLogger("xformers")
26
+
27
+
28
+ @dataclass
29
+ class NystromSelfAttentionConfig(AttentionConfig):
30
+ """
31
+ num_heads Number of heads.
32
+ num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
33
+ approximation according to https://arxiv.org/pdf/2102.03902.pdf.
34
+ causal Apply a causal mask, in that the attention cannot be applied to the future.
35
+ use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
36
+ inverse, otherwise use standard torch inverse.
37
+ pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
38
+ method from (Razavi et al. 2014).
39
+ False if using exact coefficient computation (leads to faster convergence).
40
+ inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse.
41
+ v_skip_connection A module that will take V as input and will be added as a skip connection to the
42
+ softmax approximation. A skip connection is added in the paper to help with training.
43
+ conv_kernel_size Kernel size for convolution optionally added to help in training.
44
+ If v_skip_connection is not specified, this will be used to define the default
45
+ depth wise convolution used as a skip connection.
46
+ If both conv_kernel_size and v_skip_connection are None, no skip connection will
47
+ be added.
48
+ landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d.
49
+ """
50
+
51
+ num_heads: int
52
+ num_landmarks: Optional[int]
53
+ landmark_pooling: Optional[nn.Module]
54
+ causal: Optional[bool]
55
+ pinverse_original_init: Optional[bool]
56
+ inv_iterations: Optional[int]
57
+ v_skip_connection: Optional[nn.Module]
58
+ conv_kernel_size: Optional[int]
59
+ use_razavi_pinverse: Optional[bool]
60
+
61
+
62
+ class AvgPool(nn.Module):
63
+ def __init__(self, n: int):
64
+ super().__init__()
65
+ self.n = n
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ # Average independently for every segment in the sequence dimension
69
+ seq_len = x.shape[1]
70
+ head_dim = x.shape[2]
71
+ segments = seq_len // self.n
72
+ assert segments > 0, "num_landmarks should be smaller than the sequence length"
73
+
74
+ # Dimensions are a match
75
+ if seq_len % self.n == 0:
76
+ return x.reshape(
77
+ -1,
78
+ self.n,
79
+ segments,
80
+ head_dim,
81
+ ).mean(dim=-2)
82
+
83
+ # Handle the last segment boundary being off
84
+ n_round = self.n - seq_len % self.n
85
+
86
+ x_avg_round = (
87
+ x[:, : n_round * segments, :]
88
+ .reshape(-1, n_round, segments, head_dim)
89
+ .mean(dim=-2)
90
+ )
91
+ x_avg_off = (
92
+ x[:, n_round * segments :, :]
93
+ .reshape(-1, self.n - n_round, segments + 1, head_dim)
94
+ .mean(dim=-2)
95
+ )
96
+ return torch.cat((x_avg_round, x_avg_off), dim=-2)
97
+
98
+
99
+ @register_attention("nystrom", NystromSelfAttentionConfig)
100
+ class NystromAttention(Attention):
101
+ # TODO: update defaults for use_razavi_pinverse and inv_iterations
102
+ def __init__(
103
+ self,
104
+ dropout: float,
105
+ num_heads: int,
106
+ num_landmarks: int = 64,
107
+ landmark_pooling: Optional[nn.Module] = None,
108
+ causal: bool = False,
109
+ use_razavi_pinverse: bool = True,
110
+ pinverse_original_init: bool = False,
111
+ inv_iterations: int = 6, # recommended default in paper was 6.
112
+ v_skip_connection: Optional[nn.Module] = None,
113
+ conv_kernel_size: Optional[int] = None,
114
+ *args,
115
+ **kwargs,
116
+ ):
117
+ """
118
+ Nystrom attention mechanism, from Nystromformer_.
119
+ ::
120
+
121
+ "A Nystrom-based Algorithm for Approximating Self-Attention."
122
+ Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021)
123
+
124
+ Reference codebase: https://github.com/mlpen/Nystromformer
125
+
126
+ .. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf
127
+
128
+ """
129
+ super().__init__()
130
+ # merged key padding mask and attention mask is not accepted
131
+ self.requires_separate_masks = True
132
+ self.num_landmarks = num_landmarks
133
+ # TODO: should be able to not have to pass in num_heads
134
+ self.num_heads = num_heads
135
+ self.use_razavi_pinverse = use_razavi_pinverse
136
+ self.pinverse_original_init = pinverse_original_init
137
+ self.inv_iterations = inv_iterations
138
+ self.attn_drop = nn.Dropout(dropout)
139
+ self.skip_connection = v_skip_connection
140
+ self.causal = causal
141
+
142
+ if self.skip_connection is None and conv_kernel_size is not None:
143
+ self.skip_connection = nn.Conv2d(
144
+ in_channels=self.num_heads,
145
+ out_channels=self.num_heads,
146
+ kernel_size=(conv_kernel_size, 1),
147
+ padding=(conv_kernel_size // 2, 0),
148
+ bias=False,
149
+ groups=self.num_heads,
150
+ )
151
+
152
+ if landmark_pooling is not None:
153
+ self.landmark_pooling = landmark_pooling
154
+ else:
155
+ self.landmark_pooling = AvgPool(n=self.num_landmarks)
156
+
157
+ # Optional lower triangular masks for causal attention
158
+ self.causal_mask_1: Optional[torch.Tensor] = None
159
+ self.causal_mask_2: Optional[torch.Tensor] = None
160
+ self.causal_mask_3: Optional[torch.Tensor] = None
161
+
162
+ # This attention does not support attention masks
163
+ self.supports_attention_mask = False
164
+ self.supports_key_padding_mask = True
165
+
166
+ def forward(
167
+ self,
168
+ q: torch.Tensor,
169
+ k: torch.Tensor,
170
+ v: torch.Tensor,
171
+ key_padding_mask: Optional[torch.Tensor] = None,
172
+ *args,
173
+ **kwargs,
174
+ ):
175
+ r"""
176
+ key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or
177
+ (batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
178
+ be ignored. An additive mask is expected, meaning float values using "-inf" to mask values
179
+ """
180
+
181
+ batched_dim = k.size(0)
182
+ seq_len = k.size(-2)
183
+ tt = {"dtype": q.dtype, "device": q.device}
184
+
185
+ if key_padding_mask is not None:
186
+ if key_padding_mask.dtype == torch.bool:
187
+ logger.warning(
188
+ "Bool mask found, but an additive mask is expected. Converting but this is slow"
189
+ )
190
+
191
+ key_padding_mask = bool_mask_to_additive(key_padding_mask)
192
+
193
+ if key_padding_mask.ndim == 2:
194
+ key_padding_mask = reshape_key_padding_mask(
195
+ key_padding_mask, batched_dim
196
+ )
197
+
198
+ zeros = torch.zeros_like(key_padding_mask)
199
+ ones = torch.ones_like(key_padding_mask)
200
+ is_masked = torch.isinf(-key_padding_mask)
201
+
202
+ # _mask takes 1 if the token is not padded, otherwise 0.
203
+ _mask = torch.where(is_masked, zeros, ones)
204
+ _mask = _mask.transpose(2, 1)
205
+ assert _mask.shape == (batched_dim, q.shape[1], 1)
206
+
207
+ # Mask q and k before pooling
208
+ # https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31
209
+ q = q * _mask
210
+ k = k * _mask
211
+
212
+ assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
213
+ f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
214
+ f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
215
+ )
216
+
217
+ if self.num_landmarks >= seq_len:
218
+ mask: Optional[torch.Tensor] = None
219
+
220
+ if self.causal:
221
+ mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
222
+
223
+ if key_padding_mask is not None:
224
+ mask = key_padding_mask if mask is None else mask + key_padding_mask
225
+
226
+ x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
227
+
228
+ else:
229
+ q_landmarks = self.landmark_pooling(q)
230
+ k_landmarks = self.landmark_pooling(k)
231
+
232
+ if self.causal and (
233
+ self.causal_mask_1 is None
234
+ or (batched_dim, seq_len, self.num_landmarks)
235
+ != self.causal_mask_1.size()
236
+ ):
237
+ self.causal_mask_1 = self._triu_mask(
238
+ batched_dim, seq_len, self.num_landmarks, **tt
239
+ )
240
+ self.causal_mask_2 = self._triu_mask(
241
+ batched_dim, self.num_landmarks, self.num_landmarks, **tt
242
+ )
243
+ self.causal_mask_3 = self._triu_mask(
244
+ batched_dim, self.num_landmarks, seq_len, **tt
245
+ )
246
+
247
+ mask_3: Optional[torch.Tensor] = self.causal_mask_3
248
+ if key_padding_mask is not None:
249
+ mask_3 = (
250
+ key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
251
+ )
252
+
253
+ kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None)
254
+ kernel_2 = scaled_query_key_softmax(
255
+ q=q_landmarks, k=k_landmarks, att_mask=None
256
+ )
257
+ kernel_3 = scaled_dot_product_attention(
258
+ q=q_landmarks, k=k, v=v, att_mask=mask_3
259
+ )
260
+
261
+ kernel_2_inv = (
262
+ iterative_pinv(
263
+ kernel_2, self.inv_iterations, self.pinverse_original_init
264
+ )
265
+ if self.use_razavi_pinverse
266
+ else torch.linalg.pinv(kernel_2)
267
+ )
268
+
269
+ x = torch.matmul(
270
+ torch.matmul(
271
+ kernel_1,
272
+ kernel_2_inv,
273
+ ),
274
+ kernel_3,
275
+ )
276
+
277
+ if self.skip_connection:
278
+ # Assumption here is that v is 3D.
279
+ v_conv = self.skip_connection(
280
+ v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
281
+ )
282
+ x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
283
+ x = self.attn_drop(x)
284
+ return x
285
+
286
+ def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
287
+ device = kwargs["device"]
288
+ dtype = kwargs["dtype"]
289
+
290
+ return torch.triu(
291
+ torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
292
+ diagonal=1,
293
+ ).expand(
294
+ dim_1, -1, -1
295
+ ) # micro optim, save memory on the batch dimension
.venv/lib/python3.11/site-packages/xformers/components/attention/random.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from xformers.components.attention import (
14
+ Attention,
15
+ AttentionConfig,
16
+ AttentionMask,
17
+ maybe_sparsify,
18
+ register_attention,
19
+ sparsify,
20
+ )
21
+ from xformers.components.attention.attention_patterns import (
22
+ causal_1d_pattern,
23
+ random_pattern,
24
+ )
25
+ from xformers.components.attention.core import scaled_dot_product_attention
26
+
27
+
28
+ @dataclass
29
+ class RandomAttentionConfig(AttentionConfig):
30
+ r: Optional[
31
+ float
32
+ ] # the ratio of keys that the query can attend to. 1.0 means dense attention
33
+ constant_masking: Optional[
34
+ bool
35
+ ] # whether the randomness is per query or defined at construction time
36
+ force_sparsity: Optional[bool] # use sparsity in any case (potentially slower)
37
+
38
+
39
+ @register_attention("random", RandomAttentionConfig)
40
+ class RandomAttention(Attention):
41
+ def __init__(
42
+ self,
43
+ dropout: float,
44
+ causal: bool = False,
45
+ r: float = 0.01,
46
+ constant_masking: bool = True,
47
+ force_sparsity: bool = False,
48
+ *args,
49
+ **kwargs,
50
+ ):
51
+ """
52
+ "Random" attention, as proposed for instance in BigBird_.
53
+ Random means in that case that each query can attend to a random set of keys.
54
+ This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
55
+
56
+ Args:
57
+ r (float): the ratio in [0,1] of keys that the query can attend to
58
+ constant_masking (bool): if true, keep the same random set for all queries.
59
+
60
+ .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
61
+
62
+ """
63
+ super().__init__()
64
+
65
+ self.attn_drop = nn.Dropout(dropout, inplace=False)
66
+ self.causal = causal
67
+ self.r = r
68
+ self.rand_attention_mask: Optional[torch.Tensor] = None
69
+ self.constant_masking = constant_masking
70
+ self.force_sparsity = force_sparsity
71
+
72
+ # Properties specific to this attention mechanism
73
+ self.supports_attention_mask = True
74
+ self.supports_key_padding_mask = False
75
+
76
+ self.requires_same_k_q_dimensions = True
77
+
78
+ def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor:
79
+ sparsity = 1 - self.r
80
+ mask = random_pattern(shape[1], sparsity=sparsity)
81
+
82
+ if self.causal:
83
+ mask &= causal_1d_pattern(shape[1])
84
+
85
+ mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
86
+
87
+ return mask
88
+
89
+ def forward(
90
+ self,
91
+ q: torch.Tensor,
92
+ k: torch.Tensor,
93
+ v: torch.Tensor,
94
+ att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
95
+ *args,
96
+ **kwargs,
97
+ ):
98
+ # Rand masking
99
+ if not self.constant_masking or self.rand_attention_mask is None:
100
+ self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device)
101
+
102
+ # Mask-aware attention
103
+ if att_mask is not None:
104
+ if att_mask.dtype == torch.bool and isinstance(
105
+ self.rand_attention_mask, AttentionMask
106
+ ):
107
+ mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask)
108
+ else:
109
+ if isinstance(att_mask, AttentionMask):
110
+ # Needed because & op not defined for SparseCS with AttentionMask
111
+ att_mask = att_mask.to_bool()
112
+ mask = self.rand_attention_mask & att_mask
113
+ else:
114
+ mask = self.rand_attention_mask
115
+
116
+ # Handle q/k/v which would not fit the mask
117
+ seq_len = q.shape[-2]
118
+ q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
119
+
120
+ # Normal attention with the random mask
121
+ att = scaled_dot_product_attention(
122
+ q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
123
+ )
124
+
125
+ # Take into account an hypothetical padding
126
+ return att[:, :seq_len, :]
.venv/lib/python3.11/site-packages/xformers/components/attention/scaled_dot_product.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ from xformers.components.attention import (
14
+ Attention,
15
+ AttentionConfig,
16
+ AttentionMask,
17
+ register_attention,
18
+ )
19
+ from xformers.components.attention.core import scaled_dot_product_attention
20
+
21
+ logger = logging.getLogger("xformers")
22
+
23
+
24
+ @dataclass
25
+ class ScaledDotProductConfig(AttentionConfig):
26
+ causal: Optional[bool]
27
+ seq_len: Optional[int]
28
+ to_seq_len: Optional[int]
29
+
30
+
31
+ @register_attention("scaled_dot_product", ScaledDotProductConfig)
32
+ class ScaledDotProduct(Attention):
33
+ r"""
34
+ Implementing the Scaled Dot-Product attention proposed in
35
+ `Attention is all you need`_, Vaswani et al.
36
+
37
+ .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
38
+ """
39
+
40
+ mask: Optional[AttentionMask]
41
+
42
+ def __init__(
43
+ self,
44
+ dropout: float = 0.0,
45
+ causal: bool = False,
46
+ seq_len: Optional[int] = None,
47
+ to_seq_len: Optional[int] = None,
48
+ *args,
49
+ **kwargs,
50
+ ):
51
+ super().__init__()
52
+
53
+ self.attn_drop = nn.Dropout(dropout, inplace=False)
54
+ self.causal = causal
55
+ self.seq_len = seq_len
56
+
57
+ if causal and seq_len is not None:
58
+ self.mask = AttentionMask.make_causal(seq_len, to_seq_len)
59
+ else:
60
+ self.mask = None
61
+
62
+ # Properties specific to this attention mechanism
63
+ self.supports_attention_mask = True
64
+ self.supports_key_padding_mask = False
65
+
66
+ def forward(
67
+ self,
68
+ q: torch.Tensor,
69
+ k: torch.Tensor,
70
+ v: torch.Tensor,
71
+ att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
72
+ *args,
73
+ **kwargs,
74
+ ) -> torch.Tensor:
75
+ r"""
76
+ att_mask A 2D or 3D mask which ignores attention at certain positions.
77
+
78
+ - If the mask is boolean, a value of True will keep the value,
79
+ while a value of False will mask the value.
80
+
81
+ Key padding masks (dimension: batch x sequence length) and attention masks
82
+ (dimension: sequence length x sequence length OR batch x sequence length x sequence length)
83
+ can be combined and passed in here. Method maybe_merge_masks provided in the utils can be
84
+ used for that merging.
85
+
86
+ - If the mask has the float type, then an additive mask is expected (masked values are -inf)
87
+
88
+ """
89
+
90
+ # Convenience, create an attention mask if a tensor was passed
91
+ if att_mask is not None and isinstance(att_mask, torch.Tensor):
92
+ # By default we don't know of the causality, and a check would be expensive
93
+ att_mask = (
94
+ AttentionMask.from_bool(att_mask)
95
+ if att_mask.dtype == torch.bool
96
+ else AttentionMask(att_mask, is_causal=False)
97
+ )
98
+
99
+ # Handle a possibly deferred causal mask handling
100
+ mask = self.mask
101
+ if self.causal and self.mask is None:
102
+ mask = AttentionMask.make_causal(
103
+ seq_len=q.shape[-2],
104
+ to_seq_len=q.shape[-2],
105
+ device=q.device,
106
+ dtype=q.dtype,
107
+ )
108
+
109
+ # Merge the optional causal mask and the user-provided mask
110
+ if mask is not None:
111
+ mask = mask.to(dtype=q.dtype, device=q.device)
112
+
113
+ att_mask = att_mask + mask if att_mask is not None else mask
114
+
115
+ # Try to handle a case where the sequence is smaller than the mask
116
+ if (
117
+ att_mask is not None
118
+ and q.shape[-2] == k.shape[-2]
119
+ and q.shape[-2] < att_mask.shape[1]
120
+ ):
121
+ if isinstance(att_mask, AttentionMask):
122
+ att_mask = att_mask.make_crop(seq_len=q.shape[-2])
123
+ else:
124
+ logger.error(
125
+ "Mismatching sparse attention mask and sequence length."
126
+ + " Please pad the inputs or adjust the attention mask"
127
+ )
128
+ raise NotImplementedError
129
+
130
+ # Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
131
+ y = scaled_dot_product_attention(
132
+ q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop
133
+ )
134
+ return y
.venv/lib/python3.11/site-packages/xformers/components/attention/visual.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from xformers.components.attention import Attention, AttentionConfig, register_attention
14
+
15
+
16
+ @dataclass
17
+ class VisualAttentionConfig(AttentionConfig):
18
+ dim_model: int # dimension of the input sequence
19
+
20
+
21
+ class LKA(nn.Module):
22
+ def __init__(self, dim: int):
23
+ super().__init__()
24
+ self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
25
+ self.conv_spatial = nn.Conv2d(
26
+ dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
27
+ )
28
+ self.conv1 = nn.Conv2d(dim, dim, 1)
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ u = x.clone()
32
+ attn = self.conv0(x)
33
+ attn = self.conv_spatial(attn)
34
+ attn = self.conv1(attn)
35
+
36
+ return u * attn
37
+
38
+
39
+ @register_attention("visual", VisualAttentionConfig)
40
+ class Visual(Attention):
41
+ def __init__(
42
+ self,
43
+ dim_model: int,
44
+ *_,
45
+ **__,
46
+ ):
47
+ """
48
+ Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022).
49
+ The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network
50
+ for the reference implementation
51
+
52
+ .. Note: compared to the paper, this block contains the LKA (Large Kernel Attention)
53
+ and the prior and posterior transformations (Conv2d and activation)
54
+
55
+ .. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf
56
+ """
57
+ super().__init__()
58
+
59
+ self.block = nn.Sequential(
60
+ nn.Conv2d(dim_model, dim_model, 1),
61
+ nn.GELU(),
62
+ LKA(dim_model),
63
+ nn.Conv2d(dim_model, dim_model, 1),
64
+ )
65
+
66
+ # MHA related flags:
67
+ self.requires_same_k_q_dimensions = (
68
+ True # This mechanism only really supports self attention
69
+ )
70
+ self.supports_attention_mask = False
71
+ self.requires_skip_multi_head = (
72
+ True # This mechanism skips the multihead attention altogether
73
+ )
74
+ self.requires_squared_context = (
75
+ True # Recovering the 2D structure from context assumes squared content
76
+ )
77
+
78
+ self.requires_input_projection = (
79
+ False # This mechanism does not require that the MHA projects inputs
80
+ )
81
+
82
+ def forward(self, q: torch.Tensor, *_, **__):
83
+ # Expose the 2D token structure
84
+ B, HW, C = q.shape
85
+ H = int(math.sqrt(HW))
86
+ assert H * H == HW
87
+
88
+ x = q.transpose(-2, -1).reshape(B, C, H, H)
89
+
90
+ # Large kernel attention
91
+ residual = x.clone()
92
+ x = self.block(x)
93
+ x = x + residual
94
+
95
+ # Get back to B HW C
96
+ return x.flatten(2, 3).transpose(-2, -1)
.venv/lib/python3.11/site-packages/xformers/components/input_projection.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py
7
+ # and the MultiHeadAttention implementation from PyTorch
8
+
9
+
10
+ import logging
11
+ from dataclasses import dataclass
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from xformers._deprecation_warning import deprecated_function
18
+
19
+ logger = logging.getLogger("xformers")
20
+
21
+
22
+ @dataclass
23
+ class InputProjectionConfig:
24
+ in_features: int
25
+ out_features: int
26
+ bias: bool
27
+
28
+
29
+ class InputProjection(nn.Module):
30
+ """
31
+ Handle all the input projections in one go, opportunistically fuse some operations.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ query_proj_params: InputProjectionConfig,
37
+ key_proj_params: Optional[InputProjectionConfig],
38
+ value_proj_params: Optional[InputProjectionConfig],
39
+ use_separate_proj_weight: bool = True,
40
+ ):
41
+
42
+ super().__init__()
43
+ deprecated_function(self)
44
+
45
+ self.out_features = query_proj_params.out_features
46
+
47
+ # Each input gets a separate projection
48
+ self.q_proj = nn.Linear(
49
+ query_proj_params.in_features,
50
+ query_proj_params.out_features,
51
+ query_proj_params.bias,
52
+ )
53
+
54
+ if key_proj_params is not None:
55
+ self.k_proj = nn.Linear(
56
+ key_proj_params.in_features,
57
+ key_proj_params.out_features,
58
+ key_proj_params.bias,
59
+ )
60
+ else:
61
+ logger.info(
62
+ "No Key projection parameters were passed, assuming that the weights"
63
+ + " are shared with the query projection"
64
+ )
65
+ self.k_proj = self.q_proj
66
+
67
+ if value_proj_params is not None:
68
+ self.v_proj = nn.Linear(
69
+ value_proj_params.in_features,
70
+ value_proj_params.out_features,
71
+ value_proj_params.bias,
72
+ )
73
+ else:
74
+ logger.info(
75
+ "No Value projection parameters were passed, assuming that the weights"
76
+ + " are shared with the query projection"
77
+ )
78
+ self.v_proj = self.q_proj
79
+
80
+ if not use_separate_proj_weight:
81
+ # Compute optimization used at times, share the parameters in between Q/K/V
82
+ with torch.no_grad():
83
+ self.k_proj.weight = self.q_proj.weight
84
+ self.v_proj.weight = self.q_proj.weight
85
+
86
+ def forward(
87
+ self,
88
+ query: torch.Tensor,
89
+ key: torch.Tensor,
90
+ value: torch.Tensor,
91
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
92
+ # One projection per input tensor
93
+
94
+ # NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ?
95
+
96
+ q, k, v = map(
97
+ lambda fn, x: fn(x),
98
+ [self.q_proj, self.k_proj, self.v_proj],
99
+ [query, key, value],
100
+ )
101
+
102
+ return q, k, v
.venv/lib/python3.11/site-packages/xformers/components/multi_head_dispatch.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+ from dataclasses import asdict, dataclass
9
+ from typing import Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn.init import constant_
14
+
15
+ from xformers._deprecation_warning import deprecated_function
16
+ from xformers.components.attention import Attention
17
+ from xformers.components.input_projection import InputProjection, InputProjectionConfig
18
+ from xformers.components.positional_embedding import RotaryEmbedding
19
+
20
+ logger = logging.getLogger("xformers")
21
+
22
+
23
+ @dataclass
24
+ class MultiHeadDispatchConfig:
25
+ dim_model: int
26
+ num_heads: int
27
+ attention: Attention
28
+ bias: bool
29
+ residual_dropout: float
30
+ dim_key: Optional[int]
31
+ dim_value: Optional[int]
32
+ in_proj_container: Optional[InputProjection]
33
+ use_separate_proj_weight: Optional[bool]
34
+ use_rotary_embeddings: Optional[bool]
35
+ out_proj: Optional[nn.Module]
36
+
37
+ def __getitem__(self, item):
38
+ return getattr(self, item)
39
+
40
+
41
+ # Move head forward and fold into batch dim. dimensions become (B * nh, S, hs)
42
+ def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
43
+ return t.view(B, S, H, Hs).transpose(1, 2).flatten(start_dim=0, end_dim=1)
44
+
45
+
46
+ # Move head forward and fold into batch dim. dimensions become (B, nh, S, hs)
47
+ def _split_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
48
+ return t.view(B, S, H, Hs).transpose(1, 2)
49
+
50
+
51
+ class MultiHeadDispatch(nn.Module):
52
+ """
53
+ A multi-head masked self-attention dispatch mechanism, with a projection at the end,
54
+ following the architecture proposed in `Attention is all you need`_, Vaswani et al.
55
+
56
+ The actual attention mechanism can vary, as well as the projections.
57
+ This can be used to wrap the proposed attention mechanisms and make them multi-head aware,
58
+ but it is optional.
59
+
60
+ Args:
61
+ dim_model: The model/embedding dimension
62
+ num_heads: The number of heads being used
63
+ attention: The attention mechanism (needs to be registered to the xformers library)
64
+ bias: Whether to use bias for the projections : (Q, K, V, Output)
65
+ residual_dropout: Amount of dropout on the residual path
66
+ use_separate_proj_weight: Use different weights for the Q, K, V projections
67
+ dim_key: Optionally use a different dimension for the key
68
+ dim_value: Optionally use a different dimension for the value
69
+ in_proj_container: Optionally provide the input projection module
70
+ use_rotary_embeddings: Use rotary embeddings
71
+ out_proj: Optionally provide the output projection module
72
+
73
+
74
+ .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ dim_model: int,
80
+ num_heads: int,
81
+ attention: Attention,
82
+ bias: Tuple[bool, bool, bool, bool] = (True, True, True, True),
83
+ residual_dropout: float = 0.0,
84
+ use_separate_proj_weight: bool = True,
85
+ dim_key: Optional[int] = None,
86
+ dim_value: Optional[int] = None,
87
+ in_proj_container: Optional[InputProjection] = None,
88
+ use_rotary_embeddings: Optional[bool] = False,
89
+ out_proj: Optional[nn.Module] = None,
90
+ *args,
91
+ **kwargs,
92
+ ):
93
+ super().__init__()
94
+ deprecated_function(self)
95
+
96
+ if isinstance(bias, bool):
97
+ logger.warning(
98
+ "Single bias value provided for the MHA projections."
99
+ + f" Assuming the same parameter ({bias}) is to be used everywhere"
100
+ )
101
+ bias = (bias, bias, bias, bias)
102
+
103
+ assert (
104
+ dim_model % num_heads == 0
105
+ ) # static preset for now, each head works on 1/d the embeddings, could be relaxed
106
+ assert num_heads > 0
107
+
108
+ # Popular default is that all latent dimensions are the same
109
+ dim_key, dim_value = map(lambda x: x if x else dim_model, (dim_key, dim_value))
110
+
111
+ self.num_heads = num_heads
112
+ self.dim_key_head = dim_key // num_heads
113
+ self.dim_value_head = dim_value // num_heads
114
+ self.dim_model = dim_model
115
+ self.attention = attention
116
+
117
+ # key, query, value projections for all heads
118
+ # critical options are
119
+ # - are we sharing weights ?
120
+ # - are we adding biases ?
121
+ if attention.requires_input_projection:
122
+ self.in_proj_container = (
123
+ in_proj_container
124
+ if in_proj_container is not None
125
+ else InputProjection(
126
+ query_proj_params=InputProjectionConfig(
127
+ dim_model, dim_key, bias=bias[0]
128
+ ),
129
+ key_proj_params=InputProjectionConfig(
130
+ dim_model, dim_key, bias=bias[1]
131
+ ),
132
+ value_proj_params=InputProjectionConfig(
133
+ dim_model, dim_value, bias=bias[2]
134
+ ),
135
+ use_separate_proj_weight=use_separate_proj_weight,
136
+ )
137
+ )
138
+
139
+ # Optional rotary embeddings
140
+ self.rotary_embeddings = (
141
+ RotaryEmbedding(self.dim_key_head) if use_rotary_embeddings else None
142
+ )
143
+
144
+ # Regularization
145
+ self.resid_drop = nn.Dropout(residual_dropout, inplace=False)
146
+
147
+ # Output projection
148
+ self.proj = (
149
+ out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias[3])
150
+ )
151
+ if isinstance(self.proj, nn.Linear) and self.proj.bias is not None:
152
+ constant_(self.proj.bias, 0.0)
153
+
154
+ def forward(
155
+ self,
156
+ query: torch.Tensor,
157
+ key: Optional[torch.Tensor] = None,
158
+ value: Optional[torch.Tensor] = None,
159
+ att_mask: Optional[torch.Tensor] = None,
160
+ key_padding_mask: Optional[torch.Tensor] = None,
161
+ ) -> torch.Tensor:
162
+ """
163
+ Expected input dimensions are [batch size, sequence length, embed dim]
164
+ Output dimensions are [batch size, sequence length, embed dim]
165
+ """
166
+
167
+ if key is None:
168
+ key = query
169
+ if value is None:
170
+ value = query
171
+
172
+ if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]:
173
+ max_batch = max((query.shape[0], key.shape[0], value.shape[0]))
174
+ query, key, value = map(
175
+ lambda x: x.expand(max_batch, -1, -1), [query, key, value]
176
+ )
177
+
178
+ B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent)
179
+ _, S_K, _ = key.size() # K, Q's sequence length could differ
180
+
181
+ # Catch different query and key length but a causal attention
182
+ if S_Q != S_K:
183
+ assert (
184
+ not self.attention.requires_same_k_q_dimensions
185
+ ), "This attention mechanism requires query and key to have the same sequence (context) lengths"
186
+
187
+ if hasattr(self.attention, "causal"):
188
+ assert not self.attention.causal, (
189
+ "Causal attention is not supported when key and query have different sequence lengths.\n"
190
+ + "In that case causality is ill-determined. Please pad your sequences accordingly"
191
+ )
192
+
193
+ kw_mask_args = {}
194
+ if att_mask is not None:
195
+ assert (
196
+ self.attention.supports_attention_mask
197
+ ), "This attention does not support attention masks"
198
+ kw_mask_args["att_mask"] = att_mask
199
+
200
+ if key_padding_mask is not None:
201
+ assert (
202
+ self.attention.supports_key_padding_mask
203
+ ), "This attention does not support key padding masks"
204
+ kw_mask_args["key_padding_mask"] = key_padding_mask
205
+
206
+ if self.attention.requires_skip_multi_head:
207
+ return self.attention(query, key, value, **kw_mask_args)
208
+
209
+ # Calculate query, key, values for all heads in batch
210
+ if self.attention.requires_input_projection:
211
+ q, k, v = self.in_proj_container(query=query, key=key, value=value)
212
+ else:
213
+ k, q, v = key, query, value
214
+
215
+ # Check the dimensions properly
216
+ def check(t, name):
217
+ assert (
218
+ t.shape[2] % self.num_heads == 0
219
+ ), f"the {name} embeddings need to be divisible by the number of heads"
220
+
221
+ check(q, "projected query")
222
+ check(v, "projected value")
223
+ check(k, "projected key")
224
+
225
+ # Optional: rotary embedding, add relative positioning information
226
+ if self.rotary_embeddings:
227
+ # rotary requires the head dimension
228
+ q = _split_heads(q, B, S_Q, self.num_heads, self.dim_key_head)
229
+ k = _split_heads(k, B, S_K, self.num_heads, self.dim_key_head)
230
+ v = _split_heads(v, B, S_K, self.num_heads, self.dim_value_head)
231
+
232
+ q, k = self.rotary_embeddings(q=q, k=k)
233
+
234
+ if not self.attention.requires_head_dimension:
235
+ q, k, v = q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1)
236
+
237
+ else:
238
+ # Reshape k/q/v to either expose the heads, or fold the head dimension into the batch
239
+ reshape_fn = (
240
+ _split_heads if self.attention.requires_head_dimension else _fold_heads
241
+ )
242
+
243
+ q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_key_head)
244
+ k = reshape_fn(k, B, S_K, self.num_heads, self.dim_key_head)
245
+ v = reshape_fn(v, B, S_K, self.num_heads, self.dim_value_head)
246
+
247
+ # Self-attend
248
+ y = self.attention(q, k, v, **kw_mask_args)
249
+
250
+ # Re-assemble all head outputs side by side
251
+ y = (
252
+ y.view(B, self.num_heads, S_Q, self.dim_value_head)
253
+ .transpose(1, 2)
254
+ .flatten(start_dim=2, end_dim=3)
255
+ )
256
+
257
+ # Output projection, dropout and good to go
258
+ y = self.resid_drop(self.proj(y))
259
+
260
+ # Return the same sequence size as the input
261
+ return y
262
+
263
+ @classmethod
264
+ def from_config(cls, config: MultiHeadDispatchConfig):
265
+ # Generate the class inputs from the config
266
+ fields = asdict(config)
267
+
268
+ # Skip all Nones so that default values are used
269
+ fields = {k: v for k, v in fields.items() if v is not None}
270
+
271
+ return cls(**fields)
.venv/lib/python3.11/site-packages/xformers/components/patch_embedding.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+
10
+ import torch
11
+
12
+ from xformers._deprecation_warning import deprecated_function
13
+
14
+
15
+ class PoolType(str, Enum):
16
+ Conv2D = "CONV_2D"
17
+ # ...
18
+ # TODO: Support more cases ?
19
+
20
+
21
+ @dataclass
22
+ class PatchEmbeddingConfig:
23
+ """
24
+ The configuration for the patch embedding layer, which takes the raw token passed in
25
+ and returns a pooled representation along a given embedding dimension.
26
+
27
+ This typically trades the spatial (context length) representation with the embedding size
28
+
29
+ This is canonicaly used by ViT, but other papers (like MetaFormer or other hierarchical transformers)
30
+ propose a more general use case for this
31
+ """
32
+
33
+ in_channels: int
34
+ out_channels: int
35
+ kernel_size: int
36
+ stride: int
37
+ padding: int = 0
38
+ pool_type: PoolType = PoolType.Conv2D
39
+
40
+
41
+ class ConditionalReshape(torch.nn.Module):
42
+ def __init__(self) -> None:
43
+ super().__init__()
44
+ deprecated_function(self)
45
+
46
+ def forward(self, x):
47
+ if x.ndim == 3:
48
+ B, HW, C = x.shape
49
+ # NOTE: We're assuming a square sample here
50
+ H = int(math.sqrt(HW))
51
+ assert H * H == HW, f"{H, HW}"
52
+ x = x.transpose(1, 2).reshape(B, C, H, H)
53
+
54
+ return x
55
+
56
+
57
+ class PatchToSequence(torch.nn.Module):
58
+ def __init__(self) -> None:
59
+ super().__init__()
60
+ deprecated_function(self)
61
+
62
+ def forward(self, x):
63
+ return x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C
64
+
65
+
66
+ def build_patch_embedding(config: PatchEmbeddingConfig):
67
+ if not isinstance(config, PatchEmbeddingConfig):
68
+ config = PatchEmbeddingConfig(**config)
69
+
70
+ if config.pool_type == PoolType.Conv2D:
71
+ pool = torch.nn.Conv2d(
72
+ config.in_channels,
73
+ config.out_channels,
74
+ kernel_size=config.kernel_size,
75
+ stride=config.stride,
76
+ padding=config.padding,
77
+ )
78
+ else:
79
+ raise NotImplementedError
80
+
81
+ # The patch embedding supposes that the input really is 2D in essence
82
+ # If this block is in the middle of a stack, we need to reshape
83
+ return torch.nn.Sequential(ConditionalReshape(), pool, PatchToSequence())
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Set, Union
9
+
10
+ from xformers.utils import (
11
+ generate_matching_config,
12
+ get_registry_decorator,
13
+ import_all_modules,
14
+ )
15
+
16
+ from .base import PositionEmbedding, PositionEmbeddingConfig # noqa
17
+
18
+ # CREDITS: Classy Vision registry mechanism
19
+
20
+ POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {}
21
+ POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set()
22
+
23
+
24
+ def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]):
25
+ """Builds a position encoding from a config.
26
+
27
+ This assumes a 'name' key in the config which is used to determine what
28
+ attention class to instantiate. For instance, a config `{"name": "my_position_encoding",
29
+ "foo": "bar"}` will find a class that was registered as "my_position_encoding"
30
+ (see :func:`register_positional_embedding`) and call .from_config on it."""
31
+
32
+ if not isinstance(config, PositionEmbeddingConfig):
33
+ config_instance = generate_matching_config(
34
+ config, POSITION_EMBEDDING_REGISTRY[config["name"]].config
35
+ )
36
+ else:
37
+ config_instance = config
38
+
39
+ return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config(
40
+ config_instance
41
+ )
42
+
43
+
44
+ """Registers a PositionEncoding subclass.
45
+
46
+ This decorator allows xFormers to instantiate a subclass of PositionEncoding
47
+ from a configuration file, even if the class itself is not part of the
48
+ xFormers framework. To use it, apply this decorator to a `PositionEncoding`
49
+ subclass, like this:
50
+
51
+ .. code-block:: python
52
+
53
+ @dataclass
54
+ class MyConfig:
55
+ ...
56
+
57
+ @register_positional_embedding('my_encoding', MyConfig)
58
+ class MyEncoding(PositionEncoding):
59
+ ...
60
+
61
+ To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`."""
62
+ register_positional_embedding: Callable[
63
+ [str, Any], Callable[[Any], Any]
64
+ ] = get_registry_decorator(
65
+ POSITION_EMBEDDING_REGISTRY,
66
+ POSITION_EMBEDDING_CLASS_NAMES,
67
+ PositionEmbedding,
68
+ PositionEmbeddingConfig,
69
+ )
70
+
71
+
72
+ from .rotary import RotaryEmbedding # noqa
73
+ from .sine import SinePositionalEmbedding # type: ignore # noqa
74
+ from .vocab import VocabEmbedding # noqa
75
+
76
+ __all__ = [
77
+ "RotaryEmbedding",
78
+ "SinePositionalEmbedding",
79
+ "VocabEmbedding",
80
+ "build_positional_embedding",
81
+ "register_positional_embedding",
82
+ ]
83
+
84
+ # automatically import any Python files in the directory
85
+ import_all_modules(
86
+ str(Path(__file__).parent), "xformers.components.positional_embedding"
87
+ )
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.53 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/base.cpython-311.pyc ADDED
Binary file (2.38 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/param.cpython-311.pyc ADDED
Binary file (2.87 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/rotary.cpython-311.pyc ADDED
Binary file (4.85 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/sine.cpython-311.pyc ADDED
Binary file (2.67 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/vocab.cpython-311.pyc ADDED
Binary file (3.52 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/base.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from abc import ABCMeta, abstractmethod
8
+ from dataclasses import asdict, dataclass
9
+ from typing import Type, TypeVar
10
+
11
+ import torch.nn as nn
12
+
13
+ from xformers._deprecation_warning import deprecated_function
14
+
15
+ Self = TypeVar("Self", bound="PositionEmbedding")
16
+
17
+
18
+ @dataclass
19
+ class PositionEmbeddingConfig:
20
+ name: str
21
+ dim_model: int
22
+ seq_len: int
23
+
24
+
25
+ class PositionEmbedding(nn.Module, metaclass=ABCMeta):
26
+ @abstractmethod
27
+ def __init__(self, *args, **kwargs) -> None:
28
+ super().__init__()
29
+ deprecated_function(self)
30
+
31
+ @classmethod
32
+ def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self:
33
+ # Generate the class inputs from the config
34
+ fields = asdict(config)
35
+
36
+ # Skip all Nones so that default values are used
37
+ fields = {k: v for k, v in fields.items() if v is not None}
38
+ return cls(**fields)
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/param.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+
11
+ from xformers.components.positional_embedding import (
12
+ PositionEmbedding,
13
+ PositionEmbeddingConfig,
14
+ register_positional_embedding,
15
+ )
16
+
17
+
18
+ @dataclass
19
+ class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig):
20
+ name: str
21
+ seq_len: int
22
+ dim_model: int
23
+ add_class_token: bool
24
+
25
+
26
+ @register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig)
27
+ class LearnablePositionalEmbedding(PositionEmbedding):
28
+ def __init__(
29
+ self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__
30
+ ):
31
+ super().__init__()
32
+
33
+ # 0.02 is BERT initialization
34
+ self.pos_emb = torch.nn.Parameter(
35
+ torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02
36
+ )
37
+
38
+ self.class_token = (
39
+ torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None
40
+ )
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ if self.class_token is not None:
44
+ # Prepend class token
45
+ clf_token = (
46
+ torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device)
47
+ * self.class_token
48
+ )
49
+ x = torch.cat([clf_token, x], dim=1)
50
+
51
+ if x.ndim == 2:
52
+ x = x.unsqueeze(-1)
53
+
54
+ return x + self.pos_emb
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/rotary.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
8
+ # NOTE: Almost the same right now, moving parts to Triton is the next step
9
+
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+
15
+ def rotate_half(x):
16
+ x1, x2 = x.chunk(2, dim=-1)
17
+ return torch.cat((-x2, x1), dim=-1)
18
+
19
+
20
+ @torch.jit.script
21
+ def apply_rotary_pos_emb(x, cos, sin):
22
+ # NOTE: This could probably be moved to Triton
23
+
24
+ # Handle a possible sequence length mismatch in between q and k
25
+ cos = cos[:, :, : x.shape[-2], :]
26
+ sin = sin[:, :, : x.shape[-2], :]
27
+
28
+ return (x * cos) + (rotate_half(x) * sin)
29
+
30
+
31
+ class RotaryEmbedding(torch.nn.Module):
32
+ """
33
+ The rotary position embeddings from RoFormer_ (Su et. al).
34
+ A crucial insight from the method is that the query and keys are
35
+ transformed by rotation matrices which depend on the relative positions.
36
+
37
+ Other implementations are available in the Rotary Transformer repo_ and in
38
+ GPT-NeoX_, GPT-NeoX was an inspiration
39
+
40
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
41
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
42
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
43
+
44
+
45
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
46
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
47
+ """
48
+
49
+ def __init__(self, dim_model: int, *_, **__):
50
+ super().__init__()
51
+ # Generate and save the inverse frequency buffer (non trainable)
52
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
53
+ self.register_buffer("inv_freq", inv_freq)
54
+
55
+ self._seq_len_cached = None
56
+ self._cos_cached = None
57
+ self._sin_cached = None
58
+
59
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
60
+ seq_len = x.shape[seq_dimension]
61
+
62
+ # Reset the tables if the sequence length has changed,
63
+ # or if we're on a new device (possibly due to tracing for instance)
64
+ if (
65
+ seq_len != self._seq_len_cached
66
+ or self._cos_cached.device != x.device
67
+ or self._cos_cached.dtype != x.dtype
68
+ ):
69
+ self._seq_len_cached = seq_len
70
+ t = torch.arange(
71
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
72
+ )
73
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
74
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
75
+
76
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
77
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
78
+
79
+ return self._cos_cached, self._sin_cached
80
+
81
+ def forward(
82
+ self, q: torch.Tensor, k: torch.Tensor
83
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
84
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
85
+ k, seq_dimension=-2
86
+ )
87
+
88
+ return (
89
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
90
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
91
+ )
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/sine.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Silence Mypy errors in this file.
8
+ # type: ignore
9
+
10
+ import math
11
+
12
+ import torch
13
+
14
+ from xformers.components.positional_embedding import (
15
+ PositionEmbedding,
16
+ PositionEmbeddingConfig,
17
+ register_positional_embedding,
18
+ )
19
+
20
+
21
+ @register_positional_embedding("sine", PositionEmbeddingConfig)
22
+ class SinePositionalEmbedding(PositionEmbedding):
23
+ def __init__(self, dim_model: int, *args, **kwargs):
24
+ super().__init__()
25
+ self.dim_model = dim_model
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ seq_len = x.shape[1]
29
+ pos = (
30
+ torch.arange(0, seq_len, device=x.device, dtype=torch.float32)
31
+ .unsqueeze(1)
32
+ .repeat(1, self.dim_model)
33
+ )
34
+ dim = (
35
+ torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32)
36
+ .unsqueeze(0)
37
+ .repeat(seq_len, 1)
38
+ )
39
+ div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model))
40
+ pos *= div
41
+ pos[:, 0::2] = torch.sin(pos[:, 0::2])
42
+ pos[:, 1::2] = torch.cos(pos[:, 1::2])
43
+
44
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
45
+
46
+ return output + pos.unsqueeze(0)
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/vocab.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from xformers.components.positional_embedding import (
14
+ PositionEmbedding,
15
+ PositionEmbeddingConfig,
16
+ register_positional_embedding,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class VocabEmbeddingConfig(PositionEmbeddingConfig):
22
+ vocab_size: int
23
+ dropout: float
24
+
25
+
26
+ @register_positional_embedding("vocab", VocabEmbeddingConfig)
27
+ class VocabEmbedding(PositionEmbedding):
28
+ def __init__(
29
+ self,
30
+ dim_model: int,
31
+ seq_len: int,
32
+ vocab_size: int,
33
+ dropout: float = 0.0,
34
+ *args,
35
+ **kwargs
36
+ ):
37
+ super().__init__()
38
+
39
+ self.vocab_size = vocab_size
40
+ self.dim_model = dim_model
41
+
42
+ self.dropout = torch.nn.Dropout(p=dropout)
43
+ self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
44
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
45
+
46
+ self.position_ids: Optional[torch.Tensor] = None
47
+
48
+ self.init_weights()
49
+
50
+ def init_weights(self, gain: float = 1.0):
51
+ torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
52
+ torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain)
53
+
54
+ def forward(self, x: torch.Tensor):
55
+ position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[
56
+ None, :
57
+ ].repeat(x.shape[0], 1)
58
+
59
+ X_token = self.word_embeddings(x)
60
+ X_pos = self.position_embeddings(position_ids)
61
+
62
+ X = X_token + X_pos
63
+ X = self.dropout(X)
64
+
65
+ return X
.venv/lib/python3.11/site-packages/xformers/components/residual.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from collections import namedtuple
8
+ from enum import Enum
9
+ from typing import List, Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from xformers._deprecation_warning import deprecated_function
15
+
16
+
17
+ class ResidualNormStyle(str, Enum):
18
+ """Support different residual path and norm styles.
19
+ See "On Layer Normalization in the Transformer Architecture",
20
+ Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf
21
+ """
22
+
23
+ Pre = "pre"
24
+ Post = "post"
25
+ DeepNorm = "deepnorm"
26
+
27
+
28
+ class NormalizationType(str, Enum):
29
+ LayerNorm = "layernorm"
30
+ Skip = "skip"
31
+ # TODO: BatchNorm = "batchnorm"
32
+ # TODO: GroupNorm = "groupnorm"
33
+
34
+
35
+ def get_normalization_layer(normalization_type: NormalizationType):
36
+ class Skip(nn.Module):
37
+ def __init__(self, *_, **__) -> None:
38
+ super().__init__()
39
+ deprecated_function(self)
40
+
41
+ def forward(self, x: torch.Tensor, **_):
42
+ return x
43
+
44
+ return {
45
+ NormalizationType.LayerNorm: nn.LayerNorm,
46
+ NormalizationType.Skip: Skip,
47
+ }[normalization_type]
48
+
49
+
50
+ class RequiresWrappedInputs:
51
+ """Used to mark, through inheritance,
52
+ the fact that this class will require inputs to be passed as a single list"""
53
+
54
+ pass
55
+
56
+
57
+ # CREDITS: the following is inspired by FastAI's Transformer implementation
58
+ class Residual(nn.Module, RequiresWrappedInputs):
59
+ """
60
+ Object-oriented handling of the residual path
61
+
62
+ This supports scaling of the residual path, as proposed by DeepNet_
63
+ .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
64
+
65
+ .. Note: the wrapped layers must accept all the inputs as a single list
66
+ """
67
+
68
+ def __init__(self, layer: nn.Module, scale: Optional[float] = None):
69
+ super().__init__()
70
+ deprecated_function(self)
71
+ self.layer = layer
72
+ self.scale = scale
73
+
74
+ # PreNorm and PostNorm require all the tensors to be passed as a list
75
+ self.wrap_inputs = isinstance(layer, RequiresWrappedInputs)
76
+
77
+ def forward(self, inputs: List[torch.Tensor], **kwargs):
78
+ if self.scale is not None:
79
+ residue = inputs[0] * self.scale
80
+ else:
81
+ residue = inputs[0]
82
+
83
+ if self.wrap_inputs:
84
+ return residue + self.layer(inputs=inputs, **kwargs)
85
+
86
+ else:
87
+ return residue + self.layer(*inputs, **kwargs)
88
+
89
+
90
+ class PreNorm(nn.Module, RequiresWrappedInputs):
91
+ """Adds a normalization before computing attention
92
+
93
+ ..Note: If a list of inputs is passed, all of them get normalized"""
94
+
95
+ def __init__(
96
+ self,
97
+ d_norm: int,
98
+ sublayer: nn.Module,
99
+ normalization: NormalizationType,
100
+ use_triton: bool = True,
101
+ ):
102
+
103
+ super().__init__()
104
+ deprecated_function(self)
105
+ self.norm = get_normalization_layer(normalization)(d_norm)
106
+
107
+ self.sublayer = sublayer
108
+ self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
109
+
110
+ def forward(self, inputs: List[torch.Tensor], **kwargs):
111
+ assert len(inputs) > 0
112
+
113
+ # Perf improvement: if the inputs are all the same, only norm once
114
+ ids = [id(x) for x in inputs]
115
+ if ids.count(ids[0]) == len(ids):
116
+ # The same tensor is passed multiple times
117
+ x_norm = self.norm(inputs[0])
118
+ inputs_normed = [x_norm for _ in inputs]
119
+ else:
120
+ # The inputs differ, norm them all
121
+ inputs_normed = [self.norm(x_) for x_ in inputs]
122
+
123
+ if self.wrap_inputs:
124
+ return self.sublayer(inputs=inputs_normed, **kwargs)
125
+ else:
126
+ return self.sublayer(*inputs_normed, **kwargs)
127
+
128
+
129
+ class PostNorm(nn.Module, RequiresWrappedInputs):
130
+ """Adds LayerNorm after computing attention"""
131
+
132
+ def __init__(
133
+ self,
134
+ d_norm: int,
135
+ sublayer: nn.Module,
136
+ normalization: NormalizationType,
137
+ use_triton: bool = True,
138
+ ):
139
+ super().__init__()
140
+ deprecated_function(self)
141
+ self.norm = get_normalization_layer(normalization)(d_norm)
142
+
143
+ self.sublayer = sublayer
144
+ self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
145
+
146
+ def forward(self, inputs: List[torch.Tensor], **kwargs):
147
+ if self.wrap_inputs:
148
+ x = self.sublayer(inputs=inputs, **kwargs)
149
+ else:
150
+ x = self.sublayer(*inputs, **kwargs)
151
+ return self.norm(x)
152
+
153
+
154
+ DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
155
+
156
+
157
+ def get_deepnorm_coefficients(
158
+ encoder_layers: int, decoder_layers: int
159
+ ) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]:
160
+ """
161
+ See DeepNet_.
162
+
163
+ Returns alpha and beta depending on the number of encoder and decoder layers,
164
+ first tuple is for the encoder and second for the decoder
165
+
166
+ .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
167
+ """
168
+
169
+ N = encoder_layers
170
+ M = decoder_layers
171
+
172
+ if decoder_layers == 0:
173
+ # Encoder only
174
+ return (
175
+ DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25),
176
+ None,
177
+ )
178
+
179
+ elif encoder_layers == 0:
180
+ # Decoder only
181
+ return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25)
182
+ else:
183
+ # Encoder/decoder
184
+ encoder_coeffs = DeepNormCoefficients(
185
+ alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625
186
+ )
187
+
188
+ decoder_coeffs = DeepNormCoefficients(
189
+ alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25
190
+ )
191
+
192
+ return (encoder_coeffs, decoder_coeffs)
.venv/lib/python3.11/site-packages/xformers/components/reversible.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import List
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.autograd.function import Function
12
+ from torch.utils.checkpoint import get_device_states, set_device_states
13
+
14
+ from xformers._deprecation_warning import deprecated_function
15
+ from xformers.components import RequiresWrappedInputs
16
+
17
+ # CREDITS: Code adapted from
18
+ # https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
19
+ # https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py,
20
+ # https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
21
+
22
+
23
+ # pyre-fixme[13]: `cpu_state` is not initialized in the constructor.
24
+ class Deterministic(nn.Module):
25
+ def __init__(self, net: nn.Module):
26
+ super().__init__()
27
+ deprecated_function(self)
28
+ self.net = net
29
+ self.cpu_state: torch.Tensor = torch.get_rng_state()
30
+ self.cuda_in_fwd: bool = False
31
+ self.gpu_devices: List[int] = []
32
+ self.gpu_states: List[torch.Tensor] = []
33
+ self.wrap_inputs = isinstance(net, RequiresWrappedInputs)
34
+
35
+ def record_rng(self, *args):
36
+ self.cpu_state = torch.get_rng_state()
37
+ if torch.cuda._initialized:
38
+ self.cuda_in_fwd = True
39
+ self.gpu_devices, self.gpu_states = get_device_states(*args)
40
+
41
+ def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwargs):
42
+ if record_rng:
43
+ self.record_rng(*args)
44
+
45
+ if not set_rng:
46
+ # Normal FW run
47
+ if self.wrap_inputs:
48
+ return self.net(inputs=args, **kwargs)
49
+ else:
50
+ return self.net(*args, **kwargs)
51
+
52
+ else: # pragma: no cover # this is called in the backward pass, not picked up
53
+ # This is analogous to checkpointing, reset the original random state
54
+ rng_devices: List[int] = []
55
+ if self.cuda_in_fwd:
56
+ rng_devices = self.gpu_devices
57
+
58
+ with torch.random.fork_rng(devices=rng_devices, enabled=True):
59
+ torch.set_rng_state(self.cpu_state)
60
+ if self.cuda_in_fwd:
61
+ set_device_states(self.gpu_devices, self.gpu_states)
62
+
63
+ if self.wrap_inputs:
64
+ return self.net(inputs=args, **kwargs)
65
+ else:
66
+ return self.net(*args, **kwargs)
67
+
68
+
69
+ class ReversibleBlock(nn.Module):
70
+ def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1):
71
+ super().__init__()
72
+ self.f = Deterministic(f)
73
+ self.g = Deterministic(g)
74
+ self.split_dim = split_dim
75
+
76
+ def forward(self, x: torch.Tensor, f_args={}, g_args={}):
77
+ x1, x2 = torch.chunk(x, 2, dim=-1)
78
+ y1, y2 = None, None
79
+
80
+ with torch.no_grad():
81
+ y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
82
+ y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
83
+
84
+ return torch.cat([y1, y2], dim=self.split_dim)
85
+
86
+ def backward_pass(
87
+ self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={}
88
+ ): # pragma: no cover # this is covered, but called directly from C++
89
+ y1, y2 = torch.chunk(y, 2, dim=self.split_dim)
90
+ del y
91
+
92
+ dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim)
93
+ del dy
94
+
95
+ with torch.enable_grad():
96
+ y1.requires_grad = True
97
+ gy1 = self.g(y1, set_rng=True, **g_args)
98
+ torch.autograd.backward(gy1, dy2)
99
+
100
+ with torch.no_grad():
101
+ x2 = y2 - gy1
102
+ del y2, gy1
103
+
104
+ dx1 = dy1 + y1.grad
105
+ del dy1
106
+ y1.grad = None
107
+
108
+ with torch.enable_grad():
109
+ x2.requires_grad = True
110
+ fx2 = self.f(x2, set_rng=True, **f_args)
111
+ torch.autograd.backward(fx2, dx1)
112
+
113
+ with torch.no_grad():
114
+ x1 = y1 - fx2
115
+ del y1, fx2
116
+
117
+ dx2 = dy2 + x2.grad
118
+ del dy2
119
+ x2.grad = None
120
+
121
+ x = torch.cat([x1, x2.detach()], dim=self.split_dim)
122
+ dx = torch.cat([dx1, dx2], dim=self.split_dim)
123
+
124
+ return x, dx
125
+
126
+
127
+ class _ReversibleFunction(Function):
128
+ @staticmethod
129
+ def forward(ctx, x, blocks, kwargs):
130
+ ctx.kwargs = kwargs
131
+ for block in blocks:
132
+ x = block(x, **kwargs)
133
+ ctx.y = x.detach()
134
+ ctx.blocks = blocks
135
+ return x
136
+
137
+ @staticmethod
138
+ def backward(
139
+ ctx, dy
140
+ ): # pragma: no cover # this is covered, but called directly from C++
141
+ y = ctx.y
142
+ kwargs = ctx.kwargs
143
+ for block in ctx.blocks[::-1]:
144
+ y, dy = block.backward_pass(y, dy, **kwargs)
145
+ return dy, None, None
146
+
147
+
148
+ class ReversibleSequence(nn.Module):
149
+ def __init__(self, blocks: nn.ModuleList):
150
+ super().__init__()
151
+ deprecated_function(self)
152
+
153
+ # pyre-fixme[23]: Unable to unpack `torch.nn.Module` into 2 values.
154
+ self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks])
155
+
156
+ def forward(self, x, arg_route=(True, False), **kwargs):
157
+ f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
158
+ block_kwargs = {"f_args": f_args, "g_args": g_args}
159
+
160
+ return _ReversibleFunction.apply(x, self.blocks, block_kwargs)
.venv/lib/python3.11/site-packages/xformers/components/simplicial_embedding.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import asdict, dataclass
7
+ from typing import Optional, Type, TypeVar
8
+
9
+ import torch
10
+
11
+ from xformers._deprecation_warning import deprecated_function
12
+
13
+ Self = TypeVar("Self", bound="SimplicialEmbedding")
14
+
15
+
16
+ @dataclass
17
+ class SimplicialEmbeddingConfig:
18
+ L: int
19
+ temperature: float
20
+
21
+
22
+ class SimplicialEmbedding(torch.nn.Module):
23
+ """
24
+ An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al
25
+
26
+ Arguments:
27
+ - L: the number of embedding chunks
28
+ - temperature: optional scaling parameter for the softmax operation.
29
+ A small (<1.) temperature will lead to a sparse representation (up to one-hot),
30
+ while a large (>1.) temperature will make the vector more uniform
31
+
32
+ _"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf
33
+ """
34
+
35
+ def __init__(self, L: int, temperature: Optional[float] = None) -> None:
36
+ super().__init__()
37
+ deprecated_function(self)
38
+ self.L = L
39
+ self.temperature = temperature
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ assert (
43
+ x.shape[-1] % self.L == 0
44
+ ), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}"
45
+
46
+ # Separate the input tensor into V chunks
47
+ B, C, E = x.shape
48
+ V = E // self.L
49
+
50
+ Vs = x.reshape(B, C, self.L, V)
51
+
52
+ # Softmax normalize them, with the proposed temperature
53
+ # This is done over the last dimension, so only within Vs
54
+ if self.temperature is not None:
55
+ Vs /= self.temperature
56
+
57
+ Vs = torch.nn.functional.softmax(Vs, dim=-1)
58
+
59
+ # Concatenate back and return
60
+ return Vs.reshape(B, C, E)
61
+
62
+ @classmethod
63
+ def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self:
64
+ # Generate the class inputs from the config
65
+ fields = asdict(config)
66
+
67
+ return cls(**fields)
.venv/lib/python3.11/site-packages/xformers/ops/__init__.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from .fmha import (
9
+ AttentionBias,
10
+ AttentionOp,
11
+ AttentionOpBase,
12
+ LowerTriangularMask,
13
+ MemoryEfficientAttentionCkOp,
14
+ MemoryEfficientAttentionCutlassFwdFlashBwOp,
15
+ MemoryEfficientAttentionCutlassOp,
16
+ MemoryEfficientAttentionFlashAttentionOp,
17
+ MemoryEfficientAttentionSplitKCkOp,
18
+ memory_efficient_attention,
19
+ memory_efficient_attention_backward,
20
+ memory_efficient_attention_forward,
21
+ memory_efficient_attention_forward_requires_grad,
22
+ )
23
+ from .indexing import index_select_cat, scaled_index_add
24
+ from .ipc import init_ipc
25
+ from .modpar_layers import ColumnParallelLinear, RowParallelLinear
26
+ from .rmsnorm import RMSNorm
27
+ from .rope_padded import rope_padded
28
+ from .seqpar import sequence_parallel_leading_matmul, sequence_parallel_trailing_matmul
29
+ from .sequence_parallel_fused_ops import (
30
+ fused_allgather_and_anything,
31
+ fused_allgather_and_linear,
32
+ fused_anything_and_reducescatter,
33
+ fused_linear_and_reducescatter,
34
+ )
35
+ from .sp24 import Sparse24Tensor, sparsify24, sparsify24_like
36
+ from .swiglu_op import (
37
+ SwiGLU,
38
+ SwiGLUEagerOp,
39
+ SwiGLUFusedOp,
40
+ SwiGLUOp,
41
+ SwiGLUOpDispatch,
42
+ SwiGLUPackedFusedOp,
43
+ swiglu,
44
+ )
45
+ from .tiled_matmul import tiled_matmul
46
+ from .unbind import get_stack_strides, stack_or_none, unbind
47
+
48
+ # BW compatibility
49
+ AttentionMask = AttentionBias
50
+
51
+
52
+ def masked_matmul(a, b, mask=None):
53
+ if torch.overrides.has_torch_function((a, b, mask)):
54
+ return torch.overrides.handle_torch_function(
55
+ masked_matmul, (a, b, mask), a, b, mask
56
+ )
57
+
58
+ att = a @ b
59
+
60
+ if mask is None:
61
+ return att
62
+
63
+ if mask.dtype == torch.bool:
64
+ if mask.ndim == 2:
65
+ mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
66
+ # mask is presumed false == ignore
67
+ att[~mask] = float("-inf")
68
+ else:
69
+ # mask is presumed additive
70
+ att += mask
71
+ return att
72
+
73
+
74
+ __all__ = [
75
+ # fmha
76
+ "AttentionBias",
77
+ "AttentionMask",
78
+ "AttentionOp",
79
+ "AttentionOpBase",
80
+ "LowerTriangularMask",
81
+ "MemoryEfficientAttentionCutlassFwdFlashBwOp",
82
+ "MemoryEfficientAttentionCutlassOp",
83
+ "MemoryEfficientAttentionFlashAttentionOp",
84
+ "MemoryEfficientAttentionCkOp",
85
+ "MemoryEfficientAttentionSplitKCkOp",
86
+ "memory_efficient_attention",
87
+ "memory_efficient_attention_backward",
88
+ "memory_efficient_attention_forward",
89
+ "memory_efficient_attention_forward_requires_grad",
90
+ # indexing
91
+ "index_select_cat",
92
+ "scaled_index_add",
93
+ # ipc
94
+ "init_ipc",
95
+ # modpar_layers
96
+ "ColumnParallelLinear",
97
+ "RowParallelLinear",
98
+ # rmsnorm
99
+ "RMSNorm",
100
+ # rope_padded
101
+ "rope_padded",
102
+ # seqpar
103
+ "sequence_parallel_leading_matmul",
104
+ "sequence_parallel_trailing_matmul",
105
+ # sequence_parallel_fused_ops
106
+ "fused_allgather_and_anything",
107
+ "fused_allgather_and_linear",
108
+ "fused_anything_and_reducescatter",
109
+ "fused_linear_and_reducescatter",
110
+ # swiglu_op
111
+ "SwiGLU",
112
+ "SwiGLUEagerOp",
113
+ "SwiGLUFusedOp",
114
+ "SwiGLUOp",
115
+ "SwiGLUOpDispatch",
116
+ "SwiGLUPackedFusedOp",
117
+ "swiglu",
118
+ # tiled_matmul
119
+ "tiled_matmul",
120
+ # unbind
121
+ "get_stack_strides",
122
+ "stack_or_none",
123
+ "unbind",
124
+ # sp24
125
+ "sparsify24",
126
+ "sparsify24_like",
127
+ "Sparse24Tensor",
128
+ # .
129
+ "masked_matmul",
130
+ ]
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.34 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_index_select_cat.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.jit
12
+ def index_select_cat_fwd_kernel(
13
+ output_ptr, # *Pointer* to output tensor.
14
+ source_ptr, # *Pointer* to source tensor.
15
+ index_ptr, # *Pointer* to index tensor.
16
+ num_indices,
17
+ num_cols,
18
+ stride0, # Stride information of source tensor.
19
+ stride1,
20
+ BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
21
+ BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
22
+ ):
23
+ pid0 = tl.program_id(axis=0) # We use 2D launch grid
24
+ pid1 = tl.program_id(axis=1)
25
+
26
+ indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
27
+ rows = tl.load(index_ptr + indices, mask=(indices < num_indices))
28
+ cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
29
+
30
+ source_offsets = source_ptr + rows[:, None] * stride0 + cols[None, :] * stride1
31
+ mask = (indices[:, None] < num_indices) & (cols[None, :] < num_cols)
32
+ output = tl.load(source_offsets, mask=mask)
33
+
34
+ output_offsets = output_ptr + indices[:, None] * stride0 + cols[None, :] * stride1
35
+ tl.store(output_offsets, output, mask=mask)
36
+
37
+
38
+ def index_select_cat_fwd(
39
+ output: torch.Tensor,
40
+ source: torch.Tensor,
41
+ index: torch.Tensor,
42
+ ):
43
+ if not (source.is_cuda and index.is_cuda):
44
+ raise ValueError("The index tensor and the source tensor must be of type CUDA!")
45
+
46
+ if not source.ndim == 2:
47
+ raise ValueError(f"Expected 2-dimensional tensor, got {source.ndim}.")
48
+ if not index.ndim == 1:
49
+ raise ValueError(f"Expected 1-dimensional tensor, got {index.ndim}.")
50
+
51
+ num_rows, num_cols = source.shape
52
+ num_indices = index.shape[0]
53
+
54
+ if not num_indices < num_rows:
55
+ raise ValueError(
56
+ "The number of indices cannot exceed the number of rows in the source matrix."
57
+ )
58
+
59
+ stride0, stride1 = source.stride(0), source.stride(1)
60
+
61
+ def grid(meta):
62
+ return (
63
+ triton.cdiv(num_indices, meta["BLOCK_SIZE_INDEX"]),
64
+ triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
65
+ )
66
+
67
+ index_select_cat_fwd_kernel[grid](
68
+ output,
69
+ source,
70
+ index,
71
+ num_indices,
72
+ num_cols,
73
+ stride0,
74
+ stride1,
75
+ BLOCK_SIZE_INDEX=1,
76
+ BLOCK_SIZE_COL=512,
77
+ )
78
+
79
+ return output
80
+
81
+
82
+ @triton.jit
83
+ def index_select_cat_bwd_kernel(
84
+ grad_source_ptr, # *Pointer* to grad_source tensor.
85
+ index_ptr, # *Pointer* to index tensor.
86
+ grad_output_ptr, # *Pointer* to grad_output tensor.
87
+ num_rows,
88
+ num_indices,
89
+ num_cols,
90
+ stride0, # Stride information of input and source tensor.
91
+ stride1,
92
+ BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
93
+ BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
94
+ ):
95
+ pid0 = tl.program_id(axis=0) # We use 3D launch grid
96
+ pid1 = tl.program_id(axis=1)
97
+
98
+ cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
99
+
100
+ # load grad_output
101
+ grad_output_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
102
+ grad_output_offsets = (
103
+ grad_output_ptr
104
+ + grad_output_indices[:, None] * stride0
105
+ + cols[None, :] * stride1
106
+ )
107
+ grad_output_mask = (grad_output_indices[:, None] < num_indices) & (
108
+ cols[None, :] < num_cols
109
+ )
110
+ grad_output = tl.load(grad_output_offsets, mask=grad_output_mask).to(tl.float32)
111
+
112
+ # select indices from grad_source
113
+ grad_source_indices = tl.load(
114
+ index_ptr + grad_output_indices, mask=(grad_output_indices < num_indices)
115
+ )
116
+ grad_source_offsets = (
117
+ grad_source_ptr
118
+ + grad_source_indices[:, None] * stride0
119
+ + cols[None, :] * stride1
120
+ )
121
+
122
+ # compute scaled index add and save
123
+ tl.store(grad_source_offsets, grad_output, mask=grad_output_mask)
124
+
125
+
126
+ def index_select_cat_bwd(
127
+ grad_source: torch.Tensor,
128
+ index: torch.Tensor,
129
+ grad_output: torch.Tensor,
130
+ ):
131
+ if not (grad_source.is_cuda and grad_output.is_cuda):
132
+ raise ValueError("The grad_source and grad_output tensor must be of type CUDA!")
133
+
134
+ if not (grad_source.ndim == 2 and grad_output.ndim == 2):
135
+ raise ValueError(
136
+ f"The grad_source and grad_output must be three-dimensional "
137
+ f"(got {grad_source.ndim} and {grad_output.ndim})!"
138
+ )
139
+ if not grad_source.shape[1] == grad_output.shape[1]:
140
+ raise ValueError(
141
+ f"The number of elements along dimension 1 of grad_source and grad_output must be the same "
142
+ f"(got {grad_source.shape[1]} and {grad_output.shape[1]})"
143
+ )
144
+
145
+ num_rows, num_cols = grad_source.shape
146
+ num_indices, num_cols = grad_output.shape
147
+ if not num_rows >= num_indices:
148
+ raise ValueError(
149
+ f"The number of elements along dimension 0 of grad_source must be larger than that of grad_output "
150
+ f"(got {num_rows} and {num_indices})!"
151
+ )
152
+ if not index.shape[0] == num_indices:
153
+ raise ValueError(
154
+ f"The number of indices and the number of elements along dimension 0 of grad_output must match "
155
+ f"(got {index.shape[0]} and {num_indices})!"
156
+ )
157
+
158
+ stride0, stride1 = grad_source.stride(0), grad_source.stride(1)
159
+ if not (grad_output.stride(0) == stride0 and grad_output.stride(1) == stride1):
160
+ raise ValueError(
161
+ f"The strides of the grad_source and grad_output tensors must match "
162
+ f"(got {stride0} vs. {grad_output.stride(0)}, {stride1} vs. {grad_output.stride(1)})!"
163
+ )
164
+
165
+ def grid(meta):
166
+ return (
167
+ triton.cdiv(num_indices, meta["BLOCK_SIZE_INDEX"]),
168
+ triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
169
+ )
170
+
171
+ index_select_cat_bwd_kernel[grid](
172
+ grad_source,
173
+ index,
174
+ grad_output,
175
+ num_rows,
176
+ num_indices,
177
+ num_cols,
178
+ grad_source.stride(0),
179
+ grad_source.stride(1),
180
+ BLOCK_SIZE_INDEX=1,
181
+ BLOCK_SIZE_COL=512,
182
+ )
183
+
184
+ return
.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_scaled_index_add.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+
13
+ @triton.jit
14
+ def scaled_index_add_fwd_kernel(
15
+ input_ptr, # *Pointer* to input tensor.
16
+ index_ptr, # *Pointer* to index tensor.
17
+ source_ptr, # *Pointer* to source tensor.
18
+ scaling_ptr, # *Pointer* to the scaling tensor.
19
+ alpha,
20
+ num_inp_indices,
21
+ num_src_indices,
22
+ num_rows,
23
+ num_cols,
24
+ stride0, # Stride information of input and source tensor.
25
+ stride1,
26
+ stride2,
27
+ BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
28
+ BLOCK_SIZE_ROW: tl.constexpr, # Number of rows each program should process.
29
+ BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
30
+ HAS_SCALING: tl.constexpr, # Boolean indicating if the scaling factor is present.
31
+ ):
32
+ pid0 = tl.program_id(axis=0) # We use 3D launch grid
33
+ pid1 = tl.program_id(axis=1)
34
+ pid2 = tl.program_id(axis=2)
35
+
36
+ rows = pid1 * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
37
+ cols = pid2 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
38
+
39
+ # load source
40
+ source_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
41
+ source_offsets = (
42
+ source_ptr
43
+ + source_indices[:, None, None] * stride0
44
+ + rows[None, :, None] * stride1
45
+ + cols[None, None, :] * stride2
46
+ )
47
+ source_mask = (
48
+ (source_indices[:, None, None] < num_src_indices)
49
+ & (rows[None, :, None] < num_rows)
50
+ & (cols[None, None, :] < num_cols)
51
+ )
52
+ source = tl.load(source_offsets, mask=source_mask).to(tl.float32)
53
+
54
+ # load input
55
+ input_indices = tl.load(
56
+ index_ptr + source_indices, mask=(source_indices < num_src_indices)
57
+ )
58
+ input_offsets = (
59
+ input_ptr
60
+ + input_indices[:, None, None] * stride0
61
+ + rows[None, :, None] * stride1
62
+ + cols[None, None, :] * stride2
63
+ )
64
+ x = tl.load(input_offsets, mask=source_mask).to(tl.float32)
65
+
66
+ # compute scaled index add and save
67
+ if HAS_SCALING:
68
+ scaling = tl.load(
69
+ scaling_ptr + cols[None, None, :] * stride2,
70
+ mask=(cols[None, None, :] < num_cols),
71
+ ).to(tl.float32)
72
+ tl.store(input_offsets, x + alpha * scaling * source, mask=source_mask)
73
+ else:
74
+ tl.store(input_offsets, x + alpha * source, mask=source_mask)
75
+
76
+
77
+ def scaled_index_add_fwd(
78
+ x: torch.Tensor,
79
+ index: torch.Tensor,
80
+ source: torch.Tensor,
81
+ scaling: Optional[torch.Tensor],
82
+ alpha: float,
83
+ ):
84
+ if not (x.is_cuda and index.is_cuda and source.is_cuda):
85
+ raise ValueError(
86
+ "The input tensor, the index tensor and the source tensor must be of type CUDA!"
87
+ )
88
+
89
+ if not (x.ndim == 3 and source.ndim == 3):
90
+ raise ValueError(
91
+ f"The input and source must be three-dimensional (got {x.ndim} and {source.ndim})!"
92
+ )
93
+ if not x.shape[1] == source.shape[1]:
94
+ raise ValueError(
95
+ f"The number of elements along dimension 1 of the input and source must be the same "
96
+ f"(got {x.shape[1], } and {source.shape[1], })!"
97
+ )
98
+ if not x.shape[2] == source.shape[2]:
99
+ raise ValueError(
100
+ f"The number of elements along dimension 2 of the input and source must be the same "
101
+ f"(got {x.shape[2], } and {source.shape[2], })!"
102
+ )
103
+
104
+ num_inp_indices, num_rows, num_cols = x.shape
105
+ num_src_indices, num_rows, num_cols = source.shape
106
+ if not num_inp_indices >= num_src_indices:
107
+ raise ValueError(
108
+ f"The number of elements along dimension 0 of the input must be larger than that of source "
109
+ f"(got {num_inp_indices} and {num_src_indices})!"
110
+ )
111
+ if not index.shape[0] == num_src_indices:
112
+ raise ValueError(
113
+ f"The number of indices and source tensors must match (got {len(index)} and {len(source)})!"
114
+ )
115
+
116
+ stride0, stride1, stride2 = x.stride(0), x.stride(1), x.stride(2)
117
+ if not (
118
+ source.stride(0) == stride0
119
+ and source.stride(1) == stride1
120
+ and source.stride(2) == stride2
121
+ ):
122
+ raise ValueError(
123
+ f"The strides of the source and input tensors must match (got {source.stride(0)} vs. {stride0}, "
124
+ f"{source.stride(1)} vs. {stride1}, {source.stride(2)} vs. {stride2})!"
125
+ )
126
+
127
+ if scaling is None:
128
+ HAS_SCALING = False
129
+ else:
130
+ HAS_SCALING = True
131
+ if not scaling.is_cuda:
132
+ raise ValueError("The scaling tensor must be of type CUDA!")
133
+ if not (scaling.ndim == 1 and scaling.numel() == num_cols):
134
+ raise ValueError(
135
+ f"The scaling tensor must be a 1-dimensional tensor (got {scaling.ndim}) and its size "
136
+ f"must be equal to the size of dimension 2 of source (got {scaling.numel()} vs. {num_cols})."
137
+ )
138
+ if not scaling.stride(0) == stride2:
139
+ raise ValueError(
140
+ f"The stride of scaling must match the stride2 of input (got {scaling.stride(0)} vs. {stride2})"
141
+ )
142
+
143
+ if not index.ndim == 1:
144
+ raise ValueError(f"The index must be one-dimensional (got {index.ndim})!")
145
+
146
+ def grid(meta):
147
+ return (
148
+ triton.cdiv(num_src_indices, meta["BLOCK_SIZE_INDEX"]),
149
+ triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]),
150
+ triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
151
+ )
152
+
153
+ scaled_index_add_fwd_kernel[grid](
154
+ x,
155
+ index,
156
+ source,
157
+ scaling,
158
+ alpha,
159
+ num_inp_indices,
160
+ num_src_indices,
161
+ num_rows,
162
+ num_cols,
163
+ x.stride(0),
164
+ x.stride(1),
165
+ x.stride(2),
166
+ BLOCK_SIZE_INDEX=1,
167
+ BLOCK_SIZE_ROW=1,
168
+ BLOCK_SIZE_COL=512,
169
+ HAS_SCALING=HAS_SCALING,
170
+ )
171
+
172
+ return
173
+
174
+
175
+ @triton.jit
176
+ def scaled_index_add_bwd_kernel(
177
+ grad_output_ptr, # *Pointer* to input tensor.
178
+ grad_source_ptr, # *Pointer* to index tensor.
179
+ grad_scaling_ptr, # *Pointer* to source tensor.
180
+ source_ptr, # *Pointer* to the source tensor.
181
+ scaling_ptr, # *Pointer* to the scaling tensor.
182
+ index_ptr,
183
+ alpha,
184
+ num_inp_indices,
185
+ num_src_indices,
186
+ num_rows,
187
+ num_cols,
188
+ stride0, # Stride information of input and source tensor.
189
+ stride1,
190
+ stride2,
191
+ BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
192
+ BLOCK_SIZE_ROW: tl.constexpr, # Number of rows each program should process.
193
+ BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
194
+ HAS_SCALING: tl.constexpr, # Boolean indicating if the scaling factor is present.
195
+ ):
196
+ pid0 = tl.program_id(axis=0) # We use 3D launch grid
197
+ pid1 = tl.program_id(axis=1)
198
+ pid2 = tl.program_id(axis=2)
199
+
200
+ rows = pid1 * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
201
+ cols = pid2 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
202
+
203
+ # load source
204
+ source_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
205
+ source_offsets = (
206
+ source_ptr
207
+ + source_indices[:, None, None] * stride0
208
+ + rows[None, :, None] * stride1
209
+ + cols[None, None, :] * stride2
210
+ )
211
+ source_mask = (
212
+ (source_indices[:, None, None] < num_src_indices)
213
+ & (rows[None, :, None] < num_rows)
214
+ & (cols[None, None, :] < num_cols)
215
+ )
216
+ source = tl.load(source_offsets, mask=source_mask).to(tl.float32)
217
+
218
+ # load grad_output
219
+ grad_output_indices = tl.load(
220
+ index_ptr + source_indices, mask=(source_indices < num_src_indices)
221
+ )
222
+ grad_output_offsets = (
223
+ grad_output_ptr
224
+ + grad_output_indices * stride0
225
+ + rows[None, :, None] * stride1
226
+ + cols[None, None, :] * stride2
227
+ )
228
+ grad_output = tl.load(grad_output_offsets, mask=source_mask).to(tl.float32)
229
+
230
+ # compute gradient
231
+ grad_source_offsets = (
232
+ grad_source_ptr
233
+ + source_indices[:, None, None] * stride0
234
+ + rows[None, :, None] * stride1
235
+ + cols[None, None, :] * stride2
236
+ )
237
+ if HAS_SCALING:
238
+ scaling = tl.load(
239
+ scaling_ptr + cols[None, None, :] * stride2,
240
+ mask=(cols[None, None, :] < num_cols),
241
+ ).to(tl.float32)
242
+
243
+ tl.store(grad_source_offsets, alpha * grad_output * scaling, mask=source_mask)
244
+
245
+ grad_scaling_offsets = (
246
+ grad_scaling_ptr
247
+ + source_indices[:, None, None] * stride0
248
+ + rows[None, :, None] * stride1
249
+ + cols[None, None, :] * stride2
250
+ )
251
+ tl.store(grad_scaling_offsets, alpha * grad_output * source, mask=source_mask)
252
+ else:
253
+ tl.store(grad_source_offsets, alpha * grad_output, mask=source_mask)
254
+
255
+
256
+ def scaled_index_add_bwd(
257
+ grad_output: torch.Tensor,
258
+ grad_source: torch.Tensor,
259
+ grad_scaling: Optional[torch.Tensor],
260
+ source: torch.Tensor,
261
+ scaling: Optional[torch.Tensor],
262
+ index: torch.Tensor,
263
+ alpha: float,
264
+ ):
265
+ if not (grad_output.is_cuda and grad_source.is_cuda):
266
+ raise ValueError(
267
+ "The grad_output tensor and grad_source tensor must be of type CUDA!"
268
+ )
269
+
270
+ if not (grad_output.ndim == 3 and source.ndim == 3):
271
+ raise ValueError(
272
+ f"The input and source must be three-dimensional (got {grad_output.ndim} and {source.ndim})!"
273
+ )
274
+
275
+ if not grad_output.shape[1] == source.shape[1]:
276
+ raise ValueError(
277
+ f"The number of elements along dimension 1 of the input and source must be the same "
278
+ f"(got {grad_output.shape[1], } and {source.shape[1], })!"
279
+ )
280
+ if not grad_output.shape[2] == source.shape[2]:
281
+ raise ValueError(
282
+ f"The number of elements along dimension 2 of the input and source must be the same "
283
+ f"(got {grad_output.shape[2], } and {source.shape[2], })!"
284
+ )
285
+
286
+ num_inp_indices, num_rows, num_cols = grad_output.shape
287
+ num_src_indices, num_rows, num_cols = source.shape
288
+ if not num_inp_indices >= num_src_indices:
289
+ raise ValueError(
290
+ f"The number of elements along dimension 0 of the input must be larger than that of source "
291
+ f"(got {num_inp_indices} and {num_src_indices})!"
292
+ )
293
+
294
+ stride0, stride1, stride2 = source.stride(0), source.stride(1), source.stride(2)
295
+ if not (
296
+ grad_output.stride(0) == stride0
297
+ and grad_output.stride(1) == stride1
298
+ and grad_output.stride(2) == stride2
299
+ ):
300
+ raise ValueError(
301
+ f"The strides of grad_output and source must match "
302
+ f"(got {grad_output.stride(0)} vs {stride0}, {grad_output.stride(1)} vs {stride1}, "
303
+ f"{grad_output.stride(2)} vs {stride2})!"
304
+ )
305
+ if not (
306
+ grad_source.stride(0) == stride0
307
+ and grad_source.stride(1) == stride1
308
+ and grad_source.stride(2) == stride2
309
+ ):
310
+ raise ValueError(
311
+ f"The strides of grad_source and source must match "
312
+ f"(got {grad_source.stride(0)} vs {stride0}, {grad_source.stride(1)} vs {stride1}, "
313
+ f"{grad_source.stride(2)} vs {stride2})!"
314
+ )
315
+
316
+ if scaling is not None and grad_scaling is not None:
317
+ HAS_SCALING = True
318
+ if not grad_scaling.is_cuda:
319
+ raise ValueError("The scaling tensor must be of type CUDA!")
320
+ if not (
321
+ grad_scaling.stride(0) == stride0
322
+ and grad_scaling.stride(1) == stride1
323
+ and grad_scaling.stride(2) == stride2
324
+ ):
325
+ raise ValueError(
326
+ f"The strides of grad_scaling and source must match "
327
+ f"(got {grad_scaling.stride(0)} vs {stride0}, {grad_scaling.stride(1)} vs {stride1}, "
328
+ f"{grad_scaling.stride(2)} vs {stride2})!"
329
+ )
330
+ if not scaling.stride(0) == stride2:
331
+ raise ValueError(
332
+ f"The stride of scaling must match stride2 of source (got {scaling.stride(0)} vs. {stride2})!"
333
+ )
334
+ else:
335
+ HAS_SCALING = False
336
+
337
+ def grid(meta):
338
+ return (
339
+ triton.cdiv(num_src_indices, meta["BLOCK_SIZE_INDEX"]),
340
+ triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]),
341
+ triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
342
+ )
343
+
344
+ scaled_index_add_bwd_kernel[grid](
345
+ grad_output,
346
+ grad_source,
347
+ grad_scaling,
348
+ source,
349
+ scaling,
350
+ index,
351
+ alpha,
352
+ num_inp_indices,
353
+ num_src_indices,
354
+ num_rows,
355
+ num_cols,
356
+ stride0,
357
+ stride1,
358
+ stride2,
359
+ BLOCK_SIZE_INDEX=1,
360
+ BLOCK_SIZE_ROW=1,
361
+ BLOCK_SIZE_COL=512,
362
+ HAS_SCALING=HAS_SCALING,
363
+ )
364
+
365
+ return
.venv/lib/python3.11/site-packages/xformers/ops/_triton/rmsnorm_kernels.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ try:
10
+ from triton.language.extra.cuda.libdevice import rsqrt
11
+ except ImportError:
12
+ try:
13
+ from triton.language.math import rsqrt
14
+ except ImportError:
15
+ from triton.language.libdevice import rsqrt
16
+
17
+
18
+ @triton.jit
19
+ def _rms_norm_kernel(
20
+ x_ptr,
21
+ h1_ptr,
22
+ w_ptr,
23
+ eps,
24
+ stride,
25
+ N_COLS: tl.constexpr,
26
+ BLOCK_SIZE: tl.constexpr,
27
+ INCLUDE_WEIGHT: tl.constexpr,
28
+ ):
29
+ row = tl.program_id(0).to(tl.int64)
30
+ x_ptr += row * stride
31
+ h1_ptr += row * stride
32
+
33
+ _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
34
+ for offset in range(0, N_COLS, BLOCK_SIZE):
35
+ cols = offset + tl.arange(0, BLOCK_SIZE)
36
+ a = tl.load(
37
+ x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last"
38
+ ).to(tl.float32)
39
+ _mean += a * a
40
+ rstd = rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
41
+ for offset in range(0, N_COLS, BLOCK_SIZE):
42
+ cols = offset + tl.arange(0, BLOCK_SIZE)
43
+ mask = cols < N_COLS
44
+ a = tl.load(
45
+ x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
46
+ ).to(tl.float32)
47
+ if INCLUDE_WEIGHT:
48
+ w = tl.load(w_ptr + cols, mask=mask)
49
+ tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
50
+ else:
51
+ tl.store(h1_ptr + cols, a * rstd, mask=mask)
52
+
53
+
54
+ @triton.jit
55
+ def _rms_norm_add_kernel(
56
+ x_ptr,
57
+ y_ptr,
58
+ h1_ptr,
59
+ w_ptr,
60
+ eps,
61
+ stride,
62
+ N_COLS: tl.constexpr,
63
+ BLOCK_SIZE: tl.constexpr,
64
+ INCLUDE_WEIGHT: tl.constexpr,
65
+ ):
66
+ row = tl.program_id(0)
67
+ x_ptr += row * stride
68
+ y_ptr += row * stride
69
+ h1_ptr += row * stride
70
+
71
+ _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
72
+ for offset in range(0, N_COLS, BLOCK_SIZE):
73
+ cols = offset + tl.arange(0, BLOCK_SIZE)
74
+ mask = cols < N_COLS
75
+ ax = tl.load(
76
+ x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last"
77
+ ).to(tl.float32)
78
+ ay = tl.load(
79
+ y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
80
+ ).to(tl.float32)
81
+ a = ax + ay
82
+ tl.store(x_ptr + cols, a, mask=mask)
83
+ _mean += a * a
84
+ rstd = rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
85
+ for offset in range(0, N_COLS, BLOCK_SIZE):
86
+ cols = offset + tl.arange(0, BLOCK_SIZE)
87
+ mask = cols < N_COLS
88
+ a = tl.load(
89
+ x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
90
+ ).to(tl.float32)
91
+ if INCLUDE_WEIGHT:
92
+ w = tl.load(w_ptr + cols, mask=mask)
93
+ tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
94
+ else:
95
+ tl.store(h1_ptr + cols, a * rstd, mask=mask)
96
+
97
+
98
+ def _rms_norm_forward(x, attn_norm_weights, eps):
99
+ if not x.is_contiguous():
100
+ raise ValueError("data must be contiguous")
101
+ if attn_norm_weights is not None:
102
+ if not attn_norm_weights.is_contiguous():
103
+ raise ValueError("weights must be contiguous")
104
+ out = torch.empty_like(x)
105
+ x_arg = x.reshape(-1, x.shape[-1])
106
+ M, N = x_arg.shape
107
+ # Less than 64KB per feature: enqueue fused kernel
108
+ MAX_FUSED_SIZE = 65536 // x.element_size()
109
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
110
+ BLOCK_SIZE = max(BLOCK_SIZE, 128)
111
+ BLOCK_SIZE = min(BLOCK_SIZE, 8192)
112
+ # heuristics for number of warps
113
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
114
+ with torch.cuda.device(x.device):
115
+ _rms_norm_kernel[(M,)](
116
+ x_arg,
117
+ out,
118
+ attn_norm_weights,
119
+ eps,
120
+ x_arg.stride(0),
121
+ N,
122
+ BLOCK_SIZE=BLOCK_SIZE,
123
+ num_warps=num_warps,
124
+ INCLUDE_WEIGHT=attn_norm_weights is not None,
125
+ )
126
+ return out
127
+
128
+
129
+ def _rms_norm_add_forward(x, y, attn_norm_weights, eps):
130
+ # x, y contiguous of same shape [..., n]
131
+ # output of same shape, normed over the last dim.
132
+ if not x.is_contiguous():
133
+ raise ValueError("x must be contiguous")
134
+ if not y.is_contiguous():
135
+ raise ValueError("y must be contiguous")
136
+ if attn_norm_weights is not None:
137
+ if not attn_norm_weights.is_contiguous():
138
+ raise ValueError("weights must be contiguous")
139
+ out = torch.empty_like(x)
140
+ x_arg = x.reshape(-1, x.shape[-1])
141
+ y_arg = y.reshape(-1, x.shape[-1])
142
+ M, N = x_arg.shape
143
+ # Less than 64KB per feature: enqueue fused kernel
144
+ MAX_FUSED_SIZE = 65536 // x.element_size()
145
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
146
+ BLOCK_SIZE = max(BLOCK_SIZE, 128)
147
+ BLOCK_SIZE = min(BLOCK_SIZE, 8192)
148
+ # heuristics for number of warps
149
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
150
+ with torch.cuda.device(x.device):
151
+ _rms_norm_add_kernel[(M,)](
152
+ x_arg,
153
+ y_arg,
154
+ out,
155
+ attn_norm_weights,
156
+ eps,
157
+ x_arg.stride(0),
158
+ N,
159
+ BLOCK_SIZE=BLOCK_SIZE,
160
+ num_warps=num_warps,
161
+ INCLUDE_WEIGHT=attn_norm_weights is not None,
162
+ )
163
+ return out
.venv/lib/python3.11/site-packages/xformers/ops/_triton/rope_padded_kernels.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import triton # type: ignore
6
+ import triton.language as tl # type: ignore
7
+
8
+ try:
9
+ from triton.language.extra.cuda.libdevice import pow
10
+ except ImportError:
11
+ try:
12
+ from triton.language.math import pow
13
+ except ImportError:
14
+ from triton.language.libdevice import pow
15
+
16
+
17
+ @triton.jit
18
+ def _rope_padded_kernel(
19
+ xq,
20
+ xk,
21
+ xv,
22
+ out_q,
23
+ cache_k,
24
+ cache_v,
25
+ seqstartq,
26
+ seqstartk,
27
+ seqlenk,
28
+ theta,
29
+ linear_scale,
30
+ use_dynamic_scaling: tl.constexpr,
31
+ dynamic_old_context_len: tl.constexpr,
32
+ dynamic_scale_factor: tl.constexpr,
33
+ dynamic_low_freq_factor: tl.constexpr,
34
+ dynamic_high_freq_factor: tl.constexpr,
35
+ first_seqpos,
36
+ seqpos,
37
+ k_start: tl.constexpr,
38
+ v_start: tl.constexpr,
39
+ n_groups,
40
+ dim: tl.constexpr, # dimension of each head
41
+ stride_xqM,
42
+ stride_xqG,
43
+ stride_xqH,
44
+ stride_xkM,
45
+ stride_xkG,
46
+ stride_xkH,
47
+ stride_xvM,
48
+ stride_xvG,
49
+ stride_xvH,
50
+ stride_cachekM,
51
+ stride_cachekG,
52
+ stride_cachekH,
53
+ stride_cachevM,
54
+ stride_cachevG,
55
+ stride_cachevH,
56
+ stride_seqstartq,
57
+ stride_seqstartk,
58
+ stride_seqlenk,
59
+ stride_outqM,
60
+ stride_outqG,
61
+ stride_outqH,
62
+ stride_seqpos,
63
+ internal_dtype: tl.constexpr,
64
+ # If True, seqstartq and seqstartk are not used but rather we
65
+ # assume that every batch element has the same number of
66
+ # queries (i.e. num_queries := tl.num_programs(1) )
67
+ # and the same cache space cache_padding_length.
68
+ # Always False when called below.
69
+ const_batch_strides: tl.constexpr,
70
+ # If const_batch_strides==True, the common cache length for each batch element.
71
+ # (Only the first seqlenk[i] elements are actually in use, and only the last
72
+ # num_queries of those are actually written to.)
73
+ cache_padding_length,
74
+ # offset added to all values in seqlenk before using them.
75
+ # Always 0 when called below.
76
+ seqlenk_shift: tl.constexpr,
77
+ BLOCK_SIZE: tl.constexpr,
78
+ adjacents: tl.constexpr,
79
+ ):
80
+ """
81
+ Each letter in this diagram is a whole row of length dim.
82
+
83
+ INPUT xq xk xv
84
+
85
+ head_dim ─►
86
+
87
+ batch qqqqqq kk vv
88
+ │ qqqqqq kk vv
89
+ ▼ qqqqqq kk vv
90
+
91
+ head_idx: (goes across all heads of all 3 inputs)
92
+ ▲ ▲ ▲ ▲ ▲ ▲
93
+ │ │ │ │ │ │
94
+ │ │
95
+ 0 k_start │v_start │n_total_heads
96
+ │ │
97
+ │ │
98
+ k_start v_start
99
+
100
+ Output is to out_q (same shape as xq), an xk-shaped part
101
+ of cache_k and an xv-shaped part of cache_v
102
+ """
103
+ query_pos_in_batch_elt = tl.program_id(0)
104
+ batch_elt = tl.program_id(1)
105
+ group_head_idx = tl.program_id(2)
106
+ group_idx = group_head_idx % n_groups
107
+ head_idx = group_head_idx // n_groups
108
+
109
+ if internal_dtype == "f32":
110
+ theta = theta.to(tl.float32)
111
+ elif internal_dtype == "f64":
112
+ theta = theta.to(tl.float64)
113
+
114
+ if const_batch_strides:
115
+ query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt
116
+ end_query_pos = tl.num_programs(1) * (batch_elt + 1)
117
+ else:
118
+ query_pos = query_pos_in_batch_elt + tl.load(
119
+ seqstartq + batch_elt * stride_seqstartq
120
+ )
121
+ end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq)
122
+ if query_pos >= end_query_pos:
123
+ return
124
+
125
+ is_q = head_idx < k_start
126
+ is_v = head_idx >= v_start
127
+
128
+ xq += query_pos * stride_xqM + head_idx * stride_xqH + group_idx * stride_xqG
129
+ out_q += (
130
+ query_pos * stride_outqM + head_idx * stride_outqH + group_idx * stride_outqG
131
+ )
132
+
133
+ if const_batch_strides:
134
+ cache_start = cache_padding_length * batch_elt
135
+ else:
136
+ cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk)
137
+ end_of_batch_elt_cache = (
138
+ cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift
139
+ )
140
+
141
+ cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos)
142
+ if seqpos is not None:
143
+ seq_pos = tl.load(seqpos + query_pos * stride_seqpos)
144
+ else:
145
+ seq_pos = cache_pos - cache_start
146
+ if first_seqpos is not None:
147
+ seq_pos += tl.load(first_seqpos + batch_elt * stride_seqpos)
148
+ cache_k += (
149
+ (head_idx - k_start) * stride_cachekH
150
+ + cache_pos * stride_cachekM
151
+ + group_idx * stride_cachekG
152
+ )
153
+ xk += (
154
+ query_pos * stride_xkM
155
+ + (head_idx - k_start) * stride_xkH
156
+ + group_idx * stride_xkG
157
+ )
158
+ in_qk = tl.where(is_q, xq, xk)
159
+ out_qk = tl.where(is_q, out_q, cache_k)
160
+
161
+ cache_v += (
162
+ (head_idx - v_start) * stride_cachevH
163
+ + cache_pos * stride_cachevM
164
+ + group_idx * stride_cachevG
165
+ )
166
+ xv += (
167
+ query_pos * stride_xvM
168
+ + (head_idx - v_start) * stride_xvH
169
+ + group_idx * stride_xvG
170
+ )
171
+
172
+ out = tl.where(is_v, cache_v, out_qk)
173
+ x_in = tl.where(is_v, xv, in_qk)
174
+
175
+ for offset in range(0, dim // 2, BLOCK_SIZE // 2):
176
+ c = tl.arange(0, BLOCK_SIZE // 2)
177
+ powers = (offset + c) * 2.0
178
+ if adjacents:
179
+ cols_re = (offset + c) * 2
180
+ cols_im = cols_re + 1
181
+ else:
182
+ cols_re = offset + c
183
+ cols_im = cols_re + dim // 2
184
+
185
+ mask = cols_im < dim
186
+
187
+ re_x = tl.load(x_in + cols_re, mask=mask)
188
+ im_x = tl.load(x_in + cols_im, mask=mask)
189
+ # freqs = seq_pos / (theta ** (powers / dim))
190
+ freqs = pow(theta, powers / (-dim))
191
+
192
+ if use_dynamic_scaling:
193
+ lo_freq_wavelen = dynamic_old_context_len / dynamic_low_freq_factor
194
+ hi_freq_wavelen = dynamic_old_context_len / dynamic_high_freq_factor
195
+
196
+ wavelens = 6.28318530718 / freqs # 2*pi
197
+ is_low_freq = wavelens > lo_freq_wavelen
198
+ freqs = tl.where(is_low_freq, freqs / dynamic_scale_factor, freqs)
199
+
200
+ is_mid_freq = hi_freq_wavelen <= wavelens and wavelens <= lo_freq_wavelen
201
+
202
+ smooth = (dynamic_old_context_len / wavelens - dynamic_low_freq_factor) / (
203
+ dynamic_high_freq_factor - dynamic_low_freq_factor
204
+ )
205
+ freqs = tl.where(
206
+ is_mid_freq,
207
+ (1 - smooth) * freqs / dynamic_scale_factor + smooth * freqs,
208
+ freqs,
209
+ )
210
+
211
+ freqs = seq_pos * freqs / linear_scale
212
+ sines = tl.sin(freqs)
213
+ cosines = tl.cos(freqs)
214
+ re_out = re_x * cosines - im_x * sines
215
+ im_out = im_x * cosines + re_x * sines
216
+
217
+ re_out_ = tl.where(is_v, re_x, re_out)
218
+ im_out_ = tl.where(is_v, im_x, im_out)
219
+ if internal_dtype == "f64":
220
+ if re_x.dtype == tl.bfloat16:
221
+ # triton 2.0.0 crashes if you try to convert
222
+ # float64 directly to bfloat16, so make an intermediate step.
223
+ re_out_ = re_out_.to(tl.float32)
224
+ im_out_ = im_out_.to(tl.float32)
225
+ tl.store(out + cols_re, re_out_, mask=mask)
226
+ tl.store(out + cols_im, im_out_, mask=mask)
.venv/lib/python3.11/site-packages/xformers/ops/_triton/tiled_matmul_kernels.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import itertools
8
+ from typing import List, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
14
+
15
+
16
+ def init_to_zero(*names):
17
+ def result(nargs):
18
+ for name in names:
19
+ nargs[name].zero_()
20
+
21
+ return result
22
+
23
+
24
+ def gen_config(
25
+ block_m: int,
26
+ block_n: int,
27
+ block_k: int,
28
+ stages: int,
29
+ warps: int,
30
+ split_k: int = 1,
31
+ group_m: int = 8,
32
+ ) -> triton.Config:
33
+ """A more compact way to define a triton.Config, so it fits on one line"""
34
+
35
+ return triton.Config(
36
+ {
37
+ "BLOCK_M": block_m,
38
+ "BLOCK_N": block_n,
39
+ "BLOCK_K": block_k,
40
+ "SPLIT_K": split_k,
41
+ "GROUP_M": group_m,
42
+ },
43
+ num_stages=stages,
44
+ num_warps=warps,
45
+ pre_hook=init_to_zero(*[f"C{i+1}{j+1}" for i in range(3) for j in range(3)])
46
+ if split_k > 1
47
+ else init_to_zero(),
48
+ )
49
+
50
+
51
+ BASIC_MATMUL_CONFIGS = [
52
+ gen_config(block_m=128, block_n=256, block_k=32, stages=3, warps=8),
53
+ gen_config(block_m=256, block_n=128, block_k=32, stages=3, warps=8),
54
+ gen_config(block_m=256, block_n=64, block_k=32, stages=4, warps=4),
55
+ gen_config(block_m=64, block_n=256, block_k=32, stages=4, warps=4),
56
+ gen_config(block_m=128, block_n=128, block_k=32, stages=4, warps=4),
57
+ gen_config(block_m=128, block_n=64, block_k=32, stages=4, warps=4),
58
+ gen_config(block_m=64, block_n=128, block_k=32, stages=4, warps=4),
59
+ gen_config(block_m=128, block_n=32, block_k=32, stages=4, warps=4),
60
+ gen_config(block_m=64, block_n=32, block_k=32, stages=5, warps=2),
61
+ ]
62
+
63
+
64
+ INT8_MATMUL_CONFIGS = [
65
+ gen_config(block_m=128, block_n=256, block_k=128, stages=3, warps=8),
66
+ gen_config(block_m=256, block_n=128, block_k=128, stages=3, warps=8),
67
+ gen_config(block_m=256, block_n=64, block_k=128, stages=4, warps=4),
68
+ gen_config(block_m=64, block_n=256, block_k=128, stages=4, warps=4),
69
+ gen_config(block_m=128, block_n=128, block_k=128, stages=4, warps=4),
70
+ gen_config(block_m=128, block_n=64, block_k=64, stages=4, warps=4),
71
+ gen_config(block_m=64, block_n=128, block_k=64, stages=4, warps=4),
72
+ gen_config(block_m=128, block_n=32, block_k=64, stages=4, warps=4),
73
+ gen_config(block_m=64, block_n=32, block_k=64, stages=5, warps=2),
74
+ ]
75
+
76
+
77
+ IO_BOUND_MATMUL_CONFIGS_STAGES = [2, 3, 4, 5, 6]
78
+ IO_BOUND_MATMUL_CONFIGS_BLOCK_M = [16, 32]
79
+ IO_BOUND_MATMUL_CONFIGS_BLOCK_K = [32, 64]
80
+ IO_BOUND_MATMUL_CONFIGS_BLOCK_N = [32, 64, 128, 256]
81
+ IO_BOUND_MATMUL_CONFIGS_SPLIT_K = [1, 2, 4, 8, 16]
82
+
83
+
84
+ IO_BOUND_MATMUL_CONFIGS = [
85
+ gen_config(
86
+ block_m=block_m,
87
+ block_n=block_n,
88
+ block_k=block_k,
89
+ stages=stages,
90
+ warps=2 if block_n <= 64 else 4,
91
+ split_k=split_k,
92
+ )
93
+ for stages, block_m, block_k, block_n, split_k in itertools.product(
94
+ IO_BOUND_MATMUL_CONFIGS_STAGES,
95
+ IO_BOUND_MATMUL_CONFIGS_BLOCK_M,
96
+ IO_BOUND_MATMUL_CONFIGS_BLOCK_K,
97
+ IO_BOUND_MATMUL_CONFIGS_BLOCK_N,
98
+ IO_BOUND_MATMUL_CONFIGS_SPLIT_K,
99
+ )
100
+ ]
101
+
102
+
103
+ TRITON_CONFIGS = BASIC_MATMUL_CONFIGS + INT8_MATMUL_CONFIGS + IO_BOUND_MATMUL_CONFIGS
104
+
105
+
106
+ def our_estimate_matmul_time(
107
+ A11, B11, C11, M1, M2, M3, N1, N2, N3, K1, K2, K3, **kwargs
108
+ ):
109
+ """Call into Triton's upstream cost model, with the right args
110
+
111
+ The upstream function expects arguments to have certain names. Since we
112
+ renamed a few of them in our implementation, we rename them back.
113
+
114
+ At the time of writing (July 2023) the arguments that Triton expects are:
115
+ M, N, K, A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages.
116
+
117
+ """
118
+ return estimate_matmul_time(
119
+ M=M1 + M2 + M3, N=N1 + N2 + N3, K=K1 + K2 + K3, A=A11, B=B11, C=C11, **kwargs
120
+ )
121
+
122
+
123
+ def our_early_config_prune(config, named_args, **kwargs):
124
+ new_named_args = named_args.copy()
125
+ new_named_args["M"] = named_args["M1"] + named_args["M2"] + named_args["M3"]
126
+ new_named_args["N"] = named_args["N1"] + named_args["N2"] + named_args["N3"]
127
+ new_named_args["K"] = named_args["K1"] + named_args["K2"] + named_args["K3"]
128
+ new_named_args["A"] = named_args["A11"]
129
+ new_named_args["B"] = named_args["B11"]
130
+ new_named_args["C"] = named_args["C11"]
131
+ return early_config_prune(config, new_named_args, **kwargs)
132
+
133
+
134
+ @triton.autotune(
135
+ configs=TRITON_CONFIGS,
136
+ key=["M1", "M2", "M3", "N1", "N2", "N3", "K1", "K2", "K3"],
137
+ prune_configs_by={
138
+ "early_config_prune": our_early_config_prune,
139
+ "perf_model": our_estimate_matmul_time,
140
+ "top_k": 10,
141
+ },
142
+ )
143
+ @triton.heuristics(
144
+ {
145
+ "EVEN_K": lambda args: all(
146
+ k % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
147
+ for k in [args["K1"], args["K2"], args["K3"]]
148
+ ),
149
+ }
150
+ )
151
+ @triton.jit()
152
+ def _xformers_tiled_matmul_kernel(
153
+ A11,
154
+ A12,
155
+ A13,
156
+ A21,
157
+ A22,
158
+ A23,
159
+ A31,
160
+ A32,
161
+ A33,
162
+ B11,
163
+ B12,
164
+ B13,
165
+ B21,
166
+ B22,
167
+ B23,
168
+ B31,
169
+ B32,
170
+ B33,
171
+ C11,
172
+ C12,
173
+ C13,
174
+ C21,
175
+ C22,
176
+ C23,
177
+ C31,
178
+ C32,
179
+ C33,
180
+ M1,
181
+ M2,
182
+ M3,
183
+ N1,
184
+ N2,
185
+ N3,
186
+ K1,
187
+ K2,
188
+ K3,
189
+ stride_am1,
190
+ stride_am2,
191
+ stride_am3,
192
+ stride_ak1,
193
+ stride_ak2,
194
+ stride_ak3,
195
+ stride_bk1,
196
+ stride_bk2,
197
+ stride_bk3,
198
+ stride_bn1,
199
+ stride_bn2,
200
+ stride_bn3,
201
+ stride_cm1,
202
+ stride_cm2,
203
+ stride_cm3,
204
+ stride_cn1,
205
+ stride_cn2,
206
+ stride_cn3,
207
+ BLOCK_M: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
208
+ BLOCK_N: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
209
+ BLOCK_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
210
+ GROUP_M: tl.constexpr,
211
+ SPLIT_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
212
+ EVEN_K: tl.constexpr,
213
+ ACC_TYPE: tl.constexpr,
214
+ ):
215
+ # matrix multiplication
216
+ pid = tl.program_id(0)
217
+ pid_k = tl.program_id(1)
218
+ grid_m1 = tl.cdiv(M1, BLOCK_M)
219
+ grid_m2 = tl.cdiv(M2, BLOCK_M)
220
+ grid_m3 = tl.cdiv(M3, BLOCK_M)
221
+ grid_n1 = tl.cdiv(N1, BLOCK_N)
222
+ grid_n2 = tl.cdiv(N2, BLOCK_N)
223
+ grid_n3 = tl.cdiv(N3, BLOCK_N)
224
+ grid_m = grid_m1 + grid_m2 + grid_m3
225
+ grid_n = grid_n1 + grid_n2 + grid_n3
226
+
227
+ # re-order program ID for better L2 performance
228
+ width = GROUP_M * grid_n
229
+ group_id = pid // width
230
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
231
+ pid_m = group_id * GROUP_M + (pid % group_size)
232
+ pid_n = (pid % width) // (group_size)
233
+
234
+ # We use tl.where to circumvent a regression in alignment auto-detection:
235
+ # https://github.com/openai/triton/issues/1784
236
+
237
+ A1 = tl.where(pid_m < grid_m1, A11, tl.where(pid_m < grid_m1 + grid_m2, A21, A31))
238
+ A2 = tl.where(pid_m < grid_m1, A12, tl.where(pid_m < grid_m1 + grid_m2, A22, A32))
239
+ A3 = tl.where(pid_m < grid_m1, A13, tl.where(pid_m < grid_m1 + grid_m2, A23, A33))
240
+ B1 = tl.where(pid_n < grid_n1, B11, tl.where(pid_n < grid_n1 + grid_n2, B12, B13))
241
+ B2 = tl.where(pid_n < grid_n1, B21, tl.where(pid_n < grid_n1 + grid_n2, B22, B23))
242
+ B3 = tl.where(pid_n < grid_n1, B31, tl.where(pid_n < grid_n1 + grid_n2, B32, B33))
243
+ C = tl.where(
244
+ pid_m < grid_m1,
245
+ tl.where(pid_n < grid_n1, C11, tl.where(pid_n < grid_n1 + grid_n2, C12, C13)),
246
+ tl.where(
247
+ pid_m < grid_m1 + grid_m2,
248
+ tl.where(
249
+ pid_n < grid_n1, C21, tl.where(pid_n < grid_n1 + grid_n2, C22, C23)
250
+ ),
251
+ tl.where(
252
+ pid_n < grid_n1, C31, tl.where(pid_n < grid_n1 + grid_n2, C32, C33)
253
+ ),
254
+ ),
255
+ )
256
+ M = tl.where(pid_m < grid_m1, M1, tl.where(pid_m < grid_m1 + grid_m2, M2, M3))
257
+ N = tl.where(pid_n < grid_n1, N1, tl.where(pid_n < grid_n1 + grid_n2, N2, N3))
258
+ stride_ak = tl.where(
259
+ pid_m < grid_m1,
260
+ stride_ak1,
261
+ tl.where(pid_m < grid_m1 + grid_m2, stride_ak2, stride_ak3),
262
+ )
263
+ stride_bk = tl.where(
264
+ pid_n < grid_n1,
265
+ stride_bk1,
266
+ tl.where(pid_n < grid_n1 + grid_n2, stride_bk2, stride_bk3),
267
+ )
268
+ stride_cn = tl.where(
269
+ pid_m < grid_m1,
270
+ stride_cn1,
271
+ tl.where(pid_m < grid_m1 + grid_m2, stride_cn2, stride_cn3),
272
+ )
273
+ stride_cm = tl.where(
274
+ pid_n < grid_n1,
275
+ stride_cm1,
276
+ tl.where(pid_n < grid_n1 + grid_n2, stride_cm2, stride_cm3),
277
+ )
278
+ pid_m = tl.where(
279
+ pid_m < grid_m1,
280
+ pid_m,
281
+ tl.where(pid_m < grid_m1 + grid_m2, pid_m - grid_m1, pid_m - grid_m1 - grid_m2),
282
+ )
283
+ pid_n = tl.where(
284
+ pid_n < grid_n1,
285
+ pid_n,
286
+ tl.where(pid_n < grid_n1 + grid_n2, pid_n - grid_n1, pid_n - grid_n1 - grid_n2),
287
+ )
288
+
289
+ # do matrix multiplication
290
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
291
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
292
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
293
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
294
+ # pointers
295
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
296
+ grid_k1 = tl.cdiv(K1, BLOCK_K)
297
+ grid_k2 = tl.cdiv(K2, BLOCK_K)
298
+ grid_k3 = tl.cdiv(K3, BLOCK_K)
299
+ for tile in range(pid_k, grid_k1 + grid_k2 + grid_k3, SPLIT_K):
300
+ A = tl.where(tile < grid_k1, A1, tl.where(tile < grid_k1 + grid_k2, A2, A3))
301
+ B = tl.where(tile < grid_k1, B1, tl.where(tile < grid_k1 + grid_k2, B2, B3))
302
+ K = tl.where(tile < grid_k1, K1, tl.where(tile < grid_k1 + grid_k2, K2, K3))
303
+ stride_am = tl.where(
304
+ tile < grid_k1,
305
+ stride_am1,
306
+ tl.where(tile < grid_k1 + grid_k2, stride_am2, stride_am3),
307
+ )
308
+ stride_bn = tl.where(
309
+ tile < grid_k1,
310
+ stride_bn1,
311
+ tl.where(tile < grid_k1 + grid_k2, stride_bn2, stride_bn3),
312
+ )
313
+ my_tile = tl.where(
314
+ tile < grid_k1,
315
+ tile,
316
+ tl.where(
317
+ tile < grid_k1 + grid_k2, tile - grid_k1, tile - grid_k1 - grid_k2
318
+ ),
319
+ )
320
+ rk = my_tile * BLOCK_K + tl.arange(0, BLOCK_K)
321
+ Ain = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
322
+ Bin = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
323
+ if EVEN_K:
324
+ a = tl.load(Ain)
325
+ b = tl.load(Bin)
326
+ else:
327
+ a = tl.load(Ain, mask=rk[None, :] < K, other=0.0)
328
+ b = tl.load(Bin, mask=rk[:, None] < K, other=0.0)
329
+ acc += tl.dot(a, b, allow_tf32=False)
330
+ acc = acc.to(C.dtype.element_ty)
331
+ # rematerialize rm and rn to save registers
332
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
333
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
334
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
335
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
336
+ # handles write-back with reduction-splitting
337
+ if SPLIT_K == 1:
338
+ tl.store(C, acc, mask=mask)
339
+ else:
340
+ tl.atomic_add(C, acc, mask=mask)
341
+
342
+
343
+ def _check_row_or_column(row_or_col_type, row_or_col_idx, tensor_name, dim_name, vals):
344
+ assert len(vals) > 0
345
+ for pos, val in enumerate(vals[1:]):
346
+ assert val == vals[0], (
347
+ f"the tensors on {row_or_col_type} {row_or_col_idx} of the {tensor_name} "
348
+ f"must all have the same stride along the {dim_name} dimension, got "
349
+ f"{vals[0]} at position 0 and {val} at position {pos + 1}"
350
+ )
351
+ return vals[0]
352
+
353
+
354
+ def _get_strides(
355
+ ts: List[List[torch.Tensor]], tensor_name, dim_0_name, dim_1_name
356
+ ) -> Tuple[List[int], List[int]]:
357
+ strides_0 = [
358
+ _check_row_or_column(
359
+ "column", idx, tensor_name, dim_0_name, [y.stride(0) for y in x]
360
+ )
361
+ for idx, x in enumerate(zip(*ts))
362
+ ]
363
+ strides_1 = [
364
+ _check_row_or_column(
365
+ "row", idx, tensor_name, dim_1_name, [y.stride(1) for y in x]
366
+ )
367
+ for idx, x in enumerate(ts)
368
+ ]
369
+ assert all(s == 1 for s in strides_0) or all(s == 1 for s in strides_1)
370
+ while len(strides_0) < 3:
371
+ strides_0.append(1 if strides_0[0] == 1 else 0)
372
+ while len(strides_1) < 3:
373
+ strides_1.append(1 if strides_1[0] == 1 else 0)
374
+ return strides_0, strides_1
375
+
376
+
377
+ def _launch_triton_matmul(
378
+ a: List[List[torch.Tensor]],
379
+ b: List[List[torch.Tensor]],
380
+ c: List[List[torch.Tensor]],
381
+ ms: List[int],
382
+ ns: List[int],
383
+ ks: List[int],
384
+ ) -> None:
385
+ strides_am, strides_ak = _get_strides(a, "first operand", "m", "k")
386
+ strides_bk, strides_bn = _get_strides(b, "second operand", "k", "n")
387
+ strides_cm, strides_cn = _get_strides(c, "output", "m", "n")
388
+
389
+ # accumulator types
390
+ ACC_TYPE = (
391
+ tl.float32
392
+ if c[0][0].dtype in [torch.float16, torch.bfloat16, torch.float32]
393
+ else tl.int32
394
+ )
395
+
396
+ # launch kernel
397
+ def grid(META):
398
+ return (
399
+ sum(triton.cdiv(m, META["BLOCK_M"]) for m in ms)
400
+ * sum(triton.cdiv(n, META["BLOCK_N"]) for n in ns),
401
+ META["SPLIT_K"],
402
+ )
403
+
404
+ _xformers_tiled_matmul_kernel[grid](
405
+ *[
406
+ a[min(i, len(a) - 1)][min(j, len(a[0]) - 1)]
407
+ for i in range(3)
408
+ for j in range(3)
409
+ ],
410
+ *[
411
+ b[min(i, len(b) - 1)][min(j, len(b[0]) - 1)]
412
+ for i in range(3)
413
+ for j in range(3)
414
+ ],
415
+ *[
416
+ c[min(i, len(c) - 1)][min(j, len(c[0]) - 1)]
417
+ for i in range(3)
418
+ for j in range(3)
419
+ ],
420
+ *[ms[i] if len(ms) > i else 0 for i in range(3)],
421
+ *[ns[i] if len(ns) > i else 0 for i in range(3)],
422
+ *[ks[i] if len(ks) > i else 0 for i in range(3)],
423
+ *strides_am,
424
+ *strides_ak,
425
+ *strides_bk,
426
+ *strides_bn,
427
+ *strides_cm,
428
+ *strides_cn,
429
+ ACC_TYPE=ACC_TYPE,
430
+ )
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, List, Optional, Sequence, Tuple, Type, Union, cast
7
+
8
+ import torch
9
+
10
+ from . import (
11
+ attn_bias,
12
+ ck,
13
+ ck_decoder,
14
+ ck_splitk,
15
+ cutlass,
16
+ flash,
17
+ flash3,
18
+ triton_splitk,
19
+ )
20
+ from .attn_bias import (
21
+ VARLEN_BIASES,
22
+ AttentionBias,
23
+ BlockDiagonalMask,
24
+ LowerTriangularMask,
25
+ )
26
+ from .common import (
27
+ AttentionBwOpBase,
28
+ AttentionFwOpBase,
29
+ AttentionOp,
30
+ AttentionOpBase,
31
+ Context,
32
+ Gradients,
33
+ Inputs,
34
+ bmk2bmhk,
35
+ )
36
+ from .dispatch import (
37
+ _dispatch_bw,
38
+ _dispatch_fw,
39
+ _ensure_op_supports_or_raise,
40
+ _get_use_fa3,
41
+ _set_use_fa3,
42
+ )
43
+
44
+ MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
45
+ MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
46
+ MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
47
+ MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp)
48
+ MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp)
49
+ MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp)
50
+
51
+
52
+ def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]) -> Any:
53
+ if attn_bias_tensor is None:
54
+ return attn_bias_ctx
55
+ return attn_bias_tensor
56
+
57
+
58
+ # Note: `torch.compile` only allows custom autograd functions
59
+ # to accept a subset of types. Therefore we serialize `op` objects
60
+ # to `str` before entering the function, and unserialize them inside.
61
+ # See also: https://github.com/pytorch/pytorch/issues/118395
62
+ _OPS_LOOKUP = {
63
+ flash.FwOp.NAME: flash.FwOp,
64
+ flash.BwOp.NAME: flash.BwOp,
65
+ }
66
+
67
+
68
+ def _serialize_op(op):
69
+ if op is not None and op.NAME in _OPS_LOOKUP:
70
+ return op.NAME
71
+ return op
72
+
73
+
74
+ def _unserialize_op(op):
75
+ if isinstance(op, str):
76
+ return _OPS_LOOKUP[op]
77
+ return op
78
+
79
+
80
+ class _fMHA(torch.autograd.Function):
81
+ @staticmethod
82
+ # type: ignore
83
+ def forward(ctx, op_fw, op_bw, *args: Any) -> Any:
84
+ inp = Inputs(*args)
85
+
86
+ op_fw = _unserialize_op(op_fw)
87
+ op_bw = _unserialize_op(op_bw)
88
+
89
+ out, op_ctx = _memory_efficient_attention_forward_requires_grad(
90
+ inp=inp, op=op_fw
91
+ )
92
+
93
+ # Saving attn_bias is a bit complicated, as the
94
+ # torch part should go in `save_for_backward`
95
+ if isinstance(inp.attn_bias, torch.Tensor):
96
+ attn_bias_tensor = inp.attn_bias
97
+ attn_bias_ctx = None
98
+ else:
99
+ attn_bias_tensor = None
100
+ attn_bias_ctx = inp.attn_bias
101
+
102
+ ctx.save_for_backward(
103
+ inp.query,
104
+ inp.key,
105
+ inp.value,
106
+ op_ctx.out,
107
+ op_ctx.lse,
108
+ )
109
+ ctx.rng_state = op_ctx.rng_state
110
+ ctx.attn_bias_tensor = attn_bias_tensor
111
+ if op_ctx.op_bw is not None:
112
+ if op_bw is not None and op_bw is not op_ctx.op_bw:
113
+ raise ValueError(
114
+ f"Specified op_bw={op_bw.NAME}, but forward op "
115
+ f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
116
+ )
117
+ op_bw = op_ctx.op_bw
118
+ if (
119
+ op_bw is not None
120
+ and isinstance(inp.attn_bias, VARLEN_BIASES)
121
+ and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2
122
+ and op_bw.VARLEN_LSE_PACKED != op_fw.VARLEN_LSE_PACKED
123
+ ):
124
+ raise ValueError(
125
+ f"Specified op_bw={op_bw.NAME} is not compatible with the "
126
+ f"op_fw={op_fw.NAME}, because they use different format of logsumexp. "
127
+ f"NOTE: This is new with xFormers 0.0.28"
128
+ )
129
+ if op_bw is None and (
130
+ inp.query.requires_grad or inp.key.requires_grad or inp.value.requires_grad
131
+ ):
132
+ varlen_lse_packed = _detect_lse_packed_or_raise(op_ctx.lse, inp)
133
+ if varlen_lse_packed is not None and op_fw is not None:
134
+ assert (
135
+ op_fw.VARLEN_LSE_PACKED == varlen_lse_packed
136
+ ), f"{op_fw.NAME}: wrong value for `VARLEN_LSE_PACKED` ?"
137
+ # NOTE: We need to check tensor strides to decide which operator we run in the BW pass.
138
+ # Unfortunately, PyTorch only allows to call this function during the FW pass, so
139
+ # we decide the operator to use now.
140
+ op_bw = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
141
+ ctx.op_fw = op_fw
142
+ ctx.op_bw = op_bw
143
+ ctx.p = inp.p
144
+ # This allows to create gradients from a single storage,
145
+ # to avoid a "cat" in the BW pass.
146
+ # The heuristic is approximative, but:
147
+ # (1) It's not a big issue to create a shared storage
148
+ # (2) The heuristic needs to pass `torch.compile`
149
+ # (this is also why we run it in the FW pass, the BW pass is stricter)
150
+ ctx.qkv_share_storage = (
151
+ inp.query.shape[0] == inp.key.shape[0]
152
+ and inp.query.shape[-1] == inp.value.shape[-1]
153
+ and inp.query.stride(-2)
154
+ == (inp.key.shape[-1] + inp.query.shape[-1] + inp.value.shape[-1])
155
+ )
156
+
157
+ ctx.scale = inp.scale
158
+ ctx.attn_bias_ctx = attn_bias_ctx
159
+ ctx.n_args = len(args)
160
+ return out, op_ctx.lse
161
+
162
+ @staticmethod
163
+ @torch.autograd.function.once_differentiable
164
+ def backward(ctx, grad, grad_lse):
165
+ # Re-create context
166
+ query, key, value, out, lse = ctx.saved_tensors
167
+ attn_bias_tensor = ctx.attn_bias_tensor
168
+ rng_state = ctx.rng_state
169
+ inp = Inputs(
170
+ query=query,
171
+ key=key,
172
+ value=value,
173
+ attn_bias=_deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
174
+ p=ctx.p,
175
+ scale=ctx.scale,
176
+ )
177
+ op_ctx = Context(
178
+ lse=lse,
179
+ out=out,
180
+ rng_state=rng_state,
181
+ )
182
+ grads = _memory_efficient_attention_backward(
183
+ ctx=op_ctx,
184
+ inp=inp,
185
+ grad=grad,
186
+ op=ctx.op_bw,
187
+ _skip_op_checks=True,
188
+ )
189
+ return (None, None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
190
+ ctx.n_args - 2
191
+ )
192
+
193
+
194
+ def memory_efficient_attention(
195
+ query: torch.Tensor,
196
+ key: torch.Tensor,
197
+ value: torch.Tensor,
198
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
199
+ p: float = 0.0,
200
+ scale: Optional[float] = None,
201
+ *,
202
+ op: Optional[AttentionOp] = None,
203
+ output_dtype: Optional[torch.dtype] = None,
204
+ ) -> torch.Tensor:
205
+ """Implements the memory-efficient attention mechanism following
206
+ `"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
207
+
208
+ :Inputs shape:
209
+
210
+ - Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
211
+ the sequence length, H the number of heads, and K the embeding size per head
212
+
213
+ - If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
214
+
215
+ - Inputs can also be of dimension 5 with GQA - see note below
216
+
217
+ - Inputs can be non-contiguous - we only require the last dimension's stride to be 1
218
+
219
+
220
+ :Equivalent pytorch code:
221
+
222
+ .. code-block:: python
223
+
224
+ scale = 1.0 / query.shape[-1] ** 0.5
225
+ query = query * scale
226
+ query = query.transpose(1, 2)
227
+ key = key.transpose(1, 2)
228
+ value = value.transpose(1, 2)
229
+ attn = query @ key.transpose(-2, -1)
230
+ if attn_bias is not None:
231
+ attn = attn + attn_bias
232
+ attn = attn.softmax(-1)
233
+ attn = F.dropout(attn, p)
234
+ attn = attn @ value
235
+ return attn.transpose(1, 2)
236
+
237
+ :Examples:
238
+
239
+ .. code-block:: python
240
+
241
+ import xformers.ops as xops
242
+
243
+ # Compute regular attention
244
+ y = xops.memory_efficient_attention(q, k, v)
245
+
246
+ # With a dropout of 0.2
247
+ y = xops.memory_efficient_attention(q, k, v, p=0.2)
248
+
249
+ # Causal attention
250
+ y = xops.memory_efficient_attention(
251
+ q, k, v,
252
+ attn_bias=xops.LowerTriangularMask()
253
+ )
254
+
255
+ :Supported hardware:
256
+
257
+ NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
258
+
259
+ :EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
260
+
261
+ MQA/GQA is an experimental feature supported only for the forward pass.
262
+ If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
263
+ in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
264
+ ``H`` is the number of heads per group (8 in the example).
265
+
266
+ Please note that xFormers will not automatically broadcast the inputs, so you will need
267
+ to broadcast it manually before calling `memory_efficient_attention`.
268
+
269
+ :GQA/MQA example:
270
+
271
+ .. code-block:: python
272
+
273
+ import torch
274
+ import xformers.ops as xops
275
+
276
+ B, M, K = 3, 32, 128
277
+ kwargs = dict(device="cuda", dtype=torch.float16)
278
+ q = torch.randn([B, M, 8, K], **kwargs)
279
+ k = torch.randn([B, M, 2, K], **kwargs)
280
+ v = torch.randn([B, M, 2, K], **kwargs)
281
+ out_gqa = xops.memory_efficient_attention(
282
+ q.reshape([B, M, 2, 4, K]),
283
+ k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
284
+ v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
285
+ )
286
+
287
+ Raises:
288
+ NotImplementedError: if there is no operator available to compute the MHA
289
+ ValueError: if inputs are invalid
290
+
291
+ :parameter query: Tensor of shape ``[B, Mq, H, K]``
292
+ :parameter key: Tensor of shape ``[B, Mkv, H, K]``
293
+ :parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
294
+ :parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
295
+ For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
296
+ This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
297
+ :parameter p: Dropout probability. Disabled if set to ``0.0``
298
+ :parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
299
+ scale (q.shape[-1]**-0.5) will be used.
300
+ :parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
301
+ If set to ``None`` (recommended), xFormers \
302
+ will dispatch to the best available operator, depending on the inputs \
303
+ and options.
304
+ :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
305
+ """
306
+ return _memory_efficient_attention(
307
+ Inputs(
308
+ query=query,
309
+ key=key,
310
+ value=value,
311
+ p=p,
312
+ attn_bias=attn_bias,
313
+ scale=scale,
314
+ output_dtype=output_dtype,
315
+ ),
316
+ op=op,
317
+ )
318
+
319
+
320
+ torch.library.define(
321
+ "xformer::memory_efficient_attention_forward",
322
+ "(Tensor q, Tensor k, Tensor v, Tensor? b = None, float? p = 0.0, float? scale = None) -> Tensor",
323
+ )
324
+
325
+
326
+ @torch.library.impl("xformer::memory_efficient_attention_forward", "Meta")
327
+ def memory_efficient_attention_forward_meta(q, k, v):
328
+ return q.new_empty(q.shape)
329
+
330
+
331
+ # torch.compile has issue when tracing through op dispatch and ensure_op_support
332
+ # so provide a wrapper to register it as a custom torch library op.
333
+ @torch.library.impl("xformer::memory_efficient_attention_forward", "CUDA")
334
+ def memory_efficient_attention_forward_torch_wrapper(
335
+ query: torch.Tensor,
336
+ key: torch.Tensor,
337
+ value: torch.Tensor,
338
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
339
+ p: float = 0.0,
340
+ scale: Optional[float] = None,
341
+ ) -> torch.Tensor:
342
+ """
343
+ This provides a torch-compilable wrapper op to
344
+ memory_efficient_attention_forward in certain special cases.
345
+
346
+ Note that the following are not supported
347
+ - `op` input (?)
348
+ - certain attn_bias types (?)
349
+ - output_dtype
350
+ - K != Kv
351
+ """
352
+ return memory_efficient_attention_forward(
353
+ query,
354
+ key,
355
+ value,
356
+ attn_bias,
357
+ p,
358
+ scale,
359
+ )
360
+
361
+
362
+ def memory_efficient_attention_forward(
363
+ query: torch.Tensor,
364
+ key: torch.Tensor,
365
+ value: torch.Tensor,
366
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
367
+ p: float = 0.0,
368
+ scale: Optional[float] = None,
369
+ *,
370
+ op: Optional[Type[AttentionFwOpBase]] = None,
371
+ output_dtype: Optional[torch.dtype] = None,
372
+ ) -> torch.Tensor:
373
+ """
374
+ Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
375
+ """
376
+ return _memory_efficient_attention_forward(
377
+ Inputs(
378
+ query=query,
379
+ key=key,
380
+ value=value,
381
+ p=p,
382
+ attn_bias=attn_bias,
383
+ scale=scale,
384
+ output_dtype=output_dtype,
385
+ ),
386
+ op=op,
387
+ )
388
+
389
+
390
+ def memory_efficient_attention_forward_requires_grad(
391
+ query: torch.Tensor,
392
+ key: torch.Tensor,
393
+ value: torch.Tensor,
394
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
395
+ p: float = 0.0,
396
+ scale: Optional[float] = None,
397
+ *,
398
+ op: Optional[Type[AttentionFwOpBase]] = None,
399
+ output_dtype: Optional[torch.dtype] = None,
400
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
401
+ """
402
+ Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
403
+ See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
404
+ See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
405
+ """
406
+ if p != 0.0:
407
+ raise NotImplementedError(
408
+ "dropout is not supported on the non-autograd API."
409
+ " If you want to use dropout, please call `memory_efficient_attention` directly"
410
+ )
411
+ out, ctx = _memory_efficient_attention_forward_requires_grad(
412
+ Inputs(
413
+ query=query,
414
+ key=key,
415
+ value=value,
416
+ p=p,
417
+ attn_bias=attn_bias,
418
+ scale=scale,
419
+ output_dtype=output_dtype,
420
+ ),
421
+ op=op,
422
+ )
423
+ return out, ctx.lse
424
+
425
+
426
+ def memory_efficient_attention_backward(
427
+ grad: torch.Tensor,
428
+ output: torch.Tensor,
429
+ lse: torch.Tensor,
430
+ query: torch.Tensor,
431
+ key: torch.Tensor,
432
+ value: torch.Tensor,
433
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
434
+ p: float = 0.0,
435
+ scale: Optional[float] = None,
436
+ *,
437
+ op: Optional[Type[AttentionBwOpBase]] = None,
438
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
439
+ """
440
+ Computes the gradient of the attention.
441
+ Returns a tuple (dq, dk, dv)
442
+ See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
443
+ `lse` is the tensor returned by
444
+ :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
445
+ """
446
+ if p != 0.0:
447
+ raise NotImplementedError(
448
+ "dropout is not supported on the non-autograd API."
449
+ " If you want to use dropout, please call `memory_efficient_attention` directly"
450
+ )
451
+ gradients = _memory_efficient_attention_backward(
452
+ Context(out=output, lse=lse),
453
+ Inputs(
454
+ query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
455
+ ),
456
+ grad,
457
+ op=op,
458
+ )
459
+ return (gradients.dq, gradients.dk, gradients.dv)
460
+
461
+
462
+ def _memory_efficient_attention(
463
+ inp: Inputs, op: Optional[AttentionOp] = None
464
+ ) -> torch.Tensor:
465
+ # fast-path that doesn't require computing the logsumexp for backward computation
466
+ if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
467
+ return _memory_efficient_attention_forward(
468
+ inp, op=op[0] if op is not None else None
469
+ )
470
+
471
+ output_shape = inp.normalize_bmhk()
472
+
473
+ op_fw = _serialize_op(op[0] if op is not None else None)
474
+ op_bw = _serialize_op(op[1] if op is not None else None)
475
+ return _fMHA.apply(
476
+ op_fw, op_bw, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
477
+ )[0].reshape(output_shape)
478
+
479
+
480
+ def _memory_efficient_attention_forward(
481
+ inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
482
+ ) -> torch.Tensor:
483
+ inp.validate_inputs()
484
+ output_shape = inp.normalize_bmhk()
485
+ if op is None:
486
+ op = _dispatch_fw(inp, False)
487
+ else:
488
+ _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
489
+
490
+ out, *_ = op.apply(inp, needs_gradient=False)
491
+ return out.reshape(output_shape)
492
+
493
+
494
+ def _memory_efficient_attention_forward_requires_grad(
495
+ inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
496
+ ) -> Tuple[torch.Tensor, Context]:
497
+ inp.validate_inputs()
498
+ output_shape = inp.normalize_bmhk()
499
+ if op is None:
500
+ op = _dispatch_fw(inp, True)
501
+ else:
502
+ _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
503
+ out = op.apply(inp, needs_gradient=True)
504
+ assert out[1] is not None
505
+ return (out[0].reshape(output_shape), out[1])
506
+
507
+
508
+ def _detect_lse_packed_or_raise(lse: torch.Tensor, inp: Inputs) -> Optional[bool]:
509
+ """
510
+ Detects the LSE format if we're in a varlen case.
511
+ Returns `None` if the format is not relevant (eg not varlen)
512
+ Raises an exception if the `lse` has the wrong shape
513
+ """
514
+ shape_mismatch_err = (
515
+ "Input tensors have incompatible shapes.\n"
516
+ f" lse.shape : {lse.shape}\n"
517
+ f" query.shape : {inp.query.shape}\n"
518
+ f" attn_bias : {type(inp.attn_bias)}"
519
+ )
520
+ # 1. Check ndim & head dimensions
521
+ # In any case, LSE should be [*, *GH]
522
+ if lse.ndim != (inp.query.ndim - 1) or lse.shape[1:-1] != inp.query.shape[2:-1]:
523
+ raise ValueError(shape_mismatch_err)
524
+ lse_bm = [lse.shape[0], lse.shape[-1]]
525
+ lse_packed_shape = [inp.query.shape[0], inp.query.shape[1]]
526
+ lse_packed = lse_bm[0] == lse_packed_shape[0] and lse_bm >= lse_packed_shape
527
+ # 2. Check correctness for varlen biases with query.shape = [1, M, *GH, K]
528
+ # Either [1, *GH, M] (packed)
529
+ # Or [num_seq, *GH, Mq] .. with `Mq >= max_q` (padded)
530
+ if isinstance(inp.attn_bias, VARLEN_BIASES):
531
+ si = inp.attn_bias.q_seqinfo
532
+ lse_padded_shape = [si.seqstart.shape[0] - 1, si.max_seqlen]
533
+ lse_padded = lse_bm[0] == lse_padded_shape[0] and lse_bm >= lse_padded_shape
534
+ if lse_packed and lse_padded:
535
+ return None
536
+ elif lse_packed:
537
+ return True
538
+ elif lse_padded:
539
+ return False
540
+ raise ValueError(shape_mismatch_err)
541
+ # 3. For non-varlen, shape must be [B, *GH] with query.shape=[B, M, *GH, K]
542
+ if not lse_packed:
543
+ raise ValueError(shape_mismatch_err)
544
+ return None
545
+
546
+
547
+ def _memory_efficient_attention_backward(
548
+ ctx: Context,
549
+ inp: Inputs,
550
+ grad: torch.Tensor,
551
+ op: Optional[Type[AttentionBwOpBase]],
552
+ *,
553
+ _skip_op_checks: bool = False,
554
+ ) -> Gradients:
555
+ """Warning: grad/ctx.out is potentially in BMK format"""
556
+ inp.validate_inputs()
557
+ if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
558
+ raise ValueError(
559
+ "All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
560
+ f"grad.shape : {grad.shape} \n"
561
+ f"out.shape : {ctx.out.shape} \n"
562
+ f"query.shape: {inp.query.shape}"
563
+ )
564
+ shape_dq, shape_dk, shape_dv = tuple(
565
+ x.shape for x in (inp.query, inp.key, inp.value)
566
+ )
567
+ inp.normalize_bmhk()
568
+ varlen_lse_packed = _detect_lse_packed_or_raise(ctx.lse, inp)
569
+ grad = bmk2bmhk(grad, 1)
570
+ ctx.out = bmk2bmhk(ctx.out, 1)
571
+
572
+ if op is None:
573
+ op = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
574
+ elif not _skip_op_checks:
575
+ _ensure_op_supports_or_raise(
576
+ ValueError, "memory_efficient_attention_backward", op, inp
577
+ )
578
+ if varlen_lse_packed is not None and varlen_lse_packed != op.VARLEN_LSE_PACKED:
579
+ raise ValueError(
580
+ f"Wrong LSE format for {op.NAME} in variable seqlen case. "
581
+ f"Double-check that the BW operator {op.NAME} is compatible "
582
+ f"with the operator used in the FW pass."
583
+ )
584
+
585
+ grads = op.apply(ctx, inp, grad)
586
+ grads.dq = grads.dq.reshape(shape_dq)
587
+ grads.dk = grads.dk.reshape(shape_dk)
588
+ grads.dv = grads.dv.reshape(shape_dv)
589
+ return grads
590
+
591
+
592
+ def memory_efficient_attention_partial(
593
+ query: torch.Tensor,
594
+ key: torch.Tensor,
595
+ value: torch.Tensor,
596
+ attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
597
+ p: float = 0.0,
598
+ scale: Optional[float] = None,
599
+ *,
600
+ op: Optional[Union[AttentionOp, Type[AttentionFwOpBase]]] = None,
601
+ output_dtype: Optional[torch.dtype] = None,
602
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
603
+ """
604
+ Returns a tuple (output, lse), where `output` is the attention in the style of
605
+ memory_efficient_attention, and `lse` is extra data, a log-sum-exp.
606
+ The outputs of calls to this with the same query and separate keys and values
607
+ can be merged with merge_attentions to obtain the attention of the queries
608
+ against the disjoint union of the keys and values.
609
+
610
+ Warning: The backward pass of this function is quite restricted. In particular
611
+ we assume that in the forward pass the outputs were only used in merge_attention
612
+ calculations, and that LSEs weren't used anywhere except in merge attentions.
613
+ """
614
+ if p != 0.0:
615
+ raise NotImplementedError("dropout is not supported.")
616
+ fwop: Optional[Type[AttentionFwOpBase]] = op[0] if isinstance(op, tuple) else op
617
+ inp = Inputs(
618
+ query=query,
619
+ key=key,
620
+ value=value,
621
+ p=p,
622
+ attn_bias=attn_bias,
623
+ scale=scale,
624
+ output_dtype=output_dtype,
625
+ is_partial=True,
626
+ )
627
+
628
+ is_grad = torch.is_grad_enabled() and any(
629
+ x.requires_grad for x in [query, key, value]
630
+ )
631
+
632
+ if not is_grad:
633
+ out, ctx = _memory_efficient_attention_forward_requires_grad(
634
+ inp,
635
+ op=fwop,
636
+ )
637
+ return out, ctx.lse
638
+
639
+ if query.ndim == 5:
640
+ raise ValueError("gradients not supported for 5D tensors")
641
+ if isinstance(op, tuple):
642
+ op_fw = _serialize_op(op[0])
643
+ op_bw = _serialize_op(op[1])
644
+ elif op is None:
645
+ op_fw = op_bw = None
646
+ else:
647
+ op_fw = _serialize_op(op)
648
+ op_bw = None
649
+ return _fMHA.apply(
650
+ op_fw,
651
+ op_bw,
652
+ inp.query,
653
+ inp.key,
654
+ inp.value,
655
+ inp.attn_bias,
656
+ inp.p,
657
+ inp.scale,
658
+ inp.output_dtype,
659
+ inp.is_partial,
660
+ )
661
+
662
+
663
+ def merge_attentions(
664
+ attn_split: Union[torch.Tensor, Sequence[torch.Tensor]],
665
+ lse_split: Union[torch.Tensor, Sequence[torch.Tensor]],
666
+ write_lse: bool = True,
667
+ output_dtype: Optional[torch.dtype] = None,
668
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
669
+ """
670
+ Combine attention output computed on different parts of K/V for the same
671
+ query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099
672
+ The result is equal to
673
+ Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...)
674
+ LSE_full = log(exp(LSE1) + exp(LSE2) + ...)
675
+
676
+ Args:
677
+ attn_split: attention outputs for chunks,
678
+ either as a list of tensors of shapes [B, M, G, H, Kq] or [B, M, H, Kq]
679
+ or as a single tensor of shape [num_chunks, B, M, G, H, Kq]
680
+ or [num_chunks, B, M, H, Kq]
681
+ lse_split: LSE for chunks,
682
+ either as a list of tensors of shapes [B, G, H, M] or [B, H, M]
683
+ or as a single tensor of shape [num_chunks, B, G, H, M] or [num_chunks, B, H, M]
684
+ write_lse: whether to output LSE
685
+ output_dtype: dtype of attn_out
686
+
687
+ Returns:
688
+ attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
689
+ lse_out: [B, G, H, M] or [B, H, M] if write_lse
690
+ or None otherwise
691
+ """
692
+
693
+ attn_is_concat = isinstance(attn_split, torch.Tensor)
694
+ lse_is_concat = isinstance(lse_split, torch.Tensor)
695
+
696
+ attn_requires_grad = (
697
+ attn_split.requires_grad # type: ignore
698
+ if attn_is_concat
699
+ else any(x.requires_grad for x in attn_split)
700
+ )
701
+ lse_requires_grad = (
702
+ lse_split.requires_grad # type: ignore
703
+ if lse_is_concat
704
+ else any(x.requires_grad for x in lse_split)
705
+ )
706
+ requires_grad = torch.is_grad_enabled() and (
707
+ attn_requires_grad or lse_requires_grad
708
+ )
709
+ if requires_grad and not write_lse:
710
+ raise ValueError("write_lse should be true if inputs require gradients.")
711
+
712
+ concat_path = attn_is_concat and lse_is_concat and not requires_grad
713
+ if concat_path:
714
+ attn_split = cast(torch.Tensor, attn_split)
715
+ lse_split = cast(torch.Tensor, lse_split)
716
+ if attn_split.ndim != lse_split.ndim + 1:
717
+ raise ValueError(
718
+ f"Incompatible input shapes: {attn_split.shape=}, {lse_split.shape=}"
719
+ )
720
+
721
+ is_bmhk = attn_split.ndim == 5
722
+ if is_bmhk:
723
+ attn_split = attn_split.unsqueeze(3)
724
+ lse_split = lse_split.unsqueeze(2)
725
+
726
+ num_chunks, B, M, G, H, Kq = attn_split.shape
727
+ num_chunks1, B1, G1, H1, M1 = lse_split.shape
728
+ if B != B1 or G != G1 or H != H1 or num_chunks != num_chunks1 or M != M:
729
+ raise ValueError(
730
+ f"Incompatible input shapes: {attn_split.shape=} {lse_split.shape=} "
731
+ f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {num_chunks}/{num_chunks1}, {M}/{M}"
732
+ )
733
+
734
+ attn_split = attn_split.permute(1, 3, 4, 0, 2, 5)
735
+ lse_split = lse_split.permute(1, 2, 3, 0, 4)
736
+
737
+ device = attn_split.device
738
+ attn_dtype = attn_split.dtype
739
+ lse_dtype = lse_split.dtype
740
+ else:
741
+ if attn_is_concat:
742
+ attn_split = attn_split.unbind(0) # type: ignore
743
+ if lse_is_concat:
744
+ lse_split = lse_split.unbind(0) # type: ignore
745
+ num_chunks = len(attn_split)
746
+ if len(lse_split) != num_chunks:
747
+ raise ValueError(
748
+ f"Incompatible number of LSE and attention chunks: {len(attn_split)=}, {len(lse_split)=}"
749
+ )
750
+
751
+ attn_unsqueezed = []
752
+ lse_unsqueezed = []
753
+ is_bmhk = False
754
+ for i in range(num_chunks):
755
+ if attn_split[i].ndim != lse_split[i].ndim + 1:
756
+ raise ValueError(
757
+ f"Incompatible input shapes for chunk {i}: {attn_split[i].shape=}, {lse_split[i].shape=}"
758
+ )
759
+
760
+ is_bmhk = attn_split[i].ndim == 4
761
+ if is_bmhk:
762
+ attn_unsqueezed.append(attn_split[i].unsqueeze(2))
763
+ lse_unsqueezed.append(lse_split[i].unsqueeze(1))
764
+ else:
765
+ attn_unsqueezed.append(attn_split[i])
766
+ lse_unsqueezed.append(lse_split[i])
767
+ attn_split, lse_split = attn_unsqueezed, lse_unsqueezed
768
+
769
+ B, M, G, H, Kq = attn_split[0].shape
770
+ B1, G1, H1, M1 = lse_split[0].shape
771
+ if B != B1 or G != G1 or H != H1 or M != M:
772
+ raise ValueError(
773
+ f"Incompatible input shapes: {attn_split[0].shape=}, {lse_split[0].shape=} "
774
+ f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {M}/{M}"
775
+ )
776
+
777
+ for i in range(num_chunks):
778
+ if attn_split[i].shape != (B, M, G, H, Kq):
779
+ raise ValueError(
780
+ f"Incompatible input shapes for attention chunk {i}: "
781
+ f"{attn_split[i].shape=}, {(B, M, G, H, Kq)=}"
782
+ )
783
+ if lse_split[i].shape != (B, G, H, M):
784
+ raise ValueError(
785
+ f"Incompatible input shapes for LSE chunk {i}: "
786
+ f"{lse_split[i].shape=}, {(B, G, H, M)=}"
787
+ )
788
+
789
+ attn_split[i] = attn_split[i].permute(0, 2, 3, 1, 4) # to (B, G, H, M, Kq)
790
+
791
+ device = attn_split[0].device
792
+ attn_dtype = attn_split[0].dtype
793
+ lse_dtype = lse_split[0].dtype
794
+
795
+ attn_out = torch.empty(
796
+ B,
797
+ M,
798
+ G,
799
+ H,
800
+ Kq,
801
+ device=device,
802
+ dtype=output_dtype or attn_dtype,
803
+ requires_grad=requires_grad,
804
+ )
805
+ if write_lse:
806
+ lse_out = torch.empty(
807
+ B, G, H, M, device=device, dtype=lse_dtype, requires_grad=requires_grad
808
+ )
809
+ else:
810
+ lse_out = None
811
+
812
+ if concat_path:
813
+ triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split) # type: ignore
814
+ else:
815
+ attn_out, lse_out = _MergeAttentions.apply(attn_out, lse_out, *attn_split, *lse_split) # type: ignore
816
+
817
+ if is_bmhk:
818
+ attn_out = attn_out[:, :, 0]
819
+ if lse_out is not None:
820
+ lse_out = lse_out[:, 0]
821
+
822
+ return attn_out, lse_out
823
+
824
+
825
+ class _MergeAttentions(torch.autograd.Function):
826
+ @staticmethod
827
+ # type: ignore
828
+ def forward(
829
+ ctx, attn_out: torch.Tensor, lse_out: torch.Tensor, *inputs: torch.Tensor
830
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
831
+ num_chunks = len(inputs) // 2
832
+ attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
833
+
834
+ triton_splitk.merge_attentions_varargs(attn_out, lse_out, attn_split, lse_split)
835
+
836
+ ctx.save_for_backward(
837
+ attn_out,
838
+ lse_out,
839
+ *inputs,
840
+ )
841
+ return attn_out, lse_out
842
+
843
+ @staticmethod
844
+ # type: ignore
845
+ def backward(
846
+ ctx, grad_attn: torch.Tensor, grad_lse: torch.Tensor
847
+ ) -> Tuple[Optional[torch.Tensor], ...]:
848
+ out, lse, *inputs = ctx.saved_tensors
849
+ num_chunks = len(inputs) // 2
850
+ attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
851
+ dattn, dlse = triton_splitk.merge_attentions_varargs_backward(
852
+ attn_split,
853
+ lse_split,
854
+ out,
855
+ lse,
856
+ grad_attn,
857
+ grad_lse,
858
+ )
859
+ ret = [None, None] + dattn + dlse
860
+ return tuple(ret)
861
+
862
+
863
+ ALL_FW_OPS: List[Type[AttentionFwOpBase]] = [
864
+ cutlass.FwOp if torch.version.cuda else ck.FwOp,
865
+ flash.FwOp,
866
+ flash3.FwOp,
867
+ triton_splitk.FwOp,
868
+ ]
869
+
870
+ ALL_BW_OPS: List[Type[AttentionBwOpBase]] = [
871
+ cutlass.BwOp if torch.version.cuda else ck.BwOp,
872
+ flash.BwOp,
873
+ flash3.BwOp,
874
+ ]
875
+
876
+ __all__ = [
877
+ "AttentionBias",
878
+ "AttentionOp",
879
+ "AttentionOpBase",
880
+ "LowerTriangularMask",
881
+ "MemoryEfficientAttentionCutlassFwdFlashBwOp",
882
+ "MemoryEfficientAttentionCutlassOp",
883
+ "MemoryEfficientAttentionFlashAttentionOp",
884
+ "memory_efficient_attention",
885
+ "MemoryEfficientAttentionCkOp",
886
+ "MemoryEfficientAttentionCkDecoderOp",
887
+ "ALL_FW_OPS",
888
+ "ALL_BW_OPS",
889
+ "attn_bias",
890
+ "_get_use_fa3",
891
+ "_set_use_fa3",
892
+ "BlockDiagonalMask",
893
+ ]
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (36.5 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/attn_bias.cpython-311.pyc ADDED
Binary file (84.4 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_decoder.cpython-311.pyc ADDED
Binary file (6.87 kB). View file