| """ |
| Jina Code Embeddings - MLX Implementation |
| |
| MLX port of Jina AI's code embedding models. |
| Based on Qwen2.5-Coder with last-token pooling. |
| |
| Features: |
| - Last-token pooling |
| - L2 normalization |
| - Task-specific instruction prefixes |
| - Matryoshka embedding dimensions |
| |
| Usage: |
| import mlx.core as mx |
| from tokenizers import Tokenizer |
| from model import JinaCodeEmbeddingModel |
| import json |
| |
| with open("config.json") as f: |
| config = json.load(f) |
| |
| model = JinaCodeEmbeddingModel(config) |
| weights = mx.load("model.safetensors") |
| model.load_weights(list(weights.items())) |
| |
| tokenizer = Tokenizer.from_file("tokenizer.json") |
| |
| texts = ["Find the most relevant code snippet given the following query:\nprint hello world"] |
| embeddings = model.encode(texts, tokenizer) |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
|
|
|
|
| INSTRUCTION_CONFIG = { |
| "nl2code": { |
| "query": "Find the most relevant code snippet given the following query:\n", |
| "passage": "Candidate code snippet:\n", |
| }, |
| "qa": { |
| "query": "Find the most relevant answer given the following question:\n", |
| "passage": "Candidate answer:\n", |
| }, |
| "code2code": { |
| "query": "Find an equivalent code snippet given the following code snippet:\n", |
| "passage": "Candidate code snippet:\n", |
| }, |
| "code2nl": { |
| "query": "Find the most relevant comment given the following code snippet:\n", |
| "passage": "Candidate comment:\n", |
| }, |
| "code2completion": { |
| "query": "Find the most relevant completion given the following start of code snippet:\n", |
| "passage": "Candidate completion:\n", |
| }, |
| } |
|
|
|
|
| @dataclass |
| class ModelArgs: |
| hidden_size: int |
| num_hidden_layers: int |
| intermediate_size: int |
| num_attention_heads: int |
| rms_norm_eps: float |
| vocab_size: int |
| num_key_value_heads: int |
| max_position_embeddings: int |
| rope_theta: float = 1000000.0 |
| tie_word_embeddings: bool = True |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| dim = args.hidden_size |
| self.n_heads = args.num_attention_heads |
| self.n_kv_heads = args.num_key_value_heads |
| self.head_dim = dim // self.n_heads |
| self.scale = self.head_dim ** -0.5 |
| self.rope_theta = args.rope_theta |
|
|
| |
| self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=True) |
| self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True) |
| self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True) |
| self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False) |
|
|
| def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: |
| B, L, D = x.shape |
|
|
| queries = self.q_proj(x) |
| keys = self.k_proj(x) |
| values = self.v_proj(x) |
|
|
| queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) |
| keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) |
| values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) |
|
|
| |
| queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0) |
| keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0) |
|
|
| |
| output = mx.fast.scaled_dot_product_attention( |
| queries, keys, values, mask=mask.astype(queries.dtype) if mask is not None else None, scale=self.scale |
| ) |
|
|
| output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) |
| return self.o_proj(output) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, dim, hidden_dim): |
| super().__init__() |
| self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) |
| self.down_proj = nn.Linear(hidden_dim, dim, bias=False) |
| self.up_proj = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def __call__(self, x) -> mx.array: |
| return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.self_attn = Attention(args) |
| self.mlp = MLP(args.hidden_size, args.intermediate_size) |
| self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
|
|
| def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: |
| h = x + self.self_attn(self.input_layernorm(x), mask) |
| out = h + self.mlp(self.post_attention_layernorm(h)) |
| return out |
|
|
|
|
| class Qwen2Model(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) |
| self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] |
| self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
|
|
| def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None): |
| h = self.embed_tokens(inputs) |
| for layer in self.layers: |
| h = layer(h, mask) |
| return self.norm(h) |
|
|
|
|
| class JinaCodeEmbeddingModel(nn.Module): |
| """Jina Code Embedding model with last-token pooling.""" |
|
|
| def __init__(self, config: dict): |
| super().__init__() |
| args = ModelArgs( |
| hidden_size=config["hidden_size"], |
| num_hidden_layers=config["num_hidden_layers"], |
| intermediate_size=config["intermediate_size"], |
| num_attention_heads=config["num_attention_heads"], |
| rms_norm_eps=config["rms_norm_eps"], |
| vocab_size=config["vocab_size"], |
| num_key_value_heads=config["num_key_value_heads"], |
| max_position_embeddings=config["max_position_embeddings"], |
| rope_theta=config.get("rope_theta", 1000000.0), |
| tie_word_embeddings=config.get("tie_word_embeddings", True), |
| ) |
| self.model = Qwen2Model(args) |
| self.config = config |
|
|
| def __call__( |
| self, |
| input_ids: mx.array, |
| attention_mask: Optional[mx.array] = None, |
| ): |
| batch_size, seq_len = input_ids.shape |
|
|
| |
| causal_mask = mx.tril(mx.ones((seq_len, seq_len))) |
| causal_mask = mx.where(causal_mask == 0, -1e4, 0.0) |
| causal_mask = causal_mask[None, None, :, :] |
|
|
| if attention_mask is not None: |
| padding_mask = mx.where(attention_mask == 0, -1e4, 0.0) |
| padding_mask = padding_mask[:, None, None, :] |
| mask = causal_mask + padding_mask |
| else: |
| mask = causal_mask |
|
|
| hidden_states = self.model(input_ids, mask) |
|
|
| |
| if attention_mask is not None: |
| sequence_lengths = mx.sum(attention_mask.astype(mx.int32), axis=1) - 1 |
| batch_indices = mx.arange(batch_size) |
| embeddings = hidden_states[batch_indices, sequence_lengths] |
| else: |
| embeddings = hidden_states[:, -1, :] |
|
|
| |
| norms = mx.linalg.norm(embeddings, axis=1, keepdims=True) |
| embeddings = embeddings / norms |
| return embeddings |
|
|
| def encode( |
| self, |
| texts: List[str], |
| tokenizer, |
| max_length: int = 8192, |
| truncate_dim: Optional[int] = None, |
| task: str = "nl2code", |
| prompt_type: str = "query", |
| ): |
| """ |
| Encode texts to embeddings. |
| |
| Args: |
| texts: List of input texts |
| tokenizer: Tokenizer instance (from tokenizers library) |
| max_length: Maximum sequence length |
| truncate_dim: Optional Matryoshka dimension |
| task: One of nl2code, qa, code2code, code2nl, code2completion |
| prompt_type: "query" or "passage" |
| """ |
| prefix = INSTRUCTION_CONFIG.get(task, {}).get(prompt_type, "") |
| if prefix: |
| texts = [prefix + t for t in texts] |
|
|
| encodings = tokenizer.encode_batch(texts) |
| max_len = min(max_length, max(len(enc.ids) for enc in encodings)) |
|
|
| input_ids = [] |
| attention_mask = [] |
| for encoding in encodings: |
| ids = encoding.ids[:max_len] |
| mask = encoding.attention_mask[:max_len] |
| pad_len = max_len - len(ids) |
| if pad_len > 0: |
| ids = ids + [0] * pad_len |
| mask = mask + [0] * pad_len |
| input_ids.append(ids) |
| attention_mask.append(mask) |
|
|
| input_ids = mx.array(input_ids) |
| attention_mask = mx.array(attention_mask) |
|
|
| embeddings = self(input_ids, attention_mask) |
|
|
| if truncate_dim is not None: |
| embeddings = embeddings[:, :truncate_dim] |
| norms = mx.linalg.norm(embeddings, axis=1, keepdims=True) |
| embeddings = embeddings / norms |
|
|
| return embeddings |
|
|