""" Jina Code Embeddings - MLX Implementation MLX port of Jina AI's code embedding models. Based on Qwen2.5-Coder with last-token pooling. Features: - Last-token pooling - L2 normalization - Task-specific instruction prefixes - Matryoshka embedding dimensions Usage: import mlx.core as mx from tokenizers import Tokenizer from model import JinaCodeEmbeddingModel import json with open("config.json") as f: config = json.load(f) model = JinaCodeEmbeddingModel(config) weights = mx.load("model.safetensors") model.load_weights(list(weights.items())) tokenizer = Tokenizer.from_file("tokenizer.json") texts = ["Find the most relevant code snippet given the following query:\nprint hello world"] embeddings = model.encode(texts, tokenizer) """ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn INSTRUCTION_CONFIG = { "nl2code": { "query": "Find the most relevant code snippet given the following query:\n", "passage": "Candidate code snippet:\n", }, "qa": { "query": "Find the most relevant answer given the following question:\n", "passage": "Candidate answer:\n", }, "code2code": { "query": "Find an equivalent code snippet given the following code snippet:\n", "passage": "Candidate code snippet:\n", }, "code2nl": { "query": "Find the most relevant comment given the following code snippet:\n", "passage": "Candidate comment:\n", }, "code2completion": { "query": "Find the most relevant completion given the following start of code snippet:\n", "passage": "Candidate completion:\n", }, } @dataclass class ModelArgs: hidden_size: int num_hidden_layers: int intermediate_size: int num_attention_heads: int rms_norm_eps: float vocab_size: int num_key_value_heads: int max_position_embeddings: int rope_theta: float = 1000000.0 tie_word_embeddings: bool = True class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size self.n_heads = args.num_attention_heads self.n_kv_heads = args.num_key_value_heads self.head_dim = dim // self.n_heads self.scale = self.head_dim ** -0.5 self.rope_theta = args.rope_theta # Qwen2 has bias on q/k/v but not o self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: B, L, D = x.shape queries = self.q_proj(x) keys = self.k_proj(x) values = self.v_proj(x) queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) # RoPE via mx.fast queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0) keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0) # Scaled dot-product attention (handles GQA, precision, and masking internally) output = mx.fast.scaled_dot_product_attention( queries, keys, values, mask=mask.astype(queries.dtype) if mask is not None else None, scale=self.scale ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) class MLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, dim, bias=False) self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: h = x + self.self_attn(self.input_layernorm(x), mask) out = h + self.mlp(self.post_attention_layernorm(h)) return out class Qwen2Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None): h = self.embed_tokens(inputs) for layer in self.layers: h = layer(h, mask) return self.norm(h) class JinaCodeEmbeddingModel(nn.Module): """Jina Code Embedding model with last-token pooling.""" def __init__(self, config: dict): super().__init__() args = ModelArgs( hidden_size=config["hidden_size"], num_hidden_layers=config["num_hidden_layers"], intermediate_size=config["intermediate_size"], num_attention_heads=config["num_attention_heads"], rms_norm_eps=config["rms_norm_eps"], vocab_size=config["vocab_size"], num_key_value_heads=config["num_key_value_heads"], max_position_embeddings=config["max_position_embeddings"], rope_theta=config.get("rope_theta", 1000000.0), tie_word_embeddings=config.get("tie_word_embeddings", True), ) self.model = Qwen2Model(args) self.config = config def __call__( self, input_ids: mx.array, attention_mask: Optional[mx.array] = None, ): batch_size, seq_len = input_ids.shape # Causal mask for SDPA: [1, 1, seq_len, seq_len] causal_mask = mx.tril(mx.ones((seq_len, seq_len))) causal_mask = mx.where(causal_mask == 0, -1e4, 0.0) causal_mask = causal_mask[None, None, :, :] if attention_mask is not None: padding_mask = mx.where(attention_mask == 0, -1e4, 0.0) padding_mask = padding_mask[:, None, None, :] mask = causal_mask + padding_mask else: mask = causal_mask hidden_states = self.model(input_ids, mask) # Last token pooling if attention_mask is not None: sequence_lengths = mx.sum(attention_mask.astype(mx.int32), axis=1) - 1 batch_indices = mx.arange(batch_size) embeddings = hidden_states[batch_indices, sequence_lengths] else: embeddings = hidden_states[:, -1, :] # L2 normalize norms = mx.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / norms return embeddings def encode( self, texts: List[str], tokenizer, max_length: int = 8192, truncate_dim: Optional[int] = None, task: str = "nl2code", prompt_type: str = "query", ): """ Encode texts to embeddings. Args: texts: List of input texts tokenizer: Tokenizer instance (from tokenizers library) max_length: Maximum sequence length truncate_dim: Optional Matryoshka dimension task: One of nl2code, qa, code2code, code2nl, code2completion prompt_type: "query" or "passage" """ prefix = INSTRUCTION_CONFIG.get(task, {}).get(prompt_type, "") if prefix: texts = [prefix + t for t in texts] encodings = tokenizer.encode_batch(texts) max_len = min(max_length, max(len(enc.ids) for enc in encodings)) input_ids = [] attention_mask = [] for encoding in encodings: ids = encoding.ids[:max_len] mask = encoding.attention_mask[:max_len] pad_len = max_len - len(ids) if pad_len > 0: ids = ids + [0] * pad_len mask = mask + [0] * pad_len input_ids.append(ids) attention_mask.append(mask) input_ids = mx.array(input_ids) attention_mask = mx.array(attention_mask) embeddings = self(input_ids, attention_mask) if truncate_dim is not None: embeddings = embeddings[:, :truncate_dim] norms = mx.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / norms return embeddings