yangheng commited on
Commit
50b3e1a
·
verified ·
1 Parent(s): bcd9992

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,36 @@
1
  ---
2
- license: mit
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ metrics:
3
+ - matthews_correlation
4
+ - f1
5
+ tags:
6
+ - biology
7
+ - medical
8
  ---
9
+ This is the official pre-trained model introduced in [DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome
10
+ ](https://arxiv.org/pdf/2306.15006.pdf).
11
+
12
+ DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.
13
+
14
+ To load the model from huggingface:
15
+ ```
16
+ import torch
17
+ from transformers import AutoTokenizer, AutoModel
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
20
+ model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
21
+ ```
22
+
23
+ To calculate the embedding of a dna sequence
24
+ ```
25
+ dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
26
+ inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
27
+ hidden_states = model(inputs)[0] # [1, sequence_length, 768]
28
+
29
+ # embedding with mean pooling
30
+ embedding_mean = torch.mean(hidden_states[0], dim=0)
31
+ print(embedding_mean.shape) # expect to be 768
32
+
33
+ # embedding with max pooling
34
+ embedding_max = torch.max(hidden_states[0], dim=0)[0]
35
+ print(embedding_max.shape) # expect to be 768
36
+ ```
__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # file: __init__.py
3
+ # time: 11:52 27/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
bert_layers.py ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
5
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
6
+ # Copyright (c) 2022, Tri Dao.
7
+
8
+ import copy
9
+ import logging
10
+ import math
11
+ import warnings
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import rearrange
17
+ from transformers.activations import ACT2FN
18
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
19
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
20
+
21
+ from .bert_padding import (
22
+ index_first_axis,
23
+ index_put_first_axis,
24
+ pad_input,
25
+ unpad_input,
26
+ unpad_input_only,
27
+ )
28
+
29
+ # try:
30
+ # from .flash_attn_triton import flash_attn_qkvpacked_func
31
+ # except ImportError as e:
32
+ # flash_attn_qkvpacked_func = None
33
+ flash_attn_qkvpacked_func = None
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class BertEmbeddings(nn.Module):
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.word_embeddings = nn.Embedding(
41
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
42
+ )
43
+ # ALiBi doesn't use position embeddings
44
+ self.token_type_embeddings = nn.Embedding(
45
+ config.type_vocab_size, config.hidden_size
46
+ )
47
+
48
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
49
+ # variable name and be able to load any TensorFlow checkpoint file
50
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
51
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
+ self.register_buffer(
53
+ "token_type_ids",
54
+ torch.zeros(config.max_position_embeddings, dtype=torch.long),
55
+ persistent=False,
56
+ )
57
+
58
+ def forward(
59
+ self,
60
+ input_ids: Optional[torch.LongTensor] = None,
61
+ token_type_ids: Optional[torch.LongTensor] = None,
62
+ position_ids: Optional[torch.LongTensor] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ past_key_values_length: int = 0,
65
+ ) -> torch.Tensor:
66
+ if (input_ids is not None) == (inputs_embeds is not None):
67
+ raise ValueError("Must specify either input_ids or input_embeds!")
68
+ if input_ids is not None:
69
+ input_shape = input_ids.size()
70
+ else:
71
+ assert inputs_embeds is not None # just for type checking
72
+ input_shape = inputs_embeds.size()[:-1]
73
+
74
+ seq_length = input_shape[1]
75
+
76
+ if position_ids is None:
77
+ # great! ALiBi
78
+ pass
79
+
80
+ # Setting the token_type_ids to the registered buffer in constructor
81
+ # where it is all zeros, which usually occurs when it's auto-generated;
82
+ # registered buffer helps users when tracing the model without passing
83
+ # token_type_ids, solves issue #5664
84
+ if token_type_ids is None:
85
+ if hasattr(self, "token_type_ids"):
86
+ assert isinstance(self.token_type_ids, torch.LongTensor)
87
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
88
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
89
+ input_shape[0], seq_length
90
+ )
91
+ token_type_ids = buffered_token_type_ids_expanded # type: ignore
92
+ else:
93
+ token_type_ids = torch.zeros(
94
+ input_shape, # type: ignore
95
+ dtype=torch.long,
96
+ device=self.word_embeddings.device,
97
+ ) # type: ignore # yapf: disable
98
+
99
+ if inputs_embeds is None:
100
+ inputs_embeds = self.word_embeddings(input_ids)
101
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
102
+
103
+ embeddings = inputs_embeds + token_type_embeddings
104
+ # no position embeddings! ALiBi
105
+ embeddings = self.LayerNorm(embeddings)
106
+ embeddings = self.dropout(embeddings)
107
+ return embeddings
108
+
109
+
110
+ class BertUnpadSelfAttention(nn.Module):
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
114
+ config, "embedding_size"
115
+ ):
116
+ raise ValueError(
117
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
118
+ f"heads ({config.num_attention_heads})"
119
+ )
120
+
121
+ self.num_attention_heads = config.num_attention_heads
122
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
123
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
124
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
125
+ self.p_dropout = config.attention_probs_dropout_prob
126
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
127
+
128
+ # Warn if defaulting to pytorch because of import issues
129
+ if flash_attn_qkvpacked_func is None:
130
+ warnings.warn(
131
+ "Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model)."
132
+ )
133
+
134
+ def forward(
135
+ self,
136
+ hidden_states: torch.Tensor,
137
+ cu_seqlens: torch.Tensor,
138
+ max_seqlen_in_batch: int,
139
+ indices: torch.Tensor,
140
+ attn_mask: torch.Tensor,
141
+ bias: torch.Tensor,
142
+ ) -> torch.Tensor:
143
+ """Perform self-attention.
144
+
145
+ If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
146
+ implementation of self-attention.
147
+
148
+ The arguments are unpadded, and our implementations of attention require padded arguments,
149
+ so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
150
+ The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
151
+ It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
152
+
153
+ Args:
154
+ hidden_states: (total_nnz, dim)
155
+ cu_seqlens: (batch + 1,)
156
+ max_seqlen_in_batch: int
157
+ indices: (total_nnz,)
158
+ attn_mask: (batch, max_seqlen_in_batch)
159
+ bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
160
+
161
+ Returns:
162
+ attention: (total_nnz, dim)
163
+ """
164
+ qkv = self.Wqkv(hidden_states)
165
+ qkv = pad_input(
166
+ qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen_in_batch
167
+ ) # batch, max_seqlen_in_batch, thd
168
+ qkv = rearrange(
169
+ qkv, "b s (t h d) -> b s t h d", t=3, h=self.num_attention_heads
170
+ )
171
+ if self.p_dropout or flash_attn_qkvpacked_func is None:
172
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
173
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
174
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
175
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
176
+ attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size)
177
+ attention_scores = attention_scores + bias
178
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
179
+ attention_probs = self.dropout(attention_probs)
180
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
181
+ else:
182
+ # Triton implementation only supports 0 attention dropout
183
+ convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
184
+ if convert_dtype:
185
+ # Triton implementation only supports fp16 and bf16
186
+ orig_dtype = qkv.dtype
187
+ qkv = qkv.to(torch.float16)
188
+ bias_dtype = bias.dtype
189
+ bias = bias.to(torch.float16)
190
+ attention = flash_attn_qkvpacked_func(qkv, bias)
191
+ attention = attention.to(orig_dtype)
192
+ bias = bias.to(bias_dtype)
193
+ else:
194
+ attention = flash_attn_qkvpacked_func(qkv, bias)
195
+
196
+ # attn_mask is 1 for attend and 0 for don't
197
+ attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
198
+ return rearrange(attention, "nnz h d -> nnz (h d)")
199
+
200
+
201
+ # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
202
+ class BertSelfOutput(nn.Module):
203
+ def __init__(self, config):
204
+ super().__init__()
205
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
206
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
207
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
208
+
209
+ def forward(
210
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
211
+ ) -> torch.Tensor:
212
+ hidden_states = self.dense(hidden_states)
213
+ hidden_states = self.dropout(hidden_states)
214
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
215
+ return hidden_states
216
+
217
+
218
+ class BertUnpadAttention(nn.Module):
219
+ """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
220
+
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.self = BertUnpadSelfAttention(config)
224
+ self.output = BertSelfOutput(config)
225
+
226
+ def forward(
227
+ self,
228
+ input_tensor: torch.Tensor,
229
+ cu_seqlens: torch.Tensor,
230
+ max_s: int,
231
+ subset_idx: Optional[torch.Tensor] = None,
232
+ indices: Optional[torch.Tensor] = None,
233
+ attn_mask: Optional[torch.Tensor] = None,
234
+ bias: Optional[torch.Tensor] = None,
235
+ ) -> torch.Tensor:
236
+ """Forward pass for scaled self-attention without padding.
237
+
238
+ Arguments:
239
+ input_tensor: (total_nnz, dim)
240
+ cu_seqlens: (batch + 1,)
241
+ max_s: int
242
+ subset_idx: () set of indices whose values we care about at the end of the layer
243
+ (e.g., the masked tokens, if this is the final layer).
244
+ indices: None or (total_nnz,)
245
+ attn_mask: None or (batch, max_seqlen_in_batch)
246
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
247
+ """
248
+ self_output = self.self(
249
+ input_tensor, cu_seqlens, max_s, indices, attn_mask, bias
250
+ )
251
+ if subset_idx is not None:
252
+ return self.output(
253
+ index_first_axis(self_output, subset_idx),
254
+ index_first_axis(input_tensor, subset_idx),
255
+ )
256
+ else:
257
+ return self.output(self_output, input_tensor)
258
+
259
+
260
+ class BertGatedLinearUnitMLP(nn.Module):
261
+ """Applies the FFN at the end of each Mosaic BERT layer.
262
+
263
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
264
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
265
+ introduces Gated Linear Units.
266
+
267
+ Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
268
+ standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
269
+ `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
270
+ with the `config.intermediate_size=3072`.
271
+ However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
272
+ parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
273
+ """
274
+
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.config = config
278
+ self.gated_layers = nn.Linear(
279
+ config.hidden_size, config.intermediate_size * 2, bias=False
280
+ )
281
+ self.act = nn.GELU(approximate="none")
282
+ self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
285
+
286
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287
+ """Compute new hidden states from current hidden states.
288
+
289
+ Args:
290
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
291
+ the attention layer [nnz, dim].
292
+ """
293
+ residual_connection = hidden_states
294
+ # compute the activation
295
+ hidden_states = self.gated_layers(hidden_states)
296
+ gated = hidden_states[:, : self.config.intermediate_size]
297
+ non_gated = hidden_states[:, self.config.intermediate_size :]
298
+ hidden_states = self.act(gated) * non_gated
299
+ hidden_states = self.dropout(hidden_states)
300
+ # multiply by the second matrix
301
+ hidden_states = self.wo(hidden_states)
302
+ # add the residual connection and post-LN
303
+ hidden_states = self.layernorm(hidden_states + residual_connection)
304
+ return hidden_states
305
+
306
+
307
+ class BertLayer(nn.Module):
308
+ """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
309
+
310
+ def __init__(self, config):
311
+ super(BertLayer, self).__init__()
312
+ self.attention = BertUnpadAttention(config)
313
+ self.mlp = BertGatedLinearUnitMLP(config)
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states: torch.Tensor,
318
+ cu_seqlens: torch.Tensor,
319
+ seqlen: int,
320
+ subset_idx: Optional[torch.Tensor] = None,
321
+ indices: Optional[torch.Tensor] = None,
322
+ attn_mask: Optional[torch.Tensor] = None,
323
+ bias: Optional[torch.Tensor] = None,
324
+ ) -> torch.Tensor:
325
+ """Forward pass for a BERT layer, including both attention and MLP.
326
+
327
+ Args:
328
+ hidden_states: (total_nnz, dim)
329
+ cu_seqlens: (batch + 1,)
330
+ seqlen: int
331
+ subset_idx: () set of indices whose values we care about at the end of the layer
332
+ (e.g., the masked tokens, if this is the final layer).
333
+ indices: None or (total_nnz,)
334
+ attn_mask: None or (batch, max_seqlen_in_batch)
335
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
336
+ """
337
+ attention_output = self.attention(
338
+ hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias
339
+ )
340
+ layer_output = self.mlp(attention_output)
341
+ return layer_output
342
+
343
+
344
+ class BertEncoder(nn.Module):
345
+ """A stack of BERT layers providing the backbone of Mosaic BERT.
346
+
347
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
348
+ but with substantial modifications to implement unpadding and ALiBi.
349
+
350
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
351
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
352
+ """
353
+
354
+ def __init__(self, config):
355
+ super().__init__()
356
+ layer = BertLayer(config)
357
+ self.layer = nn.ModuleList(
358
+ [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]
359
+ )
360
+
361
+ self.num_attention_heads = config.num_attention_heads
362
+
363
+ # The alibi mask will be dynamically expanded if it is too small for
364
+ # the input the model receives. But it generally helps to initialize it
365
+ # to a reasonably large size to help pre-allocate CUDA memory.
366
+ # The default `alibi_starting_size` is 512.
367
+ self._current_alibi_size = int(config.alibi_starting_size)
368
+ self.alibi = torch.zeros(
369
+ (
370
+ 1,
371
+ self.num_attention_heads,
372
+ self._current_alibi_size,
373
+ self._current_alibi_size,
374
+ )
375
+ )
376
+ self.rebuild_alibi_tensor(size=config.alibi_starting_size)
377
+
378
+ def rebuild_alibi_tensor(
379
+ self, size: int, device: Optional[Union[torch.device, str]] = None
380
+ ):
381
+ # Alibi
382
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
383
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
384
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
385
+ # will be applied, it is necessary to construct the diagonal mask.
386
+ n_heads = self.num_attention_heads
387
+
388
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
389
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
390
+ start = 2 ** (-(2 ** -(math.log2(n_heads) - 3)))
391
+ ratio = start
392
+ return [start * ratio**i for i in range(n_heads)]
393
+
394
+ # In the paper, they only train models that have 2^a heads for some a. This function
395
+ # has some good properties that only occur when the input is a power of 2. To
396
+ # maintain that even when the number of heads is not a power of 2, we use a
397
+ # workaround.
398
+ if math.log2(n_heads).is_integer():
399
+ return get_slopes_power_of_2(n_heads)
400
+
401
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
402
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
403
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
404
+ slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2]
405
+ return slopes_a + slopes_b
406
+
407
+ context_position = torch.arange(size, device=device)[:, None]
408
+ memory_position = torch.arange(size, device=device)[None, :]
409
+ relative_position = torch.abs(memory_position - context_position)
410
+ # [n_heads, max_token_length, max_token_length]
411
+ relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1)
412
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
413
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
414
+ # [1, n_heads, max_token_length, max_token_length]
415
+ alibi = alibi.unsqueeze(0)
416
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
417
+
418
+ self._current_alibi_size = size
419
+ self.alibi = alibi
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ attention_mask: torch.Tensor,
425
+ output_all_encoded_layers: Optional[bool] = True,
426
+ subset_mask: Optional[torch.Tensor] = None,
427
+ ) -> List[torch.Tensor]:
428
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
429
+ extended_attention_mask = extended_attention_mask.to(
430
+ dtype=torch.float32
431
+ ) # fp16 compatibility
432
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
433
+
434
+ attention_mask_bool = attention_mask.bool()
435
+ batch, seqlen = hidden_states.shape[:2]
436
+ # Unpad inputs and mask. It will remove tokens that are padded.
437
+ # Assume ntokens is total number of tokens (padded and non-padded)
438
+ # and ntokens_unpad is total number of non-padded tokens.
439
+ # Then unpadding performs the following compression of the inputs:
440
+ # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
441
+ hidden_states, indices, cu_seqlens, _ = unpad_input(
442
+ hidden_states, attention_mask_bool
443
+ )
444
+
445
+ # Add alibi matrix to extended_attention_mask
446
+ if self._current_alibi_size < seqlen:
447
+ # Rebuild the alibi tensor when needed
448
+ warnings.warn(
449
+ f"Increasing alibi size from {self._current_alibi_size} to {seqlen}"
450
+ )
451
+ self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
452
+ elif self.alibi.device != hidden_states.device:
453
+ # Device catch-up
454
+ self.alibi = self.alibi.to(hidden_states.device)
455
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
456
+ attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
457
+ alibi_attn_mask = attn_bias + alibi_bias
458
+
459
+ all_encoder_layers = []
460
+ if subset_mask is None:
461
+ for layer_module in self.layer:
462
+ hidden_states = layer_module(
463
+ hidden_states,
464
+ cu_seqlens,
465
+ seqlen,
466
+ None,
467
+ indices,
468
+ attn_mask=attention_mask,
469
+ bias=alibi_attn_mask,
470
+ )
471
+ if output_all_encoded_layers:
472
+ all_encoder_layers.append(hidden_states)
473
+ # Pad inputs and mask. It will insert back zero-padded tokens.
474
+ # Assume ntokens is total number of tokens (padded and non-padded)
475
+ # and ntokens_unpad is total number of non-padded tokens.
476
+ # Then padding performs the following de-compression:
477
+ # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
478
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
479
+ else:
480
+ for i in range(len(self.layer) - 1):
481
+ layer_module = self.layer[i]
482
+ hidden_states = layer_module(
483
+ hidden_states,
484
+ cu_seqlens,
485
+ seqlen,
486
+ None,
487
+ indices,
488
+ attn_mask=attention_mask,
489
+ bias=alibi_attn_mask,
490
+ )
491
+ if output_all_encoded_layers:
492
+ all_encoder_layers.append(hidden_states)
493
+ subset_idx = torch.nonzero(
494
+ subset_mask[attention_mask_bool], as_tuple=False
495
+ ).flatten()
496
+ hidden_states = self.layer[-1](
497
+ hidden_states,
498
+ cu_seqlens,
499
+ seqlen,
500
+ subset_idx=subset_idx,
501
+ indices=indices,
502
+ attn_mask=attention_mask,
503
+ bias=alibi_attn_mask,
504
+ )
505
+
506
+ if not output_all_encoded_layers:
507
+ all_encoder_layers.append(hidden_states)
508
+ return all_encoder_layers
509
+
510
+
511
+ class BertPooler(nn.Module):
512
+ def __init__(self, config):
513
+ super(BertPooler, self).__init__()
514
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
515
+ self.activation = nn.Tanh()
516
+
517
+ def forward(
518
+ self, hidden_states: torch.Tensor, pool: Optional[bool] = True
519
+ ) -> torch.Tensor:
520
+ # We "pool" the model by simply taking the hidden state corresponding
521
+ # to the first token.
522
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
523
+ pooled_output = self.dense(first_token_tensor)
524
+ pooled_output = self.activation(pooled_output)
525
+ return pooled_output
526
+
527
+
528
+ class BertPredictionHeadTransform(nn.Module):
529
+ def __init__(self, config):
530
+ super().__init__()
531
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
532
+ if isinstance(config.hidden_act, str):
533
+ self.transform_act_fn = ACT2FN[config.hidden_act]
534
+ else:
535
+ self.transform_act_fn = config.hidden_act
536
+ self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
537
+
538
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
539
+ hidden_states = self.dense(hidden_states)
540
+ hidden_states = self.transform_act_fn(hidden_states)
541
+ hidden_states = self.LayerNorm(hidden_states)
542
+ return hidden_states
543
+
544
+
545
+ class BertModel(BertPreTrainedModel):
546
+ """Overall BERT model.
547
+
548
+ Args:
549
+ config: a BertConfig class instance with the configuration to build a new model
550
+
551
+ Inputs:
552
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
553
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
554
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
555
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
556
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
557
+ a `sentence B` token (see BERT paper for more details).
558
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
559
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
560
+ input sequence length in the current batch. It's the mask that we typically use for attention when
561
+ a batch has varying length sentences.
562
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
563
+
564
+ Outputs: Tuple of (encoded_layers, pooled_output)
565
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
566
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
567
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
568
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
569
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
570
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
571
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
572
+ classifier pretrained on top of the hidden state associated to the first character of the
573
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
574
+
575
+ Example usage:
576
+ ```python
577
+ # Already been converted into WordPiece token ids
578
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
579
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
580
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
581
+ config = model.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
582
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
583
+ model = BertModel(config=config)
584
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
585
+ ```
586
+ """
587
+
588
+ def __init__(self, config, add_pooling_layer=True):
589
+ super(BertModel, self).__init__(config)
590
+ self.embeddings = BertEmbeddings(config)
591
+ self.encoder = BertEncoder(config)
592
+ self.pooler = BertPooler(config) if add_pooling_layer else None
593
+ self.post_init()
594
+
595
+ def get_input_embeddings(self):
596
+ return self.embeddings.word_embeddings
597
+
598
+ def set_input_embeddings(self, value):
599
+ self.embeddings.word_embeddings = value
600
+
601
+ def forward(
602
+ self,
603
+ input_ids: torch.Tensor,
604
+ token_type_ids: Optional[torch.Tensor] = None,
605
+ attention_mask: Optional[torch.Tensor] = None,
606
+ position_ids: Optional[torch.Tensor] = None,
607
+ output_all_encoded_layers: Optional[bool] = False,
608
+ masked_tokens_mask: Optional[torch.Tensor] = None,
609
+ **kwargs,
610
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
611
+ if attention_mask is None:
612
+ attention_mask = torch.ones_like(input_ids)
613
+ if token_type_ids is None:
614
+ token_type_ids = torch.zeros_like(input_ids)
615
+
616
+ embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
617
+
618
+ subset_mask = []
619
+ first_col_mask = []
620
+
621
+ if masked_tokens_mask is None:
622
+ subset_mask = None
623
+ else:
624
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
625
+ first_col_mask[:, 0] = True
626
+ subset_mask = masked_tokens_mask | first_col_mask
627
+
628
+ encoder_outputs = self.encoder(
629
+ embedding_output,
630
+ attention_mask,
631
+ output_all_encoded_layers=output_all_encoded_layers,
632
+ subset_mask=subset_mask,
633
+ )
634
+
635
+ if masked_tokens_mask is None:
636
+ sequence_output = encoder_outputs[-1]
637
+ pooled_output = (
638
+ self.pooler(sequence_output) if self.pooler is not None else None
639
+ )
640
+ else:
641
+ # TD [2022-03-01]: the indexing here is very tricky.
642
+ attention_mask_bool = attention_mask.bool()
643
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
644
+ sequence_output = encoder_outputs[-1][
645
+ masked_tokens_mask[attention_mask_bool][subset_idx]
646
+ ]
647
+ if self.pooler is not None:
648
+ pool_input = encoder_outputs[-1][
649
+ first_col_mask[attention_mask_bool][subset_idx]
650
+ ]
651
+ pooled_output = self.pooler(pool_input, pool=False)
652
+ else:
653
+ pooled_output = None
654
+
655
+ if not output_all_encoded_layers:
656
+ encoder_outputs = sequence_output
657
+
658
+ if self.pooler is not None:
659
+ return encoder_outputs, pooled_output
660
+
661
+ return encoder_outputs, None
662
+
663
+
664
+ ###################
665
+ # Bert Heads
666
+ ###################
667
+ class BertLMPredictionHead(nn.Module):
668
+ def __init__(self, config, bert_model_embedding_weights):
669
+ super().__init__()
670
+ self.transform = BertPredictionHeadTransform(config)
671
+ # The output weights are the same as the input embeddings, but there is
672
+ # an output-only bias for each token.
673
+ self.decoder = nn.Linear(
674
+ bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0)
675
+ )
676
+ self.decoder.weight = bert_model_embedding_weights
677
+
678
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
679
+ hidden_states = self.transform(hidden_states)
680
+ hidden_states = self.decoder(hidden_states)
681
+ return hidden_states
682
+
683
+
684
+ class BertOnlyMLMHead(nn.Module):
685
+ def __init__(self, config, bert_model_embedding_weights):
686
+ super().__init__()
687
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
688
+
689
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
690
+ prediction_scores = self.predictions(sequence_output)
691
+ return prediction_scores
692
+
693
+
694
+ class BertOnlyNSPHead(nn.Module):
695
+ def __init__(self, config):
696
+ super().__init__()
697
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
698
+
699
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
700
+ seq_relationship_score = self.seq_relationship(pooled_output)
701
+ return seq_relationship_score
702
+
703
+
704
+ class BertForMaskedLM(BertPreTrainedModel):
705
+ def __init__(self, config):
706
+ super().__init__(config)
707
+
708
+ if config.is_decoder:
709
+ warnings.warn(
710
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
711
+ "bi-directional self-attention."
712
+ )
713
+
714
+ self.bert = BertModel(config, add_pooling_layer=False)
715
+ self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
716
+
717
+ # Initialize weights and apply final processing
718
+ self.post_init()
719
+
720
+ def get_output_embeddings(self):
721
+ return self.cls.predictions.decoder
722
+
723
+ def set_output_embeddings(self, new_embeddings):
724
+ self.cls.predictions.decoder = new_embeddings
725
+
726
+ def forward(
727
+ self,
728
+ input_ids: Optional[torch.Tensor] = None,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ token_type_ids: Optional[torch.Tensor] = None,
731
+ position_ids: Optional[torch.Tensor] = None,
732
+ head_mask: Optional[torch.Tensor] = None,
733
+ inputs_embeds: Optional[torch.Tensor] = None,
734
+ encoder_hidden_states: Optional[torch.Tensor] = None,
735
+ encoder_attention_mask: Optional[torch.Tensor] = None,
736
+ labels: Optional[torch.Tensor] = None,
737
+ output_attentions: Optional[bool] = None,
738
+ output_hidden_states: Optional[bool] = None,
739
+ return_dict: Optional[bool] = None,
740
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
741
+ # labels should be a `torch.LongTensor` of shape
742
+ # `(batch_size, sequence_length)`. These are used for computing the
743
+ # masked language modeling loss.
744
+ #
745
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
746
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
747
+ # (masked), the loss is only computed for the tokens with labels in `[0,
748
+ # ..., config.vocab_size]`
749
+ #
750
+ # Prediction scores are only computed for masked tokens and the (bs,
751
+ # seqlen) dimensions are flattened
752
+ if (input_ids is not None) == (inputs_embeds is not None):
753
+ raise ValueError("Must specify either input_ids or input_embeds!")
754
+
755
+ if labels is None:
756
+ masked_tokens_mask = None
757
+ else:
758
+ masked_tokens_mask = labels > 0
759
+
760
+ return_dict = (
761
+ return_dict if return_dict is not None else self.config.use_return_dict
762
+ )
763
+
764
+ outputs = self.bert(
765
+ input_ids,
766
+ attention_mask=attention_mask,
767
+ token_type_ids=token_type_ids,
768
+ position_ids=position_ids,
769
+ head_mask=head_mask,
770
+ inputs_embeds=inputs_embeds,
771
+ encoder_hidden_states=encoder_hidden_states,
772
+ encoder_attention_mask=encoder_attention_mask,
773
+ output_attentions=output_attentions,
774
+ output_hidden_states=output_hidden_states,
775
+ return_dict=return_dict,
776
+ masked_tokens_mask=masked_tokens_mask,
777
+ )
778
+
779
+ sequence_output = outputs[0]
780
+ prediction_scores = self.cls(sequence_output)
781
+
782
+ loss = None
783
+ if labels is not None:
784
+ # Compute loss
785
+ loss_fct = nn.CrossEntropyLoss()
786
+ masked_token_idx = torch.nonzero(
787
+ labels.flatten() > 0, as_tuple=False
788
+ ).flatten()
789
+ loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx])
790
+
791
+ assert input_ids is not None, "Coding error; please open an issue"
792
+ batch, seqlen = input_ids.shape[:2]
793
+ prediction_scores = rearrange(
794
+ index_put_first_axis(
795
+ prediction_scores, masked_token_idx, batch * seqlen
796
+ ),
797
+ "(b s) d -> b s d",
798
+ b=batch,
799
+ )
800
+
801
+ if not return_dict:
802
+ output = (prediction_scores,) + outputs[2:]
803
+ return ((loss,) + output) if loss is not None else output
804
+
805
+ return MaskedLMOutput(
806
+ loss=loss,
807
+ logits=prediction_scores,
808
+ hidden_states=outputs[0],
809
+ attentions=None,
810
+ )
811
+
812
+ def prepare_inputs_for_generation(
813
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs
814
+ ):
815
+ input_shape = input_ids.shape
816
+ effective_batch_size = input_shape[0]
817
+
818
+ # add a dummy token
819
+ if self.config.pad_token_id is None:
820
+ raise ValueError("The PAD token should be defined for generation")
821
+
822
+ attention_mask = torch.cat(
823
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
824
+ dim=-1,
825
+ )
826
+ dummy_token = torch.full(
827
+ (effective_batch_size, 1),
828
+ self.config.pad_token_id,
829
+ dtype=torch.long,
830
+ device=input_ids.device,
831
+ )
832
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
833
+
834
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
835
+
836
+
837
+ class BertForNextSentencePrediction(BertPreTrainedModel):
838
+ # TBD: Push in future commit
839
+ pass
840
+
841
+
842
+ class BertForSequenceClassification(BertPreTrainedModel):
843
+ """Bert Model transformer with a sequence classification/regression head.
844
+
845
+ This head is just a linear layer on top of the pooled output. Used for,
846
+ e.g., GLUE tasks.
847
+ """
848
+
849
+ def __init__(self, config):
850
+ super().__init__(config)
851
+ self.num_labels = config.num_labels
852
+ self.config = config
853
+
854
+ self.bert = BertModel(config)
855
+ classifier_dropout = (
856
+ config.classifier_dropout
857
+ if config.classifier_dropout is not None
858
+ else config.hidden_dropout_prob
859
+ )
860
+ self.dropout = nn.Dropout(classifier_dropout)
861
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
862
+
863
+ # Initialize weights and apply final processing
864
+ self.post_init()
865
+
866
+ def forward(
867
+ self,
868
+ input_ids: Optional[torch.Tensor] = None,
869
+ attention_mask: Optional[torch.Tensor] = None,
870
+ token_type_ids: Optional[torch.Tensor] = None,
871
+ position_ids: Optional[torch.Tensor] = None,
872
+ head_mask: Optional[torch.Tensor] = None,
873
+ inputs_embeds: Optional[torch.Tensor] = None,
874
+ labels: Optional[torch.Tensor] = None,
875
+ output_attentions: Optional[bool] = None,
876
+ output_hidden_states: Optional[bool] = None,
877
+ return_dict: Optional[bool] = None,
878
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
879
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
880
+ # Labels for computing the sequence classification/regression loss.
881
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
882
+ # If `config.num_labels == 1` a regression loss is computed
883
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
884
+ # is computed (cross-entropy).
885
+
886
+ return_dict = (
887
+ return_dict if return_dict is not None else self.config.use_return_dict
888
+ )
889
+
890
+ outputs = self.bert(
891
+ input_ids,
892
+ attention_mask=attention_mask,
893
+ token_type_ids=token_type_ids,
894
+ position_ids=position_ids,
895
+ head_mask=head_mask,
896
+ inputs_embeds=inputs_embeds,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ pooled_output = outputs[1]
903
+
904
+ pooled_output = self.dropout(pooled_output)
905
+ logits = self.classifier(pooled_output)
906
+
907
+ loss = None
908
+ if labels is not None:
909
+ # Compute loss
910
+ if self.config.problem_type is None:
911
+ if self.num_labels == 1:
912
+ self.config.problem_type = "regression"
913
+ elif self.num_labels > 1 and (
914
+ labels.dtype == torch.long or labels.dtype == torch.int
915
+ ):
916
+ self.config.problem_type = "single_label_classification"
917
+ else:
918
+ self.config.problem_type = "multi_label_classification"
919
+
920
+ if self.config.problem_type == "regression":
921
+ loss_fct = nn.MSELoss()
922
+ if self.num_labels == 1:
923
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
924
+ else:
925
+ loss = loss_fct(logits, labels)
926
+ elif self.config.problem_type == "single_label_classification":
927
+ loss_fct = nn.CrossEntropyLoss()
928
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
929
+ elif self.config.problem_type == "multi_label_classification":
930
+ loss_fct = nn.BCEWithLogitsLoss()
931
+ loss = loss_fct(logits, labels)
932
+
933
+ if not return_dict:
934
+ output = (logits,) + outputs[2:]
935
+ return ((loss,) + output) if loss is not None else output
936
+
937
+ return SequenceClassifierOutput(
938
+ loss=loss,
939
+ logits=logits,
940
+ hidden_states=outputs[0],
941
+ attentions=None,
942
+ )
bert_padding.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
5
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
6
+
7
+
8
+ from typing import Tuple, cast
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+
15
+ class IndexFirstAxis(torch.autograd.Function):
16
+ @staticmethod
17
+ def forward(ctx, input: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
18
+ """Get just the values of `input` which are at `indices`.
19
+
20
+ Arguments:
21
+ ctx: the autograd context object
22
+ input: (b, ...) 2+ dimensional tensor
23
+ indices: (num_idx) 1D tensor
24
+ """
25
+ ctx.save_for_backward(indices)
26
+ assert input.ndim >= 2
27
+ ctx.first_axis_dim, other_shape = (
28
+ input.shape[0],
29
+ input.shape[1:],
30
+ ) # type: ignore
31
+ second_dim = other_shape.numel() # product of sizes of all but first dimension
32
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
33
+ return torch.gather(
34
+ rearrange(input, "b ... -> b (...)"), # (b, ...) -> (b, second_dim)
35
+ 0,
36
+ repeat(
37
+ indices, "z -> z d", d=second_dim
38
+ ), # (indices,) -> (indices, second_dim)
39
+ ).reshape(
40
+ -1, *other_shape
41
+ ) # (num_idx, ...)
42
+
43
+ @staticmethod
44
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
45
+ (indices,) = ctx.saved_tensors
46
+ assert grad_output.ndim >= 2
47
+ other_shape = grad_output.shape[1:]
48
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
49
+ grad_input = torch.zeros(
50
+ [ctx.first_axis_dim, grad_output.shape[1]],
51
+ device=grad_output.device,
52
+ dtype=grad_output.dtype,
53
+ )
54
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
55
+ # grad_input[indices] = grad_output
56
+ grad_input.scatter_(
57
+ 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
58
+ )
59
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
60
+
61
+
62
+ index_first_axis = IndexFirstAxis.apply
63
+
64
+
65
+ class IndexPutFirstAxis(torch.autograd.Function):
66
+ @staticmethod
67
+ def forward(
68
+ ctx, values: torch.Tensor, indices: torch.Tensor, first_axis_dim
69
+ ) -> torch.Tensor:
70
+ ctx.save_for_backward(indices)
71
+ assert indices.ndim == 1
72
+ assert values.ndim >= 2
73
+ output = torch.zeros(
74
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
75
+ )
76
+ output[indices] = values
77
+ return output
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
81
+ (indices,) = ctx.saved_tensors
82
+ grad_values = grad_output[indices]
83
+ return grad_values, None, None
84
+
85
+
86
+ index_put_first_axis = IndexPutFirstAxis.apply
87
+
88
+
89
+ def unpad_input(
90
+ hidden_states: torch.Tensor,
91
+ attention_mask: torch.Tensor,
92
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
93
+ """Remove padding from input sequences.
94
+
95
+ Arguments:
96
+ hidden_states: (batch, seqlen, ...)
97
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
98
+
99
+ Returns:
100
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
101
+ indices: (total_nnz)
102
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
103
+ max_seqlen_in_batch: int ()
104
+ """
105
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
106
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
107
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
108
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
109
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
110
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
111
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
112
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
113
+ # so we write custom forward and backward to make it a bit faster.
114
+ hidden_states = cast(
115
+ torch.Tensor,
116
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
117
+ )
118
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
119
+
120
+
121
+ def unpad_input_only(
122
+ hidden_states: torch.Tensor,
123
+ attention_mask: torch.Tensor,
124
+ ) -> torch.Tensor:
125
+ """Like unpad_input, but only return the unpadded first tensor.
126
+
127
+ Save a small amount of overhead.
128
+
129
+ Arguments:
130
+ hidden_states: (batch, seqlen, ...)
131
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
132
+
133
+ Returns:
134
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
135
+ """
136
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
137
+ return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
138
+
139
+
140
+ def pad_input(
141
+ hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int
142
+ ) -> torch.Tensor:
143
+ """Add padding to sequences.
144
+
145
+ Arguments:
146
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
147
+ indices: (total_nnz)
148
+ batch: int batch_size
149
+ seqlen: int max sequence length
150
+
151
+ Returns:
152
+ hidden_states: (batch, seqlen, ...)
153
+ """
154
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
155
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "",
3
+ "alibi_starting_size": 512,
4
+ "architectures": [
5
+ "BertForMaskedLM"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_bert.BertConfig",
10
+ "AutoModel": "bert_layers.BertModel",
11
+ "AutoModelForMaskedLM": "bert_layers.BertForMaskedLM",
12
+ "AutoModelForSequenceClassification": "bert_layers.BertForSequenceClassification"
13
+ },
14
+ "classifier_dropout": null,
15
+ "gradient_checkpointing": false,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.1,
18
+ "hidden_size": 768,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 3072,
21
+ "layer_norm_eps": 1e-12,
22
+ "max_position_embeddings": 512,
23
+ "model_type": "bert",
24
+ "num_attention_heads": 12,
25
+ "num_hidden_layers": 12,
26
+ "position_embedding_type": "absolute",
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.28.0",
29
+ "type_vocab_size": 2,
30
+ "use_cache": true,
31
+ "vocab_size": 4096
32
+ }
configuration_bert.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from transformers import BertConfig as TransformersBertConfig
5
+
6
+
7
+ class BertConfig(TransformersBertConfig):
8
+ def __init__(
9
+ self,
10
+ alibi_starting_size: int = 512,
11
+ attention_probs_dropout_prob: float = 0.0,
12
+ **kwargs,
13
+ ):
14
+ """Configuration class for MosaicBert.
15
+
16
+ Args:
17
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
18
+ create when initializing the model. You should be able to ignore this parameter in most cases.
19
+ Defaults to 512.
20
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT
21
+ (otherwise, Flash Attention will be off by default). Defaults to 0.0.
22
+ """
23
+ super().__init__(
24
+ attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs
25
+ )
26
+ self.alibi_starting_size = alibi_starting_size
flash_attn_triton.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Triton implementation of Flash Attention.
5
+
6
+ # Copyright (c) 2022, Tri Dao.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ *Experimental* implementation of FlashAttention in Triton.
21
+ We use the FlashAttention implementation from Phil Tillet a starting point.
22
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
23
+
24
+ Changes:
25
+ - Implement both causal and non-causal attention.
26
+ - Implement both self-attention and cross-attention.
27
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
28
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
29
+ - Support attention bias.
30
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
31
+ - Make the backward for d=128 much faster by reducing register spilling.
32
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
33
+ small batch size * nheads.
34
+
35
+ Caution:
36
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
37
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
38
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
39
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
40
+ that there are none left for other head dimensions.
41
+ Differences between this Triton version and the CUDA version:
42
+ - Triton version doesn't support dropout.
43
+ - Triton forward is generally faster than CUDA forward.
44
+ - Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
45
+ It is slightly slower when headdim=128 and batch * nheads is large.
46
+ - Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
47
+ """
48
+
49
+ import math
50
+
51
+ import torch
52
+ import triton # type: ignore (reportMissingImports)
53
+ import triton.language as tl # type: ignore (reportMissingImports)
54
+ from einops import repeat
55
+
56
+
57
+ @triton.autotune(
58
+ configs=[
59
+ triton.BenchConfig({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
60
+ # This config has a race condition when EVEN_M == False, disabling it for now.
61
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
62
+ ],
63
+ key=[
64
+ "CACHE_KEY_SEQLEN_Q",
65
+ "CACHE_KEY_SEQLEN_K",
66
+ "BIAS_TYPE",
67
+ "IS_CAUSAL",
68
+ "BLOCK_HEADDIM",
69
+ ],
70
+ )
71
+ @triton.heuristics(
72
+ {
73
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
74
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
75
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
76
+ }
77
+ )
78
+ @triton.jit
79
+ def _fwd_kernel(
80
+ Q,
81
+ K,
82
+ V,
83
+ Bias,
84
+ Out,
85
+ Lse,
86
+ TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
87
+ softmax_scale,
88
+ stride_qb,
89
+ stride_qh,
90
+ stride_qm,
91
+ stride_kb,
92
+ stride_kh,
93
+ stride_kn,
94
+ stride_vb,
95
+ stride_vh,
96
+ stride_vn,
97
+ stride_bb,
98
+ stride_bh,
99
+ stride_bm,
100
+ stride_ob,
101
+ stride_oh,
102
+ stride_om,
103
+ nheads,
104
+ seqlen_q,
105
+ seqlen_k,
106
+ seqlen_q_rounded,
107
+ headdim,
108
+ CACHE_KEY_SEQLEN_Q,
109
+ CACHE_KEY_SEQLEN_K,
110
+ BIAS_TYPE: tl.constexpr,
111
+ IS_CAUSAL: tl.constexpr,
112
+ BLOCK_HEADDIM: tl.constexpr,
113
+ EVEN_M: tl.constexpr,
114
+ EVEN_N: tl.constexpr,
115
+ EVEN_HEADDIM: tl.constexpr,
116
+ BLOCK_M: tl.constexpr,
117
+ BLOCK_N: tl.constexpr,
118
+ ):
119
+ start_m = tl.program_id(0)
120
+ off_hb = tl.program_id(1)
121
+ off_b = off_hb // nheads
122
+ off_h = off_hb % nheads
123
+ # off_b = tl.program_id(1)
124
+ # off_h = tl.program_id(2)
125
+ # off_hb = off_b * nheads + off_h
126
+ # initialize offsets
127
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
128
+ offs_n = tl.arange(0, BLOCK_N)
129
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
130
+ # Initialize pointers to Q, K, V
131
+ # Adding parenthesis around indexing might use int32 math instead of int64 math?
132
+ # https://github.com/openai/triton/issues/741
133
+ # I'm seeing a tiny bit of difference (5-7us)
134
+ q_ptrs = (
135
+ Q
136
+ + off_b * stride_qb
137
+ + off_h * stride_qh
138
+ + (offs_m[:, None] * stride_qm + offs_d[None, :])
139
+ )
140
+ k_ptrs = (
141
+ K
142
+ + off_b * stride_kb
143
+ + off_h * stride_kh
144
+ + (offs_n[:, None] * stride_kn + offs_d[None, :])
145
+ )
146
+ v_ptrs = (
147
+ V
148
+ + off_b * stride_vb
149
+ + off_h * stride_vh
150
+ + (offs_n[:, None] * stride_vn + offs_d[None, :])
151
+ )
152
+ if BIAS_TYPE == "vector":
153
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
154
+ elif BIAS_TYPE == "matrix":
155
+ b_ptrs = (
156
+ Bias
157
+ + off_b * stride_bb
158
+ + off_h * stride_bh
159
+ + (offs_m[:, None] * stride_bm + offs_n[None, :])
160
+ )
161
+ else:
162
+ raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
163
+ # initialize pointer to m and l
164
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
165
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
166
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
167
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
168
+ # load q: it will stay in SRAM throughout
169
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
170
+ # tl.load(q_ptrs), we get the wrong output!
171
+ if EVEN_M & EVEN_N:
172
+ if EVEN_HEADDIM:
173
+ q = tl.load(q_ptrs)
174
+ else:
175
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
176
+ else:
177
+ if EVEN_HEADDIM:
178
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
179
+ else:
180
+ q = tl.load(
181
+ q_ptrs,
182
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
183
+ other=0.0,
184
+ )
185
+ # loop over k, v and update accumulator
186
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
187
+ for start_n in range(0, end_n, BLOCK_N):
188
+ start_n = tl.multiple_of(start_n, BLOCK_N)
189
+ # -- compute qk ----
190
+ if (
191
+ EVEN_N & EVEN_M
192
+ ): # If we just do "if EVEN_N", there seems to be some race condition
193
+ if EVEN_HEADDIM:
194
+ k = tl.load(k_ptrs + start_n * stride_kn)
195
+ else:
196
+ k = tl.load(
197
+ k_ptrs + start_n * stride_kn,
198
+ mask=offs_d[None, :] < headdim,
199
+ other=0.0,
200
+ )
201
+ else:
202
+ if EVEN_HEADDIM:
203
+ k = tl.load(
204
+ k_ptrs + start_n * stride_kn,
205
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
206
+ other=0.0,
207
+ )
208
+ else:
209
+ k = tl.load(
210
+ k_ptrs + start_n * stride_kn,
211
+ mask=((start_n + offs_n)[:, None] < seqlen_k)
212
+ & (offs_d[None, :] < headdim),
213
+ other=0.0,
214
+ )
215
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
216
+ qk += tl.dot(q, k, trans_b=True)
217
+ # Trying to combine the two masks seem to make the result wrong
218
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
219
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
220
+ if IS_CAUSAL:
221
+ qk += tl.where(
222
+ offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")
223
+ )
224
+ if BIAS_TYPE != "none":
225
+ if BIAS_TYPE == "vector":
226
+ if EVEN_N:
227
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
228
+ else:
229
+ bias = tl.load(
230
+ b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
231
+ ).to(tl.float32)
232
+ bias = bias[None, :]
233
+ elif BIAS_TYPE == "matrix":
234
+ if EVEN_M & EVEN_N:
235
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
236
+ else:
237
+ bias = tl.load(
238
+ b_ptrs + start_n,
239
+ mask=(offs_m[:, None] < seqlen_q)
240
+ & ((start_n + offs_n)[None, :] < seqlen_k),
241
+ other=0.0,
242
+ ).to(tl.float32)
243
+ else:
244
+ raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
245
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
246
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
247
+ # to multiply with softmax_scale here.
248
+ qk = qk * softmax_scale + bias
249
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
250
+ p = tl.exp(qk - m_ij[:, None])
251
+ else:
252
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
253
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
254
+ l_ij = tl.sum(p, 1)
255
+
256
+ # scale acc_o
257
+ acc_o_scale = tl.exp(m_i - m_ij)
258
+
259
+ # # -- update output accumulator --
260
+ # BUG: have to store and immediately load
261
+ tl.store(t_ptrs, acc_o_scale)
262
+ acc_o_scale = tl.load(t_ptrs)
263
+ acc_o = acc_o * acc_o_scale[:, None]
264
+ # update acc_o
265
+ if (
266
+ EVEN_N & EVEN_M
267
+ ): # If we just do "if EVEN_N", there seems to be some race condition
268
+ if EVEN_HEADDIM:
269
+ v = tl.load(v_ptrs + start_n * stride_vn)
270
+ else:
271
+ v = tl.load(
272
+ v_ptrs + start_n * stride_vn,
273
+ mask=offs_d[None, :] < headdim,
274
+ other=0.0,
275
+ )
276
+ else:
277
+ if EVEN_HEADDIM:
278
+ v = tl.load(
279
+ v_ptrs + start_n * stride_vn,
280
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
281
+ other=0.0,
282
+ )
283
+ else:
284
+ v = tl.load(
285
+ v_ptrs + start_n * stride_vn,
286
+ mask=((start_n + offs_n)[:, None] < seqlen_k)
287
+ & (offs_d[None, :] < headdim),
288
+ other=0.0,
289
+ )
290
+ p = p.to(v.dtype)
291
+ acc_o += tl.dot(p, v)
292
+
293
+ # -- update statistics
294
+ m_i = m_ij
295
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
296
+ lse_i = m_ij + tl.log(l_i_new)
297
+
298
+ o_scale = tl.exp(m_i - lse_i)
299
+ # BUG: have to store and immediately load
300
+ tl.store(t_ptrs, o_scale)
301
+ o_scale = tl.load(t_ptrs)
302
+ acc_o = acc_o * o_scale[:, None]
303
+ # rematerialize offsets to save registers
304
+ start_m = tl.program_id(0)
305
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
306
+ # write back l and m
307
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
308
+ tl.store(lse_ptrs, lse_i)
309
+ # initialize pointers to output
310
+ offs_n = tl.arange(0, BLOCK_HEADDIM)
311
+ out_ptrs = (
312
+ Out
313
+ + off_b * stride_ob
314
+ + off_h * stride_oh
315
+ + (offs_m[:, None] * stride_om + offs_n[None, :])
316
+ )
317
+ if EVEN_M:
318
+ if EVEN_HEADDIM:
319
+ tl.store(out_ptrs, acc_o)
320
+ else:
321
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
322
+ else:
323
+ if EVEN_HEADDIM:
324
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
325
+ else:
326
+ tl.store(
327
+ out_ptrs,
328
+ acc_o,
329
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
330
+ )
331
+
332
+
333
+ @triton.jit
334
+ def _bwd_preprocess_do_o_dot(
335
+ Out,
336
+ DO,
337
+ Delta,
338
+ stride_ob,
339
+ stride_oh,
340
+ stride_om,
341
+ stride_dob,
342
+ stride_doh,
343
+ stride_dom,
344
+ nheads,
345
+ seqlen_q,
346
+ seqlen_q_rounded,
347
+ headdim,
348
+ BLOCK_M: tl.constexpr,
349
+ BLOCK_HEADDIM: tl.constexpr,
350
+ ):
351
+ start_m = tl.program_id(0)
352
+ off_hb = tl.program_id(1)
353
+ off_b = off_hb // nheads
354
+ off_h = off_hb % nheads
355
+ # initialize offsets
356
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
357
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
358
+ # load
359
+ o = tl.load(
360
+ Out
361
+ + off_b * stride_ob
362
+ + off_h * stride_oh
363
+ + offs_m[:, None] * stride_om
364
+ + offs_d[None, :],
365
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
366
+ other=0.0,
367
+ ).to(tl.float32)
368
+ do = tl.load(
369
+ DO
370
+ + off_b * stride_dob
371
+ + off_h * stride_doh
372
+ + offs_m[:, None] * stride_dom
373
+ + offs_d[None, :],
374
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
375
+ other=0.0,
376
+ ).to(tl.float32)
377
+ delta = tl.sum(o * do, axis=1)
378
+ # write-back
379
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
380
+
381
+
382
+ @triton.jit
383
+ def _bwd_kernel_one_col_block(
384
+ start_n,
385
+ Q,
386
+ K,
387
+ V,
388
+ Bias,
389
+ DO,
390
+ DQ,
391
+ DK,
392
+ DV,
393
+ LSE,
394
+ D,
395
+ softmax_scale,
396
+ stride_qm,
397
+ stride_kn,
398
+ stride_vn,
399
+ stride_bm,
400
+ stride_dom,
401
+ stride_dqm,
402
+ stride_dkn,
403
+ stride_dvn,
404
+ seqlen_q,
405
+ seqlen_k,
406
+ headdim,
407
+ ATOMIC_ADD: tl.constexpr,
408
+ BIAS_TYPE: tl.constexpr,
409
+ IS_CAUSAL: tl.constexpr,
410
+ BLOCK_HEADDIM: tl.constexpr,
411
+ EVEN_M: tl.constexpr,
412
+ EVEN_N: tl.constexpr,
413
+ EVEN_HEADDIM: tl.constexpr,
414
+ BLOCK_M: tl.constexpr,
415
+ BLOCK_N: tl.constexpr,
416
+ ):
417
+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
418
+ begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
419
+ # initialize row/col offsets
420
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
421
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
422
+ offs_m = tl.arange(0, BLOCK_M)
423
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
424
+ # initialize pointers to value-like data
425
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
426
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
427
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
428
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
429
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
430
+ if BIAS_TYPE == "vector":
431
+ b_ptrs = Bias + offs_n
432
+ elif BIAS_TYPE == "matrix":
433
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
434
+ else:
435
+ raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
436
+ # initialize dv and dk
437
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
438
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
439
+ # k and v stay in SRAM throughout
440
+ # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
441
+ # if we just call tl.load(k_ptrs), we get the wrong output!
442
+ if EVEN_N & EVEN_M:
443
+ if EVEN_HEADDIM:
444
+ k = tl.load(k_ptrs)
445
+ v = tl.load(v_ptrs)
446
+ else:
447
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
448
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
449
+ else:
450
+ if EVEN_HEADDIM:
451
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
452
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
453
+ else:
454
+ k = tl.load(
455
+ k_ptrs,
456
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
457
+ other=0.0,
458
+ )
459
+ v = tl.load(
460
+ v_ptrs,
461
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
462
+ other=0.0,
463
+ )
464
+ # loop over rows
465
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
466
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
467
+ start_m = tl.multiple_of(start_m, BLOCK_M)
468
+ offs_m_curr = start_m + offs_m
469
+ # load q, k, v, do on-chip
470
+ # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
471
+ if EVEN_M & EVEN_HEADDIM:
472
+ q = tl.load(q_ptrs)
473
+ else:
474
+ if EVEN_HEADDIM:
475
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
476
+ else:
477
+ q = tl.load(
478
+ q_ptrs,
479
+ mask=(offs_m_curr[:, None] < seqlen_q)
480
+ & (offs_d[None, :] < headdim),
481
+ other=0.0,
482
+ )
483
+ # recompute p = softmax(qk, dim=-1).T
484
+ qk = tl.dot(q, k, trans_b=True)
485
+ # Trying to combine the two masks seem to make the result wrong
486
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
487
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
488
+ if IS_CAUSAL:
489
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
490
+ if BIAS_TYPE != "none":
491
+ if BIAS_TYPE == "vector":
492
+ if EVEN_N:
493
+ bias = tl.load(b_ptrs).to(tl.float32)
494
+ else:
495
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(
496
+ tl.float32
497
+ )
498
+ bias = bias[None, :]
499
+ elif BIAS_TYPE == "matrix":
500
+ if EVEN_M & EVEN_N:
501
+ bias = tl.load(b_ptrs).to(tl.float32)
502
+ else:
503
+ bias = tl.load(
504
+ b_ptrs,
505
+ mask=(offs_m_curr[:, None] < seqlen_q)
506
+ & (offs_n[None, :] < seqlen_k),
507
+ other=0.0,
508
+ ).to(tl.float32)
509
+ else:
510
+ raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
511
+ qk = qk * softmax_scale + bias
512
+ # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
513
+ # Also wrong for headdim=64.
514
+ if not (EVEN_M & EVEN_HEADDIM):
515
+ tl.debug_barrier()
516
+ lse_i = tl.load(LSE + offs_m_curr)
517
+ if BIAS_TYPE == "none":
518
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
519
+ else:
520
+ p = tl.exp(qk - lse_i[:, None])
521
+ # compute dv
522
+ # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
523
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
524
+ # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
525
+ # the output is correct.
526
+ if EVEN_M & EVEN_HEADDIM:
527
+ do = tl.load(do_ptrs)
528
+ else:
529
+ # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
530
+ do = tl.load(
531
+ do_ptrs,
532
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
533
+ other=0.0,
534
+ )
535
+ # if EVEN_M:
536
+ # if EVEN_HEADDIM:
537
+ # do = tl.load(do_ptrs)
538
+ # else:
539
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
540
+ # else:
541
+ # if EVEN_HEADDIM:
542
+ # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
543
+ # else:
544
+ # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
545
+ # & (offs_d[None, :] < headdim), other=0.0)
546
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
547
+ # compute dp = dot(v, do)
548
+ # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
549
+ # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
550
+ # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
551
+ if not (EVEN_M & EVEN_HEADDIM):
552
+ tl.debug_barrier()
553
+ dp = tl.dot(do, v, trans_b=True)
554
+ # There's a race condition for headdim=48
555
+ if not EVEN_HEADDIM:
556
+ tl.debug_barrier()
557
+ # compute ds = p * (dp - delta[:, None])
558
+ # Putting the subtraction after the dp matmul (instead of before) is slightly faster
559
+ Di = tl.load(D + offs_m_curr)
560
+ # Converting ds to q.dtype here reduces register pressure and makes it much faster
561
+ # for BLOCK_HEADDIM=128
562
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
563
+ # compute dk = dot(ds.T, q)
564
+ dk += tl.dot(ds, q, trans_a=True)
565
+ # compute dq
566
+ if not ATOMIC_ADD:
567
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
568
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
569
+ dq += tl.dot(ds, k)
570
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
571
+ else:
572
+ if EVEN_HEADDIM:
573
+ dq = tl.load(
574
+ dq_ptrs,
575
+ mask=offs_m_curr[:, None] < seqlen_q,
576
+ other=0.0,
577
+ eviction_policy="evict_last",
578
+ )
579
+ dq += tl.dot(ds, k)
580
+ tl.store(
581
+ dq_ptrs,
582
+ dq,
583
+ mask=offs_m_curr[:, None] < seqlen_q,
584
+ eviction_policy="evict_last",
585
+ )
586
+ else:
587
+ dq = tl.load(
588
+ dq_ptrs,
589
+ mask=(offs_m_curr[:, None] < seqlen_q)
590
+ & (offs_d[None, :] < headdim),
591
+ other=0.0,
592
+ eviction_policy="evict_last",
593
+ )
594
+ dq += tl.dot(ds, k)
595
+ tl.store(
596
+ dq_ptrs,
597
+ dq,
598
+ mask=(offs_m_curr[:, None] < seqlen_q)
599
+ & (offs_d[None, :] < headdim),
600
+ eviction_policy="evict_last",
601
+ )
602
+ else: # If we're parallelizing across the seqlen_k dimension
603
+ dq = tl.dot(ds, k)
604
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
605
+ tl.atomic_add(dq_ptrs, dq)
606
+ else:
607
+ if EVEN_HEADDIM:
608
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
609
+ else:
610
+ tl.atomic_add(
611
+ dq_ptrs,
612
+ dq,
613
+ mask=(offs_m_curr[:, None] < seqlen_q)
614
+ & (offs_d[None, :] < headdim),
615
+ )
616
+ # increment pointers
617
+ dq_ptrs += BLOCK_M * stride_dqm
618
+ q_ptrs += BLOCK_M * stride_qm
619
+ do_ptrs += BLOCK_M * stride_dom
620
+ if BIAS_TYPE == "matrix":
621
+ b_ptrs += BLOCK_M * stride_bm
622
+ # write-back
623
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
624
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
625
+ # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
626
+ # if we just call tl.store(dv_ptrs), there's a race condition
627
+ if EVEN_N & EVEN_M:
628
+ if EVEN_HEADDIM:
629
+ tl.store(dv_ptrs, dv)
630
+ tl.store(dk_ptrs, dk)
631
+ else:
632
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
633
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
634
+ else:
635
+ if EVEN_HEADDIM:
636
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
637
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
638
+ else:
639
+ tl.store(
640
+ dv_ptrs,
641
+ dv,
642
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
643
+ )
644
+ tl.store(
645
+ dk_ptrs,
646
+ dk,
647
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
648
+ )
649
+
650
+
651
+ def init_to_zero(name):
652
+ return lambda nargs: nargs[name].zero_()
653
+
654
+
655
+ @triton.autotune(
656
+ configs=[
657
+ triton.BenchConfig(
658
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
659
+ num_warps=8,
660
+ num_stages=1,
661
+ pre_hook=init_to_zero("DQ"),
662
+ ),
663
+ triton.BenchConfig(
664
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
665
+ num_warps=8,
666
+ num_stages=1,
667
+ pre_hook=init_to_zero("DQ"),
668
+ ),
669
+ # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
670
+ # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
671
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
672
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
673
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
674
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
675
+ ],
676
+ key=[
677
+ "CACHE_KEY_SEQLEN_Q",
678
+ "CACHE_KEY_SEQLEN_K",
679
+ "BIAS_TYPE",
680
+ "IS_CAUSAL",
681
+ "BLOCK_HEADDIM",
682
+ ],
683
+ )
684
+ @triton.heuristics(
685
+ {
686
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
687
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
688
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
689
+ }
690
+ )
691
+ @triton.jit
692
+ def _bwd_kernel(
693
+ Q,
694
+ K,
695
+ V,
696
+ Bias,
697
+ DO,
698
+ DQ,
699
+ DK,
700
+ DV,
701
+ LSE,
702
+ D,
703
+ softmax_scale,
704
+ stride_qb,
705
+ stride_qh,
706
+ stride_qm,
707
+ stride_kb,
708
+ stride_kh,
709
+ stride_kn,
710
+ stride_vb,
711
+ stride_vh,
712
+ stride_vn,
713
+ stride_bb,
714
+ stride_bh,
715
+ stride_bm,
716
+ stride_dob,
717
+ stride_doh,
718
+ stride_dom,
719
+ stride_dqb,
720
+ stride_dqh,
721
+ stride_dqm,
722
+ stride_dkb,
723
+ stride_dkh,
724
+ stride_dkn,
725
+ stride_dvb,
726
+ stride_dvh,
727
+ stride_dvn,
728
+ nheads,
729
+ seqlen_q,
730
+ seqlen_k,
731
+ seqlen_q_rounded,
732
+ headdim,
733
+ CACHE_KEY_SEQLEN_Q,
734
+ CACHE_KEY_SEQLEN_K,
735
+ BIAS_TYPE: tl.constexpr,
736
+ IS_CAUSAL: tl.constexpr,
737
+ BLOCK_HEADDIM: tl.constexpr,
738
+ SEQUENCE_PARALLEL: tl.constexpr,
739
+ EVEN_M: tl.constexpr,
740
+ EVEN_N: tl.constexpr,
741
+ EVEN_HEADDIM: tl.constexpr,
742
+ BLOCK_M: tl.constexpr,
743
+ BLOCK_N: tl.constexpr,
744
+ ):
745
+ off_hb = tl.program_id(1)
746
+ off_b = off_hb // nheads
747
+ off_h = off_hb % nheads
748
+ # offset pointers for batch/head
749
+ Q += off_b * stride_qb + off_h * stride_qh
750
+ K += off_b * stride_kb + off_h * stride_kh
751
+ V += off_b * stride_vb + off_h * stride_vh
752
+ DO += off_b * stride_dob + off_h * stride_doh
753
+ DQ += off_b * stride_dqb + off_h * stride_dqh
754
+ DK += off_b * stride_dkb + off_h * stride_dkh
755
+ DV += off_b * stride_dvb + off_h * stride_dvh
756
+ if BIAS_TYPE != "none":
757
+ Bias += off_b * stride_bb + off_h * stride_bh
758
+ # pointer to row-wise quantities in value-like data
759
+ D += off_hb * seqlen_q_rounded
760
+ LSE += off_hb * seqlen_q_rounded
761
+ if not SEQUENCE_PARALLEL:
762
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
763
+ for start_n in range(0, num_block_n):
764
+ _bwd_kernel_one_col_block(
765
+ start_n,
766
+ Q,
767
+ K,
768
+ V,
769
+ Bias,
770
+ DO,
771
+ DQ,
772
+ DK,
773
+ DV,
774
+ LSE,
775
+ D,
776
+ softmax_scale,
777
+ stride_qm,
778
+ stride_kn,
779
+ stride_vn,
780
+ stride_bm,
781
+ stride_dom,
782
+ stride_dqm,
783
+ stride_dkn,
784
+ stride_dvn,
785
+ seqlen_q,
786
+ seqlen_k,
787
+ headdim,
788
+ ATOMIC_ADD=False,
789
+ BIAS_TYPE=BIAS_TYPE,
790
+ IS_CAUSAL=IS_CAUSAL,
791
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
792
+ EVEN_M=EVEN_M,
793
+ EVEN_N=EVEN_N,
794
+ EVEN_HEADDIM=EVEN_HEADDIM,
795
+ BLOCK_M=BLOCK_M,
796
+ BLOCK_N=BLOCK_N,
797
+ )
798
+ else:
799
+ start_n = tl.program_id(0)
800
+ _bwd_kernel_one_col_block(
801
+ start_n,
802
+ Q,
803
+ K,
804
+ V,
805
+ Bias,
806
+ DO,
807
+ DQ,
808
+ DK,
809
+ DV,
810
+ LSE,
811
+ D,
812
+ softmax_scale,
813
+ stride_qm,
814
+ stride_kn,
815
+ stride_vn,
816
+ stride_bm,
817
+ stride_dom,
818
+ stride_dqm,
819
+ stride_dkn,
820
+ stride_dvn,
821
+ seqlen_q,
822
+ seqlen_k,
823
+ headdim,
824
+ ATOMIC_ADD=True,
825
+ BIAS_TYPE=BIAS_TYPE,
826
+ IS_CAUSAL=IS_CAUSAL,
827
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
828
+ EVEN_M=EVEN_M,
829
+ EVEN_N=EVEN_N,
830
+ EVEN_HEADDIM=EVEN_HEADDIM,
831
+ BLOCK_M=BLOCK_M,
832
+ BLOCK_N=BLOCK_N,
833
+ )
834
+
835
+
836
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
837
+ # shape constraints
838
+ batch, seqlen_q, nheads, d = q.shape
839
+ _, seqlen_k, _, _ = k.shape
840
+ assert k.shape == (batch, seqlen_k, nheads, d)
841
+ assert v.shape == (batch, seqlen_k, nheads, d)
842
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
843
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
844
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
845
+ assert q.is_cuda and k.is_cuda and v.is_cuda
846
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
847
+
848
+ has_bias = bias is not None
849
+ bias_type = "none"
850
+ if has_bias:
851
+ assert bias.dtype in [q.dtype, torch.float]
852
+ assert bias.is_cuda
853
+ assert bias.dim() == 4
854
+ if bias.stride(-1) != 1:
855
+ bias = bias.contiguous()
856
+ if bias.shape[2:] == (1, seqlen_k):
857
+ bias_type = "vector"
858
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
859
+ bias_type = "matrix"
860
+ else:
861
+ raise RuntimeError(
862
+ "Last 2 dimensions of bias must be (1, seqlen_k)"
863
+ " or (seqlen_q, seqlen_k)"
864
+ )
865
+ if bias.shape[:2] == (1, nheads):
866
+ bias = repeat(bias, "1 h ... -> b h ...", b=batch)
867
+ elif bias.shape[:2] == (batch, 1):
868
+ bias = repeat(bias, "b 1 ... -> b h ...", h=nheads)
869
+ elif bias.shape[:2] == (1, 1):
870
+ bias = repeat(bias, "1 h ... -> b h ...", b=batch)
871
+ bias = repeat(bias, "b 1 ... -> b h ...", h=nheads)
872
+ assert bias.shape[:2] == (
873
+ batch,
874
+ nheads,
875
+ ), f"First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}"
876
+ assert bias is not None # for type checking
877
+ bias_strides = (
878
+ (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
879
+ )
880
+
881
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
882
+ lse = torch.empty(
883
+ (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
884
+ )
885
+ tmp = torch.empty(
886
+ (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
887
+ )
888
+ o = torch.empty_like(q)
889
+
890
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
891
+ # BLOCK = 128
892
+ # num_warps = 4 if d <= 64 else 8
893
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
894
+ _fwd_kernel[grid]( # type: ignore
895
+ q,
896
+ k,
897
+ v,
898
+ bias,
899
+ o,
900
+ lse,
901
+ tmp,
902
+ softmax_scale,
903
+ q.stride(0),
904
+ q.stride(2),
905
+ q.stride(1),
906
+ k.stride(0),
907
+ k.stride(2),
908
+ k.stride(1),
909
+ v.stride(0),
910
+ v.stride(2),
911
+ v.stride(1),
912
+ *bias_strides,
913
+ o.stride(0),
914
+ o.stride(2),
915
+ o.stride(1),
916
+ nheads,
917
+ seqlen_q,
918
+ seqlen_k,
919
+ seqlen_q_rounded,
920
+ d,
921
+ seqlen_q // 32,
922
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
923
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
924
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
925
+ bias_type,
926
+ causal,
927
+ BLOCK_HEADDIM,
928
+ # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
929
+ # num_warps=num_warps,
930
+ # num_stages=1,
931
+ )
932
+ return o, lse, softmax_scale # softmax_scale could have been updated
933
+
934
+
935
+ def _flash_attn_backward(
936
+ do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
937
+ ):
938
+ # Make sure that the last dimension is contiguous
939
+ if do.stride(-1) != 1:
940
+ do = do.contiguous()
941
+ batch, seqlen_q, nheads, d = q.shape
942
+ _, seqlen_k, _, _ = k.shape
943
+ # assert d in {16, 32, 64, 128}
944
+ assert d <= 128
945
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
946
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
947
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
948
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
949
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
950
+ # dq_accum = torch.zeros_like(q, dtype=torch.float32)
951
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
952
+ delta = torch.empty_like(lse)
953
+ # delta = torch.zeros_like(lse)
954
+
955
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
956
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
957
+ _bwd_preprocess_do_o_dot[grid]( # type: ignore
958
+ o,
959
+ do,
960
+ delta,
961
+ o.stride(0),
962
+ o.stride(2),
963
+ o.stride(1),
964
+ do.stride(0),
965
+ do.stride(2),
966
+ do.stride(1),
967
+ nheads,
968
+ seqlen_q,
969
+ seqlen_q_rounded,
970
+ d,
971
+ BLOCK_M=128,
972
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
973
+ )
974
+
975
+ has_bias = bias is not None
976
+ bias_type = "none"
977
+ if has_bias:
978
+ assert bias.dtype in [q.dtype, torch.float]
979
+ assert bias.is_cuda
980
+ assert bias.dim() == 4
981
+ assert bias.stride(-1) == 1
982
+ if bias.shape[2:] == (1, seqlen_k):
983
+ bias_type = "vector"
984
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
985
+ bias_type = "matrix"
986
+ else:
987
+ raise RuntimeError(
988
+ "Last 2 dimensions of bias must be (1, seqlen_k)"
989
+ " or (seqlen_q, seqlen_k)"
990
+ )
991
+ if bias.shape[:2] == (1, nheads):
992
+ bias = repeat(bias, "1 h ... -> b h ...", b=batch)
993
+ elif bias.shape[:2] == (batch, 1):
994
+ bias = repeat(bias, "b 1 ... -> b h ...", h=nheads)
995
+ elif bias.shape[:2] == (1, 1):
996
+ bias = repeat(bias, "1 h ... -> b h ...", b=batch)
997
+ bias = repeat(bias, "b 1 ... -> b h ...", h=nheads)
998
+ assert bias.shape[:2] == (
999
+ batch,
1000
+ nheads,
1001
+ ), f"First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}"
1002
+ assert bias is not None # type checking
1003
+ bias_strides = (
1004
+ (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
1005
+ )
1006
+
1007
+ # BLOCK_M = 128
1008
+ # BLOCK_N = 64
1009
+ # num_warps = 4
1010
+ grid = lambda META: (
1011
+ triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
1012
+ batch * nheads,
1013
+ )
1014
+ _bwd_kernel[grid]( # type: ignore
1015
+ q,
1016
+ k,
1017
+ v,
1018
+ bias,
1019
+ do,
1020
+ dq_accum,
1021
+ dk,
1022
+ dv,
1023
+ lse,
1024
+ delta,
1025
+ softmax_scale,
1026
+ q.stride(0),
1027
+ q.stride(2),
1028
+ q.stride(1),
1029
+ k.stride(0),
1030
+ k.stride(2),
1031
+ k.stride(1),
1032
+ v.stride(0),
1033
+ v.stride(2),
1034
+ v.stride(1),
1035
+ *bias_strides,
1036
+ do.stride(0),
1037
+ do.stride(2),
1038
+ do.stride(1),
1039
+ dq_accum.stride(0),
1040
+ dq_accum.stride(2),
1041
+ dq_accum.stride(1),
1042
+ dk.stride(0),
1043
+ dk.stride(2),
1044
+ dk.stride(1),
1045
+ dv.stride(0),
1046
+ dv.stride(2),
1047
+ dv.stride(1),
1048
+ nheads,
1049
+ seqlen_q,
1050
+ seqlen_k,
1051
+ seqlen_q_rounded,
1052
+ d,
1053
+ seqlen_q // 32,
1054
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
1055
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
1056
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
1057
+ bias_type,
1058
+ causal,
1059
+ BLOCK_HEADDIM,
1060
+ # SEQUENCE_PARALLEL=False,
1061
+ # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1062
+ # num_warps=num_warps,
1063
+ # num_stages=1,
1064
+ )
1065
+ dq.copy_(dq_accum)
1066
+
1067
+
1068
+ class _FlashAttnQKVPackedFunc(torch.autograd.Function):
1069
+ @staticmethod
1070
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
1071
+ """Forward pass for packed FlashAttention.
1072
+
1073
+ Args:
1074
+ ctx: autograd context
1075
+ qkv: (batch, seqlen, 3, nheads, headdim)
1076
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
1077
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
1078
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
1079
+ causal (bool): whether to incorporate causal attention masking
1080
+ softmax_scale (float, optional): scale factor for softmax
1081
+ """
1082
+ # Make sure that the last dimension is contiguous
1083
+ if qkv.stride(-1) != 1:
1084
+ qkv = qkv.contiguous()
1085
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1086
+ qkv[:, :, 0],
1087
+ qkv[:, :, 1],
1088
+ qkv[:, :, 2],
1089
+ bias=bias,
1090
+ causal=causal,
1091
+ softmax_scale=softmax_scale,
1092
+ )
1093
+ ctx.save_for_backward(qkv, o, lse, bias)
1094
+ ctx.causal = causal
1095
+ return o
1096
+
1097
+ @staticmethod
1098
+ def backward(ctx, do):
1099
+ qkv, o, lse, bias = ctx.saved_tensors
1100
+ assert not ctx.needs_input_grad[
1101
+ 1
1102
+ ], "FlashAttention does not support bias gradient yet"
1103
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1104
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1105
+ with torch.inference_mode():
1106
+ dqkv = torch.empty_like(qkv)
1107
+ _flash_attn_backward(
1108
+ do,
1109
+ qkv[:, :, 0],
1110
+ qkv[:, :, 1],
1111
+ qkv[:, :, 2],
1112
+ o,
1113
+ lse,
1114
+ dqkv[:, :, 0],
1115
+ dqkv[:, :, 1],
1116
+ dqkv[:, :, 2],
1117
+ bias=bias,
1118
+ causal=ctx.causal,
1119
+ softmax_scale=ctx.softmax_scale,
1120
+ )
1121
+ return dqkv, None, None, None
1122
+
1123
+
1124
+ flash_attn_qkvpacked_func = _FlashAttnQKVPackedFunc.apply
1125
+
1126
+
1127
+ class _FlashAttnFunc(torch.autograd.Function):
1128
+ @staticmethod
1129
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
1130
+ """Forward pass for FlashAttention.
1131
+
1132
+ Args:
1133
+ ctx: autograd context
1134
+ q: (batch_size, seqlen_q, nheads, headdim)
1135
+ k: (batch_size, seqlen_k, nheads, headdim)
1136
+ v: (batch_size, seqlen_k, nheads, headdim)
1137
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1138
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1139
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1140
+ causal (bool): whether to incorporate causal attention masking
1141
+ softmax_scale (float, optional): scale factor for softmax
1142
+ """
1143
+ # Make sure that the last dimension is contiguous
1144
+ q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
1145
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1146
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
1147
+ )
1148
+ ctx.save_for_backward(q, k, v, o, lse, bias)
1149
+ ctx.causal = causal
1150
+ return o
1151
+
1152
+ @staticmethod
1153
+ def backward(ctx, do):
1154
+ q, k, v, o, lse, bias = ctx.saved_tensors
1155
+ assert not ctx.needs_input_grad[
1156
+ 3
1157
+ ], "FlashAttention does not support bias gradient yet"
1158
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1159
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1160
+ with torch.inference_mode():
1161
+ dq = torch.empty_like(q)
1162
+ dk = torch.empty_like(k)
1163
+ dv = torch.empty_like(v)
1164
+ _flash_attn_backward(
1165
+ do,
1166
+ q,
1167
+ k,
1168
+ v,
1169
+ o,
1170
+ lse,
1171
+ dq,
1172
+ dk,
1173
+ dv,
1174
+ bias=bias,
1175
+ causal=ctx.causal,
1176
+ softmax_scale=ctx.softmax_scale,
1177
+ )
1178
+ return dq, dk, dv, None, None, None
1179
+
1180
+
1181
+ flash_attn_func = _FlashAttnFunc.apply
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.28.0"
5
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ff39ec77a484dd01070a41bfd6e95cdd7247bec80fe357ab43a4be33687aeba
3
+ size 468354983
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"tokenizer_class": "PreTrainedTokenizerFast", "unk_token": "[UNK]", "cls_token": "[CLS]", "sep_token": "[SEP]", "pad_token": "[PAD]", "mask_token": "[MASK]"}