klemenk commited on
Commit
75ed49a
·
verified ·
1 Parent(s): e0d0c9f

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +27 -0
modeling_auristream.py CHANGED
@@ -472,6 +472,33 @@ class CausalSelfAttention(nn.Module):
472
 
473
  return y
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
  class MLP(nn.Module):
477
 
 
472
 
473
  return y
474
 
475
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
476
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
477
+
478
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
479
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
480
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
481
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
482
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
483
+
484
+ # append cached keys and values with new keys and values
485
+ if k_cache is not None:
486
+ k = torch.cat((k_cache, k), dim=2)
487
+ if v_cache is not None:
488
+ v = torch.cat((v_cache, v), dim=2)
489
+
490
+ # manual implementation of attention
491
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
492
+ att = F.softmax(att, dim=-1)
493
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
494
+
495
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
496
+
497
+ # output projection
498
+ y = self.c_proj(y)
499
+
500
+ return y, k, v
501
+
502
 
503
  class MLP(nn.Module):
504