Note:

This model is a copied version of DNABERT-2-117M which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original DNABERT-2-117M model can be found at https://huggingface.co/zhihan1996/DNABERT-2-117M. If you use this model please provide attribution to the original authors of DNABERT-2 and the MosaicML team for their implementation.

The only changes made were in flash_attn_triton.py and bert_layers.py.

In flash_attn_triton.py, the change was to alter:

  1. qk += tl.dot(q, k, trans_b=True) to qk += tl.dot(q, tl.trans(k)) according to the solution provided in the flash attention issue. There were 2 other instances of the use of this trans_b=True argument in the file which were also changed to use the same solution.

In bert_layers.py the changes were:

  1. use_flash_attn config flag (BertUnpadSelfAttention): Added self.use_flash_attn = getattr(config, 'use_flash_attn', True). Setting use_flash_attn: false in the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.

  2. Attention weight return (BertUnpadSelfAttention, BertUnpadAttention, BertLayer): Added a return_attn_weights: bool = False parameter threaded through the call chain. When enabled, the eager path returns the (B, H, T, T) attention probability tensor alongside the hidden states.

  3. HF-compatible encoder output (BertEncoder): Added output_attentions: bool = False. When output_all_encoded_layers=True, each layer's hidden states are now padded back to (B, T, D) before collection (previously unpadded (nnz, D)), and the embedding output is prepended as index 0 to match the HuggingFace hidden_states convention.

  4. Standard HuggingFace output objects (BertModel, BertForMaskedLM, BertForSequenceClassification): BertModel.forward now accepts output_hidden_states and output_attentions keyword arguments and returns a BaseModelOutputWithPooling object with .last_hidden_state, .pooler_output, .hidden_states, and .attentions fields. BertForMaskedLM and BertForSequenceClassification were updated accordingly to read from these named fields.

Original README:

This is the official pre-trained model introduced in DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome .

We sincerely appreciate the MosaicML team for the MosaicBERT implementation, which serves as the base of DNABERT-2 development.

DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.

To load the model from huggingface:

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

To calculate the embedding of a dna sequence

dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states = model(inputs)[0] # [1, sequence_length, 768]

# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for Taykhoom/DNABERT2-patched