klemenk commited on
Commit
f32a41a
·
verified ·
1 Parent(s): d47c2bd

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +6 -1
modeling_auristream.py CHANGED
@@ -495,7 +495,7 @@ class CausalSelfAttention(nn.Module):
495
  q = q.view(B, T, self.n_head, self.head_dim)
496
  v = v.view(B, T, self.n_head, self.head_dim)
497
 
498
- k_orig = k.clone()
499
 
500
  if self.rotary is not None:
501
  cos, sin = self.rotary(q)
@@ -550,6 +550,11 @@ class CausalSelfAttention(nn.Module):
550
  if v_cache is not None:
551
  v = torch.cat((v_cache, v), dim=2)
552
 
 
 
 
 
 
553
  # manual implementation of attention
554
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
555
  att = F.softmax(att, dim=-1)
 
495
  q = q.view(B, T, self.n_head, self.head_dim)
496
  v = v.view(B, T, self.n_head, self.head_dim)
497
 
498
+ k_orig = k.clone().transpose(1, 2)
499
 
500
  if self.rotary is not None:
501
  cos, sin = self.rotary(q)
 
550
  if v_cache is not None:
551
  v = torch.cat((v_cache, v), dim=2)
552
 
553
+ if self.rotary is not None:
554
+ cos, sin = self.rotary(q)
555
+ q = apply_rotary_emb(q.transpose(1, 2), cos, sin).transpose(1, 2)
556
+ k = apply_rotary_emb(k.transpose(1, 2), cos, sin).transpose(1, 2)
557
+
558
  # manual implementation of attention
559
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
560
  att = F.softmax(att, dim=-1)