koichi12 commited on
Commit
640f355
·
verified ·
1 Parent(s): 5ef2986

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/xformers/_flash_attn/__init__.py +11 -0
  2. .venv/lib/python3.11/site-packages/xformers/_flash_attn/bert_padding.py +213 -0
  3. .venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_interface.py +1286 -0
  4. .venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton.py +1160 -0
  5. .venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton_og.py +365 -0
  6. .venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attention.py +197 -0
  7. .venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attn_interface.py +200 -0
  8. .venv/lib/python3.11/site-packages/xformers/_flash_attn/fused_softmax.py +201 -0
  9. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__init__.py +0 -0
  10. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/block.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/embedding.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mha.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mlp.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/block.py +397 -0
  16. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/embedding.py +216 -0
  17. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mha.py +1020 -0
  18. .venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mlp.py +191 -0
  19. .venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/layer_norm.py +800 -0
  20. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__init__.py +4 -0
  21. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_fetch_results.py +96 -0
  28. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_submit.py +49 -0
  29. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__init__.py +4 -0
  30. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/dataset.py +46 -0
  34. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/model_wrapper.py +288 -0
  35. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_grid_search.py +148 -0
  36. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_tasks.py +302 -0
  37. .venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_with_submitit.py +153 -0
  38. .venv/lib/python3.11/site-packages/xformers/benchmarks/__init__.py +4 -0
  39. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_core.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_merge_attentions.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sequence_parallel_fused.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sp24.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/xformers/_flash_attn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.6.3"
2
+
3
+ from flash_attn.flash_attn_interface import (
4
+ flash_attn_func,
5
+ flash_attn_kvpacked_func,
6
+ flash_attn_qkvpacked_func,
7
+ flash_attn_varlen_func,
8
+ flash_attn_varlen_kvpacked_func,
9
+ flash_attn_varlen_qkvpacked_func,
10
+ flash_attn_with_kvcache,
11
+ )
.venv/lib/python3.11/site-packages/xformers/_flash_attn/bert_padding.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
19
+ ).reshape(-1, *other_shape)
20
+
21
+ @staticmethod
22
+ def backward(ctx, grad_output):
23
+ (indices,) = ctx.saved_tensors
24
+ assert grad_output.ndim >= 2
25
+ other_shape = grad_output.shape[1:]
26
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
27
+ grad_input = torch.zeros(
28
+ [ctx.first_axis_dim, grad_output.shape[1]],
29
+ device=grad_output.device,
30
+ dtype=grad_output.dtype,
31
+ )
32
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
33
+ # grad_input[indices] = grad_output
34
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
35
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
36
+
37
+
38
+ index_first_axis = IndexFirstAxis.apply
39
+
40
+
41
+ class IndexPutFirstAxis(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(ctx, values, indices, first_axis_dim):
44
+ ctx.save_for_backward(indices)
45
+ assert indices.ndim == 1
46
+ assert values.ndim >= 2
47
+ output = torch.zeros(
48
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
49
+ )
50
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ output[indices] = values
52
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
53
+ return output
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ (indices,) = ctx.saved_tensors
58
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
59
+ grad_values = grad_output[indices]
60
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
61
+ return grad_values, None, None
62
+
63
+
64
+ index_put_first_axis = IndexPutFirstAxis.apply
65
+
66
+
67
+ class IndexFirstAxisResidual(torch.autograd.Function):
68
+ @staticmethod
69
+ def forward(ctx, input, indices):
70
+ ctx.save_for_backward(indices)
71
+ assert input.ndim >= 2
72
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
73
+ second_dim = other_shape.numel()
74
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
75
+ output = input[indices]
76
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
77
+ # memory format to channel_first. In other words, input might not be contiguous.
78
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
79
+ return output, input.detach()
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output, grad_residual):
83
+ (indices,) = ctx.saved_tensors
84
+ assert grad_output.ndim >= 2
85
+ other_shape = grad_output.shape[1:]
86
+ assert grad_residual.shape[1:] == other_shape
87
+ grad_input = grad_residual
88
+ # grad_input[indices] += grad_output
89
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
90
+ indices = indices.expand_as(grad_output)
91
+ grad_input.scatter_add_(0, indices, grad_output)
92
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
93
+
94
+
95
+ index_first_axis_residual = IndexFirstAxisResidual.apply
96
+
97
+
98
+ def unpad_input(hidden_states, attention_mask):
99
+ """
100
+ Arguments:
101
+ hidden_states: (batch, seqlen, ...)
102
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
103
+ Return:
104
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
105
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
106
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
107
+ max_seqlen_in_batch: int
108
+ """
109
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
110
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
111
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
112
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
113
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
114
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
115
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
116
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
117
+ # so we write custom forward and backward to make it a bit faster.
118
+ return (
119
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
120
+ indices,
121
+ cu_seqlens,
122
+ max_seqlen_in_batch,
123
+ )
124
+
125
+
126
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
127
+ """
128
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
129
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
130
+
131
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
132
+ ```
133
+ [
134
+ [2, 3, 0, 0, 0, 0],
135
+ [3, 2, 0, 0, 0, 0],
136
+ [6, 0, 0, 0, 0, 0]
137
+ ]
138
+ ```
139
+ , which refers to the 3D-attention mask:
140
+ ```
141
+ [
142
+ [
143
+ [1, 0, 0, 0, 0, 0],
144
+ [1, 1, 0, 0, 0, 0],
145
+ [0, 0, 1, 0, 0, 0],
146
+ [0, 0, 1, 1, 0, 0],
147
+ [0, 0, 1, 1, 1, 0],
148
+ [0, 0, 0, 0, 0, 1]
149
+ ],
150
+ [
151
+ [1, 0, 0, 0, 0, 0],
152
+ [1, 1, 0, 0, 0, 0],
153
+ [1, 1, 1, 0, 0, 0],
154
+ [0, 0, 0, 1, 0, 0],
155
+ [0, 0, 0, 1, 1, 0],
156
+ [0, 0, 0, 0, 0, 1]
157
+ ],
158
+ [
159
+ [1, 0, 0, 0, 0, 0],
160
+ [1, 1, 0, 0, 0, 0],
161
+ [1, 1, 1, 0, 0, 0],
162
+ [1, 1, 1, 1, 0, 0],
163
+ [1, 1, 1, 1, 1, 0],
164
+ [1, 1, 1, 1, 1, 1]
165
+ ]
166
+ ]
167
+ ```.
168
+
169
+ Arguments:
170
+ hidden_states: (batch, seqlen, ...)
171
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
172
+ Return:
173
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
174
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
175
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
176
+ max_seqlen_in_batch: int
177
+ """
178
+ length = attention_mask_in_length.sum(dim=-1)
179
+ seqlen = attention_mask_in_length.size(-1)
180
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
181
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
182
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
183
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
184
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
185
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
186
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
187
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
188
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
189
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
190
+ # so we write custom forward and backward to make it a bit faster.
191
+ return (
192
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
193
+ indices,
194
+ cu_seqlens,
195
+ max_seqlen_in_batch,
196
+ )
197
+
198
+
199
+ def pad_input(hidden_states, indices, batch, seqlen):
200
+ """
201
+ Arguments:
202
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
203
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
204
+ batch: int, batch size for the padded sequence.
205
+ seqlen: int, maximum sequence length for the padded sequence.
206
+ Return:
207
+ hidden_states: (batch, seqlen, ...)
208
+ """
209
+ dim = hidden_states.shape[-1]
210
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
211
+ # output[indices] = hidden_states
212
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
213
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_interface.py ADDED
@@ -0,0 +1,1286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ import flash_attn_2_cuda as flash_attn_cuda
11
+
12
+ # isort: on
13
+
14
+ def maybe_contiguous(x):
15
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
16
+
17
+ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
18
+ # This should match the block sizes in the CUDA kernel
19
+ assert head_dim <= 256
20
+ major, minor = torch.cuda.get_device_capability(device)
21
+ is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
22
+ is_sm80 = major == 8 and minor == 0
23
+ is_sm90 = major == 9 and minor == 0
24
+ if head_dim <= 32:
25
+ return 128
26
+ if head_dim <= 64:
27
+ return 128 if not is_dropout else 64
28
+ elif head_dim <= 96:
29
+ return 64
30
+ elif head_dim <= 128:
31
+ if is_sm8x:
32
+ return 64 if (not is_dropout and is_causal) else 32
33
+ else:
34
+ return 64 if not is_dropout else 32
35
+ elif head_dim <= 160:
36
+ if is_sm8x:
37
+ return 64
38
+ else:
39
+ return 32
40
+ elif head_dim <= 192:
41
+ return 64
42
+ elif head_dim <= 224:
43
+ return 64
44
+ elif head_dim <= 256:
45
+ return 64
46
+
47
+
48
+ def _flash_attn_forward(
49
+ q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
50
+ ):
51
+ q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
52
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
53
+ q,
54
+ k,
55
+ v,
56
+ None,
57
+ alibi_slopes,
58
+ dropout_p,
59
+ softmax_scale,
60
+ causal,
61
+ window_size[0],
62
+ window_size[1],
63
+ softcap,
64
+ return_softmax,
65
+ None,
66
+ )
67
+ return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
68
+
69
+
70
+ def _flash_attn_varlen_forward(
71
+ q,
72
+ k,
73
+ v,
74
+ cu_seqlens_q,
75
+ cu_seqlens_k,
76
+ max_seqlen_q,
77
+ max_seqlen_k,
78
+ dropout_p,
79
+ softmax_scale,
80
+ causal,
81
+ window_size=(-1, -1),
82
+ softcap=0.0,
83
+ alibi_slopes=None,
84
+ return_softmax=False,
85
+ block_table=None,
86
+ leftpad_k=None,
87
+ seqused_k=None,
88
+ ):
89
+ q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
90
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
91
+ q,
92
+ k,
93
+ v,
94
+ None,
95
+ cu_seqlens_q,
96
+ cu_seqlens_k,
97
+ seqused_k,
98
+ leftpad_k,
99
+ block_table,
100
+ alibi_slopes,
101
+ max_seqlen_q,
102
+ max_seqlen_k,
103
+ dropout_p,
104
+ softmax_scale,
105
+ False,
106
+ causal,
107
+ window_size[0],
108
+ window_size[1],
109
+ softcap,
110
+ return_softmax,
111
+ None,
112
+ )
113
+ # if out.isnan().any() or softmax_lse.isnan().any():
114
+ # breakpoint()
115
+ return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
116
+
117
+
118
+ def _flash_attn_backward(
119
+ dout,
120
+ q,
121
+ k,
122
+ v,
123
+ out,
124
+ softmax_lse,
125
+ dq,
126
+ dk,
127
+ dv,
128
+ dropout_p,
129
+ softmax_scale,
130
+ causal,
131
+ window_size,
132
+ softcap,
133
+ alibi_slopes,
134
+ deterministic,
135
+ rng_state=None,
136
+ ):
137
+ # dq, dk, dv are allocated by us so they should already be contiguous
138
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
139
+ (
140
+ dq,
141
+ dk,
142
+ dv,
143
+ softmax_d,
144
+ ) = flash_attn_cuda.bwd(
145
+ dout,
146
+ q,
147
+ k,
148
+ v,
149
+ out,
150
+ softmax_lse,
151
+ dq,
152
+ dk,
153
+ dv,
154
+ alibi_slopes,
155
+ dropout_p,
156
+ softmax_scale,
157
+ causal,
158
+ window_size[0],
159
+ window_size[1],
160
+ softcap,
161
+ deterministic,
162
+ None,
163
+ rng_state,
164
+ )
165
+ return dq, dk, dv, softmax_d
166
+
167
+
168
+ def _flash_attn_varlen_backward(
169
+ dout,
170
+ q,
171
+ k,
172
+ v,
173
+ out,
174
+ softmax_lse,
175
+ dq,
176
+ dk,
177
+ dv,
178
+ cu_seqlens_q,
179
+ cu_seqlens_k,
180
+ max_seqlen_q,
181
+ max_seqlen_k,
182
+ dropout_p,
183
+ softmax_scale,
184
+ causal,
185
+ window_size,
186
+ softcap,
187
+ alibi_slopes,
188
+ deterministic,
189
+ rng_state=None,
190
+ ):
191
+ # dq, dk, dv are allocated by us so they should already be contiguous
192
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
193
+ (
194
+ dq,
195
+ dk,
196
+ dv,
197
+ softmax_d,
198
+ ) = flash_attn_cuda.varlen_bwd(
199
+ dout,
200
+ q,
201
+ k,
202
+ v,
203
+ out,
204
+ softmax_lse,
205
+ dq,
206
+ dk,
207
+ dv,
208
+ cu_seqlens_q,
209
+ cu_seqlens_k,
210
+ alibi_slopes,
211
+ max_seqlen_q,
212
+ max_seqlen_k,
213
+ dropout_p,
214
+ softmax_scale,
215
+ False,
216
+ causal,
217
+ window_size[0],
218
+ window_size[1],
219
+ softcap,
220
+ deterministic,
221
+ None,
222
+ rng_state,
223
+ )
224
+ # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
225
+ # breakpoint()
226
+ return dq, dk, dv, softmax_d
227
+
228
+
229
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
230
+ @staticmethod
231
+ def forward(
232
+ ctx,
233
+ qkv,
234
+ dropout_p,
235
+ softmax_scale,
236
+ causal,
237
+ window_size,
238
+ softcap,
239
+ alibi_slopes,
240
+ deterministic,
241
+ return_softmax,
242
+ ):
243
+ if softmax_scale is None:
244
+ softmax_scale = qkv.shape[-1] ** (-0.5)
245
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
246
+ qkv[:, :, 0],
247
+ qkv[:, :, 1],
248
+ qkv[:, :, 2],
249
+ dropout_p,
250
+ softmax_scale,
251
+ causal=causal,
252
+ window_size=window_size,
253
+ softcap=softcap,
254
+ alibi_slopes=alibi_slopes,
255
+ return_softmax=return_softmax and dropout_p > 0,
256
+ )
257
+ ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
258
+ ctx.dropout_p = dropout_p
259
+ ctx.softmax_scale = softmax_scale
260
+ ctx.causal = causal
261
+ ctx.window_size = window_size
262
+ ctx.softcap = softcap
263
+ ctx.alibi_slopes = alibi_slopes
264
+ ctx.deterministic = deterministic
265
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
266
+
267
+ @staticmethod
268
+ def backward(ctx, dout, *args):
269
+ q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
270
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
271
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
272
+ _flash_attn_backward(
273
+ dout,
274
+ q,
275
+ k,
276
+ v,
277
+ out,
278
+ softmax_lse,
279
+ dqkv[:, :, 0],
280
+ dqkv[:, :, 1],
281
+ dqkv[:, :, 2],
282
+ ctx.dropout_p,
283
+ ctx.softmax_scale,
284
+ ctx.causal,
285
+ ctx.window_size,
286
+ ctx.softcap,
287
+ ctx.alibi_slopes,
288
+ ctx.deterministic,
289
+ rng_state=rng_state,
290
+ )
291
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
292
+ return dqkv, None, None, None, None, None, None, None, None
293
+
294
+
295
+ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
296
+ @staticmethod
297
+ def forward(
298
+ ctx,
299
+ qkv,
300
+ cu_seqlens,
301
+ max_seqlen,
302
+ dropout_p,
303
+ softmax_scale,
304
+ causal,
305
+ window_size,
306
+ softcap,
307
+ alibi_slopes,
308
+ deterministic,
309
+ return_softmax,
310
+ ):
311
+ if softmax_scale is None:
312
+ softmax_scale = qkv.shape[-1] ** (-0.5)
313
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
314
+ qkv[:, 0],
315
+ qkv[:, 1],
316
+ qkv[:, 2],
317
+ cu_seqlens,
318
+ cu_seqlens,
319
+ max_seqlen,
320
+ max_seqlen,
321
+ dropout_p,
322
+ softmax_scale,
323
+ causal=causal,
324
+ window_size=window_size,
325
+ softcap=softcap,
326
+ alibi_slopes=alibi_slopes,
327
+ return_softmax=return_softmax and dropout_p > 0,
328
+ block_table=None,
329
+ )
330
+ ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
331
+ ctx.dropout_p = dropout_p
332
+ ctx.max_seqlen = max_seqlen
333
+ ctx.softmax_scale = softmax_scale
334
+ ctx.causal = causal
335
+ ctx.window_size = window_size
336
+ ctx.softcap = softcap
337
+ ctx.alibi_slopes = alibi_slopes
338
+ ctx.deterministic = deterministic
339
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
340
+
341
+ @staticmethod
342
+ def backward(ctx, dout, *args):
343
+ q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
344
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
345
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
346
+ _flash_attn_varlen_backward(
347
+ dout,
348
+ q,
349
+ k,
350
+ v,
351
+ out,
352
+ softmax_lse,
353
+ dqkv[:, 0],
354
+ dqkv[:, 1],
355
+ dqkv[:, 2],
356
+ cu_seqlens,
357
+ cu_seqlens,
358
+ ctx.max_seqlen,
359
+ ctx.max_seqlen,
360
+ ctx.dropout_p,
361
+ ctx.softmax_scale,
362
+ ctx.causal,
363
+ ctx.window_size,
364
+ ctx.softcap,
365
+ ctx.alibi_slopes,
366
+ ctx.deterministic,
367
+ rng_state=rng_state,
368
+ )
369
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
370
+ return dqkv, None, None, None, None, None, None, None, None, None, None
371
+
372
+
373
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
374
+ @staticmethod
375
+ def forward(
376
+ ctx,
377
+ q,
378
+ kv,
379
+ dropout_p,
380
+ softmax_scale,
381
+ causal,
382
+ window_size,
383
+ softcap,
384
+ alibi_slopes,
385
+ deterministic,
386
+ return_softmax,
387
+ ):
388
+ if softmax_scale is None:
389
+ softmax_scale = q.shape[-1] ** (-0.5)
390
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
391
+ q,
392
+ kv[:, :, 0],
393
+ kv[:, :, 1],
394
+ dropout_p,
395
+ softmax_scale,
396
+ causal=causal,
397
+ window_size=window_size,
398
+ softcap=softcap,
399
+ alibi_slopes=alibi_slopes,
400
+ return_softmax=return_softmax and dropout_p > 0,
401
+ )
402
+ ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
403
+ ctx.dropout_p = dropout_p
404
+ ctx.softmax_scale = softmax_scale
405
+ ctx.causal = causal
406
+ ctx.window_size = window_size
407
+ ctx.softcap = softcap
408
+ ctx.alibi_slopes = alibi_slopes
409
+ ctx.deterministic = deterministic
410
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
411
+
412
+ @staticmethod
413
+ def backward(ctx, dout, *args):
414
+ q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
415
+ dq = torch.empty_like(q)
416
+ kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
417
+ dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
418
+ _flash_attn_backward(
419
+ dout,
420
+ q,
421
+ k,
422
+ v,
423
+ out,
424
+ softmax_lse,
425
+ dq,
426
+ dkv[:, :, 0],
427
+ dkv[:, :, 1],
428
+ ctx.dropout_p,
429
+ ctx.softmax_scale,
430
+ ctx.causal,
431
+ ctx.window_size,
432
+ ctx.softcap,
433
+ ctx.alibi_slopes,
434
+ ctx.deterministic,
435
+ rng_state=rng_state,
436
+ )
437
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
438
+ dkv = dkv[..., : dout.shape[-1]]
439
+ return dq, dkv, None, None, None, None, None, None, None, None
440
+
441
+
442
+ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
443
+ @staticmethod
444
+ def forward(
445
+ ctx,
446
+ q,
447
+ kv,
448
+ cu_seqlens_q,
449
+ cu_seqlens_k,
450
+ max_seqlen_q,
451
+ max_seqlen_k,
452
+ dropout_p,
453
+ softmax_scale,
454
+ causal,
455
+ window_size,
456
+ softcap,
457
+ alibi_slopes,
458
+ deterministic,
459
+ return_softmax,
460
+ ):
461
+ if softmax_scale is None:
462
+ softmax_scale = q.shape[-1] ** (-0.5)
463
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
464
+ q,
465
+ kv[:, 0],
466
+ kv[:, 1],
467
+ cu_seqlens_q,
468
+ cu_seqlens_k,
469
+ max_seqlen_q,
470
+ max_seqlen_k,
471
+ dropout_p,
472
+ softmax_scale,
473
+ causal=causal,
474
+ window_size=window_size,
475
+ softcap=softcap,
476
+ alibi_slopes=alibi_slopes,
477
+ return_softmax=return_softmax and dropout_p > 0,
478
+ block_table=None,
479
+ )
480
+ ctx.save_for_backward(
481
+ q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
482
+ )
483
+ ctx.dropout_p = dropout_p
484
+ ctx.max_seqlen_q = max_seqlen_q
485
+ ctx.max_seqlen_k = max_seqlen_k
486
+ ctx.softmax_scale = softmax_scale
487
+ ctx.causal = causal
488
+ ctx.window_size = window_size
489
+ ctx.softcap = softcap
490
+ ctx.alibi_slopes = alibi_slopes
491
+ ctx.deterministic = deterministic
492
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
493
+
494
+ @staticmethod
495
+ def backward(ctx, dout, *args):
496
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
497
+ dq = torch.empty_like(q)
498
+ kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
499
+ dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
500
+ _flash_attn_varlen_backward(
501
+ dout,
502
+ q,
503
+ k,
504
+ v,
505
+ out,
506
+ softmax_lse,
507
+ dq,
508
+ dkv[:, 0],
509
+ dkv[:, 1],
510
+ cu_seqlens_q,
511
+ cu_seqlens_k,
512
+ ctx.max_seqlen_q,
513
+ ctx.max_seqlen_k,
514
+ ctx.dropout_p,
515
+ ctx.softmax_scale,
516
+ ctx.causal,
517
+ ctx.window_size,
518
+ ctx.softcap,
519
+ ctx.alibi_slopes,
520
+ ctx.deterministic,
521
+ rng_state=rng_state,
522
+ )
523
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
524
+ dkv = dkv[..., : dout.shape[-1]]
525
+ return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
526
+
527
+
528
+ class FlashAttnFunc(torch.autograd.Function):
529
+ @staticmethod
530
+ def forward(
531
+ ctx,
532
+ q,
533
+ k,
534
+ v,
535
+ dropout_p,
536
+ softmax_scale,
537
+ causal,
538
+ window_size,
539
+ softcap,
540
+ alibi_slopes,
541
+ deterministic,
542
+ return_softmax,
543
+ ):
544
+ if softmax_scale is None:
545
+ softmax_scale = q.shape[-1] ** (-0.5)
546
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
547
+ q,
548
+ k,
549
+ v,
550
+ dropout_p,
551
+ softmax_scale,
552
+ causal=causal,
553
+ window_size=window_size,
554
+ softcap=softcap,
555
+ alibi_slopes=alibi_slopes,
556
+ return_softmax=return_softmax and dropout_p > 0,
557
+ )
558
+ ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
559
+ ctx.dropout_p = dropout_p
560
+ ctx.softmax_scale = softmax_scale
561
+ ctx.causal = causal
562
+ ctx.window_size = window_size
563
+ ctx.softcap = softcap
564
+ ctx.alibi_slopes = alibi_slopes
565
+ ctx.deterministic = deterministic
566
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
567
+
568
+ @staticmethod
569
+ def backward(ctx, dout, *args):
570
+ q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
571
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
572
+ _flash_attn_backward(
573
+ dout,
574
+ q,
575
+ k,
576
+ v,
577
+ out,
578
+ softmax_lse,
579
+ dq,
580
+ dk,
581
+ dv,
582
+ ctx.dropout_p,
583
+ ctx.softmax_scale,
584
+ ctx.causal,
585
+ ctx.window_size,
586
+ ctx.softcap,
587
+ ctx.alibi_slopes,
588
+ ctx.deterministic,
589
+ rng_state=rng_state,
590
+ )
591
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
592
+ dk = dk[..., : dout.shape[-1]]
593
+ dv = dv[..., : dout.shape[-1]]
594
+ return dq, dk, dv, None, None, None, None, None, None, None, None
595
+
596
+
597
+ class FlashAttnVarlenFunc(torch.autograd.Function):
598
+ @staticmethod
599
+ def forward(
600
+ ctx,
601
+ q,
602
+ k,
603
+ v,
604
+ cu_seqlens_q,
605
+ cu_seqlens_k,
606
+ max_seqlen_q,
607
+ max_seqlen_k,
608
+ dropout_p,
609
+ softmax_scale,
610
+ causal,
611
+ window_size,
612
+ softcap,
613
+ alibi_slopes,
614
+ deterministic,
615
+ return_softmax,
616
+ block_table,
617
+ ):
618
+ if softmax_scale is None:
619
+ softmax_scale = q.shape[-1] ** (-0.5)
620
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
621
+ q,
622
+ k,
623
+ v,
624
+ cu_seqlens_q,
625
+ cu_seqlens_k,
626
+ max_seqlen_q,
627
+ max_seqlen_k,
628
+ dropout_p,
629
+ softmax_scale,
630
+ causal=causal,
631
+ window_size=window_size,
632
+ softcap=softcap,
633
+ alibi_slopes=alibi_slopes,
634
+ return_softmax=return_softmax and dropout_p > 0,
635
+ block_table=block_table,
636
+ )
637
+ ctx.save_for_backward(
638
+ q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
639
+ )
640
+ ctx.dropout_p = dropout_p
641
+ ctx.max_seqlen_q = max_seqlen_q
642
+ ctx.max_seqlen_k = max_seqlen_k
643
+ ctx.softmax_scale = softmax_scale
644
+ ctx.causal = causal
645
+ ctx.window_size = window_size
646
+ ctx.softcap = softcap
647
+ ctx.alibi_slopes = alibi_slopes
648
+ ctx.deterministic = deterministic
649
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
650
+
651
+ @staticmethod
652
+ def backward(ctx, dout, *args):
653
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
654
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
655
+ _flash_attn_varlen_backward(
656
+ dout,
657
+ q,
658
+ k,
659
+ v,
660
+ out,
661
+ softmax_lse,
662
+ dq,
663
+ dk,
664
+ dv,
665
+ cu_seqlens_q,
666
+ cu_seqlens_k,
667
+ ctx.max_seqlen_q,
668
+ ctx.max_seqlen_k,
669
+ ctx.dropout_p,
670
+ ctx.softmax_scale,
671
+ ctx.causal,
672
+ ctx.window_size,
673
+ ctx.softcap,
674
+ ctx.alibi_slopes,
675
+ ctx.deterministic,
676
+ rng_state=rng_state,
677
+ )
678
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
679
+ dk = dk[..., : dout.shape[-1]]
680
+ dv = dv[..., : dout.shape[-1]]
681
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
682
+
683
+
684
+ def flash_attn_qkvpacked_func(
685
+ qkv,
686
+ dropout_p=0.0,
687
+ softmax_scale=None,
688
+ causal=False,
689
+ window_size=(-1, -1), # -1 means infinite context window
690
+ softcap=0.0, # <=0.0 means deactivate
691
+ alibi_slopes=None,
692
+ deterministic=False,
693
+ return_attn_probs=False,
694
+ ):
695
+ """dropout_p should be set to 0.0 during evaluation
696
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
697
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
698
+ of the gradients of Q, K, V.
699
+ For multi-query and grouped-query attention (MQA/GQA), please see
700
+ flash_attn_kvpacked_func and flash_attn_func.
701
+
702
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
703
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
704
+
705
+ Arguments:
706
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
707
+ dropout_p: float. Dropout probability.
708
+ softmax_scale: float. The scaling of QK^T before applying softmax.
709
+ Default to 1 / sqrt(headdim).
710
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
711
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
712
+ softcap: float. Anything > 0 activates softcapping attention.
713
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
714
+ the attention score of query i and key j.
715
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
716
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
717
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
718
+ testing only. The returned probabilities are not guaranteed to be correct
719
+ (they might not have the right scaling).
720
+ Return:
721
+ out: (batch_size, seqlen, nheads, headdim).
722
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
723
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
724
+ normalization factor).
725
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
726
+ The output of softmax (possibly with different scaling). It also encodes the dropout
727
+ pattern (negative means that location was dropped, nonnegative means it was kept).
728
+ """
729
+ return FlashAttnQKVPackedFunc.apply(
730
+ qkv,
731
+ dropout_p,
732
+ softmax_scale,
733
+ causal,
734
+ window_size,
735
+ softcap,
736
+ alibi_slopes,
737
+ deterministic,
738
+ return_attn_probs,
739
+ )
740
+
741
+
742
+ def flash_attn_kvpacked_func(
743
+ q,
744
+ kv,
745
+ dropout_p=0.0,
746
+ softmax_scale=None,
747
+ causal=False,
748
+ window_size=(-1, -1), # -1 means infinite context window
749
+ softcap=0.0, # 0.0 means deactivated
750
+ alibi_slopes=None,
751
+ deterministic=False,
752
+ return_attn_probs=False,
753
+ ):
754
+ """dropout_p should be set to 0.0 during evaluation
755
+ If K, V are already stacked into 1 tensor, this function will be faster than
756
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
757
+ of the gradients of K, V.
758
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
759
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
760
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
761
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
762
+
763
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
764
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
765
+ 1 1 1 1 0
766
+ 1 1 1 1 1
767
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
768
+ 0 0
769
+ 0 0
770
+ 0 0
771
+ 1 0
772
+ 1 1
773
+ If the row of the mask is all zero, the output will be zero.
774
+
775
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
776
+ will only attend to keys between
777
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
778
+
779
+ Arguments:
780
+ q: (batch_size, seqlen, nheads, headdim)
781
+ kv: (batch_size, seqlen, 2, nheads_k, headdim)
782
+ dropout_p: float. Dropout probability.
783
+ softmax_scale: float. The scaling of QK^T before applying softmax.
784
+ Default to 1 / sqrt(headdim).
785
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
786
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
787
+ softcap: float. Anything > 0 activates softcapping attention.
788
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
789
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
790
+ is added to the attention score of query i and key j.
791
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
792
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
793
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
794
+ testing only. The returned probabilities are not guaranteed to be correct
795
+ (they might not have the right scaling).
796
+ Return:
797
+ out: (batch_size, seqlen, nheads, headdim).
798
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
799
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
800
+ normalization factor).
801
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
802
+ The output of softmax (possibly with different scaling). It also encodes the dropout
803
+ pattern (negative means that location was dropped, nonnegative means it was kept).
804
+ """
805
+ return FlashAttnKVPackedFunc.apply(
806
+ q,
807
+ kv,
808
+ dropout_p,
809
+ softmax_scale,
810
+ causal,
811
+ window_size,
812
+ softcap,
813
+ alibi_slopes,
814
+ deterministic,
815
+ return_attn_probs,
816
+ )
817
+
818
+
819
+ def flash_attn_func(
820
+ q,
821
+ k,
822
+ v,
823
+ dropout_p=0.0,
824
+ softmax_scale=None,
825
+ causal=False,
826
+ window_size=(-1, -1), # -1 means infinite context window
827
+ softcap=0.0, # 0.0 means deactivated
828
+ alibi_slopes=None,
829
+ deterministic=False,
830
+ return_attn_probs=False,
831
+ ):
832
+ """dropout_p should be set to 0.0 during evaluation
833
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
834
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
835
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
836
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
837
+
838
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
839
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
840
+ 1 1 1 1 0
841
+ 1 1 1 1 1
842
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
843
+ 0 0
844
+ 0 0
845
+ 0 0
846
+ 1 0
847
+ 1 1
848
+ If the row of the mask is all zero, the output will be zero.
849
+
850
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
851
+ will only attend to keys between
852
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
853
+
854
+ Arguments:
855
+ q: (batch_size, seqlen, nheads, headdim)
856
+ k: (batch_size, seqlen, nheads_k, headdim)
857
+ v: (batch_size, seqlen, nheads_k, headdim)
858
+ dropout_p: float. Dropout probability.
859
+ softmax_scale: float. The scaling of QK^T before applying softmax.
860
+ Default to 1 / sqrt(headdim).
861
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
862
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
863
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
864
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
865
+ is added to the attention score of query i and key j.
866
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
867
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
868
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
869
+ testing only. The returned probabilities are not guaranteed to be correct
870
+ (they might not have the right scaling).
871
+ Return:
872
+ out: (batch_size, seqlen, nheads, headdim).
873
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
874
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
875
+ normalization factor).
876
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
877
+ The output of softmax (possibly with different scaling). It also encodes the dropout
878
+ pattern (negative means that location was dropped, nonnegative means it was kept).
879
+ """
880
+ return FlashAttnFunc.apply(
881
+ q,
882
+ k,
883
+ v,
884
+ dropout_p,
885
+ softmax_scale,
886
+ causal,
887
+ window_size,
888
+ softcap,
889
+ alibi_slopes,
890
+ deterministic,
891
+ return_attn_probs,
892
+ )
893
+
894
+
895
+ def flash_attn_varlen_qkvpacked_func(
896
+ qkv,
897
+ cu_seqlens,
898
+ max_seqlen,
899
+ dropout_p=0.0,
900
+ softmax_scale=None,
901
+ causal=False,
902
+ window_size=(-1, -1), # -1 means infinite context window
903
+ softcap=0.0, # 0.0 means deactivated
904
+ alibi_slopes=None,
905
+ deterministic=False,
906
+ return_attn_probs=False,
907
+ ):
908
+ """dropout_p should be set to 0.0 during evaluation
909
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
910
+ calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
911
+ of the gradients of Q, K, V.
912
+ For multi-query and grouped-query attention (MQA/GQA), please see
913
+ flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
914
+
915
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
916
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
917
+
918
+ Arguments:
919
+ qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
920
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
921
+ of the sequences in the batch, used to index into qkv.
922
+ max_seqlen: int. Maximum sequence length in the batch.
923
+ dropout_p: float. Dropout probability.
924
+ softmax_scale: float. The scaling of QK^T before applying softmax.
925
+ Default to 1 / sqrt(headdim).
926
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
927
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
928
+ softcap: float. Anything > 0 activates softcapping attention.
929
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
930
+ is added to the attention score of query i and key j.
931
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
932
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
933
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
934
+ testing only. The returned probabilities are not guaranteed to be correct
935
+ (they might not have the right scaling).
936
+ Return:
937
+ out: (total, nheads, headdim).
938
+ softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
939
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
940
+ normalization factor).
941
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
942
+ The output of softmax (possibly with different scaling). It also encodes the dropout
943
+ pattern (negative means that location was dropped, nonnegative means it was kept).
944
+ """
945
+ return FlashAttnVarlenQKVPackedFunc.apply(
946
+ qkv,
947
+ cu_seqlens,
948
+ max_seqlen,
949
+ dropout_p,
950
+ softmax_scale,
951
+ causal,
952
+ window_size,
953
+ softcap,
954
+ alibi_slopes,
955
+ deterministic,
956
+ return_attn_probs,
957
+ )
958
+
959
+
960
+ def flash_attn_varlen_kvpacked_func(
961
+ q,
962
+ kv,
963
+ cu_seqlens_q,
964
+ cu_seqlens_k,
965
+ max_seqlen_q,
966
+ max_seqlen_k,
967
+ dropout_p=0.0,
968
+ softmax_scale=None,
969
+ causal=False,
970
+ window_size=(-1, -1), # -1 means infinite context window
971
+ softcap=0.0, # 0.0 means deactivated
972
+ alibi_slopes=None,
973
+ deterministic=False,
974
+ return_attn_probs=False,
975
+ ):
976
+ """dropout_p should be set to 0.0 during evaluation
977
+ If K, V are already stacked into 1 tensor, this function will be faster than
978
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
979
+ of the gradients of K, V.
980
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
981
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
982
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
983
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
984
+
985
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
986
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
987
+ 1 1 1 1 0
988
+ 1 1 1 1 1
989
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
990
+ 0 0
991
+ 0 0
992
+ 0 0
993
+ 1 0
994
+ 1 1
995
+ If the row of the mask is all zero, the output will be zero.
996
+
997
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
998
+ will only attend to keys between
999
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1000
+
1001
+ Arguments:
1002
+ q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1003
+ kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1004
+ cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1005
+ of the sequences in the batch, used to index into q.
1006
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1007
+ of the sequences in the batch, used to index into kv.
1008
+ max_seqlen_q: int. Maximum query sequence length in the batch.
1009
+ max_seqlen_k: int. Maximum key sequence length in the batch.
1010
+ dropout_p: float. Dropout probability.
1011
+ softmax_scale: float. The scaling of QK^T before applying softmax.
1012
+ Default to 1 / sqrt(headdim).
1013
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1014
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1015
+ softcap: float. Anything > 0 activates softcapping attention.
1016
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1017
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1018
+ is added to the attention score of query i and key j.
1019
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1020
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
1021
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1022
+ testing only. The returned probabilities are not guaranteed to be correct
1023
+ (they might not have the right scaling).
1024
+ Return:
1025
+ out: (total, nheads, headdim).
1026
+ softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1027
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1028
+ normalization factor).
1029
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1030
+ The output of softmax (possibly with different scaling). It also encodes the dropout
1031
+ pattern (negative means that location was dropped, nonnegative means it was kept).
1032
+ """
1033
+ return FlashAttnVarlenKVPackedFunc.apply(
1034
+ q,
1035
+ kv,
1036
+ cu_seqlens_q,
1037
+ cu_seqlens_k,
1038
+ max_seqlen_q,
1039
+ max_seqlen_k,
1040
+ dropout_p,
1041
+ softmax_scale,
1042
+ causal,
1043
+ window_size,
1044
+ softcap,
1045
+ alibi_slopes,
1046
+ deterministic,
1047
+ return_attn_probs,
1048
+ )
1049
+
1050
+
1051
+ def flash_attn_varlen_func(
1052
+ q,
1053
+ k,
1054
+ v,
1055
+ cu_seqlens_q,
1056
+ cu_seqlens_k,
1057
+ max_seqlen_q,
1058
+ max_seqlen_k,
1059
+ dropout_p=0.0,
1060
+ softmax_scale=None,
1061
+ causal=False,
1062
+ window_size=(-1, -1), # -1 means infinite context window
1063
+ softcap=0.0, # 0.0 means deactivated
1064
+ alibi_slopes=None,
1065
+ deterministic=False,
1066
+ return_attn_probs=False,
1067
+ block_table=None,
1068
+ ):
1069
+ """dropout_p should be set to 0.0 during evaluation
1070
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1071
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1072
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1073
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1074
+
1075
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1076
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1077
+ 1 1 1 1 0
1078
+ 1 1 1 1 1
1079
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1080
+ 0 0
1081
+ 0 0
1082
+ 0 0
1083
+ 1 0
1084
+ 1 1
1085
+ If the row of the mask is all zero, the output will be zero.
1086
+
1087
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
1088
+ will only attend to keys between
1089
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1090
+
1091
+ Arguments:
1092
+ q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1093
+ k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1094
+ v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1095
+ cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1096
+ of the sequences in the batch, used to index into q.
1097
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1098
+ of the sequences in the batch, used to index into kv.
1099
+ max_seqlen_q: int. Maximum query sequence length in the batch.
1100
+ max_seqlen_k: int. Maximum key sequence length in the batch.
1101
+ dropout_p: float. Dropout probability.
1102
+ softmax_scale: float. The scaling of QK^T before applying softmax.
1103
+ Default to 1 / sqrt(headdim).
1104
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1105
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1106
+ softcap: float. Anything > 0 activates softcapping attention.
1107
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1108
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1109
+ is added to the attention score of query i and key j.
1110
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1111
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
1112
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1113
+ testing only. The returned probabilities are not guaranteed to be correct
1114
+ (they might not have the right scaling).
1115
+ Return:
1116
+ out: (total, nheads, headdim).
1117
+ softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1118
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1119
+ normalization factor).
1120
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1121
+ The output of softmax (possibly with different scaling). It also encodes the dropout
1122
+ pattern (negative means that location was dropped, nonnegative means it was kept).
1123
+ """
1124
+ return FlashAttnVarlenFunc.apply(
1125
+ q,
1126
+ k,
1127
+ v,
1128
+ cu_seqlens_q,
1129
+ cu_seqlens_k,
1130
+ max_seqlen_q,
1131
+ max_seqlen_k,
1132
+ dropout_p,
1133
+ softmax_scale,
1134
+ causal,
1135
+ window_size,
1136
+ softcap,
1137
+ alibi_slopes,
1138
+ deterministic,
1139
+ return_attn_probs,
1140
+ block_table,
1141
+ )
1142
+
1143
+
1144
+ def flash_attn_with_kvcache(
1145
+ q,
1146
+ k_cache,
1147
+ v_cache,
1148
+ k=None,
1149
+ v=None,
1150
+ rotary_cos=None,
1151
+ rotary_sin=None,
1152
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1153
+ cache_batch_idx: Optional[torch.Tensor] = None,
1154
+ cache_leftpad: Optional[torch.Tensor] = None,
1155
+ block_table: Optional[torch.Tensor] = None,
1156
+ softmax_scale=None,
1157
+ causal=False,
1158
+ window_size=(-1, -1), # -1 means infinite context window
1159
+ softcap=0.0, # 0.0 means deactivated
1160
+ rotary_interleaved=True,
1161
+ alibi_slopes=None,
1162
+ num_splits=0,
1163
+ return_softmax_lse=False,
1164
+ ):
1165
+ """
1166
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
1167
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
1168
+ the previous step, and update them with the new keys/values from the current step, and do
1169
+ attention with the updated cache, all in 1 kernel.
1170
+
1171
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
1172
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
1173
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
1174
+
1175
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
1176
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
1177
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
1178
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
1179
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
1180
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1181
+
1182
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
1183
+
1184
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1185
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1186
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1187
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1188
+
1189
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1190
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1191
+ 1 1 1 1 0
1192
+ 1 1 1 1 1
1193
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1194
+ 0 0
1195
+ 0 0
1196
+ 0 0
1197
+ 1 0
1198
+ 1 1
1199
+ If the row of the mask is all zero, the output will be zero.
1200
+
1201
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
1202
+ will only attend to keys between
1203
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1204
+
1205
+ Note: Does not support backward pass.
1206
+
1207
+ Arguments:
1208
+ q: (batch_size, seqlen, nheads, headdim)
1209
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
1210
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1211
+ page_block_size must be a multiple of 256.
1212
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
1213
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1214
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
1215
+ k with k_cache, starting at the indices specified by cache_seqlens.
1216
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1217
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
1218
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
1219
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
1220
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
1221
+ KV cache.
1222
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
1223
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
1224
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
1225
+ might come from any of the duplicate indices.
1226
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
1227
+ block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1228
+ softmax_scale: float. The scaling of QK^T before applying softmax.
1229
+ Default to 1 / sqrt(headdim).
1230
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1231
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1232
+ softcap: float. Anything > 0 activates softcapping attention.
1233
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
1234
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1235
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1236
+ (i.e. GPT-NeoX style).
1237
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1238
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1239
+ is added to the attention score of query i and key j.
1240
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1241
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1242
+ to automatically determine the number of splits.
1243
+ Don't change this unless you know what you are doing.
1244
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1245
+
1246
+ Return:
1247
+ out: (batch_size, seqlen, nheads, headdim).
1248
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1249
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1250
+ normalization factor).
1251
+ """
1252
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1253
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1254
+ q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1255
+ if softmax_scale is None:
1256
+ softmax_scale = q.shape[-1] ** (-0.5)
1257
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1258
+ cache_seqlens = torch.full(
1259
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1260
+ )
1261
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1262
+ cache_batch_idx = maybe_contiguous(cache_batch_idx)
1263
+ block_table = maybe_contiguous(block_table)
1264
+ out, softmax_lse = flash_attn_cuda.fwd_kvcache(
1265
+ q,
1266
+ k_cache,
1267
+ v_cache,
1268
+ k,
1269
+ v,
1270
+ cache_seqlens,
1271
+ rotary_cos,
1272
+ rotary_sin,
1273
+ cache_batch_idx,
1274
+ cache_leftpad,
1275
+ block_table,
1276
+ alibi_slopes,
1277
+ None,
1278
+ softmax_scale,
1279
+ causal,
1280
+ window_size[0],
1281
+ window_size[1],
1282
+ softcap,
1283
+ rotary_interleaved,
1284
+ num_splits,
1285
+ )
1286
+ return (out, softmax_lse) if return_softmax_lse else out
.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton.py ADDED
@@ -0,0 +1,1160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ *Experimental* implementation of FlashAttention in Triton.
3
+ Tested with triton==2.0.0.dev20221202.
4
+ Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
5
+ other than 64:
6
+ https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
7
+ We'll update this implementation with the new Triton backend once this is fixed.
8
+
9
+ We use the FlashAttention implementation from Phil Tillet a starting point.
10
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
11
+
12
+ Changes:
13
+ - Implement both causal and non-causal attention.
14
+ - Implement both self-attention and cross-attention.
15
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
16
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
17
+ - Support attention bias.
18
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
19
+ - Make the backward for d=128 much faster by reducing register spilling.
20
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
21
+ small batch size * nheads.
22
+
23
+ Caution:
24
+ - This is an *experimental* implementation. The forward pass should be quite robust but
25
+ I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
26
+ - This implementation has only been tested on A100.
27
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
28
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
29
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
30
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
31
+ that there are none left for other head dimensions.
32
+
33
+ Differences between this Triton version and the CUDA version:
34
+ - Triton version doesn't support dropout.
35
+ - Triton forward is generally faster than CUDA forward, while Triton backward is
36
+ generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
37
+ than CUDA forward + backward.
38
+ - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
39
+ - Triton version supports attention bias, while CUDA version doesn't.
40
+ """
41
+
42
+ import math
43
+
44
+ import torch
45
+ import triton
46
+ import triton.language as tl
47
+
48
+
49
+ # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
50
+ # @triton.autotune(
51
+ # configs=[
52
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
53
+ # # This config has a race condition when EVEN_M == False, disabling it for now.
54
+ # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
55
+ # ],
56
+ # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
57
+ # )
58
+ @triton.heuristics(
59
+ {
60
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
61
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
62
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
63
+ }
64
+ )
65
+ @triton.jit
66
+ def _fwd_kernel(
67
+ Q,
68
+ K,
69
+ V,
70
+ Bias,
71
+ Out,
72
+ Lse,
73
+ TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
74
+ softmax_scale,
75
+ stride_qb,
76
+ stride_qh,
77
+ stride_qm,
78
+ stride_kb,
79
+ stride_kh,
80
+ stride_kn,
81
+ stride_vb,
82
+ stride_vh,
83
+ stride_vn,
84
+ stride_bb,
85
+ stride_bh,
86
+ stride_bm,
87
+ stride_ob,
88
+ stride_oh,
89
+ stride_om,
90
+ nheads,
91
+ seqlen_q,
92
+ seqlen_k,
93
+ seqlen_q_rounded,
94
+ headdim,
95
+ CACHE_KEY_SEQLEN_Q,
96
+ CACHE_KEY_SEQLEN_K,
97
+ BIAS_TYPE: tl.constexpr,
98
+ IS_CAUSAL: tl.constexpr,
99
+ BLOCK_HEADDIM: tl.constexpr,
100
+ EVEN_M: tl.constexpr,
101
+ EVEN_N: tl.constexpr,
102
+ EVEN_HEADDIM: tl.constexpr,
103
+ BLOCK_M: tl.constexpr,
104
+ BLOCK_N: tl.constexpr,
105
+ ):
106
+ start_m = tl.program_id(0)
107
+ off_hb = tl.program_id(1)
108
+ off_b = off_hb // nheads
109
+ off_h = off_hb % nheads
110
+ # off_b = tl.program_id(1)
111
+ # off_h = tl.program_id(2)
112
+ # off_hb = off_b * nheads + off_h
113
+ # initialize offsets
114
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
115
+ offs_n = tl.arange(0, BLOCK_N)
116
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
117
+ # Initialize pointers to Q, K, V
118
+ # Adding parenthesis around indexing might use int32 math instead of int64 math?
119
+ # https://github.com/openai/triton/issues/741
120
+ # I'm seeing a tiny bit of difference (5-7us)
121
+ q_ptrs = (
122
+ Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
123
+ )
124
+ k_ptrs = (
125
+ K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
126
+ )
127
+ v_ptrs = (
128
+ V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
129
+ )
130
+ if BIAS_TYPE == "vector":
131
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
132
+ elif BIAS_TYPE == "matrix":
133
+ b_ptrs = (
134
+ Bias
135
+ + off_b * stride_bb
136
+ + off_h * stride_bh
137
+ + (offs_m[:, None] * stride_bm + offs_n[None, :])
138
+ )
139
+ # initialize pointer to m and l
140
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
141
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
142
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
143
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
144
+ # load q: it will stay in SRAM throughout
145
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
146
+ # tl.load(q_ptrs), we get the wrong output!
147
+ if EVEN_M & EVEN_N:
148
+ if EVEN_HEADDIM:
149
+ q = tl.load(q_ptrs)
150
+ else:
151
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
152
+ else:
153
+ if EVEN_HEADDIM:
154
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
155
+ else:
156
+ q = tl.load(
157
+ q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
158
+ )
159
+ # loop over k, v and update accumulator
160
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
161
+ for start_n in range(0, end_n, BLOCK_N):
162
+ start_n = tl.multiple_of(start_n, BLOCK_N)
163
+ # -- compute qk ----
164
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
165
+ if EVEN_HEADDIM:
166
+ k = tl.load(k_ptrs + start_n * stride_kn)
167
+ else:
168
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
169
+ else:
170
+ if EVEN_HEADDIM:
171
+ k = tl.load(
172
+ k_ptrs + start_n * stride_kn,
173
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
174
+ other=0.0,
175
+ )
176
+ else:
177
+ k = tl.load(
178
+ k_ptrs + start_n * stride_kn,
179
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
180
+ other=0.0,
181
+ )
182
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
183
+ qk += tl.dot(q, k, trans_b=True)
184
+ # Trying to combine the two masks seem to make the result wrong
185
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
186
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
187
+ if IS_CAUSAL:
188
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
189
+ if BIAS_TYPE != "none":
190
+ if BIAS_TYPE == "vector":
191
+ if EVEN_N:
192
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
193
+ else:
194
+ bias = tl.load(
195
+ b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
196
+ ).to(tl.float32)
197
+ bias = bias[None, :]
198
+ elif BIAS_TYPE == "matrix":
199
+ if EVEN_M & EVEN_N:
200
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
201
+ else:
202
+ bias = tl.load(
203
+ b_ptrs + start_n,
204
+ mask=(offs_m[:, None] < seqlen_q)
205
+ & ((start_n + offs_n)[None, :] < seqlen_k),
206
+ other=0.0,
207
+ ).to(tl.float32)
208
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
209
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
210
+ # to multiply with softmax_scale here.
211
+ qk = qk * softmax_scale + bias
212
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
213
+ p = tl.exp(qk - m_ij[:, None])
214
+ else:
215
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
216
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
217
+ l_ij = tl.sum(p, 1)
218
+
219
+ # scale acc_o
220
+ acc_o_scale = tl.exp(m_i - m_ij)
221
+
222
+ # # -- update output accumulator --
223
+ # BUG: have to store and immediately load
224
+ tl.store(t_ptrs, acc_o_scale)
225
+ acc_o_scale = tl.load(t_ptrs)
226
+ acc_o = acc_o * acc_o_scale[:, None]
227
+ # update acc_o
228
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
229
+ if EVEN_HEADDIM:
230
+ v = tl.load(v_ptrs + start_n * stride_vn)
231
+ else:
232
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
233
+ else:
234
+ if EVEN_HEADDIM:
235
+ v = tl.load(
236
+ v_ptrs + start_n * stride_vn,
237
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
238
+ other=0.0,
239
+ )
240
+ else:
241
+ v = tl.load(
242
+ v_ptrs + start_n * stride_vn,
243
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
244
+ other=0.0,
245
+ )
246
+ p = p.to(v.dtype)
247
+ acc_o += tl.dot(p, v)
248
+
249
+ # -- update statistics
250
+ m_i = m_ij
251
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
252
+ lse_i = m_ij + tl.log(l_i_new)
253
+
254
+ o_scale = tl.exp(m_i - lse_i)
255
+ # BUG: have to store and immediately load
256
+ tl.store(t_ptrs, o_scale)
257
+ o_scale = tl.load(t_ptrs)
258
+ acc_o = acc_o * o_scale[:, None]
259
+ # rematerialize offsets to save registers
260
+ start_m = tl.program_id(0)
261
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
262
+ # write back l and m
263
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
264
+ tl.store(lse_ptrs, lse_i)
265
+ # initialize pointers to output
266
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
267
+ out_ptrs = (
268
+ Out
269
+ + off_b * stride_ob
270
+ + off_h * stride_oh
271
+ + (offs_m[:, None] * stride_om + offs_d[None, :])
272
+ )
273
+ if EVEN_M:
274
+ if EVEN_HEADDIM:
275
+ tl.store(out_ptrs, acc_o)
276
+ else:
277
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
278
+ else:
279
+ if EVEN_HEADDIM:
280
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
281
+ else:
282
+ tl.store(
283
+ out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
284
+ )
285
+
286
+
287
+ @triton.jit
288
+ def _bwd_preprocess_do_o_dot(
289
+ Out,
290
+ DO,
291
+ Delta,
292
+ stride_ob,
293
+ stride_oh,
294
+ stride_om,
295
+ stride_dob,
296
+ stride_doh,
297
+ stride_dom,
298
+ nheads,
299
+ seqlen_q,
300
+ seqlen_q_rounded,
301
+ headdim,
302
+ BLOCK_M: tl.constexpr,
303
+ BLOCK_HEADDIM: tl.constexpr,
304
+ ):
305
+ start_m = tl.program_id(0)
306
+ off_hb = tl.program_id(1)
307
+ off_b = off_hb // nheads
308
+ off_h = off_hb % nheads
309
+ # initialize offsets
310
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
311
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
312
+ # load
313
+ o = tl.load(
314
+ Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
315
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
316
+ other=0.0,
317
+ ).to(tl.float32)
318
+ do = tl.load(
319
+ DO
320
+ + off_b * stride_dob
321
+ + off_h * stride_doh
322
+ + offs_m[:, None] * stride_dom
323
+ + offs_d[None, :],
324
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
325
+ other=0.0,
326
+ ).to(tl.float32)
327
+ delta = tl.sum(o * do, axis=1)
328
+ # write-back
329
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
330
+
331
+
332
+ @triton.jit
333
+ def _bwd_store_dk_dv(
334
+ dk_ptrs,
335
+ dv_ptrs,
336
+ dk,
337
+ dv,
338
+ offs_n,
339
+ offs_d,
340
+ seqlen_k,
341
+ headdim,
342
+ EVEN_M: tl.constexpr,
343
+ EVEN_N: tl.constexpr,
344
+ EVEN_HEADDIM: tl.constexpr,
345
+ ):
346
+ # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
347
+ # if we just call tl.store(dv_ptrs), there's a race condition
348
+ if EVEN_N & EVEN_M:
349
+ if EVEN_HEADDIM:
350
+ tl.store(dv_ptrs, dv)
351
+ tl.store(dk_ptrs, dk)
352
+ else:
353
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
354
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
355
+ else:
356
+ if EVEN_HEADDIM:
357
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
358
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
359
+ else:
360
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
361
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
362
+
363
+
364
+ @triton.jit
365
+ def _bwd_kernel_one_col_block(
366
+ start_n,
367
+ Q,
368
+ K,
369
+ V,
370
+ Bias,
371
+ DO,
372
+ DQ,
373
+ DK,
374
+ DV,
375
+ LSE,
376
+ D,
377
+ softmax_scale,
378
+ stride_qm,
379
+ stride_kn,
380
+ stride_vn,
381
+ stride_bm,
382
+ stride_dom,
383
+ stride_dqm,
384
+ stride_dkn,
385
+ stride_dvn,
386
+ seqlen_q,
387
+ seqlen_k,
388
+ headdim,
389
+ ATOMIC_ADD: tl.constexpr,
390
+ BIAS_TYPE: tl.constexpr,
391
+ IS_CAUSAL: tl.constexpr,
392
+ BLOCK_HEADDIM: tl.constexpr,
393
+ EVEN_M: tl.constexpr,
394
+ EVEN_N: tl.constexpr,
395
+ EVEN_HEADDIM: tl.constexpr,
396
+ BLOCK_M: tl.constexpr,
397
+ BLOCK_N: tl.constexpr,
398
+ ):
399
+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
400
+ begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
401
+ # initialize row/col offsets
402
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
403
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
404
+ offs_m = tl.arange(0, BLOCK_M)
405
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
406
+ # initialize pointers to value-like data
407
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
408
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
409
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
410
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
411
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
412
+ if BIAS_TYPE == "vector":
413
+ b_ptrs = Bias + offs_n
414
+ elif BIAS_TYPE == "matrix":
415
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
416
+ # initialize dv and dk
417
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
418
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
419
+ # There seems to be some problem with Triton pipelining that makes results wrong for
420
+ # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
421
+ # may have zero step, and pipelining with the bias matrix could screw it up.
422
+ # So we just exit early.
423
+ if begin_m >= seqlen_q:
424
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
425
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
426
+ _bwd_store_dk_dv(
427
+ dk_ptrs,
428
+ dv_ptrs,
429
+ dk,
430
+ dv,
431
+ offs_n,
432
+ offs_d,
433
+ seqlen_k,
434
+ headdim,
435
+ EVEN_M=EVEN_M,
436
+ EVEN_N=EVEN_N,
437
+ EVEN_HEADDIM=EVEN_HEADDIM,
438
+ )
439
+ return
440
+ # k and v stay in SRAM throughout
441
+ # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
442
+ # if we just call tl.load(k_ptrs), we get the wrong output!
443
+ if EVEN_N & EVEN_M:
444
+ if EVEN_HEADDIM:
445
+ k = tl.load(k_ptrs)
446
+ v = tl.load(v_ptrs)
447
+ else:
448
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
449
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
450
+ else:
451
+ if EVEN_HEADDIM:
452
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
453
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
454
+ else:
455
+ k = tl.load(
456
+ k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
457
+ )
458
+ v = tl.load(
459
+ v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
460
+ )
461
+ # loop over rows
462
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
463
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
464
+ start_m = tl.multiple_of(start_m, BLOCK_M)
465
+ offs_m_curr = start_m + offs_m
466
+ # load q, k, v, do on-chip
467
+ # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
468
+ if EVEN_M & EVEN_HEADDIM:
469
+ q = tl.load(q_ptrs)
470
+ else:
471
+ if EVEN_HEADDIM:
472
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
473
+ else:
474
+ q = tl.load(
475
+ q_ptrs,
476
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
477
+ other=0.0,
478
+ )
479
+ # recompute p = softmax(qk, dim=-1).T
480
+ qk = tl.dot(q, k, trans_b=True)
481
+ # Trying to combine the two masks seem to make the result wrong
482
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
483
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
484
+ if IS_CAUSAL:
485
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
486
+ if BIAS_TYPE != "none":
487
+ tl.debug_barrier() # Race condition otherwise
488
+ if BIAS_TYPE == "vector":
489
+ if EVEN_N:
490
+ bias = tl.load(b_ptrs).to(tl.float32)
491
+ else:
492
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
493
+ bias = bias[None, :]
494
+ elif BIAS_TYPE == "matrix":
495
+ if EVEN_M & EVEN_N:
496
+ bias = tl.load(b_ptrs).to(tl.float32)
497
+ else:
498
+ bias = tl.load(
499
+ b_ptrs,
500
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
501
+ other=0.0,
502
+ ).to(tl.float32)
503
+ qk = qk * softmax_scale + bias
504
+ # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
505
+ # Also wrong for headdim=64.
506
+ if not (EVEN_M & EVEN_HEADDIM):
507
+ tl.debug_barrier()
508
+ lse_i = tl.load(LSE + offs_m_curr)
509
+ if BIAS_TYPE == "none":
510
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
511
+ else:
512
+ p = tl.exp(qk - lse_i[:, None])
513
+ # compute dv
514
+ # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
515
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
516
+ # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
517
+ # the output is correct.
518
+ if EVEN_M & EVEN_HEADDIM:
519
+ do = tl.load(do_ptrs)
520
+ else:
521
+ # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
522
+ do = tl.load(
523
+ do_ptrs,
524
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
525
+ other=0.0,
526
+ )
527
+ # if EVEN_M:
528
+ # if EVEN_HEADDIM:
529
+ # do = tl.load(do_ptrs)
530
+ # else:
531
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
532
+ # else:
533
+ # if EVEN_HEADDIM:
534
+ # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
535
+ # else:
536
+ # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
537
+ # & (offs_d[None, :] < headdim), other=0.0)
538
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
539
+ # compute dp = dot(v, do)
540
+ # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
541
+ # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
542
+ # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
543
+ if not (EVEN_M & EVEN_HEADDIM):
544
+ tl.debug_barrier()
545
+ dp = tl.dot(do, v, trans_b=True)
546
+ # There's a race condition for headdim=48
547
+ if not EVEN_HEADDIM:
548
+ tl.debug_barrier()
549
+ # compute ds = p * (dp - delta[:, None])
550
+ # Putting the subtraction after the dp matmul (instead of before) is slightly faster
551
+ Di = tl.load(D + offs_m_curr)
552
+ # Converting ds to q.dtype here reduces register pressure and makes it much faster
553
+ # for BLOCK_HEADDIM=128
554
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
555
+ # compute dk = dot(ds.T, q)
556
+ dk += tl.dot(ds, q, trans_a=True)
557
+ # compute dq
558
+ if not (
559
+ EVEN_M & EVEN_HEADDIM
560
+ ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
561
+ tl.debug_barrier()
562
+ if not ATOMIC_ADD:
563
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
564
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
565
+ dq += tl.dot(ds, k)
566
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
567
+ else:
568
+ if EVEN_HEADDIM:
569
+ dq = tl.load(
570
+ dq_ptrs,
571
+ mask=offs_m_curr[:, None] < seqlen_q,
572
+ other=0.0,
573
+ eviction_policy="evict_last",
574
+ )
575
+ dq += tl.dot(ds, k)
576
+ tl.store(
577
+ dq_ptrs,
578
+ dq,
579
+ mask=offs_m_curr[:, None] < seqlen_q,
580
+ eviction_policy="evict_last",
581
+ )
582
+ else:
583
+ dq = tl.load(
584
+ dq_ptrs,
585
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
586
+ other=0.0,
587
+ eviction_policy="evict_last",
588
+ )
589
+ dq += tl.dot(ds, k)
590
+ tl.store(
591
+ dq_ptrs,
592
+ dq,
593
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
594
+ eviction_policy="evict_last",
595
+ )
596
+ else: # If we're parallelizing across the seqlen_k dimension
597
+ dq = tl.dot(ds, k)
598
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
599
+ tl.atomic_add(dq_ptrs, dq)
600
+ else:
601
+ if EVEN_HEADDIM:
602
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
603
+ else:
604
+ tl.atomic_add(
605
+ dq_ptrs,
606
+ dq,
607
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
608
+ )
609
+ # increment pointers
610
+ dq_ptrs += BLOCK_M * stride_dqm
611
+ q_ptrs += BLOCK_M * stride_qm
612
+ do_ptrs += BLOCK_M * stride_dom
613
+ if BIAS_TYPE == "matrix":
614
+ b_ptrs += BLOCK_M * stride_bm
615
+ # write-back
616
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
617
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
618
+ _bwd_store_dk_dv(
619
+ dk_ptrs,
620
+ dv_ptrs,
621
+ dk,
622
+ dv,
623
+ offs_n,
624
+ offs_d,
625
+ seqlen_k,
626
+ headdim,
627
+ EVEN_M=EVEN_M,
628
+ EVEN_N=EVEN_N,
629
+ EVEN_HEADDIM=EVEN_HEADDIM,
630
+ )
631
+
632
+
633
+ def init_to_zero(name):
634
+ return lambda nargs: nargs[name].zero_()
635
+
636
+
637
+ @triton.autotune(
638
+ configs=[
639
+ triton.Config(
640
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
641
+ num_warps=8,
642
+ num_stages=1,
643
+ pre_hook=init_to_zero("DQ"),
644
+ ),
645
+ triton.Config(
646
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
647
+ num_warps=8,
648
+ num_stages=1,
649
+ pre_hook=init_to_zero("DQ"),
650
+ ),
651
+ # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
652
+ # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
653
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
654
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
655
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
656
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
657
+ ],
658
+ key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
659
+ )
660
+ @triton.heuristics(
661
+ {
662
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
663
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
664
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
665
+ }
666
+ )
667
+ @triton.jit
668
+ def _bwd_kernel(
669
+ Q,
670
+ K,
671
+ V,
672
+ Bias,
673
+ DO,
674
+ DQ,
675
+ DK,
676
+ DV,
677
+ LSE,
678
+ D,
679
+ softmax_scale,
680
+ stride_qb,
681
+ stride_qh,
682
+ stride_qm,
683
+ stride_kb,
684
+ stride_kh,
685
+ stride_kn,
686
+ stride_vb,
687
+ stride_vh,
688
+ stride_vn,
689
+ stride_bb,
690
+ stride_bh,
691
+ stride_bm,
692
+ stride_dob,
693
+ stride_doh,
694
+ stride_dom,
695
+ stride_dqb,
696
+ stride_dqh,
697
+ stride_dqm,
698
+ stride_dkb,
699
+ stride_dkh,
700
+ stride_dkn,
701
+ stride_dvb,
702
+ stride_dvh,
703
+ stride_dvn,
704
+ nheads,
705
+ seqlen_q,
706
+ seqlen_k,
707
+ seqlen_q_rounded,
708
+ headdim,
709
+ CACHE_KEY_SEQLEN_Q,
710
+ CACHE_KEY_SEQLEN_K,
711
+ BIAS_TYPE: tl.constexpr,
712
+ IS_CAUSAL: tl.constexpr,
713
+ BLOCK_HEADDIM: tl.constexpr,
714
+ SEQUENCE_PARALLEL: tl.constexpr,
715
+ EVEN_M: tl.constexpr,
716
+ EVEN_N: tl.constexpr,
717
+ EVEN_HEADDIM: tl.constexpr,
718
+ BLOCK_M: tl.constexpr,
719
+ BLOCK_N: tl.constexpr,
720
+ ):
721
+ off_hb = tl.program_id(1)
722
+ off_b = off_hb // nheads
723
+ off_h = off_hb % nheads
724
+ # offset pointers for batch/head
725
+ Q += off_b * stride_qb + off_h * stride_qh
726
+ K += off_b * stride_kb + off_h * stride_kh
727
+ V += off_b * stride_vb + off_h * stride_vh
728
+ DO += off_b * stride_dob + off_h * stride_doh
729
+ DQ += off_b * stride_dqb + off_h * stride_dqh
730
+ DK += off_b * stride_dkb + off_h * stride_dkh
731
+ DV += off_b * stride_dvb + off_h * stride_dvh
732
+ if BIAS_TYPE != "none":
733
+ Bias += off_b * stride_bb + off_h * stride_bh
734
+ # pointer to row-wise quantities in value-like data
735
+ D += off_hb * seqlen_q_rounded
736
+ LSE += off_hb * seqlen_q_rounded
737
+ if not SEQUENCE_PARALLEL:
738
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
739
+ for start_n in range(0, num_block_n):
740
+ _bwd_kernel_one_col_block(
741
+ start_n,
742
+ Q,
743
+ K,
744
+ V,
745
+ Bias,
746
+ DO,
747
+ DQ,
748
+ DK,
749
+ DV,
750
+ LSE,
751
+ D,
752
+ softmax_scale,
753
+ stride_qm,
754
+ stride_kn,
755
+ stride_vn,
756
+ stride_bm,
757
+ stride_dom,
758
+ stride_dqm,
759
+ stride_dkn,
760
+ stride_dvn,
761
+ seqlen_q,
762
+ seqlen_k,
763
+ headdim,
764
+ ATOMIC_ADD=False,
765
+ BIAS_TYPE=BIAS_TYPE,
766
+ IS_CAUSAL=IS_CAUSAL,
767
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
768
+ EVEN_M=EVEN_M,
769
+ EVEN_N=EVEN_N,
770
+ EVEN_HEADDIM=EVEN_HEADDIM,
771
+ BLOCK_M=BLOCK_M,
772
+ BLOCK_N=BLOCK_N,
773
+ )
774
+ else:
775
+ start_n = tl.program_id(0)
776
+ _bwd_kernel_one_col_block(
777
+ start_n,
778
+ Q,
779
+ K,
780
+ V,
781
+ Bias,
782
+ DO,
783
+ DQ,
784
+ DK,
785
+ DV,
786
+ LSE,
787
+ D,
788
+ softmax_scale,
789
+ stride_qm,
790
+ stride_kn,
791
+ stride_vn,
792
+ stride_bm,
793
+ stride_dom,
794
+ stride_dqm,
795
+ stride_dkn,
796
+ stride_dvn,
797
+ seqlen_q,
798
+ seqlen_k,
799
+ headdim,
800
+ ATOMIC_ADD=True,
801
+ BIAS_TYPE=BIAS_TYPE,
802
+ IS_CAUSAL=IS_CAUSAL,
803
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
804
+ EVEN_M=EVEN_M,
805
+ EVEN_N=EVEN_N,
806
+ EVEN_HEADDIM=EVEN_HEADDIM,
807
+ BLOCK_M=BLOCK_M,
808
+ BLOCK_N=BLOCK_N,
809
+ )
810
+
811
+
812
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
813
+ # shape constraints
814
+ batch, seqlen_q, nheads, d = q.shape
815
+ _, seqlen_k, _, _ = k.shape
816
+ assert k.shape == (batch, seqlen_k, nheads, d)
817
+ assert v.shape == (batch, seqlen_k, nheads, d)
818
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
819
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
820
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
821
+ assert q.is_cuda and k.is_cuda and v.is_cuda
822
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
823
+
824
+ has_bias = bias is not None
825
+ bias_type = "none"
826
+ if has_bias:
827
+ assert bias.dtype in [q.dtype, torch.float]
828
+ assert bias.is_cuda
829
+ assert bias.dim() == 4
830
+ if bias.stride(-1) != 1:
831
+ bias = bias.contiguous()
832
+ if bias.shape[2:] == (1, seqlen_k):
833
+ bias_type = "vector"
834
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
835
+ bias_type = "matrix"
836
+ else:
837
+ raise RuntimeError(
838
+ "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
839
+ )
840
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
841
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
842
+
843
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
844
+ lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
845
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
846
+ o = torch.empty_like(q)
847
+
848
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
849
+ BLOCK = 128
850
+ num_warps = 4 if d <= 64 else 8
851
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
852
+ _fwd_kernel[grid](
853
+ q,
854
+ k,
855
+ v,
856
+ bias,
857
+ o,
858
+ lse,
859
+ tmp,
860
+ softmax_scale,
861
+ q.stride(0),
862
+ q.stride(2),
863
+ q.stride(1),
864
+ k.stride(0),
865
+ k.stride(2),
866
+ k.stride(1),
867
+ v.stride(0),
868
+ v.stride(2),
869
+ v.stride(1),
870
+ *bias_strides,
871
+ o.stride(0),
872
+ o.stride(2),
873
+ o.stride(1),
874
+ nheads,
875
+ seqlen_q,
876
+ seqlen_k,
877
+ seqlen_q_rounded,
878
+ d,
879
+ seqlen_q // 32,
880
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
881
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
882
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
883
+ bias_type,
884
+ causal,
885
+ BLOCK_HEADDIM,
886
+ BLOCK_M=BLOCK,
887
+ BLOCK_N=BLOCK,
888
+ num_warps=num_warps,
889
+ num_stages=1,
890
+ )
891
+ return o, lse, softmax_scale # softmax_scale could have been updated
892
+
893
+
894
+ def _flash_attn_backward(
895
+ do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
896
+ ):
897
+ # Make sure that the last dimension is contiguous
898
+ if do.stride(-1) != 1:
899
+ do = do.contiguous()
900
+ batch, seqlen_q, nheads, d = q.shape
901
+ _, seqlen_k, _, _ = k.shape
902
+ # assert d in {16, 32, 64, 128}
903
+ assert d <= 128
904
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
905
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
906
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
907
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
908
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
909
+ # dq_accum = torch.zeros_like(q, dtype=torch.float32)
910
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
911
+ delta = torch.empty_like(lse)
912
+ # delta = torch.zeros_like(lse)
913
+
914
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
915
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
916
+ _bwd_preprocess_do_o_dot[grid](
917
+ o,
918
+ do,
919
+ delta,
920
+ o.stride(0),
921
+ o.stride(2),
922
+ o.stride(1),
923
+ do.stride(0),
924
+ do.stride(2),
925
+ do.stride(1),
926
+ nheads,
927
+ seqlen_q,
928
+ seqlen_q_rounded,
929
+ d,
930
+ BLOCK_M=128,
931
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
932
+ )
933
+
934
+ has_bias = bias is not None
935
+ bias_type = "none"
936
+ if has_bias:
937
+ assert bias.dtype in [q.dtype, torch.float]
938
+ assert bias.is_cuda
939
+ assert bias.dim() == 4
940
+ assert bias.stride(-1) == 1
941
+ if bias.shape[2:] == (1, seqlen_k):
942
+ bias_type = "vector"
943
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
944
+ bias_type = "matrix"
945
+ else:
946
+ raise RuntimeError(
947
+ "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
948
+ )
949
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
950
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
951
+
952
+ # BLOCK_M = 128
953
+ # BLOCK_N = 64
954
+ # num_warps = 4
955
+ grid = lambda META: (
956
+ triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
957
+ batch * nheads,
958
+ )
959
+ _bwd_kernel[grid](
960
+ q,
961
+ k,
962
+ v,
963
+ bias,
964
+ do,
965
+ dq_accum,
966
+ dk,
967
+ dv,
968
+ lse,
969
+ delta,
970
+ softmax_scale,
971
+ q.stride(0),
972
+ q.stride(2),
973
+ q.stride(1),
974
+ k.stride(0),
975
+ k.stride(2),
976
+ k.stride(1),
977
+ v.stride(0),
978
+ v.stride(2),
979
+ v.stride(1),
980
+ *bias_strides,
981
+ do.stride(0),
982
+ do.stride(2),
983
+ do.stride(1),
984
+ dq_accum.stride(0),
985
+ dq_accum.stride(2),
986
+ dq_accum.stride(1),
987
+ dk.stride(0),
988
+ dk.stride(2),
989
+ dk.stride(1),
990
+ dv.stride(0),
991
+ dv.stride(2),
992
+ dv.stride(1),
993
+ nheads,
994
+ seqlen_q,
995
+ seqlen_k,
996
+ seqlen_q_rounded,
997
+ d,
998
+ seqlen_q // 32,
999
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
1000
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
1001
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1002
+ bias_type,
1003
+ causal,
1004
+ BLOCK_HEADDIM,
1005
+ # SEQUENCE_PARALLEL=False,
1006
+ # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1007
+ # num_warps=num_warps,
1008
+ # num_stages=1,
1009
+ )
1010
+ dq.copy_(dq_accum)
1011
+
1012
+
1013
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
1014
+ @staticmethod
1015
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
1016
+ """
1017
+ qkv: (batch, seqlen, 3, nheads, headdim)
1018
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
1019
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
1020
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
1021
+ """
1022
+ # Make sure that the last dimension is contiguous
1023
+ if qkv.stride(-1) != 1:
1024
+ qkv = qkv.contiguous()
1025
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1026
+ qkv[:, :, 0],
1027
+ qkv[:, :, 1],
1028
+ qkv[:, :, 2],
1029
+ bias=bias,
1030
+ causal=causal,
1031
+ softmax_scale=softmax_scale,
1032
+ )
1033
+ ctx.save_for_backward(qkv, o, lse, bias)
1034
+ ctx.causal = causal
1035
+ return o
1036
+
1037
+ @staticmethod
1038
+ def backward(ctx, do):
1039
+ qkv, o, lse, bias = ctx.saved_tensors
1040
+ assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
1041
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1042
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1043
+ with torch.inference_mode():
1044
+ dqkv = torch.empty_like(qkv)
1045
+ _flash_attn_backward(
1046
+ do,
1047
+ qkv[:, :, 0],
1048
+ qkv[:, :, 1],
1049
+ qkv[:, :, 2],
1050
+ o,
1051
+ lse,
1052
+ dqkv[:, :, 0],
1053
+ dqkv[:, :, 1],
1054
+ dqkv[:, :, 2],
1055
+ bias=bias,
1056
+ causal=ctx.causal,
1057
+ softmax_scale=ctx.softmax_scale,
1058
+ )
1059
+ return dqkv, None, None, None
1060
+
1061
+
1062
+ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
1063
+
1064
+
1065
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
1066
+ @staticmethod
1067
+ def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
1068
+ """
1069
+ q: (batch, seqlen_q, nheads, headdim)
1070
+ kv: (batch, seqlen_k, 2, nheads, headdim)
1071
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1072
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1073
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1074
+ """
1075
+ # Make sure that the last dimension is contiguous
1076
+ q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
1077
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1078
+ q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
1079
+ )
1080
+ ctx.save_for_backward(q, kv, o, lse, bias)
1081
+ ctx.causal = causal
1082
+ return o
1083
+
1084
+ @staticmethod
1085
+ def backward(ctx, do):
1086
+ q, kv, o, lse, bias = ctx.saved_tensors
1087
+ if len(ctx.needs_input_grad) >= 3:
1088
+ assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
1089
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1090
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1091
+ with torch.inference_mode():
1092
+ dq = torch.empty_like(q)
1093
+ dkv = torch.empty_like(kv)
1094
+ _flash_attn_backward(
1095
+ do,
1096
+ q,
1097
+ kv[:, :, 0],
1098
+ kv[:, :, 1],
1099
+ o,
1100
+ lse,
1101
+ dq,
1102
+ dkv[:, :, 0],
1103
+ dkv[:, :, 1],
1104
+ bias=bias,
1105
+ causal=ctx.causal,
1106
+ softmax_scale=ctx.softmax_scale,
1107
+ )
1108
+ return dq, dkv, None, None, None
1109
+
1110
+
1111
+ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
1112
+
1113
+
1114
+ class FlashAttnFunc(torch.autograd.Function):
1115
+ @staticmethod
1116
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
1117
+ """
1118
+ q: (batch_size, seqlen_q, nheads, headdim)
1119
+ k, v: (batch_size, seqlen_k, nheads, headdim)
1120
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1121
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1122
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1123
+ """
1124
+ # Make sure that the last dimension is contiguous
1125
+ q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
1126
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1127
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
1128
+ )
1129
+ ctx.save_for_backward(q, k, v, o, lse, bias)
1130
+ ctx.causal = causal
1131
+ return o
1132
+
1133
+ @staticmethod
1134
+ def backward(ctx, do):
1135
+ q, k, v, o, lse, bias = ctx.saved_tensors
1136
+ assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
1137
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1138
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1139
+ with torch.inference_mode():
1140
+ dq = torch.empty_like(q)
1141
+ dk = torch.empty_like(k)
1142
+ dv = torch.empty_like(v)
1143
+ _flash_attn_backward(
1144
+ do,
1145
+ q,
1146
+ k,
1147
+ v,
1148
+ o,
1149
+ lse,
1150
+ dq,
1151
+ dk,
1152
+ dv,
1153
+ bias=bias,
1154
+ causal=ctx.causal,
1155
+ softmax_scale=ctx.softmax_scale,
1156
+ )
1157
+ return dq, dk, dv, None, None, None
1158
+
1159
+
1160
+ flash_attn_func = FlashAttnFunc.apply
.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton_og.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
2
+ # for benchmarking.
3
+ # We fixed a few dtype cast to make it work for bf16
4
+
5
+ """
6
+ Fused Attention
7
+ ===============
8
+ This is a Triton implementation of the Flash Attention algorithm
9
+ (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
10
+ """
11
+
12
+ import pytest
13
+ import torch
14
+ import triton
15
+ import triton.language as tl
16
+
17
+
18
+ @triton.jit
19
+ def _fwd_kernel(
20
+ Q,
21
+ K,
22
+ V,
23
+ sm_scale,
24
+ TMP,
25
+ L,
26
+ M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
27
+ Out,
28
+ stride_qz,
29
+ stride_qh,
30
+ stride_qm,
31
+ stride_qk,
32
+ stride_kz,
33
+ stride_kh,
34
+ stride_kn,
35
+ stride_kk,
36
+ stride_vz,
37
+ stride_vh,
38
+ stride_vk,
39
+ stride_vn,
40
+ stride_oz,
41
+ stride_oh,
42
+ stride_om,
43
+ stride_on,
44
+ Z,
45
+ H,
46
+ N_CTX,
47
+ BLOCK_M: tl.constexpr,
48
+ BLOCK_DMODEL: tl.constexpr,
49
+ BLOCK_N: tl.constexpr,
50
+ ):
51
+ start_m = tl.program_id(0)
52
+ off_hz = tl.program_id(1)
53
+ # initialize offsets
54
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
55
+ offs_n = tl.arange(0, BLOCK_N)
56
+ offs_d = tl.arange(0, BLOCK_DMODEL)
57
+ off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
58
+ off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
59
+ off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
60
+ # Initialize pointers to Q, K, V
61
+ q_ptrs = Q + off_q
62
+ k_ptrs = K + off_k
63
+ v_ptrs = V + off_v
64
+ # initialize pointer to m and l
65
+ t_ptrs = TMP + off_hz * N_CTX + offs_m
66
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
67
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
68
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
69
+ # load q: it will stay in SRAM throughout
70
+ q = tl.load(q_ptrs)
71
+ # loop over k, v and update accumulator
72
+ for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
73
+ start_n = tl.multiple_of(start_n, BLOCK_N)
74
+ # -- compute qk ----
75
+ k = tl.load(k_ptrs + start_n * stride_kn)
76
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
77
+ qk += tl.dot(q, k, trans_b=True)
78
+ qk *= sm_scale
79
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
80
+ # -- compute m_ij, p, l_ij
81
+ m_ij = tl.max(qk, 1)
82
+ p = tl.exp(qk - m_ij[:, None])
83
+ l_ij = tl.sum(p, 1)
84
+ # -- update m_i and l_i
85
+ m_i_new = tl.maximum(m_i, m_ij)
86
+ alpha = tl.exp(m_i - m_i_new)
87
+ beta = tl.exp(m_ij - m_i_new)
88
+ l_i_new = alpha * l_i + beta * l_ij
89
+ # -- update output accumulator --
90
+ # scale p
91
+ p_scale = beta / l_i_new
92
+ p = p * p_scale[:, None]
93
+ # scale acc
94
+ acc_scale = l_i / l_i_new * alpha
95
+ tl.store(t_ptrs, acc_scale)
96
+ acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
97
+ acc = acc * acc_scale[:, None]
98
+ # update acc
99
+ v = tl.load(v_ptrs + start_n * stride_vk)
100
+ p = p.to(v.dtype)
101
+ acc += tl.dot(p, v)
102
+ # update m_i and l_i
103
+ l_i = l_i_new
104
+ m_i = m_i_new
105
+ # rematerialize offsets to save registers
106
+ start_m = tl.program_id(0)
107
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
108
+ # write back l and m
109
+ l_ptrs = L + off_hz * N_CTX + offs_m
110
+ m_ptrs = M + off_hz * N_CTX + offs_m
111
+ tl.store(l_ptrs, l_i)
112
+ tl.store(m_ptrs, m_i)
113
+ # initialize pointers to output
114
+ offs_n = tl.arange(0, BLOCK_DMODEL)
115
+ off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
116
+ out_ptrs = Out + off_o
117
+ tl.store(out_ptrs, acc)
118
+
119
+
120
+ @triton.jit
121
+ def _bwd_preprocess(
122
+ Out,
123
+ DO,
124
+ L,
125
+ NewDO,
126
+ Delta,
127
+ BLOCK_M: tl.constexpr,
128
+ D_HEAD: tl.constexpr,
129
+ ):
130
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
131
+ off_n = tl.arange(0, D_HEAD)
132
+ # load
133
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
134
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
135
+ denom = tl.load(L + off_m).to(tl.float32)
136
+ # compute
137
+ do = do / denom[:, None]
138
+ delta = tl.sum(o * do, axis=1)
139
+ # write-back
140
+ tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
141
+ tl.store(Delta + off_m, delta)
142
+
143
+
144
+ @triton.jit
145
+ def _bwd_kernel(
146
+ Q,
147
+ K,
148
+ V,
149
+ sm_scale,
150
+ Out,
151
+ DO,
152
+ DQ,
153
+ DK,
154
+ DV,
155
+ L,
156
+ M,
157
+ D,
158
+ stride_qz,
159
+ stride_qh,
160
+ stride_qm,
161
+ stride_qk,
162
+ stride_kz,
163
+ stride_kh,
164
+ stride_kn,
165
+ stride_kk,
166
+ stride_vz,
167
+ stride_vh,
168
+ stride_vk,
169
+ stride_vn,
170
+ Z,
171
+ H,
172
+ N_CTX,
173
+ num_block,
174
+ BLOCK_M: tl.constexpr,
175
+ BLOCK_DMODEL: tl.constexpr,
176
+ BLOCK_N: tl.constexpr,
177
+ ):
178
+ off_hz = tl.program_id(0)
179
+ off_z = off_hz // H
180
+ off_h = off_hz % H
181
+ # offset pointers for batch/head
182
+ Q += off_z * stride_qz + off_h * stride_qh
183
+ K += off_z * stride_qz + off_h * stride_qh
184
+ V += off_z * stride_qz + off_h * stride_qh
185
+ DO += off_z * stride_qz + off_h * stride_qh
186
+ DQ += off_z * stride_qz + off_h * stride_qh
187
+ DK += off_z * stride_qz + off_h * stride_qh
188
+ DV += off_z * stride_qz + off_h * stride_qh
189
+ for start_n in range(0, num_block):
190
+ lo = start_n * BLOCK_M
191
+ # initialize row/col offsets
192
+ offs_qm = lo + tl.arange(0, BLOCK_M)
193
+ offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
194
+ offs_m = tl.arange(0, BLOCK_N)
195
+ offs_k = tl.arange(0, BLOCK_DMODEL)
196
+ # initialize pointers to value-like data
197
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
198
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
199
+ v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
200
+ do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
201
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
202
+ # pointer to row-wise quantities in value-like data
203
+ D_ptrs = D + off_hz * N_CTX
204
+ m_ptrs = M + off_hz * N_CTX
205
+ # initialize dv amd dk
206
+ dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
207
+ dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
208
+ # k and v stay in SRAM throughout
209
+ k = tl.load(k_ptrs)
210
+ v = tl.load(v_ptrs)
211
+ # loop over rows
212
+ for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
213
+ offs_m_curr = start_m + offs_m
214
+ # load q, k, v, do on-chip
215
+ q = tl.load(q_ptrs)
216
+ # recompute p = softmax(qk, dim=-1).T
217
+ # NOTE: `do` is pre-divided by `l`; no normalization here
218
+ qk = tl.dot(q, k, trans_b=True)
219
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
220
+ m = tl.load(m_ptrs + offs_m_curr)
221
+ p = tl.exp(qk * sm_scale - m[:, None])
222
+ # compute dv
223
+ do = tl.load(do_ptrs)
224
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
225
+ # compute dp = dot(v, do)
226
+ Di = tl.load(D_ptrs + offs_m_curr)
227
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
228
+ dp += tl.dot(do, v, trans_b=True)
229
+ # compute ds = p * (dp - delta[:, None])
230
+ ds = p * dp * sm_scale
231
+ # compute dk = dot(ds.T, q)
232
+ dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
233
+ # # compute dq
234
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
235
+ dq += tl.dot(ds.to(k.dtype), k)
236
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
237
+ # # increment pointers
238
+ dq_ptrs += BLOCK_M * stride_qm
239
+ q_ptrs += BLOCK_M * stride_qm
240
+ do_ptrs += BLOCK_M * stride_qm
241
+ # write-back
242
+ dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
243
+ dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
244
+ tl.store(dv_ptrs, dv)
245
+ tl.store(dk_ptrs, dk)
246
+
247
+
248
+ class _attention(torch.autograd.Function):
249
+ @staticmethod
250
+ def forward(ctx, q, k, v, sm_scale):
251
+ BLOCK = 128
252
+ # shape constraints
253
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
254
+ assert Lq == Lk and Lk == Lv
255
+ assert Lk in {16, 32, 64, 128}
256
+ o = torch.empty_like(q)
257
+ grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
258
+ tmp = torch.empty(
259
+ (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
260
+ )
261
+ L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
262
+ m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
263
+ num_warps = 4 if Lk <= 64 else 8
264
+
265
+ _fwd_kernel[grid](
266
+ q,
267
+ k,
268
+ v,
269
+ sm_scale,
270
+ tmp,
271
+ L,
272
+ m,
273
+ o,
274
+ q.stride(0),
275
+ q.stride(1),
276
+ q.stride(2),
277
+ q.stride(3),
278
+ k.stride(0),
279
+ k.stride(1),
280
+ k.stride(2),
281
+ k.stride(3),
282
+ v.stride(0),
283
+ v.stride(1),
284
+ v.stride(2),
285
+ v.stride(3),
286
+ o.stride(0),
287
+ o.stride(1),
288
+ o.stride(2),
289
+ o.stride(3),
290
+ q.shape[0],
291
+ q.shape[1],
292
+ q.shape[2],
293
+ BLOCK_M=BLOCK,
294
+ BLOCK_N=BLOCK,
295
+ BLOCK_DMODEL=Lk,
296
+ num_warps=num_warps,
297
+ num_stages=1,
298
+ )
299
+ ctx.save_for_backward(q, k, v, o, L, m)
300
+ ctx.BLOCK = BLOCK
301
+ ctx.grid = grid
302
+ ctx.sm_scale = sm_scale
303
+ ctx.BLOCK_DMODEL = Lk
304
+ return o
305
+
306
+ @staticmethod
307
+ def backward(ctx, do):
308
+ q, k, v, o, l, m = ctx.saved_tensors
309
+ do = do.contiguous()
310
+ dq = torch.zeros_like(q, dtype=torch.float32)
311
+ dk = torch.empty_like(k)
312
+ dv = torch.empty_like(v)
313
+ do_scaled = torch.empty_like(do)
314
+ delta = torch.empty_like(l)
315
+ _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
316
+ o,
317
+ do,
318
+ l,
319
+ do_scaled,
320
+ delta,
321
+ BLOCK_M=ctx.BLOCK,
322
+ D_HEAD=ctx.BLOCK_DMODEL,
323
+ )
324
+
325
+ # NOTE: kernel currently buggy for other values of `num_warps`
326
+ num_warps = 8
327
+ _bwd_kernel[(ctx.grid[1],)](
328
+ q,
329
+ k,
330
+ v,
331
+ ctx.sm_scale,
332
+ o,
333
+ do_scaled,
334
+ dq,
335
+ dk,
336
+ dv,
337
+ l,
338
+ m,
339
+ delta,
340
+ q.stride(0),
341
+ q.stride(1),
342
+ q.stride(2),
343
+ q.stride(3),
344
+ k.stride(0),
345
+ k.stride(1),
346
+ k.stride(2),
347
+ k.stride(3),
348
+ v.stride(0),
349
+ v.stride(1),
350
+ v.stride(2),
351
+ v.stride(3),
352
+ q.shape[0],
353
+ q.shape[1],
354
+ q.shape[2],
355
+ ctx.grid[0],
356
+ BLOCK_M=ctx.BLOCK,
357
+ BLOCK_N=ctx.BLOCK,
358
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL,
359
+ num_warps=num_warps,
360
+ num_stages=1,
361
+ )
362
+ return dq.to(q.dtype), dk, dv, None
363
+
364
+
365
+ attention = _attention.apply
.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attention.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import hydra
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
9
+ from flash_attn.flash_blocksparse_attn_interface import (
10
+ convert_blockmask,
11
+ flash_blocksparse_attn_func,
12
+ )
13
+
14
+
15
+ class FlashBlocksparseAttention(nn.Module):
16
+ """Implement the scaled dot product attention with softmax.
17
+ Arguments
18
+ ---------
19
+ softmax_temp: The temperature to use for the softmax attention.
20
+ (default: 1/sqrt(d_keys) where d_keys is computed at
21
+ runtime)
22
+ attention_dropout: The dropout rate to apply to the attention
23
+ (default: 0.1)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ sparsity_config,
29
+ softmax_temp=None,
30
+ attention_dropout=0.0,
31
+ max_seq_length=2048,
32
+ device=None,
33
+ dtype=None,
34
+ ):
35
+ super().__init__()
36
+ self.sparsity_config = hydra.utils.instantiate(sparsity_config)
37
+ self.softmax_temp = softmax_temp
38
+ self.dropout_p = attention_dropout
39
+
40
+ # initialize sparse layout and register as buffer
41
+ max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
42
+ layout = self.sparsity_config.make_layout(max_seq_length)
43
+ self.register_buffer("layout", layout)
44
+ blockmask_converted = convert_blockmask(self.layout, causal=False)
45
+ self.register_buffer("blockmask_converted", blockmask_converted)
46
+ # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
47
+
48
+ def forward(
49
+ self,
50
+ qkv,
51
+ attn_mask=None,
52
+ key_padding_mask=None,
53
+ causal=False,
54
+ cu_seqlens=None,
55
+ max_s=None,
56
+ need_weights=False,
57
+ convert_mask=True,
58
+ ):
59
+ """Implements the multihead softmax attention.
60
+ Arguments
61
+ ---------
62
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
63
+ attn_mask: An implementation of BaseMask that encodes where each
64
+ query can attend to
65
+ key_padding_mask: An implementation of BaseMask that encodes how
66
+ many query each sequence in the batch consists of
67
+ """
68
+ assert not need_weights
69
+ assert attn_mask is None
70
+ assert qkv.dtype == torch.float16
71
+ assert qkv.is_cuda
72
+
73
+ if cu_seqlens is None:
74
+ batch_size = qkv.shape[0]
75
+ seqlen = qkv.shape[1]
76
+ # Convert mask to take a subset
77
+ seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
78
+ assert seqlen_rounded // 16 <= self.layout.shape[0], (
79
+ seqlen_rounded // 256 <= self.layout.shape[1]
80
+ )
81
+ blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
82
+ if key_padding_mask is None:
83
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
84
+ max_s = seqlen
85
+ cu_seqlens = torch.arange(
86
+ 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
87
+ )
88
+ output = flash_blocksparse_attn_func(
89
+ qkv,
90
+ cu_seqlens,
91
+ blockmask,
92
+ self.dropout_p if self.training else 0.0,
93
+ max_s,
94
+ softmax_scale=self.softmax_temp,
95
+ causal=causal,
96
+ )
97
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
98
+ else:
99
+ key_padding_mask_bool = key_padding_mask.bool_matrix
100
+ nheads = qkv.shape[-2]
101
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
102
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
103
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
104
+ output_unpad = flash_blocksparse_attn_func(
105
+ x_unpad,
106
+ cu_seqlens,
107
+ blockmask,
108
+ self.dropout_p if self.training else 0.0,
109
+ max_s,
110
+ softmax_scale=self.softmax_temp,
111
+ causal=causal,
112
+ )
113
+ output = rearrange(
114
+ pad_input(
115
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
116
+ ),
117
+ "b s (h d) -> b s h d",
118
+ h=nheads,
119
+ )
120
+ else:
121
+ assert max_s is not None
122
+ seqlen = max_s
123
+ # Convert mask to take a subset
124
+ seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
125
+ assert seqlen_rounded // 16 <= self.layout.shape[0], (
126
+ seqlen_rounded // 256 <= self.layout.shape[1]
127
+ )
128
+ blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
129
+ if convert_mask:
130
+ output = flash_blocksparse_attn_func(
131
+ qkv,
132
+ cu_seqlens,
133
+ blockmask,
134
+ self.dropout_p if self.training else 0.0,
135
+ max_s,
136
+ softmax_scale=self.softmax_temp,
137
+ causal=causal,
138
+ )
139
+ else:
140
+ output = flash_blocksparse_attn_func(
141
+ qkv,
142
+ cu_seqlens,
143
+ self.blockmask_converted,
144
+ self.dropout_p if self.training else 0.0,
145
+ max_s,
146
+ softmax_scale=self.softmax_temp,
147
+ causal=causal,
148
+ convert_mask=False,
149
+ )
150
+
151
+ return output, None
152
+
153
+
154
+ class FlashBlocksparseMHA(nn.Module):
155
+ def __init__(
156
+ self,
157
+ embed_dim,
158
+ num_heads,
159
+ sparsity_config,
160
+ bias=True,
161
+ batch_first=True,
162
+ attention_dropout=0.0,
163
+ causal=False,
164
+ max_seq_length=2048,
165
+ device=None,
166
+ dtype=None,
167
+ **kwargs,
168
+ ) -> None:
169
+ assert batch_first
170
+ factory_kwargs = {"device": device, "dtype": dtype}
171
+ super().__init__()
172
+ self.embed_dim = embed_dim
173
+ self.causal = causal
174
+
175
+ self.num_heads = num_heads
176
+ assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
177
+ self.head_dim = self.embed_dim // num_heads
178
+ assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
179
+
180
+ self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
181
+ self.inner_attn = FlashBlocksparseAttention(
182
+ sparsity_config,
183
+ attention_dropout=attention_dropout,
184
+ max_seq_length=max_seq_length,
185
+ **factory_kwargs,
186
+ )
187
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
188
+
189
+ def forward(
190
+ self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
191
+ ):
192
+ qkv = self.Wqkv(x)
193
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
194
+ context, attn_weights = self.inner_attn(
195
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
196
+ )
197
+ return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights
.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attn_interface.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
2
+ import flash_attn_cuda
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def convert_blockmask(blockmask, causal):
8
+ """Convert from the 0-1 format to the format used by the CUDA code.
9
+ 0 means the block is skipped.
10
+ nonzero means the block is not skipped.
11
+ Argument:
12
+ blockmask: (row, col): a 0-1 tensor
13
+ Return:
14
+ blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
15
+ indices of the nonzero blocks, padded with -1 to reach length @row.
16
+ The indices are multiplied by 4, with the smallest bit used to encode whether
17
+ it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
18
+ the last nonzero in its row..
19
+ """
20
+ assert not causal
21
+ # TD [2022-05-13]: The indexing and sorting is very tricky
22
+ nrow, ncol = blockmask.shape
23
+ # Sort does not support bool on CUDA
24
+ blockmask = blockmask.to(dtype=torch.uint8)
25
+ nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
26
+ nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
27
+ last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
28
+ last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
29
+ torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
30
+ ]
31
+ first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
32
+ first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
33
+ torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
34
+ ]
35
+ nonzero_idx = nonzero_sorted_rowidx * 4
36
+ nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
37
+ nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
38
+ nonzero_idx[nonzero_val == 0] = -1
39
+ return nonzero_idx.T.contiguous().to(dtype=torch.int32)
40
+
41
+
42
+ def _flash_blocksparse_attn_forward(
43
+ qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
44
+ ):
45
+ context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
46
+ qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
47
+ )
48
+ # if context.isnan().any() or softmax_lse.isnan().any():
49
+ # breakpoint()
50
+ S_dmask = rest[0] if return_softmax else None
51
+ return context, softmax_lse, S_dmask
52
+
53
+
54
+ def _flash_blocksparse_attn_backward(
55
+ dout,
56
+ qkv,
57
+ out,
58
+ S_dmask,
59
+ softmax_lse,
60
+ cu_seqlens,
61
+ blockmask,
62
+ dropout_p,
63
+ max_s,
64
+ softmax_scale,
65
+ causal,
66
+ ):
67
+ dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
68
+ dout,
69
+ qkv,
70
+ out,
71
+ S_dmask,
72
+ softmax_lse,
73
+ cu_seqlens,
74
+ blockmask,
75
+ dropout_p,
76
+ softmax_scale,
77
+ max_s,
78
+ causal,
79
+ None,
80
+ )
81
+ # if dqkv.isnan().any() or softmax_d.isnan().any():
82
+ # breakpoint()
83
+ return dqkv
84
+
85
+
86
+ class FlashBlocksparseAttnFun(torch.autograd.Function):
87
+ @staticmethod
88
+ def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
89
+ # Save rng_state because the backward pass will regenerate the dropout mask
90
+ rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
91
+ if softmax_scale is None:
92
+ softmax_scale = qkv.shape[-1] ** (-0.5)
93
+ context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
94
+ qkv,
95
+ cu_seqlens,
96
+ blockmask,
97
+ dropout_p,
98
+ max_s,
99
+ softmax_scale,
100
+ causal=causal,
101
+ return_softmax=False,
102
+ )
103
+ ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
104
+ ctx.dropout_p = dropout_p
105
+ ctx.max_s = max_s
106
+ ctx.softmax_scale = softmax_scale
107
+ ctx.causal = causal
108
+ return context
109
+
110
+ @staticmethod
111
+ def backward(ctx, dout):
112
+ qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
113
+ if rng_state is not None:
114
+ cur_rng_state = torch.cuda.get_rng_state()
115
+ torch.cuda.set_rng_state(rng_state)
116
+ # S_dmask is None, temporarily use another tensor just to get it running
117
+ dqkv = _flash_blocksparse_attn_backward(
118
+ dout,
119
+ qkv,
120
+ context,
121
+ context,
122
+ softmax_lse,
123
+ cu_seqlens,
124
+ blockmask,
125
+ ctx.dropout_p,
126
+ ctx.max_s,
127
+ ctx.softmax_scale,
128
+ ctx.causal,
129
+ )
130
+ if rng_state is not None:
131
+ torch.cuda.set_rng_state(cur_rng_state)
132
+ return dqkv, None, None, None, None, None, None, None
133
+
134
+
135
+ # We duplicate code to return both the output and the softmax for testing
136
+ # Returning both makes backward a bit slower, so we want to keep using the other version for speed.
137
+ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
138
+ @staticmethod
139
+ def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
140
+ # Save rng_state because the backward pass is gonna regenerate the dropout mask
141
+ rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
142
+ if softmax_scale is None:
143
+ softmax_scale = qkv.shape[-1] ** (-0.5)
144
+ context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
145
+ qkv,
146
+ cu_seqlens,
147
+ blockmask,
148
+ dropout_p,
149
+ max_s,
150
+ softmax_scale,
151
+ causal=causal,
152
+ return_softmax=True,
153
+ )
154
+ ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
155
+ ctx.dropout_p = dropout_p
156
+ ctx.max_s = max_s
157
+ ctx.softmax_scale = softmax_scale
158
+ ctx.causal = causal
159
+ return context, S_dmask, softmax_lse
160
+
161
+ @staticmethod
162
+ def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
163
+ qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
164
+ if rng_state is not None:
165
+ cur_rng_state = torch.cuda.get_rng_state()
166
+ torch.cuda.set_rng_state(rng_state)
167
+ dqkv = _flash_blocksparse_attn_backward(
168
+ dout,
169
+ qkv,
170
+ context,
171
+ S_dmask,
172
+ softmax_lse,
173
+ cu_seqlens,
174
+ blockmask,
175
+ ctx.dropout_p,
176
+ ctx.max_s,
177
+ ctx.softmax_scale,
178
+ ctx.causal,
179
+ )
180
+ if rng_state is not None:
181
+ torch.cuda.set_rng_state(cur_rng_state)
182
+ return dqkv, None, None, None, None, None, None
183
+
184
+
185
+ def flash_blocksparse_attn_func(
186
+ qkv,
187
+ cu_seqlens,
188
+ blockmask,
189
+ dropout_p,
190
+ max_s,
191
+ softmax_scale=None,
192
+ causal=False,
193
+ return_attn_probs=False,
194
+ convert_mask=True,
195
+ ):
196
+ """dropout_p should be set to 0.0 during evaluation"""
197
+ func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
198
+ if convert_mask:
199
+ blockmask = convert_blockmask(blockmask, causal=causal)
200
+ return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)
.venv/lib/python3.11/site-packages/xformers/_flash_attn/fused_softmax.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py
2
+ # for benchmarking.
3
+ # We added support for seqlen=2k and seqlen=4k
4
+
5
+ # coding=utf-8
6
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ import torch
20
+ from apex._autocast_utils import _cast_if_autocast_enabled
21
+ from apex.transformer.enums import AttnMaskType
22
+ from fused_softmax_lib import (
23
+ scaled_masked_softmax_backward,
24
+ scaled_masked_softmax_forward,
25
+ scaled_masked_softmax_get_batch_per_block,
26
+ scaled_upper_triang_masked_softmax_backward,
27
+ scaled_upper_triang_masked_softmax_forward,
28
+ )
29
+
30
+
31
+ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
32
+ """
33
+ Fused operation which performs following three operations in sequence
34
+ 1. Scale the tensor.
35
+ 2. Apply upper triangular mask (typically used in gpt models).
36
+ 3. Perform softmax.
37
+ """
38
+
39
+ @staticmethod
40
+ def forward(ctx, inputs, scale):
41
+ scale_t = torch.tensor([scale])
42
+ softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
43
+ ctx.save_for_backward(softmax_results, scale_t)
44
+ return softmax_results
45
+
46
+ @staticmethod
47
+ def backward(ctx, output_grads):
48
+ softmax_results, scale_t = ctx.saved_tensors
49
+ input_grads = scaled_upper_triang_masked_softmax_backward(
50
+ output_grads, softmax_results, scale_t[0]
51
+ )
52
+ return input_grads, None
53
+
54
+
55
+ def scaled_upper_triang_masked_softmax(inputs, _, scale):
56
+ b, np, sq, sk = inputs.size()
57
+ assert sq == sk, "causal mask is only for self attention"
58
+ # Reshaping input to 3D tensor (attn_batches, sq, sk)
59
+ inputs = inputs.view(-1, sq, sk)
60
+ args = _cast_if_autocast_enabled(inputs, scale)
61
+ with torch.cuda.amp.autocast(enabled=False):
62
+ probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
63
+ return probs.view(b, np, sq, sk)
64
+
65
+
66
+ # NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
67
+ # Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
68
+ # So I needed to manually write two `torch.autograd.Function` inheritances.
69
+ # Fused operation which performs following three operations in sequence
70
+ # 1. Scale the tensor.
71
+ # 2. Apply the mask.
72
+ # 3. Perform softmax.
73
+ class ScaledMaskedSoftmax(torch.autograd.Function):
74
+ @staticmethod
75
+ def forward(ctx, inputs, mask, scale):
76
+ scale_t = torch.tensor([scale])
77
+ softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0])
78
+ ctx.save_for_backward(softmax_results, scale_t)
79
+ return softmax_results
80
+
81
+ @staticmethod
82
+ def backward(ctx, output_grads):
83
+ softmax_results, scale_t = ctx.saved_tensors
84
+ input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
85
+ return input_grads, None, None
86
+
87
+
88
+ def scaled_masked_softmax(inputs, mask, scale):
89
+ # input is 4D tensor (b, np, sq, sk)
90
+ args = _cast_if_autocast_enabled(inputs, mask, scale)
91
+ with torch.cuda.amp.autocast(enabled=False):
92
+ return ScaledMaskedSoftmax.apply(*args)
93
+
94
+
95
+ class FusedScaleMaskSoftmax(torch.nn.Module):
96
+ """
97
+ fused operation: scaling + mask + softmax
98
+
99
+ Arguments:
100
+ input_in_fp16: flag to indicate if input in fp16 data format.
101
+ input_in_bf16: flag to indicate if input in bf16 data format.
102
+ attn_mask_type: attention mask type (pad or causal)
103
+ scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
104
+ mask_func: mask function to be applied.
105
+ softmax_in_fp32: if true, softmax in performed at fp32 precision.
106
+ scale: scaling factor used in input tensor scaling.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ input_in_fp16,
112
+ input_in_bf16,
113
+ attn_mask_type,
114
+ scaled_masked_softmax_fusion,
115
+ mask_func,
116
+ softmax_in_fp32,
117
+ scale,
118
+ ):
119
+ super().__init__()
120
+ self.input_in_fp16 = input_in_fp16
121
+ self.input_in_bf16 = input_in_bf16
122
+ if self.input_in_fp16 and self.input_in_bf16:
123
+ raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
124
+ self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
125
+ self.attn_mask_type = attn_mask_type
126
+ self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
127
+ self.mask_func = mask_func
128
+ self.softmax_in_fp32 = softmax_in_fp32
129
+ self.scale = scale
130
+
131
+ if not (self.scale is None or softmax_in_fp32):
132
+ raise RuntimeError("softmax should be in fp32 when scaled")
133
+
134
+ if self.scaled_masked_softmax_fusion:
135
+ if self.attn_mask_type == AttnMaskType.causal:
136
+ self.fused_softmax_func = scaled_upper_triang_masked_softmax
137
+ elif self.attn_mask_type == AttnMaskType.padding:
138
+ self.fused_softmax_func = scaled_masked_softmax
139
+ else:
140
+ raise ValueError("Invalid attn_mask_type.")
141
+
142
+ def forward(self, input, mask):
143
+ # [b, np, sq, sk]
144
+ assert input.dim() == 4
145
+
146
+ if self.is_kernel_available(mask, *input.size()):
147
+ return self.forward_fused_softmax(input, mask)
148
+ else:
149
+ return self.forward_torch_softmax(input, mask)
150
+
151
+ def is_kernel_available(self, mask, b, np, sq, sk):
152
+ attn_batches = b * np
153
+
154
+ if (
155
+ self.scaled_masked_softmax_fusion # user want to fuse
156
+ and self.input_in_float16 # input must be fp16
157
+ and (
158
+ self.attn_mask_type == AttnMaskType.causal
159
+ or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
160
+ )
161
+ and 16 < sk <= 8192 # sk must be 16 ~ 8192
162
+ and sq % 4 == 0 # sq must be divisor of 4
163
+ and sk % 4 == 0 # sk must be divisor of 4
164
+ and attn_batches % 4 == 0 # np * b must be divisor of 4
165
+ ):
166
+ if 0 <= sk <= 8192:
167
+ batch_per_block = self.get_batch_per_block(sq, sk, b, np)
168
+
169
+ if self.attn_mask_type == AttnMaskType.causal:
170
+ if attn_batches % batch_per_block == 0:
171
+ return True
172
+ else:
173
+ if sq % batch_per_block == 0:
174
+ return True
175
+ return False
176
+
177
+ def forward_fused_softmax(self, input, mask):
178
+ # input.shape = [b, np, sq, sk]
179
+ scale = self.scale if self.scale is not None else 1.0
180
+ return self.fused_softmax_func(input, mask, scale)
181
+
182
+ def forward_torch_softmax(self, input, mask):
183
+ if self.input_in_float16 and self.softmax_in_fp32:
184
+ input = input.float()
185
+
186
+ if self.scale is not None:
187
+ input = input * self.scale
188
+ mask_output = self.mask_func(input, mask) if mask is not None else input
189
+ probs = torch.nn.Softmax(dim=-1)(mask_output)
190
+
191
+ if self.input_in_float16 and self.softmax_in_fp32:
192
+ if self.input_in_fp16:
193
+ probs = probs.half()
194
+ else:
195
+ probs = probs.bfloat16()
196
+
197
+ return probs
198
+
199
+ @staticmethod
200
+ def get_batch_per_block(sq, sk, b, np):
201
+ return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np)
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/block.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/embedding.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mha.cpython-311.pyc ADDED
Binary file (41.9 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mlp.cpython-311.pyc ADDED
Binary file (7.81 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/block.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ from functools import partial
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+ from torchvision.ops import StochasticDepth
11
+
12
+ from flash_attn.modules.mha import MHA
13
+ from flash_attn.modules.mlp import Mlp
14
+
15
+ try:
16
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
17
+ except ImportError:
18
+ layer_norm_fn, RMSNorm = None, None
19
+
20
+
21
+ class Block(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim,
25
+ mixer_cls=None,
26
+ mlp_cls=None,
27
+ norm_cls=nn.LayerNorm,
28
+ dropout_cls=nn.Dropout,
29
+ prenorm=True,
30
+ resid_dropout1=0.0,
31
+ resid_dropout2=0.0,
32
+ drop_path1=0.0,
33
+ drop_path2=0.0,
34
+ fused_dropout_add_ln=False,
35
+ return_residual=False,
36
+ residual_in_fp32=False,
37
+ sequence_parallel=False,
38
+ mark_shared_params=False,
39
+ ):
40
+ """
41
+ For prenorm=True, this Block has a slightly different structure compared to a regular
42
+ prenorm Transformer block.
43
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
44
+ [Ref: https://arxiv.org/abs/2002.04745]
45
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
46
+ the hidden_states (output of the MLP) and the residual.
47
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
48
+ The residual needs to be provided (except for the very first block).
49
+
50
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
51
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
52
+
53
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
54
+ This is for performance reason: for post-norm architecture, returning the input allows us
55
+ to fuse the backward of nn.Linear with the residual connection.
56
+ """
57
+ super().__init__()
58
+ self.prenorm = prenorm
59
+ self.fused_dropout_add_ln = fused_dropout_add_ln
60
+ self.return_residual = return_residual
61
+ self.residual_in_fp32 = residual_in_fp32
62
+ if self.residual_in_fp32:
63
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
64
+ if mixer_cls is None:
65
+ mixer_cls = partial(MHA, num_heads=dim // 64)
66
+ if mlp_cls is None:
67
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
68
+ self.mixer = mixer_cls(dim)
69
+ self.dropout1 = dropout_cls(resid_dropout1)
70
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
71
+ self.norm1 = norm_cls(dim)
72
+ self.mlp = mlp_cls(dim)
73
+ if not isinstance(self.mlp, nn.Identity):
74
+ self.dropout2 = dropout_cls(resid_dropout2)
75
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
76
+ self.norm2 = norm_cls(dim)
77
+
78
+ if self.fused_dropout_add_ln:
79
+ assert layer_norm_fn is not None, "Triton is not installed"
80
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
81
+ self.dropout1, nn.Dropout
82
+ )
83
+
84
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
85
+ # then the input to each worker in the tensor parallel group will be different.
86
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
87
+ # For now this is not an issue because we always use sequence_parallel=True during training
88
+ # and only use sequence_parallel=False during inference.
89
+
90
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
91
+ if sequence_parallel:
92
+ for p in self.norm1.parameters():
93
+ p._sequence_parallel = True
94
+ if hasattr(self, "norm2"):
95
+ for p in self.norm2.parameters():
96
+ p._sequence_parallel = True
97
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
98
+ if mark_shared_params:
99
+ for p in self.norm1.parameters():
100
+ p._shared_params = True
101
+ if hasattr(self, "norm2"):
102
+ for p in self.norm2.parameters():
103
+ p._shared_params = True
104
+
105
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
106
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states: Tensor,
111
+ residual: Optional[Tensor] = None,
112
+ mixer_subset=None,
113
+ mixer_kwargs=None,
114
+ ):
115
+ r"""Pass the input through the encoder layer.
116
+
117
+ Args:
118
+ hidden_states: the sequence to the encoder layer (required).
119
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
120
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
121
+ before applying the query projection. Useful for e.g., ViT where we only care
122
+ about the CLS token in the last layer.
123
+ """
124
+ if self.prenorm:
125
+ if not self.fused_dropout_add_ln:
126
+ dropped = self.drop_path1(self.dropout1(hidden_states))
127
+ residual = (dropped + residual) if residual is not None else dropped
128
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
129
+ if self.residual_in_fp32:
130
+ residual = residual.to(torch.float32)
131
+ else:
132
+ if self.drop_path1.p == 0 or not self.training:
133
+ rowscale1 = None
134
+ else:
135
+ rowscale1 = self.drop_path1(
136
+ torch.ones(
137
+ hidden_states.shape[:-1],
138
+ device=hidden_states.device,
139
+ dtype=hidden_states.dtype,
140
+ )
141
+ )
142
+ hidden_states, residual = layer_norm_fn(
143
+ hidden_states,
144
+ self.norm1.weight,
145
+ self.norm1.bias,
146
+ residual=residual,
147
+ eps=self.norm1.eps,
148
+ dropout_p=self.dropout1.p if self.training else 0.0,
149
+ rowscale=rowscale1,
150
+ prenorm=True,
151
+ residual_in_fp32=self.residual_in_fp32,
152
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
153
+ )
154
+ if mixer_kwargs is None:
155
+ mixer_kwargs = {}
156
+ if mixer_subset is not None:
157
+ mixer_kwargs["mixer_subset"] = mixer_subset
158
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
159
+ if mixer_subset is not None:
160
+ residual = residual[:, mixer_subset]
161
+ if not isinstance(self.mlp, nn.Identity):
162
+ if not self.fused_dropout_add_ln:
163
+ dropped = self.drop_path2(self.dropout2(hidden_states))
164
+ residual = (dropped + residual) if residual is not None else dropped
165
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
166
+ if self.residual_in_fp32:
167
+ residual = residual.to(torch.float32)
168
+ else:
169
+ if self.drop_path2.p == 0 or not self.training:
170
+ rowscale2 = None
171
+ else:
172
+ rowscale2 = self.drop_path2(
173
+ torch.ones(
174
+ hidden_states.shape[:-1],
175
+ device=hidden_states.device,
176
+ dtype=hidden_states.dtype,
177
+ )
178
+ )
179
+ hidden_states, residual = layer_norm_fn(
180
+ hidden_states,
181
+ self.norm2.weight,
182
+ self.norm2.bias,
183
+ residual=residual,
184
+ eps=self.norm2.eps,
185
+ dropout_p=self.dropout2.p if self.training else 0.0,
186
+ rowscale=rowscale2,
187
+ prenorm=True,
188
+ residual_in_fp32=self.residual_in_fp32,
189
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
190
+ )
191
+ hidden_states = self.mlp(hidden_states)
192
+ return hidden_states, residual
193
+ else:
194
+ assert residual is None
195
+ mixer_out = self.mixer(
196
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
197
+ )
198
+ if self.return_residual: # mixer out is actually a pair here
199
+ mixer_out, hidden_states = mixer_out
200
+ if not self.fused_dropout_add_ln:
201
+ hidden_states = self.norm1(
202
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
203
+ dtype=self.norm1.weight.dtype
204
+ )
205
+ )
206
+ else:
207
+ if self.drop_path1.p == 0 or not self.training:
208
+ rowscale1 = None
209
+ else:
210
+ rowscale1 = self.drop_path1(
211
+ torch.ones(
212
+ mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
213
+ )
214
+ )
215
+ hidden_states = layer_norm_fn(
216
+ mixer_out,
217
+ self.norm1.weight,
218
+ self.norm1.bias,
219
+ residual=hidden_states,
220
+ eps=self.norm1.eps,
221
+ dropout_p=self.dropout1.p if self.training else 0.0,
222
+ rowscale=rowscale1,
223
+ prenorm=False,
224
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
225
+ )
226
+ if not isinstance(self.mlp, nn.Identity):
227
+ mlp_out = self.mlp(hidden_states)
228
+ if self.return_residual: # mlp out is actually a pair here
229
+ mlp_out, hidden_states = mlp_out
230
+ if not self.fused_dropout_add_ln:
231
+ hidden_states = self.norm2(
232
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
233
+ dtype=self.norm2.weight.dtype
234
+ )
235
+ )
236
+ else:
237
+ if self.drop_path2.p == 0 or not self.training:
238
+ rowscale2 = None
239
+ else:
240
+ rowscale2 = self.drop_path2(
241
+ torch.ones(
242
+ mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
243
+ )
244
+ )
245
+ hidden_states = layer_norm_fn(
246
+ mlp_out,
247
+ self.norm2.weight,
248
+ self.norm2.bias,
249
+ residual=hidden_states,
250
+ eps=self.norm2.eps,
251
+ dropout_p=self.dropout2.p if self.training else 0.0,
252
+ rowscale=rowscale2,
253
+ prenorm=False,
254
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
255
+ )
256
+ return hidden_states
257
+
258
+
259
+ class ParallelBlock(nn.Module):
260
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
261
+ and PaLM.
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ dim,
267
+ mixer_cls=None,
268
+ mlp_cls=None,
269
+ norm_cls=nn.LayerNorm,
270
+ dropout_cls=nn.Dropout,
271
+ resid_dropout1=0.0,
272
+ resid_dropout2=0.0,
273
+ tied_norm=False,
274
+ fused_dropout_add_ln=False,
275
+ residual_in_fp32=False,
276
+ sequence_parallel=False,
277
+ mark_shared_params=False,
278
+ ):
279
+ """
280
+ This Block has a slightly different structure compared to a regular
281
+ prenorm Transformer block.
282
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
283
+ [Ref: https://arxiv.org/abs/2002.04745]
284
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
285
+ the hidden_states (output1 of the MHA / MLP) and the residual.
286
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
287
+ The residual needs to be provided (except for the very first block).
288
+ """
289
+ super().__init__()
290
+ self.tied_norm = tied_norm
291
+ self.fused_dropout_add_ln = fused_dropout_add_ln
292
+ self.residual_in_fp32 = residual_in_fp32
293
+ if mixer_cls is None:
294
+ mixer_cls = partial(MHA, num_heads=dim // 64)
295
+ if mlp_cls is None:
296
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
297
+ self.mixer = mixer_cls(dim)
298
+ self.dropout1 = dropout_cls(resid_dropout1)
299
+ self.norm1 = norm_cls(dim)
300
+ self.mlp = mlp_cls(dim)
301
+ self.dropout2 = dropout_cls(resid_dropout2)
302
+ if not self.tied_norm:
303
+ self.norm2 = norm_cls(dim)
304
+
305
+ if self.fused_dropout_add_ln:
306
+ assert layer_norm_fn is not None, "Triton is not installed"
307
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
308
+ self.dropout1, nn.Dropout
309
+ )
310
+
311
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
312
+ # then the input to each worker in the tensor parallel group will be different.
313
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
314
+ # For now this is not an issue because we always use sequence_parallel=True during training
315
+ # and only use sequence_parallel=False during inference.
316
+
317
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
318
+ if sequence_parallel:
319
+ for p in self.norm1.parameters():
320
+ p._sequence_parallel = True
321
+ if hasattr(self, "norm2"):
322
+ for p in self.norm2.parameters():
323
+ p._sequence_parallel = True
324
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
325
+ if mark_shared_params:
326
+ for p in self.norm1.parameters():
327
+ p._shared_params = True
328
+ if hasattr(self, "norm2"):
329
+ for p in self.norm2.parameters():
330
+ p._shared_params = True
331
+
332
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
333
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
334
+
335
+ def forward(
336
+ self,
337
+ hidden_states1: Tensor,
338
+ hidden_states2: Optional[Tensor] = None,
339
+ residual: Optional[Tensor] = None,
340
+ mixer_kwargs=None,
341
+ ):
342
+ r"""Pass the input through the encoder layer.
343
+
344
+ Args:
345
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
346
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
347
+ residual.
348
+ """
349
+ # TODO: Ideally we should only do the allgather / allreduce once for
350
+ # the Linear to MLP & Attention
351
+ if not self.fused_dropout_add_ln:
352
+ dropped1 = self.dropout1(hidden_states1)
353
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
354
+ if hidden_states2 is not None:
355
+ dropped2 = self.dropout2(hidden_states2)
356
+ residual = (
357
+ (residual + dropped1 + dropped2)
358
+ if residual is not None
359
+ else dropped1 + dropped2
360
+ )
361
+ else:
362
+ residual = (residual + dropped1) if residual is not None else dropped1
363
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
364
+ hidden_states2 = (
365
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
366
+ if not self.tied_norm
367
+ else hidden_states1
368
+ )
369
+ if self.residual_in_fp32:
370
+ residual = residual.to(torch.float32)
371
+ else:
372
+ weight2, bias2 = (
373
+ (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
374
+ )
375
+ hidden_states1, *rest, residual = layer_norm_fn(
376
+ hidden_states1,
377
+ self.norm1.weight,
378
+ self.norm1.bias,
379
+ residual=residual,
380
+ x1=hidden_states2,
381
+ weight1=weight2,
382
+ bias1=bias2,
383
+ eps=self.norm1.eps,
384
+ dropout_p=self.dropout1.p if self.training else 0.0,
385
+ prenorm=True,
386
+ residual_in_fp32=self.residual_in_fp32,
387
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
388
+ )
389
+ if self.tied_norm:
390
+ hidden_states2 = hidden_states1
391
+ else:
392
+ hidden_states2, = rest
393
+ if mixer_kwargs is None:
394
+ mixer_kwargs = {}
395
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
396
+ hidden_states2 = self.mlp(hidden_states2)
397
+ return hidden_states1, hidden_states2, residual
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/embedding.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from torch import Tensor
7
+
8
+ from flash_attn.utils.distributed import all_reduce, reduce_scatter
9
+
10
+
11
+ class GPT2Embeddings(nn.Module):
12
+ def __init__(
13
+ self,
14
+ embed_dim,
15
+ vocab_size,
16
+ max_position_embeddings,
17
+ padding_idx=None,
18
+ word_embed_proj_dim=None,
19
+ device=None,
20
+ dtype=None,
21
+ ):
22
+ """
23
+ If max_position_embeddings <= 0, there's no position embeddings
24
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
25
+ the project up to embed_dim
26
+ """
27
+ factory_kwargs = {"device": device, "dtype": dtype}
28
+ super().__init__()
29
+ if word_embed_proj_dim is None:
30
+ self.word_embeddings = nn.Embedding(
31
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
32
+ )
33
+ self.project_in = None
34
+ else:
35
+ self.word_embeddings = nn.Embedding(
36
+ vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
37
+ )
38
+ self.project_in = nn.Linear(
39
+ word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
40
+ )
41
+ self.max_position_embeddings = max_position_embeddings
42
+ if self.max_position_embeddings > 0:
43
+ self.position_embeddings = nn.Embedding(
44
+ max_position_embeddings, embed_dim, **factory_kwargs
45
+ )
46
+
47
+ def forward(self, input_ids, position_ids=None):
48
+ """
49
+ input_ids: (batch, seqlen)
50
+ position_ids: (batch, seqlen)
51
+ """
52
+ batch_size, seqlen = input_ids.shape
53
+ embeddings = self.word_embeddings(input_ids)
54
+ if self.project_in is not None:
55
+ embeddings = self.project_in(embeddings)
56
+ if self.max_position_embeddings > 0:
57
+ if position_ids is None:
58
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
59
+ position_embeddings = self.position_embeddings(position_ids)
60
+ embeddings = embeddings + position_embeddings
61
+ return embeddings
62
+
63
+
64
+ class BertEmbeddings(nn.Module):
65
+ def __init__(
66
+ self,
67
+ embed_dim,
68
+ vocab_size,
69
+ max_position_embeddings,
70
+ type_vocab_size,
71
+ padding_idx=None,
72
+ device=None,
73
+ dtype=None,
74
+ ):
75
+ """
76
+ If max_position_embeddings <= 0, there's no position embeddings
77
+ If type_vocab_size <= 0, there's no token type embeddings
78
+ """
79
+ factory_kwargs = {"device": device, "dtype": dtype}
80
+ super().__init__()
81
+ self.word_embeddings = nn.Embedding(
82
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
83
+ )
84
+ self.max_position_embeddings = max_position_embeddings
85
+ self.type_vocab_size = type_vocab_size
86
+ if self.max_position_embeddings > 0:
87
+ self.position_embeddings = nn.Embedding(
88
+ max_position_embeddings, embed_dim, **factory_kwargs
89
+ )
90
+ if self.type_vocab_size > 0:
91
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
92
+
93
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
94
+ """
95
+ input_ids: (batch, seqlen)
96
+ position_ids: (batch, seqlen)
97
+ token_type_ids: (batch, seqlen)
98
+ """
99
+ batch_size, seqlen = input_ids.shape
100
+ embeddings = self.word_embeddings(input_ids)
101
+ if self.max_position_embeddings > 0:
102
+ if position_ids is None:
103
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
104
+ position_embeddings = self.position_embeddings(position_ids)
105
+ embeddings = embeddings + position_embeddings
106
+ if self.type_vocab_size > 0:
107
+ if token_type_ids is None:
108
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
109
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
110
+ embeddings = embeddings + token_type_embeddings
111
+ return embeddings
112
+
113
+
114
+ class VocabParallelEmbedding(nn.Embedding):
115
+ def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
116
+ self.process_group = process_group
117
+ if process_group is not None:
118
+ world_size = torch.distributed.get_world_size(process_group)
119
+ if num_embeddings % world_size != 0:
120
+ raise ValueError(
121
+ f"num_embeddings ({num_embeddings}) must be divisible by "
122
+ f"world_size ({world_size})"
123
+ )
124
+ if world_size > 1 and padding_idx is not None:
125
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
126
+ else:
127
+ world_size = 1
128
+ super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
129
+
130
+ def forward(self, input: Tensor) -> Tensor:
131
+ if self.process_group is None:
132
+ return super().forward(input)
133
+ else:
134
+ rank = torch.distributed.get_rank(self.process_group)
135
+ vocab_size = self.num_embeddings
136
+ vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
137
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
138
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
139
+ input = input - vocab_start_index
140
+ input[input_ids_mask] = 0
141
+ embeddings = super().forward(input)
142
+ embeddings[input_ids_mask] = 0.0
143
+ return embeddings
144
+
145
+
146
+ class ColumnParallelEmbedding(nn.Embedding):
147
+ def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
148
+ self.process_group = process_group
149
+ if process_group is not None:
150
+ world_size = torch.distributed.get_world_size(process_group)
151
+ if embedding_dim % world_size != 0:
152
+ raise ValueError(
153
+ f"embedding_dim ({embedding_dim}) must be divisible by "
154
+ f"world_size ({world_size})"
155
+ )
156
+ else:
157
+ world_size = 1
158
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
159
+
160
+
161
+ class ParallelGPT2Embeddings(nn.Module):
162
+ def __init__(
163
+ self,
164
+ embed_dim,
165
+ vocab_size,
166
+ max_position_embeddings,
167
+ process_group,
168
+ padding_idx=None,
169
+ sequence_parallel=True,
170
+ device=None,
171
+ dtype=None,
172
+ ):
173
+ """
174
+ If max_position_embeddings <= 0, there's no position embeddings
175
+ """
176
+ factory_kwargs = {"device": device, "dtype": dtype}
177
+ super().__init__()
178
+ self.process_group = process_group
179
+ self.sequence_parallel = sequence_parallel
180
+ self.word_embeddings = VocabParallelEmbedding(
181
+ vocab_size,
182
+ embed_dim,
183
+ padding_idx=padding_idx,
184
+ process_group=process_group,
185
+ **factory_kwargs,
186
+ )
187
+ self.max_position_embeddings = max_position_embeddings
188
+ if self.max_position_embeddings > 0:
189
+ self.position_embeddings = ColumnParallelEmbedding(
190
+ max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
191
+ )
192
+
193
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
194
+ """
195
+ input_ids: (batch, seqlen)
196
+ position_ids: (batch, seqlen)
197
+ """
198
+ batch_size, seqlen = input_ids.shape
199
+ world_size = torch.distributed.get_world_size(self.process_group)
200
+ embeddings = self.word_embeddings(input_ids)
201
+ if self.max_position_embeddings > 0:
202
+ if position_ids is None:
203
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
204
+ position_embeddings = self.position_embeddings(position_ids)
205
+ if world_size <= 1:
206
+ embeddings = embeddings + position_embeddings
207
+ else:
208
+ partition_dim = self.position_embeddings.embedding_dim
209
+ rank = torch.distributed.get_rank(self.process_group)
210
+ embeddings[
211
+ ..., rank * partition_dim : (rank + 1) * partition_dim
212
+ ] += position_embeddings
213
+ if combine_batch_seqlen_dim:
214
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
215
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
216
+ return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mha.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, repeat
9
+
10
+ from flash_attn.utils.distributed import get_dim_for_local_rank
11
+
12
+ try:
13
+ from flash_attn import (
14
+ flash_attn_kvpacked_func,
15
+ flash_attn_qkvpacked_func,
16
+ flash_attn_varlen_kvpacked_func,
17
+ flash_attn_varlen_qkvpacked_func,
18
+ flash_attn_with_kvcache,
19
+ )
20
+ except ImportError:
21
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
+ flash_attn_with_kvcache = None
24
+
25
+ try:
26
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
27
+ except ImportError:
28
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
+
30
+ try:
31
+ from flash_attn.layers.rotary import RotaryEmbedding
32
+ except ImportError:
33
+ RotaryEmbedding = None
34
+
35
+
36
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
37
+ def get_alibi_slopes(nheads):
38
+ def get_slopes_power_of_2(nheads):
39
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
40
+ ratio = start
41
+ return [start * ratio**i for i in range(nheads)]
42
+
43
+ if math.log2(nheads).is_integer():
44
+ return get_slopes_power_of_2(nheads)
45
+ else:
46
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
47
+ return (
48
+ get_slopes_power_of_2(closest_power_of_2)
49
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
50
+ )
51
+
52
+
53
+ class FlashSelfAttention(nn.Module):
54
+ """Implement the scaled dot product attention with softmax.
55
+ Arguments
56
+ ---------
57
+ softmax_scale: The temperature to use for the softmax attention.
58
+ (default: 1/sqrt(d_keys) where d_keys is computed at
59
+ runtime)
60
+ attention_dropout: The dropout rate to apply to the attention
61
+ (default: 0.0)
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ causal=False,
67
+ softmax_scale=None,
68
+ attention_dropout=0.0,
69
+ window_size=(-1, -1),
70
+ alibi_slopes=None,
71
+ deterministic=False,
72
+ ):
73
+ super().__init__()
74
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
75
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
76
+ self.causal = causal
77
+ self.softmax_scale = softmax_scale
78
+ self.drop = nn.Dropout(attention_dropout)
79
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
80
+ self.window_size = window_size
81
+ self.deterministic = deterministic
82
+
83
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
84
+ """Implements the multihead softmax attention.
85
+ Arguments
86
+ ---------
87
+ qkv: The tensor containing the query, key, and value.
88
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
89
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
90
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
91
+ causal: if passed, will override self.causal
92
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
93
+ of the sequences in the batch, used to index into qkv.
94
+ max_seqlen: int. Maximum sequence length in the batch.
95
+ Returns:
96
+ --------
97
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
98
+ else (B, S, H, D).
99
+ """
100
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
101
+ assert qkv.is_cuda
102
+ causal = self.causal if causal is None else causal
103
+ unpadded = cu_seqlens is not None
104
+ if self.alibi_slopes is not None:
105
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
106
+ if unpadded:
107
+ assert cu_seqlens.dtype == torch.int32
108
+ assert max_seqlen is not None
109
+ assert isinstance(max_seqlen, int)
110
+ return flash_attn_varlen_qkvpacked_func(
111
+ qkv,
112
+ cu_seqlens,
113
+ max_seqlen,
114
+ self.drop.p if self.training else 0.0,
115
+ softmax_scale=self.softmax_scale,
116
+ causal=causal,
117
+ alibi_slopes=self.alibi_slopes,
118
+ window_size=self.window_size,
119
+ deterministic=self.deterministic,
120
+ )
121
+ else:
122
+ return flash_attn_qkvpacked_func(
123
+ qkv,
124
+ self.drop.p if self.training else 0.0,
125
+ softmax_scale=self.softmax_scale,
126
+ causal=causal,
127
+ alibi_slopes=self.alibi_slopes,
128
+ window_size=self.window_size,
129
+ deterministic=self.deterministic,
130
+ )
131
+
132
+
133
+ class FlashCrossAttention(nn.Module):
134
+ """Implement the scaled dot product attention with softmax.
135
+ Arguments
136
+ ---------
137
+ softmax_scale: The temperature to use for the softmax attention.
138
+ (default: 1/sqrt(d_keys) where d_keys is computed at
139
+ runtime)
140
+ attention_dropout: The dropout rate to apply to the attention
141
+ (default: 0.0)
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ causal=False,
147
+ softmax_scale=None,
148
+ attention_dropout=0.0,
149
+ alibi_slopes=None,
150
+ window_size=(-1, -1),
151
+ deterministic=False,
152
+ ):
153
+ super().__init__()
154
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
155
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
156
+ self.causal = causal
157
+ self.softmax_scale = softmax_scale
158
+ self.drop = nn.Dropout(attention_dropout)
159
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
160
+ self.window_size = window_size
161
+ self.deterministic = deterministic
162
+
163
+ def forward(
164
+ self,
165
+ q,
166
+ kv,
167
+ causal=None,
168
+ cu_seqlens=None,
169
+ max_seqlen=None,
170
+ cu_seqlens_k=None,
171
+ max_seqlen_k=None,
172
+ ):
173
+ """Implements the multihead softmax attention.
174
+ Arguments
175
+ ---------
176
+ q: The tensor containing the query. (B, Sq, H, D)
177
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
178
+ causal: if passed, will override self.causal
179
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
180
+ of the sequences in the batch, used to index into q.
181
+ max_seqlen: int. Maximum sequence length in the batch of q.
182
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
183
+ of the sequences in the batch, used to index into kv.
184
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
185
+ """
186
+ assert q.dtype in [torch.float16, torch.bfloat16]
187
+ assert q.is_cuda and kv.is_cuda
188
+ causal = self.causal if causal is None else causal
189
+ unpadded = cu_seqlens is not None
190
+ if self.alibi_slopes is not None:
191
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
192
+ if unpadded:
193
+ assert cu_seqlens.dtype == torch.int32
194
+ assert max_seqlen is not None
195
+ assert isinstance(max_seqlen, int)
196
+ assert cu_seqlens_k is not None
197
+ assert cu_seqlens_k.dtype == torch.int32
198
+ assert max_seqlen_k is not None
199
+ assert isinstance(max_seqlen_k, int)
200
+ return flash_attn_varlen_kvpacked_func(
201
+ q,
202
+ kv,
203
+ cu_seqlens,
204
+ cu_seqlens_k,
205
+ max_seqlen,
206
+ max_seqlen_k,
207
+ self.drop.p if self.training else 0.0,
208
+ softmax_scale=self.softmax_scale,
209
+ causal=causal,
210
+ alibi_slopes=self.alibi_slopes,
211
+ window_size=self.window_size,
212
+ deterministic=self.deterministic,
213
+ )
214
+ else:
215
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
216
+ seqlen_k = kv.shape[1]
217
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
218
+ return flash_attn_kvpacked_func(
219
+ q,
220
+ kv,
221
+ self.drop.p if self.training else 0.0,
222
+ causal=causal,
223
+ softmax_scale=self.softmax_scale,
224
+ alibi_slopes=self.alibi_slopes,
225
+ window_size=self.window_size,
226
+ deterministic=self.deterministic,
227
+ )
228
+
229
+
230
+ class SelfAttention(nn.Module):
231
+ """Implement the scaled dot product attention with softmax.
232
+ Arguments
233
+ ---------
234
+ softmax_scale: The temperature to use for the softmax attention.
235
+ (default: 1/sqrt(d_keys) where d_keys is computed at
236
+ runtime)
237
+ attention_dropout: The dropout rate to apply to the attention
238
+ (default: 0.0)
239
+ """
240
+
241
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
242
+ super().__init__()
243
+ self.causal = causal
244
+ self.softmax_scale = softmax_scale
245
+ self.drop = nn.Dropout(attention_dropout)
246
+
247
+ def forward(self, qkv, causal=None, key_padding_mask=None):
248
+ """Implements the multihead softmax attention.
249
+ Arguments
250
+ ---------
251
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
252
+ causal: if passed, will override self.causal
253
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
254
+ False means to mask out. (B, S)
255
+ """
256
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
257
+ causal = self.causal if causal is None else causal
258
+ q, k, v = qkv.unbind(dim=2)
259
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
260
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
261
+ if key_padding_mask is not None:
262
+ padding_mask = torch.full(
263
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
264
+ )
265
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
266
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
267
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
268
+ if causal:
269
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
270
+ # So we have to construct the mask in float
271
+ causal_mask = torch.triu(
272
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
273
+ )
274
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
275
+ scores = scores + causal_mask.to(dtype=scores.dtype)
276
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
277
+ attention_drop = self.drop(attention)
278
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
279
+ return output
280
+
281
+
282
+ class CrossAttention(nn.Module):
283
+ """Implement the scaled dot product attention with softmax.
284
+ Arguments
285
+ ---------
286
+ softmax_scale: The temperature to use for the softmax attention.
287
+ (default: 1/sqrt(d_keys) where d_keys is computed at
288
+ runtime)
289
+ attention_dropout: The dropout rate to apply to the attention
290
+ (default: 0.0)
291
+ """
292
+
293
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
294
+ super().__init__()
295
+ self.causal = causal
296
+ self.softmax_scale = softmax_scale
297
+ self.drop = nn.Dropout(attention_dropout)
298
+
299
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
300
+ """Implements the multihead softmax attention.
301
+ Arguments
302
+ ---------
303
+ q: The tensor containing the query. (B, Sq, H, D)
304
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
305
+ causal: if passed, will override self.causal
306
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
307
+ False means to mask out. (B, Sk)
308
+ """
309
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
310
+ causal = self.causal if causal is None else causal
311
+ seqlen_k = kv.shape[1]
312
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
313
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
314
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
315
+ k, v = kv.unbind(dim=2)
316
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
317
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
318
+ if key_padding_mask is not None:
319
+ padding_mask = torch.full(
320
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
321
+ )
322
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
323
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
324
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
325
+ if causal:
326
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
327
+ row_idx = rearrange(
328
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
329
+ )
330
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
331
+ sk = (
332
+ seqlen_k
333
+ if key_padding_mask is None
334
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
335
+ )
336
+ causal_mask = col_idx > row_idx + sk - seqlen_q
337
+ scores = scores.masked_fill(causal_mask, -10000.0)
338
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
339
+ attention_drop = self.drop(attention)
340
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
341
+ return output
342
+
343
+
344
+ class LinearResidual(nn.Linear):
345
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
346
+
347
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
348
+ return super().forward(input), input
349
+
350
+
351
+ def _update_kv_cache(kv, inference_params, layer_idx):
352
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
353
+ # Pre-allocate memory for key-values for inference.
354
+ num_heads, head_dim = kv.shape[-2:]
355
+ if layer_idx not in inference_params.key_value_memory_dict:
356
+ kv_cache = torch.empty(
357
+ inference_params.max_batch_size,
358
+ inference_params.max_seqlen,
359
+ 2,
360
+ num_heads,
361
+ head_dim,
362
+ dtype=kv.dtype,
363
+ device=kv.device,
364
+ )
365
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
366
+ else:
367
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
368
+ # Adjust key and value for inference
369
+ batch_start = inference_params.batch_size_offset
370
+ batch_end = batch_start + kv.shape[0]
371
+ sequence_start = inference_params.seqlen_offset
372
+ sequence_end = sequence_start + kv.shape[1]
373
+ assert batch_end <= kv_cache.shape[0]
374
+ assert sequence_end <= kv_cache.shape[1]
375
+ assert kv_cache is not None
376
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
377
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
378
+
379
+
380
+ class MHA(nn.Module):
381
+ """Multi-head self-attention and cross-attention"""
382
+
383
+ def __init__(
384
+ self,
385
+ embed_dim,
386
+ num_heads,
387
+ num_heads_kv=None,
388
+ cross_attn=False,
389
+ qkv_proj_bias=True,
390
+ out_proj_bias=True,
391
+ dropout=0.0,
392
+ softmax_scale=None,
393
+ causal=False,
394
+ layer_idx=None,
395
+ dwconv=False,
396
+ rotary_emb_dim=0,
397
+ rotary_emb_base=10000.0,
398
+ rotary_emb_scale_base=None,
399
+ rotary_emb_interleaved=False,
400
+ use_alibi=False,
401
+ window_size=(-1, -1),
402
+ fused_bias_fc=False,
403
+ use_flash_attn=False,
404
+ return_residual=False,
405
+ checkpointing=False,
406
+ device=None,
407
+ dtype=None,
408
+ ) -> None:
409
+ """
410
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
411
+ return_residual: whether to return the input x along with the output. This is for
412
+ performance reason: for post-norm architecture, returning the input allows us
413
+ to fuse the backward of nn.Linear with the residual connection.
414
+ """
415
+ factory_kwargs = {"device": device, "dtype": dtype}
416
+ super().__init__()
417
+ self.embed_dim = embed_dim
418
+ self.cross_attn = cross_attn
419
+ self.causal = causal
420
+ self.layer_idx = layer_idx
421
+ self.dwconv = dwconv
422
+ self.rotary_emb_dim = rotary_emb_dim
423
+ self.use_flash_attn = use_flash_attn
424
+ self.return_residual = return_residual
425
+ self.checkpointing = checkpointing
426
+ if use_alibi:
427
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
428
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
429
+ else:
430
+ alibi_slopes = None
431
+ if window_size != (-1, -1):
432
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
433
+
434
+ self.num_heads = num_heads
435
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
436
+ assert (
437
+ self.num_heads % self.num_heads_kv == 0
438
+ ), "num_heads must be divisible by num_heads_kv"
439
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
440
+ self.head_dim = self.embed_dim // num_heads
441
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
442
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
443
+
444
+ if self.rotary_emb_dim > 0:
445
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
446
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
447
+ self.rotary_emb = RotaryEmbedding(
448
+ self.rotary_emb_dim,
449
+ base=rotary_emb_base,
450
+ scale_base=rotary_emb_scale_base,
451
+ interleaved=rotary_emb_interleaved,
452
+ device=device,
453
+ )
454
+
455
+ if fused_bias_fc and FusedDense is None:
456
+ raise ImportError("fused_dense is not installed")
457
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
458
+ linear_resid_cls = (
459
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
460
+ )
461
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
462
+ inner_attn_cls = (
463
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
464
+ if use_flash_attn
465
+ else SelfAttention
466
+ )
467
+ inner_cross_attn_cls = (
468
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
469
+ if use_flash_attn
470
+ else CrossAttention
471
+ )
472
+ if not self.cross_attn:
473
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
474
+ else:
475
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
476
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
477
+ if self.dwconv:
478
+ if self.num_heads_kv == self.num_heads:
479
+ self.dwconv_qkv = nn.Conv1d(
480
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
481
+ )
482
+ else:
483
+ self.dwconv_q = nn.Conv1d(
484
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
485
+ )
486
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
487
+ self.inner_attn = inner_attn_cls(
488
+ causal=causal,
489
+ softmax_scale=softmax_scale,
490
+ attention_dropout=dropout,
491
+ )
492
+ self.inner_cross_attn = inner_cross_attn_cls(
493
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
494
+ )
495
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
496
+
497
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
498
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
499
+ device = self.out_proj.weight.device
500
+ return torch.empty(
501
+ batch_size,
502
+ max_seqlen,
503
+ 2,
504
+ self.num_heads_kv,
505
+ self.head_dim,
506
+ dtype=dtype,
507
+ device=device,
508
+ )
509
+
510
+ def _update_kv_cache(self, kv, inference_params):
511
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
512
+ assert not self.dwconv, "Generation does not support dwconv yet"
513
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
514
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
515
+
516
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
517
+ """
518
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
519
+ q: (batch_size, seqlen_q, nheads, head_dim)
520
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
521
+ """
522
+ assert inference_params is not None and inference_params.seqlen_offset > 0
523
+ assert self.use_flash_attn
524
+ if self.rotary_emb_dim > 0:
525
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
526
+ self.rotary_emb._update_cos_sin_cache(
527
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
528
+ )
529
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
530
+ else:
531
+ rotary_cos, rotary_sin = None, None
532
+ batch = q.shape[0]
533
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
534
+ cache_seqlens = (
535
+ inference_params.lengths_per_sample[:batch]
536
+ if inference_params.lengths_per_sample is not None
537
+ else inference_params.seqlen_offset
538
+ )
539
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
540
+ context = flash_attn_with_kvcache(
541
+ q,
542
+ kv_cache[:, :, 0],
543
+ kv_cache[:, :, 1],
544
+ kv[:, :, 0],
545
+ kv[:, :, 1],
546
+ rotary_cos=rotary_cos,
547
+ rotary_sin=rotary_sin,
548
+ cache_seqlens=cache_seqlens,
549
+ softmax_scale=self.inner_cross_attn.softmax_scale,
550
+ causal=self.inner_cross_attn.causal,
551
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
552
+ alibi_slopes=alibi_slopes,
553
+ )
554
+ return context
555
+
556
+ def _update_kvcache_attention(self, q, kv, inference_params):
557
+ """Write kv to inference_params, then do attention"""
558
+ if (
559
+ inference_params.seqlen_offset == 0
560
+ or flash_attn_with_kvcache is None
561
+ or not self.use_flash_attn
562
+ ):
563
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
564
+ kv = self._update_kv_cache(kv, inference_params)
565
+ return self.inner_cross_attn(q, kv)
566
+ else:
567
+ batch = q.shape[0]
568
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
569
+ cache_seqlens = (
570
+ inference_params.lengths_per_sample[:batch]
571
+ if inference_params.lengths_per_sample is not None
572
+ else inference_params.seqlen_offset
573
+ )
574
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
575
+ return flash_attn_with_kvcache(
576
+ q,
577
+ kv_cache[:, :, 0],
578
+ kv_cache[:, :, 1],
579
+ kv[:, :, 0],
580
+ kv[:, :, 1],
581
+ cache_seqlens=cache_seqlens,
582
+ softmax_scale=self.inner_cross_attn.softmax_scale,
583
+ causal=self.inner_cross_attn.causal,
584
+ alibi_slopes=alibi_slopes,
585
+ )
586
+
587
+ def forward(
588
+ self,
589
+ x,
590
+ x_kv=None,
591
+ key_padding_mask=None,
592
+ cu_seqlens=None,
593
+ max_seqlen=None,
594
+ mixer_subset=None,
595
+ inference_params=None,
596
+ **kwargs,
597
+ ):
598
+ """
599
+ Arguments:
600
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
601
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
602
+ is the is the sum of the sequence lengths in the batch.
603
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
604
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
605
+ of the sequences in the batch, used to index into x. Only applicable when using
606
+ FlashAttention.
607
+ max_seqlen: int. Maximum sequence length in the batch.
608
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
609
+ (batch, seqlen). Only applicable when not using FlashAttention.
610
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
611
+ before applying the query projection. Useful for e.g., ViT where we only care
612
+ about the CLS token in the last layer.
613
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
614
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
615
+ """
616
+ if cu_seqlens is not None:
617
+ assert max_seqlen is not None
618
+ assert key_padding_mask is None
619
+ assert self.use_flash_attn
620
+ assert not self.dwconv
621
+ assert self.rotary_emb_dim == 0
622
+ if key_padding_mask is not None:
623
+ assert cu_seqlens is None
624
+ assert max_seqlen is None
625
+ assert not self.use_flash_attn
626
+ if inference_params is not None:
627
+ assert key_padding_mask is None
628
+ assert cu_seqlens is None and max_seqlen is None
629
+ assert not self.dwconv
630
+
631
+ kwargs = (
632
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
633
+ if self.use_flash_attn
634
+ else {"key_padding_mask": key_padding_mask, **kwargs}
635
+ )
636
+ seqlen_offset = (
637
+ 0
638
+ if inference_params is None
639
+ else (
640
+ inference_params.lengths_per_sample
641
+ if inference_params.lengths_per_sample is not None
642
+ else inference_params.seqlen_offset
643
+ )
644
+ )
645
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
646
+ batch, seqlen = x.shape[:2]
647
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
+ assert x_kv is None and mixer_subset is None
649
+ if not self.return_residual:
650
+ qkv = self.Wqkv(x)
651
+ else:
652
+ qkv, x = self.Wqkv(x)
653
+ if self.dwconv:
654
+ qkv = rearrange(
655
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
656
+ ).contiguous()
657
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
658
+ if (
659
+ inference_params is None
660
+ or inference_params.seqlen_offset == 0
661
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
662
+ or not self.use_flash_attn
663
+ ):
664
+ if self.rotary_emb_dim > 0:
665
+ qkv = self.rotary_emb(
666
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
667
+ )
668
+ if inference_params is None:
669
+ if not self.checkpointing:
670
+ context = self.inner_attn(qkv, **kwargs)
671
+ else:
672
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
673
+ else:
674
+ context = self._update_kvcache_attention(
675
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
676
+ )
677
+ else:
678
+ context = self._apply_rotary_update_kvcache_attention(
679
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
680
+ )
681
+ else:
682
+ if self.cross_attn:
683
+ if not self.return_residual:
684
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
685
+ kv = self.Wkv(x_kv if x_kv is not None else x)
686
+ else:
687
+ if x_kv is not None:
688
+ kv, x_kv = self.Wkv(x_kv)
689
+ else:
690
+ kv, x = self.Wkv(x)
691
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
692
+ else:
693
+ assert self.num_heads_kv != self.num_heads
694
+ if not self.return_residual:
695
+ qkv = self.Wqkv(x)
696
+ else:
697
+ qkv, x = self.Wqkv(x)
698
+ q = qkv[..., : self.num_heads * self.head_dim]
699
+ kv = qkv[..., self.num_heads * self.head_dim :]
700
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
701
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
702
+ if self.dwconv:
703
+ q = rearrange(
704
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
705
+ ).contiguous()
706
+ kv = rearrange(
707
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
708
+ ).contiguous()
709
+ if (
710
+ inference_params is None
711
+ or inference_params.seqlen_offset == 0
712
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
713
+ or not self.use_flash_attn
714
+ ):
715
+ if self.rotary_emb_dim > 0:
716
+ q, kv = self.rotary_emb(
717
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
718
+ )
719
+ if inference_params is None:
720
+ if not self.checkpointing:
721
+ context = self.inner_cross_attn(q, kv, **kwargs)
722
+ else:
723
+ context = torch.utils.checkpoint.checkpoint(
724
+ self.inner_cross_attn, q, kv, **kwargs
725
+ )
726
+ else:
727
+ context = self._update_kvcache_attention(q, kv, inference_params)
728
+ else:
729
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
730
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
731
+ return out if not self.return_residual else (out, x)
732
+
733
+
734
+ class ParallelMHA(nn.Module):
735
+ """Multi-head self-attention and cross-attention"""
736
+
737
+ def __init__(
738
+ self,
739
+ embed_dim,
740
+ num_heads,
741
+ process_group,
742
+ num_heads_kv=None,
743
+ qkv_proj_bias=True,
744
+ out_proj_bias=True,
745
+ dropout=0.0,
746
+ softmax_scale=None,
747
+ causal=False,
748
+ layer_idx=None,
749
+ rotary_emb_dim=0,
750
+ rotary_emb_base=10000.0,
751
+ rotary_emb_scale_base=None,
752
+ rotary_emb_interleaved=False,
753
+ use_alibi=False,
754
+ window_size=(-1, -1),
755
+ use_flash_attn=False,
756
+ checkpointing=False,
757
+ sequence_parallel=True,
758
+ device=None,
759
+ dtype=None,
760
+ ) -> None:
761
+ factory_kwargs = {"device": device, "dtype": dtype}
762
+ super().__init__()
763
+ self.embed_dim = embed_dim
764
+ self.causal = causal
765
+ self.layer_idx = layer_idx
766
+ self.rotary_emb_dim = rotary_emb_dim
767
+ self.use_flash_attn = use_flash_attn
768
+ self.checkpointing = checkpointing
769
+ self.process_group = process_group
770
+ self.world_size = process_group.size()
771
+ self.local_rank = torch.distributed.get_rank(process_group)
772
+
773
+ self.num_heads = num_heads
774
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
775
+
776
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
777
+ assert (
778
+ self.num_heads % self.num_heads_kv == 0
779
+ ), "num_heads must be divisible by num_heads_kv"
780
+
781
+ self.num_heads_per_rank = get_dim_for_local_rank(
782
+ self.num_heads, self.world_size, self.local_rank
783
+ )
784
+ self.num_heads_kv_per_rank = get_dim_for_local_rank(
785
+ self.num_heads_kv, self.world_size, self.local_rank
786
+ )
787
+ self.head_dim = self.embed_dim // num_heads
788
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
789
+
790
+ if use_alibi:
791
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
792
+ num_heads_local = math.ceil(self.num_heads / self.world_size)
793
+ alibi_slopes = torch.tensor(
794
+ get_alibi_slopes(num_heads)[
795
+ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
796
+ ],
797
+ device=device,
798
+ )
799
+ else:
800
+ alibi_slopes = None
801
+ if window_size != (-1, -1):
802
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
803
+
804
+ if self.rotary_emb_dim > 0:
805
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
806
+ self.rotary_emb = RotaryEmbedding(
807
+ self.rotary_emb_dim,
808
+ base=rotary_emb_base,
809
+ scale_base=rotary_emb_scale_base,
810
+ interleaved=rotary_emb_interleaved,
811
+ device=device,
812
+ )
813
+
814
+ if ColumnParallelLinear is None or RowParallelLinear is None:
815
+ raise ImportError("fused_dense is not installed")
816
+ self.Wqkv = ColumnParallelLinear(
817
+ embed_dim,
818
+ qkv_dim,
819
+ process_group,
820
+ bias=qkv_proj_bias,
821
+ sequence_parallel=sequence_parallel,
822
+ multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
823
+ **factory_kwargs,
824
+ )
825
+ inner_attn_cls = (
826
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
827
+ if use_flash_attn
828
+ else SelfAttention
829
+ )
830
+ inner_cross_attn_cls = (
831
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
832
+ if use_flash_attn
833
+ else CrossAttention
834
+ )
835
+ self.inner_attn = inner_attn_cls(
836
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
837
+ )
838
+ self.inner_cross_attn = inner_cross_attn_cls(
839
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
840
+ )
841
+ self.out_proj = RowParallelLinear(
842
+ embed_dim,
843
+ embed_dim,
844
+ process_group,
845
+ bias=out_proj_bias,
846
+ sequence_parallel=sequence_parallel,
847
+ multiple_of=self.head_dim,
848
+ **factory_kwargs,
849
+ )
850
+
851
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
852
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
853
+ device = self.out_proj.weight.device
854
+ return torch.empty(
855
+ batch_size,
856
+ max_seqlen,
857
+ 2,
858
+ self.num_heads_kv_per_rank,
859
+ self.head_dim,
860
+ dtype=dtype,
861
+ device=device,
862
+ )
863
+
864
+ def _update_kv_cache(self, kv, inference_params):
865
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
866
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
867
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
868
+
869
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
870
+ """
871
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
872
+ q: (batch_size, seqlen_q, nheads, head_dim)
873
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
874
+ """
875
+ assert inference_params is not None and inference_params.seqlen_offset > 0
876
+ assert self.use_flash_attn
877
+ if self.rotary_emb_dim > 0:
878
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
879
+ self.rotary_emb._update_cos_sin_cache(
880
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
881
+ )
882
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
883
+ else:
884
+ rotary_cos, rotary_sin = None, None
885
+ batch = q.shape[0]
886
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
887
+ cache_seqlens = (
888
+ inference_params.lengths_per_sample[:batch]
889
+ if inference_params.lengths_per_sample is not None
890
+ else inference_params.seqlen_offset
891
+ )
892
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
893
+ context = flash_attn_with_kvcache(
894
+ q,
895
+ kv_cache[:, :, 0],
896
+ kv_cache[:, :, 1],
897
+ kv[:, :, 0],
898
+ kv[:, :, 1],
899
+ rotary_cos=rotary_cos,
900
+ rotary_sin=rotary_sin,
901
+ cache_seqlens=cache_seqlens,
902
+ softmax_scale=self.inner_cross_attn.softmax_scale,
903
+ causal=self.inner_cross_attn.causal,
904
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
905
+ alibi_slopes=alibi_slopes,
906
+ )
907
+ return context
908
+
909
+ def _update_kvcache_attention(self, q, kv, inference_params):
910
+ """Write kv to inference_params, then do attention"""
911
+ if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
912
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
913
+ kv = self._update_kv_cache(kv, inference_params)
914
+ return self.inner_cross_attn(q, kv)
915
+ else:
916
+ batch = q.shape[0]
917
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
918
+ cache_seqlens = (
919
+ inference_params.lengths_per_sample[:batch]
920
+ if inference_params.lengths_per_sample is not None
921
+ else inference_params.seqlen_offset
922
+ )
923
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
924
+ context = flash_attn_with_kvcache(
925
+ q,
926
+ kv_cache[:, :, 0],
927
+ kv_cache[:, :, 1],
928
+ kv[:, :, 0],
929
+ kv[:, :, 1],
930
+ cache_seqlens=cache_seqlens,
931
+ softmax_scale=self.inner_cross_attn.softmax_scale,
932
+ causal=self.inner_cross_attn.causal,
933
+ alibi_slopes=alibi_slopes,
934
+ )
935
+ return context
936
+
937
+ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
938
+ """
939
+ Arguments:
940
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
941
+ If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
942
+ split x during sequence parallel, we split the batch * seqlen dimension
943
+ (in case batch is small).
944
+ """
945
+ qkv = self.Wqkv(x)
946
+ if seqlen is not None:
947
+ qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
948
+ seqlen_offset = (
949
+ 0
950
+ if inference_params is None
951
+ else (
952
+ inference_params.lengths_per_sample
953
+ if inference_params.lengths_per_sample is not None
954
+ else inference_params.seqlen_offset
955
+ )
956
+ )
957
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
958
+ if self.num_heads_kv == self.num_heads:
959
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
960
+ if (
961
+ inference_params is None
962
+ or inference_params.seqlen_offset == 0
963
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
964
+ or not self.use_flash_attn
965
+ ):
966
+ if self.rotary_emb_dim > 0:
967
+ qkv = self.rotary_emb(
968
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
969
+ )
970
+ if inference_params is None:
971
+ if not self.checkpointing:
972
+ context = self.inner_attn(qkv, **kwargs)
973
+ else:
974
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
975
+ else:
976
+ context = self._update_kvcache_attention(
977
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
978
+ )
979
+ else:
980
+ context = self._apply_rotary_update_kvcache_attention(
981
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
982
+ )
983
+ else:
984
+ q = rearrange(
985
+ qkv[..., : self.num_heads_per_rank * self.head_dim],
986
+ "... (h d) -> ... h d",
987
+ d=self.head_dim,
988
+ )
989
+ kv = rearrange(
990
+ qkv[..., self.num_heads_per_rank * self.head_dim :],
991
+ "... (two hkv d) -> ... two hkv d",
992
+ two=2,
993
+ d=self.head_dim,
994
+ )
995
+ if (
996
+ inference_params is None
997
+ or inference_params.seqlen_offset == 0
998
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
999
+ or not self.use_flash_attn
1000
+ ):
1001
+ if self.rotary_emb_dim > 0:
1002
+ q, kv = self.rotary_emb(
1003
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
1004
+ )
1005
+ if inference_params is None:
1006
+ if not self.checkpointing:
1007
+ context = self.inner_cross_attn(q, kv, **kwargs)
1008
+ else:
1009
+ context = torch.utils.checkpoint.checkpoint(
1010
+ self.inner_cross_attn, q, kv, **kwargs
1011
+ )
1012
+ else:
1013
+ context = self._update_kvcache_attention(q, kv, inference_params)
1014
+ else:
1015
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
1016
+ context = rearrange(context, "b s h d -> b s (h d)")
1017
+ if seqlen is not None:
1018
+ context = rearrange(context, "b s d -> (b s) d")
1019
+ out = self.out_proj(context)
1020
+ return out
.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mlp.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.distributed import ProcessGroup
7
+
8
+
9
+ try:
10
+ from flash_attn.ops.activations import swiglu
11
+ except ImportError:
12
+ swiglu = None
13
+
14
+ try:
15
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
16
+ except ImportError:
17
+ ColumnParallelLinear, RowParallelLinear = None, None
18
+
19
+ try:
20
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
21
+ except ImportError:
22
+ FusedMLP, ParallelFusedMLP = None, None
23
+
24
+
25
+ class Mlp(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_features,
29
+ hidden_features=None,
30
+ out_features=None,
31
+ activation=F.gelu,
32
+ bias1=True,
33
+ bias2=True,
34
+ return_residual=False,
35
+ device=None,
36
+ dtype=None,
37
+ ):
38
+ factory_kwargs = {"device": device, "dtype": dtype}
39
+ super().__init__()
40
+ out_features = out_features if out_features is not None else in_features
41
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
42
+ self.return_residual = return_residual
43
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
44
+ self.activation = activation
45
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
46
+
47
+ def forward(self, x):
48
+ y = self.fc1(x)
49
+ y = self.activation(y)
50
+ y = self.fc2(y)
51
+ return y if not self.return_residual else (y, x)
52
+
53
+
54
+ class ParallelMLP(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_features,
58
+ hidden_features=None,
59
+ out_features=None,
60
+ activation=F.gelu,
61
+ process_group: ProcessGroup = None,
62
+ sequence_parallel=True,
63
+ bias1=True,
64
+ bias2=True,
65
+ device=None,
66
+ dtype=None,
67
+ ):
68
+ factory_kwargs = {"device": device, "dtype": dtype}
69
+ super().__init__()
70
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
71
+ assert RowParallelLinear is not None, "Need to install fused_dense"
72
+ out_features = out_features if out_features is not None else in_features
73
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
74
+ self.fc1 = ColumnParallelLinear(
75
+ in_features,
76
+ hidden_features,
77
+ process_group,
78
+ bias=bias1,
79
+ sequence_parallel=sequence_parallel,
80
+ **factory_kwargs,
81
+ )
82
+ self.activation = activation
83
+ self.fc2 = RowParallelLinear(
84
+ hidden_features,
85
+ out_features,
86
+ process_group,
87
+ bias=bias2,
88
+ sequence_parallel=sequence_parallel,
89
+ **factory_kwargs,
90
+ )
91
+
92
+ def forward(self, x):
93
+ y = self.fc1(x)
94
+ y = self.activation(y)
95
+ y = self.fc2(y)
96
+ return y
97
+
98
+
99
+ class GatedMlp(nn.Module):
100
+ def __init__(
101
+ self,
102
+ in_features,
103
+ hidden_features=None,
104
+ out_features=None,
105
+ activation=F.sigmoid,
106
+ bias1=True,
107
+ bias2=True,
108
+ multiple_of=128,
109
+ return_residual=False,
110
+ device=None,
111
+ dtype=None,
112
+ ):
113
+ factory_kwargs = {"device": device, "dtype": dtype}
114
+ super().__init__()
115
+ out_features = out_features if out_features is not None else in_features
116
+ hidden_features = (
117
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
118
+ )
119
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
120
+ self.return_residual = return_residual
121
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
122
+ self.activation = activation
123
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
124
+
125
+ def forward(self, x):
126
+ y = self.fc1(x)
127
+ if self.activation == F.sigmoid: # Special case for GLU
128
+ y = F.glu(y, dim=-1)
129
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
130
+ y, gate = y.chunk(2, dim=-1)
131
+ y = swiglu(gate, y)
132
+ else:
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = y * self.activation(gate)
135
+ y = self.fc2(y)
136
+ return y if not self.return_residual else (y, x)
137
+
138
+
139
+ class ParallelGatedMlp(nn.Module):
140
+ """Parallel GatedMlp"""
141
+
142
+ def __init__(
143
+ self,
144
+ in_features,
145
+ process_group,
146
+ hidden_features=None,
147
+ out_features=None,
148
+ activation=F.sigmoid,
149
+ bias1=True,
150
+ bias2=True,
151
+ multiple_of=128,
152
+ sequence_parallel=True,
153
+ device=None,
154
+ dtype=None,
155
+ ):
156
+ factory_kwargs = {"device": device, "dtype": dtype}
157
+ super().__init__()
158
+ out_features = out_features if out_features is not None else in_features
159
+ hidden_features = (
160
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
161
+ )
162
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
163
+ if ColumnParallelLinear is None or RowParallelLinear is None:
164
+ raise ImportError("fused_dense is not installed")
165
+ self.fc1 = ColumnParallelLinear(
166
+ in_features,
167
+ 2 * hidden_features,
168
+ process_group,
169
+ bias=bias1,
170
+ sequence_parallel=sequence_parallel,
171
+ **factory_kwargs,
172
+ )
173
+ self.activation = activation
174
+ self.fc2 = RowParallelLinear(
175
+ hidden_features,
176
+ out_features,
177
+ process_group,
178
+ bias=bias2,
179
+ sequence_parallel=sequence_parallel,
180
+ **factory_kwargs,
181
+ )
182
+
183
+ def forward(self, x):
184
+ y = self.fc1(x)
185
+ if self.activation == F.sigmoid: # Special case for GLU
186
+ y = F.glu(y, dim=-1)
187
+ else:
188
+ y, gate = y.chunk(2, dim=-1)
189
+ y = y * self.activation(gate)
190
+ y = self.fc2(y)
191
+ return y
.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/layer_norm.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
3
+
4
+ import dropout_layer_norm
5
+ import torch
6
+ from torch.nn import init
7
+
8
+
9
+ def maybe_align(x, alignment_in_bytes=16):
10
+ """Assume that x already has last dim divisible by alignment_in_bytes"""
11
+ # TD [2023-07-04] I'm not 100% sure that clone will align the memory
12
+ # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
13
+ return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
14
+
15
+
16
+ def _dropout_add_layer_norm_forward(
17
+ x0,
18
+ residual,
19
+ gamma,
20
+ beta,
21
+ rowscale,
22
+ colscale,
23
+ dropout_p,
24
+ epsilon,
25
+ residual_in_fp32=False,
26
+ is_rms_norm=False,
27
+ ):
28
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
29
+ hidden_size = gamma.numel()
30
+ x0mat = x0.view((-1, hidden_size))
31
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
32
+ rowscale = rowscale.view(-1) if rowscale is not None else None
33
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
34
+ x0mat,
35
+ residualmat,
36
+ gamma,
37
+ beta,
38
+ rowscale,
39
+ colscale,
40
+ None,
41
+ None,
42
+ dropout_p,
43
+ epsilon,
44
+ 1.0,
45
+ 0,
46
+ None,
47
+ residual_in_fp32,
48
+ is_rms_norm,
49
+ )
50
+ # dmask is None if dropout_p == 0.0
51
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
52
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
53
+
54
+
55
+ def _dropout_add_layer_norm_backward(
56
+ dz,
57
+ dx,
58
+ x,
59
+ x0,
60
+ dmask,
61
+ mu,
62
+ rsigma,
63
+ gamma,
64
+ rowscale,
65
+ colscale,
66
+ dropout_p,
67
+ has_residual,
68
+ is_rms_norm=False,
69
+ ):
70
+ """Assume that arguments are contiguous and aligned to 16 bytes
71
+ dx == None means that it was a post-norm architecture
72
+ (x = drop(x0) + residual was not returned in the fwd).
73
+ x0 must not be None if we have colscale.
74
+ """
75
+ hidden_size = gamma.numel()
76
+ xmat = x.view((-1, hidden_size))
77
+ dzmat = dz.view(xmat.shape)
78
+ dxmat = dx.view(xmat.shape) if dx is not None else None
79
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
80
+ rowscale = rowscale.view(-1) if rowscale is not None else None
81
+ if colscale is not None:
82
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
83
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
84
+ dzmat,
85
+ dxmat,
86
+ xmat,
87
+ x0mat,
88
+ dmask,
89
+ mu,
90
+ rsigma,
91
+ gamma,
92
+ rowscale,
93
+ colscale,
94
+ None,
95
+ None,
96
+ dropout_p,
97
+ 1.0,
98
+ 0,
99
+ has_residual,
100
+ is_rms_norm,
101
+ )
102
+ # dresidualmat is None if not has_residual
103
+ if colscale is None:
104
+ return dx0mat, dresidualmat, dgamma, dbeta
105
+ else:
106
+ dcolscale = rest[0]
107
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
108
+
109
+
110
+ def _dropout_add_layer_norm_subset_forward(
111
+ x0,
112
+ residual,
113
+ gamma,
114
+ beta,
115
+ colscale,
116
+ x0_subset,
117
+ out_subset,
118
+ dropout_p,
119
+ epsilon,
120
+ rowscale_const,
121
+ out_numrows,
122
+ residual_in_fp32=False,
123
+ is_rms_norm=False,
124
+ ):
125
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
126
+ hidden_size = gamma.numel()
127
+ x0mat = x0.view((-1, hidden_size))
128
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
129
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
130
+ out_subset = out_subset.view(-1) if out_subset is not None else None
131
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
132
+ x0mat,
133
+ residualmat,
134
+ gamma,
135
+ beta,
136
+ None,
137
+ colscale,
138
+ x0_subset,
139
+ out_subset,
140
+ dropout_p,
141
+ epsilon,
142
+ rowscale_const,
143
+ out_numrows,
144
+ None,
145
+ residual_in_fp32,
146
+ is_rms_norm,
147
+ )
148
+ # dmask is None if dropout_p == 0.0
149
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
150
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
151
+
152
+
153
+ def _dropout_add_layer_norm_subset_backward(
154
+ dz,
155
+ dx,
156
+ x,
157
+ x0,
158
+ dmask,
159
+ mu,
160
+ rsigma,
161
+ gamma,
162
+ colscale,
163
+ x0_subset,
164
+ out_subset,
165
+ dropout_p,
166
+ rowscale_const,
167
+ x0_numrows,
168
+ has_residual,
169
+ is_rms_norm=False,
170
+ ):
171
+ """Assume that arguments are contiguous and aligned to 16 bytes
172
+ dx == None means that it was a post-norm architecture
173
+ (x = drop(x0) + residual was not returned in the fwd).
174
+ x0 must not be None if we have colscale.
175
+ """
176
+ hidden_size = gamma.numel()
177
+ xmat = x.view((-1, hidden_size))
178
+ dzmat = dz.view(-1, hidden_size)
179
+ dxmat = dx.view(xmat.shape) if dx is not None else None
180
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
181
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
182
+ out_subset = out_subset.view(-1) if out_subset is not None else None
183
+ if colscale is not None:
184
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
185
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
186
+ dzmat,
187
+ dxmat,
188
+ xmat,
189
+ x0mat,
190
+ dmask,
191
+ mu,
192
+ rsigma,
193
+ gamma,
194
+ None,
195
+ colscale,
196
+ x0_subset,
197
+ out_subset,
198
+ dropout_p,
199
+ rowscale_const,
200
+ x0_numrows,
201
+ has_residual,
202
+ is_rms_norm,
203
+ )
204
+ # dresidualmat is None if not has_residual
205
+ if colscale is None:
206
+ return dx0mat, dresidualmat, dgamma, dbeta
207
+ else:
208
+ dcolscale = rest[0]
209
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
210
+
211
+
212
+ def _dropout_add_layer_norm_parallel_residual_forward(
213
+ x0,
214
+ x1,
215
+ residual,
216
+ gamma0,
217
+ beta0,
218
+ gamma1,
219
+ beta1,
220
+ dropout_p,
221
+ epsilon,
222
+ residual_in_fp32=False,
223
+ is_rms_norm=False,
224
+ ):
225
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
226
+ hidden_size = gamma0.numel()
227
+ x0mat = x0.view((-1, hidden_size))
228
+ x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
229
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
230
+ (
231
+ z0mat,
232
+ z1mat,
233
+ xmat,
234
+ dmask0,
235
+ dmask1,
236
+ mu,
237
+ rsigma,
238
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
239
+ x0mat,
240
+ x1mat,
241
+ residualmat,
242
+ gamma0,
243
+ beta0,
244
+ gamma1,
245
+ beta1,
246
+ dropout_p,
247
+ epsilon,
248
+ None,
249
+ residual_in_fp32,
250
+ is_rms_norm,
251
+ )
252
+ # dmask0 and dmask1 are None if dropout_p == 0.0
253
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
254
+ return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
255
+
256
+
257
+ def _dropout_add_layer_norm_parallel_residual_backward(
258
+ dz0,
259
+ dz1,
260
+ dx,
261
+ x,
262
+ dmask0,
263
+ dmask1,
264
+ mu,
265
+ rsigma,
266
+ gamma0,
267
+ gamma1,
268
+ dropout_p,
269
+ has_x1,
270
+ has_residual,
271
+ is_rms_norm=False,
272
+ ):
273
+ """Assume that arguments are contiguous and aligned to 16 bytes
274
+ dx == None means that it was a post-norm architecture
275
+ (x = drop(x0) + residual was not returned in the fwd).
276
+ """
277
+ hidden_size = gamma0.numel()
278
+ xmat = x.view((-1, hidden_size))
279
+ dz0mat = dz0.view(xmat.shape)
280
+ dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
281
+ dxmat = dx.view(xmat.shape) if dx is not None else None
282
+ (
283
+ dx0mat,
284
+ dx1mat,
285
+ dresidualmat,
286
+ dgamma0,
287
+ dbeta0,
288
+ dgamma1,
289
+ dbeta1,
290
+ *rest,
291
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
292
+ dz0mat,
293
+ dz1mat,
294
+ dxmat,
295
+ xmat,
296
+ dmask0,
297
+ dmask1,
298
+ mu,
299
+ rsigma,
300
+ gamma0,
301
+ gamma1,
302
+ dropout_p,
303
+ has_x1,
304
+ has_residual,
305
+ is_rms_norm,
306
+ )
307
+ # dresidualmat is None if not has_residual
308
+ return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
309
+
310
+
311
+ class DropoutAddLayerNormFn(torch.autograd.Function):
312
+ @staticmethod
313
+ def forward(
314
+ ctx,
315
+ x0,
316
+ residual,
317
+ gamma,
318
+ beta,
319
+ rowscale,
320
+ colscale,
321
+ dropout_p,
322
+ epsilon,
323
+ residual_in_fp32=False,
324
+ prenorm=False,
325
+ is_rms_norm=False,
326
+ return_dmask=False,
327
+ ):
328
+ x0 = maybe_align(x0.contiguous(), 16)
329
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
330
+ gamma = maybe_align(gamma.contiguous(), 16)
331
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
332
+ rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
333
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
334
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
335
+ x0,
336
+ residual,
337
+ gamma,
338
+ beta,
339
+ rowscale,
340
+ colscale,
341
+ dropout_p,
342
+ epsilon,
343
+ residual_in_fp32,
344
+ is_rms_norm,
345
+ )
346
+ # Only need to save x0 if we need to compute gradient wrt colscale
347
+ x0_saved = x0 if colscale is not None else None
348
+ ctx.save_for_backward(
349
+ xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
350
+ )
351
+ ctx.prenorm = prenorm
352
+ ctx.dropout_p = dropout_p
353
+ ctx.has_residual = residual is not None
354
+ ctx.is_rms_norm = is_rms_norm
355
+ ctx.has_beta = beta is not None
356
+ if not return_dmask:
357
+ return (
358
+ zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
359
+ )
360
+ else:
361
+ dmask = (
362
+ dmask.view(x0.shape)
363
+ if dropout_p > 0.0
364
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
365
+ )
366
+ ctx.mark_non_differentiable(dmask)
367
+ return (
368
+ (zmat.view(x0.shape), dmask)
369
+ if not prenorm
370
+ else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
371
+ )
372
+
373
+ @staticmethod
374
+ def backward(ctx, dz, *args):
375
+ # assert dz.is_contiguous()
376
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
377
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
378
+ x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
379
+ # x0 is None if colscale is None
380
+ dropout_p = ctx.dropout_p
381
+ has_residual = ctx.has_residual
382
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
383
+ dz,
384
+ dx,
385
+ x,
386
+ x0,
387
+ dmask,
388
+ mu,
389
+ rsigma,
390
+ gamma,
391
+ rowscale,
392
+ colscale,
393
+ dropout_p,
394
+ has_residual,
395
+ ctx.is_rms_norm,
396
+ )
397
+ dx0 = dx0mat.view(x.shape)
398
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
399
+ dcolscale = rest[0] if colscale is not None else None
400
+ return (
401
+ dx0,
402
+ dresidual,
403
+ dgamma,
404
+ dbeta if ctx.has_beta else None,
405
+ None,
406
+ dcolscale,
407
+ None,
408
+ None,
409
+ None,
410
+ None,
411
+ None,
412
+ None,
413
+ )
414
+
415
+
416
+ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
417
+ @staticmethod
418
+ def forward(
419
+ ctx,
420
+ x0,
421
+ residual,
422
+ gamma,
423
+ beta,
424
+ colscale,
425
+ x0_subset,
426
+ out_subset,
427
+ dropout_p,
428
+ epsilon,
429
+ rowscale_const,
430
+ out_numrows,
431
+ residual_in_fp32=False,
432
+ prenorm=False,
433
+ is_rms_norm=False,
434
+ return_dmask=False,
435
+ ):
436
+ x0 = maybe_align(x0.contiguous(), 16)
437
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
438
+ gamma = maybe_align(gamma.contiguous(), 16)
439
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
440
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
441
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
442
+ x0,
443
+ residual,
444
+ gamma,
445
+ beta,
446
+ colscale,
447
+ x0_subset,
448
+ out_subset,
449
+ dropout_p,
450
+ epsilon,
451
+ rowscale_const,
452
+ out_numrows,
453
+ residual_in_fp32,
454
+ is_rms_norm,
455
+ )
456
+ # Only need to save x0 if we need to compute gradient wrt colscale
457
+ x0_saved = x0 if colscale is not None else None
458
+ x_shape = (-1, *x0.shape[1:])
459
+ ctx.save_for_backward(
460
+ xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
461
+ )
462
+ ctx.prenorm = prenorm
463
+ ctx.dropout_p = dropout_p
464
+ ctx.rowscale_const = rowscale_const
465
+ ctx.x0_numrows = x0.shape[:-1].numel()
466
+ ctx.has_residual = residual is not None
467
+ ctx.is_rms_norm = is_rms_norm
468
+ ctx.has_beta = beta is not None
469
+ z_shape = (-1, *x0.shape[1:])
470
+ if not return_dmask:
471
+ return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
472
+ else:
473
+ z = zmat.view(z_shape)
474
+ dmask = (
475
+ dmask.view(x0.shape)
476
+ if dropout_p > 0.0
477
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
478
+ )
479
+ ctx.mark_non_differentiable(dmask)
480
+ return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
481
+
482
+ @staticmethod
483
+ def backward(ctx, dz, *args):
484
+ # assert dz.is_contiguous()
485
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
486
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
487
+ x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
488
+ # x0 is None if colscale is None
489
+ dropout_p = ctx.dropout_p
490
+ has_residual = ctx.has_residual
491
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
492
+ dz,
493
+ dx,
494
+ x,
495
+ x0,
496
+ dmask,
497
+ mu,
498
+ rsigma,
499
+ gamma,
500
+ colscale,
501
+ x0_subset,
502
+ out_subset,
503
+ dropout_p,
504
+ ctx.rowscale_const,
505
+ ctx.x0_numrows,
506
+ has_residual,
507
+ ctx.is_rms_norm,
508
+ )
509
+ dx0 = dx0mat.view(-1, *x.shape[1:])
510
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
511
+ dcolscale = rest[0] if colscale is not None else None
512
+ return (
513
+ dx0,
514
+ dresidual,
515
+ dgamma,
516
+ dbeta if ctx.has_beta else None,
517
+ dcolscale,
518
+ None,
519
+ None,
520
+ None,
521
+ None,
522
+ None,
523
+ None,
524
+ None,
525
+ None,
526
+ None,
527
+ None,
528
+ )
529
+
530
+
531
+ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
532
+ @staticmethod
533
+ def forward(
534
+ ctx,
535
+ x0,
536
+ x1,
537
+ residual,
538
+ gamma0,
539
+ beta0,
540
+ gamma1,
541
+ beta1,
542
+ dropout_p,
543
+ epsilon,
544
+ residual_in_fp32=False,
545
+ prenorm=False,
546
+ is_rms_norm=False,
547
+ return_dmask=False,
548
+ ):
549
+ x0 = maybe_align(x0.contiguous(), 16)
550
+ x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
551
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
552
+ gamma0 = maybe_align(gamma0.contiguous(), 16)
553
+ beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
554
+ gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
555
+ beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
556
+ (
557
+ z0mat,
558
+ z1mat,
559
+ xmat,
560
+ dmask0,
561
+ dmask1,
562
+ mu,
563
+ rsigma,
564
+ ) = _dropout_add_layer_norm_parallel_residual_forward(
565
+ x0,
566
+ x1,
567
+ residual,
568
+ gamma0,
569
+ beta0,
570
+ gamma1,
571
+ beta1,
572
+ dropout_p,
573
+ epsilon,
574
+ residual_in_fp32,
575
+ is_rms_norm,
576
+ )
577
+ ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
578
+ ctx.prenorm = prenorm
579
+ ctx.dropout_p = dropout_p
580
+ ctx.has_x1 = x1 is not None
581
+ ctx.has_residual = residual is not None
582
+ ctx.is_rms_norm = is_rms_norm
583
+ ctx.has_beta = beta0 is not None
584
+ z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
585
+ if not return_dmask:
586
+ return z if not prenorm else (*z, xmat.view(x0.shape))
587
+ else:
588
+ dmask0 = (
589
+ dmask0.view(x0.shape)
590
+ if dropout_p > 0.0
591
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
592
+ )
593
+ dmask1 = (
594
+ dmask1.view(x0.shape)
595
+ if dropout_p > 0.0 and x1 is not None
596
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
597
+ )
598
+ ctx.mark_non_differentiable(dmask0)
599
+ ctx.mark_non_differentiable(dmask1)
600
+ return (
601
+ (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
602
+ )
603
+
604
+ @staticmethod
605
+ def backward(ctx, dz0, dz1, *args):
606
+ dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
607
+ dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
608
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
609
+ x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
610
+ dropout_p = ctx.dropout_p
611
+ has_x1 = ctx.has_x1
612
+ has_residual = ctx.has_residual
613
+ (
614
+ dx0mat,
615
+ dx1mat,
616
+ dresidualmat,
617
+ dgamma0,
618
+ dbeta0,
619
+ dgamma1,
620
+ dbeta1,
621
+ ) = _dropout_add_layer_norm_parallel_residual_backward(
622
+ dz0,
623
+ dz1,
624
+ dx,
625
+ x,
626
+ dmask0,
627
+ dmask1,
628
+ mu,
629
+ rsigma,
630
+ gamma0,
631
+ gamma1,
632
+ dropout_p,
633
+ has_x1,
634
+ has_residual,
635
+ ctx.is_rms_norm,
636
+ )
637
+ dx0 = dx0mat.view(x.shape)
638
+ dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
639
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
640
+ return (
641
+ dx0,
642
+ dx1,
643
+ dresidual,
644
+ dgamma0,
645
+ dbeta0 if ctx.has_beta else None,
646
+ dgamma1,
647
+ dbeta1 if ctx.has_beta else None,
648
+ None,
649
+ None,
650
+ None,
651
+ None,
652
+ None,
653
+ None,
654
+ )
655
+
656
+
657
+ def layer_norm(x, weight, bias, epsilon):
658
+ return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
659
+
660
+
661
+ def dropout_add_layer_norm(
662
+ x0,
663
+ residual,
664
+ weight,
665
+ bias,
666
+ dropout_p,
667
+ epsilon,
668
+ rowscale=None,
669
+ layerscale=None,
670
+ prenorm=False,
671
+ residual_in_fp32=False,
672
+ return_dropout_mask=False,
673
+ ):
674
+ """residual_in_fp32 only has an effect if residual is None.
675
+ Otherwise residual dtype is residual.dtype.
676
+ """
677
+ return DropoutAddLayerNormFn.apply(
678
+ x0,
679
+ residual,
680
+ weight,
681
+ bias,
682
+ rowscale,
683
+ layerscale,
684
+ dropout_p,
685
+ epsilon,
686
+ residual_in_fp32,
687
+ prenorm,
688
+ False,
689
+ return_dropout_mask,
690
+ )
691
+
692
+
693
+ def dropout_add_layer_norm_subset(
694
+ x0,
695
+ residual,
696
+ weight,
697
+ bias,
698
+ dropout_p,
699
+ epsilon,
700
+ layerscale=None,
701
+ x0_subset=None,
702
+ out_subset=None,
703
+ rowscale_const=1.0,
704
+ out_numrows=0,
705
+ prenorm=False,
706
+ residual_in_fp32=False,
707
+ return_dropout_mask=False,
708
+ ):
709
+ """residual_in_fp32 only has an effect if residual is None.
710
+ Otherwise residual dtype is residual.dtype.
711
+ """
712
+ return DropoutAddLayerNormSubsetFn.apply(
713
+ x0,
714
+ residual,
715
+ weight,
716
+ bias,
717
+ layerscale,
718
+ x0_subset,
719
+ out_subset,
720
+ dropout_p,
721
+ epsilon,
722
+ rowscale_const,
723
+ out_numrows,
724
+ residual_in_fp32,
725
+ prenorm,
726
+ False,
727
+ return_dropout_mask,
728
+ )
729
+
730
+
731
+ def dropout_add_layer_norm_parallel_residual(
732
+ x0,
733
+ x1,
734
+ residual,
735
+ weight0,
736
+ bias0,
737
+ weight1,
738
+ bias1,
739
+ dropout_p,
740
+ epsilon,
741
+ prenorm=False,
742
+ residual_in_fp32=False,
743
+ return_dropout_mask=False,
744
+ ):
745
+ """residual_in_fp32 only has an effect if residual is None.
746
+ Otherwise residual dtype is residual.dtype.
747
+ """
748
+ return DropoutAddLayerNormParallelResidualFn.apply(
749
+ x0,
750
+ x1,
751
+ residual,
752
+ weight0,
753
+ bias0,
754
+ weight1,
755
+ bias1,
756
+ dropout_p,
757
+ epsilon,
758
+ residual_in_fp32,
759
+ prenorm,
760
+ False,
761
+ return_dropout_mask,
762
+ )
763
+
764
+
765
+ class DropoutAddLayerNorm(torch.nn.Module):
766
+ def __init__(
767
+ self,
768
+ hidden_size,
769
+ prenorm=False,
770
+ p=0.0,
771
+ eps=1e-5,
772
+ residual_in_fp32=False,
773
+ device=None,
774
+ dtype=None,
775
+ ):
776
+ factory_kwargs = {"device": device, "dtype": dtype}
777
+ super().__init__()
778
+ self.prenorm = prenorm
779
+ self.p = p
780
+ self.eps = eps
781
+ self.residual_in_fp32 = residual_in_fp32
782
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
783
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
784
+ self.reset_parameters()
785
+
786
+ def reset_parameters(self):
787
+ init.ones_(self.weight)
788
+ init.zeros_(self.bias)
789
+
790
+ def forward(self, x0, residual=None):
791
+ return dropout_add_layer_norm(
792
+ x0,
793
+ residual,
794
+ self.weight,
795
+ self.bias,
796
+ self.p if self.training else 0.0,
797
+ self.eps,
798
+ prenorm=self.prenorm,
799
+ residual_in_fp32=self.residual_in_fp32,
800
+ )
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-311.pyc ADDED
Binary file (6.26 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-311.pyc ADDED
Binary file (2.58 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-311.pyc ADDED
Binary file (9.71 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-311.pyc ADDED
Binary file (7.98 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_fetch_results.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Any, Dict
12
+
13
+ if __name__ == "__main__":
14
+ # Get the user requests
15
+ parser = argparse.ArgumentParser(
16
+ "Collect results from a given batch of distributed results"
17
+ )
18
+ parser.add_argument("-ck", "--checkpoint_path", required=True)
19
+ args = parser.parse_args()
20
+
21
+ logging.getLogger().setLevel(logging.INFO)
22
+
23
+ # Go through all the data in the given repo, try to find the end results
24
+ root = Path(args.checkpoint_path)
25
+
26
+ # - list all the mechanisms being benchmarked
27
+ results: Dict[str, Any] = {}
28
+
29
+ for attention in filter(lambda x: x.is_dir(), root.iterdir()):
30
+ logging.info(f"\nFound results for {attention.stem}")
31
+ task_jsons = attention.glob("*/test_eval_summary.json")
32
+ results[attention.stem] = {}
33
+
34
+ for task in task_jsons:
35
+ task_name = task.stem.split("__")[0]
36
+ logging.info(f"Logs found for task: {task_name}")
37
+ results[attention.stem][task_name] = -1
38
+ found_result = False
39
+
40
+ # - collect the individual results
41
+ with open(task, "r") as result_file:
42
+ dct = json.load(result_file)
43
+ if "test_accu_mean" in dct:
44
+ found_result = True
45
+ results[attention.stem][task_name] = dct["test_accu_mean"]
46
+
47
+ logging.info(
48
+ f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
49
+ f"{results[attention.stem][task_name]}"
50
+ )
51
+ else:
52
+ break
53
+
54
+ # - report an error if no result was found
55
+ if not found_result:
56
+ ERR_TAIL = 30
57
+
58
+ logging.warning(
59
+ f"No result found for {task_name}, showing the error log in {task.parent}"
60
+ )
61
+ err_log = Path(task.parent).glob("*.err")
62
+ print("*****************************************************")
63
+ with open(next(err_log), "r") as err_file:
64
+ for i, line in enumerate(reversed(err_file.readlines())):
65
+ print(line, end="")
66
+ if i > ERR_TAIL:
67
+ break
68
+ print("*****************************************************")
69
+
70
+ logging.info(f"\nCollected results: {json.dumps(results, indent=2)}")
71
+
72
+ # - reduction: compute the average
73
+ tasks = set(t for v in results.values() for t in v.keys())
74
+ # -- fill in the possible gaps
75
+ for att in results.keys():
76
+ for t in tasks:
77
+ if t not in results[att].keys():
78
+ results[att][t] = 0.0
79
+
80
+ # -- add the average value
81
+ for att in results.keys():
82
+ results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2)
83
+
84
+ # - Format as an array, markdown style
85
+ tasks_sort = sorted(
86
+ set(t for v in results.values() for t in v.keys()), reverse=True
87
+ )
88
+ print(
89
+ "{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort)
90
+ )
91
+
92
+ for att in results.keys():
93
+ print(
94
+ "{0:<20}".format(att)
95
+ + "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort)
96
+ )
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_submit.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import argparse
8
+ import os
9
+ from pathlib import Path
10
+
11
+ from xformers.benchmarks.LRA.run_tasks import Task
12
+ from xformers.components.attention import ATTENTION_REGISTRY
13
+
14
+
15
+ def get_default_shared_folder() -> str:
16
+ checkpoint_paths = ["/checkpoint", "/checkpoints"]
17
+ for checkpoint_path in checkpoint_paths:
18
+ if Path(checkpoint_path).is_dir():
19
+ return checkpoint_path
20
+
21
+ return "."
22
+
23
+
24
+ if __name__ == "__main__":
25
+ default_checkpoint_path = get_default_shared_folder()
26
+
27
+ # Get the user requests
28
+ parser = argparse.ArgumentParser(
29
+ "Benchmark different attention mechanisms on various sequence lengths"
30
+ )
31
+ parser.add_argument("-c", "--config_path", required=True)
32
+ parser.add_argument("-ck", "--checkpoint_path", required=True)
33
+ parser.add_argument(
34
+ "-a", "--attentions", nargs="+", default=list(ATTENTION_REGISTRY.keys())
35
+ )
36
+ parser.add_argument("-t", "--tasks", nargs="+", default=[t.value for t in Task])
37
+ parser.add_argument(
38
+ "--partition", default="a100", type=str, help="Partition where to submit"
39
+ )
40
+ args = parser.parse_args()
41
+
42
+ for attention in args.attentions:
43
+ for task in args.tasks:
44
+ os.system(
45
+ "python3 run_with_submitit.py"
46
+ + f" --attention {attention} --task {task} --config {args.config_path}"
47
+ + f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}"
48
+ + f" --partition {args.partition}"
49
+ )
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (2.72 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # CREDITS: Almost as-is from the Nystromformer repo
8
+ # https://github.com/mlpen/Nystromformer
9
+
10
+ import logging
11
+ import pickle
12
+
13
+ import torch
14
+ from torch.utils.data.dataset import Dataset
15
+
16
+ logging.getLogger().setLevel(logging.INFO)
17
+
18
+
19
+ class LRADataset(Dataset):
20
+ def __init__(self, file_path, seq_len):
21
+ with open(file_path, "rb") as f:
22
+ self.examples = pickle.load(f)
23
+
24
+ self.seq_len = seq_len
25
+ logging.info(f"Loaded {file_path}... size={len(self.examples)}")
26
+
27
+ def __len__(self):
28
+ return len(self.examples)
29
+
30
+ def __getitem__(self, i):
31
+ return self.create_inst(self.examples[i], self.seq_len)
32
+
33
+ @staticmethod
34
+ def create_inst(inst, seq_len):
35
+ output = {
36
+ "input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len]
37
+ }
38
+ output["mask_0"] = (output["input_ids_0"] != 0).float()
39
+
40
+ if "input_ids_1" in inst:
41
+ output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[
42
+ :seq_len
43
+ ]
44
+ output["mask_1"] = (output["input_ids_1"] != 0).float()
45
+ output["label"] = torch.tensor(inst["label"], dtype=torch.long)
46
+ return output
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/model_wrapper.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # CREDITS: adapted from the Nystromformer repo
8
+ # https://github.com/mlpen/Nystromformer
9
+
10
+ from enum import Enum
11
+ from typing import Dict, Union
12
+
13
+ import pytorch_lightning as pl
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from xformers.components import build_attention
18
+ from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig
19
+ from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
20
+ from xformers.utils import generate_matching_config
21
+
22
+ PLOutput = Dict[str, Union[float, torch.Tensor]]
23
+
24
+
25
+ class Pooling(str, Enum):
26
+ MEAN = "mean"
27
+ CLS = "cls"
28
+
29
+
30
+ def pooling(mode: Pooling):
31
+ def pool_cls(inp):
32
+ return inp[:, 0, :]
33
+
34
+ def pool_mean(inp):
35
+ return inp.mean(dim=1)
36
+
37
+ return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode]
38
+
39
+
40
+ def append_cls(inp, mask, vocab_size):
41
+ batch_size = inp.size(0)
42
+ cls_id = (
43
+ (vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)
44
+ ).long()
45
+ cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
46
+ inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
47
+ mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
48
+ return inp, mask
49
+
50
+
51
+ def patch_model_config(config, attention_name):
52
+ # Rebuild a specific config out of generic + extra params
53
+ commons = config["common"]
54
+ try:
55
+ extra_attention_settings = config["extra_settings"]["attention"][attention_name]
56
+ except KeyError:
57
+ extra_attention_settings = None
58
+
59
+ for bc in config["xformer"]:
60
+ bc["dim_model"] = commons["dim_model"]
61
+ bc["position_encoding_config"].update(commons)
62
+ bc["feedforward_config"].update(commons)
63
+ bc["multi_head_config"].update(commons)
64
+ bc["multi_head_config"]["attention"].update(commons)
65
+ bc["multi_head_config"]["attention"]["name"] = attention_name
66
+ bc["multi_head_config"]["attention"]["dim_head"] = (
67
+ commons["dim_model"] / commons["num_heads"]
68
+ )
69
+ if extra_attention_settings is not None:
70
+ bc["multi_head_config"]["attention"].update(extra_attention_settings)
71
+
72
+ bc["multi_head_config"] = generate_matching_config(
73
+ bc["multi_head_config"], MultiHeadDispatchConfig
74
+ )
75
+ bc["multi_head_config"].attention = build_attention(
76
+ bc["multi_head_config"].attention
77
+ )
78
+ bc = generate_matching_config(bc, xFormerEncoderConfig)
79
+
80
+ return config
81
+
82
+
83
+ class SCHead(nn.Module):
84
+ def __init__(self, config, dim_embedding, dim_mlp):
85
+ super().__init__()
86
+ self.pooling = pooling(Pooling(config["pooling_mode"]))
87
+
88
+ self.mlpblock = nn.Sequential(
89
+ nn.Linear(dim_embedding, dim_mlp),
90
+ nn.ReLU(),
91
+ nn.Linear(dim_mlp, config["common"]["num_classes"]),
92
+ )
93
+
94
+ def forward(self, inp: torch.Tensor):
95
+ seq_score = self.mlpblock(self.pooling(inp))
96
+ return seq_score
97
+
98
+
99
+ class SCHeadDual(nn.Module):
100
+ def __init__(self, config, dim_embedding, dim_mlp):
101
+ super().__init__()
102
+ self.pooling = pooling(Pooling(config["pooling_mode"]))
103
+
104
+ self.mlpblock = nn.Sequential(
105
+ nn.Linear(
106
+ dim_embedding * 4,
107
+ dim_mlp,
108
+ ),
109
+ nn.ReLU(),
110
+ nn.Linear(dim_mlp, config["common"]["num_classes"]),
111
+ )
112
+
113
+ def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
114
+ X_0 = self.pooling(inp_0)
115
+ X_1 = self.pooling(inp_1)
116
+ seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
117
+ return seq_score
118
+
119
+
120
+ class ModelTrunk(pl.LightningModule):
121
+ def __init__(self, config, model_name):
122
+ super().__init__()
123
+
124
+ config_model = config["model"]
125
+ self.config_training = config["training"]
126
+
127
+ self.enable_amp = config["training"]["mixed_precision"]
128
+ self.pooling_mode = Pooling(config_model["pooling_mode"])
129
+ self.vocab_size = config_model["common"]["vocab_size"]
130
+
131
+ # Rebuild a specific config out of generic + extra params
132
+ self.config_model = patch_model_config(config_model, model_name)
133
+ self.model = xFormer.from_config(xFormerConfig(config_model["xformer"]))
134
+ self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"])
135
+
136
+ ff_config = self.config_model["xformer"][0]["feedforward_config"]
137
+ self.dim_mlp = (
138
+ self.config_model["common"]["dim_model"]
139
+ * ff_config["hidden_layer_multiplier"]
140
+ )
141
+
142
+ def training_step( # type: ignore
143
+ self, batch: Dict[str, torch.Tensor], batch_idx: int
144
+ ) -> PLOutput:
145
+ outputs = self(**batch)
146
+ self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
147
+ self.log("train_accu", outputs["accu"], sync_dist=True)
148
+ return outputs
149
+
150
+ def training_epoch_end(self, outputs):
151
+ logs = self.eval_epoch_end(outputs)
152
+ self.log("train_accu_mean", logs["accu"], sync_dist=True)
153
+
154
+ def configure_optimizers(self):
155
+ optimizer = torch.optim.AdamW(
156
+ self.parameters(),
157
+ lr=self.config_training["learning_rate"],
158
+ betas=(0.9, 0.999),
159
+ eps=1e-6,
160
+ weight_decay=self.config_training["weight_decay"],
161
+ )
162
+
163
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
164
+ optimizer=optimizer,
165
+ max_lr=self.config_training["learning_rate"],
166
+ pct_start=self.config_training["warmup"]
167
+ / self.config_training["num_train_steps"],
168
+ anneal_strategy=self.config_training["lr_decay"],
169
+ total_steps=self.config_training["num_train_steps"],
170
+ )
171
+
172
+ return [optimizer], [lr_scheduler]
173
+
174
+ def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
175
+ outputs = self(**batch)
176
+ return outputs
177
+
178
+ def eval_epoch_end(self, outputs, prefix: str = "train"):
179
+ logs = {}
180
+ counts = torch.tensor([x["count"] for x in outputs]).float()
181
+ logs["count"] = counts.sum()
182
+ for k in ("accu", "loss"):
183
+ logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
184
+ "count"
185
+ ]
186
+ self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
187
+ return logs
188
+
189
+ def validation_step( # type: ignore
190
+ self, batch: Dict[str, torch.Tensor], batch_idx: int
191
+ ) -> PLOutput:
192
+ outputs = self.eval_step(batch, batch_idx)
193
+ self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
194
+ self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
195
+ return outputs
196
+
197
+ def validation_epoch_end(self, outputs):
198
+ self.eval_epoch_end(outputs, prefix="val")
199
+
200
+ def test_step( # type: ignore
201
+ self, batch: Dict[str, torch.Tensor], batch_idx: int
202
+ ) -> PLOutput:
203
+ return self.eval_step(batch, batch_idx)
204
+
205
+ def test_epoch_end(self, outputs):
206
+ self.eval_epoch_end(outputs, prefix="test")
207
+
208
+
209
+ class ModelForSC(ModelTrunk):
210
+ def __init__(self, config, model_name):
211
+ # Setup trunk
212
+ super().__init__(config, model_name)
213
+
214
+ self.seq_classifer = SCHead(
215
+ self.config_model,
216
+ dim_embedding=self.config_model["common"]["dim_model"],
217
+ dim_mlp=self.dim_mlp,
218
+ )
219
+
220
+ def forward( # type: ignore
221
+ self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
222
+ ):
223
+
224
+ if self.pooling_mode == Pooling.CLS:
225
+ input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
226
+
227
+ token_out = self.norm(
228
+ self.model(input_ids_0, encoder_input_mask=mask_0)
229
+ ) * mask_0.unsqueeze(-1)
230
+
231
+ seq_scores = self.seq_classifer(token_out)
232
+
233
+ seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
234
+ seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
235
+ outputs = {
236
+ "loss": seq_loss.mean(),
237
+ "accu": seq_accu.mean(),
238
+ "count": label.size(0),
239
+ }
240
+
241
+ return outputs
242
+
243
+
244
+ class ModelForSCDual(ModelTrunk):
245
+ def __init__(self, config, model_name):
246
+ # Setup trunk
247
+ super().__init__(config, model_name)
248
+
249
+ self.seq_classifer = SCHeadDual(
250
+ self.config_model,
251
+ dim_embedding=self.config_model["common"]["dim_model"],
252
+ dim_mlp=self.dim_mlp,
253
+ )
254
+
255
+ def forward( # type: ignore
256
+ self,
257
+ input_ids_0: torch.Tensor,
258
+ input_ids_1: torch.Tensor,
259
+ mask_0: torch.Tensor,
260
+ mask_1: torch.Tensor,
261
+ label: torch.Tensor,
262
+ ):
263
+
264
+ mask_0, mask_1 = mask_0.long(), mask_1.long()
265
+
266
+ if self.pooling_mode == Pooling.CLS:
267
+ input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
268
+ input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
269
+
270
+ # Concatenate the two inputs into one batch
271
+ input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
272
+ masks = torch.cat([mask_0, mask_1], dim=0)
273
+
274
+ tokens_out = self.norm(
275
+ self.model(input_ids, encoder_input_mask=masks)
276
+ ) * masks.unsqueeze(-1)
277
+
278
+ seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))
279
+
280
+ seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
281
+ seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
282
+ outputs = {
283
+ "loss": seq_loss.mean(),
284
+ "accu": seq_accu.mean(),
285
+ "count": label.size(0),
286
+ }
287
+
288
+ return outputs
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_grid_search.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import itertools
8
+ import os
9
+ import uuid
10
+ from datetime import date
11
+ from pathlib import Path
12
+ from typing import Dict, Iterable
13
+
14
+ import submitit
15
+
16
+ from xformers.benchmarks.LRA.run_with_submitit import (
17
+ Trainer,
18
+ get_init_file,
19
+ get_shared_folder,
20
+ parse_args,
21
+ )
22
+
23
+
24
+ def grid_parameters(grid: Dict):
25
+ """
26
+ Yield all combinations of parameters in the grid (as a dict)
27
+ """
28
+ grid_copy = dict(grid)
29
+ # Turn single value in an Iterable
30
+ for k in grid_copy:
31
+ if not isinstance(grid_copy[k], Iterable):
32
+ grid_copy[k] = [grid_copy[k]]
33
+ for p in itertools.product(*grid_copy.values()):
34
+ yield dict(zip(grid.keys(), p))
35
+
36
+
37
+ def grid_search(args):
38
+ if args.checkpoint_dir == "":
39
+ args.checkpoint_dir = get_shared_folder() / "%j"
40
+
41
+ date_curr = date.today().strftime("%m-%d-%Y")
42
+ orig_check_dir = os.path.join(args.checkpoint_dir, date_curr)
43
+
44
+ # Create the executor
45
+ # Note that the folder will depend on the job_id, to easily track experiments
46
+ executor = submitit.AutoExecutor(
47
+ folder=get_shared_folder() / "%j", slurm_max_num_timeout=30
48
+ )
49
+ num_gpus_per_node = args.ngpus
50
+ nodes = args.nodes
51
+ args.world_size = args.nodes * args.ngpus
52
+ partition = args.partition
53
+
54
+ executor.update_parameters(
55
+ gpus_per_node=num_gpus_per_node,
56
+ tasks_per_node=num_gpus_per_node, # one task per GPU
57
+ cpus_per_task=10,
58
+ nodes=nodes,
59
+ timeout_min=60 * 72,
60
+ slurm_signal_delay_s=120,
61
+ slurm_partition=partition,
62
+ )
63
+ executor.update_parameters(name="lra")
64
+
65
+ if args.task == "text":
66
+ grid_meta = {
67
+ "training:learning_rate": (
68
+ [1e-4, 2e-4, 3e-4, 5e-5],
69
+ lambda val: f"lr{val}",
70
+ ),
71
+ "training:warmup": ([3000, 8000], lambda val: f"warmup{val}"),
72
+ "training:seed": ([1234, 32, 1994], lambda val: f"seed{val}"),
73
+ "training:weight_decay": ([0.02, 0.05, 0.01], lambda val: f"wd{val}"),
74
+ "model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
75
+ "model:common:dropout": ([0, 0.05], lambda val: f"drop{val}"),
76
+ }
77
+ elif args.task == "retrieval":
78
+ grid_meta = {
79
+ "training:learning_rate": ([1e-4, 3e-4], lambda val: f"lr{val}"),
80
+ "training:warmup": ([2000, 8000], lambda val: f"warmup{val}"),
81
+ "training:seed": ([4096, 1234, 3, 15, 5], lambda val: f"seed{val}"),
82
+ "training:weight_decay": ([0.01, 0], lambda val: f"wd{val}"),
83
+ "model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
84
+ "model:common:dropout": ([0], lambda val: f"drop{val}"),
85
+ }
86
+ elif args.task == "listops":
87
+ grid_meta = {
88
+ "training:learning_rate": (
89
+ [1e-4, 2e-4, 3e-4, 5e-5],
90
+ lambda val: f"lr{val}",
91
+ ),
92
+ "training:warmup": ([3000, 2000], lambda val: f"warmup{val}"),
93
+ "training:seed": (
94
+ [
95
+ 1234,
96
+ ],
97
+ lambda val: f"seed{val}",
98
+ ),
99
+ "training:weight_decay": ([0.02, 0.05, 0, 1], lambda val: f"wd{val}"),
100
+ "model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
101
+ "model:common:dropout": ([0], lambda val: f"drop{val}"),
102
+ }
103
+ else:
104
+ grid_meta = {
105
+ "training:learning_rate": ([1e-4, 5e-5], lambda val: f"lr{val}"),
106
+ "training:warmup": ([8000], lambda val: f"warmup{val}"),
107
+ "training:seed": ([1234, 4321, 3], lambda val: f"seed{val}"),
108
+ "training:weight_decay": ([0.01], lambda val: f"wd{val}"),
109
+ "model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
110
+ "model:common:dropout": ([0.1], lambda val: f"drop{val}"),
111
+ }
112
+
113
+ grid = {k: v[0] for k, v in grid_meta.items()}
114
+ save_key = {k: v[1] for k, v in grid_meta.items()}
115
+
116
+ hyper_parameters = list(grid_parameters(grid))
117
+ jobs = []
118
+
119
+ for i, grid_data in enumerate(hyper_parameters):
120
+
121
+ args.sweep_parameters = grid_data
122
+ run_name = f"{args.attention}"
123
+ # run_name = "paper_config"
124
+ for k, v in grid_data.items():
125
+ run_name += "prenorm-" + save_key[k](v)
126
+ args.checkpoint_dir = os.path.join(
127
+ orig_check_dir, f"{args.task}", "logs", run_name
128
+ )
129
+ Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
130
+ args.tb_dir = os.path.join(orig_check_dir, f"{args.task}", "tb", run_name)
131
+ Path(args.tb_dir).mkdir(parents=True, exist_ok=True)
132
+
133
+ # Chronos needs a different job name each time
134
+ executor.update_parameters(name=f"lra_{args.task}_{i:02d}_{uuid.uuid4().hex}")
135
+
136
+ args.dist_url = get_init_file().as_uri()
137
+ args.temp_file = str(get_init_file())
138
+
139
+ trainer = Trainer(args)
140
+ job = executor.submit(trainer)
141
+ jobs.append(job)
142
+ print(f"Run {i:02d} submitted with train cfg: {args}")
143
+ print(f"Submitted jobs ids: {','.join([str(job.job_id) for job in jobs])}")
144
+
145
+
146
+ if __name__ == "__main__":
147
+ args = parse_args()
148
+ grid_search(args)
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_tasks.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ import os
11
+ from enum import Enum
12
+ from pathlib import Path
13
+ from typing import Dict, Tuple, cast
14
+
15
+ import pytorch_lightning as pl
16
+ import torch
17
+ import torch.nn as nn
18
+ from fvcore.nn import FlopCountAnalysis, flop_count_str
19
+ from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
20
+ from pytorch_lightning.loggers import TensorBoardLogger
21
+ from pytorch_lightning.strategies import DDPStrategy
22
+ from torch.utils.data import DataLoader
23
+
24
+ from xformers.benchmarks.LRA.code.dataset import LRADataset
25
+ from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual
26
+ from xformers.components.attention import ATTENTION_REGISTRY
27
+
28
+
29
+ class Task(str, Enum):
30
+ Retrieval = "retrieval"
31
+ ListOps = "listops"
32
+ Image = "image"
33
+ PathfinderBaseline = "pathfinder32-curv_baseline"
34
+ PathfinderContour9 = "pathfinder32-curv_contour_length_9"
35
+ PathfinderContour14 = "pathfinder32-curv_contour_length_14"
36
+ Text = "text"
37
+
38
+
39
+ def load_config(path: str) -> Dict:
40
+ with open(Path(path).absolute(), "r") as fileio:
41
+ config = json.load(fileio)
42
+
43
+ # Duplicate the pathfinder configs
44
+ config["pathfinder32-curv_baseline"] = config["pathfinder32"]
45
+ config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"]
46
+ config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"]
47
+ return config
48
+
49
+
50
+ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module:
51
+ task = args.task
52
+ attention_name = args.attention
53
+
54
+ model = cast(
55
+ pl.LightningModule,
56
+ (
57
+ ModelForSCDual(config[f"{task}"], attention_name)
58
+ if task == Task.Retrieval
59
+ else ModelForSC(config[f"{task}"], attention_name)
60
+ ),
61
+ )
62
+
63
+ logging.info(model)
64
+ summary = pl.utilities.model_summary.LayerSummary(model)
65
+ logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M")
66
+
67
+ with torch.no_grad():
68
+ # Check the flops
69
+ seq_len = config[f"{task}"]["model"]["common"]["seq_len"]
70
+ x = torch.rand(1, seq_len).long()
71
+ mask = torch.rand(1, seq_len).long()
72
+ indices = torch.rand(1, seq_len).long()
73
+ flops = FlopCountAnalysis(model.model, (x, mask, indices))
74
+ logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops")
75
+ logging.info(flop_count_str(flops))
76
+
77
+ return model
78
+
79
+
80
+ def get_arg_parser():
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument(
83
+ "--attention",
84
+ type=str,
85
+ help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \
86
+ A list can be passed to test several mechanisms in sequence",
87
+ dest="attention",
88
+ required=True,
89
+ )
90
+ parser.add_argument(
91
+ "--task",
92
+ type=Task,
93
+ help=f"Task to chose, among {[t.value for t in Task]}.",
94
+ dest="task",
95
+ required=True,
96
+ )
97
+ parser.add_argument(
98
+ "--skip_train",
99
+ type=bool,
100
+ help="Whether to skip training, and test an existing model",
101
+ dest="skip_train",
102
+ default=False,
103
+ )
104
+ parser.add_argument(
105
+ "--config",
106
+ type=str,
107
+ help="Path to the config being used",
108
+ dest="config",
109
+ default="./config.json",
110
+ )
111
+ parser.add_argument(
112
+ "--checkpoint_dir",
113
+ type=str,
114
+ help="Path to the checkpoint directory",
115
+ dest="checkpoint_dir",
116
+ default=f"/checkpoints/{os.getenv('USER')}/xformers",
117
+ )
118
+ parser.add_argument(
119
+ "--checkpoint_path",
120
+ type=str,
121
+ help="Path to checkpoint",
122
+ )
123
+ parser.add_argument(
124
+ "--debug",
125
+ help="Make it easier to debug a possible issue",
126
+ dest="debug",
127
+ default=False,
128
+ action="store_true",
129
+ )
130
+ parser.add_argument(
131
+ "--world_size",
132
+ help="Number of GPUs used",
133
+ dest="world_size",
134
+ type=int,
135
+ default=1,
136
+ )
137
+ parser.add_argument(
138
+ "--sweep_parameters",
139
+ help="Rewrite some hyperparameters in the config",
140
+ dest="sweep_parameters",
141
+ type=dict,
142
+ default=None,
143
+ )
144
+ return parser
145
+
146
+
147
+ def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]:
148
+ experiment_name = f"{task}__{attention_name}"
149
+ logger = TensorBoardLogger(
150
+ save_dir=args.checkpoint_dir,
151
+ name="", # remove lightning_logs subdirectory
152
+ version=experiment_name,
153
+ )
154
+ log_dir = os.path.join(logger._save_dir, experiment_name)
155
+ return log_dir, logger
156
+
157
+
158
+ def rewrite_hyper(config, rewrites):
159
+ def replace(config_dict, k, v):
160
+ if len(k.split(":")) == 1:
161
+ config_dict[k] = v
162
+ return
163
+ first_key = k.split(":")[0]
164
+ assert first_key in config_dict, first_key
165
+ k = k[len(first_key) + 1 :]
166
+ replace(config_dict[first_key], k, v)
167
+
168
+ for k, v in rewrites.items():
169
+ replace(config, k, v)
170
+ return config
171
+
172
+
173
+ def build_dataloaders(
174
+ args: argparse.Namespace,
175
+ config_training: Dict,
176
+ num_workers: int = 4,
177
+ ) -> Dict[str, DataLoader]:
178
+ datasets = {}
179
+ for component in ("train", "dev", "test"):
180
+ datasets[component] = LRADataset(
181
+ file_path=f"datasets/{args.task}.{component}.pickle",
182
+ seq_len=config_training["seq_len"],
183
+ )
184
+
185
+ # Gradient accumulation
186
+ accumu_steps = config_training["gradient_accumulation"]
187
+ logging.info(f"accumu_steps={accumu_steps}")
188
+
189
+ # Batch size
190
+ per_gpu_batch_size = (
191
+ config_training["batch_size"] // args.world_size // accumu_steps
192
+ )
193
+ logging.warning(
194
+ f"Requested batch size: {config_training['batch_size']}. Given world\
195
+ size and grad accumulation, per-gpu batch is\
196
+ {per_gpu_batch_size}"
197
+ )
198
+
199
+ dataloaders = {
200
+ k: DataLoader(
201
+ v,
202
+ batch_size=per_gpu_batch_size,
203
+ shuffle=False,
204
+ pin_memory=True,
205
+ num_workers=num_workers,
206
+ )
207
+ for k, v in datasets.items()
208
+ }
209
+ return dataloaders
210
+
211
+
212
+ def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]:
213
+ eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step}
214
+ for k, v in trainer.callback_metrics.items():
215
+ eval_summary[k] = v.item()
216
+ return eval_summary
217
+
218
+
219
+ class BasicProgressBar(TQDMProgressBar):
220
+ def get_metrics(self, trainer, model):
221
+ items = super().get_metrics(trainer, model)
222
+ items.pop("v_num", None)
223
+ return items
224
+
225
+
226
+ def benchmark(args):
227
+ log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}")
228
+ args.logger = logger
229
+
230
+ config = load_config(args.config)
231
+
232
+ config_task = config[f"{args.task}"]
233
+ if args.sweep_parameters is not None:
234
+ logging.info("Replacing hyperparameters")
235
+ rewrite_hyper(config_task, args.sweep_parameters)
236
+
237
+ config_training = config_task["training"]
238
+ config_training["seq_len"] = config_task["model"]["common"]["seq_len"]
239
+ logging.info(f"Learning rate: {config_training['learning_rate']}")
240
+
241
+ pl.seed_everything(config_training.get("seed", 0))
242
+ dataloaders = build_dataloaders(args, config_training)
243
+
244
+ model = build_model(args, config)
245
+
246
+ progress_bar = BasicProgressBar()
247
+ checkpoint_callback = ModelCheckpoint(
248
+ monitor="val_accu",
249
+ mode="max",
250
+ dirpath=args.checkpoint_dir,
251
+ filename="{epoch}-{val_accu:.2f}",
252
+ every_n_train_steps=config_training["eval_frequency"],
253
+ )
254
+
255
+ trainer = pl.Trainer(
256
+ accelerator="gpu",
257
+ strategy=(
258
+ DDPStrategy(find_unused_parameters=args.debug)
259
+ if not args.skip_train
260
+ else None
261
+ ),
262
+ accumulate_grad_batches=config_training["gradient_accumulation"],
263
+ callbacks=[progress_bar, checkpoint_callback],
264
+ detect_anomaly=args.debug,
265
+ deterministic=True,
266
+ gpus=args.world_size,
267
+ limit_val_batches=config_training["num_eval_steps"],
268
+ logger=logger,
269
+ max_steps=config_training["num_train_steps"],
270
+ num_sanity_val_steps=int(not args.skip_train),
271
+ precision=16 if config_training["mixed_precision"] else 32,
272
+ val_check_interval=config_training["eval_frequency"]
273
+ / float(len(dataloaders["train"])),
274
+ )
275
+
276
+ if not args.skip_train:
277
+ trainer.fit(
278
+ model,
279
+ train_dataloaders=dataloaders["train"],
280
+ val_dataloaders=dataloaders["dev"],
281
+ )
282
+ ckpt_path = checkpoint_callback.best_model_path
283
+ else:
284
+ ckpt_path = args.checkpoint_path
285
+
286
+ trainer.test(
287
+ model,
288
+ dataloaders=dataloaders["test"],
289
+ ckpt_path=ckpt_path,
290
+ )
291
+ eval_summary = get_eval_summary(trainer)
292
+ with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f:
293
+ logging.info(f"Saving test results at {f.name}")
294
+ json.dump(eval_summary, f)
295
+
296
+
297
+ if __name__ == "__main__":
298
+ parser = get_arg_parser()
299
+ args = parser.parse_args()
300
+ if args.skip_train and args.checkpoint_path is None:
301
+ raise parser.error("Must provide --checkpoint_path if --skip_train=True")
302
+ benchmark(args)
.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_with_submitit.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ """
8
+ A script to run multinode training with submitit.
9
+ Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+ import uuid
15
+ from pathlib import Path
16
+
17
+ import submitit
18
+
19
+ from xformers.benchmarks.LRA.run_tasks import benchmark, get_arg_parser
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(
24
+ "Submitit for LRA", parents=[get_arg_parser()], add_help=False
25
+ )
26
+ parser.add_argument(
27
+ "--ngpus", default=1, type=int, help="Number of gpus to request on each node"
28
+ )
29
+ parser.add_argument(
30
+ "--nodes", default=1, type=int, help="Number of nodes to request"
31
+ )
32
+ parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
33
+
34
+ parser.add_argument(
35
+ "--partition", default="a100", type=str, help="Partition where to submit"
36
+ )
37
+ parser.add_argument(
38
+ "--use_volta32", action="store_true", help="Big models? Use this"
39
+ )
40
+ parser.add_argument(
41
+ "--enforce_host_memory", action="store_true", help="Use if the host OOMs"
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--comment",
46
+ default="",
47
+ type=str,
48
+ help="Comment to pass to scheduler, e.g. priority message",
49
+ )
50
+ return parser.parse_args()
51
+
52
+
53
+ def get_shared_folder() -> Path:
54
+ user = os.getenv("USER")
55
+ checkpoint_paths = ["/checkpoint", "/checkpoints"]
56
+ for checkpoint_path in checkpoint_paths:
57
+ if Path(checkpoint_path).is_dir():
58
+ p = Path(f"{checkpoint_path}/{user}/xformers/submitit")
59
+ p.mkdir(exist_ok=True, parents=True)
60
+ return p
61
+ raise RuntimeError(f"No shared folder available - considering {checkpoint_paths}")
62
+
63
+
64
+ def get_init_file():
65
+ # Init file must not exist, but it's parent dir must exist.
66
+ os.makedirs(str(get_shared_folder()), exist_ok=True)
67
+ init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
68
+ if init_file.exists():
69
+ os.remove(str(init_file))
70
+ return init_file
71
+
72
+
73
+ class Trainer:
74
+ def __init__(self, args):
75
+ self.args = args
76
+
77
+ def __call__(self):
78
+ self._setup_gpu_args()
79
+ benchmark(self.args)
80
+
81
+ def checkpoint(self):
82
+ self.args.dist_url = get_init_file().as_uri()
83
+ print("Requeuing ", self.args)
84
+ empty_trainer = type(self)(self.args)
85
+ return submitit.helpers.DelayedSubmission(empty_trainer)
86
+
87
+ def _setup_gpu_args(self):
88
+ job_env = submitit.JobEnvironment()
89
+ self.args.checkpoint_dir = Path(
90
+ str(self.args.checkpoint_dir).replace("%j", str(job_env.job_id))
91
+ )
92
+ self.args.gpu = job_env.local_rank
93
+ self.args.rank = job_env.global_rank
94
+ self.args.world_size = job_env.num_tasks
95
+ print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
96
+
97
+
98
+ def main():
99
+ args = parse_args()
100
+ if args.checkpoint_dir == "":
101
+ args.checkpoint_dir = get_shared_folder() / "%j"
102
+ Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
103
+ executor = submitit.AutoExecutor(
104
+ folder=args.checkpoint_dir, slurm_max_num_timeout=30
105
+ )
106
+
107
+ num_gpus_per_node = args.ngpus
108
+ nodes = args.nodes
109
+ timeout_min = args.timeout
110
+ args.world_size = args.nodes * args.ngpus
111
+
112
+ partition = args.partition
113
+
114
+ kwargs = {
115
+ "gpus_per_node": num_gpus_per_node,
116
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
117
+ "cpus_per_task": 10,
118
+ "nodes": nodes,
119
+ "timeout_min": timeout_min, # max is 60 * 72
120
+ # Below are cluster dependent parameters
121
+ "slurm_partition": partition,
122
+ "slurm_signal_delay_s": 120,
123
+ }
124
+
125
+ if args.enforce_host_memory:
126
+ kwargs["mem_gb"] = (40 * num_gpus_per_node,)
127
+
128
+ if args.use_volta32:
129
+ kwargs["slurm_constraint"] = "volta32gb"
130
+
131
+ if args.comment:
132
+ kwargs["slurm_comment"] = args.comment
133
+
134
+ executor.update_parameters(
135
+ **kwargs,
136
+ )
137
+
138
+ executor.update_parameters(name="lra")
139
+
140
+ args.dist_url = get_init_file().as_uri()
141
+ args.temp_file = str(get_init_file())
142
+
143
+ trainer = Trainer(args)
144
+ job = executor.submit(trainer)
145
+
146
+ print(f"Submitted job_id: {job.job_id}")
147
+ print(f"Logs and checkpoints will be saved at: {args.checkpoint_dir}")
148
+ with open(Path(f"{args.checkpoint_dir}") / Path("jobs.txt"), "a") as jobfile:
149
+ jobfile.write(f"{job.job_id}\n")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
.venv/lib/python3.11/site-packages/xformers/benchmarks/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (192 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-311.pyc ADDED
Binary file (20.8 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_core.cpython-311.pyc ADDED
Binary file (9.66 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-311.pyc ADDED
Binary file (8.68 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_merge_attentions.cpython-311.pyc ADDED
Binary file (5.45 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-311.pyc ADDED
Binary file (4.26 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-311.pyc ADDED
Binary file (4.61 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-311.pyc ADDED
Binary file (3.93 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sequence_parallel_fused.cpython-311.pyc ADDED
Binary file (24.5 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sp24.cpython-311.pyc ADDED
Binary file (9.33 kB). View file