Fix the kv-cache dimensions
#47
by cchudant - opened
- modelling_RW.py +1 -1
modelling_RW.py
CHANGED
|
@@ -271,7 +271,7 @@ class Attention(nn.Module):
|
|
| 271 |
# concatenate along seq_length dimension:
|
| 272 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
| 273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
| 274 |
-
key_layer = torch.cat((past_key, key_layer), dim=
|
| 275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 276 |
|
| 277 |
_, kv_length, _ = key_layer.shape
|
|
|
|
| 271 |
# concatenate along seq_length dimension:
|
| 272 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
| 273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
| 274 |
+
key_layer = torch.cat((past_key, key_layer), dim=2)
|
| 275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 276 |
|
| 277 |
_, kv_length, _ = key_layer.shape
|