klemenk commited on
Commit
3845e7a
·
verified ·
1 Parent(s): 12d67f7

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +32 -27
modeling_auristream.py CHANGED
@@ -141,7 +141,6 @@ class AuriStream(PreTrainedModel):
141
  top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
142
  """
143
  Samples an integer from the distribution of logits
144
-
145
  Parameters:
146
  logits (torch.FloatTensor): The logits of the distribution
147
  temp (float): The temperature of the sampling, if 0.0, then argmax is used
@@ -403,29 +402,6 @@ class Block(nn.Module):
403
  return x
404
 
405
 
406
- class CausalSelfAttention(nn.Module):
407
-
408
- def __init__(self, config):
409
- super().__init__()
410
- self.n_head = config.n_head
411
- self.n_embd = config.n_embd
412
- self.head_dim = self.n_embd // self.n_head
413
- assert self.n_embd % self.n_head == 0
414
- # key, query, value projections for all heads, but in a batch
415
- self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
416
- # output projection
417
- self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
418
-
419
- rope_theta = 500000
420
- if hasattr(config, 'rope_theta') and config.rope_theta is not None:
421
- rope_theta = config.rope_theta
422
-
423
- self.rotary = Rotary(self.head_dim, base=rope_theta)
424
-
425
- if hasattr(config, 'use_rope') and not config.use_rope:
426
- self.rotary = None
427
-
428
-
429
  class CausalSelfAttention(nn.Module):
430
 
431
  def __init__(self, config):
@@ -469,20 +445,22 @@ class CausalSelfAttention(nn.Module):
469
  is_causal=True)
470
  else:
471
  # manual implementation of attention
 
 
 
472
  att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
473
  mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
474
  mask = mask.view(1, 1, T, T)
475
  masked_att = att.masked_fill(mask, float('-inf'))
476
  # upcast to float32 for numerical stability, as per llama implementation
477
  masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
478
- masked_att = self.attn_dropout(masked_att)
479
  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
480
  y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
481
 
482
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
483
 
484
  # output projection
485
- y = self.resid_dropout(self.c_proj(y))
486
 
487
  # return attention maps if requested
488
  if return_attn_maps:
@@ -494,6 +472,33 @@ class CausalSelfAttention(nn.Module):
494
 
495
  return y
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
  class MLP(nn.Module):
499
 
@@ -560,4 +565,4 @@ class RMSNorm(nn.Module):
560
  output = self._norm(x.float()).type_as(x)
561
  if self.weight is not None:
562
  return output * self.weight
563
- return output
 
141
  top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
142
  """
143
  Samples an integer from the distribution of logits
 
144
  Parameters:
145
  logits (torch.FloatTensor): The logits of the distribution
146
  temp (float): The temperature of the sampling, if 0.0, then argmax is used
 
402
  return x
403
 
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  class CausalSelfAttention(nn.Module):
406
 
407
  def __init__(self, config):
 
445
  is_causal=True)
446
  else:
447
  # manual implementation of attention
448
+ q = q.transpose(1, 2)
449
+ k = k.transpose(1, 2)
450
+ v = v.transpose(1, 2)
451
  att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
452
  mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
453
  mask = mask.view(1, 1, T, T)
454
  masked_att = att.masked_fill(mask, float('-inf'))
455
  # upcast to float32 for numerical stability, as per llama implementation
456
  masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
 
457
  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
458
  y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
459
 
460
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
461
 
462
  # output projection
463
+ y = self.c_proj(y)
464
 
465
  # return attention maps if requested
466
  if return_attn_maps:
 
472
 
473
  return y
474
 
475
+ def kv_cache_forward(self, x, k_cache=None, v_cache=None):
476
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
477
+
478
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
479
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
480
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
481
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
482
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
483
+
484
+ # append cached keys and values with new keys and values
485
+ if k_cache is not None:
486
+ k = torch.cat((k_cache, k), dim=2)
487
+ if v_cache is not None:
488
+ v = torch.cat((v_cache, v), dim=2)
489
+
490
+ # manual implementation of attention
491
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
492
+ att = F.softmax(att, dim=-1)
493
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
494
+
495
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
496
+
497
+ # output projection
498
+ y = self.c_proj(y)
499
+
500
+ return y, k, v
501
+
502
 
503
  class MLP(nn.Module):
504
 
 
565
  output = self._norm(x.float()).type_as(x)
566
  if self.weight is not None:
567
  return output * self.weight
568
+ return output