Taykhoom commited on
Commit
2154dc3
·
verified ·
1 Parent(s): a982f43

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ ---
4
+
5
+ # How to use
6
+ ```python
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(
10
+ "Taykhoom/Helix-mRNA-Wrapper",
11
+ trust_remote_code=True,
12
+ )
13
+
14
+ model = AutoModel.from_pretrained(
15
+ "Taykhoom/Helix-mRNA-Wrapper",
16
+ trust_remote_code=True,
17
+ ).eval()
18
+
19
+ dna = "ACGUAGCAUCGGAUCUAUCUAUCGACACUUGGUUAUCGAUCUACGAGCAUCUCGUUAGC"
20
+ inputs = tokenizer(
21
+ dna,
22
+ return_tensors="pt",
23
+ truncation=True,
24
+ padding="longest",
25
+ max_length=tokenizer.model_max_length,
26
+ return_special_tokens_mask=True,
27
+ )
28
+
29
+ special_tokens_mask = inputs["special_tokens_mask"]
30
+ attention_mask = 1 - special_tokens_mask
31
+
32
+ embedding = model(
33
+ input_ids=inputs["input_ids"],
34
+ attention_mask=attention_mask,
35
+ ).last_hidden_state # [1, sequence_length, 256]
36
+ ```
37
+
38
+ # Model Variants
39
+ The following `base_model` options are available for embedding generation. The short name (keys) or the full model name (values) can be specified using the `base_model` argument.
40
+ ```python
41
+ VARIANTS = {
42
+ "aido_rna_1m_mars": "genbio-ai/AIDO.RNA-1M-MARS",
43
+ "aido_rna_25m_mars": "genbio-ai/AIDO.RNA-25M-MARS",
44
+ "aido_rna_300m_mars": "genbio-ai/AIDO.RNA-300M-MARS",
45
+ "aido_rna_650m": "genbio-ai/AIDO.RNA-650M",
46
+ "aido_rna_650m_cds": "genbio-ai/AIDO.RNA-650M-CDS",
47
+ "aido_rna_1b600m": "genbio-ai/AIDO.RNA-1.6B",
48
+ "aido_rna_1b600m_cds": "genbio-ai/AIDO.RNA-1.6B-CDS",
49
+ }
50
+ ```
51
+
52
+ # Performance Vs Original Helix-mRNA Models
53
+
54
+ Verify that the modified code produces the same embeddings as the original Helix-mRNA models.
55
+
56
+ Original Helix-mRNA code snippet:
57
+ ```python
58
+ from helical.models.helix_mrna import HelixmRNA, HelixmRNAConfig
59
+ import torch
60
+
61
+ input_sequences = ["ACGUAGCAUCGGAUCUAUCUAUCGACACUUGGUUAUCGAUCUACGAGCAUCUCGUUAGC"]
62
+
63
+ helix_mrna_config = HelixmRNAConfig(batch_size=1)
64
+ helix_mrna = HelixmRNA(configurer=helix_mrna_config)
65
+
66
+ # prepare data for input to the model
67
+ processed_input_data = helix_mrna.process_data(input_sequences)
68
+
69
+ # generate the embeddings for the processed data
70
+ embedding = torch.Tensor(helix_mrna.get_embeddings(processed_input_data))
71
+
72
+ embedding_mean = torch.mean(embedding, dim=1) # [1, 256]
73
+ print(torch.mean(embedding_mean)) # Outputs tensor(-0.0033)
74
+
75
+ embedding_max = torch.max(embedding, dim=1)[0]
76
+ print(torch.mean(embedding_max)) # Outputs tensor(0.0989)
77
+
78
+ ```
79
+
80
+ Modified code snippet using the wrapper:
81
+ ```python
82
+ import torch
83
+ from transformers import AutoModel, AutoTokenizer
84
+
85
+ tokenizer = AutoTokenizer.from_pretrained(
86
+ "Taykhoom/Helix-mRNA-Wrapper",
87
+ trust_remote_code=True,
88
+ )
89
+
90
+ model = AutoModel.from_pretrained(
91
+ "Taykhoom/Helix-mRNA-Wrapper",
92
+ trust_remote_code=True,
93
+ ).eval()
94
+
95
+ dna = "ACGUAGCAUCGGAUCUAUCUAUCGACACUUGGUUAUCGAUCUACGAGCAUCUCGUUAGC"
96
+ inputs = tokenizer(
97
+ dna,
98
+ return_tensors="pt",
99
+ truncation=True,
100
+ padding="longest",
101
+ max_length=tokenizer.model_max_length,
102
+ return_special_tokens_mask=True,
103
+ )
104
+
105
+ special_tokens_mask = inputs["special_tokens_mask"]
106
+ attention_mask = 1 - special_tokens_mask
107
+
108
+ embedding = model(
109
+ input_ids=inputs["input_ids"],
110
+ attention_mask=attention_mask,
111
+ ).last_hidden_state # [1, sequence_length, 256]
112
+
113
+ embedding_mean = torch.mean(embedding, dim=1)
114
+ print(torch.mean(embedding_mean)) # Outputs tensor(-0.0033, grad_fn=<MeanBackward0>)
115
+
116
+ embedding_max = torch.max(embedding, dim=1)[0]
117
+ print(torch.mean(embedding_max)) # Outputs tensor(0.0989, grad_fn=<MeanBackward0>)
118
+ ```
119
+
120
+ # License Notice
121
+ This repository contains modified versions of Helical code.
122
+ Modifications include:
123
+ - Removal of reliance on helical package
124
+ - Removal of some ease-of-use embedding generation code (to standardize usage) and other checks (see original repository for more details)
125
+
126
+ Not all of the original functionality may be preserved. These changes were made to better integrate with the mRNABench framework which focuses on embedding generation for mRNA sequences. Most of the required code was directly copied from the original Helical repository with minimal changes, so please refer to the original repository for full details on the implementation.
127
+
128
+ When using this repository, please adhere to the original license terms of the Helical code. This license can be found in this directory as `LICENSE`.
129
+
130
+ # Original Repository
131
+ The original Helical repository can be found at: https://github.com/helicalAI/helical
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "helical-ai/Helix-mRNA",
3
+ "max_length": 12288,
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_helix_mrna.HelixmRNAConfig",
6
+ "AutoModel": "modeling_helix_mrna.HelixmRNAModel"
7
+ }
8
+ }
configuration_helix_mrna.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from transformers import PretrainedConfig
3
+
4
+ class HelixmRNAConfig(PretrainedConfig):
5
+ """HelixmRNAConfig class to store the configuration of the Helix-mRNA model.
6
+
7
+ Parameters
8
+ ----------
9
+ batch_size : int, optional, default=10
10
+ The batch size
11
+ device : Literal["cpu", "cuda"], optional, default="cpu"
12
+ The device to use. Either use "cuda" or "cpu".
13
+ max_length : int, optional, default=12288
14
+ The maximum length of the input sequence.
15
+ nproc: int, optional, default=1
16
+ Number of processes to use for data processing.
17
+ """
18
+
19
+ model_name: Literal["helical-ai/Helix-mRNA"] = "helical-ai/Helix-mRNA"
20
+
21
+ def __init__(
22
+ self,
23
+ max_length: int = 12288,
24
+ **kwargs,
25
+ ):
26
+
27
+ self.config = {
28
+ "model_name": self.model_name,
29
+ "max_length": max_length,
30
+ }
31
+
32
+ super().__init__(**kwargs)
33
+
34
+ @property
35
+ def layers_block_type(self):
36
+ layers = []
37
+ if self.num_hidden_layers != len(self.layers_block_type_string):
38
+ raise ValueError(
39
+ f"num_hidden_layers should be equal to the number of layers in layers_block_type_string, but got {self.num_hidden_layers} and {len(self.layers_block_type_string)}"
40
+ )
41
+ for layer in self.layers_block_type_string:
42
+ if layer == "M":
43
+ layers.append("mamba")
44
+ elif layer == "*":
45
+ layers.append("attention")
46
+ elif layer == "+":
47
+ layers.append("mlp")
48
+ return layers
modeling_helix_mrna.py ADDED
@@ -0,0 +1,1701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Helix-mRNA model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union, Dict, Any, List
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+
24
+ from transformers.cache_utils import DynamicCache
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import ModelOutput
29
+
30
+ from transformers.modeling_attn_mask_utils import (
31
+ AttentionMaskConverter,
32
+ )
33
+ from .configuration_helix_mrna import HelixmRNAConfig
34
+
35
+ from transformers.utils.import_utils import (
36
+ is_causal_conv1d_available,
37
+ is_flash_attn_2_available,
38
+ is_flash_attn_greater_or_equal_2_10,
39
+ is_mamba_2_ssm_available,
40
+ )
41
+
42
+ if is_flash_attn_2_available():
43
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
44
+
45
+ if is_mamba_2_ssm_available():
46
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
47
+ from mamba_ssm.ops.triton.ssd_combined import (
48
+ mamba_chunk_scan_combined,
49
+ mamba_split_conv1d_scan_combined,
50
+ )
51
+ else:
52
+ selective_state_update = None
53
+
54
+ if is_causal_conv1d_available():
55
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
56
+ else:
57
+ causal_conv1d_update, causal_conv1d_fn = None, None
58
+
59
+ is_fast_path_available = all(
60
+ (selective_state_update, causal_conv1d_fn, causal_conv1d_update)
61
+ )
62
+
63
+ # copied from transformers.models.mistral.modeling_mistral.pad_tensor_by_size
64
+
65
+
66
+ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
67
+ """
68
+ Padding x tensor with `pad_size` on the seq_len dim (dim=1)
69
+
70
+ Assumes that we only have tensors of either size 4 or 3
71
+ """
72
+ pad_shape = (
73
+ (0, 0, 0, 0, 0, pad_size, 0, 0)
74
+ if len(input_tensor.shape) == 4
75
+ else (0, 0, 0, pad_size, 0, 0)
76
+ )
77
+
78
+ return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
79
+
80
+
81
+ def reshape_into_chunks(input_tensor, pad_size, chunk_size):
82
+ """
83
+ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
84
+ simultaneously splitting it into chunk sequences.
85
+
86
+ Assumes that we only have tensors of either size 4 or 3
87
+ """
88
+ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
89
+ input_tensor = pad_tensor_by_size(input_tensor, pad_size)
90
+
91
+ if len(input_tensor.shape) == 3:
92
+ # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
93
+ return input_tensor.reshape(
94
+ input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]
95
+ )
96
+ else:
97
+ # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
98
+ return input_tensor.reshape(
99
+ input_tensor.shape[0],
100
+ -1,
101
+ chunk_size,
102
+ input_tensor.shape[2],
103
+ input_tensor.shape[3],
104
+ )
105
+
106
+
107
+ def segment_sum(input_tensor):
108
+ """
109
+ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
110
+ """
111
+ chunk_size = input_tensor.size(-1)
112
+ # 1. expand input tensor to have an additional dimension and repeat along that dimension
113
+ # [..., chunk_size] -> [..., chunk_size, chunk_size]
114
+ input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
115
+ # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
116
+ mask = torch.tril(
117
+ torch.ones(
118
+ chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool
119
+ ),
120
+ diagonal=-1,
121
+ )
122
+ input_tensor = input_tensor.masked_fill(~mask, 0)
123
+ # 3. compute actual cumsum
124
+ tensor_segsum = torch.cumsum(input_tensor, dim=-2)
125
+
126
+ # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
127
+ mask = torch.tril(
128
+ torch.ones(
129
+ chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool
130
+ ),
131
+ diagonal=0,
132
+ )
133
+ tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
134
+ return tensor_segsum
135
+
136
+
137
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
138
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
139
+ """
140
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
141
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
142
+ """
143
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
144
+ if n_rep == 1:
145
+ return hidden_states
146
+ hidden_states = hidden_states[:, :, None, :, :].expand(
147
+ batch, num_key_value_heads, n_rep, slen, head_dim
148
+ )
149
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
150
+
151
+
152
+ class HybridMambaAttentionDynamicCache(DynamicCache):
153
+ """
154
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
155
+ (which has a constant shape regardless of seq_len).
156
+
157
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
158
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
159
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
160
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
161
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
162
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
163
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
164
+ """
165
+
166
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
167
+ super().__init__()
168
+ self.dtype = dtype
169
+ self.layers_block_type = config.layers_block_type
170
+ self.has_previous_state = False # only used by mamba
171
+ intermediate_size = config.expand * config.hidden_size
172
+ ssm_state_size = config.state_size
173
+ conv_kernel_size = config.conv_kernel
174
+ self.seqlen_offset = 0
175
+ self.conv_states = []
176
+ self.ssm_states = []
177
+ self.transformer_layers = []
178
+ for i in range(config.num_hidden_layers):
179
+ if self.layers_block_type[i] == "mamba":
180
+ self.conv_states += [
181
+ torch.zeros(
182
+ batch_size,
183
+ intermediate_size,
184
+ conv_kernel_size,
185
+ device=device,
186
+ dtype=dtype,
187
+ )
188
+ ]
189
+ self.ssm_states += [
190
+ torch.zeros(
191
+ batch_size,
192
+ intermediate_size,
193
+ ssm_state_size,
194
+ device=device,
195
+ dtype=dtype,
196
+ )
197
+ ]
198
+ else:
199
+ self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
200
+ self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
201
+ self.transformer_layers.append(i)
202
+
203
+ self.key_cache = [
204
+ torch.tensor([[]] * batch_size, device=device)
205
+ for _ in range(config.num_hidden_layers)
206
+ ]
207
+ self.value_cache = [
208
+ torch.tensor([[]] * batch_size, device=device)
209
+ for _ in range(config.num_hidden_layers)
210
+ ]
211
+
212
+ def update(
213
+ self,
214
+ key_states: torch.Tensor,
215
+ value_states: torch.Tensor,
216
+ layer_idx: int,
217
+ cache_kwargs: Optional[Dict[str, Any]] = None,
218
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
219
+ # Update the cache
220
+ if self.key_cache[layer_idx].shape[-1] == 0:
221
+ self.key_cache[layer_idx] = key_states
222
+ self.value_cache[layer_idx] = value_states
223
+ else:
224
+ self.key_cache[layer_idx] = torch.cat(
225
+ [self.key_cache[layer_idx], key_states], dim=2
226
+ )
227
+ self.value_cache[layer_idx] = torch.cat(
228
+ [self.value_cache[layer_idx], value_states], dim=2
229
+ )
230
+
231
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
232
+
233
+ def reorder_cache(self, beam_idx: torch.LongTensor):
234
+ """Reorders the cache for beam search, given the selected beam indices."""
235
+ for layer_idx in range(len(self.key_cache)):
236
+ device = self.key_cache[layer_idx].device
237
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
238
+ 0, beam_idx.to(device)
239
+ )
240
+ device = self.value_cache[layer_idx].device
241
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
242
+ 0, beam_idx.to(device)
243
+ )
244
+
245
+ device = self.conv_states[layer_idx].device
246
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(
247
+ 0, beam_idx.to(device)
248
+ )
249
+ device = self.ssm_states[layer_idx].device
250
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(
251
+ 0, beam_idx.to(device)
252
+ )
253
+
254
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
255
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
256
+ # take any layer that contains cache and not empty tensor
257
+ layer_idx = (
258
+ self.transformer_layers[0]
259
+ if layer_idx not in self.transformer_layers
260
+ else layer_idx
261
+ )
262
+ if len(self.key_cache) <= layer_idx:
263
+ return 0
264
+ return self.key_cache[layer_idx].shape[-2]
265
+
266
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
267
+ raise NotImplementedError(
268
+ "HybridMambaAttentionDynamicCache does not have a legacy cache equivalent."
269
+ )
270
+
271
+ @classmethod
272
+ def from_legacy_cache(
273
+ cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
274
+ ) -> "DynamicCache":
275
+ raise NotImplementedError(
276
+ "HybridMambaAttentionDynamicCache does not have a legacy cache equivalent."
277
+ )
278
+
279
+
280
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba
281
+ class HelixmRNAAttention(nn.Module):
282
+ """
283
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
284
+ and "Generating Long Sequences with Sparse Transformers".
285
+ """
286
+
287
+ def __init__(
288
+ self, config: HelixmRNAConfig, layer_idx: Optional[int] = None
289
+ ):
290
+ super().__init__()
291
+ self.config = config
292
+ self.layer_idx = layer_idx
293
+ if layer_idx is None:
294
+ print(
295
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
296
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
297
+ "when creating this class."
298
+ )
299
+
300
+ self.hidden_size = config.hidden_size
301
+ self.num_heads = config.num_attention_heads
302
+ self.head_dim = self.hidden_size // self.num_heads
303
+ self.num_key_value_heads = config.num_key_value_heads
304
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
305
+ self.is_causal = True
306
+ self.attention_dropout = config.attention_dropout
307
+
308
+ if (self.head_dim * self.num_heads) != self.hidden_size:
309
+ raise ValueError(
310
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
311
+ f" and `num_heads`: {self.num_heads})."
312
+ )
313
+ self.q_proj = nn.Linear(
314
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
315
+ )
316
+ self.k_proj = nn.Linear(
317
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
318
+ )
319
+ self.v_proj = nn.Linear(
320
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
321
+ )
322
+ self.o_proj = nn.Linear(
323
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
324
+ )
325
+
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
332
+ output_attentions: bool = False,
333
+ use_cache: bool = False,
334
+ cache_position: Optional[torch.LongTensor] = None,
335
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
336
+ bsz, q_len, _ = hidden_states.size()
337
+
338
+ query_states = self.q_proj(hidden_states)
339
+ key_states = self.k_proj(hidden_states)
340
+ value_states = self.v_proj(hidden_states)
341
+
342
+ query_states = query_states.view(
343
+ bsz, q_len, self.num_heads, self.head_dim
344
+ ).transpose(1, 2)
345
+ key_states = key_states.view(
346
+ bsz, q_len, self.num_key_value_heads, self.head_dim
347
+ ).transpose(1, 2)
348
+ value_states = value_states.view(
349
+ bsz, q_len, self.num_key_value_heads, self.head_dim
350
+ ).transpose(1, 2)
351
+
352
+ if past_key_value is not None:
353
+ key_states, value_states = past_key_value.update(
354
+ key_states, value_states, self.layer_idx
355
+ )
356
+
357
+ # repeat k/v heads if n_kv_heads < n_heads
358
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
359
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
360
+
361
+ attn_weights = torch.matmul(
362
+ query_states, key_states.transpose(2, 3)
363
+ ) / math.sqrt(self.head_dim)
364
+
365
+ if attention_mask is not None: # no matter the length, we just slice it
366
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
367
+ attn_weights = attn_weights + causal_mask
368
+
369
+ # upcast attention to fp32
370
+ attn_weights = nn.functional.softmax(
371
+ attn_weights, dim=-1, dtype=torch.float32
372
+ ).to(query_states.dtype)
373
+ attn_weights = nn.functional.dropout(
374
+ attn_weights, p=self.attention_dropout, training=self.training
375
+ )
376
+ attn_output = torch.matmul(attn_weights, value_states)
377
+
378
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
379
+ raise ValueError(
380
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
381
+ f" {attn_output.size()}"
382
+ )
383
+
384
+ attn_output = attn_output.transpose(1, 2).contiguous()
385
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
386
+
387
+ attn_output = self.o_proj(attn_output)
388
+
389
+ if not output_attentions:
390
+ attn_weights = None
391
+
392
+ return attn_output, attn_weights, past_key_value
393
+
394
+
395
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
396
+ class HelixmRNAFlashAttention2(HelixmRNAAttention):
397
+ """
398
+ Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
399
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
400
+ flash attention and deal with padding tokens in case the input contains any of them.
401
+ """
402
+
403
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
404
+ def __init__(self, *args, **kwargs):
405
+ super().__init__(*args, **kwargs)
406
+
407
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
408
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
409
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
410
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
411
+
412
+ def forward(
413
+ self,
414
+ hidden_states: torch.Tensor,
415
+ attention_mask: Optional[torch.Tensor] = None,
416
+ position_ids: Optional[torch.LongTensor] = None,
417
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
418
+ output_attentions: bool = False,
419
+ use_cache: bool = False,
420
+ cache_position: Optional[torch.LongTensor] = None,
421
+ **kwargs,
422
+ ):
423
+ bsz, q_len, _ = hidden_states.size()
424
+
425
+ query_states = self.q_proj(hidden_states)
426
+ key_states = self.k_proj(hidden_states)
427
+ value_states = self.v_proj(hidden_states)
428
+
429
+ # Flash attention requires the input to have the shape
430
+ # batch_size x seq_length x head_dim x hidden_dim
431
+ # therefore we just need to keep the original shape
432
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
433
+ key_states = key_states.view(
434
+ bsz, q_len, self.num_key_value_heads, self.head_dim
435
+ ).transpose(1, 2)
436
+ value_states = value_states.view(
437
+ bsz, q_len, self.num_key_value_heads, self.head_dim
438
+ ).transpose(1, 2)
439
+
440
+ if past_key_value is not None:
441
+ key_states, value_states = past_key_value.update(
442
+ key_states, value_states, self.layer_idx
443
+ )
444
+
445
+ # repeat k/v heads if n_kv_heads < n_heads
446
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
447
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
448
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
449
+
450
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
451
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
452
+ # cast them back in float16 just to be sure everything works as expected.
453
+ input_dtype = query_states.dtype
454
+ if input_dtype == torch.float32:
455
+ if torch.is_autocast_enabled():
456
+ target_dtype = torch.get_autocast_gpu_dtype()
457
+ # Handle the case where the model is quantized
458
+ elif hasattr(self.config, "_pre_quantization_dtype"):
459
+ target_dtype = self.config._pre_quantization_dtype
460
+ else:
461
+ target_dtype = self.q_proj.weight.dtype
462
+
463
+ print(
464
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
465
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
466
+ f" {target_dtype}."
467
+ )
468
+
469
+ query_states = query_states.to(target_dtype)
470
+ key_states = key_states.to(target_dtype)
471
+ value_states = value_states.to(target_dtype)
472
+
473
+ # Reashape to the expected shape for Flash Attention
474
+ key_states = key_states.transpose(1, 2)
475
+ value_states = value_states.transpose(1, 2)
476
+
477
+ attn_output = _flash_attention_forward(
478
+ query_states,
479
+ key_states,
480
+ value_states,
481
+ attention_mask,
482
+ q_len,
483
+ dropout=dropout_rate,
484
+ sliding_window=getattr(self.config, "sliding_window", None),
485
+ is_causal=self.is_causal,
486
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
487
+ )
488
+
489
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
490
+ attn_output = self.o_proj(attn_output)
491
+
492
+ if not output_attentions:
493
+ attn_weights = None
494
+
495
+ return attn_output, attn_weights, past_key_value
496
+
497
+
498
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
499
+ class HelixmRNASdpaAttention(HelixmRNAAttention):
500
+ """
501
+ Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
502
+ `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
503
+ SDPA API.
504
+ """
505
+
506
+ # Adapted from JambaAttention.forward
507
+ def forward(
508
+ self,
509
+ hidden_states: torch.Tensor,
510
+ attention_mask: Optional[torch.Tensor] = None,
511
+ position_ids: Optional[torch.LongTensor] = None,
512
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
513
+ output_attentions: bool = False,
514
+ use_cache: bool = False,
515
+ cache_position: Optional[torch.LongTensor] = None,
516
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
517
+ if output_attentions:
518
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
519
+ print(
520
+ "JambaModel is using JambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
521
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
522
+ )
523
+ return super().forward(
524
+ hidden_states=hidden_states,
525
+ attention_mask=attention_mask,
526
+ position_ids=position_ids,
527
+ past_key_value=past_key_value,
528
+ output_attentions=output_attentions,
529
+ use_cache=use_cache,
530
+ )
531
+
532
+ bsz, q_len, _ = hidden_states.size()
533
+
534
+ query_states = self.q_proj(hidden_states)
535
+ key_states = self.k_proj(hidden_states)
536
+ value_states = self.v_proj(hidden_states)
537
+
538
+ query_states = query_states.view(
539
+ bsz, q_len, self.num_heads, self.head_dim
540
+ ).transpose(1, 2)
541
+ key_states = key_states.view(
542
+ bsz, q_len, self.num_key_value_heads, self.head_dim
543
+ ).transpose(1, 2)
544
+ value_states = value_states.view(
545
+ bsz, q_len, self.num_key_value_heads, self.head_dim
546
+ ).transpose(1, 2)
547
+
548
+ if past_key_value is not None:
549
+ key_states, value_states = past_key_value.update(
550
+ key_states, value_states, self.layer_idx
551
+ )
552
+
553
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
554
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
555
+
556
+ causal_mask = attention_mask
557
+ if attention_mask is not None:
558
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
559
+
560
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
561
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
562
+ if query_states.device.type == "cuda" and attention_mask is not None:
563
+ query_states = query_states.contiguous()
564
+ key_states = key_states.contiguous()
565
+ value_states = value_states.contiguous()
566
+
567
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
568
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
569
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
570
+ is_causal = (
571
+ True if self.is_causal and causal_mask is None and q_len > 1 else False
572
+ )
573
+
574
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
575
+ query_states,
576
+ key_states,
577
+ value_states,
578
+ attn_mask=causal_mask,
579
+ dropout_p=self.attention_dropout if self.training else 0.0,
580
+ is_causal=is_causal,
581
+ )
582
+
583
+ attn_output = attn_output.transpose(1, 2).contiguous()
584
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
585
+
586
+ attn_output = self.o_proj(attn_output)
587
+
588
+ return attn_output, None, past_key_value
589
+
590
+
591
+ HelixmRNA_ATTENTION_CLASSES = {
592
+ "eager": HelixmRNAAttention,
593
+ "flash_attention_2": HelixmRNAFlashAttention2,
594
+ "sdpa": HelixmRNASdpaAttention,
595
+ }
596
+
597
+
598
+ class Mamba2Cache:
599
+ """
600
+ Arguments:
601
+ config: Mamba2Config
602
+ batch_size: int
603
+ dtype: torch.dtype
604
+ device: torch.device
605
+
606
+ Attributes:
607
+ seqlen_offset: int
608
+ dtype: torch.dtype
609
+ conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
610
+ ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
611
+ """
612
+
613
+ def __init__(
614
+ self,
615
+ config: HelixmRNAConfig,
616
+ batch_size: int,
617
+ dtype: torch.dtype = torch.float16,
618
+ device: Optional[str] = None,
619
+ ):
620
+ self.seqlen_offset = 0
621
+ self.dtype = dtype
622
+ self.conv_kernel_size = config.conv_kernel
623
+ self.intermediate_size = int(config.expand * config.hidden_size)
624
+
625
+ self.conv_states = {
626
+ i: torch.zeros(
627
+ batch_size,
628
+ self.intermediate_size + 2 * config.n_groups * config.state_size,
629
+ self.conv_kernel_size,
630
+ device=device,
631
+ dtype=dtype,
632
+ )
633
+ for i in range(config.num_hidden_layers)
634
+ }
635
+ self.ssm_states = {
636
+ i: torch.zeros(
637
+ batch_size,
638
+ config.num_heads,
639
+ config.head_dim,
640
+ config.state_size,
641
+ device=device,
642
+ dtype=dtype,
643
+ )
644
+ for i in range(config.num_hidden_layers)
645
+ }
646
+ self.activation = config.hidden_act
647
+ self.act = ACT2FN[config.hidden_act]
648
+
649
+ def update_conv_state(
650
+ self,
651
+ layer_idx: int,
652
+ new_conv_state: torch.Tensor,
653
+ cache_position: torch.LongTensor,
654
+ ) -> torch.Tensor:
655
+ conv_state = self.conv_states[layer_idx]
656
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
657
+
658
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
659
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
660
+ self.conv_states[layer_idx].zero_()
661
+ self.conv_states[layer_idx] += conv_state
662
+ return self.conv_states[layer_idx]
663
+
664
+ def reset(self):
665
+ self.conv_states.zero_()
666
+ self.ssm_states.zero_()
667
+
668
+
669
+ class MambaRMSNormGated(torch.nn.Module):
670
+ def __init__(self, hidden_size, eps=1e-6):
671
+ super().__init__()
672
+ self.weight = nn.Parameter(torch.ones(hidden_size))
673
+ self.variance_epsilon = eps
674
+
675
+ def forward(self, hidden_states, gate=None):
676
+ input_dtype = hidden_states.dtype
677
+ hidden_states = hidden_states.to(torch.float32)
678
+
679
+ if gate is not None:
680
+ hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
681
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
682
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
683
+
684
+ return self.weight * hidden_states.to(input_dtype)
685
+
686
+
687
+ class Mamba2Mixer(nn.Module):
688
+ """
689
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
690
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
691
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
692
+ and is why Mamba is called **selective** state spaces)
693
+ """
694
+
695
+ def __init__(self, config: HelixmRNAConfig, layer_idx: int):
696
+ super().__init__()
697
+ self.num_heads = config.num_heads
698
+ self.hidden_size = config.hidden_size
699
+ self.ssm_state_size = config.state_size
700
+ self.conv_kernel_size = config.conv_kernel
701
+ self.intermediate_size = int(config.expand * self.hidden_size)
702
+ self.time_step_rank = int(config.time_step_rank)
703
+ self.layer_idx = layer_idx
704
+ self.use_conv_bias = config.use_conv_bias
705
+ self.activation = config.hidden_act
706
+ self.act = ACT2FN[config.hidden_act]
707
+
708
+ self.layer_norm_epsilon = config.layer_norm_epsilon
709
+ self.rms_norm = config.rms_norm
710
+
711
+ self.n_groups = config.n_groups
712
+ self.head_dim = config.head_dim
713
+ self.chunk_size = config.chunk_size
714
+
715
+ self.time_step_limit = config.time_step_limit
716
+ self.time_step_min = config.time_step_min
717
+ self.time_step_max = config.time_step_max
718
+
719
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
720
+ self.conv1d = nn.Conv1d(
721
+ in_channels=self.conv_dim,
722
+ out_channels=self.conv_dim,
723
+ bias=config.use_conv_bias,
724
+ kernel_size=config.conv_kernel,
725
+ groups=self.conv_dim,
726
+ padding=config.conv_kernel - 1,
727
+ )
728
+
729
+ # projection of the input hidden states
730
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
731
+ self.in_proj = nn.Linear(
732
+ self.hidden_size,
733
+ projection_size,
734
+ bias=config.use_bias,
735
+ )
736
+ # selective projection used to make dt, B and C input dependant
737
+
738
+ # time step projection (discretization)
739
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
740
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
741
+
742
+ # S4D real initialization. These are not discretized!
743
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
744
+ A = torch.arange(1, self.num_heads + 1)
745
+ self.A_log = nn.Parameter(torch.log(A))
746
+ self.A_log._no_weight_decay = True
747
+ self.norm = MambaRMSNormGated(
748
+ self.intermediate_size, eps=self.layer_norm_epsilon
749
+ )
750
+ self.D = nn.Parameter(torch.ones(self.num_heads))
751
+ self.D._no_weight_decay = True
752
+
753
+ self.out_proj = nn.Linear(
754
+ self.intermediate_size, self.hidden_size, bias=config.use_bias
755
+ )
756
+ self.use_bias = config.use_bias
757
+
758
+ if not is_fast_path_available:
759
+ print(
760
+ "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
761
+ " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
762
+ " https://github.com/Dao-AILab/causal-conv1d"
763
+ )
764
+
765
+ def cuda_kernels_forward(
766
+ self,
767
+ hidden_states: torch.Tensor,
768
+ cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
769
+ cache_position: Optional[torch.LongTensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ ):
772
+ # set up dimensions for reshapes later
773
+
774
+ batch_size, seq_len, _ = hidden_states.shape
775
+ groups_time_state_size = self.n_groups * self.ssm_state_size
776
+ d_to_remove = (
777
+ 2 * self.intermediate_size
778
+ + 2 * self.n_groups * self.ssm_state_size
779
+ + self.num_heads
780
+ )
781
+
782
+ # getting projected states from cache if it exists
783
+ if cache_params is not None and cache_params.seqlen_offset > 0:
784
+ in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
785
+ d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
786
+ split_projection_dim = [
787
+ d_mlp,
788
+ d_mlp,
789
+ self.intermediate_size,
790
+ self.conv_dim,
791
+ self.num_heads,
792
+ ]
793
+ _, _, gate, hidden_states_B_C, dt = torch.split(
794
+ in_projected_states, split_projection_dim, dim=-1
795
+ )
796
+
797
+ hidden_states_B_C = causal_conv1d_update(
798
+ hidden_states_B_C,
799
+ cache_params.conv_states[self.layer_idx],
800
+ self.conv1d.weight.squeeze(1),
801
+ self.conv1d.bias,
802
+ self.activation,
803
+ )
804
+
805
+ hidden_states, B, C = torch.split(
806
+ hidden_states_B_C,
807
+ [
808
+ self.intermediate_size,
809
+ groups_time_state_size,
810
+ groups_time_state_size,
811
+ ],
812
+ dim=-1,
813
+ )
814
+ A = -torch.exp(self.A_log.float()) # (nheads,)
815
+
816
+ A = (
817
+ A[:, None, ...][:, :, None]
818
+ .expand(-1, self.head_dim, self.ssm_state_size)
819
+ .to(dtype=torch.float32)
820
+ )
821
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
822
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
823
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
824
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
825
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
826
+ hidden_states_reshaped = hidden_states.view(
827
+ batch_size, self.num_heads, self.head_dim
828
+ )
829
+ hidden_states = selective_state_update(
830
+ cache_params.ssm_states[self.layer_idx],
831
+ hidden_states_reshaped,
832
+ dt,
833
+ A,
834
+ B,
835
+ C,
836
+ D,
837
+ z=None,
838
+ dt_bias=dt_bias,
839
+ dt_softplus=True,
840
+ )
841
+ hidden_states = hidden_states.view(
842
+ batch_size, self.num_heads * self.head_dim
843
+ )
844
+ hidden_states = self.norm(hidden_states, gate)
845
+ out = self.out_proj(hidden_states)[:, None, ...]
846
+ # if no cache is found, calling the kernel
847
+ else:
848
+ if (
849
+ attention_mask is not None
850
+ and attention_mask.shape[1] > 1
851
+ and attention_mask.shape[0] > 1
852
+ ):
853
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
854
+ dtype = hidden_states.dtype
855
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
856
+ # 1. Gated MLP's linear projection
857
+ projected_states = self.in_proj(hidden_states)
858
+ A = -torch.exp(
859
+ self.A_log.float()
860
+ ) # (num_heads) or (intermediate_size, state_size)
861
+ dt_limit_kwargs = (
862
+ {}
863
+ if self.time_step_limit == (0.0, float("inf"))
864
+ else {"dt_limit": self.time_step_limit}
865
+ )
866
+
867
+ if self.training and cache_params is None:
868
+ out, ssm_state = mamba_split_conv1d_scan_combined(
869
+ projected_states,
870
+ self.conv1d.weight.squeeze(1),
871
+ self.conv1d.bias,
872
+ self.dt_bias,
873
+ A,
874
+ D=self.D,
875
+ chunk_size=self.chunk_size,
876
+ seq_idx=None, # was seq_idx
877
+ activation=self.activation,
878
+ rmsnorm_weight=self.norm.weight,
879
+ rmsnorm_eps=self.norm.variance_epsilon,
880
+ outproj_weight=self.out_proj.weight,
881
+ outproj_bias=self.out_proj.bias,
882
+ headdim=self.head_dim,
883
+ ngroups=self.n_groups,
884
+ norm_before_gate=False,
885
+ return_final_states=True,
886
+ **dt_limit_kwargs,
887
+ )
888
+
889
+ else:
890
+ gate, hidden_states_B_C, time_step = torch.split(
891
+ projected_states,
892
+ [self.intermediate_size, self.conv_dim, self.num_heads],
893
+ dim=-1,
894
+ )
895
+
896
+ # 1D Convolution
897
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
898
+ hidden_states_B_C = self.act(
899
+ self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[
900
+ :, :seq_len
901
+ ]
902
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
903
+ else:
904
+ hidden_states_B_C = causal_conv1d_fn(
905
+ x=hidden_states_B_C.transpose(1, 2),
906
+ weight=self.conv1d.weight.squeeze(1),
907
+ bias=self.conv1d.bias,
908
+ activation=self.activation,
909
+ ).transpose(1, 2)[:, :seq_len]
910
+ hidden_states, B, C = torch.split(
911
+ hidden_states_B_C,
912
+ [
913
+ self.intermediate_size,
914
+ groups_time_state_size,
915
+ groups_time_state_size,
916
+ ],
917
+ dim=-1,
918
+ )
919
+ if (
920
+ attention_mask is not None
921
+ and attention_mask.shape[1] > 1
922
+ and attention_mask.shape[0] > 1
923
+ ):
924
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
925
+ dtype = hidden_states.dtype
926
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(
927
+ dtype
928
+ )
929
+ scan_output, ssm_state = mamba_chunk_scan_combined(
930
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
931
+ time_step,
932
+ A,
933
+ B.view(batch_size, seq_len, self.n_groups, -1),
934
+ C.view(batch_size, seq_len, self.n_groups, -1),
935
+ chunk_size=self.chunk_size,
936
+ D=self.D,
937
+ z=None,
938
+ seq_idx=None,
939
+ return_final_states=True,
940
+ dt_bias=self.dt_bias,
941
+ dt_softplus=True,
942
+ **dt_limit_kwargs,
943
+ )
944
+ if ssm_state is not None and cache_params is not None:
945
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
946
+ scan_output = scan_output.view(batch_size, seq_len, -1)
947
+ # Multiply "gate" branch and apply extra normalization layer
948
+ scan_output = self.norm(scan_output, gate)
949
+ out = self.out_proj(scan_output)
950
+ return out
951
+
952
+ # fmt: off
953
+ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
954
+ batch_size, seq_len, _ = input_states.shape
955
+ dtype = input_states.dtype
956
+ # Gated MLP's linear projection
957
+ projected_states = self.in_proj(input_states.squeeze(1))
958
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
959
+ _, _, gate, hidden_states, dt = projected_states.split(
960
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
961
+ )
962
+
963
+ # Convolution sequence transformation
964
+ if cache_params is not None:
965
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
966
+ ssm_state = ssm_state.to(hidden_states.device)
967
+ if cache_params.seqlen_offset > 0:
968
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
969
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
970
+ # handle batched generation - states are copied through
971
+ conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
972
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
973
+ hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
974
+ if self.use_conv_bias:
975
+ hidden_states += self.conv1d.bias
976
+ hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
977
+ else:
978
+ hidden_states = hidden_states.transpose(1,2)
979
+ conv_state = nn.functional.pad(
980
+ hidden_states,
981
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
982
+ )
983
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
984
+ hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
985
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
986
+ dtype = hidden_states.dtype
987
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
988
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
989
+ else:
990
+ ssm_state = torch.zeros(
991
+ (batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
992
+ device=hidden_states.device, dtype=dtype
993
+ )
994
+ hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
995
+ hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
996
+ A = -torch.exp(self.A_log.float()) # [num_heads]
997
+ if cache_params is not None and cache_params.seqlen_offset > 0:
998
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
999
+ # for batched generation
1000
+ dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
1001
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
1002
+ # [num_heads] -> [num_heads, head_dim]
1003
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
1004
+
1005
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
1006
+ dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
1007
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
1008
+ # [bsz, num_heads, head_dim, state_size]
1009
+ dA = torch.exp(dt[..., None] * A)
1010
+
1011
+ # Discretize B
1012
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
1013
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
1014
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
1015
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
1016
+ B = B.reshape(batch_size, -1, B.shape[-1])
1017
+ # [bsz, num_heads, head_dim, state_size]
1018
+ dB = dt[..., None] * B[..., None, :]
1019
+
1020
+ # Discretize x into dB
1021
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
1022
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
1023
+ dBx = dB * hidden_states[..., None]
1024
+
1025
+ # State calculation
1026
+ cache_params.ssm_states[self.layer_idx].copy_(
1027
+ cache_params.ssm_states[self.layer_idx] * dA + dBx
1028
+ )
1029
+
1030
+ # Subsequent output
1031
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
1032
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
1033
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
1034
+ C = C.reshape(batch_size, -1, C.shape[-1])
1035
+ # [bsz, num_heads, head_dim]
1036
+
1037
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
1038
+ # Reshape ssm_states to merge the first two dimensions
1039
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
1040
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
1041
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
1042
+ y = y.view(batch_size, self.num_heads, self.head_dim)
1043
+
1044
+ # D skip connection
1045
+ # [num_heads] -> [num_heads, head_dim]
1046
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
1047
+ y = (y + hidden_states * D).to(y.dtype)
1048
+
1049
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
1050
+ y = y.reshape(batch_size, -1)[:, None, ...]
1051
+ else:
1052
+ # begin ssd naive implementation without einsums
1053
+ dt = nn.functional.softplus(dt + self.dt_bias)
1054
+ dt = torch.clamp(dt, self.time_step_min)
1055
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
1056
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
1057
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
1058
+ B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
1059
+ C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
1060
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
1061
+
1062
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
1063
+
1064
+ # Discretize x and A
1065
+ hidden_states = hidden_states * dt[..., None]
1066
+ A = A.to(hidden_states.dtype) * dt
1067
+
1068
+ # Rearrange into blocks/chunks
1069
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
1070
+
1071
+
1072
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
1073
+ A = A.permute(0, 3, 1, 2)
1074
+ A_cumsum = torch.cumsum(A, dim=-1)
1075
+
1076
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
1077
+ # This is the analog of a causal mask
1078
+ L = torch.exp(segment_sum(A))
1079
+
1080
+ # First, contraction of C and B to get G (attention-weights like)
1081
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
1082
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
1083
+
1084
+
1085
+ # Step 2: Compute M, equivalent to applying attention mask to weights
1086
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
1087
+ M = M_intermediate.sum(dim=-1)
1088
+
1089
+ # Step 3: Compute Y_diag (apply to values)
1090
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
1091
+
1092
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
1093
+
1094
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
1095
+ B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
1096
+ # permute back B * decay states
1097
+ states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
1098
+ if cache_params is not None and cache_params.seqlen_offset > 0:
1099
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
1100
+ else:
1101
+ previous_states = torch.zeros_like(states[:, :1])
1102
+ states = torch.cat([previous_states, states], dim=1)
1103
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
1104
+
1105
+ states_permuted = states.permute(0, 2, 1, 3, 4)
1106
+ result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
1107
+ new_states = result.permute(0, 2, 1, 3, 4)
1108
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
1109
+
1110
+ # Compute state -> output conversion per chunk
1111
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
1112
+ state_decay_out = torch.exp(A_cumsum)
1113
+ # compute Yoff
1114
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
1115
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
1116
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
1117
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
1118
+
1119
+ y = Y_diag + Y_off
1120
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
1121
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
1122
+
1123
+ y = y + D_residual
1124
+ # Cutting off padded chunks
1125
+ if pad_size > 0:
1126
+ y = y[:, :seq_len, :, :]
1127
+ y = y.reshape(batch_size, seq_len, -1)
1128
+ if ssm_state is not None and cache_params is not None:
1129
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1130
+
1131
+ scan_output = self.norm(y, gate)
1132
+
1133
+ # end ssd naive
1134
+
1135
+ # 4. Final linear projection
1136
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
1137
+ return contextualized_states
1138
+ # fmt: on
1139
+
1140
+ def forward(
1141
+ self,
1142
+ hidden_states,
1143
+ cache_params: Optional[Mamba2Cache] = None,
1144
+ cache_position: Optional[torch.LongTensor] = None,
1145
+ attention_mask: Optional[torch.Tensor] = None,
1146
+ ):
1147
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
1148
+ return self.cuda_kernels_forward(
1149
+ hidden_states, cache_params, cache_position, attention_mask
1150
+ )
1151
+ dtype = hidden_states.dtype
1152
+ if (
1153
+ attention_mask is not None
1154
+ and attention_mask.shape[1] > 1
1155
+ and attention_mask.shape[0] > 1
1156
+ ):
1157
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
1158
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
1159
+
1160
+ return self.torch_forward(
1161
+ hidden_states, cache_params, cache_position, attention_mask
1162
+ )
1163
+
1164
+
1165
+ class Mamba2RMSNorm(nn.Module):
1166
+ def __init__(self, hidden_size, eps=1e-6):
1167
+ """
1168
+ Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
1169
+ """
1170
+ super().__init__()
1171
+ self.weight = nn.Parameter(torch.ones(hidden_size))
1172
+ self.variance_epsilon = eps
1173
+
1174
+ def forward(self, hidden_states):
1175
+ input_dtype = hidden_states.dtype
1176
+ hidden_states = hidden_states.to(torch.float32)
1177
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
1178
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1179
+ return self.weight * hidden_states.to(input_dtype)
1180
+
1181
+
1182
+ class HelixmRNAMLP(nn.Module):
1183
+ def __init__(self, config, layer_idx=None):
1184
+ super().__init__()
1185
+ self.hidden_size = config.hidden_size
1186
+ self.intermediate_size = self.hidden_size * 4 # config.intermediate_size
1187
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
1188
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
1189
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
1190
+ self.act_fn = ACT2FN[config.hidden_act]
1191
+
1192
+ def forward(self, hidden_state, **kwargs):
1193
+ hidden_states = self.down_proj(
1194
+ self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
1195
+ )
1196
+ return (hidden_states,)
1197
+
1198
+
1199
+ class HelixmRNAMLPLayer(nn.Module):
1200
+ def __init__(self, config, layer_idx=None):
1201
+ super().__init__()
1202
+ ffn_layer_class = HelixmRNAMLP
1203
+ self.feed_forward = ffn_layer_class(config)
1204
+ self.input_layernorm = Mamba2RMSNorm(
1205
+ config.hidden_size, eps=config.layer_norm_epsilon
1206
+ )
1207
+
1208
+ def forward(
1209
+ self,
1210
+ hidden_states,
1211
+ use_cache=True,
1212
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1213
+ **kwargs,
1214
+ ):
1215
+ residual = hidden_states
1216
+
1217
+ hidden_states = self.input_layernorm(hidden_states)
1218
+ ff_outputs = self.feed_forward(hidden_states)
1219
+
1220
+ hidden_states = ff_outputs[0]
1221
+ hidden_states = residual + hidden_states
1222
+
1223
+ outputs = (hidden_states,)
1224
+
1225
+ if use_cache:
1226
+ outputs += (past_key_value,)
1227
+
1228
+ return outputs
1229
+
1230
+
1231
+ class Mamba2Block(nn.Module):
1232
+ def __init__(self, config, layer_idx):
1233
+ super().__init__()
1234
+ self.config = config
1235
+ self.layer_idx = layer_idx
1236
+ self.residual_in_fp32 = config.residual_in_fp32
1237
+ self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
1238
+ self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
1239
+
1240
+ def forward(
1241
+ self,
1242
+ hidden_states,
1243
+ cache_position: Optional[torch.LongTensor] = None,
1244
+ attention_mask: Optional[torch.Tensor] = None,
1245
+ position_ids: Optional[torch.LongTensor] = None,
1246
+ output_attentions: Optional[bool] = False,
1247
+ use_cache: Optional[bool] = False,
1248
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1249
+ ):
1250
+ residual = hidden_states
1251
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
1252
+ if self.residual_in_fp32:
1253
+ residual = residual.to(torch.float32)
1254
+
1255
+ hidden_states = self.mixer(
1256
+ hidden_states,
1257
+ cache_params=past_key_value,
1258
+ cache_position=cache_position,
1259
+ attention_mask=attention_mask,
1260
+ )
1261
+ hidden_states = residual + hidden_states
1262
+
1263
+ hidden_states = (hidden_states,)
1264
+ if output_attentions:
1265
+ hidden_states += (None,)
1266
+
1267
+ if use_cache:
1268
+ hidden_states += (past_key_value,)
1269
+
1270
+ return hidden_states
1271
+
1272
+
1273
+ class HelixmRNAAttentionDecoderLayer(nn.Module):
1274
+ def __init__(self, config: HelixmRNAConfig, layer_idx: int):
1275
+ super().__init__()
1276
+ self.self_attn = HelixmRNA_ATTENTION_CLASSES[config._attn_implementation](
1277
+ config, layer_idx
1278
+ )
1279
+
1280
+ ffn_layer_class = HelixmRNAMLP
1281
+ self.feed_forward = ffn_layer_class(config)
1282
+ self.input_layernorm = Mamba2RMSNorm(
1283
+ config.hidden_size, eps=config.layer_norm_epsilon
1284
+ )
1285
+ self.pre_ff_layernorm = Mamba2RMSNorm(
1286
+ config.hidden_size, eps=config.layer_norm_epsilon
1287
+ )
1288
+
1289
+ def forward(
1290
+ self,
1291
+ hidden_states: torch.Tensor,
1292
+ attention_mask: Optional[torch.Tensor] = None,
1293
+ position_ids: Optional[torch.LongTensor] = None,
1294
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1295
+ output_attentions: Optional[bool] = False,
1296
+ output_router_logits: Optional[bool] = False,
1297
+ use_cache: Optional[bool] = False,
1298
+ cache_position: Optional[torch.LongTensor] = None,
1299
+ ) -> Tuple[
1300
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1301
+ ]:
1302
+ """
1303
+ Args:
1304
+ hidden_states (`torch.FloatTensor`):
1305
+ Input to the layer of shape `(batch, seq_len, embed_dim)`
1306
+ attention_mask (`torch.FloatTensor`, *optional*):
1307
+ Attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0.
1308
+ past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
1309
+ output_attentions (`bool`, *optional*):
1310
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1311
+ returned tensors for more detail.
1312
+ output_router_logits (`bool`, *optional*):
1313
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1314
+ should not be returned during inference.
1315
+ use_cache (`bool`, *optional*):
1316
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1317
+ (see `past_key_values`).
1318
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1319
+ Indices depicting the position of the input sequence tokens in the sequence.
1320
+ """
1321
+
1322
+ residual = hidden_states
1323
+
1324
+ hidden_states = self.input_layernorm(hidden_states)
1325
+
1326
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1327
+ hidden_states=hidden_states,
1328
+ attention_mask=attention_mask,
1329
+ position_ids=position_ids,
1330
+ past_key_value=past_key_value,
1331
+ output_attentions=output_attentions,
1332
+ use_cache=use_cache,
1333
+ cache_position=cache_position,
1334
+ )
1335
+
1336
+ # residual connection after attention
1337
+ hidden_states = residual + hidden_states
1338
+
1339
+ # feed-forward (experts/MLP)
1340
+ residual = hidden_states
1341
+ hidden_states = self.pre_ff_layernorm(hidden_states)
1342
+ ff_outputs = self.feed_forward(hidden_states)
1343
+
1344
+ hidden_states = ff_outputs[0]
1345
+ hidden_states = residual + hidden_states
1346
+
1347
+ outputs = (hidden_states,)
1348
+
1349
+ if output_attentions:
1350
+ outputs += (self_attn_weights,)
1351
+
1352
+ if use_cache:
1353
+ outputs += (present_key_value,)
1354
+
1355
+ return outputs
1356
+
1357
+
1358
+ class HelixmRNAPreTrainedModel(PreTrainedModel):
1359
+ """
1360
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1361
+ models.
1362
+ """
1363
+
1364
+ config_class = HelixmRNAConfig
1365
+ base_model_prefix = "backbone"
1366
+ supports_gradient_checkpointing = True
1367
+ _is_stateful = True
1368
+ _no_split_modules = ["HelixmRNAAttentionDecoderLayer", "Mamba2Block"]
1369
+ _skip_keys_device_placement = "past_key_values"
1370
+ _supports_flash_attn_2 = True
1371
+ _supports_sdpa = True
1372
+ _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
1373
+
1374
+ def _init_weights(self, module):
1375
+ """Initialize the weights."""
1376
+ if isinstance(module, Mamba2Mixer):
1377
+ module.A_log._no_weight_decay = True
1378
+ module.D._no_weight_decay = True
1379
+
1380
+ dt = torch.exp(
1381
+ torch.rand(self.config.num_heads)
1382
+ * (
1383
+ math.log(self.config.time_step_max)
1384
+ - math.log(self.config.time_step_min)
1385
+ )
1386
+ + math.log(self.config.time_step_min)
1387
+ ).clamp(min=self.config.time_step_floor)
1388
+
1389
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
1390
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
1391
+ with torch.no_grad():
1392
+ module.dt_bias.copy_(inv_dt)
1393
+ module.dt_bias._no_reinit = True
1394
+
1395
+ if isinstance(module, nn.Linear):
1396
+ if module.bias is not None:
1397
+ if not getattr(module.bias, "_no_reinit", False):
1398
+ nn.init.zeros_(module.bias)
1399
+ elif isinstance(module, nn.Embedding):
1400
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
1401
+
1402
+ if self.config.rescale_prenorm_residual:
1403
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
1404
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
1405
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
1406
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
1407
+ #
1408
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
1409
+ for name, p in module.named_parameters():
1410
+ if name in ["out_proj.weight"]:
1411
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
1412
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
1413
+ # We need to reinit p since this code could be called multiple times
1414
+ # Having just p *= scale would repeatedly scale it down
1415
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
1416
+ with torch.no_grad():
1417
+ p /= math.sqrt(self.config.num_hidden_layers)
1418
+
1419
+
1420
+ @dataclass
1421
+ # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
1422
+ class HelixmRNAOutput(ModelOutput):
1423
+ """
1424
+ Class for the MAMBA2 model outputs.
1425
+
1426
+ Args:
1427
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1428
+ Sequence of hidden-states at the output of the last layer of the model.
1429
+ cache_params (`Mamba2Cache`):
1430
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1431
+ avoid providing the old `input_ids`.
1432
+
1433
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
1434
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1435
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1436
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1437
+
1438
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1439
+ """
1440
+
1441
+ last_hidden_state: Optional[torch.FloatTensor] = None
1442
+ cache_params: Optional[Mamba2Cache] = None
1443
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1444
+
1445
+
1446
+ ALL_DECODER_LAYER_TYPES = {
1447
+ "attention": HelixmRNAAttentionDecoderLayer,
1448
+ "mamba": Mamba2Block,
1449
+ "mlp": HelixmRNAMLPLayer,
1450
+ }
1451
+
1452
+
1453
+ class HelixmRNAModel(HelixmRNAPreTrainedModel):
1454
+
1455
+ @classmethod
1456
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1457
+ wrapper_config = kwargs.pop("config", None)
1458
+ if wrapper_config is None:
1459
+ raise ValueError("Config must be provided")
1460
+
1461
+ model_name = wrapper_config.model_name
1462
+ cfg = HelixmRNAConfig.from_pretrained(model_name, **kwargs)
1463
+ cfg.model_name = model_name
1464
+
1465
+ return super().from_pretrained(
1466
+ model_name,
1467
+ *model_args,
1468
+ config=cfg,
1469
+ **kwargs,
1470
+ )
1471
+
1472
+ def __init__(self, config):
1473
+ super().__init__(config)
1474
+
1475
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
1476
+ decoder_layers = []
1477
+ for i in range(config.num_hidden_layers):
1478
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
1479
+ decoder_layers.append(layer_class(config, layer_idx=i))
1480
+ self.layers = nn.ModuleList(decoder_layers)
1481
+ self.gradient_checkpointing = False
1482
+ self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
1483
+ self._attn_implementation = config._attn_implementation
1484
+ # Initialize weights and apply final processing
1485
+ self._register_load_state_dict_pre_hook(self.load_hook)
1486
+ self.post_init()
1487
+
1488
+ def load_hook(self, state_dict, prefix, *args):
1489
+ for k in state_dict:
1490
+ if "embedding." in k:
1491
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
1492
+ break
1493
+
1494
+ def get_input_embeddings(self):
1495
+ return self.embeddings
1496
+
1497
+ def set_input_embeddings(self, new_embeddings):
1498
+ self.embeddings = new_embeddings
1499
+
1500
+ def forward(
1501
+ self,
1502
+ input_ids: Optional[torch.LongTensor] = None,
1503
+ inputs_embeds: Optional[torch.LongTensor] = None,
1504
+ use_cache: Optional[bool] = None,
1505
+ output_hidden_states: Optional[bool] = None,
1506
+ return_dict: Optional[bool] = None,
1507
+ cache_position: Optional[torch.LongTensor] = None,
1508
+ attention_mask: Optional[torch.Tensor] = None,
1509
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
1510
+ position_ids: Optional[torch.LongTensor] = None,
1511
+ output_attentions: Optional[bool] = None,
1512
+ **kwargs,
1513
+ ) -> Union[Tuple, HelixmRNAOutput]:
1514
+ output_hidden_states = (
1515
+ output_hidden_states
1516
+ if output_hidden_states is not None
1517
+ else self.config.output_hidden_states
1518
+ )
1519
+ use_cache = (
1520
+ use_cache
1521
+ if use_cache is not None
1522
+ else (self.config.use_cache if not self.training else False)
1523
+ )
1524
+ return_dict = (
1525
+ return_dict if return_dict is not None else self.config.use_return_dict
1526
+ )
1527
+
1528
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
1529
+ raise ValueError(
1530
+ "You must specify exactly one of input_ids or inputs_embeds"
1531
+ )
1532
+
1533
+ if inputs_embeds is None:
1534
+ inputs_embeds = self.embeddings(input_ids)
1535
+
1536
+ if self.gradient_checkpointing and self.training and use_cache:
1537
+ use_cache = False
1538
+ cache_params = past_key_values
1539
+ if use_cache:
1540
+ if cache_params is None:
1541
+ cache_params = HybridMambaAttentionDynamicCache(
1542
+ self.config,
1543
+ inputs_embeds.size(0),
1544
+ device=inputs_embeds.device,
1545
+ dtype=inputs_embeds.dtype,
1546
+ )
1547
+ cache_position = torch.arange(
1548
+ 0, self.config.conv_kernel, device=inputs_embeds.device
1549
+ )
1550
+ elif cache_position is None:
1551
+ # cases when we do manual forward instead of using `model.generate` which will initiate
1552
+ # `cache_position` and makes sure it is not None, throw error here instead of doing some
1553
+ # hack to conjecture the current cache position
1554
+ raise ValueError(
1555
+ "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
1556
+ "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
1557
+ "be initialized for you automatically"
1558
+ )
1559
+ if use_cache and past_key_values is None:
1560
+ print(
1561
+ "HelixmRNA requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
1562
+ "provided, so no cache will be returned."
1563
+ )
1564
+ else:
1565
+ cache_params = None
1566
+
1567
+ hidden_states = inputs_embeds
1568
+ if cache_position is None:
1569
+ cache_position = torch.arange(
1570
+ hidden_states.shape[1], device=hidden_states.device
1571
+ )
1572
+ if position_ids is None:
1573
+ position_ids = cache_position.unsqueeze(0)
1574
+
1575
+ causal_mask = self._update_causal_mask(
1576
+ attention_mask, inputs_embeds, cache_position
1577
+ )
1578
+
1579
+ all_hidden_states = () if output_hidden_states else None
1580
+ all_self_attns = () if output_attentions else None
1581
+ for helix_block in self.layers:
1582
+
1583
+ layer_mask = (
1584
+ attention_mask if isinstance(helix_block, Mamba2Block) else causal_mask
1585
+ )
1586
+
1587
+ if self.gradient_checkpointing and self.training:
1588
+ layer_outputs = self._gradient_checkpointing_func(
1589
+ helix_block.__call__,
1590
+ hidden_states,
1591
+ layer_mask,
1592
+ position_ids,
1593
+ past_key_values,
1594
+ output_attentions,
1595
+ use_cache,
1596
+ cache_position,
1597
+ )
1598
+ else:
1599
+ layer_outputs = helix_block(
1600
+ hidden_states,
1601
+ attention_mask=layer_mask,
1602
+ position_ids=position_ids,
1603
+ past_key_value=past_key_values,
1604
+ output_attentions=output_attentions,
1605
+ use_cache=use_cache,
1606
+ cache_position=cache_position,
1607
+ )
1608
+
1609
+ if output_hidden_states:
1610
+ all_hidden_states += (layer_outputs[0],)
1611
+
1612
+ hidden_states = self.norm_f(layer_outputs[0])
1613
+
1614
+ if output_hidden_states:
1615
+ all_hidden_states = all_hidden_states + (hidden_states,)
1616
+
1617
+ if output_attentions:
1618
+ if layer_outputs[1] is not None:
1619
+ # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1620
+ all_self_attns += (layer_outputs[1],)
1621
+
1622
+ if use_cache:
1623
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
1624
+
1625
+ if not return_dict:
1626
+ return tuple(
1627
+ v
1628
+ for v in [hidden_states, cache_params, all_hidden_states]
1629
+ if v is not None
1630
+ )
1631
+
1632
+ return HelixmRNAOutput(
1633
+ last_hidden_state=hidden_states,
1634
+ cache_params=cache_params if use_cache else None,
1635
+ hidden_states=all_hidden_states,
1636
+ )
1637
+
1638
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1639
+ if self.config._attn_implementation == "flash_attention_2":
1640
+ if attention_mask is not None and 0.0 in attention_mask:
1641
+ return attention_mask
1642
+ return None
1643
+
1644
+ dtype, device = input_tensor.dtype, input_tensor.device
1645
+ min_dtype = torch.finfo(dtype).min
1646
+ sequence_length = input_tensor.shape[1]
1647
+ target_length = cache_position[-1] + 1
1648
+
1649
+ causal_mask = torch.full(
1650
+ (sequence_length, target_length),
1651
+ fill_value=min_dtype,
1652
+ dtype=dtype,
1653
+ device=device,
1654
+ )
1655
+ if sequence_length != 1:
1656
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1657
+ causal_mask *= torch.arange(
1658
+ target_length, device=device
1659
+ ) > cache_position.reshape(-1, 1)
1660
+ causal_mask = causal_mask[None, None, :, :].expand(
1661
+ input_tensor.shape[0], 1, -1, -1
1662
+ )
1663
+ if attention_mask is not None:
1664
+ causal_mask = (
1665
+ causal_mask.clone()
1666
+ ) # copy to contiguous memory for in-place edit
1667
+ if attention_mask.dim() == 2:
1668
+ mask_length = attention_mask.shape[-1]
1669
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
1670
+ :, None, None, :
1671
+ ].eq(0.0)
1672
+ causal_mask[..., :mask_length] = causal_mask[
1673
+ ..., :mask_length
1674
+ ].masked_fill(padding_mask, min_dtype)
1675
+
1676
+ if (
1677
+ self.config._attn_implementation == "sdpa"
1678
+ and attention_mask is not None
1679
+ and attention_mask.device.type == "cuda"
1680
+ ):
1681
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1682
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1683
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1684
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1685
+ causal_mask, min_dtype
1686
+ )
1687
+
1688
+ return causal_mask
1689
+
1690
+ def _update_mamba_mask(self, attention_mask, cache_position):
1691
+ """
1692
+ No need for zeroing states when
1693
+ 1. Cached forward
1694
+ 2. Attending to all inputs
1695
+ """
1696
+ mamba_mask = attention_mask
1697
+ if cache_position[0] > 0 or (
1698
+ attention_mask is not None and torch.all(attention_mask == 1)
1699
+ ):
1700
+ mamba_mask = None
1701
+ return mamba_mask
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "[PAD]",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "[UNK]"
9
+ }
tokenization_helix_mrna.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Character tokenizer for Hugging Face."""
2
+
3
+ from typing import List, Optional, Dict, Sequence, Tuple
4
+
5
+ from transformers import PreTrainedTokenizer
6
+
7
+
8
+ class HelixmRNATokenizer(PreTrainedTokenizer):
9
+ model_input_names = ["input_ids"]
10
+
11
+ def __init__(
12
+ self,
13
+ model_max_length: int,
14
+ bos_token="[BOS]",
15
+ eos_token="[SEP]",
16
+ sep_token="[SEP]",
17
+ cls_token="[CLS]",
18
+ pad_token="[PAD]",
19
+ mask_token="[MASK]",
20
+ unk_token="[UNK]",
21
+ **kwargs,
22
+ ):
23
+ """Character tokenizer for Hugging Face transformers.
24
+ Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py
25
+ Args:
26
+ model_max_length (int): Model maximum sequence length.
27
+ characters (Sequence[str]): List of desired characters. Any character which
28
+ is not included in this list will be replaced by a special token called
29
+ [UNK] with id=6. Following is a list of the special tokens with
30
+ their corresponding ids:
31
+ "[CLS]": 0
32
+ "[SEP]": 1
33
+ "[BOS]": 2
34
+ "[MASK]": 3
35
+ "[PAD]": 4
36
+ "[RESERVED]": 5
37
+ "[UNK]": 6
38
+ an id (starting at 7) will be assigned to each character.
39
+ """
40
+ self.characters = ("A", "C", "G", "U", "N", "E", "T")
41
+ self.model_max_length = model_max_length
42
+
43
+ self._vocab_str_to_int = {
44
+ "[CLS]": 2,
45
+ "[SEP]": 1,
46
+ "[BOS]": 0,
47
+ "[MASK]": 3,
48
+ "[PAD]": 1,
49
+ "[RESERVED]": 5,
50
+ "[UNK]": 6,
51
+ **{ch: i + 7 for i, ch in enumerate(self.characters)},
52
+ }
53
+
54
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
55
+ add_prefix_space = kwargs.pop("add_prefix_space", False)
56
+ padding_side = kwargs.pop("padding_side", "left")
57
+
58
+ self._vocab_str_to_int["T"] = self._vocab_str_to_int["U"]
59
+
60
+ super().__init__(
61
+ bos_token=bos_token,
62
+ eos_token=eos_token,
63
+ sep_token=sep_token,
64
+ cls_token=cls_token,
65
+ pad_token=pad_token,
66
+ mask_token=mask_token,
67
+ unk_token=unk_token,
68
+ add_prefix_space=add_prefix_space,
69
+ model_max_length=model_max_length,
70
+ padding_side=padding_side,
71
+ **kwargs,
72
+ )
73
+
74
+ @property
75
+ def vocab_size(self) -> int:
76
+ return len(self._vocab_str_to_int)
77
+
78
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
79
+ return list(text.upper()) # Convert all base pairs to uppercase
80
+
81
+ def _convert_token_to_id(self, token: str) -> int:
82
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
83
+
84
+ def _convert_id_to_token(self, index: int) -> str:
85
+ return self._vocab_int_to_str[index]
86
+
87
+ def convert_tokens_to_string(self, tokens):
88
+ return "".join(
89
+ tokens
90
+ ) # Note: this operation has lost info about which base pairs were originally lowercase
91
+
92
+ def get_special_tokens_mask(
93
+ self,
94
+ token_ids_0: List[int],
95
+ token_ids_1: Optional[List[int]] = None,
96
+ already_has_special_tokens: bool = False,
97
+ ) -> List[int]:
98
+ if already_has_special_tokens:
99
+ return super().get_special_tokens_mask(
100
+ token_ids_0=token_ids_0,
101
+ token_ids_1=token_ids_1,
102
+ already_has_special_tokens=True,
103
+ )
104
+
105
+ result = ([0] * len(token_ids_0)) + [1]
106
+ if token_ids_1 is not None:
107
+ result += ([0] * len(token_ids_1)) + [1]
108
+ return result
109
+
110
+ def build_inputs_with_special_tokens(
111
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
112
+ ) -> List[int]:
113
+ sep = [self.sep_token_id]
114
+ # cls = [self.cls_token_id]
115
+ result = token_ids_0 + sep
116
+ if token_ids_1 is not None:
117
+ result += token_ids_1 + sep
118
+ return result
119
+
120
+ def get_vocab(self) -> Dict[str, int]:
121
+ return self._vocab_str_to_int
122
+
123
+ # Fixed vocabulary with no vocab file
124
+ def save_vocabulary(
125
+ self, save_directory: str, filename_prefix: Optional[str] = None
126
+ ) -> Tuple:
127
+ return ()
tokenizer_config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "[BOS]",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "[PAD]",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "[CLS]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "[MASK]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "6": {
37
+ "content": "[UNK]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "[BOS]",
46
+ "characters": [
47
+ "A",
48
+ "C",
49
+ "G",
50
+ "U",
51
+ "N",
52
+ "E",
53
+ "T"
54
+ ],
55
+ "clean_up_tokenization_spaces": false,
56
+ "cls_token": "[CLS]",
57
+ "eos_token": "[SEP]",
58
+ "mask_token": "[MASK]",
59
+ "model_max_length": 12288,
60
+ "pad_token": "[PAD]",
61
+ "padding_side": "left",
62
+ "sep_token": "[SEP]",
63
+ "unk_token": "[UNK]",
64
+ "tokenizer_class": "HelixmRNATokenizer",
65
+ "auto_map": {
66
+ "AutoTokenizer": [
67
+ "tokenization_helix_mrna.HelixmRNATokenizer",
68
+ null
69
+ ]
70
+ }
71
+ }