Upload folder using huggingface_hub
Browse files
long_term_attention_gibbs.py
CHANGED
|
@@ -309,7 +309,7 @@ class LongTermAttention(nn.Module):
|
|
| 309 |
self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
|
| 310 |
self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
|
| 311 |
self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
|
| 312 |
-
context = self.expected_value(self.score) # Shape [
|
| 313 |
|
| 314 |
-
return context.contiguous().transpose(1,2).reshape(
|
| 315 |
|
|
|
|
| 309 |
self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
|
| 310 |
self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
|
| 311 |
self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
|
| 312 |
+
context = self.expected_value(self.score) # Shape [B, n_head, qlen, d_head]
|
| 313 |
|
| 314 |
+
return context.contiguous().transpose(1,2).reshape(batch_size, qlen, -1)
|
| 315 |
|