Update modeling_auristream.py
Browse files- modeling_auristream.py +0 -1
modeling_auristream.py
CHANGED
|
@@ -454,7 +454,6 @@ class CausalSelfAttention(nn.Module):
|
|
| 454 |
masked_att = att.masked_fill(mask, float('-inf'))
|
| 455 |
# upcast to float32 for numerical stability, as per llama implementation
|
| 456 |
masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 457 |
-
masked_att = self.attn_dropout(masked_att)
|
| 458 |
# (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 459 |
y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
|
| 460 |
|
|
|
|
| 454 |
masked_att = att.masked_fill(mask, float('-inf'))
|
| 455 |
# upcast to float32 for numerical stability, as per llama implementation
|
| 456 |
masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
|
|
| 457 |
# (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 458 |
y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v)
|
| 459 |
|