klemenk commited on
Commit
12d67f7
·
verified ·
1 Parent(s): 40d7f7e

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +1 -14
modeling_auristream.py CHANGED
@@ -448,13 +448,6 @@ class CausalSelfAttention(nn.Module):
448
  if hasattr(config, 'use_rope') and not config.use_rope:
449
  self.rotary = None
450
 
451
- # Check if we are running on TPU
452
- try:
453
- import torch_xla.core.xla_model as xm
454
- self.tpu = True
455
- except ImportError:
456
- self.tpu = False
457
-
458
  def forward(self, x, return_kv=False, return_attn_maps=False):
459
 
460
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
@@ -470,13 +463,7 @@ class CausalSelfAttention(nn.Module):
470
  q = apply_rotary_emb(q, cos, sin)
471
  k = apply_rotary_emb(k, cos, sin)
472
 
473
- if self.tpu and not return_kv and not return_attn_maps:
474
- from torch_xla.experimental.custom_kernel import flash_attention
475
- q_norm = q / math.sqrt(k.size(-1))
476
- y = flash_attention(
477
- q_norm.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
478
- causal=True, partition_spec=('fsdp', None, None, None))
479
- elif not return_kv and not return_attn_maps:
480
  y = F.scaled_dot_product_attention(
481
  q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
482
  is_causal=True)
 
448
  if hasattr(config, 'use_rope') and not config.use_rope:
449
  self.rotary = None
450
 
 
 
 
 
 
 
 
451
  def forward(self, x, return_kv=False, return_attn_maps=False):
452
 
453
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
 
463
  q = apply_rotary_emb(q, cos, sin)
464
  k = apply_rotary_emb(k, cos, sin)
465
 
466
+ if not return_kv and not return_attn_maps:
 
 
 
 
 
 
467
  y = F.scaled_dot_product_attention(
468
  q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
469
  is_causal=True)