RNAErnie / README.md
Taykhoom's picture
Upload README.md with huggingface_hub
5973c83 verified
---
language:
- rna
library_name: transformers
tags:
- RNA
- language-model
license: apache-2.0
---
# RNAErnie
RNAErnie is a BERT-based RNA language model pretrained on RNACentral using a
motif-aware masking strategy with type-guided fine-tuning. It uses a DNA-style
vocabulary (T instead of U) and extends the token vocabulary with 28 ncRNA
type labels to enable type-guided learning.
## Architecture
| Parameter | Value |
|---|---|
| Layers | 12 |
| Attention heads | 12 |
| Embedding dimension | 768 |
| Intermediate size | 3072 |
| Vocabulary size | 39 |
| Positional encoding | Absolute learned |
| Architecture | Post-LN BERT / ERNIE |
| Max sequence length | 512 |
**Vocabulary:** Special tokens `[PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [DEL]=5, [IND]=6`;
ncRNA type labels at indices 7-34 (RNaseMRPRNA, RNasePRNA, SRPRNA, YRNA, antisenseRNA,
autocatalyticallysplicedintron, guideRNA, hammerheadribozyme, lncRNA, miRNA, miscRNA,
ncRNA, other, piRNA, premiRNA, precursorRNA, rRNA, ribozyme, sRNA, scRNA, scaRNA,
siRNA, snRNA, snoRNA, tRNA, telomeraseRNA, tmRNA, vaultRNA);
nucleotides `A=35, T=36, C=37, G=38`.
**Tokenisation note:** Input U is silently converted to T. The model was pretrained
with DNA-style T notation.
## Pretraining
- **Objective:** Masked language modelling (MLM) with motif-aware masking
- **Data:** RNACentral (sequences with length <= 512)
- **Source checkpoint:** `model_state.pdparams` from the original PaddlePaddle repository
### Checkpoint selection
There is a single publicly released RNAErnie checkpoint
(`output/BERT,ERNIE,MOTIF,PROMPT/checkpoint_final/model_state.pdparams`),
corresponding to the `BERT,ERNIE,MOTIF,PROMPT` pretraining variant described in the
paper.
## Parity Verification
Hidden-state representations verified identical (max abs diff < 7e-6) at all
13 representation levels (embedding + 12 layers) against a standalone
pure-PyTorch reference that implements the PaddlePaddle ERNIE forward pass
directly from the raw `.pdparams` weights — without running PaddlePaddle.
The reference uses PaddlePaddle's linear convention (`x @ W`, weight stored
`(in, out)`) and loads weights from the original checkpoint file identically to
the conversion script, so the comparison is mathematically equivalent to a live
PaddlePaddle run. Verified on GPU with PyTorch 2.7 / CUDA 12.
**Note on weight conversion:** PaddlePaddle stores `nn.Linear` weights as
`(in_features, out_features)`, the transpose of PyTorch's `(out_features, in_features)`.
All linear layer weights (attention projections, FFN, pooler, MLM transform) are
transposed during conversion; embedding tables and bias vectors are copied as-is.
## Implementation Notes
The original implementation uses PaddlePaddle's ERNIE/TransformerEncoderLayer
backbone. This HF port re-implements the identical Post-LN BERT architecture in
pure PyTorch and adds `attn_implementation="sdpa"` and
`attn_implementation="flash_attention_2"` support, which were not part of the
original codebase.
## Related Models
See the full [RNAErnie collection](https://huggingface.co/collections/Taykhoom/rnaernie-6a219927c11fdcccedb243db).
| Model | Context | Training data | Notes |
|---|---|---|---|
| **[RNAErnie](https://huggingface.co/Taykhoom/RNAErnie)** | **512** | **RNACentral (nts<=512)** | **This model; PaddlePaddle ERNIE backbone** |
| [RNAErnie2](https://huggingface.co/Taykhoom/RNAErnie2) | 2048 | RNACentral v22 (~31M seqs) | Retrained; PyTorch BERT |
## Usage
### Embedding generation
```python
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model.eval()
sequences = ["AUGCAUGCAUGC", "GCUGCAUGCUAGC"]
enc = tokenizer(sequences, return_tensors="pt", padding=True)
with torch.no_grad():
out = model(**enc)
cls_emb = out.last_hidden_state[:, 0, :] # (batch, 768) -- CLS token
token_emb = out.last_hidden_state # (batch, seq_len, 768)
# Intermediate layers
out_all = model(**enc, output_hidden_states=True)
layer6_emb = out_all.hidden_states[6] # (batch, seq_len, 768)
```
### MLM logits
```python
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model.eval()
enc = tokenizer(["ATG[MASK]ATG"], return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits # (1, seq_len, 39)
```
### SDPA / Flash Attention 2
```python
model = AutoModel.from_pretrained(
"Taykhoom/RNAErnie",
attn_implementation="sdpa", # or "flash_attention_2"
trust_remote_code=True,
)
```
### Fine-tuning
Standard HF conventions. For sequence-level tasks, use the CLS token embedding
(`last_hidden_state[:, 0, :]`) as input to a classification head. For type-guided
fine-tuning (as in the paper), prepend the ncRNA type label token to the input.
## Citation
```bibtex
@article{wang2024_rnaernie,
title = {Multi-purpose {RNA} language modelling with motif-aware pretraining and type-guided fine-tuning},
author = {Wang, Ning and Bian, Jiang and Li, Yuchen and Li, Xuhong and Mumtaz, Shahid and Kong, Linghe and Xiong, Haoyi},
journal = {Nature Machine Intelligence},
volume = {6},
pages = {548--557},
year = {2024},
doi = {10.1038/s42256-024-00836-4}
}
```
## Credits
Original model and code by Wang et al. Source: [GitHub](https://github.com/CatIIIIIIII/RNAErnie).
The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code)
and reviewed manually by Taykhoom Dalal.
## License
Apache 2.0, following the original repository.