hanxiao's picture
optimize: use mx.fast.scaled_dot_product_attention
43687e3 verified
"""
Jina Code Embeddings - MLX Implementation
MLX port of Jina AI's code embedding models.
Based on Qwen2.5-Coder with last-token pooling.
Features:
- Last-token pooling
- L2 normalization
- Task-specific instruction prefixes
- Matryoshka embedding dimensions
Usage:
import mlx.core as mx
from tokenizers import Tokenizer
from model import JinaCodeEmbeddingModel
import json
with open("config.json") as f:
config = json.load(f)
model = JinaCodeEmbeddingModel(config)
weights = mx.load("model.safetensors")
model.load_weights(list(weights.items()))
tokenizer = Tokenizer.from_file("tokenizer.json")
texts = ["Find the most relevant code snippet given the following query:\nprint hello world"]
embeddings = model.encode(texts, tokenizer)
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
INSTRUCTION_CONFIG = {
"nl2code": {
"query": "Find the most relevant code snippet given the following query:\n",
"passage": "Candidate code snippet:\n",
},
"qa": {
"query": "Find the most relevant answer given the following question:\n",
"passage": "Candidate answer:\n",
},
"code2code": {
"query": "Find an equivalent code snippet given the following code snippet:\n",
"passage": "Candidate code snippet:\n",
},
"code2nl": {
"query": "Find the most relevant comment given the following code snippet:\n",
"passage": "Candidate comment:\n",
},
"code2completion": {
"query": "Find the most relevant completion given the following start of code snippet:\n",
"passage": "Candidate completion:\n",
},
}
@dataclass
class ModelArgs:
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
max_position_embeddings: int
rope_theta: float = 1000000.0
tie_word_embeddings: bool = True
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = args.num_attention_heads
self.n_kv_heads = args.num_key_value_heads
self.head_dim = dim // self.n_heads
self.scale = self.head_dim ** -0.5
self.rope_theta = args.rope_theta
# Qwen2 has bias on q/k/v but not o
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
B, L, D = x.shape
queries = self.q_proj(x)
keys = self.k_proj(x)
values = self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
# RoPE via mx.fast
queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
# Scaled dot-product attention (handles GQA, precision, and masking internally)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, mask=mask.astype(queries.dtype) if mask is not None else None, scale=self.scale
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
h = x + self.self_attn(self.input_layernorm(x), mask)
out = h + self.mlp(self.post_attention_layernorm(h))
return out
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None):
h = self.embed_tokens(inputs)
for layer in self.layers:
h = layer(h, mask)
return self.norm(h)
class JinaCodeEmbeddingModel(nn.Module):
"""Jina Code Embedding model with last-token pooling."""
def __init__(self, config: dict):
super().__init__()
args = ModelArgs(
hidden_size=config["hidden_size"],
num_hidden_layers=config["num_hidden_layers"],
intermediate_size=config["intermediate_size"],
num_attention_heads=config["num_attention_heads"],
rms_norm_eps=config["rms_norm_eps"],
vocab_size=config["vocab_size"],
num_key_value_heads=config["num_key_value_heads"],
max_position_embeddings=config["max_position_embeddings"],
rope_theta=config.get("rope_theta", 1000000.0),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
self.model = Qwen2Model(args)
self.config = config
def __call__(
self,
input_ids: mx.array,
attention_mask: Optional[mx.array] = None,
):
batch_size, seq_len = input_ids.shape
# Causal mask for SDPA: [1, 1, seq_len, seq_len]
causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
causal_mask = causal_mask[None, None, :, :]
if attention_mask is not None:
padding_mask = mx.where(attention_mask == 0, -1e4, 0.0)
padding_mask = padding_mask[:, None, None, :]
mask = causal_mask + padding_mask
else:
mask = causal_mask
hidden_states = self.model(input_ids, mask)
# Last token pooling
if attention_mask is not None:
sequence_lengths = mx.sum(attention_mask.astype(mx.int32), axis=1) - 1
batch_indices = mx.arange(batch_size)
embeddings = hidden_states[batch_indices, sequence_lengths]
else:
embeddings = hidden_states[:, -1, :]
# L2 normalize
norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / norms
return embeddings
def encode(
self,
texts: List[str],
tokenizer,
max_length: int = 8192,
truncate_dim: Optional[int] = None,
task: str = "nl2code",
prompt_type: str = "query",
):
"""
Encode texts to embeddings.
Args:
texts: List of input texts
tokenizer: Tokenizer instance (from tokenizers library)
max_length: Maximum sequence length
truncate_dim: Optional Matryoshka dimension
task: One of nl2code, qa, code2code, code2nl, code2completion
prompt_type: "query" or "passage"
"""
prefix = INSTRUCTION_CONFIG.get(task, {}).get(prompt_type, "")
if prefix:
texts = [prefix + t for t in texts]
encodings = tokenizer.encode_batch(texts)
max_len = min(max_length, max(len(enc.ids) for enc in encodings))
input_ids = []
attention_mask = []
for encoding in encodings:
ids = encoding.ids[:max_len]
mask = encoding.attention_mask[:max_len]
pad_len = max_len - len(ids)
if pad_len > 0:
ids = ids + [0] * pad_len
mask = mask + [0] * pad_len
input_ids.append(ids)
attention_mask.append(mask)
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
embeddings = self(input_ids, attention_mask)
if truncate_dim is not None:
embeddings = embeddings[:, :truncate_dim]
norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / norms
return embeddings