Taykhoom commited on
Commit
f2409f7
·
verified ·
1 Parent(s): ea63a23

add attention return + support eager attention or triton FA2 via config.use_flash_attn

Browse files
Files changed (3) hide show
  1. README.md +24 -3
  2. bert_layers.py +92 -39
  3. flash_attn_triton.py +3 -3
README.md CHANGED
@@ -7,12 +7,33 @@ tags:
7
  - medical
8
  - genomics
9
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  This is the official pre-trained model introduced in [DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome
11
  ](https://arxiv.org/pdf/2306.15006.pdf).
12
 
13
- We sincerely appreciate the MosaicML team for the [MosaicBERT](https://openreview.net/forum?id=5zipcfLC2Z) implementation, which serves as the base of DNABERT-2 development.
14
 
15
- DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.
16
 
17
  To load the model from huggingface:
18
  ```
@@ -36,4 +57,4 @@ print(embedding_mean.shape) # expect to be 768
36
  # embedding with max pooling
37
  embedding_max = torch.max(hidden_states[0], dim=0)[0]
38
  print(embedding_max.shape) # expect to be 768
39
- ```
 
7
  - medical
8
  - genomics
9
  ---
10
+
11
+ # Note:
12
+ This model is a copied version of DNABERT-2-117M which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original DNABERT-2-117M model can be found at https://huggingface.co/zhihan1996/DNABERT-2-117M. If you use this model please provide attribution to the original authors of DNABERT-2 and the MosaicML team for their implementation.
13
+
14
+ The only changes made were in `flash_attn_triton.py` and `bert_layers.py`.
15
+
16
+ In `flash_attn_triton.py`, the change was to alter:
17
+
18
+ 1. ```qk += tl.dot(q, k, trans_b=True)``` to ```qk += tl.dot(q, tl.trans(k))``` according to the solution provided in the flash attention issue. There were 2 other instances of the use of this ```trans_b=True``` argument in the file which were also changed to use the same solution.
19
+
20
+ In `bert_layers.py` the changes were:
21
+
22
+ 1. **`use_flash_attn` config flag** (`BertUnpadSelfAttention`): Added `self.use_flash_attn = getattr(config, 'use_flash_attn', True)`. Setting `use_flash_attn: false` in the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.
23
+
24
+ 2. **Attention weight return** (`BertUnpadSelfAttention`, `BertUnpadAttention`, `BertLayer`): Added a `return_attn_weights: bool = False` parameter threaded through the call chain. When enabled, the eager path returns the `(B, H, T, T)` attention probability tensor alongside the hidden states.
25
+
26
+ 3. **HF-compatible encoder output** (`BertEncoder`): Added `output_attentions: bool = False`. When `output_all_encoded_layers=True`, each layer's hidden states are now padded back to `(B, T, D)` before collection (previously unpadded `(nnz, D)`), and the embedding output is prepended as index 0 to match the HuggingFace `hidden_states` convention.
27
+
28
+ 4. **Standard HuggingFace output objects** (`BertModel`, `BertForMaskedLM`, `BertForSequenceClassification`): `BertModel.forward` now accepts `output_hidden_states` and `output_attentions` keyword arguments and returns a `BaseModelOutputWithPooling` object with `.last_hidden_state`, `.pooler_output`, `.hidden_states`, and `.attentions` fields. `BertForMaskedLM` and `BertForSequenceClassification` were updated accordingly to read from these named fields.
29
+
30
+ # Original README:
31
  This is the official pre-trained model introduced in [DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome
32
  ](https://arxiv.org/pdf/2306.15006.pdf).
33
 
34
+ We sincerely appreciate the MosaicML team for the [MosaicBERT](https://openreview.net/forum?id=5zipcfLC2Z) implementation, which serves as the base of DNABERT-2 development.
35
 
36
+ DNABERT-2 is a transformer-based genome foundation model trained on multi-species genome.
37
 
38
  To load the model from huggingface:
39
  ```
 
57
  # embedding with max pooling
58
  embedding_max = torch.max(hidden_states[0], dim=0)[0]
59
  print(embedding_max.shape) # expect to be 768
60
+ ```
bert_layers.py CHANGED
@@ -16,7 +16,8 @@ import torch.nn as nn
16
  from einops import rearrange
17
  from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
  from transformers.activations import ACT2FN
19
- from transformers.modeling_outputs import (MaskedLMOutput,
 
20
  SequenceClassifierOutput)
21
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
22
  from transformers.modeling_utils import PreTrainedModel
@@ -120,6 +121,7 @@ class BertUnpadSelfAttention(nn.Module):
120
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
121
  self.p_dropout = config.attention_probs_dropout_prob
122
  self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
 
123
 
124
  # Warn if defaulting to pytorch because of import issues
125
  if flash_attn_qkvpacked_func is None:
@@ -129,7 +131,8 @@ class BertUnpadSelfAttention(nn.Module):
129
 
130
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
131
  max_seqlen_in_batch: int, indices: torch.Tensor,
132
- attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
 
133
  """Perform self-attention.
134
 
135
  If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
@@ -158,7 +161,7 @@ class BertUnpadSelfAttention(nn.Module):
158
  'b s (t h d) -> b s t h d',
159
  t=3,
160
  h=self.num_attention_heads)
161
- if self.p_dropout or flash_attn_qkvpacked_func is None:
162
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
163
  q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
164
  k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
@@ -172,6 +175,7 @@ class BertUnpadSelfAttention(nn.Module):
172
  3) # b s h d
173
  else:
174
  # Triton implementation only supports 0 attention dropout
 
175
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
176
  if convert_dtype:
177
  # Triton implementation only supports fp16 and bf16
@@ -187,7 +191,10 @@ class BertUnpadSelfAttention(nn.Module):
187
 
188
  # attn_mask is 1 for attend and 0 for don't
189
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
190
- return rearrange(attention, 'nnz h d -> nnz (h d)')
 
 
 
191
 
192
 
193
  # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
@@ -225,6 +232,7 @@ class BertUnpadAttention(nn.Module):
225
  indices: Optional[torch.Tensor] = None,
226
  attn_mask: Optional[torch.Tensor] = None,
227
  bias: Optional[torch.Tensor] = None,
 
228
  ) -> torch.Tensor:
229
  """Forward pass for scaled self-attention without padding.
230
 
@@ -237,14 +245,24 @@ class BertUnpadAttention(nn.Module):
237
  indices: None or (total_nnz,)
238
  attn_mask: None or (batch, max_seqlen_in_batch)
239
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
 
240
  """
241
- self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
242
- attn_mask, bias)
 
 
 
 
 
 
243
  if subset_idx is not None:
244
- return self.output(index_first_axis(self_output, subset_idx),
245
- index_first_axis(input_tensor, subset_idx))
246
  else:
247
- return self.output(self_output, input_tensor)
 
 
 
248
 
249
 
250
  class BertGatedLinearUnitMLP(nn.Module):
@@ -312,6 +330,7 @@ class BertLayer(nn.Module):
312
  indices: Optional[torch.Tensor] = None,
313
  attn_mask: Optional[torch.Tensor] = None,
314
  bias: Optional[torch.Tensor] = None,
 
315
  ) -> torch.Tensor:
316
  """Forward pass for a BERT layer, including both attention and MLP.
317
 
@@ -324,10 +343,19 @@ class BertLayer(nn.Module):
324
  indices: None or (total_nnz,)
325
  attn_mask: None or (batch, max_seqlen_in_batch)
326
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
 
327
  """
328
- attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
329
- subset_idx, indices, attn_mask, bias)
 
 
 
 
 
 
330
  layer_output = self.mlp(attention_output)
 
 
331
  return layer_output
332
 
333
 
@@ -410,7 +438,8 @@ class BertEncoder(nn.Module):
410
  attention_mask: torch.Tensor,
411
  output_all_encoded_layers: Optional[bool] = True,
412
  subset_mask: Optional[torch.Tensor] = None,
413
- ) -> List[torch.Tensor]:
 
414
 
415
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
416
  extended_attention_mask = extended_attention_mask.to(
@@ -419,6 +448,12 @@ class BertEncoder(nn.Module):
419
 
420
  attention_mask_bool = attention_mask.bool()
421
  batch, seqlen = hidden_states.shape[:2]
 
 
 
 
 
 
422
  # Unpad inputs and mask. It will remove tokens that are padded.
423
  # Assume ntokens is total number of tokens (padded and non-padded)
424
  # and ntokens_unpad is total number of non-padded tokens.
@@ -442,17 +477,27 @@ class BertEncoder(nn.Module):
442
  alibi_attn_mask = attn_bias + alibi_bias
443
 
444
  all_encoder_layers = []
 
445
  if subset_mask is None:
446
  for layer_module in self.layer:
447
- hidden_states = layer_module(hidden_states,
448
- cu_seqlens,
449
- seqlen,
450
- None,
451
- indices,
452
- attn_mask=attention_mask,
453
- bias=alibi_attn_mask)
 
 
 
 
 
 
 
454
  if output_all_encoded_layers:
455
- all_encoder_layers.append(hidden_states)
 
 
456
  # Pad inputs and mask. It will insert back zero-padded tokens.
457
  # Assume ntokens is total number of tokens (padded and non-padded)
458
  # and ntokens_unpad is total number of non-padded tokens.
@@ -483,7 +528,13 @@ class BertEncoder(nn.Module):
483
 
484
  if not output_all_encoded_layers:
485
  all_encoder_layers.append(hidden_states)
486
- return all_encoder_layers
 
 
 
 
 
 
487
 
488
 
489
  class BertPooler(nn.Module):
@@ -586,8 +637,10 @@ class BertModel(BertPreTrainedModel):
586
  position_ids: Optional[torch.Tensor] = None,
587
  output_all_encoded_layers: Optional[bool] = False,
588
  masked_tokens_mask: Optional[torch.Tensor] = None,
 
 
589
  **kwargs
590
- ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
591
  if attention_mask is None:
592
  attention_mask = torch.ones_like(input_ids)
593
  if token_type_ids is None:
@@ -606,11 +659,12 @@ class BertModel(BertPreTrainedModel):
606
  first_col_mask[:, 0] = True
607
  subset_mask = masked_tokens_mask | first_col_mask
608
 
609
- encoder_outputs = self.encoder(
610
  embedding_output,
611
  attention_mask,
612
- output_all_encoded_layers=output_all_encoded_layers,
613
- subset_mask=subset_mask)
 
614
 
615
  if masked_tokens_mask is None:
616
  sequence_output = encoder_outputs[-1]
@@ -629,13 +683,12 @@ class BertModel(BertPreTrainedModel):
629
  else:
630
  pooled_output = None
631
 
632
- if not output_all_encoded_layers:
633
- encoder_outputs = sequence_output
634
-
635
- if self.pooler is not None:
636
- return encoder_outputs, pooled_output
637
-
638
- return encoder_outputs, None
639
 
640
 
641
  ###################
@@ -755,8 +808,8 @@ class BertForMaskedLM(BertPreTrainedModel):
755
  return_dict=return_dict,
756
  masked_tokens_mask=masked_tokens_mask,
757
  )
758
-
759
- sequence_output = outputs[0]
760
  prediction_scores = self.cls(sequence_output)
761
 
762
  loss = None
@@ -782,8 +835,8 @@ class BertForMaskedLM(BertPreTrainedModel):
782
  return MaskedLMOutput(
783
  loss=loss,
784
  logits=prediction_scores,
785
- hidden_states=outputs[0],
786
- attentions=None,
787
  )
788
 
789
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
@@ -868,7 +921,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
868
  return_dict=return_dict,
869
  )
870
 
871
- pooled_output = outputs[1]
872
 
873
  pooled_output = self.dropout(pooled_output)
874
  logits = self.classifier(pooled_output)
@@ -906,7 +959,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
906
  return SequenceClassifierOutput(
907
  loss=loss,
908
  logits=logits,
909
- hidden_states=outputs[0],
910
- attentions=None,
911
  )
912
 
 
16
  from einops import rearrange
17
  from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
  from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import (BaseModelOutputWithPooling,
20
+ MaskedLMOutput,
21
  SequenceClassifierOutput)
22
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
23
  from transformers.modeling_utils import PreTrainedModel
 
121
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
122
  self.p_dropout = config.attention_probs_dropout_prob
123
  self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
124
+ self.use_flash_attn = getattr(config, 'use_flash_attn', True)
125
 
126
  # Warn if defaulting to pytorch because of import issues
127
  if flash_attn_qkvpacked_func is None:
 
131
 
132
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
133
  max_seqlen_in_batch: int, indices: torch.Tensor,
134
+ attn_mask: torch.Tensor, bias: torch.Tensor,
135
+ return_attn_weights: bool = False) -> torch.Tensor:
136
  """Perform self-attention.
137
 
138
  If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
 
161
  'b s (t h d) -> b s t h d',
162
  t=3,
163
  h=self.num_attention_heads)
164
+ if self.p_dropout or flash_attn_qkvpacked_func is None or not self.use_flash_attn or return_attn_weights:
165
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
166
  q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
167
  k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
 
175
  3) # b s h d
176
  else:
177
  # Triton implementation only supports 0 attention dropout
178
+ attention_probs = None
179
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
180
  if convert_dtype:
181
  # Triton implementation only supports fp16 and bf16
 
191
 
192
  # attn_mask is 1 for attend and 0 for don't
193
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
194
+ out = rearrange(attention, 'nnz h d -> nnz (h d)')
195
+ if return_attn_weights:
196
+ return out, attention_probs # (nnz, D), (B, H, T, T)
197
+ return out
198
 
199
 
200
  # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
 
232
  indices: Optional[torch.Tensor] = None,
233
  attn_mask: Optional[torch.Tensor] = None,
234
  bias: Optional[torch.Tensor] = None,
235
+ return_attn_weights: bool = False,
236
  ) -> torch.Tensor:
237
  """Forward pass for scaled self-attention without padding.
238
 
 
245
  indices: None or (total_nnz,)
246
  attn_mask: None or (batch, max_seqlen_in_batch)
247
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
248
+ return_attn_weights: If True, return attention probabilities alongside output.
249
  """
250
+ if return_attn_weights:
251
+ self_output, attn_probs = self.self(
252
+ input_tensor, cu_seqlens, max_s, indices, attn_mask, bias,
253
+ return_attn_weights=True)
254
+ else:
255
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
256
+ attn_mask, bias)
257
+ attn_probs = None
258
  if subset_idx is not None:
259
+ output = self.output(index_first_axis(self_output, subset_idx),
260
+ index_first_axis(input_tensor, subset_idx))
261
  else:
262
+ output = self.output(self_output, input_tensor)
263
+ if return_attn_weights:
264
+ return output, attn_probs
265
+ return output
266
 
267
 
268
  class BertGatedLinearUnitMLP(nn.Module):
 
330
  indices: Optional[torch.Tensor] = None,
331
  attn_mask: Optional[torch.Tensor] = None,
332
  bias: Optional[torch.Tensor] = None,
333
+ return_attn_weights: bool = False,
334
  ) -> torch.Tensor:
335
  """Forward pass for a BERT layer, including both attention and MLP.
336
 
 
343
  indices: None or (total_nnz,)
344
  attn_mask: None or (batch, max_seqlen_in_batch)
345
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
346
+ return_attn_weights: If True, return attention probabilities alongside output.
347
  """
348
+ if return_attn_weights:
349
+ attention_output, attn_probs = self.attention(
350
+ hidden_states, cu_seqlens, seqlen, subset_idx, indices,
351
+ attn_mask, bias, return_attn_weights=True)
352
+ else:
353
+ attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
354
+ subset_idx, indices, attn_mask, bias)
355
+ attn_probs = None
356
  layer_output = self.mlp(attention_output)
357
+ if return_attn_weights:
358
+ return layer_output, attn_probs
359
  return layer_output
360
 
361
 
 
438
  attention_mask: torch.Tensor,
439
  output_all_encoded_layers: Optional[bool] = True,
440
  subset_mask: Optional[torch.Tensor] = None,
441
+ output_attentions: bool = False,
442
+ ) -> Tuple[List[torch.Tensor], Optional[Tuple[torch.Tensor, ...]]]:
443
 
444
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
445
  extended_attention_mask = extended_attention_mask.to(
 
448
 
449
  attention_mask_bool = attention_mask.bool()
450
  batch, seqlen = hidden_states.shape[:2]
451
+
452
+ # Capture padded embedding output (B, T, D) before unpadding, so it
453
+ # can be prepended to all_encoder_layers as hidden_states index 0 in
454
+ # the HF convention (embedding = index 0, layer i = index i+1).
455
+ padded_embedding = hidden_states
456
+
457
  # Unpad inputs and mask. It will remove tokens that are padded.
458
  # Assume ntokens is total number of tokens (padded and non-padded)
459
  # and ntokens_unpad is total number of non-padded tokens.
 
477
  alibi_attn_mask = attn_bias + alibi_bias
478
 
479
  all_encoder_layers = []
480
+ all_attention_probs: List[torch.Tensor] = []
481
  if subset_mask is None:
482
  for layer_module in self.layer:
483
+ if output_attentions:
484
+ hidden_states, attn_probs = layer_module(
485
+ hidden_states, cu_seqlens, seqlen, None, indices,
486
+ attn_mask=attention_mask, bias=alibi_attn_mask,
487
+ return_attn_weights=True)
488
+ all_attention_probs.append(attn_probs)
489
+ else:
490
+ hidden_states = layer_module(hidden_states,
491
+ cu_seqlens,
492
+ seqlen,
493
+ None,
494
+ indices,
495
+ attn_mask=attention_mask,
496
+ bias=alibi_attn_mask)
497
  if output_all_encoded_layers:
498
+ # Pad back to (B, T, D) so callers get consistent shapes.
499
+ all_encoder_layers.append(
500
+ pad_input(hidden_states, indices, batch, seqlen))
501
  # Pad inputs and mask. It will insert back zero-padded tokens.
502
  # Assume ntokens is total number of tokens (padded and non-padded)
503
  # and ntokens_unpad is total number of non-padded tokens.
 
528
 
529
  if not output_all_encoded_layers:
530
  all_encoder_layers.append(hidden_states)
531
+ else:
532
+ # Prepend padded embedding as index 0 to match HF convention:
533
+ # hidden_states[0] = embedding, hidden_states[i+1] = layer i output.
534
+ all_encoder_layers.insert(0, padded_embedding)
535
+
536
+ attn_out = tuple(all_attention_probs) if output_attentions else None
537
+ return all_encoder_layers, attn_out
538
 
539
 
540
  class BertPooler(nn.Module):
 
637
  position_ids: Optional[torch.Tensor] = None,
638
  output_all_encoded_layers: Optional[bool] = False,
639
  masked_tokens_mask: Optional[torch.Tensor] = None,
640
+ output_hidden_states: bool = False,
641
+ output_attentions: bool = False,
642
  **kwargs
643
+ ) -> BaseModelOutputWithPooling:
644
  if attention_mask is None:
645
  attention_mask = torch.ones_like(input_ids)
646
  if token_type_ids is None:
 
659
  first_col_mask[:, 0] = True
660
  subset_mask = masked_tokens_mask | first_col_mask
661
 
662
+ encoder_outputs, all_attentions = self.encoder(
663
  embedding_output,
664
  attention_mask,
665
+ output_all_encoded_layers=output_hidden_states,
666
+ subset_mask=subset_mask,
667
+ output_attentions=output_attentions)
668
 
669
  if masked_tokens_mask is None:
670
  sequence_output = encoder_outputs[-1]
 
683
  else:
684
  pooled_output = None
685
 
686
+ return BaseModelOutputWithPooling(
687
+ last_hidden_state=sequence_output,
688
+ pooler_output=pooled_output,
689
+ hidden_states=tuple(encoder_outputs) if output_hidden_states else None,
690
+ attentions=all_attentions,
691
+ )
 
692
 
693
 
694
  ###################
 
808
  return_dict=return_dict,
809
  masked_tokens_mask=masked_tokens_mask,
810
  )
811
+
812
+ sequence_output = outputs.last_hidden_state
813
  prediction_scores = self.cls(sequence_output)
814
 
815
  loss = None
 
835
  return MaskedLMOutput(
836
  loss=loss,
837
  logits=prediction_scores,
838
+ hidden_states=outputs.hidden_states,
839
+ attentions=outputs.attentions,
840
  )
841
 
842
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
 
921
  return_dict=return_dict,
922
  )
923
 
924
+ pooled_output = outputs.pooler_output
925
 
926
  pooled_output = self.dropout(pooled_output)
927
  logits = self.classifier(pooled_output)
 
959
  return SequenceClassifierOutput(
960
  loss=loss,
961
  logits=logits,
962
+ hidden_states=outputs.hidden_states,
963
+ attentions=outputs.attentions,
964
  )
965
 
flash_attn_triton.py CHANGED
@@ -188,7 +188,7 @@ def _fwd_kernel(
188
  (offs_d[None, :] < headdim),
189
  other=0.0)
190
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
191
- qk += tl.dot(q, k, trans_b=True)
192
  # Trying to combine the two masks seem to make the result wrong
193
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
194
  qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
@@ -431,7 +431,7 @@ def _bwd_kernel_one_col_block(
431
  (offs_d[None, :] < headdim),
432
  other=0.0)
433
  # recompute p = softmax(qk, dim=-1).T
434
- qk = tl.dot(q, k, trans_b=True)
435
  # Trying to combine the two masks seem to make the result wrong
436
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
437
  qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
@@ -498,7 +498,7 @@ def _bwd_kernel_one_col_block(
498
  # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
499
  if not (EVEN_M & EVEN_HEADDIM):
500
  tl.debug_barrier()
501
- dp = tl.dot(do, v, trans_b=True)
502
  # There's a race condition for headdim=48
503
  if not EVEN_HEADDIM:
504
  tl.debug_barrier()
 
188
  (offs_d[None, :] < headdim),
189
  other=0.0)
190
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
191
+ qk += tl.dot(q, tl.trans(k)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
192
  # Trying to combine the two masks seem to make the result wrong
193
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
194
  qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
 
431
  (offs_d[None, :] < headdim),
432
  other=0.0)
433
  # recompute p = softmax(qk, dim=-1).T
434
+ qk = tl.dot(q, tl.trans(k)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
435
  # Trying to combine the two masks seem to make the result wrong
436
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
437
  qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
 
498
  # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
499
  if not (EVEN_M & EVEN_HEADDIM):
500
  tl.debug_barrier()
501
+ dp = tl.dot(do, tl.trans(v)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
502
  # There's a race condition for headdim=48
503
  if not EVEN_HEADDIM:
504
  tl.debug_barrier()