--- language: - rna library_name: transformers tags: - RNA - mRNA - bert - language-model - flash-attention license: apache-2.0 --- # mRNABERT Weights and tokenizer for [mRNABERT](https://huggingface.co/YYLY66/mRNABERT) (Xiong et al., Nature Communications 2025), loaded with the bug-fixed model code from [Taykhoom/MosaicBERT-updated](https://huggingface.co/Taykhoom/MosaicBERT-updated). mRNABERT is a language model pre-trained on 18 million mRNA sequences incorporating contrastive learning to integrate semantic features of amino acids. **This repo contains only weights and tokenizer files.** The model code is loaded automatically from `Taykhoom/MosaicBERT-updated` via `trust_remote_code=True`. See that repo for the full list of bugs fixed relative to the original MosaicBERT implementation. ## Architecture mRNABERT uses the MosaicBERT architecture with an mRNA-specific vocabulary. | Parameter | Value | |---|---| | Layers | 12 | | Attention heads | 12 | | Embedding dimension | 768 | | Vocabulary size | 74 (5 special + 5 single-nt + 64 codons) | | Positional encoding | ALiBi (no position embeddings) | | Attention | Flash Attention (packed QKV) | | FFN | Gated Linear Units (GeGLU) | | Padding | Unpadding (tokens concatenated, no padding overhead) | | Max sequence length | ~10000 nt (practical; MosaicBERT uses ALiBi and extrapolates to longer sequences) | | Parameters | ~114M | ### Vocabulary The tokenizer uses `BertTokenizer` with a hybrid vocabulary. Sequences are encoded in the **DNA alphabet (T, not U)** even though the model is trained on mRNA. | Range | Tokens | Use | |---|---|---| | 0-4 | `[PAD]` `[UNK]` `[CLS]` `[SEP]` `[MASK]` | Special tokens | | 5-9 | `A` `T` `C` `G` `N` | Single nucleotides (UTR regions) | | 10-73 | `AAA` ... `GGG` | All 64 codons (CDS regions) | ## Pretraining - **Objective:** Masked Language Modeling + contrastive learning (amino-acid semantic features) - **Data:** 18 million curated mRNA sequences - **Source checkpoint:** `pytorch_model.bin` from [YYLY66/mRNABERT](https://huggingface.co/YYLY66/mRNABERT) ## Parity Verification Hidden states verified max abs diff < 2.4e-05 at all 13 representation levels (embedding + 12 transformer layers) relative to the original implementation. Both models use `flash_attn_varlen_qkvpacked_func`; the small numerical differences are flash attention rounding, not a correctness issue. SDPA vs eager max diff = 1.81e-05. Verified on GPU with PyTorch 2.7 / CUDA 12.9. ## Usage mRNABERT requires CDS-aware preprocessing: UTR regions must be single-nucleotide space-separated and CDS regions must be codon space-separated. The tokenizer handles this automatically via `batch_encode_with_cds()` when a CDS track is available, or you can pass pre-formatted strings directly for simple use cases. Sequences use **T (not U)**. ### Embedding generation with CDS tracks (recommended) ```python import torch import numpy as np from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) model.eval() # Raw sequences (T not U) + per-nucleotide CDS track # cds[i] != 0 marks the start of a codon at position i sequences = ["ATCGATGTTTCCC", "AATGCCC"] cds_tracks = [ np.array([0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0]), # CDS starts at pos 3 np.array([0, 1, 0, 0, 1, 0, 0]), # CDS starts at pos 1 ] enc, chunk_counts = tokenizer.batch_encode_with_cds( sequences, cds_tracks, return_tensors="pt", padding=True ) with torch.no_grad(): out = model(**enc) mask = enc["attention_mask"].unsqueeze(-1).float() mean_emb = (out.last_hidden_state * mask).sum(1) / mask.sum(1) # (batch, 768) ``` ### Embedding generation without CDS tracks Pass pre-formatted space-separated strings directly when no CDS annotation is available: ```python import torch from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) model.eval() # Space-separated: single nt for UTRs, codons for CDS; use T not U sequences = [ "A T C G G A GGG CCC TTT AAA", # mixed UTR + CDS "ATG TTT CCC GAC TAA", # CDS only ] enc = tokenizer(sequences, return_tensors="pt", padding=True) with torch.no_grad(): out = model(**enc) mask = enc["attention_mask"].unsqueeze(-1).float() mean_emb = (out.last_hidden_state * mask).sum(1) / mask.sum(1) # (batch, 768) ``` ### MLM logits ```python import torch from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) model.eval() enc = tokenizer(["A T C G [MASK] CCC TTT"], return_tensors="pt") with torch.no_grad(): logits = model(**enc).logits # (1, seq_len, 74) ``` ### Attention implementation ```python # SDPA (default on PyTorch >= 2.0) model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True, attn_implementation="sdpa") # Flash Attention 2 (requires: pip install flash-attn --no-build-isolation) model = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True, attn_implementation="flash_attention_2") ``` ### Fine-tuning ```python import torch.nn as nn from transformers import AutoModel class mRNABERTClassifier(nn.Module): def __init__(self, num_labels): super().__init__() self.encoder = AutoModel.from_pretrained("Taykhoom/mRNABERT", trust_remote_code=True) self.head = nn.Linear(768, num_labels) def forward(self, input_ids, attention_mask): out = self.encoder(input_ids, attention_mask=attention_mask) mask = attention_mask.unsqueeze(-1).float() pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1) return self.head(pooled) ``` ## Citation ```bibtex @article{xiong2025_mrnabert, title = {{mRNABERT}: advancing {mRNA} sequence design with a universal language model and comprehensive dataset}, author = {Xiong, Ying and Wang, Aowen, and Kang, Yu and Shen, Chao and Hsieh, Chang-Yu and Hou, Tingjun}, journal = {Nature Communications}, volume = {16}, number = {1}, pages = {10371}, year = {2025}, doi = {10.1038/s41467-025-65340-8} } ``` ## Credits Original mRNABERT model and weights by Xiong et al. Source: [GitHub](https://github.com/yyly6/mRNABERT). Bug-fixed model code by [Taykhoom/MosaicBERT-updated](https://huggingface.co/Taykhoom/MosaicBERT-updated), authored primarily by [Claude Code](https://claude.ai/code) and reviewed manually by Taykhoom Dalal. ## License Apache 2.0, following the original repository.