--- language: - rna library_name: transformers tags: - RNA - language-model license: apache-2.0 --- # ERNIE-RNA ERNIE-RNA is an RNA-specific large language model that incorporates RNA base-pairing potential as a recurrent 2D structural bias into each attention layer, enabling the model to capture secondary structure information during pretraining. ## Architecture | Parameter | Value | |---|---| | Layers | 12 | | Attention heads | 12 | | Embedding dimension | 768 | | FFN dimension | 3072 | | Vocabulary size | 25 | | Positional encoding | Sinusoidal (fairseq-style) | | Architecture | Post-LN Transformer with recurrent 2D RNA pairing bias | | Max sequence length | 1024 | ### Vocabulary | Token | ID | Notes | |---|---|---| | `` | 0 | Prepended to every sequence | | `` | 1 | Padding token | | `` | 2 | Appended to every sequence | | `` | 3 | Unknown token | | G | 4 | | | A | 5 | | | U | 6 | T is silently mapped to U during tokenization | | C | 7 | | | N | 8 | Ambiguous nucleotide | | Y-I | 9-20 | IUPAC ambiguity codes | | madeupword0-2 | 21-23 | Padding tokens from original vocab | | `` | 24 | MLM mask token | ### 2D RNA Pairing Bias ERNIE-RNA computes a pairwise RNA base-pairing potential matrix from the input sequence at the start of each forward pass. This matrix (shape `[B, T, T, 1]`) is projected to `[B, H, T, T]` via a 2-layer MLP (1 -> 6 -> H, with GELU) and added to the attention logits in the first layer. The pre-softmax attention scores then become the updated 2D bias for the next layer, creating a recurrent structural information pathway across all 12 transformer layers. Base-pairing scores: A-U = 2.0, G-C = 3.0, G-U wobble = 0.8. ## Pretraining - **Objective:** Masked language modeling (MLM) on RNA sequences - **Data:** RNAcentral (non-redundant RNA sequences) - **Source checkpoint:** `ERNIE-RNA_pretrain.pt` ### Checkpoint selection Single pretrained checkpoint from the original repository. Used as-is; no fine-tuned variants are included in this release. ## Parity Verification Hidden-state representations verified identical (max abs diff = 1.82e-06) to the original implementation at all 13 representation levels (embedding + 12 transformer layers). Verified on GPU with PyTorch 2.7 / CUDA 12. Only `attn_implementation="eager"` is supported (see Implementation Notes). ## Related Models See the full [ERNIE-RNA collection](https://huggingface.co/collections/Taykhoom/ernie-rna-6a20c1a8ea56c00a74e2dd93). | Model | Notes | |---|---| | **[Taykhoom/ERNIE-RNA](https://huggingface.co/Taykhoom/ERNIE-RNA)** | **Pretrained model (this model)** | | [Taykhoom/ERNIE-RNA-SS](https://huggingface.co/Taykhoom/ERNIE-RNA-SS) | SS fine-tuned (bpRNA-new), backbone only | | [Taykhoom/ERNIE-RNA-MRL](https://huggingface.co/Taykhoom/ERNIE-RNA-MRL) | UTR MRL fine-tuned, backbone only | ## Usage ### Embedding generation ```python import torch from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True) model = AutoModel.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True) model.eval() sequences = ["AUGCAUGCAUGC", "GGGGCCCCGGGG"] enc = tokenizer(sequences, return_tensors="pt", padding=True) with torch.no_grad(): out = model(**enc) cls_emb = out.last_hidden_state[:, 0, :] # (batch, 768) -- CLS token token_emb = out.last_hidden_state # (batch, seq_len, 768) # Intermediate layers out_all = model(**enc, output_hidden_states=True) layer6_emb = out_all.hidden_states[6] # (batch, seq_len, 768) ``` ### MLM logits ```python import torch from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained("Taykhoom/ERNIE-RNA", trust_remote_code=True) model.eval() enc = tokenizer(["AUGAUG"], return_tensors="pt") with torch.no_grad(): logits = model(**enc).logits # (1, seq_len, 25) ``` ### Fine-tuning Use the CLS token embedding (`last_hidden_state[:, 0, :]`) as input to a prediction head for sequence-level tasks. For token-level tasks, use `last_hidden_state` directly. ## Implementation Notes ERNIE-RNA's recurrent 2D bias is updated from the pre-softmax attention scores at every layer (the raw QK logits become the bias input for the next layer). Fused attention kernels (SDPA, FlashAttention) do not expose pre-softmax scores, so they cannot maintain this recurrent pathway. Only `attn_implementation="eager"` is supported; requesting `sdpa` or `flash_attention_2` raises a `ValueError`. The `twod_proj` MLP is always run in float32 (matching the original) regardless of the model's compute dtype. ## Citation ```bibtex @article{yin2025_ernierna, title = {{ERNIE-RNA}: an {RNA} language model with structure-enhanced representations}, author = {Yin, Weijie and Zhang, Zhaoyu and He, Liang and Jiang, Rui and Zhang, Shuo and Liu, Gan and Zeng, Xuezhi and Zhao, Wen and Gao, Xiaowo}, journal = {Nature Communications}, volume = {16}, number = {1}, pages = {8407}, year = {2025}, doi = {10.1038/s41467-025-64972-0} } ``` ## Credits Original model and code by Yin et al. Source: [GitHub](https://github.com/Bruce-ywj/ERNIE-RNA). The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code) and reviewed manually by Taykhoom Dalal. ## License Apache 2.0, following the original repository.