Note:

This model is a copied version of mRNABERT which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original mRNABERT model can be found at https://huggingface.co/YYLY66/mRNABERT. If you use this model please provide attribution to the original authors of mRNABERT and the MosaicML team for their implementation.

The only changes made were in flash_attn_triton.py and bert_layers.py.

In flash_attn_triton.py, the change was to alter:

  1. qk += tl.dot(q, k, trans_b=True) to qk += tl.dot(q, tl.trans(k)) according to the solution provided in the flash attention issue. There were 2 other instances of the use of this trans_b=True argument in the file which were also changed to use the same solution.

In bert_layers.py the changes were:

  1. use_flash_attn config flag (BertUnpadSelfAttention): Added self.use_flash_attn = getattr(config, 'use_flash_attn', True). Setting use_flash_attn: false in the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.

  2. Attention weight return (BertUnpadSelfAttention, BertUnpadAttention, BertLayer): Added a return_attn_weights: bool = False parameter threaded through the call chain. When enabled, the eager path returns the (B, H, T, T) attention probability tensor alongside the hidden states.

  3. HF-compatible encoder output (BertEncoder): Added output_attentions: bool = False. When output_all_encoded_layers=True, each layer's hidden states are now padded back to (B, T, D) before collection (previously unpadded (nnz, D)), and the embedding output is prepended as index 0 to match the HuggingFace hidden_states convention.

  4. Standard HuggingFace output objects (BertModel, BertForMaskedLM, BertForSequenceClassification): BertModel.forward now accepts output_hidden_states and output_attentions keyword arguments and returns a BaseModelOutputWithPooling object with .last_hidden_state, .pooler_output, .hidden_states, and .attentions fields. BertForMaskedLM and BertForSequenceClassification were updated accordingly to read from these named fields.

Original README:

mRNABERT

A robust language model pre-trained on over 18 million high-quality mRNA sequences, incorporating contrastive learning to integrate the semantic features of amino acids.

This is the official pre-trained model introduced in mRNABERT: advancing mRNA sequence design with a universal language model and comprehensive dataset.

The repository of mRNABERT is at yyly6/mRNABERT.

Intended uses & limitations

The model could be used for mRNA sequences feature extraction or to be fine-tuned on downstream tasks. **Before inputting the model, you need to preprocess the data: use single-letter separation for the UTR regions and three-character separation for the CDS regions.**For full examples, please see our code on data processing.

Training data

The mRNABERT model was pretrained on a comprehensive mRNA dataset, which originally consisted of approximately 36 million complete CDS or mRNA sequences. After cleaning, this number was reduced to 18 million.

Usage

To load the model from huggingface:

import torch
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig

config = BertConfig.from_pretrained("YYLY66/mRNABERT")
tokenizer = AutoTokenizer.from_pretrained("YYLY66/mRNABERT")
model = AutoModel.from_pretrained("YYLY66/mRNABERT", trust_remote_code=True, config=config)

To extract the embeddings of mRNA sequences:

seq = ["A T C G G A GGG CCC TTT",
       "A T C G",
       "TTT CCC GAC ATG"]  #Separate the sequences with spaces.

encoding = tokenizer.batch_encode_plus(seq, add_special_tokens=True, padding='longest', return_tensors="pt")

input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

output = model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output[0]

attention_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state)  # Shape : [batch_size, seq_length, hidden_size]

# Sum embeddings along the batch dimension
sum_embeddings = torch.sum(last_hidden_state * attention_mask, dim=1)

# Also sum the masks along the batch dimension
sum_masks = attention_mask.sum(1)

# Compute mean embedding.
mean_embedding = sum_embeddings / sum_masks  #Shape:[batch_size, hidden_size]

The extracted embeddings can be used for contrastive learning pretraining or as a feature extractor for protein-related downstream tasks.

Citation

BibTeX:

@article{xiong2025mrnabert,
  title={mRNABERT: advancing mRNA sequence design with a universal language model and comprehensive dataset},
  author={Xiong, Ying and Wang, Aowen and Kang, Yu and Shen, Chao and Hsieh, Chang-Yu and Hou, Tingjun},
  journal={Nature Communications},
  volume={16},
  number={1},
  pages={10371},
  year={2025},
  publisher={Nature Publishing Group UK London},
}

Contact

If you have any question, please feel free to email us (xiongying@zju.edu.cn).

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support