Taykhoom commited on
Commit
d411fc4
·
verified ·
1 Parent(s): 8cc833c

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GENBIO AI COMMUNITY LICENSE AGREEMENT
2
+
3
+ This GenBio AI Community License Agreement (the “License”) constitutes an agreement between you or the legal entity you represent (“you” or “your”) and GENBIO.AI, INC. (“GenBio”), governing your use of the GenBio Materials. If you are using the GenBio Materials on behalf of a legal entity, you represent and warrant to GenBio that you have full legal authority to act on behalf of that legal entity as applicable under the License. If you do not have the authority to accept this License or if you disagree with any or all of the License, you shall not use the GenBio Materials in any manner. By using or distributing any portion or element of the GenBio Materials, you imply your agreement to be bound by the License.
4
+
5
+ “GenBio Materials” means any datasets, code, model weights or any other materials provided by GenBio at the following GitHub Page https://github.com/genbio-ai or Hugging Face Page https://huggingface.co/genbio-ai, including any updates or modifications made from time to time, whether in Source or Object form, and is made available to you under this License.
6
+
7
+
8
+ 1. License Grant.
9
+ 1.1 License Scope. Subject to the terms of this License, GenBio grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under GenBio’s intellectual property or other rights owned by GenBio embodied in the GenBio Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the GenBio Materials for any Non-Commercial Purposes.
10
+ 1.2 Use Restrictions. Restricted activities in relation to the License or use of GenBio Materials include:
11
+ 1.2.1 You shall use the GenBio Materials, Contributions, Derivative Works, Outputs and Output Derivatives (as defined below) solely for Non-Commercial Purposes;
12
+ 1.2.2 You shall not, directly or indirectly: (a) use or provide access to any Outputs or Output Derivatives to train, optimize, improve, or otherwise enhance the functionality or performance of any machine learning models or related technologies that are similar to the GenBio Materials; (b) engage in any form of model distillation or other methods that would achieve the purposes described in subsection (a) above. Notwithstanding the foregoing, you may use Outputs and Output Derivatives to train, optimize, improve, or enhance the functionality or performance of: (i) The GenBio Materials itself; and (ii) downstream Derivative Works of the GenBio Materials;
13
+ 1.2.3 Your use of the GenBio Materials shall be subject to any additional terms and conditions that: (a) GenBio provides to you separately; or (b) GenBio otherwise makes available to you.
14
+
15
+ 2. Sharing and Distribution.
16
+ 2.1 Subject to Section 1, if you distribute or make available the GenBio Materials or a Derivative Work to a third party for your Non-Commercial Purposes, in Source or Object form, you shall:
17
+ 2.1.1 provide a copy of this License to that third party;
18
+ 2.1.2 retain the following attribution notice within a “Notice” text file distributed as a part of such copies: “This is licensed under the GenBio AI Community License Agreement, Copyright © GENBIO.AI, INC. All Rights Reserved”; and
19
+ 2.1.3 prominently display “Powered by GenBio AI” on a related website, user interface, blogpost, about page, or product documentation.
20
+ 2.2 If You create a Derivative Work, you may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that you clearly indicate which attributions apply to the GenBio Materials and state in the “Notice” text file that you changed the GenBio Materials and how it was modified.
21
+
22
+ 3. Submission of Contribution.
23
+ Unless you explicitly state otherwise, any Contribution intentionally submitted for inclusion in the GenBio Materials by you to GenBio shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with GenBio regarding such Contributions.
24
+
25
+ 4. Export Control.
26
+ You shall comply with the applicable U.S. Foreign Corrupt Practices Act and all applicable export laws, restrictions and regulations of the U.S. Department of Commerce, and any other applicable U.S. and foreign authority.
27
+
28
+ 5. Disclaimer of Warranty.
29
+ GENBIO MATERIALS PROVIDED BY GENBIO OR ANY OUTPUT YOU RECEIVED ARE PROVIDED “AS IS.” EXCEPT TO THE EXTENT PROHIBITED BY LAW. GENBIO MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND, WHETHER EXPRESS, IMPLIED OR OTHERWISE, REGARDING THE ACCURACY, COMPLETENESS OR PERFORMANCE OF THE SERVICES AND YOUR OUTPUT, OR WITH RESPECT TO SATISFACTORY QUALITY, FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT.
30
+
31
+ 6. Limitation of Liability.
32
+ In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the GenBio Materials (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
33
+
34
+ 7. General Terms.
35
+ 7.1 Relationship of Parties. You and GenBio are independent contractors, and nothing herein shall be deemed to constitute either party as the agent or representative of the other or both parties as joint venturers or partners for any purpose.
36
+ 7.2 Assignment. This License and the rights and obligations herein may not be assigned or transferred, in whole or in part, by You without the prior written consent of GenBio. Any assignment in violation of this provision is void. GenBio may freely assign or transfer this License, in whole or in part. This License shall be binding upon, and inure to the benefit of, the successors and permitted assigns of the parties.
37
+ 7.3 Governing Law. This License shall be governed, construed and interpreted in accordance with the laws of the State of California, without giving effect to principles of conflicts of law. Each of the parties to this License consents to the exclusive jurisdiction and venue of the courts of the state and federal courts of California.
38
+ 7.4 Severability. If any provision of this License is held to be invalid, illegal or unenforceable in any respect, that provision shall be limited or eliminated to the minimum extent necessary so that this License otherwise remains in full force and effect and enforceable.
39
+
40
+ 8. Definitions.
41
+ 8.1 “Commercial Entity” means any entity engaged in any activity intended for or directed toward commercial advantage or monetary compensation, including, without limitation, the development of any product or service intended to be sold or made available for a fee. For the purpose of this License, references to a Commercial Entity expressly exclude any universities, non-profit organizations, not-for-profit entities, research institutes and educational and government bodies.
42
+ 8.2 “Contribution” means any work of authorship, including the original version of the GenBio Materials and any modifications or additions to that GenBio Materials or Derivative Works thereof, that is intentionally submitted to GenBio for inclusion in the GenBio Materials by the copyright owner or by an individual or legal entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to GenBio or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, GenBio for the purpose of discussing and improving the GenBio Materials, but excluding Outputs and all communications that are conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution”.
43
+ 8.3 “Contributor” means GenBio and any individual or legal entity on behalf of whom a Contribution has been received by GenBio and subsequently incorporated within the GenBio Materials.
44
+ 8.4 “Derivative Work” means any work, whether in Source or Object form, that is based on (or derived from) the GenBio Materials and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the GenBio Materials and Derivative Works thereof.
45
+ 8.5 “Non-Commercial Purposes” means uses not intended for or directed toward commercial advantage or monetary compensation, or the facilitation of development of any product or service to be sold or made available for a fee. For the avoidance of doubt, the provision of Outputs as a service is not a Non-Commercial Purpose.
46
+ 8.6 “Object” means any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
47
+ 8.7 “Output” means any output, including any protein sequence, structure prediction, functional annotation, molecule, descriptions of a molecule, model, sequence, text, and/or image that is elicited directly or indirectly by, or otherwise made available to, you in connection with your use of the GenBio Materials, including, but not limited to, the use of AI-Powered Technology. For the avoidance of doubt, it includes any intermediate results, such as activations across model layers, intermediate outputs from model layers (e.g., attention maps), as well as gradients and embeddings produced by the GenBio Materials.
48
+ 8.8 “Output Derivatives” means any enhancements, modifications and derivative works of Outputs (including, but not limited to, any derivative sequences or molecules).
49
+ 8.9 “Source” means the preferred form for making modifications, including but not limited to GenBio Materials source code, documentation source, and configuration files.
50
+
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to use
2
+ ```python
3
+ from transformers import AutoModel, AutoTokenizer
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained(
6
+ "Taykhoom/AIDO-RNA-Wrapper",
7
+ trust_remote_code=True,
8
+ )
9
+
10
+ model = AutoModel.from_pretrained(
11
+ "Taykhoom/AIDO-RNA-Wrapper",
12
+ trust_remote_code=True,
13
+ base_model="genbio-ai/AIDO.RNA-650M-CDS",
14
+ )
15
+ ```
16
+
17
+ # Model Variants
18
+ The following `base_model` options are available for embedding generation. The short name (keys) or the full model name (values) can be specified using the `base_model` argument.
19
+ ```python
20
+ VARIANTS = {
21
+ "aido_rna_1m_mars": "genbio-ai/AIDO.RNA-1M-MARS",
22
+ "aido_rna_25m_mars": "genbio-ai/AIDO.RNA-25M-MARS",
23
+ "aido_rna_300m_mars": "genbio-ai/AIDO.RNA-300M-MARS",
24
+ "aido_rna_650m": "genbio-ai/AIDO.RNA-650M",
25
+ "aido_rna_650m_cds": "genbio-ai/AIDO.RNA-650M-CDS",
26
+ "aido_rna_1b600m": "genbio-ai/AIDO.RNA-1.6B",
27
+ "aido_rna_1b600m_cds": "genbio-ai/AIDO.RNA-1.6B-CDS",
28
+ }
29
+ ```
30
+
31
+ # Performance Vs Original AIDO.RNA Models
32
+
33
+ Verify that the modified code produces the same embeddings as the original AIDO.RNA models.
34
+
35
+ Original AIDO.RNA code snippet:
36
+ ```python
37
+ from modelgenerator.tasks import Embed
38
+ import torch
39
+
40
+ model = Embed.from_config({"model.backbone": "aido_rna_650m"}).eval()
41
+ dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
42
+ transformed_batch = model.transform({"sequences": [dna]})
43
+ embedding = model(transformed_batch) # [1, sequence_length, 1280]
44
+
45
+ embedding_mean = torch.mean(embedding, dim=1)
46
+ print(torch.mean(embedding_mean)) # Outputs tensor(0.0005, grad_fn=<MeanBackward0>)
47
+
48
+ embedding_max = torch.max(embedding, dim=1)[0]
49
+ print(torch.mean(embedding_max)) # Outputs tensor(1.5583, grad_fn=<MeanBackward0>)
50
+ ```
51
+
52
+ Modified code snippet using the wrapper:
53
+ ```python
54
+ import torch
55
+ from transformers import AutoTokenizer, AutoModel
56
+
57
+ tokenizer = AutoTokenizer.from_pretrained(
58
+ "Taykhoom/AIDO-RNA-Wrapper",
59
+ trust_remote_code=True,
60
+ )
61
+
62
+ model = AutoModel.from_pretrained(
63
+ "Taykhoom/AIDO-RNA-Wrapper",
64
+ trust_remote_code=True,
65
+ base_model="genbio-ai/AIDO.RNA-650M",
66
+ )
67
+
68
+ dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
69
+ inputs = tokenizer(dna, add_special_tokens=True, return_special_tokens_mask=True, return_tensors="pt")
70
+
71
+ embedding = model(
72
+ input_ids=inputs["input_ids"],
73
+ attention_mask=inputs["attention_mask"],
74
+ ).last_hidden_state # [1, sequence_length, 1280]
75
+
76
+ embedding_mean = torch.mean(embedding, dim=1)
77
+ print(torch.mean(embedding_mean)) # Outputs tensor(0.0005, grad_fn=<MeanBackward0>)
78
+
79
+ embedding_max = torch.max(embedding, dim=1)[0]
80
+ print(torch.mean(embedding_max)) # Outputs tensor(1.5583, grad_fn=<MeanBackward0>)
81
+ ```
82
+
83
+ # License Notice
84
+ This repository contains modified versions of GenBio AI code.
85
+ Modifications include:
86
+ - Removal of reliance on modelgenerator package
87
+ - Can load specific AIDO.RNA models via the `base_model` argument
88
+
89
+ Not all of the original functionality may be preserved. These changes were made to better integrate with the mRNABench framework which focuses on embedding generation for mRNA sequences. Most of the required code was directly copied from the original GenBio AI repository with minimal changes, so please refer to the original repository for full details on the implementation.
90
+
91
+ When using this repository, please adhere to the original license terms of the GenBio AI code. This license can be found in this directory as `LICENSE`.
config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "rnabert",
3
+ "auto_map": {
4
+ "AutoConfig": "configuration_rnabert.RNABertConfig",
5
+ "AutoModel": "modeling_rnabert.RNABertModel"
6
+ }
7
+ }
configuration_rnabert.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class RNABertConfig(PretrainedConfig):
4
+ model_type = "rnabert"
5
+
6
+ def __init__(self, base_model=None, **kwargs):
7
+ self.base_model = base_model
8
+ super().__init__(**kwargs)
modeling_rnabert.py ADDED
@@ -0,0 +1,1204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch MegatronBERT model."""
17
+
18
+
19
+ import math
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ MaskedLMOutput,
32
+ )
33
+ from transformers import PreTrainedModel
34
+ from transformers.pytorch_utils import (
35
+ apply_chunking_to_forward,
36
+ find_pruneable_heads_and_indices,
37
+ prune_linear_layer,
38
+ )
39
+ from .configuration_rnabert import RNABertConfig
40
+
41
+ VARIANTS = {
42
+ "aido_rna_1m_mars": "genbio-ai/AIDO.RNA-1M-MARS",
43
+ "aido_rna_25m_mars": "genbio-ai/AIDO.RNA-25M-MARS",
44
+ "aido_rna_300m_mars": "genbio-ai/AIDO.RNA-300M-MARS",
45
+ "aido_rna_650m": "genbio-ai/AIDO.RNA-650M",
46
+ "aido_rna_650m_cds": "genbio-ai/AIDO.RNA-650M-CDS",
47
+ "aido_rna_1b600m": "genbio-ai/AIDO.RNA-1.6B",
48
+ "aido_rna_1b600m_cds": "genbio-ai/AIDO.RNA-1.6B-CDS",
49
+ }
50
+
51
+ class RNABertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word, position and token_type embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ if config.position_embedding_type != "rope":
60
+ self.position_embeddings = nn.Embedding(
61
+ config.max_position_embeddings, config.hidden_size
62
+ )
63
+ # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
64
+
65
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
66
+ # any TensorFlow checkpoint file
67
+
68
+ # In Megatron, layer-norm is applied after the 1st dropout.
69
+ # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
70
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
71
+
72
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
73
+ self.register_buffer(
74
+ "position_ids",
75
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
76
+ persistent=False,
77
+ )
78
+ self.position_embedding_type = getattr(
79
+ config, "position_embedding_type", "rope"
80
+ )
81
+
82
+ def forward(
83
+ self,
84
+ input_ids: Optional[torch.LongTensor] = None,
85
+ token_type_ids: Optional[torch.LongTensor] = None,
86
+ position_ids: Optional[torch.LongTensor] = None,
87
+ inputs_embeds: Optional[torch.LongTensor] = None,
88
+ past_key_values_length: int = 0,
89
+ ) -> torch.Tensor:
90
+ if input_ids is not None:
91
+ input_shape = input_ids.size()
92
+ else:
93
+ input_shape = inputs_embeds.size()[:-1]
94
+
95
+ seq_length = input_shape[1]
96
+
97
+ if position_ids is None:
98
+ position_ids = self.position_ids[
99
+ :, past_key_values_length : seq_length + past_key_values_length
100
+ ]
101
+
102
+ if token_type_ids is None:
103
+ token_type_ids = torch.zeros(
104
+ input_shape, dtype=torch.long, device=self.position_ids.device
105
+ )
106
+
107
+ if inputs_embeds is None:
108
+ inputs_embeds = self.word_embeddings(input_ids)
109
+ # token_type_embeddings = self.token_type_embeddings(token_type_ids)
110
+
111
+ # embeddings = inputs_embeds + token_type_embeddings
112
+ embeddings = inputs_embeds
113
+ if self.position_embedding_type == "absolute":
114
+ position_embeddings = self.position_embeddings(position_ids)
115
+ embeddings += position_embeddings
116
+
117
+ # Megatron BERT moves that layer norm after the drop-out (and to each layer).
118
+ # embeddings = self.LayerNorm(embeddings)
119
+ embeddings = self.dropout(embeddings)
120
+ return embeddings
121
+
122
+
123
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RNABert
124
+ class RNABertSelfAttention(nn.Module):
125
+ def __init__(self, config, position_embedding_type=None):
126
+ super().__init__()
127
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
128
+ config, "embedding_size"
129
+ ):
130
+ raise ValueError(
131
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
132
+ f"heads ({config.num_attention_heads})"
133
+ )
134
+
135
+ self.num_attention_heads = config.num_attention_heads
136
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
137
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
138
+
139
+ self.query = nn.Linear(
140
+ config.hidden_size, self.all_head_size, bias=config.add_linear_bias
141
+ )
142
+ self.key = nn.Linear(
143
+ config.hidden_size, self.all_head_size, bias=config.add_linear_bias
144
+ )
145
+ self.value = nn.Linear(
146
+ config.hidden_size, self.all_head_size, bias=config.add_linear_bias
147
+ )
148
+
149
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
150
+ self.position_embedding_type = position_embedding_type or getattr(
151
+ config, "position_embedding_type", "absolute"
152
+ )
153
+ if (
154
+ self.position_embedding_type == "relative_key"
155
+ or self.position_embedding_type == "relative_key_query"
156
+ ):
157
+ self.max_position_embeddings = config.max_position_embeddings
158
+ self.distance_embedding = nn.Embedding(
159
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
160
+ )
161
+
162
+ self.is_decoder = config.is_decoder
163
+
164
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
165
+ new_x_shape = x.size()[:-1] + (
166
+ self.num_attention_heads,
167
+ self.attention_head_size,
168
+ )
169
+ x = x.view(new_x_shape)
170
+ return x.permute(0, 2, 1, 3)
171
+
172
+ def forward(
173
+ self,
174
+ hidden_states: torch.Tensor,
175
+ attention_mask: Optional[torch.FloatTensor] = None,
176
+ head_mask: Optional[torch.FloatTensor] = None,
177
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
178
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
179
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
180
+ output_attentions: Optional[bool] = False,
181
+ rotary_pos_emb=None,
182
+ ) -> Tuple[torch.Tensor]:
183
+ mixed_query_layer = self.query(hidden_states)
184
+
185
+ # If this is instantiated as a cross-attention module, the keys
186
+ # and values come from an encoder; the attention mask needs to be
187
+ # such that the encoder's padding tokens are not attended to.
188
+ is_cross_attention = encoder_hidden_states is not None
189
+
190
+ if is_cross_attention and past_key_value is not None:
191
+ # reuse k,v, cross_attentions
192
+ key_layer = past_key_value[0]
193
+ value_layer = past_key_value[1]
194
+ attention_mask = encoder_attention_mask
195
+ elif is_cross_attention:
196
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
198
+ attention_mask = encoder_attention_mask
199
+ elif past_key_value is not None:
200
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
201
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
202
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
203
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
204
+ else:
205
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
206
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
207
+
208
+ # [b, hn, sq, c]
209
+ query_layer = self.transpose_for_scores(mixed_query_layer)
210
+
211
+ if rotary_pos_emb is not None:
212
+ if isinstance(rotary_pos_emb, tuple):
213
+ rotary_pos_emb = rotary_pos_emb
214
+ else:
215
+ rotary_pos_emb = (rotary_pos_emb,) * 2
216
+
217
+ q_pos_emb, k_pos_emb = rotary_pos_emb
218
+
219
+ # [b, hn, sq, c] --> [sq, b, hn, c]
220
+ query_layer = query_layer.permute(2, 0, 1, 3).contiguous()
221
+ key_layer = key_layer.permute(2, 0, 1, 3).contiguous()
222
+
223
+ query_layer = apply_rotary_pos_emb(
224
+ query_layer, q_pos_emb
225
+ ) # debug query_layer[:,0]
226
+ key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
227
+
228
+ # [sq, b, hn, c] --> [b, hn, sq, c]
229
+ query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
230
+ key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
231
+
232
+ use_cache = past_key_value is not None
233
+ if self.is_decoder:
234
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
235
+ # Further calls to cross_attention layer can then reuse all cross-attention
236
+ # key/value_states (first "if" case)
237
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
238
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
239
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
240
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
241
+ past_key_value = (key_layer, value_layer)
242
+
243
+ # Take the dot product between "query" and "key" to get the raw attention scores.
244
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
245
+
246
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
247
+ if attention_mask is not None:
248
+ # Apply the attention mask is (precomputed for all layers in RNABertModel forward() function)
249
+ attention_scores = attention_scores + attention_mask.to(
250
+ attention_scores.dtype
251
+ )
252
+
253
+ # Normalize the attention scores to probabilities.
254
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
255
+
256
+ no_prob_mask = attention_mask < -1e-5
257
+ attention_probs = attention_probs.masked_fill(no_prob_mask, 0.0)
258
+ # This is actually dropping out entire tokens to attend to, which might
259
+ # seem a bit unusual, but is taken from the original Transformer paper.
260
+ attention_probs = self.dropout(attention_probs)
261
+
262
+ # Mask heads if we want to
263
+ if head_mask is not None:
264
+ attention_probs = attention_probs * head_mask
265
+
266
+ context_layer = torch.matmul(attention_probs, value_layer)
267
+
268
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
269
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
270
+ context_layer = context_layer.view(new_context_layer_shape)
271
+
272
+ outputs = (
273
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
274
+ )
275
+
276
+ if self.is_decoder:
277
+ outputs = outputs + (past_key_value,)
278
+ return outputs
279
+
280
+
281
+ # Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RNABertAttention below.
282
+ class RNABertSelfOutput(nn.Module):
283
+ def __init__(self, config):
284
+ super().__init__()
285
+ self.dense = nn.Linear(
286
+ config.hidden_size, config.hidden_size, bias=config.add_linear_bias
287
+ )
288
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
289
+
290
+ def forward(
291
+ self, hidden_states: torch.Tensor, residual: torch.Tensor
292
+ ) -> torch.Tensor:
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+ return residual + hidden_states
296
+
297
+
298
+ # Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
299
+ class RNABertAttention(nn.Module):
300
+ def __init__(self, config):
301
+ super().__init__()
302
+ self.ln = config.norm_cls(config.hidden_size, eps=config.layer_norm_eps)
303
+
304
+ self.self = RNABertSelfAttention(config)
305
+ self.output = RNABertSelfOutput(config)
306
+ self.pruned_heads = set()
307
+
308
+ def prune_heads(self, heads):
309
+ if len(heads) == 0:
310
+ return
311
+ heads, index = find_pruneable_heads_and_indices(
312
+ heads,
313
+ self.self.num_attention_heads,
314
+ self.self.attention_head_size,
315
+ self.pruned_heads,
316
+ )
317
+
318
+ # Prune linear layers
319
+ self.self.query = prune_linear_layer(self.self.query, index)
320
+ self.self.key = prune_linear_layer(self.self.key, index)
321
+ self.self.value = prune_linear_layer(self.self.value, index)
322
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
323
+
324
+ # Update hyper params and store pruned heads
325
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
326
+ self.self.all_head_size = (
327
+ self.self.attention_head_size * self.self.num_attention_heads
328
+ )
329
+ self.pruned_heads = self.pruned_heads.union(heads)
330
+
331
+ def forward(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ attention_mask: Optional[torch.FloatTensor] = None,
335
+ head_mask: Optional[torch.FloatTensor] = None,
336
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
337
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
338
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
339
+ output_attentions: Optional[bool] = False,
340
+ rotary_pos_emb=None,
341
+ ) -> Tuple[torch.Tensor]:
342
+ # debug_point1 = hidden_states[0]
343
+ ln_outputs = self.ln(hidden_states)
344
+ self_outputs = self.self(
345
+ ln_outputs,
346
+ attention_mask,
347
+ head_mask,
348
+ encoder_hidden_states,
349
+ encoder_attention_mask,
350
+ past_key_value,
351
+ output_attentions,
352
+ rotary_pos_emb,
353
+ )
354
+ attention_output = self.output(self_outputs[0], hidden_states)
355
+ outputs = (attention_output,) + self_outputs[
356
+ 1:
357
+ ] # add attentions if we output them
358
+ return outputs
359
+
360
+
361
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RNABert
362
+ class RNABertMLP(nn.Module):
363
+ def __init__(self, config: RNABertConfig):
364
+ super().__init__()
365
+ assert config.hidden_act == "swiglu", "Only swiglu is supported."
366
+ self.up_proj = nn.Linear(
367
+ config.hidden_size, config.intermediate_size, bias=config.add_linear_bias
368
+ )
369
+ self.down_proj = nn.Linear(
370
+ config.intermediate_size, config.hidden_size, bias=config.add_linear_bias
371
+ )
372
+ self.gate_proj = nn.Linear(
373
+ config.hidden_size, config.intermediate_size, bias=config.add_linear_bias
374
+ )
375
+ self.intermediate_act_fn = ACT2FN[
376
+ "silu"
377
+ ] # swiglu use silu as part of its activation function
378
+
379
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
380
+ down_proj = self.down_proj(
381
+ self.intermediate_act_fn(self.gate_proj(hidden_states))
382
+ * self.up_proj(hidden_states)
383
+ )
384
+ return down_proj
385
+
386
+
387
+ # Based on transformers.models.bert.modeling_bert.BertOutput. Moved LayerNorm to RNABertLayer below.
388
+ class RNABertOutput(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ # self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
392
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
393
+
394
+ def forward(
395
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
396
+ ) -> torch.Tensor:
397
+ # hidden_states = self.dense(hidden_states)
398
+ hidden_states = self.dropout(hidden_states)
399
+ return input_tensor + hidden_states
400
+
401
+
402
+ # Based on transformers.models.bert.modeling_bert.BertLayer. Added LayerNorm.
403
+ class RNABertLayer(nn.Module):
404
+ def __init__(self, config):
405
+ super().__init__()
406
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
407
+ self.seq_len_dim = 1
408
+ self.attention = RNABertAttention(config)
409
+ self.is_decoder = config.is_decoder
410
+ self.add_cross_attention = config.add_cross_attention
411
+ if self.add_cross_attention:
412
+ if not self.is_decoder:
413
+ raise TypeError(
414
+ f"{self} should be used as a decoder model if cross attention is added"
415
+ )
416
+ self.crossattention = RNABertAttention(config)
417
+ self.ln = config.norm_cls(config.hidden_size, eps=config.layer_norm_eps)
418
+ self.mlp = RNABertMLP(config)
419
+ self.output = RNABertOutput(config)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ attention_mask: Optional[torch.FloatTensor] = None,
425
+ head_mask: Optional[torch.FloatTensor] = None,
426
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
427
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
428
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
429
+ output_attentions: Optional[bool] = False,
430
+ rotary_pos_emb=None,
431
+ ) -> Tuple[torch.Tensor]:
432
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
433
+ self_attn_past_key_value = (
434
+ past_key_value[:2] if past_key_value is not None else None
435
+ )
436
+ self_attention_outputs = self.attention(
437
+ hidden_states,
438
+ attention_mask,
439
+ head_mask,
440
+ output_attentions=output_attentions,
441
+ past_key_value=self_attn_past_key_value,
442
+ rotary_pos_emb=rotary_pos_emb,
443
+ )
444
+ attention_output = self_attention_outputs[0]
445
+
446
+ # if decoder, the last output is tuple of self-attn cache
447
+ if self.is_decoder:
448
+ outputs = self_attention_outputs[1:-1]
449
+ present_key_value = self_attention_outputs[-1]
450
+ else:
451
+ outputs = self_attention_outputs[
452
+ 1:
453
+ ] # add self attentions if we output attention weights
454
+
455
+ cross_attn_present_key_value = None
456
+ if self.is_decoder and encoder_hidden_states is not None:
457
+ if not hasattr(self, "crossattention"):
458
+ raise AttributeError(
459
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
460
+ " by setting `config.add_cross_attention=True`"
461
+ )
462
+
463
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
464
+ cross_attn_past_key_value = (
465
+ past_key_value[-2:] if past_key_value is not None else None
466
+ )
467
+ cross_attention_outputs = self.crossattention(
468
+ attention_output,
469
+ attention_mask,
470
+ head_mask,
471
+ encoder_hidden_states,
472
+ encoder_attention_mask,
473
+ cross_attn_past_key_value,
474
+ output_attentions,
475
+ )
476
+ attention_output = cross_attention_outputs[0]
477
+ outputs = (
478
+ outputs + cross_attention_outputs[1:-1]
479
+ ) # add cross attentions if we output attention weights
480
+
481
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
482
+ cross_attn_present_key_value = cross_attention_outputs[-1]
483
+ present_key_value = present_key_value + cross_attn_present_key_value
484
+
485
+ layer_output = apply_chunking_to_forward(
486
+ self.feed_forward_chunk,
487
+ self.chunk_size_feed_forward,
488
+ self.seq_len_dim,
489
+ attention_output,
490
+ )
491
+ outputs = (layer_output,) + outputs
492
+
493
+ # if decoder, return the attn key/values as the last output
494
+ if self.is_decoder:
495
+ outputs = outputs + (present_key_value,)
496
+
497
+ return outputs
498
+
499
+ def feed_forward_chunk(self, attention_output):
500
+ # debug: attention_output[0]
501
+ ln_output = self.ln(attention_output)
502
+ mlp_output = self.mlp(ln_output)
503
+ layer_output = self.output(mlp_output, attention_output)
504
+ return layer_output
505
+
506
+
507
+ class RnaRMSNorm(nn.Module):
508
+ def __init__(self, hidden_size, eps=1e-6):
509
+ """
510
+ same as LlamaRMSNorm
511
+ """
512
+ super().__init__()
513
+ self.weight = nn.Parameter(torch.ones(hidden_size))
514
+ self.variance_epsilon = eps
515
+
516
+ def forward(self, hidden_states):
517
+ input_dtype = hidden_states.dtype
518
+ hidden_states = hidden_states.to(torch.float32)
519
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
520
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
521
+ return self.weight * hidden_states.to(input_dtype)
522
+
523
+
524
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
525
+
526
+ ALL_LAYERNORM_LAYERS.append(RnaRMSNorm)
527
+
528
+
529
+ class RNABertEncoder(nn.Module):
530
+ def __init__(self, config):
531
+ super().__init__()
532
+ self.config = config
533
+ self.layer = nn.ModuleList(
534
+ [RNABertLayer(config) for _ in range(config.num_hidden_layers)]
535
+ )
536
+
537
+ # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one
538
+ # is simply the final LN (Transformer's BERT has it attached to each hidden layer).
539
+ self.ln = config.norm_cls(config.hidden_size, eps=config.layer_norm_eps)
540
+ self.gradient_checkpointing = False
541
+
542
+ def forward(
543
+ self,
544
+ hidden_states: torch.Tensor,
545
+ attention_mask: Optional[torch.FloatTensor] = None,
546
+ head_mask: Optional[torch.FloatTensor] = None,
547
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
548
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
549
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
550
+ use_cache: Optional[bool] = None,
551
+ output_attentions: Optional[bool] = False,
552
+ output_hidden_states: Optional[bool] = False,
553
+ return_dict: Optional[bool] = True,
554
+ rotary_pos_emb: Optional[torch.FloatTensor] = None,
555
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
556
+ if self.gradient_checkpointing and self.training:
557
+ if use_cache:
558
+ print(
559
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
560
+ )
561
+ use_cache = False
562
+ all_hidden_states = () if output_hidden_states else None
563
+ all_self_attentions = () if output_attentions else None
564
+ all_cross_attentions = (
565
+ () if output_attentions and self.config.add_cross_attention else None
566
+ )
567
+
568
+ next_decoder_cache = () if use_cache else None
569
+ for i, layer_module in enumerate(self.layer):
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ layer_head_mask = head_mask[i] if head_mask is not None else None
574
+ past_key_value = past_key_values[i] if past_key_values is not None else None
575
+
576
+ if self.gradient_checkpointing and self.training:
577
+ layer_outputs = self._gradient_checkpointing_func(
578
+ layer_module.__call__,
579
+ hidden_states,
580
+ attention_mask,
581
+ layer_head_mask,
582
+ encoder_hidden_states,
583
+ encoder_attention_mask,
584
+ past_key_value,
585
+ output_attentions,
586
+ rotary_pos_emb,
587
+ )
588
+ else:
589
+ layer_outputs = layer_module(
590
+ hidden_states,
591
+ attention_mask,
592
+ layer_head_mask,
593
+ encoder_hidden_states,
594
+ encoder_attention_mask,
595
+ past_key_value,
596
+ output_attentions,
597
+ rotary_pos_emb,
598
+ )
599
+
600
+ # Because we moved the layer-norm at the end of the hidden layer, we have non-normali-
601
+ # zed data here. If that's really needed, we must apply LN to match Transformer's BERT.
602
+
603
+ hidden_states = layer_outputs[0]
604
+ if use_cache:
605
+ next_decoder_cache += (layer_outputs[-1],)
606
+ if output_attentions:
607
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
608
+ if self.config.add_cross_attention:
609
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
610
+
611
+ # Finalize the hidden states.
612
+ hidden_states = self.ln(hidden_states)
613
+
614
+ if output_hidden_states:
615
+ all_hidden_states = all_hidden_states + (hidden_states,)
616
+
617
+ if not return_dict:
618
+ return tuple(
619
+ v
620
+ for v in [
621
+ hidden_states,
622
+ next_decoder_cache,
623
+ all_hidden_states,
624
+ all_self_attentions,
625
+ all_cross_attentions,
626
+ ]
627
+ if v is not None
628
+ )
629
+ return BaseModelOutputWithPastAndCrossAttentions(
630
+ last_hidden_state=hidden_states,
631
+ past_key_values=next_decoder_cache,
632
+ hidden_states=all_hidden_states,
633
+ attentions=all_self_attentions,
634
+ cross_attentions=all_cross_attentions,
635
+ )
636
+
637
+
638
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->RNABert
639
+ class RNABertPooler(nn.Module):
640
+ def __init__(self, config):
641
+ super().__init__()
642
+ self.dense = nn.Linear(
643
+ config.hidden_size, config.hidden_size, bias=config.add_linear_bias
644
+ )
645
+ self.activation = nn.Tanh()
646
+
647
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
648
+ # We "pool" the model by simply taking the hidden state corresponding
649
+ # to the first token.
650
+ first_token_tensor = hidden_states[:, 0]
651
+ pooled_output = self.dense(first_token_tensor)
652
+ pooled_output = self.activation(pooled_output)
653
+ return pooled_output
654
+
655
+
656
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->RNABert
657
+ class RNABertPredictionHeadTransform(nn.Module):
658
+ def __init__(self, config):
659
+ super().__init__()
660
+ self.dense = nn.Linear(
661
+ config.hidden_size, config.hidden_size
662
+ ) # in megatron, this will always have bias
663
+
664
+ self.transform_act_fn = ACT2FN["gelu"]
665
+
666
+ if config.normalization_type == "RMSNorm":
667
+ self.LayerNorm = RnaRMSNorm(config.hidden_size, eps=config.layer_norm_eps)
668
+ else:
669
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
670
+
671
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
672
+ hidden_states = self.dense(hidden_states)
673
+ hidden_states = self.transform_act_fn(hidden_states)
674
+ hidden_states = self.LayerNorm(hidden_states)
675
+ return hidden_states
676
+
677
+
678
+ # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->RNABert
679
+ class RNABertLMPredictionHead(nn.Module):
680
+ def __init__(self, config):
681
+ super().__init__()
682
+ self.transform = RNABertPredictionHeadTransform(config)
683
+
684
+ # The output weights are the same as the input embeddings, but there is
685
+ # an output-only bias for each token.
686
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
687
+
688
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
689
+
690
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
691
+ self.decoder.bias = self.bias
692
+
693
+ def forward(self, hidden_states):
694
+ hidden_states = self.transform(hidden_states)
695
+ hidden_states = self.decoder(hidden_states)
696
+ return hidden_states
697
+
698
+
699
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RNABert
700
+ class RNABertOnlyMLMHead(nn.Module):
701
+ def __init__(self, config):
702
+ super().__init__()
703
+ self.predictions = RNABertLMPredictionHead(config)
704
+
705
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
706
+ prediction_scores = self.predictions(sequence_output)
707
+ return prediction_scores
708
+
709
+
710
+ class RNABertPreTrainedModel(PreTrainedModel):
711
+ """
712
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
713
+ models.
714
+ """
715
+
716
+ config_class = RNABertConfig
717
+ # load_tf_weights = load_tf_weights_in_rnabert
718
+ base_model_prefix = "bert"
719
+ supports_gradient_checkpointing = True
720
+
721
+ def _init_weights(self, module):
722
+ """Initialize the weights"""
723
+ if isinstance(module, (nn.Linear, nn.Embedding)):
724
+ # Slightly different from the TF version which uses truncated_normal for initialization
725
+ # cf https://github.com/pytorch/pytorch/pull/5617
726
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
727
+ elif isinstance(module, nn.LayerNorm):
728
+ module.bias.data.zero_()
729
+ module.weight.data.fill_(1.0)
730
+ elif isinstance(module, RnaRMSNorm):
731
+ module.weight.data.fill_(1.0)
732
+ # no bias
733
+ if isinstance(module, nn.Linear) and module.bias is not None:
734
+ module.bias.data.zero_()
735
+
736
+
737
+
738
+ class RNABertModel(RNABertPreTrainedModel):
739
+ """
740
+
741
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
742
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
743
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
744
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
745
+
746
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
747
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
748
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
749
+ """
750
+
751
+ @classmethod
752
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
753
+ wrapper_config = kwargs.pop("config", None)
754
+ if wrapper_config is None:
755
+ raise ValueError("Config must be provided")
756
+
757
+ base_model = VARIANTS.get(wrapper_config.base_model, wrapper_config.base_model)
758
+
759
+ # load base model config
760
+ base_config = RNABertConfig.from_pretrained(base_model, **kwargs)
761
+
762
+ # keep routing info
763
+ base_config.base_model = wrapper_config.base_model
764
+
765
+ return super().from_pretrained(
766
+ base_model,
767
+ *model_args,
768
+ config=base_config,
769
+ **kwargs,
770
+ )
771
+
772
+ def __init__(self, config, add_pooling_layer=False):
773
+ super().__init__(config)
774
+ self.config = config
775
+ if config.normalization_type == "RMSNorm":
776
+ self.config.norm_cls = RnaRMSNorm
777
+ else:
778
+ assert config.normalization_type == "LayerNorm"
779
+ self.config.norm_cls = nn.LayerNorm
780
+ self.embeddings = RNABertEmbeddings(config)
781
+ self.encoder = RNABertEncoder(config)
782
+
783
+ self.pooler = RNABertPooler(config) if add_pooling_layer else None
784
+
785
+ # rotary position embeddings
786
+ if config.position_embedding_type == "rope":
787
+ rotary_dim = config.hidden_size // config.num_attention_heads
788
+
789
+ # partial rotary embeddings, which is better than full rotary
790
+ # Wang and Komatsuzaki et al
791
+ # https://github.com/kingoflolz/mesh-transformer-jax/
792
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim, config.rotary_percent)
793
+
794
+ # delete this from config so the config can be successfully saved
795
+ del self.config.norm_cls
796
+
797
+ # Initialize weights and apply final processing
798
+ self.post_init()
799
+
800
+ def get_input_embeddings(self):
801
+ return self.embeddings.word_embeddings
802
+
803
+ def set_input_embeddings(self, value):
804
+ self.embeddings.word_embeddings = value
805
+
806
+ def _prune_heads(self, heads_to_prune):
807
+ """
808
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
809
+ class PreTrainedModel
810
+ """
811
+ for layer, heads in heads_to_prune.items():
812
+ self.encoder.layer[layer].attention.prune_heads(heads)
813
+
814
+ def forward(
815
+ self,
816
+ input_ids: Optional[torch.LongTensor] = None,
817
+ attention_mask: Optional[torch.FloatTensor] = None,
818
+ token_type_ids: Optional[torch.LongTensor] = None,
819
+ position_ids: Optional[torch.LongTensor] = None,
820
+ head_mask: Optional[torch.FloatTensor] = None,
821
+ inputs_embeds: Optional[torch.FloatTensor] = None,
822
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
823
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
824
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
825
+ use_cache: Optional[bool] = None,
826
+ output_attentions: Optional[bool] = None,
827
+ output_hidden_states: Optional[bool] = None,
828
+ return_dict: Optional[bool] = None,
829
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
830
+ r"""
831
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
832
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
833
+ the model is configured as a decoder.
834
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
835
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
836
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
837
+
838
+ - 1 for tokens that are **not masked**,
839
+ - 0 for tokens that are **masked**.
840
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
841
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
842
+
843
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
844
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
845
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
846
+ use_cache (`bool`, *optional*):
847
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
848
+ `past_key_values`).
849
+ """
850
+ output_attentions = (
851
+ output_attentions
852
+ if output_attentions is not None
853
+ else self.config.output_attentions
854
+ )
855
+ output_hidden_states = (
856
+ output_hidden_states
857
+ if output_hidden_states is not None
858
+ else self.config.output_hidden_states
859
+ )
860
+ return_dict = (
861
+ return_dict if return_dict is not None else self.config.use_return_dict
862
+ )
863
+
864
+ if self.config.is_decoder:
865
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
866
+ else:
867
+ use_cache = False
868
+
869
+ if input_ids is not None and inputs_embeds is not None:
870
+ raise ValueError(
871
+ "You cannot specify both input_ids and inputs_embeds at the same time"
872
+ )
873
+ elif input_ids is not None:
874
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
875
+ input_shape = input_ids.size()
876
+ elif inputs_embeds is not None:
877
+ input_shape = inputs_embeds.size()[:-1]
878
+ else:
879
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
880
+
881
+ batch_size, seq_length = input_shape
882
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
883
+
884
+ # past_key_values_length
885
+ past_key_values_length = (
886
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
887
+ )
888
+
889
+ if attention_mask is None:
890
+ attention_mask = torch.ones(
891
+ ((batch_size, seq_length + past_key_values_length)), device=device
892
+ )
893
+ if token_type_ids is None:
894
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
895
+
896
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
897
+ # ourselves in which case we just need to make it broadcastable to all heads.
898
+ # extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
899
+ extended_attention_mask = bert_extended_attention_mask(
900
+ attention_mask
901
+ ) # True for pad, false for non-pad
902
+ extended_attention_mask = extended_attention_mask * torch.finfo(torch.float).min
903
+
904
+ # If a 2D or 3D attention mask is provided for the cross-attention
905
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
906
+ if self.config.is_decoder and encoder_hidden_states is not None:
907
+ encoder_batch_size, encoder_sequence_length, _ = (
908
+ encoder_hidden_states.size()
909
+ )
910
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
911
+ if encoder_attention_mask is None:
912
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
913
+ encoder_extended_attention_mask = self.invert_attention_mask(
914
+ encoder_attention_mask
915
+ )
916
+ else:
917
+ encoder_extended_attention_mask = None
918
+
919
+ # Prepare head mask if needed
920
+ # 1.0 in head_mask indicate we keep the head
921
+ # attention_probs has shape bsz x n_heads x N x N
922
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
923
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
924
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
925
+
926
+ # Rotary positional embeddings
927
+ rotary_pos_emb = None
928
+ if self.config.position_embedding_type == "rope":
929
+ rotary_pos_emb = self.rotary_pos_emb(input_ids.size(1))
930
+
931
+ embedding_output = self.embeddings(
932
+ input_ids=input_ids,
933
+ position_ids=position_ids,
934
+ token_type_ids=token_type_ids,
935
+ inputs_embeds=inputs_embeds,
936
+ past_key_values_length=past_key_values_length,
937
+ )
938
+ encoder_outputs = self.encoder(
939
+ embedding_output,
940
+ attention_mask=extended_attention_mask,
941
+ head_mask=head_mask,
942
+ encoder_hidden_states=encoder_hidden_states,
943
+ encoder_attention_mask=encoder_extended_attention_mask,
944
+ past_key_values=past_key_values,
945
+ use_cache=use_cache,
946
+ output_attentions=output_attentions,
947
+ output_hidden_states=output_hidden_states,
948
+ return_dict=return_dict,
949
+ rotary_pos_emb=rotary_pos_emb,
950
+ )
951
+ sequence_output = encoder_outputs[0]
952
+ pooled_output = (
953
+ self.pooler(sequence_output) if self.pooler is not None else None
954
+ )
955
+
956
+ if not return_dict:
957
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
958
+
959
+ return BaseModelOutputWithPoolingAndCrossAttentions(
960
+ last_hidden_state=sequence_output,
961
+ pooler_output=pooled_output,
962
+ past_key_values=encoder_outputs.past_key_values,
963
+ hidden_states=encoder_outputs.hidden_states,
964
+ attentions=encoder_outputs.attentions,
965
+ cross_attentions=encoder_outputs.cross_attentions,
966
+ )
967
+
968
+ class RNABertForMaskedLM(RNABertPreTrainedModel):
969
+ _tied_weights_keys = ["cls.predictions.decoder"]
970
+
971
+ def __init__(self, config):
972
+ super().__init__(config)
973
+
974
+ if config.is_decoder:
975
+ print(
976
+ "If you want to use `RNABertForMaskedLM` make sure `config.is_decoder=False` for "
977
+ "bi-directional self-attention."
978
+ )
979
+
980
+ self.bert = RNABertModel(config, add_pooling_layer=False)
981
+ self.cls = RNABertOnlyMLMHead(config)
982
+
983
+ # Initialize weights and apply final processing
984
+ self.post_init()
985
+
986
+ def get_output_embeddings(self):
987
+ return self.cls.predictions.decoder
988
+
989
+ def set_output_embeddings(self, new_embeddings):
990
+ self.cls.predictions.decoder = new_embeddings
991
+
992
+ def forward(
993
+ self,
994
+ input_ids: Optional[torch.LongTensor] = None,
995
+ attention_mask: Optional[torch.FloatTensor] = None,
996
+ token_type_ids: Optional[torch.LongTensor] = None,
997
+ position_ids: Optional[torch.LongTensor] = None,
998
+ head_mask: Optional[torch.FloatTensor] = None,
999
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1000
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1001
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1002
+ labels: Optional[torch.LongTensor] = None,
1003
+ output_attentions: Optional[bool] = None,
1004
+ output_hidden_states: Optional[bool] = None,
1005
+ return_dict: Optional[bool] = None,
1006
+ ) -> Union[Tuple, MaskedLMOutput]:
1007
+ r"""
1008
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1009
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1010
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1011
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1012
+ """
1013
+
1014
+ return_dict = (
1015
+ return_dict if return_dict is not None else self.config.use_return_dict
1016
+ )
1017
+
1018
+ outputs = self.bert(
1019
+ input_ids,
1020
+ attention_mask=attention_mask,
1021
+ token_type_ids=token_type_ids,
1022
+ position_ids=position_ids,
1023
+ head_mask=head_mask,
1024
+ inputs_embeds=inputs_embeds,
1025
+ encoder_hidden_states=encoder_hidden_states,
1026
+ encoder_attention_mask=encoder_attention_mask,
1027
+ output_attentions=output_attentions,
1028
+ output_hidden_states=output_hidden_states,
1029
+ return_dict=return_dict,
1030
+ )
1031
+
1032
+ sequence_output = outputs[0]
1033
+ prediction_scores = self.cls(sequence_output)
1034
+
1035
+ masked_lm_loss = None
1036
+ if labels is not None:
1037
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1038
+ masked_lm_loss = loss_fct(
1039
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1040
+ )
1041
+
1042
+ if not return_dict:
1043
+ output = (prediction_scores,) + outputs[2:]
1044
+ return (
1045
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1046
+ )
1047
+
1048
+ return MaskedLMOutput(
1049
+ loss=masked_lm_loss,
1050
+ logits=prediction_scores,
1051
+ hidden_states=outputs.hidden_states,
1052
+ attentions=outputs.attentions,
1053
+ )
1054
+
1055
+ def prepare_inputs_for_generation(
1056
+ self, input_ids, attention_mask=None, **model_kwargs
1057
+ ):
1058
+ input_shape = input_ids.shape
1059
+ effective_batch_size = input_shape[0]
1060
+
1061
+ # add a dummy token
1062
+ if self.config.pad_token_id is None:
1063
+ raise ValueError("The PAD token should be defined for generation")
1064
+ attention_mask = torch.cat(
1065
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
1066
+ dim=-1,
1067
+ )
1068
+ dummy_token = torch.full(
1069
+ (effective_batch_size, 1),
1070
+ self.config.pad_token_id,
1071
+ dtype=torch.long,
1072
+ device=input_ids.device,
1073
+ )
1074
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1075
+
1076
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1077
+
1078
+
1079
+ from torch import Tensor, nn
1080
+
1081
+
1082
+ class RotaryEmbedding(nn.Module):
1083
+ """Rotary Embedding for language model.
1084
+
1085
+ Args:
1086
+ kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config
1087
+ rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
1088
+ seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None
1089
+ rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000.
1090
+ """
1091
+
1092
+ def __init__(
1093
+ self,
1094
+ kv_channels: int,
1095
+ rotary_percent: float,
1096
+ seq_len_interpolation_factor: float = None,
1097
+ rotary_base: int = 10000,
1098
+ ) -> None:
1099
+ super().__init__()
1100
+
1101
+ dim = kv_channels
1102
+ if rotary_percent < 1.0:
1103
+ dim = int(dim * rotary_percent)
1104
+
1105
+ self.seq_len_interpolation_factor = seq_len_interpolation_factor
1106
+ device = (
1107
+ torch.cuda.current_device()
1108
+ if torch.cuda.is_available()
1109
+ else torch.device("cpu")
1110
+ )
1111
+ self.inv_freq = 1.0 / (
1112
+ rotary_base
1113
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
1114
+ )
1115
+
1116
+ def forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
1117
+ """Forward pass of RoPE embedding.
1118
+
1119
+ Args:
1120
+ max_seq_len (int): Maximum size of sequence
1121
+ offset (int, optional): _description_. Defaults to 0.
1122
+
1123
+ Returns:
1124
+ Tensor: Embeddings after applying RoPE.
1125
+ """
1126
+ seq = (
1127
+ torch.arange(
1128
+ max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype
1129
+ )
1130
+ + offset
1131
+ )
1132
+
1133
+ if self.seq_len_interpolation_factor is not None:
1134
+ seq *= 1 / self.seq_len_interpolation_factor
1135
+
1136
+ freqs = torch.outer(seq, self.inv_freq)
1137
+ # first part even vector components, second part odd vector components,
1138
+ # 2 * dim in dimension size
1139
+ emb = torch.cat((freqs, freqs), dim=-1)
1140
+ # emb [seq_length, .., dim]
1141
+ emb = emb[:, None, None, :]
1142
+
1143
+ return emb
1144
+
1145
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
1146
+ state_dict.pop(f"{prefix}inv_freq", None)
1147
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
1148
+
1149
+
1150
+ def _rotate_half(x: Tensor) -> Tensor:
1151
+ """Change sign so the last dimension becomes [-odd, +even]
1152
+
1153
+ Args:
1154
+ x (Tensor): Input tensor
1155
+
1156
+ Returns:
1157
+ Tensor: Tensor rotated half
1158
+ """
1159
+
1160
+ x1, x2 = torch.chunk(x, 2, dim=-1)
1161
+ return torch.cat((-x2, x1), dim=-1)
1162
+
1163
+
1164
+ def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor:
1165
+ """Apply rotary positional embedding to input tensor T.
1166
+
1167
+ check https://kexue.fm/archives/8265 for detailed formulas
1168
+
1169
+ Args:
1170
+ t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
1171
+ freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
1172
+
1173
+ Returns:
1174
+ Tensor: The input tensor after applying RoPE
1175
+ """
1176
+ rot_dim = freqs.shape[-1]
1177
+
1178
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
1179
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
1180
+
1181
+ # first part is cosine component
1182
+ # second part is sine component, need to change signs with _rotate_half method
1183
+ cos_ = torch.cos(freqs).to(t.dtype).to(t.device)
1184
+ sin_ = torch.sin(freqs).to(t.dtype).to(t.device)
1185
+
1186
+ t = (t * cos_) + (_rotate_half(t) * sin_)
1187
+ return torch.cat((t, t_pass), dim=-1)
1188
+
1189
+
1190
+ def bert_extended_attention_mask(attention_mask):
1191
+ # We create a 3D attention mask from a 2D tensor mask.
1192
+ # [b, 1, s]
1193
+ attention_mask_b1s = attention_mask.unsqueeze(1)
1194
+ # [b, s, 1]
1195
+ attention_mask_bs1 = attention_mask.unsqueeze(2)
1196
+ # [b, s, s]
1197
+ attention_mask_bss = attention_mask_b1s * attention_mask_bs1
1198
+ # [b, 1, s, s]
1199
+ extended_attention_mask = attention_mask_bss.unsqueeze(1)
1200
+
1201
+ # Convert attention mask to binary:
1202
+ extended_attention_mask = extended_attention_mask < 0.5
1203
+
1204
+ return extended_attention_mask
tokenization_rnabert.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ from transformers import PreTrainedTokenizer
5
+
6
+ class RNABertTokenizer(PreTrainedTokenizer):
7
+ """
8
+ Constructs an RNABert tokenizer.
9
+ """
10
+
11
+ vocab_files_names = {"vocab_file": "vocab.txt"}
12
+ model_input_names = ["input_ids", "attention_mask"]
13
+
14
+ def __init__(
15
+ self,
16
+ vocab_file,
17
+ unk_token="[UNK]",
18
+ cls_token="[CLS]",
19
+ pad_token="[PAD]",
20
+ mask_token="[MASK]",
21
+ sep_token="[SEP]",
22
+ bos_token="[BOS]",
23
+ eos_token="[EOS]",
24
+ version="v2", ##v2 by default because the DNA is processed to only have [CLS] and [SEP]
25
+ **kwargs,
26
+ ):
27
+ """
28
+ Args:
29
+ version: for v1, the input is like [cls] [bos] ... [eos] [sep]
30
+ for v2, the input is like [cls] ... [sep]
31
+ """
32
+ with open(vocab_file, "r") as f:
33
+ lines = f.read().splitlines()
34
+ self.all_tokens = [l.strip() for l in lines]
35
+ self._id_to_token = dict(enumerate(self.all_tokens))
36
+ self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
37
+ super().__init__(
38
+ unk_token=unk_token,
39
+ cls_token=cls_token,
40
+ pad_token=pad_token,
41
+ mask_token=mask_token,
42
+ sep_token=sep_token,
43
+ bos_token=bos_token,
44
+ eos_token=eos_token,
45
+ **kwargs,
46
+ )
47
+
48
+ # TODO, all the tokens are added? But they are also part of the vocab... bit strange.
49
+ # none of them are special, but they all need special splitting.
50
+
51
+ self.unique_no_split_tokens = self.all_tokens
52
+ self._update_trie(self.unique_no_split_tokens)
53
+ self.version = version
54
+
55
+ def _convert_id_to_token(self, index: int) -> str:
56
+ return self._id_to_token.get(index, self.unk_token)
57
+
58
+ def _convert_token_to_id(self, token: str) -> int:
59
+ return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
60
+
61
+ def _tokenize(self, text, **kwargs):
62
+ return text.split()
63
+
64
+ def get_vocab(self):
65
+ base_vocab = self._token_to_id.copy()
66
+ base_vocab.update(self.added_tokens_encoder)
67
+ return base_vocab
68
+
69
+ def token_to_id(self, token: str) -> int:
70
+ return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
71
+
72
+ def id_to_token(self, index: int) -> str:
73
+ return self._id_to_token.get(index, self.unk_token)
74
+
75
+ def build_inputs_with_special_tokens(
76
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
77
+ ) -> List[int]:
78
+ cls = [self.cls_token_id]
79
+ bos = [self.bos_token_id]
80
+ eos = [self.eos_token_id]
81
+ sep = [self.sep_token_id]
82
+
83
+ if token_ids_1 is None:
84
+ if self.version == "v1":
85
+ return cls + bos + token_ids_0 + eos + sep
86
+ else:
87
+ return cls + token_ids_0 + sep
88
+ else:
89
+ if self.version == "v1":
90
+ return (
91
+ cls + bos + token_ids_0 + eos + sep + bos + token_ids_1 + eos + sep
92
+ )
93
+ else:
94
+ return cls + token_ids_0 + sep + cls + token_ids_1 + sep
95
+
96
+ def get_special_tokens_mask(
97
+ self,
98
+ token_ids_0: List,
99
+ token_ids_1: Optional[List] = None,
100
+ already_has_special_tokens: bool = False,
101
+ ) -> List[int]:
102
+ """
103
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
104
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
105
+
106
+ Args:
107
+ token_ids_0 (`List[int]`):
108
+ List of ids of the first sequence.
109
+ token_ids_1 (`List[int]`, *optional*):
110
+ List of ids of the second sequence.
111
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
112
+ Whether or not the token list is already formatted with special tokens for the model.
113
+
114
+ Returns:
115
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
116
+ """
117
+ if already_has_special_tokens:
118
+ if token_ids_1 is not None:
119
+ raise ValueError(
120
+ "You should not supply a second sequence if the provided sequence of "
121
+ "ids is already formatted with special tokens for the model."
122
+ )
123
+
124
+ return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
125
+ mask = [1] + ([0] * len(token_ids_0)) + [1]
126
+ if token_ids_1 is not None:
127
+ mask += [0] * len(token_ids_1) + [1]
128
+ return mask
129
+
130
+ def save_vocabulary(self, save_directory, filename_prefix):
131
+ vocab_file = os.path.join(
132
+ save_directory,
133
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
134
+ )
135
+ with open(vocab_file, "w") as f:
136
+ f.write("\n".join(self.all_tokens))
137
+ return (vocab_file,)
138
+
139
+ @property
140
+ def vocab_size(self) -> int:
141
+ return len(self.all_tokens)
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "RNABertTokenizer",
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_rnabert.RNABertTokenizer",
6
+ null
7
+ ]
8
+ }
9
+ }
vocab.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [MASK]
3
+ [CLS]
4
+ [SEP]
5
+ [UNK]
6
+ A
7
+ G
8
+ C
9
+ T
10
+ U
11
+ N
12
+ [BOS]
13
+ [EOS]
14
+ [UNUSED1]
15
+ [UNUSED2]
16
+ [UNUSED3]