Upload modeling_mnemosyne.py with huggingface_hub
Browse files- modeling_mnemosyne.py +1 -1
modeling_mnemosyne.py
CHANGED
|
@@ -105,7 +105,7 @@ class MnemosyneAttention(nn.Module):
|
|
| 105 |
k=self.k_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
|
| 106 |
v=self.v_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
|
| 107 |
cos,sin=self.rotary(q,pos_ids)
|
| 108 |
-
q=(q*cos)+(rotate_half(q)*sin)
|
| 109 |
if past_kv: k,v=torch.cat([past_kv[0].to(dt),k],2),torch.cat([past_kv[1].to(dt),v],2)
|
| 110 |
nkv=(k,v) if use_cache else None
|
| 111 |
k,v=k.repeat_interleave(self.ng,1),v.repeat_interleave(self.ng,1)
|
|
|
|
| 105 |
k=self.k_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
|
| 106 |
v=self.v_proj(x).view(B,L,self.nkv,self.hd).transpose(1,2)
|
| 107 |
cos,sin=self.rotary(q,pos_ids)
|
| 108 |
+
q,k=(q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin)
|
| 109 |
if past_kv: k,v=torch.cat([past_kv[0].to(dt),k],2),torch.cat([past_kv[1].to(dt),v],2)
|
| 110 |
nkv=(k,v) if use_cache else None
|
| 111 |
k,v=k.repeat_interleave(self.ng,1),v.repeat_interleave(self.ng,1)
|