klemenk commited on
Commit
12946c0
·
verified ·
1 Parent(s): eb83fd9

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +27 -5
modeling_auristream.py CHANGED
@@ -424,7 +424,7 @@ class CausalSelfAttention(nn.Module):
424
  if hasattr(config, 'use_rope') and not config.use_rope:
425
  self.rotary = None
426
 
427
- def forward(self, x, return_kv=False, ):
428
 
429
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
430
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
@@ -439,13 +439,35 @@ class CausalSelfAttention(nn.Module):
439
  q = apply_rotary_emb(q, cos, sin)
440
  k = apply_rotary_emb(k, cos, sin)
441
 
442
- y = F.scaled_dot_product_attention(
443
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
444
- is_causal=True)
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
 
447
  # output projection
448
- y = self.c_proj(y)
 
 
 
 
 
 
 
 
 
449
  return y
450
 
451
 
 
424
  if hasattr(config, 'use_rope') and not config.use_rope:
425
  self.rotary = None
426
 
427
+ def forward(self, x, return_kv=False, return_attn_maps=False):
428
 
429
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
430
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
 
439
  q = apply_rotary_emb(q, cos, sin)
440
  k = apply_rotary_emb(k, cos, sin)
441
 
442
+ if not return_kv and not return_attn_maps:
443
+ y = F.scaled_dot_product_attention(
444
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
445
+ is_causal=True)
446
+ else:
447
+ # manual implementation of attention
448
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
449
+ mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
450
+ mask = mask.view(1, 1, T, T)
451
+ masked_att = att.masked_fill(mask, float('-inf'))
452
+ # upcast to float32 for numerical stability, as per llama implementation
453
+ masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
454
+ masked_att = self.attn_dropout(masked_att)
455
+ # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
456
+ y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
457
 
458
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
459
+
460
  # output projection
461
+ y = self.resid_dropout(self.c_proj(y))
462
+
463
+ # return attention maps if requested
464
+ if return_attn_maps:
465
+ return y, F.softmax(att, dim=-1)
466
+
467
+ # return key and value caches if requested
468
+ if return_kv:
469
+ return y, k, v
470
+
471
  return y
472
 
473