RNAErnie2 / README.md
Taykhoom's picture
Upload README.md with huggingface_hub
da85d8c verified
---
language:
- rna
library_name: transformers
tags:
- RNA
- language-model
license: apache-2.0
---
# RNAErnie2
RNAErnie2 is a BERT-based RNA language model trained from scratch on a large-scale RNA
sequence dataset with up to 2048-nucleotide context length. It is a retrained successor
to RNAErnie that replaces the PaddlePaddle-based ERNIE backbone with a standard PyTorch
BERT architecture, extends the pretraining corpus to RNACentral v22 (~31M sequences,
length <= 2048), and switches to an RNA-native vocabulary (U instead of T).
## Architecture
| Parameter | Value |
|---|---|
| Layers | 12 |
| Attention heads | 12 |
| Embedding dimension | 768 |
| Intermediate size | 3072 |
| Vocabulary size | 11 |
| Positional encoding | Absolute learned |
| Architecture | Post-LN BERT / BertForMaskedLM |
| Max sequence length | 2048 |
**Vocabulary:** `[PAD]=0, [UNK]=1, [CLS]=2, [EOS]=3, [SEP]=4, [MASK]=5, A=6, U=7, C=8, G=9, N=10`
## Pretraining
- **Objective:** Masked language modelling (MLM)
- **Data:** RNACentral v22, ~31 million RNA sequences with length <= 2048
- **Source checkpoint:** [`LLM-EDA/RNAErnie`](https://huggingface.co/LLM-EDA/RNAErnie) on HuggingFace Hub
- **Tokenisation note:** Sequences use U (not T). Input T is silently converted to U by the tokenizer.
### Checkpoint selection
There is a single publicly released RNAErnie2 checkpoint. The weights are taken from
[`LLM-EDA/RNAErnie`](https://huggingface.co/LLM-EDA/RNAErnie) with one minor
adjustment: `cls.predictions.decoder.bias` is stored explicitly (it was implicitly
tied to `cls.predictions.bias` in the original save and was absent from the file).
## Parity Verification
Hidden-state representations and MLM logits verified identical (max abs diff < 2e-5)
to the original `BertForMaskedLM` at all 13 representation levels (embedding + 12 layers).
Verified on GPU with PyTorch 2.7 / CUDA 12.
## Implementation Notes
Custom BERT implementation (`modeling_rnaernie2.py`) with eager, SDPA, and Flash
Attention 2 backends, following the architecture of
[`Taykhoom/BERT-updated`](https://huggingface.co/Taykhoom/BERT-updated).
The original [`LLM-EDA/RNAErnie`](https://huggingface.co/LLM-EDA/RNAErnie) used
standard HF BERT with no custom attention backends.
## Related Models
See the full [RNAErnie collection](https://huggingface.co/collections/Taykhoom/rnaernie-6a219927c11fdcccedb243db).
| Model | Context | Training data | Notes |
|---|---|---|---|
| [RNAErnie](https://huggingface.co/Taykhoom/RNAErnie) | 512 | RNACentral (nts<=512) | Original; PaddlePaddle backbone |
| **[RNAErnie2](https://huggingface.co/Taykhoom/RNAErnie2)** | **2048** | **RNACentral v22 (~31M seqs)** | **This model; PyTorch BERT** |
## Usage
### Embedding generation
```python
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
model.eval()
sequences = ["AUGCAUGCAUGC", "GCUGCAUGCUAGC"]
enc = tokenizer(sequences, return_tensors="pt", padding=True)
with torch.no_grad():
out = model(**enc)
cls_emb = out.last_hidden_state[:, 0, :] # (batch, 768) -- CLS token
token_emb = out.last_hidden_state # (batch, seq_len, 768)
# Intermediate layers
out_all = model(**enc, output_hidden_states=True)
layer6_emb = out_all.hidden_states[6] # (batch, seq_len, 768)
```
### MLM logits
```python
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNAErnie2", trust_remote_code=True)
model.eval()
enc = tokenizer(["AUG[MASK]AUG"], return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits # (1, seq_len, 11)
```
### SDPA / Flash Attention 2
```python
model = AutoModel.from_pretrained(
"Taykhoom/RNAErnie2",
attn_implementation="sdpa", # or "flash_attention_2"
trust_remote_code=True,
)
```
### Fine-tuning
Standard HF conventions. For sequence-level tasks, use the CLS token embedding
(`last_hidden_state[:, 0, :]`) as input to a classification head.
## Citation
```bibtex
@article{wang2024_rnaernie,
title = {Multi-purpose {RNA} language modelling with motif-aware pretraining and type-guided fine-tuning},
author = {Wang, Ning and Bian, Jiang and Li, Yuchen and Li, Xuhong and Mumtaz, Shahid and Kong, Linghe and Xiong, Haoyi},
journal = {Nature Machine Intelligence},
volume = {6},
pages = {548--557},
year = {2024},
doi = {10.1038/s42256-024-00836-4}
}
```
## Credits
Original model and code by Wang et al. Source: [GitHub](https://github.com/CatIIIIIIII/RNAErnie) /
[HuggingFace](https://huggingface.co/LLM-EDA/RNAErnie).
The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code)
and reviewed manually by Taykhoom Dalal.
## License
Apache 2.0, following the original repository.