Update modeling_auristream.py
Browse files- modeling_auristream.py +3 -0
modeling_auristream.py
CHANGED
|
@@ -445,6 +445,9 @@ class CausalSelfAttention(nn.Module):
|
|
| 445 |
is_causal=True)
|
| 446 |
else:
|
| 447 |
# manual implementation of attention
|
|
|
|
|
|
|
|
|
|
| 448 |
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
|
| 449 |
mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
|
| 450 |
mask = mask.view(1, 1, T, T)
|
|
|
|
| 445 |
is_causal=True)
|
| 446 |
else:
|
| 447 |
# manual implementation of attention
|
| 448 |
+
q = q.transpose(1, 2)
|
| 449 |
+
k = k.transpose(1, 2)
|
| 450 |
+
v = v.transpose(1, 2)
|
| 451 |
att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
|
| 452 |
mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
|
| 453 |
mask = mask.view(1, 1, T, T)
|