Update modeling_gptbert.py
Browse files- modeling_gptbert.py +4 -0
modeling_gptbert.py
CHANGED
|
@@ -23,6 +23,9 @@ from transformers.modeling_outputs import (
|
|
| 23 |
import math
|
| 24 |
from typing import TYPE_CHECKING, Optional, Union, Tuple, List
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
if is_flash_attn_2_available():
|
| 27 |
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
| 28 |
from flash_attn.layers.rotary import RotaryEmbedding
|
|
@@ -1036,6 +1039,7 @@ class GptBertForSequenceClassification(GptBertModel):
|
|
| 1036 |
|
| 1037 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1038 |
logits = self.head(sequence_output[:, 0, :])
|
|
|
|
| 1039 |
|
| 1040 |
loss = None
|
| 1041 |
if labels is not None:
|
|
|
|
| 23 |
import math
|
| 24 |
from typing import TYPE_CHECKING, Optional, Union, Tuple, List
|
| 25 |
|
| 26 |
+
def is_flash_attn_2_available():
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
if is_flash_attn_2_available():
|
| 30 |
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
| 31 |
from flash_attn.layers.rotary import RotaryEmbedding
|
|
|
|
| 1039 |
|
| 1040 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 1041 |
logits = self.head(sequence_output[:, 0, :])
|
| 1042 |
+
print(logits)
|
| 1043 |
|
| 1044 |
loss = None
|
| 1045 |
if labels is not None:
|