fix
Browse files- modeling_gptbert.py +6 -5
modeling_gptbert.py
CHANGED
|
@@ -439,8 +439,8 @@ class SelfAttention(nn.Module):
|
|
| 439 |
|
| 440 |
else:
|
| 441 |
# Standard attention path
|
| 442 |
-
query_length =
|
| 443 |
-
key_length =
|
| 444 |
|
| 445 |
query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
|
| 446 |
key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
|
|
@@ -451,7 +451,8 @@ class SelfAttention(nn.Module):
|
|
| 451 |
|
| 452 |
if v1 is None:
|
| 453 |
v1 = value
|
| 454 |
-
|
|
|
|
| 455 |
|
| 456 |
# Apply rotary embeddings
|
| 457 |
query = self.rope_embedding(query)
|
|
@@ -615,9 +616,9 @@ class GptBertModel(GptBertPreTrainedModel):
|
|
| 615 |
padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
|
| 616 |
else:
|
| 617 |
if len(attention_mask.size()) == 2:
|
|
|
|
|
|
|
| 618 |
attention_mask = attention_mask.unsqueeze(1)
|
| 619 |
-
if len(attention_mask.size()) != 3:
|
| 620 |
-
raise ValueError("Bare `attention_mask` med to eller tre dimensjoner støttes nå for SDPA.")
|
| 621 |
padding_info = attention_mask
|
| 622 |
|
| 623 |
static_embeddings = self.embedding(input_ids)
|
|
|
|
| 439 |
|
| 440 |
else:
|
| 441 |
# Standard attention path
|
| 442 |
+
query_length = query.size(1)
|
| 443 |
+
key_length = key.size(1)
|
| 444 |
|
| 445 |
query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
|
| 446 |
key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
|
|
|
|
| 451 |
|
| 452 |
if v1 is None:
|
| 453 |
v1 = value
|
| 454 |
+
else:
|
| 455 |
+
value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
|
| 456 |
|
| 457 |
# Apply rotary embeddings
|
| 458 |
query = self.rope_embedding(query)
|
|
|
|
| 616 |
padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
|
| 617 |
else:
|
| 618 |
if len(attention_mask.size()) == 2:
|
| 619 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 620 |
+
elif len(attention_mask.size()) == 3:
|
| 621 |
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
|
|
|
| 622 |
padding_info = attention_mask
|
| 623 |
|
| 624 |
static_embeddings = self.embedding(input_ids)
|