klemenk commited on
Commit
d47c2bd
·
verified ·
1 Parent(s): 9c1f3a9

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +3 -1
modeling_auristream.py CHANGED
@@ -495,6 +495,8 @@ class CausalSelfAttention(nn.Module):
495
  q = q.view(B, T, self.n_head, self.head_dim)
496
  v = v.view(B, T, self.n_head, self.head_dim)
497
 
 
 
498
  if self.rotary is not None:
499
  cos, sin = self.rotary(q)
500
  q = apply_rotary_emb(q, cos, sin)
@@ -529,7 +531,7 @@ class CausalSelfAttention(nn.Module):
529
 
530
  # return key and value caches if requested
531
  if return_kv:
532
- return y, k, v
533
 
534
  return y
535
 
 
495
  q = q.view(B, T, self.n_head, self.head_dim)
496
  v = v.view(B, T, self.n_head, self.head_dim)
497
 
498
+ k_orig = k.clone()
499
+
500
  if self.rotary is not None:
501
  cos, sin = self.rotary(q)
502
  q = apply_rotary_emb(q, cos, sin)
 
531
 
532
  # return key and value caches if requested
533
  if return_kv:
534
+ return y, k_orig, v
535
 
536
  return y
537