File size: 5,498 Bytes
898e706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25670fa
 
 
898e706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
---
language:
- rna
library_name: transformers
tags:
- RNA
- language-model
license: apache-2.0
---

# ERNIE-RNA

ERNIE-RNA is an RNA-specific large language model that incorporates RNA base-pairing potential as a recurrent 2D structural bias into each attention layer, enabling the model to capture secondary structure information during pretraining.

## Architecture

| Parameter | Value |
|---|---|
| Layers | 12 |
| Attention heads | 12 |
| Embedding dimension | 768 |
| FFN dimension | 3072 |
| Vocabulary size | 25 |
| Positional encoding | Sinusoidal (fairseq-style) |
| Architecture | Post-LN Transformer with recurrent 2D RNA pairing bias |
| Max sequence length | 1024 |

### Vocabulary

| Token | ID | Notes |
|---|---|---|
| `<cls>` | 0 | Prepended to every sequence |
| `<pad>` | 1 | Padding token |
| `<eos>` | 2 | Appended to every sequence |
| `<unk>` | 3 | Unknown token |
| G | 4 | |
| A | 5 | |
| U | 6 | T is silently mapped to U during tokenization |
| C | 7 | |
| N | 8 | Ambiguous nucleotide |
| Y-I | 9-20 | IUPAC ambiguity codes |
| madeupword0-2 | 21-23 | Padding tokens from original vocab |
| `<mask>` | 24 | MLM mask token |

### 2D RNA Pairing Bias

ERNIE-RNA computes a pairwise RNA base-pairing potential matrix from the input sequence at the start of each forward pass. This matrix (shape `[B, T, T, 1]`) is projected to `[B, H, T, T]` via a 2-layer MLP (1 -> 6 -> H, with GELU) and added to the attention logits in the first layer. The pre-softmax attention scores then become the updated 2D bias for the next layer, creating a recurrent structural information pathway across all 12 transformer layers.

Base-pairing scores: A-U = 2.0, G-C = 3.0, G-U wobble = 0.8.

## Pretraining

- **Objective:** Masked language modeling (MLM) on RNA sequences
- **Data:** RNAcentral (non-redundant RNA sequences)
- **Source checkpoint:** `ERNIE-RNA_pretrain.pt`

### Checkpoint selection

Single pretrained checkpoint from the original repository. Used as-is; no fine-tuned variants are included in this release.

## Parity Verification

Hidden-state representations verified identical (max abs diff = 1.82e-06) to the original
implementation at all 13 representation levels (embedding + 12 transformer layers).
Verified on GPU with PyTorch 2.7 / CUDA 12.

Only `attn_implementation="eager"` is supported (see Implementation Notes).

## Related Models

See the full [ERNIE-RNA collection](https://huggingface.co/collections/Taykhoom/ernie-rna-6a20c1a8ea56c00a74e2dd93).

| Model | Notes |
|---|---|
| **[Taykhoom/ERNIE-RNA](https://huggingface.co/Taykhoom/ERNIE-RNA)** | **Pretrained model (this model)** |
| [Taykhoom/ERNIE-RNA-SS](https://huggingface.co/Taykhoom/ERNIE-RNA-SS) | SS fine-tuned (bpRNA-new), backbone only |
| [Taykhoom/ERNIE-RNA-MRL](https://huggingface.co/Taykhoom/ERNIE-RNA-MRL) | UTR MRL fine-tuned, backbone only |

## Usage

### Embedding generation

```python
import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model.eval()

sequences = ["AUGCAUGCAUGC", "GGGGCCCCGGGG"]
enc = tokenizer(sequences, return_tensors="pt", padding=True)

with torch.no_grad():
    out = model(**enc)

cls_emb   = out.last_hidden_state[:, 0, :]   # (batch, 768) -- CLS token
token_emb = out.last_hidden_state             # (batch, seq_len, 768)

# Intermediate layers
out_all = model(**enc, output_hidden_states=True)
layer6_emb = out_all.hidden_states[6]         # (batch, seq_len, 768)
```

### MLM logits

```python
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True)
model.eval()

enc = tokenizer(["AUG<mask>AUG"], return_tensors="pt")
with torch.no_grad():
    logits = model(**enc).logits   # (1, seq_len, 25)
```

### Fine-tuning

Use the CLS token embedding (`last_hidden_state[:, 0, :]`) as input to a prediction head for sequence-level tasks. For token-level tasks, use `last_hidden_state` directly.

## Implementation Notes

ERNIE-RNA's recurrent 2D bias is updated from the pre-softmax attention scores at every layer (the raw QK logits become the bias input for the next layer). Fused attention kernels (SDPA, FlashAttention) do not expose pre-softmax scores, so they cannot maintain this recurrent pathway. Only `attn_implementation="eager"` is supported; requesting `sdpa` or `flash_attention_2` raises a `ValueError`.

The `twod_proj` MLP is always run in float32 (matching the original) regardless of the model's compute dtype.

## Citation

```bibtex
@article{yin2025_ernierna,
  title   = {{ERNIE-RNA}: an {RNA} language model with structure-enhanced representations},
  author  = {Yin, Weijie and Zhang, Zhaoyu and He, Liang and Jiang, Rui and Zhang, Shuo and Liu, Gan and Zeng, Xuezhi and Zhao, Wen and Gao, Xiaowo},
  journal = {Nature Communications},
  volume  = {16},
  number  = {1},
  pages   = {8407},
  year    = {2025},
  doi     = {10.1038/s41467-025-64972-0}
}
```

## Credits

Original model and code by Yin et al. Source: [GitHub](https://github.com/Bruce-ywj/ERNIE-RNA).
The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code)
and reviewed manually by Taykhoom Dalal.

## License

Apache 2.0, following the original repository.