Taykhoom commited on
Commit
00e6e55
·
verified ·
1 Parent(s): a58e15f

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - rna
4
+ library_name: transformers
5
+ tags:
6
+ - RNA
7
+ - language-model
8
+ - MSA
9
+ license: mit
10
+ ---
11
+
12
+ # RNA-MSM
13
+
14
+ Multiple sequence alignment-based RNA language model trained on homologous RNA
15
+ sequence alignments from the RNAcmap pipeline.
16
+
17
+ ## Architecture
18
+
19
+ | Parameter | Value |
20
+ |---|---|
21
+ | Layers | 10 |
22
+ | Attention heads | 12 |
23
+ | Embedding dimension | 768 |
24
+ | FFN dimension | 3072 |
25
+ | Vocabulary size | 12 |
26
+ | Positional encoding | Learned (sequence) + learned scalar (alignment row) |
27
+ | Architecture | Axial MSA Transformer (row + column self-attention) |
28
+ | Max sequence length | 1024 |
29
+ | Max alignment depth | 1024 |
30
+
31
+ **Input format:** RNA-MSM takes 3D input `(batch, num_alignments, seqlen)`. Each
32
+ alignment is a set of homologous RNA sequences of equal length (an MSA). The model
33
+ applies row self-attention (across sequence positions) and column self-attention
34
+ (across alignment rows) at each of the 10 transformer layers.
35
+
36
+ ### Vocabulary
37
+
38
+ | Token | ID | Token | ID |
39
+ |---|---|---|---|
40
+ | `<cls>` | 0 | `U` | 7 |
41
+ | `<pad>` | 1 | `X` | 8 |
42
+ | `<eos>` | 2 | `N` | 9 |
43
+ | `<unk>` | 3 | `-` | 10 |
44
+ | `A` | 4 | `<mask>` | 11 |
45
+ | `G` | 5 | | |
46
+ | `C` | 6 | | |
47
+
48
+ Each sequence is prepended with `<cls>` (id 0). No `<eos>` token is appended.
49
+
50
+ ## Pretraining
51
+
52
+ - **Objective:** Masked language modeling on RNA MSAs (masking ~15% of tokens)
53
+ - **Data:** RNA homologous sequences searched by RNAcmap from non-redundant RNA
54
+ databases
55
+ - **Source checkpoint:** `RNA_MSM_pretrained.ckpt`
56
+ ([original Google Drive link](https://drive.google.com/file/d/11A-S13qAb5wiBi1YLs3EOrnixSDq7Q0q/view))
57
+
58
+ ### Checkpoint selection
59
+
60
+ There is one publicly released RNA-MSM pretrained checkpoint. This is that checkpoint,
61
+ converted from the original PyTorch Lightning `.ckpt` format.
62
+
63
+ ## Parity Verification
64
+
65
+ Hidden-state representations verified identical (max abs diff = 0.00, exact match) to
66
+ the reference implementation at all 11 representation levels (embedding + 10 transformer
67
+ layers), both on padded and unpadded batches. Verified on GPU with PyTorch 2.7 /
68
+ CUDA 12.6.
69
+
70
+ ## Related Models
71
+
72
+ See the full [RNA-MSM collection](https://huggingface.co/collections/Taykhoom/rna-msm).
73
+
74
+ ## Usage
75
+
76
+ RNA-MSM is an **MSA model** -- it performs best when given multiple homologous
77
+ sequences as input. For single-sequence embedding, each sequence is treated as a
78
+ 1-row MSA.
79
+
80
+ ### Single-sequence embedding
81
+
82
+ ```python
83
+ import torch
84
+ from transformers import AutoTokenizer, AutoModel
85
+
86
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNA-MSM", trust_remote_code=True)
87
+ model = AutoModel.from_pretrained("Taykhoom/RNA-MSM", trust_remote_code=True)
88
+ model.eval()
89
+
90
+ sequences = ["AGCUAGCUAGCU", "GCUAGCUA"]
91
+ enc = tokenizer(sequences, return_tensors="pt", padding=True)
92
+ # enc["input_ids"]: (2, 1, seqlen) -- each sequence treated as 1-row MSA
93
+
94
+ with torch.no_grad():
95
+ out = model(**enc)
96
+
97
+ # last_hidden_state: (batch, num_alignments, seqlen, 768)
98
+ lhs = out.last_hidden_state # (2, 1, seqlen, 768)
99
+
100
+ # Per-token embeddings for the query sequence (row 0), excluding CLS
101
+ token_emb = lhs[:, 0, 1:, :] # (2, seqlen-1, 768)
102
+
103
+ # Mean-pool over non-padding positions for sequence-level embedding
104
+ mask = enc["attention_mask"][:, 0, 1:].unsqueeze(-1).float() # (2, seqlen-1, 1)
105
+ seq_emb = (token_emb * mask).sum(1) / mask.sum(1).clamp(min=1) # (2, 768)
106
+ ```
107
+
108
+ ### MSA embedding
109
+
110
+ ```python
111
+ import torch
112
+ from transformers import AutoTokenizer, AutoModel
113
+
114
+ tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNA-MSM", trust_remote_code=True)
115
+ model = AutoModel.from_pretrained("Taykhoom/RNA-MSM", trust_remote_code=True)
116
+ model.eval()
117
+
118
+ # One MSA: 3 aligned homologous sequences of equal length
119
+ msa = [
120
+ "AGCUAGCUAGCU",
121
+ "AGCUAGCUAGC-",
122
+ "AGCU--CUAGCU",
123
+ ]
124
+ enc = tokenizer.encode_msa([msa], return_tensors="pt", padding=True)
125
+ # enc["input_ids"]: (1, 3, seqlen)
126
+
127
+ with torch.no_grad():
128
+ out = model(**enc)
129
+
130
+ # last_hidden_state: (1, 3, seqlen, 768)
131
+ # Use row 0 (query sequence) for downstream tasks
132
+ query_emb = out.last_hidden_state[:, 0, 1:, :] # (1, seqlen-1, 768)
133
+ ```
134
+
135
+ ### Intermediate layers
136
+
137
+ ```python
138
+ with torch.no_grad():
139
+ out = model(**enc, output_hidden_states=True)
140
+
141
+ # hidden_states: tuple of 11 tensors, each (batch, num_alignments, seqlen, 768)
142
+ # Index 0 = embedding, 1..10 = transformer layer outputs
143
+ layer5_emb = out.hidden_states[5][:, 0, :, :] # (batch, seqlen, 768)
144
+ ```
145
+
146
+ ### MLM logits
147
+
148
+ ```python
149
+ from transformers import AutoModelForMaskedLM
150
+
151
+ mlm = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNA-MSM", trust_remote_code=True)
152
+ mlm.eval()
153
+
154
+ enc = tokenizer(["AGCU<mask>AGCU"], return_tensors="pt", padding=True)
155
+ with torch.no_grad():
156
+ logits = mlm(**enc).logits # (1, 1, seqlen, 12)
157
+ ```
158
+
159
+ ### Fine-tuning
160
+
161
+ For sequence-level downstream tasks (e.g., solvent accessibility), extract the
162
+ embedding from the query row (row 0) of the last hidden state, then apply a
163
+ prediction head. The model's attention maps (row attention) are also useful for
164
+ 2D structural tasks (e.g., secondary structure prediction).
165
+
166
+ ## Implementation Notes
167
+
168
+ RNA-MSM uses **axial attention**: each transformer layer applies row self-attention
169
+ (attending across sequence positions, summed over alignment rows) followed by column
170
+ self-attention (attending across alignment rows per position). This custom attention
171
+ pattern is not compatible with `attn_implementation="sdpa"` or
172
+ `attn_implementation="flash_attention_2"` -- only `"eager"` is supported.
173
+
174
+ `last_hidden_state` has shape `(batch, num_alignments, seqlen, embed_dim)` -- note
175
+ the 4D output, reflecting the MSA structure. For single-sequence use (1-row MSA),
176
+ this is `(batch, 1, seqlen, embed_dim)`.
177
+
178
+ ## Citation
179
+
180
+ ```bibtex
181
+ @article{zhang2024rnamsm,
182
+ author = {Zhang, Yikun and Lang, Mei and Jiang, Jiuhong and Gao, Zhiqiang
183
+ and Xu, Fan and Litfin, Thomas and Chen, Ke and Singh, Jaswinder
184
+ and Huang, Xiansong and Song, Guoli and Tian, Yonghong and Zhan, Jian
185
+ and Chen, Jie and Zhou, Yaoqi},
186
+ title = {Multiple sequence alignment-based RNA language model and its application
187
+ to structural inference},
188
+ journal = {Nucleic Acids Research},
189
+ volume = {52},
190
+ number = {1},
191
+ pages = {e3},
192
+ year = {2024},
193
+ doi = {10.1093/nar/gkad1031},
194
+ pmid = {37941140},
195
+ }
196
+ ```
197
+
198
+ ## Credits
199
+
200
+ Original model and code by Zhang et al. Source: [GitHub](https://github.com/yikunpku/RNA-MSM).
201
+ The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code)
202
+ and reviewed manually by Taykhoom Dalal.
203
+
204
+ ## License
205
+
206
+ MIT, following the original repository.
__pycache__/configuration_rnamsm.cpython-39.pyc ADDED
Binary file (1.4 kB). View file
 
__pycache__/modeling_rnamsm.cpython-39.pyc ADDED
Binary file (14.9 kB). View file
 
__pycache__/tokenization_rnamsm.cpython-39.pyc ADDED
Binary file (8.51 kB). View file
 
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_rnamsm.RNAMSMConfig",
4
+ "AutoModel": "modeling_rnamsm.RNAMSMModel",
5
+ "AutoModelForMaskedLM": "modeling_rnamsm.RNAMSMForMaskedLM"
6
+ },
7
+ "activation_dropout": 0.1,
8
+ "architectures": [
9
+ "RNAMSMForMaskedLM"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "cls_idx": 0,
13
+ "dropout": 0.1,
14
+ "embed_dim": 768,
15
+ "embed_positions_msa": true,
16
+ "eos_idx": 2,
17
+ "ffn_embed_dim": 3072,
18
+ "mask_idx": 11,
19
+ "max_alignments": 1024,
20
+ "max_positions": 1024,
21
+ "max_tokens_per_msa": 16384,
22
+ "model_type": "rnamsm",
23
+ "num_attention_heads": 12,
24
+ "num_layers": 10,
25
+ "padding_idx": 1,
26
+ "transformers_version": "4.57.6",
27
+ "vocab_size": 12
28
+ }
configuration_rnamsm.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RNAMSMConfig(PretrainedConfig):
5
+ model_type = "rnamsm"
6
+
7
+ auto_map = {
8
+ "AutoConfig": "configuration_rnamsm.RNAMSMConfig",
9
+ "AutoModel": "modeling_rnamsm.RNAMSMModel",
10
+ "AutoModelForMaskedLM": "modeling_rnamsm.RNAMSMForMaskedLM",
11
+ }
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=12,
16
+ num_layers=10,
17
+ embed_dim=768,
18
+ num_attention_heads=12,
19
+ ffn_embed_dim=3072,
20
+ padding_idx=1,
21
+ mask_idx=11,
22
+ cls_idx=0,
23
+ eos_idx=2,
24
+ dropout=0.1,
25
+ attention_dropout=0.1,
26
+ activation_dropout=0.1,
27
+ max_positions=1024,
28
+ max_alignments=1024,
29
+ max_tokens_per_msa=16384,
30
+ embed_positions_msa=True,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(padding_idx=padding_idx, **kwargs)
34
+ self.vocab_size = vocab_size
35
+ self.num_layers = num_layers
36
+ self.embed_dim = embed_dim
37
+ self.num_attention_heads = num_attention_heads
38
+ self.ffn_embed_dim = ffn_embed_dim
39
+ self.mask_idx = mask_idx
40
+ self.cls_idx = cls_idx
41
+ self.eos_idx = eos_idx
42
+ self.dropout = dropout
43
+ self.attention_dropout = attention_dropout
44
+ self.activation_dropout = activation_dropout
45
+ self.max_positions = max_positions
46
+ self.max_alignments = max_alignments
47
+ self.max_tokens_per_msa = max_tokens_per_msa
48
+ self.embed_positions_msa = embed_positions_msa
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3998fb8289d98cf53944fc14da4157a40c03dffecf0efefd7e76044ed16a0095
3
+ size 383678288
modeling_rnamsm.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
9
+
10
+ try:
11
+ from .configuration_rnamsm import RNAMSMConfig
12
+ except ImportError:
13
+ from configuration_rnamsm import RNAMSMConfig
14
+
15
+
16
+ def gelu(x):
17
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
18
+
19
+
20
+ class RNAMSMLMHead(nn.Module):
21
+ def __init__(self, config: RNAMSMConfig, embed_tokens_weight: nn.Parameter):
22
+ super().__init__()
23
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
24
+ self.layer_norm = nn.LayerNorm(config.embed_dim)
25
+ self.weight = embed_tokens_weight
26
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
27
+
28
+ def forward(self, x):
29
+ x = self.dense(x)
30
+ x = gelu(x)
31
+ x = self.layer_norm(x)
32
+ return F.linear(x, self.weight) + self.bias
33
+
34
+
35
+ class LearnedPositionalEmbedding(nn.Embedding):
36
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
37
+ num_embeddings_ = num_embeddings + padding_idx + 1
38
+ super().__init__(num_embeddings_, embedding_dim, padding_idx)
39
+ self.max_positions = num_embeddings
40
+
41
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
42
+ mask = tokens.ne(self.padding_idx).int()
43
+ positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
44
+ return F.embedding(positions, self.weight, self.padding_idx,
45
+ self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
46
+
47
+
48
+ class NormalizedResidualBlock(nn.Module):
49
+ def __init__(self, layer: nn.Module, embedding_dim: int, dropout: float):
50
+ super().__init__()
51
+ self.layer = layer
52
+ self.layer_norm = nn.LayerNorm(embedding_dim)
53
+ self.dropout_module = nn.Dropout(dropout)
54
+
55
+ def forward(self, x, *args, **kwargs):
56
+ residual = x
57
+ x = self.layer_norm(x)
58
+ outputs = self.layer(x, *args, **kwargs)
59
+ if isinstance(outputs, tuple):
60
+ x, *out = outputs
61
+ else:
62
+ x, out = outputs, None
63
+ x = self.dropout_module(x)
64
+ x = residual + x
65
+ if out is not None:
66
+ return (x,) + tuple(out)
67
+ return x
68
+
69
+
70
+ class FeedForwardNetwork(nn.Module):
71
+ def __init__(self, embedding_dim: int, ffn_embedding_dim: int,
72
+ activation_dropout: float, max_tokens_per_msa: int):
73
+ super().__init__()
74
+ self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
75
+ self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
76
+ self.activation_fn = nn.GELU()
77
+ self.activation_dropout_module = nn.Dropout(activation_dropout)
78
+ self.max_tokens_per_msa = max_tokens_per_msa
79
+
80
+ def forward(self, x):
81
+ x = self.activation_fn(self.fc1(x))
82
+ x = self.activation_dropout_module(x)
83
+ return self.fc2(x)
84
+
85
+
86
+ class RowSelfAttention(nn.Module):
87
+ """Self-attention across columns (sequence positions), summed over MSA rows."""
88
+
89
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ self.dropout = dropout
93
+ self.head_dim = embed_dim // num_heads
94
+ self.scaling = self.head_dim ** -0.5
95
+ self.max_tokens_per_msa = max_tokens_per_msa
96
+ self.attn_shape = "hnij"
97
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
98
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
99
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
100
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
101
+ self.dropout_module = nn.Dropout(dropout)
102
+
103
+ def align_scaling(self, q):
104
+ return self.scaling / math.sqrt(q.size(0))
105
+
106
+ def compute_attention_weights(self, x, scaling, padding_mask=None):
107
+ num_rows, num_cols, batch_size, embed_dim = x.size()
108
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
109
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
110
+ q = q * scaling
111
+ if padding_mask is not None:
112
+ q = q * (1 - padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q))
113
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
114
+ if padding_mask is not None:
115
+ attn_weights = attn_weights.masked_fill(
116
+ padding_mask[:, 0].unsqueeze(0).unsqueeze(2), -10000.0)
117
+ return attn_weights
118
+
119
+ def compute_attention_update(self, x, attn_probs):
120
+ num_rows, num_cols, batch_size, embed_dim = x.size()
121
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
122
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
123
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
124
+ return self.out_proj(context)
125
+
126
+ def _batched_forward(self, x, padding_mask=None):
127
+ num_rows, num_cols, batch_size, _ = x.size()
128
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
129
+ scaling = self.align_scaling(x)
130
+ attns = 0
131
+ for start in range(0, num_rows, max_rows):
132
+ pm = padding_mask[:, start:start + max_rows] if padding_mask is not None else None
133
+ attns = attns + self.compute_attention_weights(x[start:start + max_rows], scaling, pm)
134
+ attn_probs = attns.softmax(-1)
135
+ attn_probs = self.dropout_module(attn_probs)
136
+ outputs = [self.compute_attention_update(x[start:start + max_rows], attn_probs)
137
+ for start in range(0, num_rows, max_rows)]
138
+ return torch.cat(outputs, 0), attn_probs
139
+
140
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None):
141
+ num_rows, num_cols, batch_size, _ = x.size()
142
+ if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled():
143
+ return self._batched_forward(x, self_attn_padding_mask)
144
+ scaling = self.align_scaling(x)
145
+ attn_weights = self.compute_attention_weights(x, scaling, self_attn_padding_mask)
146
+ attn_probs = attn_weights.softmax(-1)
147
+ attn_probs = self.dropout_module(attn_probs)
148
+ output = self.compute_attention_update(x, attn_probs)
149
+ return output, attn_probs
150
+
151
+
152
+ class ColumnSelfAttention(nn.Module):
153
+ """Self-attention across MSA rows (alignment depth) per sequence position."""
154
+
155
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float, max_tokens_per_msa: int):
156
+ super().__init__()
157
+ self.num_heads = num_heads
158
+ self.dropout = dropout
159
+ self.head_dim = embed_dim // num_heads
160
+ self.scaling = self.head_dim ** -0.5
161
+ self.max_tokens_per_msa = max_tokens_per_msa
162
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
163
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
164
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
165
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
166
+ self.dropout_module = nn.Dropout(dropout)
167
+
168
+ def compute_attention_update(self, x, self_attn_padding_mask=None):
169
+ num_rows, num_cols, batch_size, embed_dim = x.size()
170
+ if num_rows == 1:
171
+ attn_probs = torch.ones(self.num_heads, num_cols, batch_size, 1, 1,
172
+ device=x.device, dtype=x.dtype)
173
+ output = self.out_proj(self.v_proj(x))
174
+ else:
175
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
176
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
177
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
178
+ q = q * self.scaling
179
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
180
+ if self_attn_padding_mask is not None:
181
+ attn_weights = attn_weights.masked_fill(
182
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), -10000.0)
183
+ attn_probs = attn_weights.softmax(-1)
184
+ attn_probs = self.dropout_module(attn_probs)
185
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
186
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
187
+ output = self.out_proj(context)
188
+ return output, attn_probs
189
+
190
+ def _batched_forward(self, x, self_attn_padding_mask=None):
191
+ num_rows, num_cols, batch_size, _ = x.size()
192
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
193
+ outputs, attns = [], []
194
+ for start in range(0, num_cols, max_cols):
195
+ pm = (self_attn_padding_mask[:, :, start:start + max_cols]
196
+ if self_attn_padding_mask is not None else None)
197
+ out, attn = self.compute_attention_update(x[:, start:start + max_cols], pm)
198
+ outputs.append(out)
199
+ attns.append(attn)
200
+ return torch.cat(outputs, 1), torch.cat(attns, 1)
201
+
202
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None):
203
+ num_rows, num_cols, batch_size, _ = x.size()
204
+ if num_rows * num_cols > self.max_tokens_per_msa and not torch.is_grad_enabled():
205
+ return self._batched_forward(x, self_attn_padding_mask)
206
+ return self.compute_attention_update(x, self_attn_padding_mask)
207
+
208
+
209
+ class AxialTransformerLayer(nn.Module):
210
+ def __init__(self, config: RNAMSMConfig):
211
+ super().__init__()
212
+ self.row_self_attention = NormalizedResidualBlock(
213
+ RowSelfAttention(config.embed_dim, config.num_attention_heads,
214
+ config.attention_dropout, config.max_tokens_per_msa),
215
+ config.embed_dim, config.dropout,
216
+ )
217
+ self.column_self_attention = NormalizedResidualBlock(
218
+ ColumnSelfAttention(config.embed_dim, config.num_attention_heads,
219
+ config.attention_dropout, config.max_tokens_per_msa),
220
+ config.embed_dim, config.dropout,
221
+ )
222
+ self.feed_forward_layer = NormalizedResidualBlock(
223
+ FeedForwardNetwork(config.embed_dim, config.ffn_embed_dim,
224
+ config.activation_dropout, config.max_tokens_per_msa),
225
+ config.embed_dim, config.dropout,
226
+ )
227
+
228
+ def forward(self, x, padding_mask=None, output_attentions=False):
229
+ x, row_attn = self.row_self_attention(x, self_attn_padding_mask=padding_mask)
230
+ x, col_attn = self.column_self_attention(x, self_attn_padding_mask=padding_mask)
231
+ x = self.feed_forward_layer(x)
232
+ return x, row_attn, col_attn
233
+
234
+
235
+ class RNAMSMPreTrainedModel(PreTrainedModel):
236
+ config_class = RNAMSMConfig
237
+ base_model_prefix = "rnamsm"
238
+
239
+ def _init_weights(self, module):
240
+ if isinstance(module, nn.Linear):
241
+ nn.init.normal_(module.weight, std=0.02)
242
+ if module.bias is not None:
243
+ nn.init.zeros_(module.bias)
244
+ elif isinstance(module, nn.Embedding):
245
+ nn.init.normal_(module.weight, std=0.02)
246
+ if module.padding_idx is not None:
247
+ module.weight.data[module.padding_idx].zero_()
248
+ elif isinstance(module, nn.LayerNorm):
249
+ nn.init.ones_(module.weight)
250
+ nn.init.zeros_(module.bias)
251
+
252
+
253
+ class RNAMSMModel(RNAMSMPreTrainedModel):
254
+ """
255
+ RNA-MSM backbone: MSA Transformer that processes multiple-sequence-aligned RNA
256
+ sequences and produces per-position embeddings for each alignment row.
257
+
258
+ Input: input_ids of shape (batch, num_alignments, seqlen)
259
+ Output: last_hidden_state of shape (batch, num_alignments, seqlen, embed_dim)
260
+ """
261
+
262
+ def __init__(self, config: RNAMSMConfig):
263
+ super().__init__(config)
264
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim,
265
+ padding_idx=config.padding_idx)
266
+ self.embed_positions = LearnedPositionalEmbedding(
267
+ config.max_positions, config.embed_dim, config.padding_idx)
268
+ if config.embed_positions_msa:
269
+ self.msa_position_embedding = nn.Parameter(
270
+ 0.01 * torch.randn(1, config.max_alignments, 1, 1))
271
+ else:
272
+ self.register_parameter("msa_position_embedding", None)
273
+ self.dropout_module = nn.Dropout(config.dropout)
274
+ self.emb_layer_norm_before = nn.LayerNorm(config.embed_dim)
275
+ self.emb_layer_norm_after = nn.LayerNorm(config.embed_dim)
276
+ self.layers = nn.ModuleList([AxialTransformerLayer(config)
277
+ for _ in range(config.num_layers)])
278
+ self.post_init()
279
+
280
+ def forward(
281
+ self,
282
+ input_ids: torch.Tensor,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ output_hidden_states: Optional[bool] = None,
285
+ output_attentions: Optional[bool] = None,
286
+ return_dict: Optional[bool] = None,
287
+ ):
288
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None
289
+ else self.config.output_hidden_states)
290
+ output_attentions = (output_attentions if output_attentions is not None
291
+ else self.config.output_attentions)
292
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
293
+
294
+ assert input_ids.ndim == 3, (
295
+ "RNA-MSM expects 3D input_ids of shape (batch, num_alignments, seqlen). "
296
+ "For single sequences, use tokenizer which produces (batch, 1, seqlen).")
297
+
298
+ batch_size, num_alignments, seqlen = input_ids.size()
299
+
300
+ # HF convention: attention_mask 1=attend, 0=pad -> padding_mask True=padding
301
+ if attention_mask is not None:
302
+ padding_mask = attention_mask.eq(0)
303
+ else:
304
+ padding_mask = input_ids.eq(self.config.padding_idx)
305
+
306
+ if not padding_mask.any():
307
+ padding_mask = None
308
+
309
+ # (B, R, C) -> embed: (B, R, C, D)
310
+ x = self.embed_tokens(input_ids)
311
+ x = x + self.embed_positions(
312
+ input_ids.view(batch_size * num_alignments, seqlen)
313
+ ).view(batch_size, num_alignments, seqlen, self.config.embed_dim)
314
+
315
+ if self.msa_position_embedding is not None:
316
+ if num_alignments > self.config.max_alignments:
317
+ raise RuntimeError(
318
+ f"MSA depth {num_alignments} exceeds max_alignments "
319
+ f"{self.config.max_alignments}.")
320
+ x = x + self.msa_position_embedding[:, :num_alignments]
321
+
322
+ x = self.emb_layer_norm_before(x)
323
+ x = self.dropout_module(x)
324
+
325
+ if padding_mask is not None:
326
+ x = x * (1 - padding_mask.unsqueeze(-1).to(x))
327
+
328
+ all_hidden_states = []
329
+ all_row_attentions = []
330
+ all_col_attentions = []
331
+
332
+ if output_hidden_states:
333
+ all_hidden_states.append(x)
334
+
335
+ # (B, R, C, D) -> (R, C, B, D) for axial attention
336
+ x = x.permute(1, 2, 0, 3)
337
+
338
+ for layer in self.layers:
339
+ x, row_attn, col_attn = layer(x, padding_mask=padding_mask,
340
+ output_attentions=output_attentions)
341
+ if output_hidden_states:
342
+ all_hidden_states.append(x.permute(2, 0, 1, 3))
343
+ if output_attentions:
344
+ all_row_attentions.append(row_attn)
345
+ all_col_attentions.append(col_attn)
346
+
347
+ x = self.emb_layer_norm_after(x)
348
+ x = x.permute(2, 0, 1, 3) # (R, C, B, D) -> (B, R, C, D)
349
+
350
+ if output_hidden_states:
351
+ all_hidden_states[-1] = x
352
+
353
+ if not return_dict:
354
+ return tuple(v for v in [
355
+ x,
356
+ tuple(all_hidden_states) if output_hidden_states else None,
357
+ tuple(all_row_attentions) if output_attentions else None,
358
+ ] if v is not None)
359
+
360
+ return BaseModelOutput(
361
+ last_hidden_state=x,
362
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
363
+ attentions=tuple(all_row_attentions) if output_attentions else None,
364
+ )
365
+
366
+
367
+ class RNAMSMForMaskedLM(RNAMSMPreTrainedModel):
368
+ _tied_weights_keys = ["lm_head.weight"]
369
+
370
+ def __init__(self, config: RNAMSMConfig):
371
+ super().__init__(config)
372
+ self.rnamsm = RNAMSMModel(config)
373
+ self.lm_head = RNAMSMLMHead(config, self.rnamsm.embed_tokens.weight)
374
+ self.post_init()
375
+
376
+ def get_output_embeddings(self):
377
+ return self.lm_head
378
+
379
+ def set_output_embeddings(self, new_embeddings):
380
+ self.lm_head = new_embeddings
381
+
382
+ def forward(
383
+ self,
384
+ input_ids: torch.Tensor,
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ labels: Optional[torch.Tensor] = None,
387
+ output_hidden_states: Optional[bool] = None,
388
+ output_attentions: Optional[bool] = None,
389
+ return_dict: Optional[bool] = None,
390
+ ):
391
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
392
+
393
+ out = self.rnamsm(
394
+ input_ids,
395
+ attention_mask=attention_mask,
396
+ output_hidden_states=output_hidden_states,
397
+ output_attentions=output_attentions,
398
+ return_dict=return_dict,
399
+ )
400
+
401
+ logits = self.lm_head(out[0] if not return_dict else out.last_hidden_state)
402
+
403
+ loss = None
404
+ if labels is not None:
405
+ loss = F.cross_entropy(
406
+ logits.view(-1, self.config.vocab_size),
407
+ labels.view(-1),
408
+ ignore_index=-100,
409
+ )
410
+
411
+ if not return_dict:
412
+ output = (logits,) + out[1:]
413
+ return ((loss,) + output) if loss is not None else output
414
+
415
+ return MaskedLMOutput(
416
+ loss=loss,
417
+ logits=logits,
418
+ hidden_states=out.hidden_states,
419
+ attentions=out.attentions,
420
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "<cls>",
3
+ "eos_token": "<eos>",
4
+ "mask_token": "<mask>",
5
+ "pad_token": "<pad>",
6
+ "unk_token": "<unk>"
7
+ }
tokenization_rnamsm.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ import torch
6
+ from transformers import PreTrainedTokenizer
7
+
8
+
9
+ _VOCAB = {
10
+ "<cls>": 0,
11
+ "<pad>": 1,
12
+ "<eos>": 2,
13
+ "<unk>": 3,
14
+ "A": 4,
15
+ "G": 5,
16
+ "C": 6,
17
+ "U": 7,
18
+ "X": 8,
19
+ "N": 9,
20
+ "-": 10,
21
+ "<mask>": 11,
22
+ }
23
+
24
+
25
+ class RNAMSMTokenizer(PreTrainedTokenizer):
26
+ """
27
+ Tokenizer for RNA-MSM.
28
+
29
+ Vocabulary: <cls>(0) <pad>(1) <eos>(2) <unk>(3) A(4) G(5) C(6) U(7) X(8) N(9) -(10) <mask>(11)
30
+
31
+ RNA-MSM is an MSA Transformer: it always expects 3D input
32
+ (batch, num_alignments, seqlen). This tokenizer treats each input string
33
+ as a single-sequence MSA (1 alignment row), so the standard __call__ API:
34
+
35
+ enc = tokenizer(["AGCU", "GAUC"], return_tensors="pt", padding=True)
36
+ # enc.input_ids: (2, 1, T) -- batch of 2 single-sequence MSAs
37
+
38
+ For real MSAs (multiple aligned sequences), use encode_msa():
39
+
40
+ enc = tokenizer.encode_msa([["AGCU--", "AGCUUU"]], return_tensors="pt")
41
+ # enc["input_ids"]: (1, 2, T) -- 1 MSA with 2 alignment rows
42
+ """
43
+
44
+ vocab_files_names = {"vocab_file": "vocab.json"}
45
+ model_input_names = ["input_ids", "attention_mask"]
46
+
47
+ def __init__(
48
+ self,
49
+ vocab_file=None,
50
+ cls_token="<cls>",
51
+ pad_token="<pad>",
52
+ eos_token="<eos>",
53
+ unk_token="<unk>",
54
+ mask_token="<mask>",
55
+ **kwargs,
56
+ ):
57
+ if vocab_file and os.path.isfile(vocab_file):
58
+ with open(vocab_file) as f:
59
+ self._vocab = json.load(f)
60
+ else:
61
+ self._vocab = dict(_VOCAB)
62
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
63
+ super().__init__(
64
+ cls_token=cls_token,
65
+ pad_token=pad_token,
66
+ eos_token=eos_token,
67
+ unk_token=unk_token,
68
+ mask_token=mask_token,
69
+ **kwargs,
70
+ )
71
+
72
+ @property
73
+ def vocab_size(self):
74
+ return len(self._vocab)
75
+
76
+ def get_vocab(self):
77
+ return dict(self._vocab)
78
+
79
+ def _tokenize(self, text):
80
+ return list(text)
81
+
82
+ def _convert_token_to_id(self, token):
83
+ return self._vocab.get(token, self._vocab["<unk>"])
84
+
85
+ def _convert_id_to_token(self, index):
86
+ return self._ids_to_tokens.get(index, "<unk>")
87
+
88
+ def save_vocabulary(self, save_directory, filename_prefix=None):
89
+ os.makedirs(save_directory, exist_ok=True)
90
+ fname = (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
91
+ path = os.path.join(save_directory, fname)
92
+ with open(path, "w") as f:
93
+ json.dump(self._vocab, f, indent=2)
94
+ return (path,)
95
+
96
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
97
+ cls = [self.cls_token_id]
98
+ if token_ids_1 is None:
99
+ return cls + token_ids_0
100
+ return cls + token_ids_0 + cls + token_ids_1
101
+
102
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None,
103
+ already_has_special_tokens=False):
104
+ if already_has_special_tokens:
105
+ return super().get_special_tokens_mask(
106
+ token_ids_0, token_ids_1, already_has_special_tokens=True)
107
+ mask = [1] + [0] * len(token_ids_0)
108
+ if token_ids_1 is not None:
109
+ mask += [1] + [0] * len(token_ids_1)
110
+ return mask
111
+
112
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
113
+ if token_ids_1 is None:
114
+ return [0] + token_ids_0
115
+ return [0] + token_ids_0 + [0] + token_ids_1
116
+
117
+ def __call__(
118
+ self,
119
+ text,
120
+ text_pair=None,
121
+ add_special_tokens=True,
122
+ padding=False,
123
+ truncation=False,
124
+ max_length=None,
125
+ return_tensors=None,
126
+ **kwargs,
127
+ ):
128
+ """
129
+ Tokenize one or more sequences, each treated as a 1-row MSA.
130
+
131
+ text: str or List[str]
132
+ Returns dict with input_ids of shape (batch, 1, seqlen) and
133
+ attention_mask of shape (batch, 1, seqlen).
134
+ """
135
+ if isinstance(text, str):
136
+ sequences = [text]
137
+ else:
138
+ sequences = list(text)
139
+
140
+ encoded = []
141
+ for seq in sequences:
142
+ ids = self._tokenize_single(seq, add_special_tokens)
143
+ encoded.append(ids)
144
+
145
+ if padding and len(encoded) > 1:
146
+ max_len = max(len(ids) for ids in encoded)
147
+ pad_id = self.pad_token_id
148
+ encoded = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded]
149
+
150
+ input_ids = [[ids] for ids in encoded]
151
+ attention_mask = [[[1 if t != self.pad_token_id else 0 for t in ids]]
152
+ for ids in encoded]
153
+
154
+ if return_tensors == "pt":
155
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
156
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long)
157
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
158
+
159
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
160
+
161
+ def _tokenize_single(self, sequence, add_special_tokens=True):
162
+ tokens = list(sequence)
163
+ ids = [self._convert_token_to_id(t) for t in tokens]
164
+ if add_special_tokens:
165
+ ids = [self.cls_token_id] + ids
166
+ return ids
167
+
168
+ def encode_msa(
169
+ self,
170
+ msas,
171
+ add_special_tokens=True,
172
+ padding=False,
173
+ return_tensors=None,
174
+ ):
175
+ """
176
+ Tokenize a batch of MSAs.
177
+
178
+ msas: List[List[str]]
179
+ Each inner list is one MSA (multiple aligned sequences of equal length).
180
+ All sequences within an MSA must have the same length.
181
+
182
+ Returns dict with:
183
+ input_ids: (batch, max_alignments, max_seqlen)
184
+ attention_mask: (batch, max_alignments, max_seqlen)
185
+ """
186
+ if isinstance(msas[0], str):
187
+ msas = [msas]
188
+
189
+ max_rows = max(len(msa) for msa in msas)
190
+ max_seqlen = max(
191
+ len(self._tokenize_single(seq, add_special_tokens))
192
+ for msa in msas for seq in msa
193
+ )
194
+
195
+ pad_id = self.pad_token_id
196
+ batch_ids = []
197
+ batch_mask = []
198
+
199
+ for msa in msas:
200
+ msa_ids = []
201
+ msa_mask = []
202
+ for seq in msa:
203
+ ids = self._tokenize_single(seq, add_special_tokens)
204
+ if padding:
205
+ pad_len = max_seqlen - len(ids)
206
+ mask = [1] * len(ids) + [0] * pad_len
207
+ ids = ids + [pad_id] * pad_len
208
+ else:
209
+ mask = [1] * len(ids)
210
+ msa_ids.append(ids)
211
+ msa_mask.append(mask)
212
+
213
+ if padding:
214
+ pad_row = [pad_id] * max_seqlen
215
+ pad_mask_row = [0] * max_seqlen
216
+ while len(msa_ids) < max_rows:
217
+ msa_ids.append(pad_row)
218
+ msa_mask.append(pad_mask_row)
219
+
220
+ batch_ids.append(msa_ids)
221
+ batch_mask.append(msa_mask)
222
+
223
+ if return_tensors == "pt":
224
+ batch_ids = torch.tensor(batch_ids, dtype=torch.long)
225
+ batch_mask = torch.tensor(batch_mask, dtype=torch.long)
226
+ return {"input_ids": batch_ids, "attention_mask": batch_mask}
227
+
228
+ return {"input_ids": batch_ids, "attention_mask": batch_mask}
229
+
230
+ def decode(self, token_ids, skip_special_tokens=False, **kwargs):
231
+ if isinstance(token_ids, torch.Tensor):
232
+ token_ids = token_ids.tolist()
233
+ tokens = [self._convert_id_to_token(i) for i in token_ids]
234
+ if skip_special_tokens:
235
+ special = {self.cls_token, self.pad_token, self.eos_token,
236
+ self.unk_token, self.mask_token}
237
+ tokens = [t for t in tokens if t not in special]
238
+ return "".join(tokens)
239
+
240
+ def num_special_tokens_to_add(self, pair=False):
241
+ return 1
tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<cls>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "11": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "auto_map": {
45
+ "AutoTokenizer": ["tokenization_rnamsm.RNAMSMTokenizer", null]
46
+ },
47
+ "clean_up_tokenization_spaces": false,
48
+ "cls_token": "<cls>",
49
+ "eos_token": "<eos>",
50
+ "extra_special_tokens": {},
51
+ "mask_token": "<mask>",
52
+ "model_max_length": 1024,
53
+ "pad_token": "<pad>",
54
+ "tokenizer_class": "RNAMSMTokenizer",
55
+ "unk_token": "<unk>"
56
+ }
vocab.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<cls>": 0,
3
+ "<pad>": 1,
4
+ "<eos>": 2,
5
+ "<unk>": 3,
6
+ "A": 4,
7
+ "G": 5,
8
+ "C": 6,
9
+ "U": 7,
10
+ "X": 8,
11
+ "N": 9,
12
+ "-": 10,
13
+ "<mask>": 11
14
+ }