klemenk commited on
Commit
40d7f7e
·
verified ·
1 Parent(s): 8d9662c

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +63 -5
modeling_auristream.py CHANGED
@@ -425,7 +425,37 @@ class CausalSelfAttention(nn.Module):
425
  if hasattr(config, 'use_rope') and not config.use_rope:
426
  self.rotary = None
427
 
428
- def forward(self, x, return_kv=False, ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
431
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
@@ -440,13 +470,41 @@ class CausalSelfAttention(nn.Module):
440
  q = apply_rotary_emb(q, cos, sin)
441
  k = apply_rotary_emb(k, cos, sin)
442
 
443
- y = F.scaled_dot_product_attention(
444
- q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
445
- is_causal=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
 
448
  # output projection
449
- y = self.c_proj(y)
 
 
 
 
 
 
 
 
 
450
  return y
451
 
452
 
 
425
  if hasattr(config, 'use_rope') and not config.use_rope:
426
  self.rotary = None
427
 
428
+
429
+ class CausalSelfAttention(nn.Module):
430
+
431
+ def __init__(self, config):
432
+ super().__init__()
433
+ self.n_head = config.n_head
434
+ self.n_embd = config.n_embd
435
+ self.head_dim = self.n_embd // self.n_head
436
+ assert self.n_embd % self.n_head == 0
437
+ # key, query, value projections for all heads, but in a batch
438
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
439
+ # output projection
440
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
441
+
442
+ rope_theta = 500000
443
+ if hasattr(config, 'rope_theta') and config.rope_theta is not None:
444
+ rope_theta = config.rope_theta
445
+
446
+ self.rotary = Rotary(self.head_dim, base=rope_theta)
447
+
448
+ if hasattr(config, 'use_rope') and not config.use_rope:
449
+ self.rotary = None
450
+
451
+ # Check if we are running on TPU
452
+ try:
453
+ import torch_xla.core.xla_model as xm
454
+ self.tpu = True
455
+ except ImportError:
456
+ self.tpu = False
457
+
458
+ def forward(self, x, return_kv=False, return_attn_maps=False):
459
 
460
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
461
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
 
470
  q = apply_rotary_emb(q, cos, sin)
471
  k = apply_rotary_emb(k, cos, sin)
472
 
473
+ if self.tpu and not return_kv and not return_attn_maps:
474
+ from torch_xla.experimental.custom_kernel import flash_attention
475
+ q_norm = q / math.sqrt(k.size(-1))
476
+ y = flash_attention(
477
+ q_norm.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
478
+ causal=True, partition_spec=('fsdp', None, None, None))
479
+ elif not return_kv and not return_attn_maps:
480
+ y = F.scaled_dot_product_attention(
481
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
482
+ is_causal=True)
483
+ else:
484
+ # manual implementation of attention
485
+ att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
486
+ mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
487
+ mask = mask.view(1, 1, T, T)
488
+ masked_att = att.masked_fill(mask, float('-inf'))
489
+ # upcast to float32 for numerical stability, as per llama implementation
490
+ masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
491
+ masked_att = self.attn_dropout(masked_att)
492
+ # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
493
+ y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
494
 
495
  y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
496
+
497
  # output projection
498
+ y = self.resid_dropout(self.c_proj(y))
499
+
500
+ # return attention maps if requested
501
+ if return_attn_maps:
502
+ return y, F.softmax(att, dim=-1)
503
+
504
+ # return key and value caches if requested
505
+ if return_kv:
506
+ return y, k, v
507
+
508
  return y
509
 
510