Taykhoom commited on
Commit
7b1fe62
·
verified ·
1 Parent(s): c383c58

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - rna
4
+ library_name: transformers
5
+ tags:
6
+ - RNA
7
+ - language-model
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # RNAErnie
12
+
13
+ RNAErnie is a BERT-based RNA language model pretrained on RNACentral using a
14
+ motif-aware masking strategy with type-guided fine-tuning. It uses a DNA-style
15
+ vocabulary (T instead of U) and extends the token vocabulary with 28 ncRNA
16
+ type labels to enable type-guided learning.
17
+
18
+ ## Architecture
19
+
20
+ | Parameter | Value |
21
+ |---|---|
22
+ | Layers | 12 |
23
+ | Attention heads | 12 |
24
+ | Embedding dimension | 768 |
25
+ | Intermediate size | 3072 |
26
+ | Vocabulary size | 39 |
27
+ | Positional encoding | Absolute learned |
28
+ | Architecture | Post-LN BERT / ERNIE |
29
+ | Max sequence length | 512 |
30
+
31
+ **Vocabulary:** Special tokens `[PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [DEL]=5, [IND]=6`;
32
+ ncRNA type labels at indices 7-34 (RNaseMRPRNA, RNasePRNA, SRPRNA, YRNA, antisenseRNA,
33
+ autocatalyticallysplicedintron, guideRNA, hammerheadribozyme, lncRNA, miRNA, miscRNA,
34
+ ncRNA, other, piRNA, premiRNA, precursorRNA, rRNA, ribozyme, sRNA, scRNA, scaRNA,
35
+ siRNA, snRNA, snoRNA, tRNA, telomeraseRNA, tmRNA, vaultRNA);
36
+ nucleotides `A=35, T=36, C=37, G=38`.
37
+
38
+ **Tokenisation note:** Input U is silently converted to T. The model was pretrained
39
+ with DNA-style T notation.
40
+
41
+ ## Pretraining
42
+
43
+ - **Objective:** Masked language modelling (MLM) with motif-aware masking
44
+ - **Data:** RNACentral (sequences with length <= 512)
45
+ - **Source checkpoint:** `model_state.pdparams` from the original PaddlePaddle repository
46
+
47
+ ### Checkpoint selection
48
+
49
+ There is a single publicly released RNAErnie checkpoint
50
+ (`output/BERT,ERNIE,MOTIF,PROMPT/checkpoint_final/model_state.pdparams`),
51
+ corresponding to the `BERT,ERNIE,MOTIF,PROMPT` pretraining variant described in the
52
+ paper.
53
+
54
+ ## Parity Verification
55
+
56
+ Hidden-state representations verified identical (max abs diff < 7e-6) to a
57
+ standalone PyTorch reference implementation built directly from the raw
58
+ PaddlePaddle weights at all 13 representation levels (embedding + 12 layers).
59
+ Verified on GPU with PyTorch 2.7 / CUDA 12.
60
+
61
+ **Note on weight conversion:** PaddlePaddle stores `nn.Linear` weights as
62
+ `(in_features, out_features)`, the transpose of PyTorch's `(out_features, in_features)`.
63
+ All linear layer weights (attention projections, FFN, pooler, MLM transform) are
64
+ transposed during conversion; embedding tables and bias vectors are copied as-is.
65
+
66
+ ## Implementation Notes
67
+
68
+ The original implementation uses PaddlePaddle's ERNIE/TransformerEncoderLayer
69
+ backbone. This HF port re-implements the identical Post-LN BERT architecture in
70
+ pure PyTorch and adds `attn_implementation="sdpa"` and
71
+ `attn_implementation="flash_attention_2"` support, which were not part of the
72
+ original codebase.
73
+
74
+ ## Related Models
75
+
76
+ See the full [RNAErnie collection](https://huggingface.co/collections/Taykhoom/rnaernie-6a219927c11fdcccedb243db).
77
+
78
+ | Model | Context | Training data | Notes |
79
+ |---|---|---|---|
80
+ | **[RNAErnie](./)** | **512** | **RNACentral (nts<=512)** | **This model; PaddlePaddle ERNIE backbone** |
81
+ | [RNAErnie2](../RNAErnie2) | 2048 | RNACentral v22 (~31M seqs) | Retrained; PyTorch BERT |
82
+
83
+ ## Usage
84
+
85
+ ### Embedding generation
86
+
87
+ ```python
88
+ import torch
89
+ from transformers import AutoTokenizer, AutoModel
90
+
91
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
92
+ model = AutoModel.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
93
+ model.eval()
94
+
95
+ sequences = ["AUGCAUGCAUGC", "GCUGCAUGCUAGC"]
96
+ enc = tokenizer(sequences, return_tensors="pt", padding=True)
97
+
98
+ with torch.no_grad():
99
+ out = model(**enc)
100
+
101
+ cls_emb = out.last_hidden_state[:, 0, :] # (batch, 768) -- CLS token
102
+ token_emb = out.last_hidden_state # (batch, seq_len, 768)
103
+
104
+ # Intermediate layers
105
+ out_all = model(**enc, output_hidden_states=True)
106
+ layer6_emb = out_all.hidden_states[6] # (batch, seq_len, 768)
107
+ ```
108
+
109
+ ### MLM logits
110
+
111
+ ```python
112
+ import torch
113
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
116
+ model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
117
+ model.eval()
118
+
119
+ enc = tokenizer(["ATG[MASK]ATG"], return_tensors="pt")
120
+ with torch.no_grad():
121
+ logits = model(**enc).logits # (1, seq_len, 39)
122
+ ```
123
+
124
+ ### SDPA / Flash Attention 2
125
+
126
+ ```python
127
+ model = AutoModel.from_pretrained(
128
+ "Taykhoom/RNAErnie",
129
+ attn_implementation="sdpa", # or "flash_attention_2"
130
+ trust_remote_code=True,
131
+ )
132
+ ```
133
+
134
+ ### Fine-tuning
135
+
136
+ Standard HF conventions. For sequence-level tasks, use the CLS token embedding
137
+ (`last_hidden_state[:, 0, :]`) as input to a classification head. For type-guided
138
+ fine-tuning (as in the paper), prepend the ncRNA type label token to the input.
139
+
140
+ ## Citation
141
+
142
+ ```bibtex
143
+ @article{wang2024_rnaernie,
144
+ title = {Multi-purpose {RNA} language modelling with motif-aware pretraining and type-guided fine-tuning},
145
+ author = {Wang, Ning and Bian, Jiang and Li, Yuchen and Li, Xuhong and Mumtaz, Shahid and Kong, Linghe and Xiong, Haoyi},
146
+ journal = {Nature Machine Intelligence},
147
+ volume = {6},
148
+ pages = {548--557},
149
+ year = {2024},
150
+ doi = {10.1038/s42256-024-00836-4}
151
+ }
152
+ ```
153
+
154
+ ## Credits
155
+
156
+ Original model and code by Wang et al. Source: [GitHub](https://github.com/CatIIIIIIII/RNAErnie).
157
+ The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code)
158
+ and reviewed manually by Taykhoom Dalal.
159
+
160
+ ## License
161
+
162
+ Apache 2.0, following the original repository.
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RNAErnieForMaskedLM"
4
+ ],
5
+ "model_type": "rnaernie",
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_rnaernie.RNAErnieConfig",
8
+ "AutoModel": "modeling_rnaernie.RNAErnieModel",
9
+ "AutoModelForMaskedLM": "modeling_rnaernie.RNAErnieForMaskedLM"
10
+ },
11
+ "vocab_size": 39,
12
+ "hidden_size": 768,
13
+ "num_hidden_layers": 12,
14
+ "num_attention_heads": 12,
15
+ "intermediate_size": 3072,
16
+ "hidden_act": "relu",
17
+ "hidden_dropout_prob": 0.1,
18
+ "attention_probs_dropout_prob": 0.1,
19
+ "max_position_embeddings": 513,
20
+ "type_vocab_size": 2,
21
+ "layer_norm_eps": 1e-12,
22
+ "pad_token_id": 0,
23
+ "initializer_range": 0.02,
24
+ "transformers_version": "4.57.6"
25
+ }
configuration_rnaernie.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RNAErnieConfig(PretrainedConfig):
5
+ model_type = "rnaernie"
6
+
7
+ auto_map = {
8
+ "AutoConfig": "configuration_rnaernie.RNAErnieConfig",
9
+ "AutoModel": "modeling_rnaernie.RNAErnieModel",
10
+ "AutoModelForMaskedLM": "modeling_rnaernie.RNAErnieForMaskedLM",
11
+ }
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 39,
16
+ hidden_size: int = 768,
17
+ num_hidden_layers: int = 12,
18
+ num_attention_heads: int = 12,
19
+ intermediate_size: int = 3072,
20
+ hidden_act: str = "relu",
21
+ hidden_dropout_prob: float = 0.1,
22
+ attention_probs_dropout_prob: float = 0.1,
23
+ max_position_embeddings: int = 513,
24
+ type_vocab_size: int = 2,
25
+ layer_norm_eps: float = 1e-12,
26
+ pad_token_id: int = 0,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
30
+ self.vocab_size = vocab_size
31
+ self.hidden_size = hidden_size
32
+ self.num_hidden_layers = num_hidden_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.intermediate_size = intermediate_size
35
+ self.hidden_act = hidden_act
36
+ self.hidden_dropout_prob = hidden_dropout_prob
37
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.type_vocab_size = type_vocab_size
40
+ self.layer_norm_eps = layer_norm_eps
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fae5c3d4324f5e9992efd83908ac47f66569217a5e1dc87fb296f78ff5cea48
3
+ size 346800816
modeling_rnaernie.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput
9
+
10
+ try:
11
+ from .configuration_rnaernie import RNAErnieConfig
12
+ except ImportError:
13
+ from configuration_rnaernie import RNAErnieConfig
14
+
15
+
16
+ class RNAErnieSelfAttention(nn.Module):
17
+
18
+ def __init__(self, config: RNAErnieConfig):
19
+ super().__init__()
20
+ self.num_attention_heads = config.num_attention_heads
21
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
22
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
23
+
24
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
25
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
26
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
27
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
28
+
29
+ def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
30
+ B, T, _ = x.shape
31
+ return x.view(B, T, self.num_attention_heads, self.attention_head_size).permute(0, 2, 1, 3)
32
+
33
+ def forward(
34
+ self,
35
+ hidden_states: torch.Tensor,
36
+ key_padding_mask: Optional[torch.Tensor] = None,
37
+ output_attentions: bool = False,
38
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
39
+ q = self._split_heads(self.query(hidden_states))
40
+ k = self._split_heads(self.key(hidden_states))
41
+ v = self._split_heads(self.value(hidden_states))
42
+
43
+ scale = math.sqrt(self.attention_head_size)
44
+ scores = torch.matmul(q, k.transpose(-1, -2)) / scale
45
+ if key_padding_mask is not None:
46
+ scores = scores.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
47
+ probs = F.softmax(scores, dim=-1)
48
+ probs = self.dropout(probs)
49
+ context = torch.matmul(probs, v)
50
+
51
+ B, _, T, _ = context.shape
52
+ context = context.permute(0, 2, 1, 3).contiguous().view(B, T, self.all_head_size)
53
+
54
+ if output_attentions:
55
+ return context, probs
56
+ return context, None
57
+
58
+
59
+ class RNAErnieSdpaSelfAttention(RNAErnieSelfAttention):
60
+
61
+ def forward(
62
+ self,
63
+ hidden_states: torch.Tensor,
64
+ key_padding_mask: Optional[torch.Tensor] = None,
65
+ output_attentions: bool = False,
66
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
67
+ if output_attentions:
68
+ return super().forward(hidden_states, key_padding_mask, output_attentions=True)
69
+
70
+ B, T, _ = hidden_states.shape
71
+ q = self._split_heads(self.query(hidden_states))
72
+ k = self._split_heads(self.key(hidden_states))
73
+ v = self._split_heads(self.value(hidden_states))
74
+
75
+ attn_mask = None
76
+ if key_padding_mask is not None:
77
+ attn_mask = torch.zeros(B, 1, 1, T, dtype=q.dtype, device=q.device)
78
+ attn_mask = attn_mask.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
79
+
80
+ context = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
81
+ context = context.permute(0, 2, 1, 3).contiguous().view(B, T, self.all_head_size)
82
+ return context, None
83
+
84
+
85
+ class RNAErnieFlashSelfAttention(RNAErnieSelfAttention):
86
+
87
+ def forward(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ key_padding_mask: Optional[torch.Tensor] = None,
91
+ output_attentions: bool = False,
92
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
93
+ if output_attentions:
94
+ return super().forward(hidden_states, key_padding_mask, output_attentions=True)
95
+
96
+ try:
97
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
98
+ from flash_attn.bert_padding import pad_input, unpad_input
99
+ except ImportError as e:
100
+ raise ImportError(
101
+ "flash_attn is required for attn_implementation='flash_attention_2'. "
102
+ "Install with: pip install flash-attn --no-build-isolation"
103
+ ) from e
104
+
105
+ B, T, _ = hidden_states.shape
106
+ q = self._split_heads(self.query(hidden_states))
107
+ k = self._split_heads(self.key(hidden_states))
108
+ v = self._split_heads(self.value(hidden_states))
109
+
110
+ q = q.permute(0, 2, 1, 3)
111
+ k = k.permute(0, 2, 1, 3)
112
+ v = v.permute(0, 2, 1, 3)
113
+
114
+ orig_dtype = q.dtype
115
+ if orig_dtype not in (torch.float16, torch.bfloat16):
116
+ q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
117
+
118
+ if key_padding_mask is not None and key_padding_mask.any():
119
+ attend = ~key_padding_mask
120
+ q_u, indices, cu_seqlens, max_seqlen, _ = unpad_input(q, attend)
121
+ k_u, _, _, _, _ = unpad_input(k, attend)
122
+ v_u, _, _, _, _ = unpad_input(v, attend)
123
+ out_u = flash_attn_varlen_func(
124
+ q_u, k_u, v_u,
125
+ cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
126
+ max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
127
+ causal=False,
128
+ )
129
+ out = pad_input(out_u, indices, B, T)
130
+ else:
131
+ out = flash_attn_func(q, k, v, causal=False)
132
+
133
+ out = out.to(orig_dtype).reshape(B, T, self.all_head_size)
134
+ return out, None
135
+
136
+
137
+ RNAERNIE_SELF_ATTENTION_CLASSES = {
138
+ "eager": RNAErnieSelfAttention,
139
+ "sdpa": RNAErnieSdpaSelfAttention,
140
+ "flash_attention_2": RNAErnieFlashSelfAttention,
141
+ }
142
+
143
+
144
+ class RNAErnieSelfOutput(nn.Module):
145
+ def __init__(self, config: RNAErnieConfig):
146
+ super().__init__()
147
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
148
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
149
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
150
+
151
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
152
+ hidden_states = self.dropout(self.dense(hidden_states))
153
+ return self.LayerNorm(hidden_states + input_tensor)
154
+
155
+
156
+ class RNAErnieAttention(nn.Module):
157
+ def __init__(self, config: RNAErnieConfig):
158
+ super().__init__()
159
+ attn_cls = RNAERNIE_SELF_ATTENTION_CLASSES[getattr(config, "_attn_implementation", "eager")]
160
+ self.self = attn_cls(config)
161
+ self.output = RNAErnieSelfOutput(config)
162
+
163
+ def forward(
164
+ self,
165
+ hidden_states: torch.Tensor,
166
+ key_padding_mask: Optional[torch.Tensor],
167
+ output_attentions: bool = False,
168
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
169
+ self_out, attn_weights = self.self(hidden_states, key_padding_mask, output_attentions)
170
+ return self.output(self_out, hidden_states), attn_weights
171
+
172
+
173
+ class RNAErnieIntermediate(nn.Module):
174
+ def __init__(self, config: RNAErnieConfig):
175
+ super().__init__()
176
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
177
+ self.act = nn.ReLU() if config.hidden_act == "relu" else nn.GELU()
178
+
179
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
180
+ return self.act(self.dense(hidden_states))
181
+
182
+
183
+ class RNAErnieOutput(nn.Module):
184
+ def __init__(self, config: RNAErnieConfig):
185
+ super().__init__()
186
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
187
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
188
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
189
+
190
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
191
+ hidden_states = self.dropout(self.dense(hidden_states))
192
+ return self.LayerNorm(hidden_states + input_tensor)
193
+
194
+
195
+ class RNAErnieLayer(nn.Module):
196
+ def __init__(self, config: RNAErnieConfig):
197
+ super().__init__()
198
+ self.attention = RNAErnieAttention(config)
199
+ self.intermediate = RNAErnieIntermediate(config)
200
+ self.output = RNAErnieOutput(config)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ key_padding_mask: Optional[torch.Tensor],
206
+ output_attentions: bool = False,
207
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
208
+ attn_out, attn_weights = self.attention(hidden_states, key_padding_mask, output_attentions)
209
+ return self.output(self.intermediate(attn_out), attn_out), attn_weights
210
+
211
+
212
+ class RNAErnieEncoder(nn.Module):
213
+ def __init__(self, config: RNAErnieConfig):
214
+ super().__init__()
215
+ self.layer = nn.ModuleList([RNAErnieLayer(config) for _ in range(config.num_hidden_layers)])
216
+
217
+ def forward(
218
+ self,
219
+ hidden_states: torch.Tensor,
220
+ key_padding_mask: Optional[torch.Tensor],
221
+ output_hidden_states: bool = False,
222
+ output_attentions: bool = False,
223
+ ) -> Tuple:
224
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
225
+ all_attentions = () if output_attentions else None
226
+
227
+ for layer in self.layer:
228
+ hidden_states, attn_weights = layer(hidden_states, key_padding_mask, output_attentions)
229
+ if output_hidden_states:
230
+ all_hidden_states = all_hidden_states + (hidden_states,)
231
+ if output_attentions:
232
+ all_attentions = all_attentions + (attn_weights,)
233
+
234
+ return hidden_states, all_hidden_states, all_attentions
235
+
236
+
237
+ class RNAErnieEmbeddings(nn.Module):
238
+ def __init__(self, config: RNAErnieConfig):
239
+ super().__init__()
240
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
241
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
242
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
243
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
244
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
245
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
246
+
247
+ def forward(self, input_ids: torch.LongTensor, token_type_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
248
+ B, T = input_ids.shape
249
+ if token_type_ids is None:
250
+ token_type_ids = torch.zeros_like(input_ids)
251
+ x = self.word_embeddings(input_ids)
252
+ x = x + self.position_embeddings(self.position_ids[:, :T])
253
+ x = x + self.token_type_embeddings(token_type_ids)
254
+ return self.dropout(self.LayerNorm(x))
255
+
256
+
257
+ class RNAErniePooler(nn.Module):
258
+ def __init__(self, config: RNAErnieConfig):
259
+ super().__init__()
260
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
261
+ self.activation = nn.Tanh()
262
+
263
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
264
+ return self.activation(self.dense(hidden_states[:, 0]))
265
+
266
+
267
+ class RNAErniePredictionHeadTransform(nn.Module):
268
+ def __init__(self, config: RNAErnieConfig):
269
+ super().__init__()
270
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
271
+ self.act = nn.ReLU() if config.hidden_act == "relu" else nn.GELU()
272
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
273
+
274
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
275
+ return self.LayerNorm(self.act(self.dense(hidden_states)))
276
+
277
+
278
+ class RNAErnieLMPredictionHead(nn.Module):
279
+ def __init__(self, config: RNAErnieConfig):
280
+ super().__init__()
281
+ self.transform = RNAErniePredictionHeadTransform(config)
282
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
283
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
284
+ self.decoder.bias = self.bias
285
+
286
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287
+ return self.decoder(self.transform(hidden_states))
288
+
289
+
290
+ class RNAErnieOnlyMLMHead(nn.Module):
291
+ def __init__(self, config: RNAErnieConfig):
292
+ super().__init__()
293
+ self.predictions = RNAErnieLMPredictionHead(config)
294
+
295
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
296
+ return self.predictions(sequence_output)
297
+
298
+
299
+ class RNAErnieModel(PreTrainedModel):
300
+ config_class = RNAErnieConfig
301
+ base_model_prefix = "bert"
302
+ _supports_sdpa = True
303
+ _supports_flash_attn_2 = True
304
+
305
+ def __init__(self, config: RNAErnieConfig):
306
+ super().__init__(config)
307
+ self.embeddings = RNAErnieEmbeddings(config)
308
+ self.encoder = RNAErnieEncoder(config)
309
+ self.pooler = RNAErniePooler(config)
310
+ self.post_init()
311
+
312
+ def get_input_embeddings(self):
313
+ return self.embeddings.word_embeddings
314
+
315
+ def set_input_embeddings(self, value):
316
+ self.embeddings.word_embeddings = value
317
+
318
+ def forward(
319
+ self,
320
+ input_ids: torch.LongTensor,
321
+ attention_mask: Optional[torch.Tensor] = None,
322
+ token_type_ids: Optional[torch.LongTensor] = None,
323
+ output_hidden_states: Optional[bool] = None,
324
+ output_attentions: Optional[bool] = None,
325
+ return_dict: Optional[bool] = None,
326
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
327
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
328
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
329
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
330
+
331
+ if attention_mask is None:
332
+ attention_mask = torch.ones_like(input_ids)
333
+ key_padding_mask = attention_mask.eq(0)
334
+ if not key_padding_mask.any():
335
+ key_padding_mask = None
336
+
337
+ x = self.embeddings(input_ids, token_type_ids)
338
+ last_hidden_state, all_hidden_states, all_attentions = self.encoder(
339
+ x, key_padding_mask,
340
+ output_hidden_states=output_hidden_states,
341
+ output_attentions=output_attentions,
342
+ )
343
+ pooled = self.pooler(last_hidden_state)
344
+
345
+ if not return_dict:
346
+ return tuple(v for v in [last_hidden_state, pooled, all_hidden_states, all_attentions] if v is not None)
347
+
348
+ return BaseModelOutputWithPooling(
349
+ last_hidden_state=last_hidden_state,
350
+ pooler_output=pooled,
351
+ hidden_states=all_hidden_states,
352
+ attentions=all_attentions,
353
+ )
354
+
355
+
356
+ class RNAErnieForMaskedLM(PreTrainedModel):
357
+ config_class = RNAErnieConfig
358
+ base_model_prefix = "bert"
359
+ _supports_sdpa = True
360
+ _supports_flash_attn_2 = True
361
+
362
+ def __init__(self, config: RNAErnieConfig):
363
+ super().__init__(config)
364
+ self.bert = RNAErnieModel(config)
365
+ self.cls = RNAErnieOnlyMLMHead(config)
366
+ self.post_init()
367
+
368
+ def get_input_embeddings(self):
369
+ return self.bert.embeddings.word_embeddings
370
+
371
+ def get_output_embeddings(self):
372
+ return self.cls.predictions.decoder
373
+
374
+ def set_output_embeddings(self, new_embeddings):
375
+ self.cls.predictions.decoder = new_embeddings
376
+
377
+ def forward(
378
+ self,
379
+ input_ids: torch.LongTensor,
380
+ attention_mask: Optional[torch.Tensor] = None,
381
+ token_type_ids: Optional[torch.LongTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ output_hidden_states: Optional[bool] = None,
384
+ output_attentions: Optional[bool] = None,
385
+ return_dict: Optional[bool] = None,
386
+ ) -> Union[Tuple, MaskedLMOutput]:
387
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
388
+
389
+ outputs = self.bert(
390
+ input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
391
+ output_hidden_states=output_hidden_states, output_attentions=output_attentions,
392
+ return_dict=True,
393
+ )
394
+ logits = self.cls(outputs.last_hidden_state)
395
+
396
+ loss = None
397
+ if labels is not None:
398
+ loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1), ignore_index=-100)
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[2:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return MaskedLMOutput(
405
+ loss=loss, logits=logits,
406
+ hidden_states=outputs.hidden_states, attentions=outputs.attentions,
407
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": "[PAD]",
3
+ "unk_token": "[UNK]",
4
+ "cls_token": "[CLS]",
5
+ "sep_token": "[SEP]",
6
+ "mask_token": "[MASK]"
7
+ }
tokenization_rnaernie.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ from transformers import PreTrainedTokenizer
5
+
6
+ _VOCAB = {
7
+ "[PAD]": 0,
8
+ "[UNK]": 1,
9
+ "[CLS]": 2,
10
+ "[SEP]": 3,
11
+ "[MASK]": 4,
12
+ "[DEL]": 5,
13
+ "[IND]": 6,
14
+ "RNaseMRPRNA": 7,
15
+ "RNasePRNA": 8,
16
+ "SRPRNA": 9,
17
+ "YRNA": 10,
18
+ "antisenseRNA": 11,
19
+ "autocatalyticallysplicedintron": 12,
20
+ "guideRNA": 13,
21
+ "hammerheadribozyme": 14,
22
+ "lncRNA": 15,
23
+ "miRNA": 16,
24
+ "miscRNA": 17,
25
+ "ncRNA": 18,
26
+ "other": 19,
27
+ "piRNA": 20,
28
+ "premiRNA": 21,
29
+ "precursorRNA": 22,
30
+ "rRNA": 23,
31
+ "ribozyme": 24,
32
+ "sRNA": 25,
33
+ "scRNA": 26,
34
+ "scaRNA": 27,
35
+ "siRNA": 28,
36
+ "snRNA": 29,
37
+ "snoRNA": 30,
38
+ "tRNA": 31,
39
+ "telomeraseRNA": 32,
40
+ "tmRNA": 33,
41
+ "vaultRNA": 34,
42
+ "A": 35,
43
+ "T": 36,
44
+ "C": 37,
45
+ "G": 38,
46
+ }
47
+
48
+
49
+ class RNAErnieTokenizer(PreTrainedTokenizer):
50
+ """Character-level RNA tokenizer for RNAErnie (original ERNIE/PaddlePaddle version).
51
+
52
+ Converts U to T before tokenisation (model was pretrained with DNA-style T).
53
+ Input sequences are uppercased and U->T substituted automatically.
54
+
55
+ Vocabulary (39 tokens):
56
+ - Special: [PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [DEL]=5, [IND]=6
57
+ - ncRNA type labels: indices 7-34 (28 labels)
58
+ - Nucleotides: A=35, T=36, C=37, G=38
59
+ """
60
+
61
+ vocab_files_names = {"vocab_file": "vocab.txt"}
62
+ model_input_names = ["input_ids", "attention_mask"]
63
+
64
+ def __init__(
65
+ self,
66
+ vocab_file=None,
67
+ pad_token="[PAD]",
68
+ unk_token="[UNK]",
69
+ cls_token="[CLS]",
70
+ sep_token="[SEP]",
71
+ mask_token="[MASK]",
72
+ **kwargs,
73
+ ):
74
+ if vocab_file and os.path.isfile(vocab_file):
75
+ self._vocab = {}
76
+ with open(vocab_file, encoding="utf-8") as f:
77
+ for idx, line in enumerate(f):
78
+ token = line.rstrip("\n")
79
+ self._vocab[token] = idx
80
+ else:
81
+ self._vocab = dict(_VOCAB)
82
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
83
+ super().__init__(
84
+ pad_token=pad_token,
85
+ unk_token=unk_token,
86
+ cls_token=cls_token,
87
+ sep_token=sep_token,
88
+ mask_token=mask_token,
89
+ **kwargs,
90
+ )
91
+
92
+ @property
93
+ def vocab_size(self) -> int:
94
+ return len(self._vocab)
95
+
96
+ def get_vocab(self) -> Dict[str, int]:
97
+ return dict(self._vocab)
98
+
99
+ def _tokenize(self, text: str) -> List[str]:
100
+ return list(text.upper().replace("U", "T"))
101
+
102
+ def _convert_token_to_id(self, token: str) -> int:
103
+ return self._vocab.get(token, self._vocab["[UNK]"])
104
+
105
+ def _convert_id_to_token(self, index: int) -> str:
106
+ return self._ids_to_tokens.get(index, "[UNK]")
107
+
108
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
109
+ os.makedirs(save_directory, exist_ok=True)
110
+ fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.txt"
111
+ path = os.path.join(save_directory, fname)
112
+ with open(path, "w", encoding="utf-8") as f:
113
+ for token, _ in sorted(self._vocab.items(), key=lambda x: x[1]):
114
+ f.write(token + "\n")
115
+ return (path,)
116
+
117
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
118
+ cls = [self.cls_token_id]
119
+ sep = [self.sep_token_id]
120
+ if token_ids_1 is None:
121
+ return cls + token_ids_0 + sep
122
+ return cls + token_ids_0 + sep + token_ids_1 + sep
123
+
124
+ def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:
125
+ if already_has_special_tokens:
126
+ return super().get_special_tokens_mask(token_ids_0, token_ids_1, already_has_special_tokens=True)
127
+ mask = [1] + [0] * len(token_ids_0) + [1]
128
+ if token_ids_1 is not None:
129
+ mask += [1] + [0] * len(token_ids_1) + [1]
130
+ return mask
131
+
132
+ def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
133
+ sep = [self.sep_token_id]
134
+ cls = [self.cls_token_id]
135
+ if token_ids_1 is None:
136
+ return [0] * len(cls + token_ids_0 + sep)
137
+ return [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": ["tokenization_rnaernie.RNAErnieTokenizer", null]
4
+ },
5
+ "tokenizer_class": "RNAErnieTokenizer",
6
+ "model_max_length": 512,
7
+ "pad_token": "[PAD]",
8
+ "unk_token": "[UNK]",
9
+ "cls_token": "[CLS]",
10
+ "sep_token": "[SEP]",
11
+ "mask_token": "[MASK]",
12
+ "padding_side": "right"
13
+ }
vocab.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ [DEL]
7
+ [IND]
8
+ RNaseMRPRNA
9
+ RNasePRNA
10
+ SRPRNA
11
+ YRNA
12
+ antisenseRNA
13
+ autocatalyticallysplicedintron
14
+ guideRNA
15
+ hammerheadribozyme
16
+ lncRNA
17
+ miRNA
18
+ miscRNA
19
+ ncRNA
20
+ other
21
+ piRNA
22
+ premiRNA
23
+ precursorRNA
24
+ rRNA
25
+ ribozyme
26
+ sRNA
27
+ scRNA
28
+ scaRNA
29
+ siRNA
30
+ snRNA
31
+ snoRNA
32
+ tRNA
33
+ telomeraseRNA
34
+ tmRNA
35
+ vaultRNA
36
+ A
37
+ T
38
+ C
39
+ G