Update modeling_auristream.py
Browse files- modeling_auristream.py +1 -14
modeling_auristream.py
CHANGED
|
@@ -448,13 +448,6 @@ class CausalSelfAttention(nn.Module):
|
|
| 448 |
if hasattr(config, 'use_rope') and not config.use_rope:
|
| 449 |
self.rotary = None
|
| 450 |
|
| 451 |
-
# Check if we are running on TPU
|
| 452 |
-
try:
|
| 453 |
-
import torch_xla.core.xla_model as xm
|
| 454 |
-
self.tpu = True
|
| 455 |
-
except ImportError:
|
| 456 |
-
self.tpu = False
|
| 457 |
-
|
| 458 |
def forward(self, x, return_kv=False, return_attn_maps=False):
|
| 459 |
|
| 460 |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
@@ -470,13 +463,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 470 |
q = apply_rotary_emb(q, cos, sin)
|
| 471 |
k = apply_rotary_emb(k, cos, sin)
|
| 472 |
|
| 473 |
-
if
|
| 474 |
-
from torch_xla.experimental.custom_kernel import flash_attention
|
| 475 |
-
q_norm = q / math.sqrt(k.size(-1))
|
| 476 |
-
y = flash_attention(
|
| 477 |
-
q_norm.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
|
| 478 |
-
causal=True, partition_spec=('fsdp', None, None, None))
|
| 479 |
-
elif not return_kv and not return_attn_maps:
|
| 480 |
y = F.scaled_dot_product_attention(
|
| 481 |
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
|
| 482 |
is_causal=True)
|
|
|
|
| 448 |
if hasattr(config, 'use_rope') and not config.use_rope:
|
| 449 |
self.rotary = None
|
| 450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
def forward(self, x, return_kv=False, return_attn_maps=False):
|
| 452 |
|
| 453 |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
|
|
| 463 |
q = apply_rotary_emb(q, cos, sin)
|
| 464 |
k = apply_rotary_emb(k, cos, sin)
|
| 465 |
|
| 466 |
+
if not return_kv and not return_attn_maps:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
y = F.scaled_dot_product_attention(
|
| 468 |
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
|
| 469 |
is_causal=True)
|