fix attention mask
Browse files- modeling_gptbert.py +2 -2
modeling_gptbert.py
CHANGED
|
@@ -615,9 +615,9 @@ class GptBertModel(GptBertPreTrainedModel):
|
|
| 615 |
padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
|
| 616 |
else:
|
| 617 |
if len(attention_mask.size()) == 2:
|
| 618 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 619 |
-
elif len(attention_mask.size()) == 3:
|
| 620 |
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
|
|
|
| 621 |
padding_info = attention_mask
|
| 622 |
|
| 623 |
static_embeddings = self.embedding(input_ids)
|
|
|
|
| 615 |
padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
|
| 616 |
else:
|
| 617 |
if len(attention_mask.size()) == 2:
|
|
|
|
|
|
|
| 618 |
attention_mask = attention_mask.unsqueeze(1)
|
| 619 |
+
if len(attention_mask.size()) != 3:
|
| 620 |
+
raise ValueError("Bare `attention_mask` med to eller tre dimensjoner støttes nå for SDPA.")
|
| 621 |
padding_info = attention_mask
|
| 622 |
|
| 623 |
static_embeddings = self.embedding(input_ids)
|