--- language: - rna library_name: transformers tags: - RNA - language-model license: apache-2.0 --- # ERNIE-RNA-SS ERNIE-RNA fine-tuned on RNA secondary structure (SS) prediction, backbone only. The prediction head has been discarded; only the encoder weights are included. ## Architecture | Parameter | Value | |---|---| | Layers | 12 | | Attention heads | 12 | | Embedding dimension | 768 | | FFN dimension | 3072 | | Vocabulary size | 25 | | Positional encoding | Sinusoidal (fairseq-style) | | Architecture | Post-LN Transformer with recurrent 2D RNA pairing bias | | Max sequence length | 1024 | See [Taykhoom/ERNIE-RNA](https://huggingface.co/Taykhoom/ERNIE-RNA) for the vocabulary table and full architecture description. ## Pretraining + Fine-tuning - **Pretraining objective:** Masked language modeling on RNAcentral - **Fine-tuning task:** RNA secondary structure prediction (base-pair prediction) - **Fine-tuning data:** RNA3DB - **Source checkpoint:** `RNA3DB.pt` ### Checkpoint selection Six SS fine-tuned checkpoints were available (bpRNA-1m, bpRNA-new, RIVAS, RNA3DB, RNAStralign, bpRNA-1m_RNAStralign). All six were evaluated via linear probing on three tasks from [mRNABench](https://huggingface.co/collections/morrislab/mrnabench): RNA subcellular localization (`rna-loc-fazal`), mRNA half-life prediction (`rnahl-human`), and variant effect prediction (`vep-traitgym-mendelian`). `bpRNA-new` was excluded: its backbone is identical to the pretrained ERNIE-RNA checkpoint (the backbone was frozen during SS fine-tuning), making it equivalent to `Taykhoom/ERNIE-RNA`. The remaining five were evaluated via linear probing on three tasks from [mRNABench](https://huggingface.co/collections/morrislab/mrnabench): RNA subcellular localization, mRNA half-life prediction, and variant effect prediction. `RNA3DB` was the best-performing checkpoint across all three tasks and was selected for this release. ## Parity Verification Backbone weights are extracted directly from the fine-tuned checkpoint using the same key mapping and architecture as the verified pretrained model. The underlying architecture is identical to [Taykhoom/ERNIE-RNA](https://huggingface.co/Taykhoom/ERNIE-RNA), which was verified at max abs diff = 1.82e-06 across all 13 representation levels. Only `attn_implementation="eager"` is supported (see Implementation Notes). ## Related Models See the full [ERNIE-RNA collection](https://huggingface.co/collections/Taykhoom/ernie-rna-6a20c1a8ea56c00a74e2dd93). | Model | Notes | |---|---| | [Taykhoom/ERNIE-RNA](https://huggingface.co/Taykhoom/ERNIE-RNA) | Pretrained model | | **[Taykhoom/ERNIE-RNA-SS](https://huggingface.co/Taykhoom/ERNIE-RNA-SS)** | **This model — SS fine-tuned** | | [Taykhoom/ERNIE-RNA-MRL](https://huggingface.co/Taykhoom/ERNIE-RNA-MRL) | UTR MRL fine-tuned | ## Usage ### Embedding generation ```python import torch from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("Taykhoom/ERNIE-RNA-SS", trust_remote_code=True) model = AutoModel.from_pretrained("Taykhoom/ERNIE-RNA-SS", trust_remote_code=True) model.eval() sequences = ["AUGCAUGCAUGC", "GGGGCCCCGGGG"] enc = tokenizer(sequences, return_tensors="pt", padding=True) with torch.no_grad(): out = model(**enc) cls_emb = out.last_hidden_state[:, 0, :] # (batch, 768) -- CLS token token_emb = out.last_hidden_state # (batch, seq_len, 768) # Intermediate layers out_all = model(**enc, output_hidden_states=True) layer6_emb = out_all.hidden_states[6] # (batch, seq_len, 768) ``` ### Fine-tuning Use the CLS token embedding (`last_hidden_state[:, 0, :]`) as input to a prediction head for sequence-level tasks. For token-level tasks (e.g. base-pair prediction), use `last_hidden_state` directly. ## Implementation Notes ERNIE-RNA's recurrent 2D bias is updated from the pre-softmax attention scores at every layer (the raw QK logits become the bias input for the next layer). Fused attention kernels (SDPA, FlashAttention) do not expose pre-softmax scores, so they cannot maintain this recurrent pathway. Only `attn_implementation="eager"` is supported; requesting `sdpa` or `flash_attention_2` raises a `ValueError`. The `twod_proj` MLP is always run in float32 (matching the original) regardless of the model's compute dtype. ## Citation ```bibtex @article{yin2025_ernierna, title = {{ERNIE-RNA}: an {RNA} language model with structure-enhanced representations}, author = {Yin, Weijie and Zhang, Zhaoyu and He, Liang and Jiang, Rui and Zhang, Shuo and Liu, Gan and Zeng, Xuezhi and Zhao, Wen and Gao, Xiaowo}, journal = {Nature Communications}, volume = {16}, number = {1}, pages = {8407}, year = {2025}, doi = {10.1038/s41467-025-64972-0} } ``` ## Credits Original model and code by Yin et al. Source: [GitHub](https://github.com/Bruce-ywj/ERNIE-RNA). The HF conversion code was authored primarily by [Claude Code](https://claude.ai/code) and reviewed manually by Taykhoom Dalal. ## License Apache 2.0, following the original repository.