klemenk commited on
Commit
1688ee1
·
verified ·
1 Parent(s): 12946c0

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +3 -0
modeling_auristream.py CHANGED
@@ -445,6 +445,9 @@ class CausalSelfAttention(nn.Module):
445
  is_causal=True)
446
  else:
447
  # manual implementation of attention
 
 
 
448
  att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
449
  mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
450
  mask = mask.view(1, 1, T, T)
 
445
  is_causal=True)
446
  else:
447
  # manual implementation of attention
448
+ q = q.transpose(1, 2)
449
+ k = k.transpose(1, 2)
450
+ v = v.transpose(1, 2)
451
  att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1)))
452
  mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device)
453
  mask = mask.view(1, 1, T, T)