add attention return + support eager attention or triton FA2 via config.use_flash_attn
Browse files- README.md +24 -3
- bert_layers.py +92 -39
- flash_attn_triton.py +3 -3
README.md
CHANGED
|
@@ -7,12 +7,33 @@ tags:
|
|
| 7 |
- medical
|
| 8 |
- genomics
|
| 9 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
This is the official pre-trained model introduced in [DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome
|
| 11 |
](https://arxiv.org/pdf/2306.15006.pdf).
|
| 12 |
|
| 13 |
-
We sincerely appreciate the MosaicML team for the [MosaicBERT](https://openreview.net/forum?id=5zipcfLC2Z) implementation, which serves as the base of DNABERT-2 development.
|
| 14 |
|
| 15 |
-
DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.
|
| 16 |
|
| 17 |
To load the model from huggingface:
|
| 18 |
```
|
|
@@ -36,4 +57,4 @@ print(embedding_mean.shape) # expect to be 768
|
|
| 36 |
# embedding with max pooling
|
| 37 |
embedding_max = torch.max(hidden_states[0], dim=0)[0]
|
| 38 |
print(embedding_max.shape) # expect to be 768
|
| 39 |
-
```
|
|
|
|
| 7 |
- medical
|
| 8 |
- genomics
|
| 9 |
---
|
| 10 |
+
|
| 11 |
+
# Note:
|
| 12 |
+
This model is a copied version of DNABERT-2-117M which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original DNABERT-2-117M model can be found at https://huggingface.co/zhihan1996/DNABERT-2-117M. If you use this model please provide attribution to the original authors of DNABERT-2 and the MosaicML team for their implementation.
|
| 13 |
+
|
| 14 |
+
The only changes made were in `flash_attn_triton.py` and `bert_layers.py`.
|
| 15 |
+
|
| 16 |
+
In `flash_attn_triton.py`, the change was to alter:
|
| 17 |
+
|
| 18 |
+
1. ```qk += tl.dot(q, k, trans_b=True)``` to ```qk += tl.dot(q, tl.trans(k))``` according to the solution provided in the flash attention issue. There were 2 other instances of the use of this ```trans_b=True``` argument in the file which were also changed to use the same solution.
|
| 19 |
+
|
| 20 |
+
In `bert_layers.py` the changes were:
|
| 21 |
+
|
| 22 |
+
1. **`use_flash_attn` config flag** (`BertUnpadSelfAttention`): Added `self.use_flash_attn = getattr(config, 'use_flash_attn', True)`. Setting `use_flash_attn: false` in the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.
|
| 23 |
+
|
| 24 |
+
2. **Attention weight return** (`BertUnpadSelfAttention`, `BertUnpadAttention`, `BertLayer`): Added a `return_attn_weights: bool = False` parameter threaded through the call chain. When enabled, the eager path returns the `(B, H, T, T)` attention probability tensor alongside the hidden states.
|
| 25 |
+
|
| 26 |
+
3. **HF-compatible encoder output** (`BertEncoder`): Added `output_attentions: bool = False`. When `output_all_encoded_layers=True`, each layer's hidden states are now padded back to `(B, T, D)` before collection (previously unpadded `(nnz, D)`), and the embedding output is prepended as index 0 to match the HuggingFace `hidden_states` convention.
|
| 27 |
+
|
| 28 |
+
4. **Standard HuggingFace output objects** (`BertModel`, `BertForMaskedLM`, `BertForSequenceClassification`): `BertModel.forward` now accepts `output_hidden_states` and `output_attentions` keyword arguments and returns a `BaseModelOutputWithPooling` object with `.last_hidden_state`, `.pooler_output`, `.hidden_states`, and `.attentions` fields. `BertForMaskedLM` and `BertForSequenceClassification` were updated accordingly to read from these named fields.
|
| 29 |
+
|
| 30 |
+
# Original README:
|
| 31 |
This is the official pre-trained model introduced in [DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome
|
| 32 |
](https://arxiv.org/pdf/2306.15006.pdf).
|
| 33 |
|
| 34 |
+
We sincerely appreciate the MosaicML team for the [MosaicBERT](https://openreview.net/forum?id=5zipcfLC2Z) implementation, which serves as the base of DNABERT-2 development.
|
| 35 |
|
| 36 |
+
DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.
|
| 37 |
|
| 38 |
To load the model from huggingface:
|
| 39 |
```
|
|
|
|
| 57 |
# embedding with max pooling
|
| 58 |
embedding_max = torch.max(hidden_states[0], dim=0)[0]
|
| 59 |
print(embedding_max.shape) # expect to be 768
|
| 60 |
+
```
|
bert_layers.py
CHANGED
|
@@ -16,7 +16,8 @@ import torch.nn as nn
|
|
| 16 |
from einops import rearrange
|
| 17 |
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
| 18 |
from transformers.activations import ACT2FN
|
| 19 |
-
from transformers.modeling_outputs import (
|
|
|
|
| 20 |
SequenceClassifierOutput)
|
| 21 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 22 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -120,6 +121,7 @@ class BertUnpadSelfAttention(nn.Module):
|
|
| 120 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 121 |
self.p_dropout = config.attention_probs_dropout_prob
|
| 122 |
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
|
|
|
| 123 |
|
| 124 |
# Warn if defaulting to pytorch because of import issues
|
| 125 |
if flash_attn_qkvpacked_func is None:
|
|
@@ -129,7 +131,8 @@ class BertUnpadSelfAttention(nn.Module):
|
|
| 129 |
|
| 130 |
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
| 131 |
max_seqlen_in_batch: int, indices: torch.Tensor,
|
| 132 |
-
attn_mask: torch.Tensor, bias: torch.Tensor
|
|
|
|
| 133 |
"""Perform self-attention.
|
| 134 |
|
| 135 |
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
|
|
@@ -158,7 +161,7 @@ class BertUnpadSelfAttention(nn.Module):
|
|
| 158 |
'b s (t h d) -> b s t h d',
|
| 159 |
t=3,
|
| 160 |
h=self.num_attention_heads)
|
| 161 |
-
if self.p_dropout or flash_attn_qkvpacked_func is None:
|
| 162 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
| 163 |
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
|
| 164 |
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
|
|
@@ -172,6 +175,7 @@ class BertUnpadSelfAttention(nn.Module):
|
|
| 172 |
3) # b s h d
|
| 173 |
else:
|
| 174 |
# Triton implementation only supports 0 attention dropout
|
|
|
|
| 175 |
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
| 176 |
if convert_dtype:
|
| 177 |
# Triton implementation only supports fp16 and bf16
|
|
@@ -187,7 +191,10 @@ class BertUnpadSelfAttention(nn.Module):
|
|
| 187 |
|
| 188 |
# attn_mask is 1 for attend and 0 for don't
|
| 189 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
# Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
|
|
@@ -225,6 +232,7 @@ class BertUnpadAttention(nn.Module):
|
|
| 225 |
indices: Optional[torch.Tensor] = None,
|
| 226 |
attn_mask: Optional[torch.Tensor] = None,
|
| 227 |
bias: Optional[torch.Tensor] = None,
|
|
|
|
| 228 |
) -> torch.Tensor:
|
| 229 |
"""Forward pass for scaled self-attention without padding.
|
| 230 |
|
|
@@ -237,14 +245,24 @@ class BertUnpadAttention(nn.Module):
|
|
| 237 |
indices: None or (total_nnz,)
|
| 238 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
| 239 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
|
|
|
| 240 |
"""
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
if subset_idx is not None:
|
| 244 |
-
|
| 245 |
-
|
| 246 |
else:
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
class BertGatedLinearUnitMLP(nn.Module):
|
|
@@ -312,6 +330,7 @@ class BertLayer(nn.Module):
|
|
| 312 |
indices: Optional[torch.Tensor] = None,
|
| 313 |
attn_mask: Optional[torch.Tensor] = None,
|
| 314 |
bias: Optional[torch.Tensor] = None,
|
|
|
|
| 315 |
) -> torch.Tensor:
|
| 316 |
"""Forward pass for a BERT layer, including both attention and MLP.
|
| 317 |
|
|
@@ -324,10 +343,19 @@ class BertLayer(nn.Module):
|
|
| 324 |
indices: None or (total_nnz,)
|
| 325 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
| 326 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
|
|
|
| 327 |
"""
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
layer_output = self.mlp(attention_output)
|
|
|
|
|
|
|
| 331 |
return layer_output
|
| 332 |
|
| 333 |
|
|
@@ -410,7 +438,8 @@ class BertEncoder(nn.Module):
|
|
| 410 |
attention_mask: torch.Tensor,
|
| 411 |
output_all_encoded_layers: Optional[bool] = True,
|
| 412 |
subset_mask: Optional[torch.Tensor] = None,
|
| 413 |
-
|
|
|
|
| 414 |
|
| 415 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 416 |
extended_attention_mask = extended_attention_mask.to(
|
|
@@ -419,6 +448,12 @@ class BertEncoder(nn.Module):
|
|
| 419 |
|
| 420 |
attention_mask_bool = attention_mask.bool()
|
| 421 |
batch, seqlen = hidden_states.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
# Unpad inputs and mask. It will remove tokens that are padded.
|
| 423 |
# Assume ntokens is total number of tokens (padded and non-padded)
|
| 424 |
# and ntokens_unpad is total number of non-padded tokens.
|
|
@@ -442,17 +477,27 @@ class BertEncoder(nn.Module):
|
|
| 442 |
alibi_attn_mask = attn_bias + alibi_bias
|
| 443 |
|
| 444 |
all_encoder_layers = []
|
|
|
|
| 445 |
if subset_mask is None:
|
| 446 |
for layer_module in self.layer:
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
if output_all_encoded_layers:
|
| 455 |
-
|
|
|
|
|
|
|
| 456 |
# Pad inputs and mask. It will insert back zero-padded tokens.
|
| 457 |
# Assume ntokens is total number of tokens (padded and non-padded)
|
| 458 |
# and ntokens_unpad is total number of non-padded tokens.
|
|
@@ -483,7 +528,13 @@ class BertEncoder(nn.Module):
|
|
| 483 |
|
| 484 |
if not output_all_encoded_layers:
|
| 485 |
all_encoder_layers.append(hidden_states)
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
|
| 489 |
class BertPooler(nn.Module):
|
|
@@ -586,8 +637,10 @@ class BertModel(BertPreTrainedModel):
|
|
| 586 |
position_ids: Optional[torch.Tensor] = None,
|
| 587 |
output_all_encoded_layers: Optional[bool] = False,
|
| 588 |
masked_tokens_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
| 589 |
**kwargs
|
| 590 |
-
) ->
|
| 591 |
if attention_mask is None:
|
| 592 |
attention_mask = torch.ones_like(input_ids)
|
| 593 |
if token_type_ids is None:
|
|
@@ -606,11 +659,12 @@ class BertModel(BertPreTrainedModel):
|
|
| 606 |
first_col_mask[:, 0] = True
|
| 607 |
subset_mask = masked_tokens_mask | first_col_mask
|
| 608 |
|
| 609 |
-
encoder_outputs = self.encoder(
|
| 610 |
embedding_output,
|
| 611 |
attention_mask,
|
| 612 |
-
output_all_encoded_layers=
|
| 613 |
-
subset_mask=subset_mask
|
|
|
|
| 614 |
|
| 615 |
if masked_tokens_mask is None:
|
| 616 |
sequence_output = encoder_outputs[-1]
|
|
@@ -629,13 +683,12 @@ class BertModel(BertPreTrainedModel):
|
|
| 629 |
else:
|
| 630 |
pooled_output = None
|
| 631 |
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
return encoder_outputs, None
|
| 639 |
|
| 640 |
|
| 641 |
###################
|
|
@@ -755,8 +808,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
| 755 |
return_dict=return_dict,
|
| 756 |
masked_tokens_mask=masked_tokens_mask,
|
| 757 |
)
|
| 758 |
-
|
| 759 |
-
sequence_output = outputs
|
| 760 |
prediction_scores = self.cls(sequence_output)
|
| 761 |
|
| 762 |
loss = None
|
|
@@ -782,8 +835,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
| 782 |
return MaskedLMOutput(
|
| 783 |
loss=loss,
|
| 784 |
logits=prediction_scores,
|
| 785 |
-
hidden_states=outputs
|
| 786 |
-
attentions=
|
| 787 |
)
|
| 788 |
|
| 789 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
|
@@ -868,7 +921,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 868 |
return_dict=return_dict,
|
| 869 |
)
|
| 870 |
|
| 871 |
-
pooled_output = outputs
|
| 872 |
|
| 873 |
pooled_output = self.dropout(pooled_output)
|
| 874 |
logits = self.classifier(pooled_output)
|
|
@@ -906,7 +959,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 906 |
return SequenceClassifierOutput(
|
| 907 |
loss=loss,
|
| 908 |
logits=logits,
|
| 909 |
-
hidden_states=outputs
|
| 910 |
-
attentions=
|
| 911 |
)
|
| 912 |
|
|
|
|
| 16 |
from einops import rearrange
|
| 17 |
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
| 18 |
from transformers.activations import ACT2FN
|
| 19 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPooling,
|
| 20 |
+
MaskedLMOutput,
|
| 21 |
SequenceClassifierOutput)
|
| 22 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 23 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 121 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 122 |
self.p_dropout = config.attention_probs_dropout_prob
|
| 123 |
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
| 124 |
+
self.use_flash_attn = getattr(config, 'use_flash_attn', True)
|
| 125 |
|
| 126 |
# Warn if defaulting to pytorch because of import issues
|
| 127 |
if flash_attn_qkvpacked_func is None:
|
|
|
|
| 131 |
|
| 132 |
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
| 133 |
max_seqlen_in_batch: int, indices: torch.Tensor,
|
| 134 |
+
attn_mask: torch.Tensor, bias: torch.Tensor,
|
| 135 |
+
return_attn_weights: bool = False) -> torch.Tensor:
|
| 136 |
"""Perform self-attention.
|
| 137 |
|
| 138 |
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
|
|
|
|
| 161 |
'b s (t h d) -> b s t h d',
|
| 162 |
t=3,
|
| 163 |
h=self.num_attention_heads)
|
| 164 |
+
if self.p_dropout or flash_attn_qkvpacked_func is None or not self.use_flash_attn or return_attn_weights:
|
| 165 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
| 166 |
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
|
| 167 |
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
|
|
|
|
| 175 |
3) # b s h d
|
| 176 |
else:
|
| 177 |
# Triton implementation only supports 0 attention dropout
|
| 178 |
+
attention_probs = None
|
| 179 |
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
| 180 |
if convert_dtype:
|
| 181 |
# Triton implementation only supports fp16 and bf16
|
|
|
|
| 191 |
|
| 192 |
# attn_mask is 1 for attend and 0 for don't
|
| 193 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
| 194 |
+
out = rearrange(attention, 'nnz h d -> nnz (h d)')
|
| 195 |
+
if return_attn_weights:
|
| 196 |
+
return out, attention_probs # (nnz, D), (B, H, T, T)
|
| 197 |
+
return out
|
| 198 |
|
| 199 |
|
| 200 |
# Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
|
|
|
|
| 232 |
indices: Optional[torch.Tensor] = None,
|
| 233 |
attn_mask: Optional[torch.Tensor] = None,
|
| 234 |
bias: Optional[torch.Tensor] = None,
|
| 235 |
+
return_attn_weights: bool = False,
|
| 236 |
) -> torch.Tensor:
|
| 237 |
"""Forward pass for scaled self-attention without padding.
|
| 238 |
|
|
|
|
| 245 |
indices: None or (total_nnz,)
|
| 246 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
| 247 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
| 248 |
+
return_attn_weights: If True, return attention probabilities alongside output.
|
| 249 |
"""
|
| 250 |
+
if return_attn_weights:
|
| 251 |
+
self_output, attn_probs = self.self(
|
| 252 |
+
input_tensor, cu_seqlens, max_s, indices, attn_mask, bias,
|
| 253 |
+
return_attn_weights=True)
|
| 254 |
+
else:
|
| 255 |
+
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
|
| 256 |
+
attn_mask, bias)
|
| 257 |
+
attn_probs = None
|
| 258 |
if subset_idx is not None:
|
| 259 |
+
output = self.output(index_first_axis(self_output, subset_idx),
|
| 260 |
+
index_first_axis(input_tensor, subset_idx))
|
| 261 |
else:
|
| 262 |
+
output = self.output(self_output, input_tensor)
|
| 263 |
+
if return_attn_weights:
|
| 264 |
+
return output, attn_probs
|
| 265 |
+
return output
|
| 266 |
|
| 267 |
|
| 268 |
class BertGatedLinearUnitMLP(nn.Module):
|
|
|
|
| 330 |
indices: Optional[torch.Tensor] = None,
|
| 331 |
attn_mask: Optional[torch.Tensor] = None,
|
| 332 |
bias: Optional[torch.Tensor] = None,
|
| 333 |
+
return_attn_weights: bool = False,
|
| 334 |
) -> torch.Tensor:
|
| 335 |
"""Forward pass for a BERT layer, including both attention and MLP.
|
| 336 |
|
|
|
|
| 343 |
indices: None or (total_nnz,)
|
| 344 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
| 345 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
| 346 |
+
return_attn_weights: If True, return attention probabilities alongside output.
|
| 347 |
"""
|
| 348 |
+
if return_attn_weights:
|
| 349 |
+
attention_output, attn_probs = self.attention(
|
| 350 |
+
hidden_states, cu_seqlens, seqlen, subset_idx, indices,
|
| 351 |
+
attn_mask, bias, return_attn_weights=True)
|
| 352 |
+
else:
|
| 353 |
+
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
|
| 354 |
+
subset_idx, indices, attn_mask, bias)
|
| 355 |
+
attn_probs = None
|
| 356 |
layer_output = self.mlp(attention_output)
|
| 357 |
+
if return_attn_weights:
|
| 358 |
+
return layer_output, attn_probs
|
| 359 |
return layer_output
|
| 360 |
|
| 361 |
|
|
|
|
| 438 |
attention_mask: torch.Tensor,
|
| 439 |
output_all_encoded_layers: Optional[bool] = True,
|
| 440 |
subset_mask: Optional[torch.Tensor] = None,
|
| 441 |
+
output_attentions: bool = False,
|
| 442 |
+
) -> Tuple[List[torch.Tensor], Optional[Tuple[torch.Tensor, ...]]]:
|
| 443 |
|
| 444 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 445 |
extended_attention_mask = extended_attention_mask.to(
|
|
|
|
| 448 |
|
| 449 |
attention_mask_bool = attention_mask.bool()
|
| 450 |
batch, seqlen = hidden_states.shape[:2]
|
| 451 |
+
|
| 452 |
+
# Capture padded embedding output (B, T, D) before unpadding, so it
|
| 453 |
+
# can be prepended to all_encoder_layers as hidden_states index 0 in
|
| 454 |
+
# the HF convention (embedding = index 0, layer i = index i+1).
|
| 455 |
+
padded_embedding = hidden_states
|
| 456 |
+
|
| 457 |
# Unpad inputs and mask. It will remove tokens that are padded.
|
| 458 |
# Assume ntokens is total number of tokens (padded and non-padded)
|
| 459 |
# and ntokens_unpad is total number of non-padded tokens.
|
|
|
|
| 477 |
alibi_attn_mask = attn_bias + alibi_bias
|
| 478 |
|
| 479 |
all_encoder_layers = []
|
| 480 |
+
all_attention_probs: List[torch.Tensor] = []
|
| 481 |
if subset_mask is None:
|
| 482 |
for layer_module in self.layer:
|
| 483 |
+
if output_attentions:
|
| 484 |
+
hidden_states, attn_probs = layer_module(
|
| 485 |
+
hidden_states, cu_seqlens, seqlen, None, indices,
|
| 486 |
+
attn_mask=attention_mask, bias=alibi_attn_mask,
|
| 487 |
+
return_attn_weights=True)
|
| 488 |
+
all_attention_probs.append(attn_probs)
|
| 489 |
+
else:
|
| 490 |
+
hidden_states = layer_module(hidden_states,
|
| 491 |
+
cu_seqlens,
|
| 492 |
+
seqlen,
|
| 493 |
+
None,
|
| 494 |
+
indices,
|
| 495 |
+
attn_mask=attention_mask,
|
| 496 |
+
bias=alibi_attn_mask)
|
| 497 |
if output_all_encoded_layers:
|
| 498 |
+
# Pad back to (B, T, D) so callers get consistent shapes.
|
| 499 |
+
all_encoder_layers.append(
|
| 500 |
+
pad_input(hidden_states, indices, batch, seqlen))
|
| 501 |
# Pad inputs and mask. It will insert back zero-padded tokens.
|
| 502 |
# Assume ntokens is total number of tokens (padded and non-padded)
|
| 503 |
# and ntokens_unpad is total number of non-padded tokens.
|
|
|
|
| 528 |
|
| 529 |
if not output_all_encoded_layers:
|
| 530 |
all_encoder_layers.append(hidden_states)
|
| 531 |
+
else:
|
| 532 |
+
# Prepend padded embedding as index 0 to match HF convention:
|
| 533 |
+
# hidden_states[0] = embedding, hidden_states[i+1] = layer i output.
|
| 534 |
+
all_encoder_layers.insert(0, padded_embedding)
|
| 535 |
+
|
| 536 |
+
attn_out = tuple(all_attention_probs) if output_attentions else None
|
| 537 |
+
return all_encoder_layers, attn_out
|
| 538 |
|
| 539 |
|
| 540 |
class BertPooler(nn.Module):
|
|
|
|
| 637 |
position_ids: Optional[torch.Tensor] = None,
|
| 638 |
output_all_encoded_layers: Optional[bool] = False,
|
| 639 |
masked_tokens_mask: Optional[torch.Tensor] = None,
|
| 640 |
+
output_hidden_states: bool = False,
|
| 641 |
+
output_attentions: bool = False,
|
| 642 |
**kwargs
|
| 643 |
+
) -> BaseModelOutputWithPooling:
|
| 644 |
if attention_mask is None:
|
| 645 |
attention_mask = torch.ones_like(input_ids)
|
| 646 |
if token_type_ids is None:
|
|
|
|
| 659 |
first_col_mask[:, 0] = True
|
| 660 |
subset_mask = masked_tokens_mask | first_col_mask
|
| 661 |
|
| 662 |
+
encoder_outputs, all_attentions = self.encoder(
|
| 663 |
embedding_output,
|
| 664 |
attention_mask,
|
| 665 |
+
output_all_encoded_layers=output_hidden_states,
|
| 666 |
+
subset_mask=subset_mask,
|
| 667 |
+
output_attentions=output_attentions)
|
| 668 |
|
| 669 |
if masked_tokens_mask is None:
|
| 670 |
sequence_output = encoder_outputs[-1]
|
|
|
|
| 683 |
else:
|
| 684 |
pooled_output = None
|
| 685 |
|
| 686 |
+
return BaseModelOutputWithPooling(
|
| 687 |
+
last_hidden_state=sequence_output,
|
| 688 |
+
pooler_output=pooled_output,
|
| 689 |
+
hidden_states=tuple(encoder_outputs) if output_hidden_states else None,
|
| 690 |
+
attentions=all_attentions,
|
| 691 |
+
)
|
|
|
|
| 692 |
|
| 693 |
|
| 694 |
###################
|
|
|
|
| 808 |
return_dict=return_dict,
|
| 809 |
masked_tokens_mask=masked_tokens_mask,
|
| 810 |
)
|
| 811 |
+
|
| 812 |
+
sequence_output = outputs.last_hidden_state
|
| 813 |
prediction_scores = self.cls(sequence_output)
|
| 814 |
|
| 815 |
loss = None
|
|
|
|
| 835 |
return MaskedLMOutput(
|
| 836 |
loss=loss,
|
| 837 |
logits=prediction_scores,
|
| 838 |
+
hidden_states=outputs.hidden_states,
|
| 839 |
+
attentions=outputs.attentions,
|
| 840 |
)
|
| 841 |
|
| 842 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
|
|
|
| 921 |
return_dict=return_dict,
|
| 922 |
)
|
| 923 |
|
| 924 |
+
pooled_output = outputs.pooler_output
|
| 925 |
|
| 926 |
pooled_output = self.dropout(pooled_output)
|
| 927 |
logits = self.classifier(pooled_output)
|
|
|
|
| 959 |
return SequenceClassifierOutput(
|
| 960 |
loss=loss,
|
| 961 |
logits=logits,
|
| 962 |
+
hidden_states=outputs.hidden_states,
|
| 963 |
+
attentions=outputs.attentions,
|
| 964 |
)
|
| 965 |
|
flash_attn_triton.py
CHANGED
|
@@ -188,7 +188,7 @@ def _fwd_kernel(
|
|
| 188 |
(offs_d[None, :] < headdim),
|
| 189 |
other=0.0)
|
| 190 |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 191 |
-
qk += tl.dot(q, k
|
| 192 |
# Trying to combine the two masks seem to make the result wrong
|
| 193 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 194 |
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
|
|
@@ -431,7 +431,7 @@ def _bwd_kernel_one_col_block(
|
|
| 431 |
(offs_d[None, :] < headdim),
|
| 432 |
other=0.0)
|
| 433 |
# recompute p = softmax(qk, dim=-1).T
|
| 434 |
-
qk = tl.dot(q, k
|
| 435 |
# Trying to combine the two masks seem to make the result wrong
|
| 436 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 437 |
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
|
|
@@ -498,7 +498,7 @@ def _bwd_kernel_one_col_block(
|
|
| 498 |
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
|
| 499 |
if not (EVEN_M & EVEN_HEADDIM):
|
| 500 |
tl.debug_barrier()
|
| 501 |
-
dp = tl.dot(do, v
|
| 502 |
# There's a race condition for headdim=48
|
| 503 |
if not EVEN_HEADDIM:
|
| 504 |
tl.debug_barrier()
|
|
|
|
| 188 |
(offs_d[None, :] < headdim),
|
| 189 |
other=0.0)
|
| 190 |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 191 |
+
qk += tl.dot(q, tl.trans(k)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
|
| 192 |
# Trying to combine the two masks seem to make the result wrong
|
| 193 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 194 |
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
|
|
|
|
| 431 |
(offs_d[None, :] < headdim),
|
| 432 |
other=0.0)
|
| 433 |
# recompute p = softmax(qk, dim=-1).T
|
| 434 |
+
qk = tl.dot(q, tl.trans(k)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
|
| 435 |
# Trying to combine the two masks seem to make the result wrong
|
| 436 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 437 |
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
|
|
|
|
| 498 |
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
|
| 499 |
if not (EVEN_M & EVEN_HEADDIM):
|
| 500 |
tl.debug_barrier()
|
| 501 |
+
dp = tl.dot(do, tl.trans(v)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
|
| 502 |
# There's a race condition for headdim=48
|
| 503 |
if not EVEN_HEADDIM:
|
| 504 |
tl.debug_barrier()
|