Taykhoom commited on
Commit
9e231ea
·
verified ·
1 Parent(s): 4f5f306

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - rna
4
+ library_name: transformers
5
+ tags:
6
+ - RNA
7
+ - language-model
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # RNAErnie2
12
+
13
+ RNAErnie2 is a BERT-based RNA language model trained from scratch on a large-scale RNA
14
+ sequence dataset with up to 2048-nucleotide context length. It is a retrained successor
15
+ to RNAErnie that replaces the PaddlePaddle-based ERNIE backbone with a standard PyTorch
16
+ BERT architecture, extends the pretraining corpus to RNACentral v22 (~31M sequences,
17
+ length <= 2048), and switches to an RNA-native vocabulary (U instead of T).
18
+
19
+ ## Architecture
20
+
21
+ | Parameter | Value |
22
+ |---|---|
23
+ | Layers | 12 |
24
+ | Attention heads | 12 |
25
+ | Embedding dimension | 768 |
26
+ | Intermediate size | 3072 |
27
+ | Vocabulary size | 11 |
28
+ | Positional encoding | Absolute learned |
29
+ | Architecture | Pre-LN BERT / BertForMaskedLM |
30
+ | Max sequence length | 2048 |
31
+
32
+ **Vocabulary:** `[PAD]=0, [UNK]=1, [CLS]=2, [EOS]=3, [SEP]=4, [MASK]=5, A=6, U=7, C=8, G=9, N=10`
33
+
34
+ ## Pretraining
35
+
36
+ - **Objective:** Masked language modelling (MLM)
37
+ - **Data:** RNACentral v22, ~31 million RNA sequences with length <= 2048
38
+ - **Source checkpoint:** [`LLM-EDA/RNAErnie`](https://huggingface.co/LLM-EDA/RNAErnie) on HuggingFace Hub
39
+ - **Tokenisation note:** Sequences use U (not T). Input T is silently converted to U by the tokenizer.
40
+
41
+ ### Checkpoint selection
42
+
43
+ There is a single publicly released RNAErnie2 checkpoint. The weights are taken from
44
+ [`LLM-EDA/RNAErnie`](https://huggingface.co/LLM-EDA/RNAErnie) with one minor
45
+ adjustment: `cls.predictions.decoder.bias` is stored explicitly (it was implicitly
46
+ tied to `cls.predictions.bias` in the original save and was absent from the file).
47
+
48
+ ## Parity Verification
49
+
50
+ Hidden-state representations and MLM logits verified identical (max abs diff < 2e-5)
51
+ to the original `BertForMaskedLM` at all 13 representation levels (embedding + 12 layers).
52
+ Verified on GPU with PyTorch 2.7 / CUDA 12.
53
+
54
+ ## Implementation Notes
55
+
56
+ Custom BERT implementation (`modeling_rnaernie2.py`) with eager, SDPA, and Flash
57
+ Attention 2 backends, following the architecture of
58
+ [`Taykhoom/BERT-updated`](https://huggingface.co/Taykhoom/BERT-updated).
59
+ The original [`LLM-EDA/RNAErnie`](https://huggingface.co/LLM-EDA/RNAErnie) used
60
+ standard HF BERT with no custom attention backends.
61
+
62
+ ## Related Models
63
+
64
+ See the full [RNAErnie collection](<COLLECTION_URL>).
65
+
66
+ | Model | Context | Training data | Notes |
67
+ |---|---|---|---|
68
+ | [RNAErnie](../RNAErnie) | 512 | RNACentral (nts<=512) | Original; PaddlePaddle backbone |
69
+ | **[RNAErnie2](./)** | **2048** | **RNACentral v22 (~31M seqs)** | **This model; PyTorch BERT** |
70
+
71
+ ## Usage
72
+
73
+ ### Embedding generation
74
+
75
+ ```python
76
+ import torch
77
+ from transformers import AutoTokenizer, AutoModel
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
80
+ model = AutoModel.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
81
+ model.eval()
82
+
83
+ sequences = ["AUGCAUGCAUGC", "GCUGCAUGCUAGC"]
84
+ enc = tokenizer(sequences, return_tensors="pt", padding=True)
85
+
86
+ with torch.no_grad():
87
+ out = model(**enc)
88
+
89
+ cls_emb = out.last_hidden_state[:, 0, :] # (batch, 768) -- CLS token
90
+ token_emb = out.last_hidden_state # (batch, seq_len, 768)
91
+
92
+ # Intermediate layers
93
+ out_all = model(**enc, output_hidden_states=True)
94
+ layer6_emb = out_all.hidden_states[6] # (batch, seq_len, 768)
95
+ ```
96
+
97
+ ### MLM logits
98
+
99
+ ```python
100
+ import torch
101
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
104
+ model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
105
+ model.eval()
106
+
107
+ enc = tokenizer(["AUG[MASK]AUG"], return_tensors="pt")
108
+ with torch.no_grad():
109
+ logits = model(**enc).logits # (1, seq_len, 11)
110
+ ```
111
+
112
+ ### SDPA / Flash Attention 2
113
+
114
+ ```python
115
+ model = AutoModel.from_pretrained(
116
+ "Taykhoom/RNAErnie2",
117
+ attn_implementation="sdpa", # or "flash_attention_2"
118
+ trust_remote_code=True,
119
+ )
120
+ ```
121
+
122
+ ### Fine-tuning
123
+
124
+ Standard HF conventions. For sequence-level tasks, use the CLS token embedding
125
+ (`last_hidden_state[:, 0, :]`) as input to a classification head.
126
+
127
+ ## Citation
128
+
129
+ ```bibtex
130
+ @article{wang2024_rnaernie,
131
+ title = {Multi-purpose {RNA} language modelling with motif-aware pretraining and type-guided fine-tuning},
132
+ author = {Wang, Ning and Bian, Jiang and Li, Yuchen and Li, Xuhong and Mumtaz, Shahid and Kong, Linghe and Xiong, Haoyi},
133
+ journal = {Nature Machine Intelligence},
134
+ volume = {6},
135
+ pages = {548--557},
136
+ year = {2024},
137
+ doi = {10.1038/s42256-024-00836-4}
138
+ }
139
+ ```
140
+
141
+ ## Credits
142
+
143
+ Original model and code by Wang et al. Source: [GitHub](https://github.com/CatIIIIIIII/RNAErnie) /
144
+ [HuggingFace](https://huggingface.co/LLM-EDA/RNAErnie).
145
+ The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code)
146
+ and reviewed manually by Taykhoom Dalal.
147
+
148
+ ## License
149
+
150
+ Apache 2.0, following the original repository.
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RNAErnie2ForMaskedLM"
4
+ ],
5
+ "model_type": "rnaernie2",
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_rnaernie2.RNAErnie2Config",
8
+ "AutoModel": "modeling_rnaernie2.RNAErnie2Model",
9
+ "AutoModelForMaskedLM": "modeling_rnaernie2.RNAErnie2ForMaskedLM"
10
+ },
11
+ "vocab_size": 11,
12
+ "hidden_size": 768,
13
+ "num_hidden_layers": 12,
14
+ "num_attention_heads": 12,
15
+ "intermediate_size": 3072,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.1,
18
+ "attention_probs_dropout_prob": 0.1,
19
+ "max_position_embeddings": 2048,
20
+ "type_vocab_size": 2,
21
+ "layer_norm_eps": 1e-05,
22
+ "pad_token_id": 0,
23
+ "initializer_range": 0.02,
24
+ "transformers_version": "4.57.6"
25
+ }
configuration_rnaernie2.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RNAErnie2Config(PretrainedConfig):
5
+ model_type = "rnaernie2"
6
+
7
+ auto_map = {
8
+ "AutoConfig": "configuration_rnaernie2.RNAErnie2Config",
9
+ "AutoModel": "modeling_rnaernie2.RNAErnie2Model",
10
+ "AutoModelForMaskedLM": "modeling_rnaernie2.RNAErnie2ForMaskedLM",
11
+ }
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 11,
16
+ hidden_size: int = 768,
17
+ num_hidden_layers: int = 12,
18
+ num_attention_heads: int = 12,
19
+ intermediate_size: int = 3072,
20
+ hidden_act: str = "gelu",
21
+ hidden_dropout_prob: float = 0.1,
22
+ attention_probs_dropout_prob: float = 0.1,
23
+ max_position_embeddings: int = 2048,
24
+ type_vocab_size: int = 2,
25
+ layer_norm_eps: float = 1e-5,
26
+ pad_token_id: int = 0,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
30
+ self.vocab_size = vocab_size
31
+ self.hidden_size = hidden_size
32
+ self.num_hidden_layers = num_hidden_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.intermediate_size = intermediate_size
35
+ self.hidden_act = hidden_act
36
+ self.hidden_dropout_prob = hidden_dropout_prob
37
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.type_vocab_size = type_vocab_size
40
+ self.layer_norm_eps = layer_norm_eps
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5642f83f12205dcd729145578b1ab0d78e3124335fe9d13450ab363295456b33
3
+ size 348947640
modeling_rnaernie2.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput
9
+
10
+ try:
11
+ from .configuration_rnaernie2 import RNAErnie2Config
12
+ except ImportError:
13
+ from configuration_rnaernie2 import RNAErnie2Config
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Attention variants
18
+ # ---------------------------------------------------------------------------
19
+
20
+ class RNAErnie2SelfAttention(nn.Module):
21
+
22
+ def __init__(self, config: RNAErnie2Config):
23
+ super().__init__()
24
+ self.num_attention_heads = config.num_attention_heads
25
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
26
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
27
+
28
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
29
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
30
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
31
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
32
+
33
+ def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
34
+ B, T, _ = x.shape
35
+ return x.view(B, T, self.num_attention_heads, self.attention_head_size).permute(0, 2, 1, 3)
36
+
37
+ def forward(
38
+ self,
39
+ hidden_states: torch.Tensor,
40
+ key_padding_mask: Optional[torch.Tensor] = None,
41
+ output_attentions: bool = False,
42
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
43
+ q = self._split_heads(self.query(hidden_states))
44
+ k = self._split_heads(self.key(hidden_states))
45
+ v = self._split_heads(self.value(hidden_states))
46
+
47
+ scale = math.sqrt(self.attention_head_size)
48
+ scores = torch.matmul(q, k.transpose(-1, -2)) / scale
49
+ if key_padding_mask is not None:
50
+ scores = scores.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
51
+ probs = F.softmax(scores, dim=-1)
52
+ probs = self.dropout(probs)
53
+ context = torch.matmul(probs, v)
54
+
55
+ B, _, T, _ = context.shape
56
+ context = context.permute(0, 2, 1, 3).contiguous().view(B, T, self.all_head_size)
57
+
58
+ if output_attentions:
59
+ return context, probs
60
+ return context, None
61
+
62
+
63
+ class RNAErnie2SdpaSelfAttention(RNAErnie2SelfAttention):
64
+
65
+ def forward(
66
+ self,
67
+ hidden_states: torch.Tensor,
68
+ key_padding_mask: Optional[torch.Tensor] = None,
69
+ output_attentions: bool = False,
70
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
71
+ if output_attentions:
72
+ return super().forward(hidden_states, key_padding_mask, output_attentions=True)
73
+
74
+ B, T, _ = hidden_states.shape
75
+ q = self._split_heads(self.query(hidden_states))
76
+ k = self._split_heads(self.key(hidden_states))
77
+ v = self._split_heads(self.value(hidden_states))
78
+
79
+ attn_mask = None
80
+ if key_padding_mask is not None:
81
+ attn_mask = torch.zeros(B, 1, 1, T, dtype=q.dtype, device=q.device)
82
+ attn_mask = attn_mask.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
83
+
84
+ context = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
85
+ context = context.permute(0, 2, 1, 3).contiguous().view(B, T, self.all_head_size)
86
+ return context, None
87
+
88
+
89
+ class RNAErnie2FlashSelfAttention(RNAErnie2SelfAttention):
90
+
91
+ def forward(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ key_padding_mask: Optional[torch.Tensor] = None,
95
+ output_attentions: bool = False,
96
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
97
+ if output_attentions:
98
+ return super().forward(hidden_states, key_padding_mask, output_attentions=True)
99
+
100
+ try:
101
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
102
+ from flash_attn.bert_padding import pad_input, unpad_input
103
+ except ImportError as e:
104
+ raise ImportError(
105
+ "flash_attn is required for attn_implementation='flash_attention_2'. "
106
+ "Install with: pip install flash-attn --no-build-isolation"
107
+ ) from e
108
+
109
+ B, T, _ = hidden_states.shape
110
+ q = self._split_heads(self.query(hidden_states))
111
+ k = self._split_heads(self.key(hidden_states))
112
+ v = self._split_heads(self.value(hidden_states))
113
+
114
+ q = q.permute(0, 2, 1, 3)
115
+ k = k.permute(0, 2, 1, 3)
116
+ v = v.permute(0, 2, 1, 3)
117
+
118
+ orig_dtype = q.dtype
119
+ if orig_dtype not in (torch.float16, torch.bfloat16):
120
+ q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
121
+
122
+ if key_padding_mask is not None and key_padding_mask.any():
123
+ attend = ~key_padding_mask
124
+ q_u, indices, cu_seqlens, max_seqlen, _ = unpad_input(q, attend)
125
+ k_u, _, _, _, _ = unpad_input(k, attend)
126
+ v_u, _, _, _, _ = unpad_input(v, attend)
127
+ out_u = flash_attn_varlen_func(
128
+ q_u, k_u, v_u,
129
+ cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
130
+ max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
131
+ causal=False,
132
+ )
133
+ out = pad_input(out_u, indices, B, T)
134
+ else:
135
+ out = flash_attn_func(q, k, v, causal=False)
136
+
137
+ out = out.to(orig_dtype).reshape(B, T, self.all_head_size)
138
+ return out, None
139
+
140
+
141
+ RNAERNIE2_SELF_ATTENTION_CLASSES = {
142
+ "eager": RNAErnie2SelfAttention,
143
+ "sdpa": RNAErnie2SdpaSelfAttention,
144
+ "flash_attention_2": RNAErnie2FlashSelfAttention,
145
+ }
146
+
147
+
148
+ # ---------------------------------------------------------------------------
149
+ # Layer components -- attribute names match BertForMaskedLM weight keys exactly
150
+ # ---------------------------------------------------------------------------
151
+
152
+ class RNAErnie2SelfOutput(nn.Module):
153
+ def __init__(self, config: RNAErnie2Config):
154
+ super().__init__()
155
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
156
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
157
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
158
+
159
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
160
+ hidden_states = self.dropout(self.dense(hidden_states))
161
+ return self.LayerNorm(hidden_states + input_tensor)
162
+
163
+
164
+ class RNAErnie2Attention(nn.Module):
165
+ def __init__(self, config: RNAErnie2Config):
166
+ super().__init__()
167
+ attn_cls = RNAERNIE2_SELF_ATTENTION_CLASSES[getattr(config, "_attn_implementation", "eager")]
168
+ self.self = attn_cls(config)
169
+ self.output = RNAErnie2SelfOutput(config)
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: torch.Tensor,
174
+ key_padding_mask: Optional[torch.Tensor],
175
+ output_attentions: bool = False,
176
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
177
+ self_out, attn_weights = self.self(hidden_states, key_padding_mask, output_attentions)
178
+ return self.output(self_out, hidden_states), attn_weights
179
+
180
+
181
+ class RNAErnie2Intermediate(nn.Module):
182
+ def __init__(self, config: RNAErnie2Config):
183
+ super().__init__()
184
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
185
+ self.act = nn.GELU() if config.hidden_act == "gelu" else nn.ReLU()
186
+
187
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
188
+ return self.act(self.dense(hidden_states))
189
+
190
+
191
+ class RNAErnie2Output(nn.Module):
192
+ def __init__(self, config: RNAErnie2Config):
193
+ super().__init__()
194
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
195
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
196
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
197
+
198
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
199
+ hidden_states = self.dropout(self.dense(hidden_states))
200
+ return self.LayerNorm(hidden_states + input_tensor)
201
+
202
+
203
+ class RNAErnie2Layer(nn.Module):
204
+ def __init__(self, config: RNAErnie2Config):
205
+ super().__init__()
206
+ self.attention = RNAErnie2Attention(config)
207
+ self.intermediate = RNAErnie2Intermediate(config)
208
+ self.output = RNAErnie2Output(config)
209
+
210
+ def forward(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ key_padding_mask: Optional[torch.Tensor],
214
+ output_attentions: bool = False,
215
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
216
+ attn_out, attn_weights = self.attention(hidden_states, key_padding_mask, output_attentions)
217
+ return self.output(self.intermediate(attn_out), attn_out), attn_weights
218
+
219
+
220
+ class RNAErnie2Encoder(nn.Module):
221
+ def __init__(self, config: RNAErnie2Config):
222
+ super().__init__()
223
+ self.layer = nn.ModuleList([RNAErnie2Layer(config) for _ in range(config.num_hidden_layers)])
224
+
225
+ def forward(
226
+ self,
227
+ hidden_states: torch.Tensor,
228
+ key_padding_mask: Optional[torch.Tensor],
229
+ output_hidden_states: bool = False,
230
+ output_attentions: bool = False,
231
+ ) -> Tuple:
232
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
233
+ all_attentions = () if output_attentions else None
234
+
235
+ for layer in self.layer:
236
+ hidden_states, attn_weights = layer(hidden_states, key_padding_mask, output_attentions)
237
+ if output_hidden_states:
238
+ all_hidden_states = all_hidden_states + (hidden_states,)
239
+ if output_attentions:
240
+ all_attentions = all_attentions + (attn_weights,)
241
+
242
+ return hidden_states, all_hidden_states, all_attentions
243
+
244
+
245
+ # ---------------------------------------------------------------------------
246
+ # Embeddings and pooler
247
+ # ---------------------------------------------------------------------------
248
+
249
+ class RNAErnie2Embeddings(nn.Module):
250
+ def __init__(self, config: RNAErnie2Config):
251
+ super().__init__()
252
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
253
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
254
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
255
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
256
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
257
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
258
+
259
+ def forward(self, input_ids: torch.LongTensor, token_type_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
260
+ B, T = input_ids.shape
261
+ if token_type_ids is None:
262
+ token_type_ids = torch.zeros_like(input_ids)
263
+ x = self.word_embeddings(input_ids)
264
+ x = x + self.position_embeddings(self.position_ids[:, :T])
265
+ x = x + self.token_type_embeddings(token_type_ids)
266
+ return self.dropout(self.LayerNorm(x))
267
+
268
+
269
+ class RNAErnie2Pooler(nn.Module):
270
+ def __init__(self, config: RNAErnie2Config):
271
+ super().__init__()
272
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
273
+ self.activation = nn.Tanh()
274
+
275
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
276
+ return self.activation(self.dense(hidden_states[:, 0]))
277
+
278
+
279
+ # ---------------------------------------------------------------------------
280
+ # MLM prediction head -- key names match original BertForMaskedLM exactly:
281
+ # cls.predictions.bias
282
+ # cls.predictions.transform.dense.{weight,bias}
283
+ # cls.predictions.transform.LayerNorm.{weight,bias}
284
+ # cls.predictions.decoder.weight (tied to word_embeddings)
285
+ # ---------------------------------------------------------------------------
286
+
287
+ class RNAErnie2PredictionHeadTransform(nn.Module):
288
+ def __init__(self, config: RNAErnie2Config):
289
+ super().__init__()
290
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
291
+ self.act = nn.GELU() if config.hidden_act == "gelu" else nn.ReLU()
292
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
293
+
294
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
295
+ return self.LayerNorm(self.act(self.dense(hidden_states)))
296
+
297
+
298
+ class RNAErnie2LMPredictionHead(nn.Module):
299
+ def __init__(self, config: RNAErnie2Config):
300
+ super().__init__()
301
+ self.transform = RNAErnie2PredictionHeadTransform(config)
302
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
303
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
304
+ self.decoder.bias = self.bias
305
+
306
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
307
+ return self.decoder(self.transform(hidden_states))
308
+
309
+
310
+ class RNAErnie2OnlyMLMHead(nn.Module):
311
+ def __init__(self, config: RNAErnie2Config):
312
+ super().__init__()
313
+ self.predictions = RNAErnie2LMPredictionHead(config)
314
+
315
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
316
+ return self.predictions(sequence_output)
317
+
318
+
319
+ # ---------------------------------------------------------------------------
320
+ # Top-level models
321
+ # ---------------------------------------------------------------------------
322
+
323
+ class RNAErnie2Model(PreTrainedModel):
324
+ config_class = RNAErnie2Config
325
+ _supports_sdpa = True
326
+ _supports_flash_attn_2 = True
327
+
328
+ def __init__(self, config: RNAErnie2Config):
329
+ super().__init__(config)
330
+ self.embeddings = RNAErnie2Embeddings(config)
331
+ self.encoder = RNAErnie2Encoder(config)
332
+ self.pooler = RNAErnie2Pooler(config)
333
+ self.post_init()
334
+
335
+ def get_input_embeddings(self):
336
+ return self.embeddings.word_embeddings
337
+
338
+ def set_input_embeddings(self, value):
339
+ self.embeddings.word_embeddings = value
340
+
341
+ def forward(
342
+ self,
343
+ input_ids: torch.LongTensor,
344
+ attention_mask: Optional[torch.Tensor] = None,
345
+ token_type_ids: Optional[torch.LongTensor] = None,
346
+ output_hidden_states: Optional[bool] = None,
347
+ output_attentions: Optional[bool] = None,
348
+ return_dict: Optional[bool] = None,
349
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
350
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
351
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
352
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
353
+
354
+ if attention_mask is None:
355
+ attention_mask = torch.ones_like(input_ids)
356
+ key_padding_mask = attention_mask.eq(0)
357
+ if not key_padding_mask.any():
358
+ key_padding_mask = None
359
+
360
+ x = self.embeddings(input_ids, token_type_ids)
361
+ last_hidden_state, all_hidden_states, all_attentions = self.encoder(
362
+ x, key_padding_mask,
363
+ output_hidden_states=output_hidden_states,
364
+ output_attentions=output_attentions,
365
+ )
366
+ pooled = self.pooler(last_hidden_state)
367
+
368
+ if not return_dict:
369
+ return tuple(v for v in [last_hidden_state, pooled, all_hidden_states, all_attentions] if v is not None)
370
+
371
+ return BaseModelOutputWithPooling(
372
+ last_hidden_state=last_hidden_state,
373
+ pooler_output=pooled,
374
+ hidden_states=all_hidden_states,
375
+ attentions=all_attentions,
376
+ )
377
+
378
+
379
+ class RNAErnie2ForMaskedLM(PreTrainedModel):
380
+ config_class = RNAErnie2Config
381
+ _supports_sdpa = True
382
+ _supports_flash_attn_2 = True
383
+
384
+ def __init__(self, config: RNAErnie2Config):
385
+ super().__init__(config)
386
+ self.bert = RNAErnie2Model(config)
387
+ self.cls = RNAErnie2OnlyMLMHead(config)
388
+ self.post_init()
389
+
390
+ def get_input_embeddings(self):
391
+ return self.bert.embeddings.word_embeddings
392
+
393
+ def get_output_embeddings(self):
394
+ return self.cls.predictions.decoder
395
+
396
+ def set_output_embeddings(self, new_embeddings):
397
+ self.cls.predictions.decoder = new_embeddings
398
+
399
+ def forward(
400
+ self,
401
+ input_ids: torch.LongTensor,
402
+ attention_mask: Optional[torch.Tensor] = None,
403
+ token_type_ids: Optional[torch.LongTensor] = None,
404
+ labels: Optional[torch.LongTensor] = None,
405
+ output_hidden_states: Optional[bool] = None,
406
+ output_attentions: Optional[bool] = None,
407
+ return_dict: Optional[bool] = None,
408
+ ) -> Union[Tuple, MaskedLMOutput]:
409
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
410
+
411
+ outputs = self.bert(
412
+ input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
413
+ output_hidden_states=output_hidden_states, output_attentions=output_attentions,
414
+ return_dict=True,
415
+ )
416
+ logits = self.cls(outputs.last_hidden_state)
417
+
418
+ loss = None
419
+ if labels is not None:
420
+ loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1), ignore_index=-100)
421
+
422
+ if not return_dict:
423
+ output = (logits,) + outputs[2:]
424
+ return (loss,) + output if loss is not None else output
425
+
426
+ return MaskedLMOutput(
427
+ loss=loss, logits=logits,
428
+ hidden_states=outputs.hidden_states, attentions=outputs.attentions,
429
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[EOS]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "[PAD]",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "[UNK]"
9
+ }
tokenization_rnaernie2.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import PreTrainedTokenizer
3
+
4
+
5
+ _VOCAB = {
6
+ "[PAD]": 0,
7
+ "[UNK]": 1,
8
+ "[CLS]": 2,
9
+ "[EOS]": 3,
10
+ "[SEP]": 4,
11
+ "[MASK]": 5,
12
+ "A": 6,
13
+ "U": 7,
14
+ "C": 8,
15
+ "G": 9,
16
+ "N": 10,
17
+ }
18
+
19
+
20
+ class RNAErnie2Tokenizer(PreTrainedTokenizer):
21
+ """Character-level RNA tokenizer for RNAErnie2.
22
+
23
+ Vocab (11 tokens): [PAD]=0, [UNK]=1, [CLS]=2, [EOS]=3, [SEP]=4, [MASK]=5,
24
+ A=6, U=7, C=8, G=9, N=10.
25
+ Sequences are wrapped [CLS] + tokens + [SEP].
26
+ T is silently converted to U (RNA convention).
27
+ """
28
+
29
+ vocab_files_names = {"vocab_file": "vocab.txt"}
30
+ model_input_names = ["input_ids", "attention_mask"]
31
+
32
+ def __init__(
33
+ self,
34
+ vocab_file=None,
35
+ pad_token="[PAD]",
36
+ unk_token="[UNK]",
37
+ cls_token="[CLS]",
38
+ eos_token="[EOS]",
39
+ sep_token="[SEP]",
40
+ mask_token="[MASK]",
41
+ **kwargs,
42
+ ):
43
+ self._vocab = {}
44
+ if vocab_file and os.path.isfile(vocab_file):
45
+ with open(vocab_file, encoding="utf-8") as f:
46
+ for idx, line in enumerate(f):
47
+ token = line.rstrip("\n")
48
+ self._vocab[token] = idx
49
+ else:
50
+ self._vocab = dict(_VOCAB)
51
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
52
+
53
+ super().__init__(
54
+ pad_token=pad_token,
55
+ unk_token=unk_token,
56
+ cls_token=cls_token,
57
+ eos_token=eos_token,
58
+ sep_token=sep_token,
59
+ mask_token=mask_token,
60
+ **kwargs,
61
+ )
62
+
63
+ @property
64
+ def vocab_size(self):
65
+ return len(self._vocab)
66
+
67
+ def get_vocab(self):
68
+ return dict(self._vocab)
69
+
70
+ def _tokenize(self, text):
71
+ return list(text.upper().replace("T", "U"))
72
+
73
+ def _convert_token_to_id(self, token):
74
+ return self._vocab.get(token, self._vocab.get("[UNK]", 1))
75
+
76
+ def _convert_id_to_token(self, index):
77
+ return self._ids_to_tokens.get(index, "[UNK]")
78
+
79
+ def save_vocabulary(self, save_directory, filename_prefix=None):
80
+ os.makedirs(save_directory, exist_ok=True)
81
+ fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.txt"
82
+ path = os.path.join(save_directory, fname)
83
+ with open(path, "w", encoding="utf-8") as f:
84
+ for token, _ in sorted(self._vocab.items(), key=lambda x: x[1]):
85
+ f.write(token + "\n")
86
+ return (path,)
87
+
88
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
89
+ cls = [self.cls_token_id]
90
+ sep = [self.sep_token_id]
91
+ if token_ids_1 is None:
92
+ return cls + token_ids_0 + sep
93
+ return cls + token_ids_0 + sep + token_ids_1 + sep
94
+
95
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
96
+ if already_has_special_tokens:
97
+ return super().get_special_tokens_mask(token_ids_0, token_ids_1, True)
98
+ mask = [1] + [0] * len(token_ids_0) + [1]
99
+ if token_ids_1 is not None:
100
+ mask += [0] * len(token_ids_1) + [1]
101
+ return mask
102
+
103
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
104
+ cls_sep = [0]
105
+ if token_ids_1 is None:
106
+ return cls_sep + [0] * len(token_ids_0) + cls_sep
107
+ return cls_sep + [0] * len(token_ids_0) + cls_sep + [0] * len(token_ids_1) + cls_sep
tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_rnaernie2.RNAErnie2Tokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "tokenizer_class": "RNAErnie2Tokenizer",
9
+ "model_max_length": 2048,
10
+ "pad_token": "[PAD]",
11
+ "unk_token": "[UNK]",
12
+ "cls_token": "[CLS]",
13
+ "eos_token": "[EOS]",
14
+ "sep_token": "[SEP]",
15
+ "mask_token": "[MASK]",
16
+ "padding_side": "right"
17
+ }
vocab.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [EOS]
5
+ [SEP]
6
+ [MASK]
7
+ A
8
+ U
9
+ C
10
+ G
11
+ N