klemenk commited on
Commit
b500e3b
·
verified ·
1 Parent(s): e44a01f

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +47 -6
modeling_auristream.py CHANGED
@@ -487,14 +487,55 @@ class CausalSelfAttention(nn.Module):
487
  if v_cache is not None:
488
  v = torch.cat((v_cache, v), dim=2)
489
 
490
- # manual implementation of attention
491
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
492
- att = F.softmax(att, dim=-1)
493
- y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
 
 
 
 
 
 
 
 
 
 
494
 
495
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
496
 
497
- # output projection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  y = self.c_proj(y)
499
 
500
  return y, k, v
 
487
  if v_cache is not None:
488
  v = torch.cat((v_cache, v), dim=2)
489
 
490
+ if not return_kv and not return_attn_maps:
491
+ y = F.scaled_dot_product_attention(
492
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
493
+ is_causal=True)
494
+ else:
495
+ # manual implementation of attention
496
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
497
+ att = F.softmax(att, dim=-1)
498
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
499
+
500
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
501
+
502
+ # output projection
503
+ y = self.c_proj(y)
504
 
505
+ return y, k, v
506
 
507
+ def kv_cache_forward(
508
+ self,
509
+ x: torch.Tensor,
510
+ k_cache: torch.Tensor | None = None,
511
+ v_cache: torch.Tensor | None = None,
512
+ ):
513
+ B, T, C = x.size()
514
+
515
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
516
+ q = q.view(B, T, self.n_head, self.head_dim) # (B, T, n_head, d)
517
+ k = k.view(B, T, self.n_head, self.head_dim)
518
+ v = v.view(B, T, self.n_head, self.head_dim)
519
+
520
+ if self.rotary is not None:
521
+ cos, sin = self.rotary(q) # cos/sin match (B, T, n_head, d)
522
+ q = apply_rotary_emb(q, cos, sin)
523
+ k = apply_rotary_emb(k, cos, sin)
524
+
525
+ q = q.transpose(1, 2) # (B, n_head, T, d)
526
+ k = k.transpose(1, 2)
527
+ v = v.transpose(1, 2)
528
+
529
+ if k_cache is not None:
530
+ k = torch.cat([k_cache, k], dim=2) # time dim grows
531
+ if v_cache is not None:
532
+ v = torch.cat([v_cache, v], dim=2)
533
+
534
+ y = F.scaled_dot_product_attention(
535
+ q, k, v, is_causal=True # PyTorch ≥ 2.1
536
+ ) # (B, n_head, T, d)
537
+
538
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
539
  y = self.c_proj(y)
540
 
541
  return y, k, v