optimize: use mx.fast.scaled_dot_product_attention
Browse files
model.py
CHANGED
|
@@ -25,7 +25,7 @@ Usage:
|
|
| 25 |
|
| 26 |
tokenizer = Tokenizer.from_file("tokenizer.json")
|
| 27 |
|
| 28 |
-
texts = ["Find the most relevant code snippet given the following query:\
|
| 29 |
embeddings = model.encode(texts, tokenizer)
|
| 30 |
"""
|
| 31 |
|
|
@@ -101,22 +101,14 @@ class Attention(nn.Module):
|
|
| 101 |
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 102 |
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 103 |
|
| 104 |
-
# RoPE
|
| 105 |
queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
| 106 |
keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
| 107 |
|
| 108 |
-
# GQA
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
values = mx.repeat(values, n_rep, axis=1)
|
| 113 |
-
|
| 114 |
-
# Compute in float32 to avoid fp16 overflow
|
| 115 |
-
scores = (queries.astype(mx.float32) @ keys.astype(mx.float32).transpose(0, 1, 3, 2)) * self.scale
|
| 116 |
-
if mask is not None:
|
| 117 |
-
scores = scores + mask.astype(mx.float32)
|
| 118 |
-
attn = mx.softmax(scores, axis=-1)
|
| 119 |
-
output = attn.astype(values.dtype) @ values
|
| 120 |
|
| 121 |
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
| 122 |
return self.o_proj(output)
|
|
@@ -188,7 +180,7 @@ class JinaCodeEmbeddingModel(nn.Module):
|
|
| 188 |
):
|
| 189 |
batch_size, seq_len = input_ids.shape
|
| 190 |
|
| 191 |
-
# Causal mask
|
| 192 |
causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
|
| 193 |
causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
|
| 194 |
causal_mask = causal_mask[None, None, :, :]
|
|
|
|
| 25 |
|
| 26 |
tokenizer = Tokenizer.from_file("tokenizer.json")
|
| 27 |
|
| 28 |
+
texts = ["Find the most relevant code snippet given the following query:\nprint hello world"]
|
| 29 |
embeddings = model.encode(texts, tokenizer)
|
| 30 |
"""
|
| 31 |
|
|
|
|
| 101 |
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 102 |
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
| 103 |
|
| 104 |
+
# RoPE via mx.fast
|
| 105 |
queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
| 106 |
keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
| 107 |
|
| 108 |
+
# Scaled dot-product attention (handles GQA, precision, and masking internally)
|
| 109 |
+
output = mx.fast.scaled_dot_product_attention(
|
| 110 |
+
queries, keys, values, mask=mask.astype(queries.dtype) if mask is not None else None, scale=self.scale
|
| 111 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
| 114 |
return self.o_proj(output)
|
|
|
|
| 180 |
):
|
| 181 |
batch_size, seq_len = input_ids.shape
|
| 182 |
|
| 183 |
+
# Causal mask for SDPA: [1, 1, seq_len, seq_len]
|
| 184 |
causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
|
| 185 |
causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
|
| 186 |
causal_mask = causal_mask[None, None, :, :]
|