davda54 commited on
Commit
16b2b5e
·
verified ·
1 Parent(s): 6cbbc37

fix attention mask

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +2 -2
modeling_gptbert.py CHANGED
@@ -615,9 +615,9 @@ class GptBertModel(GptBertPreTrainedModel):
615
  padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
616
  else:
617
  if len(attention_mask.size()) == 2:
618
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
619
- elif len(attention_mask.size()) == 3:
620
  attention_mask = attention_mask.unsqueeze(1)
 
 
621
  padding_info = attention_mask
622
 
623
  static_embeddings = self.embedding(input_ids)
 
615
  padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
616
  else:
617
  if len(attention_mask.size()) == 2:
 
 
618
  attention_mask = attention_mask.unsqueeze(1)
619
+ if len(attention_mask.size()) != 3:
620
+ raise ValueError("Bare `attention_mask` med to eller tre dimensjoner støttes nå for SDPA.")
621
  padding_info = attention_mask
622
 
623
  static_embeddings = self.embedding(input_ids)