--- language: - rna library_name: transformers tags: - RNA - language-model license: apache-2.0 --- # RNAErnie RNAErnie is a BERT-based RNA language model pretrained on RNACentral using a motif-aware masking strategy with type-guided fine-tuning. It uses a DNA-style vocabulary (T instead of U) and extends the token vocabulary with 28 ncRNA type labels to enable type-guided learning. ## Architecture | Parameter | Value | |---|---| | Layers | 12 | | Attention heads | 12 | | Embedding dimension | 768 | | Intermediate size | 3072 | | Vocabulary size | 39 | | Positional encoding | Absolute learned | | Architecture | Post-LN BERT / ERNIE | | Max sequence length | 512 | **Vocabulary:** Special tokens `[PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [DEL]=5, [IND]=6`; ncRNA type labels at indices 7-34 (RNaseMRPRNA, RNasePRNA, SRPRNA, YRNA, antisenseRNA, autocatalyticallysplicedintron, guideRNA, hammerheadribozyme, lncRNA, miRNA, miscRNA, ncRNA, other, piRNA, premiRNA, precursorRNA, rRNA, ribozyme, sRNA, scRNA, scaRNA, siRNA, snRNA, snoRNA, tRNA, telomeraseRNA, tmRNA, vaultRNA); nucleotides `A=35, T=36, C=37, G=38`. **Tokenisation note:** Input U is silently converted to T. The model was pretrained with DNA-style T notation. ## Pretraining - **Objective:** Masked language modelling (MLM) with motif-aware masking - **Data:** RNACentral (sequences with length <= 512) - **Source checkpoint:** `model_state.pdparams` from the original PaddlePaddle repository ### Checkpoint selection There is a single publicly released RNAErnie checkpoint (`output/BERT,ERNIE,MOTIF,PROMPT/checkpoint_final/model_state.pdparams`), corresponding to the `BERT,ERNIE,MOTIF,PROMPT` pretraining variant described in the paper. ## Parity Verification Hidden-state representations verified identical (max abs diff < 7e-6) at all 13 representation levels (embedding + 12 layers) against a standalone pure-PyTorch reference that implements the PaddlePaddle ERNIE forward pass directly from the raw `.pdparams` weights — without running PaddlePaddle. The reference uses PaddlePaddle's linear convention (`x @ W`, weight stored `(in, out)`) and loads weights from the original checkpoint file identically to the conversion script, so the comparison is mathematically equivalent to a live PaddlePaddle run. Verified on GPU with PyTorch 2.7 / CUDA 12. **Note on weight conversion:** PaddlePaddle stores `nn.Linear` weights as `(in_features, out_features)`, the transpose of PyTorch's `(out_features, in_features)`. All linear layer weights (attention projections, FFN, pooler, MLM transform) are transposed during conversion; embedding tables and bias vectors are copied as-is. ## Implementation Notes The original implementation uses PaddlePaddle's ERNIE/TransformerEncoderLayer backbone. This HF port re-implements the identical Post-LN BERT architecture in pure PyTorch and adds `attn_implementation="sdpa"` and `attn_implementation="flash_attention_2"` support, which were not part of the original codebase. ## Related Models See the full [RNAErnie collection](https://huggingface.co/collections/Taykhoom/rnaernie-6a219927c11fdcccedb243db). | Model | Context | Training data | Notes | |---|---|---|---| | **[RNAErnie](https://huggingface.co/Taykhoom/RNAErnie)** | **512** | **RNACentral (nts<=512)** | **This model; PaddlePaddle ERNIE backbone** | | [RNAErnie2](https://huggingface.co/Taykhoom/RNAErnie2) | 2048 | RNACentral v22 (~31M seqs) | Retrained; PyTorch BERT | ## Usage ### Embedding generation ```python import torch from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True) model = AutoModel.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True) model.eval() sequences = ["AUGCAUGCAUGC", "GCUGCAUGCUAGC"] enc = tokenizer(sequences, return_tensors="pt", padding=True) with torch.no_grad(): out = model(**enc) cls_emb = out.last_hidden_state[:, 0, :] # (batch, 768) -- CLS token token_emb = out.last_hidden_state # (batch, seq_len, 768) # Intermediate layers out_all = model(**enc, output_hidden_states=True) layer6_emb = out_all.hidden_states[6] # (batch, seq_len, 768) ``` ### MLM logits ```python import torch from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNAErnie", trust_remote_code=True) model.eval() enc = tokenizer(["ATG[MASK]ATG"], return_tensors="pt") with torch.no_grad(): logits = model(**enc).logits # (1, seq_len, 39) ``` ### SDPA / Flash Attention 2 ```python model = AutoModel.from_pretrained( "Taykhoom/RNAErnie", attn_implementation="sdpa", # or "flash_attention_2" trust_remote_code=True, ) ``` ### Fine-tuning Standard HF conventions. For sequence-level tasks, use the CLS token embedding (`last_hidden_state[:, 0, :]`) as input to a classification head. For type-guided fine-tuning (as in the paper), prepend the ncRNA type label token to the input. ## Citation ```bibtex @article{wang2024_rnaernie, title = {Multi-purpose {RNA} language modelling with motif-aware pretraining and type-guided fine-tuning}, author = {Wang, Ning and Bian, Jiang and Li, Yuchen and Li, Xuhong and Mumtaz, Shahid and Kong, Linghe and Xiong, Haoyi}, journal = {Nature Machine Intelligence}, volume = {6}, pages = {548--557}, year = {2024}, doi = {10.1038/s42256-024-00836-4} } ``` ## Credits Original model and code by Wang et al. Source: [GitHub](https://github.com/CatIIIIIIII/RNAErnie). The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code) and reviewed manually by Taykhoom Dalal. ## License Apache 2.0, following the original repository.