ThingsAI commited on
Commit
c58e770
·
verified ·
1 Parent(s): 960cc50

Upload modeling_quark.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_quark.py +1 -1
modeling_quark.py CHANGED
@@ -39,7 +39,7 @@ class QuarkAttention(nn.Module):
39
  q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2)
40
  k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
41
  v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
42
- q,k=self.rope(q,k)
43
  if self.ng>1: k=k.repeat_interleave(self.ng,1); v=v.repeat_interleave(self.ng,1)
44
  return self.o_proj(F.scaled_dot_product_attention(q,k,v,is_causal=True).transpose(1,2).contiguous().view(B,T,-1))
45
 
 
39
  q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2)
40
  k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
41
  v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
42
+ q,k=self.rope(q,k); q,k=q.to(v.dtype),k.to(v.dtype)
43
  if self.ng>1: k=k.repeat_interleave(self.ng,1); v=v.repeat_interleave(self.ng,1)
44
  return self.o_proj(F.scaled_dot_product_attention(q,k,v,is_causal=True).transpose(1,2).contiguous().view(B,T,-1))
45