Note:
This model is a copied version of DNABERT-2-117M which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original DNABERT-2-117M model can be found at https://huggingface.co/zhihan1996/DNABERT-2-117M. If you use this model please provide attribution to the original authors of DNABERT-2 and the MosaicML team for their implementation.
The only changes made were in flash_attn_triton.py and bert_layers.py.
In flash_attn_triton.py, the change was to alter:
qk += tl.dot(q, k, trans_b=True)toqk += tl.dot(q, tl.trans(k))according to the solution provided in the flash attention issue. There were 2 other instances of the use of thistrans_b=Trueargument in the file which were also changed to use the same solution.
In bert_layers.py the changes were:
use_flash_attnconfig flag (BertUnpadSelfAttention): Addedself.use_flash_attn = getattr(config, 'use_flash_attn', True). Settinguse_flash_attn: falsein the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.Attention weight return (
BertUnpadSelfAttention,BertUnpadAttention,BertLayer): Added areturn_attn_weights: bool = Falseparameter threaded through the call chain. When enabled, the eager path returns the(B, H, T, T)attention probability tensor alongside the hidden states.HF-compatible encoder output (
BertEncoder): Addedoutput_attentions: bool = False. Whenoutput_all_encoded_layers=True, each layer's hidden states are now padded back to(B, T, D)before collection (previously unpadded(nnz, D)), and the embedding output is prepended as index 0 to match the HuggingFacehidden_statesconvention.Standard HuggingFace output objects (
BertModel,BertForMaskedLM,BertForSequenceClassification):BertModel.forwardnow acceptsoutput_hidden_statesandoutput_attentionskeyword arguments and returns aBaseModelOutputWithPoolingobject with.last_hidden_state,.pooler_output,.hidden_states, and.attentionsfields.BertForMaskedLMandBertForSequenceClassificationwere updated accordingly to read from these named fields.
Original README:
This is the official pre-trained model introduced in DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome .
We sincerely appreciate the MosaicML team for the MosaicBERT implementation, which serves as the base of DNABERT-2 development.
DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.
To load the model from huggingface:
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
To calculate the embedding of a dna sequence
dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states = model(inputs)[0] # [1, sequence_length, 768]
# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768
# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768
- Downloads last month
- -