Update modeling_auristream.py
Browse files- modeling_auristream.py +63 -5
modeling_auristream.py
CHANGED
|
@@ -425,7 +425,37 @@ class CausalSelfAttention(nn.Module):
|
|
| 425 |
if hasattr(config, 'use_rope') and not config.use_rope:
|
| 426 |
self.rotary = None
|
| 427 |
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 431 |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
@@ -440,13 +470,41 @@ class CausalSelfAttention(nn.Module):
|
|
| 440 |
q = apply_rotary_emb(q, cos, sin)
|
| 441 |
k = apply_rotary_emb(k, cos, sin)
|
| 442 |
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
|
|
| 448 |
# output projection
|
| 449 |
-
y = self.c_proj(y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
return y
|
| 451 |
|
| 452 |
|
|
|
|
| 425 |
if hasattr(config, 'use_rope') and not config.use_rope:
|
| 426 |
self.rotary = None
|
| 427 |
|
| 428 |
+
|
| 429 |
+
class CausalSelfAttention(nn.Module):
|
| 430 |
+
|
| 431 |
+
def __init__(self, config):
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.n_head = config.n_head
|
| 434 |
+
self.n_embd = config.n_embd
|
| 435 |
+
self.head_dim = self.n_embd // self.n_head
|
| 436 |
+
assert self.n_embd % self.n_head == 0
|
| 437 |
+
# key, query, value projections for all heads, but in a batch
|
| 438 |
+
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
|
| 439 |
+
# output projection
|
| 440 |
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
| 441 |
+
|
| 442 |
+
rope_theta = 500000
|
| 443 |
+
if hasattr(config, 'rope_theta') and config.rope_theta is not None:
|
| 444 |
+
rope_theta = config.rope_theta
|
| 445 |
+
|
| 446 |
+
self.rotary = Rotary(self.head_dim, base=rope_theta)
|
| 447 |
+
|
| 448 |
+
if hasattr(config, 'use_rope') and not config.use_rope:
|
| 449 |
+
self.rotary = None
|
| 450 |
+
|
| 451 |
+
# Check if we are running on TPU
|
| 452 |
+
try:
|
| 453 |
+
import torch_xla.core.xla_model as xm
|
| 454 |
+
self.tpu = True
|
| 455 |
+
except ImportError:
|
| 456 |
+
self.tpu = False
|
| 457 |
+
|
| 458 |
+
def forward(self, x, return_kv=False, return_attn_maps=False):
|
| 459 |
|
| 460 |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 461 |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
|
|
| 470 |
q = apply_rotary_emb(q, cos, sin)
|
| 471 |
k = apply_rotary_emb(k, cos, sin)
|
| 472 |
|
| 473 |
+
if self.tpu and not return_kv and not return_attn_maps:
|
| 474 |
+
from torch_xla.experimental.custom_kernel import flash_attention
|
| 475 |
+
q_norm = q / math.sqrt(k.size(-1))
|
| 476 |
+
y = flash_attention(
|
| 477 |
+
q_norm.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
|
| 478 |
+
causal=True, partition_spec=('fsdp', None, None, None))
|
| 479 |
+
elif not return_kv and not return_attn_maps:
|
| 480 |
+
y = F.scaled_dot_product_attention(
|
| 481 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
|
| 482 |
+
is_causal=True)
|
| 483 |
+
else:
|
| 484 |
+
# manual implementation of attention
|
| 485 |
+
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
|
| 486 |
+
mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
|
| 487 |
+
mask = mask.view(1, 1, T, T)
|
| 488 |
+
masked_att = att.masked_fill(mask, float('-inf'))
|
| 489 |
+
# upcast to float32 for numerical stability, as per llama implementation
|
| 490 |
+
masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 491 |
+
masked_att = self.attn_dropout(masked_att)
|
| 492 |
+
# (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 493 |
+
y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
|
| 494 |
|
| 495 |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 496 |
+
|
| 497 |
# output projection
|
| 498 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 499 |
+
|
| 500 |
+
# return attention maps if requested
|
| 501 |
+
if return_attn_maps:
|
| 502 |
+
return y, F.softmax(att, dim=-1)
|
| 503 |
+
|
| 504 |
+
# return key and value caches if requested
|
| 505 |
+
if return_kv:
|
| 506 |
+
return y, k, v
|
| 507 |
+
|
| 508 |
return y
|
| 509 |
|
| 510 |
|