Commit ·
2bf3643
1
Parent(s): cceda44
fix bug when num_kv > 1
Browse files- modeling_RW.py +2 -2
modeling_RW.py
CHANGED
|
@@ -290,8 +290,8 @@ class Attention(nn.Module):
|
|
| 290 |
|
| 291 |
if alibi is None:
|
| 292 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 293 |
-
key_layer_ = key_layer.reshape(batch_size, self.
|
| 294 |
-
value_layer_ = value_layer.reshape(batch_size, self.
|
| 295 |
|
| 296 |
attn_output = F.scaled_dot_product_attention(
|
| 297 |
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
|
|
|
| 290 |
|
| 291 |
if alibi is None:
|
| 292 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 293 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 294 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 295 |
|
| 296 |
attn_output = F.scaled_dot_product_attention(
|
| 297 |
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|