RNAErnie

RNAErnie is a BERT-based RNA language model pretrained on RNACentral using a motif-aware masking strategy with type-guided fine-tuning. It uses a DNA-style vocabulary (T instead of U) and extends the token vocabulary with 28 ncRNA type labels to enable type-guided learning.

Architecture

Parameter Value
Layers 12
Attention heads 12
Embedding dimension 768
Intermediate size 3072
Vocabulary size 39
Positional encoding Absolute learned
Architecture Post-LN BERT / ERNIE
Max sequence length 512

Vocabulary: Special tokens [PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [DEL]=5, [IND]=6; ncRNA type labels at indices 7-34 (RNaseMRPRNA, RNasePRNA, SRPRNA, YRNA, antisenseRNA, autocatalyticallysplicedintron, guideRNA, hammerheadribozyme, lncRNA, miRNA, miscRNA, ncRNA, other, piRNA, premiRNA, precursorRNA, rRNA, ribozyme, sRNA, scRNA, scaRNA, siRNA, snRNA, snoRNA, tRNA, telomeraseRNA, tmRNA, vaultRNA); nucleotides A=35, T=36, C=37, G=38.

Tokenisation note: Input U is silently converted to T. The model was pretrained with DNA-style T notation.

Pretraining

  • Objective: Masked language modelling (MLM) with motif-aware masking
  • Data: RNACentral (sequences with length <= 512)
  • Source checkpoint: model_state.pdparams from the original PaddlePaddle repository

Checkpoint selection

There is a single publicly released RNAErnie checkpoint (output/BERT,ERNIE,MOTIF,PROMPT/checkpoint_final/model_state.pdparams), corresponding to the BERT,ERNIE,MOTIF,PROMPT pretraining variant described in the paper.

Parity Verification

Hidden-state representations verified identical (max abs diff < 7e-6) at all 13 representation levels (embedding + 12 layers) against a standalone pure-PyTorch reference that implements the PaddlePaddle ERNIE forward pass directly from the raw .pdparams weights — without running PaddlePaddle. The reference uses PaddlePaddle's linear convention (x @ W, weight stored (in, out)) and loads weights from the original checkpoint file identically to the conversion script, so the comparison is mathematically equivalent to a live PaddlePaddle run. Verified on GPU with PyTorch 2.7 / CUDA 12.

Note on weight conversion: PaddlePaddle stores nn.Linear weights as (in_features, out_features), the transpose of PyTorch's (out_features, in_features). All linear layer weights (attention projections, FFN, pooler, MLM transform) are transposed during conversion; embedding tables and bias vectors are copied as-is.

Implementation Notes

The original implementation uses PaddlePaddle's ERNIE/TransformerEncoderLayer backbone. This HF port re-implements the identical Post-LN BERT architecture in pure PyTorch and adds attn_implementation="sdpa" and attn_implementation="flash_attention_2" support, which were not part of the original codebase.

Related Models

See the full RNAErnie collection.

Model Context Training data Notes
RNAErnie 512 RNACentral (nts<=512) This model; PaddlePaddle ERNIE backbone
RNAErnie2 2048 RNACentral v22 (~31M seqs) Retrained; PyTorch BERT

Usage

Embedding generation

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model.eval()

sequences = ["AUGCAUGCAUGC", "GCUGCAUGCUAGC"]
enc = tokenizer(sequences, return_tensors="pt", padding=True)

with torch.no_grad():
    out = model(**enc)

cls_emb   = out.last_hidden_state[:, 0, :]  # (batch, 768) -- CLS token
token_emb = out.last_hidden_state           # (batch, seq_len, 768)

# Intermediate layers
out_all = model(**enc, output_hidden_states=True)
layer6_emb = out_all.hidden_states[6]       # (batch, seq_len, 768)

MLM logits

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model.eval()

enc = tokenizer(["ATG[MASK]ATG"], return_tensors="pt")
with torch.no_grad():
    logits = model(**enc).logits  # (1, seq_len, 39)

SDPA / Flash Attention 2

model = AutoModel.from_pretrained(
    "Taykhoom/RNAErnie",
    attn_implementation="sdpa",   # or "flash_attention_2"
    trust_remote_code=True,
)

Fine-tuning

Standard HF conventions. For sequence-level tasks, use the CLS token embedding (last_hidden_state[:, 0, :]) as input to a classification head. For type-guided fine-tuning (as in the paper), prepend the ncRNA type label token to the input.

Citation

@article{wang2024_rnaernie,
  title   = {Multi-purpose {RNA} language modelling with motif-aware pretraining and type-guided fine-tuning},
  author  = {Wang, Ning and Bian, Jiang and Li, Yuchen and Li, Xuhong and Mumtaz, Shahid and Kong, Linghe and Xiong, Haoyi},
  journal = {Nature Machine Intelligence},
  volume  = {6},
  pages   = {548--557},
  year    = {2024},
  doi     = {10.1038/s42256-024-00836-4}
}

Credits

Original model and code by Wang et al. Source: GitHub. The HF conversion code was authored primarily by Claude Code and reviewed manually by Taykhoom Dalal.

License

Apache 2.0, following the original repository.

Downloads last month
54
Safetensors
Model size
86.7M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including Taykhoom/RNAErnie