Note:

This model is a copied version of DNABERT-S which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original DNABERT-S model can be found at https://huggingface.co/zhihan1996/DNABERT-S. If you use this model please provide attribution to the original authors of DNABERT-S and the MosaicML team for their implementation.

The only changes made were in flash_attn_triton.py and bert_layers.py.

In flash_attn_triton.py, the change was to alter:

  1. qk += tl.dot(q, k, trans_b=True) to qk += tl.dot(q, tl.trans(k)) according to the solution provided in the flash attention issue. There were 2 other instances of the use of this trans_b=True argument in the file which were also changed to use the same solution.

In bert_layers.py the changes were:

  1. use_flash_attn config flag (BertUnpadSelfAttention): Added self.use_flash_attn = getattr(config, 'use_flash_attn', True). Setting use_flash_attn: false in the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.

  2. Attention weight return (BertUnpadSelfAttention, BertUnpadAttention, BertLayer): Added a return_attn_weights: bool = False parameter threaded through the call chain. When enabled, the eager path returns the (B, H, T, T) attention probability tensor alongside the hidden states.

  3. HF-compatible encoder output (BertEncoder): Added output_attentions: bool = False. When output_all_encoded_layers=True, each layer's hidden states are now padded back to (B, T, D) before collection (previously unpadded (nnz, D)), and the embedding output is prepended as index 0 to match the HuggingFace hidden_states convention.

  4. Standard HuggingFace output objects (BertModel, BertForMaskedLM, BertForSequenceClassification): BertModel.forward now accepts output_hidden_states and output_attentions keyword arguments and returns a BaseModelOutputWithPooling object with .last_hidden_state, .pooler_output, .hidden_states, and .attentions fields. BertForMaskedLM and BertForSequenceClassification were updated accordingly to read from these named fields.

If you use this model please provide attribution to:

@misc{zhou2024dnaberts,
      title={DNABERT-S: Learning Species-Aware DNA Embedding with Genome Foundation Models},
      author={Zhihan Zhou and Winmin Wu and Harrison Ho and Jiayi Wang and Lizhen Shi and Ramana V Davuluri and Zhong Wang and Han Liu},
      year={2024},
      eprint={2402.08777},
      archivePrefix={arXiv},
      primaryClass={q-bio.GN}
}

Original Github: https://github.com/MAGICS-LAB/DNABERT_S Original HF repo: https://huggingface.co/zhihan1996/DNABERT-S Original paper: https://arxiv.org/abs/2402.08777

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for Taykhoom/DNABERTS-patched