Update modeling_auristream.py
Browse files- modeling_auristream.py +27 -5
modeling_auristream.py
CHANGED
|
@@ -424,7 +424,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 424 |
if hasattr(config, 'use_rope') and not config.use_rope:
|
| 425 |
self.rotary = None
|
| 426 |
|
| 427 |
-
def forward(self, x, return_kv=False, ):
|
| 428 |
|
| 429 |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 430 |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
@@ -439,13 +439,35 @@ class CausalSelfAttention(nn.Module):
|
|
| 439 |
q = apply_rotary_emb(q, cos, sin)
|
| 440 |
k = apply_rotary_emb(k, cos, sin)
|
| 441 |
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
|
|
| 447 |
# output projection
|
| 448 |
-
y = self.c_proj(y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
return y
|
| 450 |
|
| 451 |
|
|
|
|
| 424 |
if hasattr(config, 'use_rope') and not config.use_rope:
|
| 425 |
self.rotary = None
|
| 426 |
|
| 427 |
+
def forward(self, x, return_kv=False, return_attn_maps=False):
|
| 428 |
|
| 429 |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 430 |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
|
|
| 439 |
q = apply_rotary_emb(q, cos, sin)
|
| 440 |
k = apply_rotary_emb(k, cos, sin)
|
| 441 |
|
| 442 |
+
if not return_kv and not return_attn_maps:
|
| 443 |
+
y = F.scaled_dot_product_attention(
|
| 444 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
|
| 445 |
+
is_causal=True)
|
| 446 |
+
else:
|
| 447 |
+
# manual implementation of attention
|
| 448 |
+
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
|
| 449 |
+
mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
|
| 450 |
+
mask = mask.view(1, 1, T, T)
|
| 451 |
+
masked_att = att.masked_fill(mask, float('-inf'))
|
| 452 |
+
# upcast to float32 for numerical stability, as per llama implementation
|
| 453 |
+
masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 454 |
+
masked_att = self.attn_dropout(masked_att)
|
| 455 |
+
# (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 456 |
+
y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
|
| 457 |
|
| 458 |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 459 |
+
|
| 460 |
# output projection
|
| 461 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 462 |
+
|
| 463 |
+
# return attention maps if requested
|
| 464 |
+
if return_attn_maps:
|
| 465 |
+
return y, F.softmax(att, dim=-1)
|
| 466 |
+
|
| 467 |
+
# return key and value caches if requested
|
| 468 |
+
if return_kv:
|
| 469 |
+
return y, k, v
|
| 470 |
+
|
| 471 |
return y
|
| 472 |
|
| 473 |
|