Update modeling_auristream.py
Browse files- modeling_auristream.py +3 -1
modeling_auristream.py
CHANGED
|
@@ -495,6 +495,8 @@ class CausalSelfAttention(nn.Module):
|
|
| 495 |
q = q.view(B, T, self.n_head, self.head_dim)
|
| 496 |
v = v.view(B, T, self.n_head, self.head_dim)
|
| 497 |
|
|
|
|
|
|
|
| 498 |
if self.rotary is not None:
|
| 499 |
cos, sin = self.rotary(q)
|
| 500 |
q = apply_rotary_emb(q, cos, sin)
|
|
@@ -529,7 +531,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 529 |
|
| 530 |
# return key and value caches if requested
|
| 531 |
if return_kv:
|
| 532 |
-
return y,
|
| 533 |
|
| 534 |
return y
|
| 535 |
|
|
|
|
| 495 |
q = q.view(B, T, self.n_head, self.head_dim)
|
| 496 |
v = v.view(B, T, self.n_head, self.head_dim)
|
| 497 |
|
| 498 |
+
k_orig = k.clone()
|
| 499 |
+
|
| 500 |
if self.rotary is not None:
|
| 501 |
cos, sin = self.rotary(q)
|
| 502 |
q = apply_rotary_emb(q, cos, sin)
|
|
|
|
| 531 |
|
| 532 |
# return key and value caches if requested
|
| 533 |
if return_kv:
|
| 534 |
+
return y, k_orig, v
|
| 535 |
|
| 536 |
return y
|
| 537 |
|