hanxiao commited on
Commit
43687e3
·
verified ·
1 Parent(s): 643acb3

optimize: use mx.fast.scaled_dot_product_attention

Browse files
Files changed (1) hide show
  1. model.py +7 -15
model.py CHANGED
@@ -25,7 +25,7 @@ Usage:
25
 
26
  tokenizer = Tokenizer.from_file("tokenizer.json")
27
 
28
- texts = ["Find the most relevant code snippet given the following query:\\nprint hello world"]
29
  embeddings = model.encode(texts, tokenizer)
30
  """
31
 
@@ -101,22 +101,14 @@ class Attention(nn.Module):
101
  keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
102
  values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
103
 
104
- # RoPE - rotate_half convention (traditional=False)
105
  queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
106
  keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
107
 
108
- # GQA: repeat KV heads
109
- if self.n_kv_heads != self.n_heads:
110
- n_rep = self.n_heads // self.n_kv_heads
111
- keys = mx.repeat(keys, n_rep, axis=1)
112
- values = mx.repeat(values, n_rep, axis=1)
113
-
114
- # Compute in float32 to avoid fp16 overflow
115
- scores = (queries.astype(mx.float32) @ keys.astype(mx.float32).transpose(0, 1, 3, 2)) * self.scale
116
- if mask is not None:
117
- scores = scores + mask.astype(mx.float32)
118
- attn = mx.softmax(scores, axis=-1)
119
- output = attn.astype(values.dtype) @ values
120
 
121
  output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
122
  return self.o_proj(output)
@@ -188,7 +180,7 @@ class JinaCodeEmbeddingModel(nn.Module):
188
  ):
189
  batch_size, seq_len = input_ids.shape
190
 
191
- # Causal mask
192
  causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
193
  causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
194
  causal_mask = causal_mask[None, None, :, :]
 
25
 
26
  tokenizer = Tokenizer.from_file("tokenizer.json")
27
 
28
+ texts = ["Find the most relevant code snippet given the following query:\nprint hello world"]
29
  embeddings = model.encode(texts, tokenizer)
30
  """
31
 
 
101
  keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
102
  values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
103
 
104
+ # RoPE via mx.fast
105
  queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
106
  keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
107
 
108
+ # Scaled dot-product attention (handles GQA, precision, and masking internally)
109
+ output = mx.fast.scaled_dot_product_attention(
110
+ queries, keys, values, mask=mask.astype(queries.dtype) if mask is not None else None, scale=self.scale
111
+ )
 
 
 
 
 
 
 
 
112
 
113
  output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
114
  return self.o_proj(output)
 
180
  ):
181
  batch_size, seq_len = input_ids.shape
182
 
183
+ # Causal mask for SDPA: [1, 1, seq_len, seq_len]
184
  causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
185
  causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
186
  causal_mask = causal_mask[None, None, :, :]