RNAErnie / README.md
Taykhoom's picture
Upload README.md with huggingface_hub
5973c83 verified
metadata
language:
  - rna
library_name: transformers
tags:
  - RNA
  - language-model
license: apache-2.0

RNAErnie

RNAErnie is a BERT-based RNA language model pretrained on RNACentral using a motif-aware masking strategy with type-guided fine-tuning. It uses a DNA-style vocabulary (T instead of U) and extends the token vocabulary with 28 ncRNA type labels to enable type-guided learning.

Architecture

Parameter Value
Layers 12
Attention heads 12
Embedding dimension 768
Intermediate size 3072
Vocabulary size 39
Positional encoding Absolute learned
Architecture Post-LN BERT / ERNIE
Max sequence length 512

Vocabulary: Special tokens [PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [DEL]=5, [IND]=6; ncRNA type labels at indices 7-34 (RNaseMRPRNA, RNasePRNA, SRPRNA, YRNA, antisenseRNA, autocatalyticallysplicedintron, guideRNA, hammerheadribozyme, lncRNA, miRNA, miscRNA, ncRNA, other, piRNA, premiRNA, precursorRNA, rRNA, ribozyme, sRNA, scRNA, scaRNA, siRNA, snRNA, snoRNA, tRNA, telomeraseRNA, tmRNA, vaultRNA); nucleotides A=35, T=36, C=37, G=38.

Tokenisation note: Input U is silently converted to T. The model was pretrained with DNA-style T notation.

Pretraining

  • Objective: Masked language modelling (MLM) with motif-aware masking
  • Data: RNACentral (sequences with length <= 512)
  • Source checkpoint: model_state.pdparams from the original PaddlePaddle repository

Checkpoint selection

There is a single publicly released RNAErnie checkpoint (output/BERT,ERNIE,MOTIF,PROMPT/checkpoint_final/model_state.pdparams), corresponding to the BERT,ERNIE,MOTIF,PROMPT pretraining variant described in the paper.

Parity Verification

Hidden-state representations verified identical (max abs diff < 7e-6) at all 13 representation levels (embedding + 12 layers) against a standalone pure-PyTorch reference that implements the PaddlePaddle ERNIE forward pass directly from the raw .pdparams weights — without running PaddlePaddle. The reference uses PaddlePaddle's linear convention (x @ W, weight stored (in, out)) and loads weights from the original checkpoint file identically to the conversion script, so the comparison is mathematically equivalent to a live PaddlePaddle run. Verified on GPU with PyTorch 2.7 / CUDA 12.

Note on weight conversion: PaddlePaddle stores nn.Linear weights as (in_features, out_features), the transpose of PyTorch's (out_features, in_features). All linear layer weights (attention projections, FFN, pooler, MLM transform) are transposed during conversion; embedding tables and bias vectors are copied as-is.

Implementation Notes

The original implementation uses PaddlePaddle's ERNIE/TransformerEncoderLayer backbone. This HF port re-implements the identical Post-LN BERT architecture in pure PyTorch and adds attn_implementation="sdpa" and attn_implementation="flash_attention_2" support, which were not part of the original codebase.

Related Models

See the full RNAErnie collection.

Model Context Training data Notes
RNAErnie 512 RNACentral (nts<=512) This model; PaddlePaddle ERNIE backbone
RNAErnie2 2048 RNACentral v22 (~31M seqs) Retrained; PyTorch BERT

Usage

Embedding generation

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model.eval()

sequences = ["AUGCAUGCAUGC", "GCUGCAUGCUAGC"]
enc = tokenizer(sequences, return_tensors="pt", padding=True)

with torch.no_grad():
    out = model(**enc)

cls_emb   = out.last_hidden_state[:, 0, :]  # (batch, 768) -- CLS token
token_emb = out.last_hidden_state           # (batch, seq_len, 768)

# Intermediate layers
out_all = model(**enc, output_hidden_states=True)
layer6_emb = out_all.hidden_states[6]       # (batch, seq_len, 768)

MLM logits

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True)
model.eval()

enc = tokenizer(["ATG[MASK]ATG"], return_tensors="pt")
with torch.no_grad():
    logits = model(**enc).logits  # (1, seq_len, 39)

SDPA / Flash Attention 2

model = AutoModel.from_pretrained(
    "Taykhoom/RNAErnie",
    attn_implementation="sdpa",   # or "flash_attention_2"
    trust_remote_code=True,
)

Fine-tuning

Standard HF conventions. For sequence-level tasks, use the CLS token embedding (last_hidden_state[:, 0, :]) as input to a classification head. For type-guided fine-tuning (as in the paper), prepend the ncRNA type label token to the input.

Citation

@article{wang2024_rnaernie,
  title   = {Multi-purpose {RNA} language modelling with motif-aware pretraining and type-guided fine-tuning},
  author  = {Wang, Ning and Bian, Jiang and Li, Yuchen and Li, Xuhong and Mumtaz, Shahid and Kong, Linghe and Xiong, Haoyi},
  journal = {Nature Machine Intelligence},
  volume  = {6},
  pages   = {548--557},
  year    = {2024},
  doi     = {10.1038/s42256-024-00836-4}
}

Credits

Original model and code by Wang et al. Source: GitHub. The HF conversion code was authored primarily by Claude Code and reviewed manually by Taykhoom Dalal.

License

Apache 2.0, following the original repository.