Update modeling_t2.py
Browse files- modeling_t2.py +1 -1
modeling_t2.py
CHANGED
|
@@ -235,7 +235,7 @@ class TransformerAttention(nn.Module):
|
|
| 235 |
k = torch.cat((past_key, k), dim=-2)
|
| 236 |
v = torch.cat((past_value, v), dim=-2)
|
| 237 |
|
| 238 |
-
cos, sin = self.rotary_emb(v,
|
| 239 |
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 240 |
|
| 241 |
if use_cache is True:
|
|
|
|
| 235 |
k = torch.cat((past_key, k), dim=-2)
|
| 236 |
v = torch.cat((past_value, v), dim=-2)
|
| 237 |
|
| 238 |
+
cos, sin = self.rotary_emb(v, position_ids)
|
| 239 |
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 240 |
|
| 241 |
if use_cache is True:
|