amewebstudio commited on
Commit
0367c8a
·
verified ·
1 Parent(s): 87ae044

Upload modeling_mnemosyne.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_mnemosyne.py +1 -1
modeling_mnemosyne.py CHANGED
@@ -105,7 +105,7 @@ class MnemosyneAttention(nn.Module):
105
  k=self.k_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
106
  v=self.v_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
107
  cos,sin=self.rotary(q,pos_ids)
108
- q=(q*cos)+(rotate_half(q)*sin); k=(k*cos)+(rotate_half(k)*sin)
109
  if past_kv: k,v=torch.cat([past_kv[0].to(dt),k],2),torch.cat([past_kv[1].to(dt),v],2)
110
  nkv=(k,v) if use_cache else None
111
  k,v=k.repeat_interleave(self.ng,1),v.repeat_interleave(self.ng,1)
 
105
  k=self.k_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
106
  v=self.v_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
107
  cos,sin=self.rotary(q,pos_ids)
108
+ q,k=(q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin)
109
  if past_kv: k,v=torch.cat([past_kv[0].to(dt),k],2),torch.cat([past_kv[1].to(dt),v],2)
110
  nkv=(k,v) if use_cache else None
111
  k,v=k.repeat_interleave(self.ng,1),v.repeat_interleave(self.ng,1)