Note:
This model is a copied version of DNABERT-S which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original DNABERT-S model can be found at https://huggingface.co/zhihan1996/DNABERT-S. If you use this model please provide attribution to the original authors of DNABERT-S and the MosaicML team for their implementation.
The only changes made were in flash_attn_triton.py and bert_layers.py.
In flash_attn_triton.py, the change was to alter:
qk += tl.dot(q, k, trans_b=True)toqk += tl.dot(q, tl.trans(k))according to the solution provided in the flash attention issue. There were 2 other instances of the use of thistrans_b=Trueargument in the file which were also changed to use the same solution.
In bert_layers.py the changes were:
use_flash_attnconfig flag (BertUnpadSelfAttention): Addedself.use_flash_attn = getattr(config, 'use_flash_attn', True). Settinguse_flash_attn: falsein the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.Attention weight return (
BertUnpadSelfAttention,BertUnpadAttention,BertLayer): Added areturn_attn_weights: bool = Falseparameter threaded through the call chain. When enabled, the eager path returns the(B, H, T, T)attention probability tensor alongside the hidden states.HF-compatible encoder output (
BertEncoder): Addedoutput_attentions: bool = False. Whenoutput_all_encoded_layers=True, each layer's hidden states are now padded back to(B, T, D)before collection (previously unpadded(nnz, D)), and the embedding output is prepended as index 0 to match the HuggingFacehidden_statesconvention.Standard HuggingFace output objects (
BertModel,BertForMaskedLM,BertForSequenceClassification):BertModel.forwardnow acceptsoutput_hidden_statesandoutput_attentionskeyword arguments and returns aBaseModelOutputWithPoolingobject with.last_hidden_state,.pooler_output,.hidden_states, and.attentionsfields.BertForMaskedLMandBertForSequenceClassificationwere updated accordingly to read from these named fields.
If you use this model please provide attribution to:
@misc{zhou2024dnaberts,
title={DNABERT-S: Learning Species-Aware DNA Embedding with Genome Foundation Models},
author={Zhihan Zhou and Winmin Wu and Harrison Ho and Jiayi Wang and Lizhen Shi and Ramana V Davuluri and Zhong Wang and Han Liu},
year={2024},
eprint={2402.08777},
archivePrefix={arXiv},
primaryClass={q-bio.GN}
}
Original Github: https://github.com/MAGICS-LAB/DNABERT_S Original HF repo: https://huggingface.co/zhihan1996/DNABERT-S Original paper: https://arxiv.org/abs/2402.08777
- Downloads last month
- -