lgcharpe commited on
Commit
4abdbe5
·
verified ·
1 Parent(s): f5214f6

Update modeling_gptbert.py

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +4 -0
modeling_gptbert.py CHANGED
@@ -23,6 +23,9 @@ from transformers.modeling_outputs import (
23
  import math
24
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
 
 
 
 
26
  if is_flash_attn_2_available():
27
  from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
28
  from flash_attn.layers.rotary import RotaryEmbedding
@@ -1036,6 +1039,7 @@ class GptBertForSequenceClassification(GptBertModel):
1036
 
1037
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1038
  logits = self.head(sequence_output[:, 0, :])
 
1039
 
1040
  loss = None
1041
  if labels is not None:
 
23
  import math
24
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
 
26
+ def is_flash_attn_2_available():
27
+ return False
28
+
29
  if is_flash_attn_2_available():
30
  from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
31
  from flash_attn.layers.rotary import RotaryEmbedding
 
1039
 
1040
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1041
  logits = self.head(sequence_output[:, 0, :])
1042
+ print(logits)
1043
 
1044
  loss = None
1045
  if labels is not None: