Rihong commited on
Commit
5cf25a4
·
verified ·
1 Parent(s): 32d787e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. long_term_attention_gibbs.py +2 -2
long_term_attention_gibbs.py CHANGED
@@ -309,7 +309,7 @@ class LongTermAttention(nn.Module):
309
  self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
310
  self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
311
  self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
312
- context = self.expected_value(self.score) # Shape [1, 32, 768]
313
 
314
- return context.contiguous().transpose(1,2).reshape(1, qlen, -1)
315
 
 
309
  self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
310
  self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
311
  self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
312
+ context = self.expected_value(self.score) # Shape [B, n_head, qlen, d_head]
313
 
314
+ return context.contiguous().transpose(1,2).reshape(batch_size, qlen, -1)
315