Update modeling_gpt2_mq.py
Browse files- modeling_gpt2_mq.py +1 -1
modeling_gpt2_mq.py
CHANGED
|
@@ -244,7 +244,7 @@ class GPT2MQAttention(nn.Module):
|
|
| 244 |
attention_mask = encoder_attention_mask
|
| 245 |
else:
|
| 246 |
query = self.q_attn(hidden_states)
|
| 247 |
-
key, value = self.kv_attn(hidden_states).split(self.
|
| 248 |
|
| 249 |
|
| 250 |
batch_size, seq_length = query.shape[:2]
|
|
|
|
| 244 |
attention_mask = encoder_attention_mask
|
| 245 |
else:
|
| 246 |
query = self.q_attn(hidden_states)
|
| 247 |
+
key, value = self.kv_attn(hidden_states).split(self.head_dim, dim=2)
|
| 248 |
|
| 249 |
|
| 250 |
batch_size, seq_length = query.shape[:2]
|