davda54 commited on
Commit
87e0acb
·
verified ·
1 Parent(s): 16b2b5e
Files changed (1) hide show
  1. modeling_gptbert.py +6 -5
modeling_gptbert.py CHANGED
@@ -439,8 +439,8 @@ class SelfAttention(nn.Module):
439
 
440
  else:
441
  # Standard attention path
442
- query_length = hidden_layer.size(0)
443
- key_length = hidden_layer.size(0)
444
 
445
  query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
446
  key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
@@ -451,7 +451,8 @@ class SelfAttention(nn.Module):
451
 
452
  if v1 is None:
453
  v1 = value
454
- value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
 
455
 
456
  # Apply rotary embeddings
457
  query = self.rope_embedding(query)
@@ -615,9 +616,9 @@ class GptBertModel(GptBertPreTrainedModel):
615
  padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
616
  else:
617
  if len(attention_mask.size()) == 2:
 
 
618
  attention_mask = attention_mask.unsqueeze(1)
619
- if len(attention_mask.size()) != 3:
620
- raise ValueError("Bare `attention_mask` med to eller tre dimensjoner støttes nå for SDPA.")
621
  padding_info = attention_mask
622
 
623
  static_embeddings = self.embedding(input_ids)
 
439
 
440
  else:
441
  # Standard attention path
442
+ query_length = query.size(1)
443
+ key_length = key.size(1)
444
 
445
  query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
446
  key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
 
451
 
452
  if v1 is None:
453
  v1 = value
454
+ else:
455
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
456
 
457
  # Apply rotary embeddings
458
  query = self.rope_embedding(query)
 
616
  padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
617
  else:
618
  if len(attention_mask.size()) == 2:
619
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
620
+ elif len(attention_mask.size()) == 3:
621
  attention_mask = attention_mask.unsqueeze(1)
 
 
622
  padding_info = attention_mask
623
 
624
  static_embeddings = self.embedding(input_ids)