Update modeling_auristream.py
Browse files- modeling_auristream.py +6 -1
modeling_auristream.py
CHANGED
|
@@ -495,7 +495,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 495 |
q = q.view(B, T, self.n_head, self.head_dim)
|
| 496 |
v = v.view(B, T, self.n_head, self.head_dim)
|
| 497 |
|
| 498 |
-
k_orig = k.clone()
|
| 499 |
|
| 500 |
if self.rotary is not None:
|
| 501 |
cos, sin = self.rotary(q)
|
|
@@ -550,6 +550,11 @@ class CausalSelfAttention(nn.Module):
|
|
| 550 |
if v_cache is not None:
|
| 551 |
v = torch.cat((v_cache, v), dim=2)
|
| 552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
# manual implementation of attention
|
| 554 |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 555 |
att = F.softmax(att, dim=-1)
|
|
|
|
| 495 |
q = q.view(B, T, self.n_head, self.head_dim)
|
| 496 |
v = v.view(B, T, self.n_head, self.head_dim)
|
| 497 |
|
| 498 |
+
k_orig = k.clone().transpose(1, 2)
|
| 499 |
|
| 500 |
if self.rotary is not None:
|
| 501 |
cos, sin = self.rotary(q)
|
|
|
|
| 550 |
if v_cache is not None:
|
| 551 |
v = torch.cat((v_cache, v), dim=2)
|
| 552 |
|
| 553 |
+
if self.rotary is not None:
|
| 554 |
+
cos, sin = self.rotary(q)
|
| 555 |
+
q = apply_rotary_emb(q.transpose(1, 2), cos, sin).transpose(1, 2)
|
| 556 |
+
k = apply_rotary_emb(k.transpose(1, 2), cos, sin).transpose(1, 2)
|
| 557 |
+
|
| 558 |
# manual implementation of attention
|
| 559 |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 560 |
att = F.softmax(att, dim=-1)
|