klemenk commited on
Commit
d09b97c
·
verified ·
1 Parent(s): d2f818b

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +63 -19
modeling_auristream.py CHANGED
@@ -495,8 +495,6 @@ class CausalSelfAttention(nn.Module):
495
  q = q.view(B, T, self.n_head, self.head_dim)
496
  v = v.view(B, T, self.n_head, self.head_dim)
497
 
498
- k_orig = k.clone().transpose(1, 2)
499
-
500
  if self.rotary is not None:
501
  cos, sin = self.rotary(q)
502
  q = apply_rotary_emb(q, cos, sin)
@@ -531,40 +529,86 @@ class CausalSelfAttention(nn.Module):
531
 
532
  # return key and value caches if requested
533
  if return_kv:
534
- return y, k_orig, v
535
 
536
  return y
537
 
538
- def kv_cache_forward(self, x, k_cache=None, v_cache=None):
539
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
 
 
 
 
 
 
 
 
541
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
542
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
543
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
544
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
545
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
546
-
547
- # append cached keys and values with new keys and values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  if k_cache is not None:
549
  k = torch.cat((k_cache, k), dim=2)
550
  if v_cache is not None:
551
  v = torch.cat((v_cache, v), dim=2)
552
-
553
- if self.rotary is not None:
554
- cos, sin = self.rotary(q)
555
- q = apply_rotary_emb(q, cos, sin)
556
- k = apply_rotary_emb(k, cos, sin)
557
-
558
  # manual implementation of attention
559
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
560
  att = F.softmax(att, dim=-1)
561
  y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
562
-
563
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
564
-
565
  # output projection
566
  y = self.c_proj(y)
567
-
568
  return y, k, v
569
 
570
 
 
495
  q = q.view(B, T, self.n_head, self.head_dim)
496
  v = v.view(B, T, self.n_head, self.head_dim)
497
 
 
 
498
  if self.rotary is not None:
499
  cos, sin = self.rotary(q)
500
  q = apply_rotary_emb(q, cos, sin)
 
529
 
530
  # return key and value caches if requested
531
  if return_kv:
532
+ return y, k, v
533
 
534
  return y
535
 
536
+ # def kv_cache_forward(self, x, k_cache=None, v_cache=None):
537
+ # B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
538
+
539
+ # # calculate query, key, values for all heads in batch and move head forward to be the batch dim
540
+ # q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
541
+ # k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
542
+ # q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
543
+ # v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
544
+
545
+ # # append cached keys and values with new keys and values
546
+ # if k_cache is not None:
547
+ # k = torch.cat((k_cache, k), dim=2)
548
+ # if v_cache is not None:
549
+ # v = torch.cat((v_cache, v), dim=2)
550
+
551
+ # if self.rotary is not None:
552
+ # cos, sin = self.rotary(q)
553
+ # q = apply_rotary_emb(q, cos, sin)
554
+ # k = apply_rotary_emb(k, cos, sin)
555
+
556
+ # # manual implementation of attention
557
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
558
+ # att = F.softmax(att, dim=-1)
559
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
560
+
561
+ # y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
562
 
563
+ # # output projection
564
+ # y = self.c_proj(y)
565
+
566
+ # return y, k, v
567
+
568
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
569
+ B, T, C = x.size() # T=1 for single new token
570
+
571
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
572
  q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
573
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, 1, hs)
574
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, 1, hs)
575
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, 1, hs)
576
+
577
+ # Apply RoPE BEFORE concatenation, using correct absolute position
578
+ if self.rotary is not None:
579
+ # Determine the position of the new token
580
+ cache_len = k_cache.shape[2] if k_cache is not None else 0
581
+
582
+ # Create a dummy tensor with the correct sequence position for rotary computation
583
+ # We need shape (B, cache_len + 1, nh, hs) but only use the last position
584
+ dummy = torch.zeros(B, cache_len + T, self.n_head, self.head_dim,
585
+ device=q.device, dtype=q.dtype)
586
+ cos, sin = self.rotary(dummy)
587
+
588
+ # Extract rotary embeddings for only the new token position
589
+ cos = cos[:, cache_len:cache_len+T, :, :]
590
+ sin = sin[:, cache_len:cache_len+T, :, :]
591
+
592
+ # Apply rotary embeddings to new q and k only
593
+ q = apply_rotary_emb(q, cos, sin)
594
+ k = apply_rotary_emb(k, cos, sin)
595
+
596
+ # NOW concatenate with cache (cached keys already have correct RoPE applied)
597
  if k_cache is not None:
598
  k = torch.cat((k_cache, k), dim=2)
599
  if v_cache is not None:
600
  v = torch.cat((v_cache, v), dim=2)
601
+
 
 
 
 
 
602
  # manual implementation of attention
603
  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
604
  att = F.softmax(att, dim=-1)
605
  y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
606
+
607
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
608
+
609
  # output projection
610
  y = self.c_proj(y)
611
+
612
  return y, k, v
613
 
614