Taykhoom commited on
Commit
e1354bd
·
verified ·
1 Parent(s): 62da139

add attention return + support eager attention or triton FA2 via config.use_flash_attn

Browse files
Files changed (3) hide show
  1. README.md +34 -13
  2. bert_layers.py +94 -44
  3. flash_attn_triton.py +3 -3
README.md CHANGED
@@ -9,7 +9,28 @@ tags:
9
  - mrna
10
  ---
11
 
12
- # mRNABERT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  A robust language model pre-trained on over 18 million high-quality mRNA sequences, incorporating contrastive learning to integrate the semantic features of amino acids.
15
 
@@ -18,14 +39,14 @@ This is the official pre-trained model introduced in [mRNABERT: advancing mRNA
18
  The repository of mRNABERT is at [yyly6/mRNABERT](https://github.com/yyly6/mRNABERT).
19
 
20
 
21
- ## Intended uses & limitations
22
  The model could be used for mRNA sequences feature extraction or to be fine-tuned on downstream tasks. **Before inputting the model, you need to preprocess the data: use single-letter separation for the UTR regions and three-character separation for the CDS regions.**For full examples, please see [our code on data processing](https://github.com/yyly6/mRNABERT).
23
 
24
- ## Training data
25
 
26
  The mRNABERT model was pretrained on [a comprehensive mRNA dataset](https://zenodo.org/records/12516160), which originally consisted of approximately 36 million complete CDS or mRNA sequences. After cleaning, this number was reduced to 18 million.
27
 
28
- ## Usage
29
  To load the model from huggingface:
30
  ```python
31
  import torch
@@ -40,14 +61,14 @@ model = AutoModel.from_pretrained("YYLY66/mRNABERT", trust_remote_code=True, con
40
  To extract the embeddings of mRNA sequences:
41
 
42
  ```python
43
- seq = ["A T C G G A GGG CCC TTT",
44
- "A T C G",
45
  "TTT CCC GAC ATG"] #Separate the sequences with spaces.
46
 
47
  encoding = tokenizer.batch_encode_plus(seq, add_special_tokens=True, padding='longest', return_tensors="pt")
48
 
49
  input_ids = encoding['input_ids']
50
- attention_mask = encoding['attention_mask']
51
 
52
  output = model(input_ids=input_ids, attention_mask=attention_mask)
53
  last_hidden_state = output[0]
@@ -55,13 +76,13 @@ last_hidden_state = output[0]
55
  attention_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state) # Shape : [batch_size, seq_length, hidden_size]
56
 
57
  # Sum embeddings along the batch dimension
58
- sum_embeddings = torch.sum(last_hidden_state * attention_mask, dim=1)
59
 
60
  # Also sum the masks along the batch dimension
61
- sum_masks = attention_mask.sum(1)
62
 
63
  # Compute mean embedding.
64
- mean_embedding = sum_embeddings / sum_masks #Shape:[batch_size, hidden_size]
65
 
66
  ```
67
 
@@ -69,7 +90,7 @@ The extracted embeddings can be used for contrastive learning pretraining or as
69
 
70
 
71
 
72
- ## Citation
73
 
74
  **BibTeX**:
75
 
@@ -86,5 +107,5 @@ The extracted embeddings can be used for contrastive learning pretraining or as
86
  }
87
  ```
88
 
89
- ## Contact
90
- If you have any question, please feel free to email us (xiongying@zju.edu.cn).
 
9
  - mrna
10
  ---
11
 
12
+ # Note:
13
+ This model is a copied version of mRNABERT which fixes the FlashAttention integration with Trition (specifically integrating the solution in: https://github.com/Dao-AILab/flash-attention/issues/508) as well as fixes the return of attention weights and hidden states in the forward function of the model. The original mRNABERT model can be found at https://huggingface.co/YYLY66/mRNABERT. If you use this model please provide attribution to the original authors of mRNABERT and the MosaicML team for their implementation.
14
+
15
+ The only changes made were in `flash_attn_triton.py` and `bert_layers.py`.
16
+
17
+ In `flash_attn_triton.py`, the change was to alter:
18
+
19
+ 1. ```qk += tl.dot(q, k, trans_b=True)``` to ```qk += tl.dot(q, tl.trans(k))``` according to the solution provided in the flash attention issue. There were 2 other instances of the use of this ```trans_b=True``` argument in the file which were also changed to use the same solution.
20
+
21
+ In `bert_layers.py` the changes were:
22
+
23
+ 1. **`use_flash_attn` config flag** (`BertUnpadSelfAttention`): Added `self.use_flash_attn = getattr(config, 'use_flash_attn', True)`. Setting `use_flash_attn: false` in the model config forces the PyTorch eager attention path, enabling attention weight extraction without requiring Triton.
24
+
25
+ 2. **Attention weight return** (`BertUnpadSelfAttention`, `BertUnpadAttention`, `BertLayer`): Added a `return_attn_weights: bool = False` parameter threaded through the call chain. When enabled, the eager path returns the `(B, H, T, T)` attention probability tensor alongside the hidden states.
26
+
27
+ 3. **HF-compatible encoder output** (`BertEncoder`): Added `output_attentions: bool = False`. When `output_all_encoded_layers=True`, each layer's hidden states are now padded back to `(B, T, D)` before collection (previously unpadded `(nnz, D)`), and the embedding output is prepended as index 0 to match the HuggingFace `hidden_states` convention.
28
+
29
+ 4. **Standard HuggingFace output objects** (`BertModel`, `BertForMaskedLM`, `BertForSequenceClassification`): `BertModel.forward` now accepts `output_hidden_states` and `output_attentions` keyword arguments and returns a `BaseModelOutputWithPooling` object with `.last_hidden_state`, `.pooler_output`, `.hidden_states`, and `.attentions` fields. `BertForMaskedLM` and `BertForSequenceClassification` were updated accordingly to read from these named fields.
30
+
31
+ # Original README:
32
+
33
+ ## mRNABERT
34
 
35
  A robust language model pre-trained on over 18 million high-quality mRNA sequences, incorporating contrastive learning to integrate the semantic features of amino acids.
36
 
 
39
  The repository of mRNABERT is at [yyly6/mRNABERT](https://github.com/yyly6/mRNABERT).
40
 
41
 
42
+ ### Intended uses & limitations
43
  The model could be used for mRNA sequences feature extraction or to be fine-tuned on downstream tasks. **Before inputting the model, you need to preprocess the data: use single-letter separation for the UTR regions and three-character separation for the CDS regions.**For full examples, please see [our code on data processing](https://github.com/yyly6/mRNABERT).
44
 
45
+ ### Training data
46
 
47
  The mRNABERT model was pretrained on [a comprehensive mRNA dataset](https://zenodo.org/records/12516160), which originally consisted of approximately 36 million complete CDS or mRNA sequences. After cleaning, this number was reduced to 18 million.
48
 
49
+ ### Usage
50
  To load the model from huggingface:
51
  ```python
52
  import torch
 
61
  To extract the embeddings of mRNA sequences:
62
 
63
  ```python
64
+ seq = ["A T C G G A GGG CCC TTT",
65
+ "A T C G",
66
  "TTT CCC GAC ATG"] #Separate the sequences with spaces.
67
 
68
  encoding = tokenizer.batch_encode_plus(seq, add_special_tokens=True, padding='longest', return_tensors="pt")
69
 
70
  input_ids = encoding['input_ids']
71
+ attention_mask = encoding['attention_mask']
72
 
73
  output = model(input_ids=input_ids, attention_mask=attention_mask)
74
  last_hidden_state = output[0]
 
76
  attention_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state) # Shape : [batch_size, seq_length, hidden_size]
77
 
78
  # Sum embeddings along the batch dimension
79
+ sum_embeddings = torch.sum(last_hidden_state * attention_mask, dim=1)
80
 
81
  # Also sum the masks along the batch dimension
82
+ sum_masks = attention_mask.sum(1)
83
 
84
  # Compute mean embedding.
85
+ mean_embedding = sum_embeddings / sum_masks #Shape:[batch_size, hidden_size]
86
 
87
  ```
88
 
 
90
 
91
 
92
 
93
+ ### Citation
94
 
95
  **BibTeX**:
96
 
 
107
  }
108
  ```
109
 
110
+ ### Contact
111
+ If you have any question, please feel free to email us (xiongying@zju.edu.cn).
bert_layers.py CHANGED
@@ -16,9 +16,11 @@ import torch.nn as nn
16
  from einops import rearrange
17
  from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
  from transformers.activations import ACT2FN
19
- from transformers.modeling_outputs import (MaskedLMOutput,
 
20
  SequenceClassifierOutput)
21
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
 
22
 
23
  from .bert_padding import (index_first_axis,
24
  index_put_first_axis, pad_input,
@@ -119,6 +121,7 @@ class BertUnpadSelfAttention(nn.Module):
119
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
  self.p_dropout = config.attention_probs_dropout_prob
121
  self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
 
122
 
123
  # Warn if defaulting to pytorch because of import issues
124
  if flash_attn_qkvpacked_func is None:
@@ -128,7 +131,8 @@ class BertUnpadSelfAttention(nn.Module):
128
 
129
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
130
  max_seqlen_in_batch: int, indices: torch.Tensor,
131
- attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
 
132
  """Perform self-attention.
133
 
134
  If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
@@ -157,7 +161,7 @@ class BertUnpadSelfAttention(nn.Module):
157
  'b s (t h d) -> b s t h d',
158
  t=3,
159
  h=self.num_attention_heads)
160
- if self.p_dropout or flash_attn_qkvpacked_func is None:
161
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
162
  q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
163
  k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
@@ -171,6 +175,7 @@ class BertUnpadSelfAttention(nn.Module):
171
  3) # b s h d
172
  else:
173
  # Triton implementation only supports 0 attention dropout
 
174
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
175
  if convert_dtype:
176
  # Triton implementation only supports fp16 and bf16
@@ -186,7 +191,10 @@ class BertUnpadSelfAttention(nn.Module):
186
 
187
  # attn_mask is 1 for attend and 0 for don't
188
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
189
- return rearrange(attention, 'nnz h d -> nnz (h d)')
 
 
 
190
 
191
 
192
  # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
@@ -224,6 +232,7 @@ class BertUnpadAttention(nn.Module):
224
  indices: Optional[torch.Tensor] = None,
225
  attn_mask: Optional[torch.Tensor] = None,
226
  bias: Optional[torch.Tensor] = None,
 
227
  ) -> torch.Tensor:
228
  """Forward pass for scaled self-attention without padding.
229
 
@@ -236,14 +245,24 @@ class BertUnpadAttention(nn.Module):
236
  indices: None or (total_nnz,)
237
  attn_mask: None or (batch, max_seqlen_in_batch)
238
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
 
239
  """
240
- self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
241
- attn_mask, bias)
 
 
 
 
 
 
242
  if subset_idx is not None:
243
- return self.output(index_first_axis(self_output, subset_idx),
244
- index_first_axis(input_tensor, subset_idx))
245
  else:
246
- return self.output(self_output, input_tensor)
 
 
 
247
 
248
 
249
  class BertGatedLinearUnitMLP(nn.Module):
@@ -311,6 +330,7 @@ class BertLayer(nn.Module):
311
  indices: Optional[torch.Tensor] = None,
312
  attn_mask: Optional[torch.Tensor] = None,
313
  bias: Optional[torch.Tensor] = None,
 
314
  ) -> torch.Tensor:
315
  """Forward pass for a BERT layer, including both attention and MLP.
316
 
@@ -323,10 +343,19 @@ class BertLayer(nn.Module):
323
  indices: None or (total_nnz,)
324
  attn_mask: None or (batch, max_seqlen_in_batch)
325
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
 
326
  """
327
- attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
328
- subset_idx, indices, attn_mask, bias)
 
 
 
 
 
 
329
  layer_output = self.mlp(attention_output)
 
 
330
  return layer_output
331
 
332
 
@@ -409,15 +438,22 @@ class BertEncoder(nn.Module):
409
  attention_mask: torch.Tensor,
410
  output_all_encoded_layers: Optional[bool] = True,
411
  subset_mask: Optional[torch.Tensor] = None,
412
- ) -> List[torch.Tensor]:
 
413
 
414
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
415
  extended_attention_mask = extended_attention_mask.to(
416
- dtype=next(self.parameters()).dtype) # fp16 compatibility
417
  extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
418
 
419
  attention_mask_bool = attention_mask.bool()
420
  batch, seqlen = hidden_states.shape[:2]
 
 
 
 
 
 
421
  # Unpad inputs and mask. It will remove tokens that are padded.
422
  # Assume ntokens is total number of tokens (padded and non-padded)
423
  # and ntokens_unpad is total number of non-padded tokens.
@@ -441,17 +477,27 @@ class BertEncoder(nn.Module):
441
  alibi_attn_mask = attn_bias + alibi_bias
442
 
443
  all_encoder_layers = []
 
444
  if subset_mask is None:
445
  for layer_module in self.layer:
446
- hidden_states = layer_module(hidden_states,
447
- cu_seqlens,
448
- seqlen,
449
- None,
450
- indices,
451
- attn_mask=attention_mask,
452
- bias=alibi_attn_mask)
 
 
 
 
 
 
 
453
  if output_all_encoded_layers:
454
- all_encoder_layers.append(hidden_states)
 
 
455
  # Pad inputs and mask. It will insert back zero-padded tokens.
456
  # Assume ntokens is total number of tokens (padded and non-padded)
457
  # and ntokens_unpad is total number of non-padded tokens.
@@ -482,7 +528,13 @@ class BertEncoder(nn.Module):
482
 
483
  if not output_all_encoded_layers:
484
  all_encoder_layers.append(hidden_states)
485
- return all_encoder_layers
 
 
 
 
 
 
486
 
487
 
488
  class BertPooler(nn.Module):
@@ -585,8 +637,10 @@ class BertModel(BertPreTrainedModel):
585
  position_ids: Optional[torch.Tensor] = None,
586
  output_all_encoded_layers: Optional[bool] = False,
587
  masked_tokens_mask: Optional[torch.Tensor] = None,
 
 
588
  **kwargs
589
- ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
590
  if attention_mask is None:
591
  attention_mask = torch.ones_like(input_ids)
592
  if token_type_ids is None:
@@ -605,11 +659,12 @@ class BertModel(BertPreTrainedModel):
605
  first_col_mask[:, 0] = True
606
  subset_mask = masked_tokens_mask | first_col_mask
607
 
608
- encoder_outputs = self.encoder(
609
  embedding_output,
610
  attention_mask,
611
- output_all_encoded_layers=output_all_encoded_layers,
612
- subset_mask=subset_mask)
 
613
 
614
  if masked_tokens_mask is None:
615
  sequence_output = encoder_outputs[-1]
@@ -628,13 +683,12 @@ class BertModel(BertPreTrainedModel):
628
  else:
629
  pooled_output = None
630
 
631
- if not output_all_encoded_layers:
632
- encoder_outputs = sequence_output
633
-
634
- if self.pooler is not None:
635
- return encoder_outputs, pooled_output
636
-
637
- return encoder_outputs, None
638
 
639
 
640
  ###################
@@ -754,8 +808,8 @@ class BertForMaskedLM(BertPreTrainedModel):
754
  return_dict=return_dict,
755
  masked_tokens_mask=masked_tokens_mask,
756
  )
757
-
758
- sequence_output = outputs[0]
759
  prediction_scores = self.cls(sequence_output)
760
 
761
  loss = None
@@ -781,8 +835,8 @@ class BertForMaskedLM(BertPreTrainedModel):
781
  return MaskedLMOutput(
782
  loss=loss,
783
  logits=prediction_scores,
784
- hidden_states=outputs[0],
785
- attentions=None,
786
  )
787
 
788
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
@@ -809,10 +863,6 @@ class BertForMaskedLM(BertPreTrainedModel):
809
  return {'input_ids': input_ids, 'attention_mask': attention_mask}
810
 
811
 
812
- class BertForNextSentencePrediction(BertPreTrainedModel):
813
- #TBD: Push in future commit
814
- pass
815
-
816
 
817
  class BertForSequenceClassification(BertPreTrainedModel):
818
  """Bert Model transformer with a sequence classification/regression head.
@@ -871,7 +921,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
871
  return_dict=return_dict,
872
  )
873
 
874
- pooled_output = outputs[1]
875
 
876
  pooled_output = self.dropout(pooled_output)
877
  logits = self.classifier(pooled_output)
@@ -909,7 +959,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
909
  return SequenceClassifierOutput(
910
  loss=loss,
911
  logits=logits,
912
- hidden_states=outputs[0],
913
- attentions=None,
914
  )
915
 
 
16
  from einops import rearrange
17
  from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
  from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import (BaseModelOutputWithPooling,
20
+ MaskedLMOutput,
21
  SequenceClassifierOutput)
22
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
23
+ from transformers.modeling_utils import PreTrainedModel
24
 
25
  from .bert_padding import (index_first_axis,
26
  index_put_first_axis, pad_input,
 
121
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
122
  self.p_dropout = config.attention_probs_dropout_prob
123
  self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
124
+ self.use_flash_attn = getattr(config, 'use_flash_attn', True)
125
 
126
  # Warn if defaulting to pytorch because of import issues
127
  if flash_attn_qkvpacked_func is None:
 
131
 
132
  def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
133
  max_seqlen_in_batch: int, indices: torch.Tensor,
134
+ attn_mask: torch.Tensor, bias: torch.Tensor,
135
+ return_attn_weights: bool = False) -> torch.Tensor:
136
  """Perform self-attention.
137
 
138
  If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
 
161
  'b s (t h d) -> b s t h d',
162
  t=3,
163
  h=self.num_attention_heads)
164
+ if self.p_dropout or flash_attn_qkvpacked_func is None or not self.use_flash_attn or return_attn_weights:
165
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
166
  q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
167
  k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
 
175
  3) # b s h d
176
  else:
177
  # Triton implementation only supports 0 attention dropout
178
+ attention_probs = None
179
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
180
  if convert_dtype:
181
  # Triton implementation only supports fp16 and bf16
 
191
 
192
  # attn_mask is 1 for attend and 0 for don't
193
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
194
+ out = rearrange(attention, 'nnz h d -> nnz (h d)')
195
+ if return_attn_weights:
196
+ return out, attention_probs # (nnz, D), (B, H, T, T)
197
+ return out
198
 
199
 
200
  # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
 
232
  indices: Optional[torch.Tensor] = None,
233
  attn_mask: Optional[torch.Tensor] = None,
234
  bias: Optional[torch.Tensor] = None,
235
+ return_attn_weights: bool = False,
236
  ) -> torch.Tensor:
237
  """Forward pass for scaled self-attention without padding.
238
 
 
245
  indices: None or (total_nnz,)
246
  attn_mask: None or (batch, max_seqlen_in_batch)
247
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
248
+ return_attn_weights: If True, return attention probabilities alongside output.
249
  """
250
+ if return_attn_weights:
251
+ self_output, attn_probs = self.self(
252
+ input_tensor, cu_seqlens, max_s, indices, attn_mask, bias,
253
+ return_attn_weights=True)
254
+ else:
255
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
256
+ attn_mask, bias)
257
+ attn_probs = None
258
  if subset_idx is not None:
259
+ output = self.output(index_first_axis(self_output, subset_idx),
260
+ index_first_axis(input_tensor, subset_idx))
261
  else:
262
+ output = self.output(self_output, input_tensor)
263
+ if return_attn_weights:
264
+ return output, attn_probs
265
+ return output
266
 
267
 
268
  class BertGatedLinearUnitMLP(nn.Module):
 
330
  indices: Optional[torch.Tensor] = None,
331
  attn_mask: Optional[torch.Tensor] = None,
332
  bias: Optional[torch.Tensor] = None,
333
+ return_attn_weights: bool = False,
334
  ) -> torch.Tensor:
335
  """Forward pass for a BERT layer, including both attention and MLP.
336
 
 
343
  indices: None or (total_nnz,)
344
  attn_mask: None or (batch, max_seqlen_in_batch)
345
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
346
+ return_attn_weights: If True, return attention probabilities alongside output.
347
  """
348
+ if return_attn_weights:
349
+ attention_output, attn_probs = self.attention(
350
+ hidden_states, cu_seqlens, seqlen, subset_idx, indices,
351
+ attn_mask, bias, return_attn_weights=True)
352
+ else:
353
+ attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
354
+ subset_idx, indices, attn_mask, bias)
355
+ attn_probs = None
356
  layer_output = self.mlp(attention_output)
357
+ if return_attn_weights:
358
+ return layer_output, attn_probs
359
  return layer_output
360
 
361
 
 
438
  attention_mask: torch.Tensor,
439
  output_all_encoded_layers: Optional[bool] = True,
440
  subset_mask: Optional[torch.Tensor] = None,
441
+ output_attentions: bool = False,
442
+ ) -> Tuple[List[torch.Tensor], Optional[Tuple[torch.Tensor, ...]]]:
443
 
444
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
445
  extended_attention_mask = extended_attention_mask.to(
446
+ dtype=torch.float32) # fp16 compatibility
447
  extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
448
 
449
  attention_mask_bool = attention_mask.bool()
450
  batch, seqlen = hidden_states.shape[:2]
451
+
452
+ # Capture padded embedding output (B, T, D) before unpadding, so it
453
+ # can be prepended to all_encoder_layers as hidden_states index 0 in
454
+ # the HF convention (embedding = index 0, layer i = index i+1).
455
+ padded_embedding = hidden_states
456
+
457
  # Unpad inputs and mask. It will remove tokens that are padded.
458
  # Assume ntokens is total number of tokens (padded and non-padded)
459
  # and ntokens_unpad is total number of non-padded tokens.
 
477
  alibi_attn_mask = attn_bias + alibi_bias
478
 
479
  all_encoder_layers = []
480
+ all_attention_probs: List[torch.Tensor] = []
481
  if subset_mask is None:
482
  for layer_module in self.layer:
483
+ if output_attentions:
484
+ hidden_states, attn_probs = layer_module(
485
+ hidden_states, cu_seqlens, seqlen, None, indices,
486
+ attn_mask=attention_mask, bias=alibi_attn_mask,
487
+ return_attn_weights=True)
488
+ all_attention_probs.append(attn_probs)
489
+ else:
490
+ hidden_states = layer_module(hidden_states,
491
+ cu_seqlens,
492
+ seqlen,
493
+ None,
494
+ indices,
495
+ attn_mask=attention_mask,
496
+ bias=alibi_attn_mask)
497
  if output_all_encoded_layers:
498
+ # Pad back to (B, T, D) so callers get consistent shapes.
499
+ all_encoder_layers.append(
500
+ pad_input(hidden_states, indices, batch, seqlen))
501
  # Pad inputs and mask. It will insert back zero-padded tokens.
502
  # Assume ntokens is total number of tokens (padded and non-padded)
503
  # and ntokens_unpad is total number of non-padded tokens.
 
528
 
529
  if not output_all_encoded_layers:
530
  all_encoder_layers.append(hidden_states)
531
+ else:
532
+ # Prepend padded embedding as index 0 to match HF convention:
533
+ # hidden_states[0] = embedding, hidden_states[i+1] = layer i output.
534
+ all_encoder_layers.insert(0, padded_embedding)
535
+
536
+ attn_out = tuple(all_attention_probs) if output_attentions else None
537
+ return all_encoder_layers, attn_out
538
 
539
 
540
  class BertPooler(nn.Module):
 
637
  position_ids: Optional[torch.Tensor] = None,
638
  output_all_encoded_layers: Optional[bool] = False,
639
  masked_tokens_mask: Optional[torch.Tensor] = None,
640
+ output_hidden_states: bool = False,
641
+ output_attentions: bool = False,
642
  **kwargs
643
+ ) -> BaseModelOutputWithPooling:
644
  if attention_mask is None:
645
  attention_mask = torch.ones_like(input_ids)
646
  if token_type_ids is None:
 
659
  first_col_mask[:, 0] = True
660
  subset_mask = masked_tokens_mask | first_col_mask
661
 
662
+ encoder_outputs, all_attentions = self.encoder(
663
  embedding_output,
664
  attention_mask,
665
+ output_all_encoded_layers=output_hidden_states,
666
+ subset_mask=subset_mask,
667
+ output_attentions=output_attentions)
668
 
669
  if masked_tokens_mask is None:
670
  sequence_output = encoder_outputs[-1]
 
683
  else:
684
  pooled_output = None
685
 
686
+ return BaseModelOutputWithPooling(
687
+ last_hidden_state=sequence_output,
688
+ pooler_output=pooled_output,
689
+ hidden_states=tuple(encoder_outputs) if output_hidden_states else None,
690
+ attentions=all_attentions,
691
+ )
 
692
 
693
 
694
  ###################
 
808
  return_dict=return_dict,
809
  masked_tokens_mask=masked_tokens_mask,
810
  )
811
+
812
+ sequence_output = outputs.last_hidden_state
813
  prediction_scores = self.cls(sequence_output)
814
 
815
  loss = None
 
835
  return MaskedLMOutput(
836
  loss=loss,
837
  logits=prediction_scores,
838
+ hidden_states=outputs.hidden_states,
839
+ attentions=outputs.attentions,
840
  )
841
 
842
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
 
863
  return {'input_ids': input_ids, 'attention_mask': attention_mask}
864
 
865
 
 
 
 
 
866
 
867
  class BertForSequenceClassification(BertPreTrainedModel):
868
  """Bert Model transformer with a sequence classification/regression head.
 
921
  return_dict=return_dict,
922
  )
923
 
924
+ pooled_output = outputs.pooler_output
925
 
926
  pooled_output = self.dropout(pooled_output)
927
  logits = self.classifier(pooled_output)
 
959
  return SequenceClassifierOutput(
960
  loss=loss,
961
  logits=logits,
962
+ hidden_states=outputs.hidden_states,
963
+ attentions=outputs.attentions,
964
  )
965
 
flash_attn_triton.py CHANGED
@@ -188,7 +188,7 @@ def _fwd_kernel(
188
  (offs_d[None, :] < headdim),
189
  other=0.0)
190
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
191
- qk += tl.dot(q, k, trans_b=True)
192
  # Trying to combine the two masks seem to make the result wrong
193
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
194
  qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
@@ -431,7 +431,7 @@ def _bwd_kernel_one_col_block(
431
  (offs_d[None, :] < headdim),
432
  other=0.0)
433
  # recompute p = softmax(qk, dim=-1).T
434
- qk = tl.dot(q, k, trans_b=True)
435
  # Trying to combine the two masks seem to make the result wrong
436
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
437
  qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
@@ -498,7 +498,7 @@ def _bwd_kernel_one_col_block(
498
  # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
499
  if not (EVEN_M & EVEN_HEADDIM):
500
  tl.debug_barrier()
501
- dp = tl.dot(do, v, trans_b=True)
502
  # There's a race condition for headdim=48
503
  if not EVEN_HEADDIM:
504
  tl.debug_barrier()
 
188
  (offs_d[None, :] < headdim),
189
  other=0.0)
190
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
191
+ qk += tl.dot(q, tl.trans(k)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
192
  # Trying to combine the two masks seem to make the result wrong
193
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
194
  qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
 
431
  (offs_d[None, :] < headdim),
432
  other=0.0)
433
  # recompute p = softmax(qk, dim=-1).T
434
+ qk = tl.dot(q, tl.trans(k)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
435
  # Trying to combine the two masks seem to make the result wrong
436
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
437
  qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
 
498
  # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
499
  if not (EVEN_M & EVEN_HEADDIM):
500
  tl.debug_barrier()
501
+ dp = tl.dot(do, tl.trans(v)) # see issue: https://github.com/Dao-AILab/flash-attention/issues/508
502
  # There's a race condition for headdim=48
503
  if not EVEN_HEADDIM:
504
  tl.debug_barrier()