hanxiao commited on
Commit
643acb3
·
verified ·
1 Parent(s): 14e8fb7

Add MLX fp16 port of jina-code-embeddings-0.5b

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. README.md +98 -0
  3. config.json +27 -0
  4. merges.txt +0 -0
  5. model.py +267 -0
  6. model.safetensors +3 -0
  7. tokenizer.json +3 -0
  8. tokenizer_config.json +11 -0
  9. vocab.json +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jina Code Embeddings 0.5B - MLX
2
+
3
+ MLX port (fp16) of [jinaai/jina-code-embeddings-0.5b](https://huggingface.co/jinaai/jina-code-embeddings-0.5b) for Apple Silicon.
4
+
5
+ ## Model Details
6
+
7
+ - **Architecture**: Qwen2.5-Coder-0.5B
8
+ - **Embedding dimension**: 896
9
+ - **Matryoshka dimensions**: 64, 128, 256, 512, 896
10
+ - **Max sequence length**: 32768 tokens
11
+ - **Pooling**: Last-token pooling
12
+ - **Supported tasks**: nl2code, code2code, code2nl, code2completion, qa
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ pip install mlx tokenizers
18
+ ```
19
+
20
+ ## Usage
21
+
22
+ ```bash
23
+ git clone https://huggingface.co/jinaai/jina-code-embeddings-0.5b-mlx
24
+ cd jina-code-embeddings-0.5b-mlx
25
+ ```
26
+
27
+ ```python
28
+ import mlx.core as mx
29
+ from tokenizers import Tokenizer
30
+ from model import JinaCodeEmbeddingModel
31
+ import json
32
+
33
+ with open("config.json") as f:
34
+ config = json.load(f)
35
+
36
+ model = JinaCodeEmbeddingModel(config)
37
+ weights = mx.load("model.safetensors")
38
+ model.load_weights(list(weights.items()))
39
+
40
+ tokenizer = Tokenizer.from_file("tokenizer.json")
41
+
42
+ # Encode a natural language query for code search
43
+ embeddings = model.encode(
44
+ ["print hello world in python"],
45
+ tokenizer,
46
+ task="nl2code",
47
+ prompt_type="query",
48
+ )
49
+
50
+ # Encode code passages
51
+ code_embeddings = model.encode(
52
+ ["print('Hello World!')"],
53
+ tokenizer,
54
+ task="nl2code",
55
+ prompt_type="passage",
56
+ )
57
+ ```
58
+
59
+ ## Task-Specific Instruction Prefixes
60
+
61
+ Each task uses specific prefixes for queries and passages:
62
+
63
+ | Task | Query Prefix | Passage Prefix |
64
+ |------|-------------|----------------|
65
+ | nl2code | Find the most relevant code snippet given the following query: | Candidate code snippet: |
66
+ | qa | Find the most relevant answer given the following question: | Candidate answer: |
67
+ | code2code | Find an equivalent code snippet given the following code snippet: | Candidate code snippet: |
68
+ | code2nl | Find the most relevant comment given the following code snippet: | Candidate comment: |
69
+ | code2completion | Find the most relevant completion given the following start of code snippet: | Candidate completion: |
70
+
71
+ ## Matryoshka Embeddings
72
+
73
+ Truncate to smaller dimensions with minimal quality loss:
74
+
75
+ ```python
76
+ embeddings = model.encode(texts, tokenizer, task="nl2code", prompt_type="query", truncate_dim=256)
77
+ ```
78
+
79
+ ## License
80
+
81
+ CC-BY-NC-4.0
82
+
83
+ ## Citation
84
+
85
+ ```bibtex
86
+ @article{jina-code-embeddings,
87
+ title={Jina Code Embeddings},
88
+ author={Jina AI},
89
+ year={2025},
90
+ url={https://huggingface.co/jinaai/jina-code-embeddings-0.5b}
91
+ }
92
+ ```
93
+
94
+ ## Links
95
+
96
+ - Original model: [jinaai/jina-code-embeddings-0.5b](https://huggingface.co/jinaai/jina-code-embeddings-0.5b)
97
+ - 1.5B MLX variant: [jinaai/jina-code-embeddings-1.5b-mlx](https://huggingface.co/jinaai/jina-code-embeddings-1.5b-mlx)
98
+ - Jina AI: [https://jina.ai](https://jina.ai)
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "qwen2",
3
+ "hidden_size": 896,
4
+ "num_hidden_layers": 24,
5
+ "intermediate_size": 4864,
6
+ "num_attention_heads": 14,
7
+ "num_key_value_heads": 2,
8
+ "rms_norm_eps": 1e-06,
9
+ "vocab_size": 151936,
10
+ "max_position_embeddings": 32768,
11
+ "rope_theta": 1000000.0,
12
+ "tie_word_embeddings": true,
13
+ "matryoshka_dims": [
14
+ 64,
15
+ 128,
16
+ 256,
17
+ 512,
18
+ 896
19
+ ],
20
+ "task_names": [
21
+ "nl2code",
22
+ "qa",
23
+ "code2code",
24
+ "code2nl",
25
+ "code2completion"
26
+ ]
27
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Jina Code Embeddings - MLX Implementation
3
+
4
+ MLX port of Jina AI's code embedding models.
5
+ Based on Qwen2.5-Coder with last-token pooling.
6
+
7
+ Features:
8
+ - Last-token pooling
9
+ - L2 normalization
10
+ - Task-specific instruction prefixes
11
+ - Matryoshka embedding dimensions
12
+
13
+ Usage:
14
+ import mlx.core as mx
15
+ from tokenizers import Tokenizer
16
+ from model import JinaCodeEmbeddingModel
17
+ import json
18
+
19
+ with open("config.json") as f:
20
+ config = json.load(f)
21
+
22
+ model = JinaCodeEmbeddingModel(config)
23
+ weights = mx.load("model.safetensors")
24
+ model.load_weights(list(weights.items()))
25
+
26
+ tokenizer = Tokenizer.from_file("tokenizer.json")
27
+
28
+ texts = ["Find the most relevant code snippet given the following query:\\nprint hello world"]
29
+ embeddings = model.encode(texts, tokenizer)
30
+ """
31
+
32
+ from dataclasses import dataclass
33
+ from typing import Any, Dict, List, Optional, Union
34
+
35
+ import mlx.core as mx
36
+ import mlx.nn as nn
37
+
38
+
39
+ INSTRUCTION_CONFIG = {
40
+ "nl2code": {
41
+ "query": "Find the most relevant code snippet given the following query:\n",
42
+ "passage": "Candidate code snippet:\n",
43
+ },
44
+ "qa": {
45
+ "query": "Find the most relevant answer given the following question:\n",
46
+ "passage": "Candidate answer:\n",
47
+ },
48
+ "code2code": {
49
+ "query": "Find an equivalent code snippet given the following code snippet:\n",
50
+ "passage": "Candidate code snippet:\n",
51
+ },
52
+ "code2nl": {
53
+ "query": "Find the most relevant comment given the following code snippet:\n",
54
+ "passage": "Candidate comment:\n",
55
+ },
56
+ "code2completion": {
57
+ "query": "Find the most relevant completion given the following start of code snippet:\n",
58
+ "passage": "Candidate completion:\n",
59
+ },
60
+ }
61
+
62
+
63
+ @dataclass
64
+ class ModelArgs:
65
+ hidden_size: int
66
+ num_hidden_layers: int
67
+ intermediate_size: int
68
+ num_attention_heads: int
69
+ rms_norm_eps: float
70
+ vocab_size: int
71
+ num_key_value_heads: int
72
+ max_position_embeddings: int
73
+ rope_theta: float = 1000000.0
74
+ tie_word_embeddings: bool = True
75
+
76
+
77
+ class Attention(nn.Module):
78
+ def __init__(self, args: ModelArgs):
79
+ super().__init__()
80
+ dim = args.hidden_size
81
+ self.n_heads = args.num_attention_heads
82
+ self.n_kv_heads = args.num_key_value_heads
83
+ self.head_dim = dim // self.n_heads
84
+ self.scale = self.head_dim ** -0.5
85
+ self.rope_theta = args.rope_theta
86
+
87
+ # Qwen2 has bias on q/k/v but not o
88
+ self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=True)
89
+ self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
90
+ self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
91
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False)
92
+
93
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
94
+ B, L, D = x.shape
95
+
96
+ queries = self.q_proj(x)
97
+ keys = self.k_proj(x)
98
+ values = self.v_proj(x)
99
+
100
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
101
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
102
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
103
+
104
+ # RoPE - rotate_half convention (traditional=False)
105
+ queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
106
+ keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
107
+
108
+ # GQA: repeat KV heads
109
+ if self.n_kv_heads != self.n_heads:
110
+ n_rep = self.n_heads // self.n_kv_heads
111
+ keys = mx.repeat(keys, n_rep, axis=1)
112
+ values = mx.repeat(values, n_rep, axis=1)
113
+
114
+ # Compute in float32 to avoid fp16 overflow
115
+ scores = (queries.astype(mx.float32) @ keys.astype(mx.float32).transpose(0, 1, 3, 2)) * self.scale
116
+ if mask is not None:
117
+ scores = scores + mask.astype(mx.float32)
118
+ attn = mx.softmax(scores, axis=-1)
119
+ output = attn.astype(values.dtype) @ values
120
+
121
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
122
+ return self.o_proj(output)
123
+
124
+
125
+ class MLP(nn.Module):
126
+ def __init__(self, dim, hidden_dim):
127
+ super().__init__()
128
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
129
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
130
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
131
+
132
+ def __call__(self, x) -> mx.array:
133
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
134
+
135
+
136
+ class TransformerBlock(nn.Module):
137
+ def __init__(self, args: ModelArgs):
138
+ super().__init__()
139
+ self.self_attn = Attention(args)
140
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
141
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
142
+ self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
143
+
144
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
145
+ h = x + self.self_attn(self.input_layernorm(x), mask)
146
+ out = h + self.mlp(self.post_attention_layernorm(h))
147
+ return out
148
+
149
+
150
+ class Qwen2Model(nn.Module):
151
+ def __init__(self, args: ModelArgs):
152
+ super().__init__()
153
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
154
+ self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
155
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
156
+
157
+ def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None):
158
+ h = self.embed_tokens(inputs)
159
+ for layer in self.layers:
160
+ h = layer(h, mask)
161
+ return self.norm(h)
162
+
163
+
164
+ class JinaCodeEmbeddingModel(nn.Module):
165
+ """Jina Code Embedding model with last-token pooling."""
166
+
167
+ def __init__(self, config: dict):
168
+ super().__init__()
169
+ args = ModelArgs(
170
+ hidden_size=config["hidden_size"],
171
+ num_hidden_layers=config["num_hidden_layers"],
172
+ intermediate_size=config["intermediate_size"],
173
+ num_attention_heads=config["num_attention_heads"],
174
+ rms_norm_eps=config["rms_norm_eps"],
175
+ vocab_size=config["vocab_size"],
176
+ num_key_value_heads=config["num_key_value_heads"],
177
+ max_position_embeddings=config["max_position_embeddings"],
178
+ rope_theta=config.get("rope_theta", 1000000.0),
179
+ tie_word_embeddings=config.get("tie_word_embeddings", True),
180
+ )
181
+ self.model = Qwen2Model(args)
182
+ self.config = config
183
+
184
+ def __call__(
185
+ self,
186
+ input_ids: mx.array,
187
+ attention_mask: Optional[mx.array] = None,
188
+ ):
189
+ batch_size, seq_len = input_ids.shape
190
+
191
+ # Causal mask
192
+ causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
193
+ causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
194
+ causal_mask = causal_mask[None, None, :, :]
195
+
196
+ if attention_mask is not None:
197
+ padding_mask = mx.where(attention_mask == 0, -1e4, 0.0)
198
+ padding_mask = padding_mask[:, None, None, :]
199
+ mask = causal_mask + padding_mask
200
+ else:
201
+ mask = causal_mask
202
+
203
+ hidden_states = self.model(input_ids, mask)
204
+
205
+ # Last token pooling
206
+ if attention_mask is not None:
207
+ sequence_lengths = mx.sum(attention_mask.astype(mx.int32), axis=1) - 1
208
+ batch_indices = mx.arange(batch_size)
209
+ embeddings = hidden_states[batch_indices, sequence_lengths]
210
+ else:
211
+ embeddings = hidden_states[:, -1, :]
212
+
213
+ # L2 normalize
214
+ norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
215
+ embeddings = embeddings / norms
216
+ return embeddings
217
+
218
+ def encode(
219
+ self,
220
+ texts: List[str],
221
+ tokenizer,
222
+ max_length: int = 8192,
223
+ truncate_dim: Optional[int] = None,
224
+ task: str = "nl2code",
225
+ prompt_type: str = "query",
226
+ ):
227
+ """
228
+ Encode texts to embeddings.
229
+
230
+ Args:
231
+ texts: List of input texts
232
+ tokenizer: Tokenizer instance (from tokenizers library)
233
+ max_length: Maximum sequence length
234
+ truncate_dim: Optional Matryoshka dimension
235
+ task: One of nl2code, qa, code2code, code2nl, code2completion
236
+ prompt_type: "query" or "passage"
237
+ """
238
+ prefix = INSTRUCTION_CONFIG.get(task, {}).get(prompt_type, "")
239
+ if prefix:
240
+ texts = [prefix + t for t in texts]
241
+
242
+ encodings = tokenizer.encode_batch(texts)
243
+ max_len = min(max_length, max(len(enc.ids) for enc in encodings))
244
+
245
+ input_ids = []
246
+ attention_mask = []
247
+ for encoding in encodings:
248
+ ids = encoding.ids[:max_len]
249
+ mask = encoding.attention_mask[:max_len]
250
+ pad_len = max_len - len(ids)
251
+ if pad_len > 0:
252
+ ids = ids + [0] * pad_len
253
+ mask = mask + [0] * pad_len
254
+ input_ids.append(ids)
255
+ attention_mask.append(mask)
256
+
257
+ input_ids = mx.array(input_ids)
258
+ attention_mask = mx.array(attention_mask)
259
+
260
+ embeddings = self(input_ids, attention_mask)
261
+
262
+ if truncate_dim is not None:
263
+ embeddings = embeddings[:, :truncate_dim]
264
+ norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
265
+ embeddings = embeddings / norms
266
+
267
+ return embeddings
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23b4ea15302a07120c52e33f864d99f6d6ed09a26e6bc599179ea7dc36e4e4a9
3
+ size 988097412
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47737715ec4d1b7194dab061db65f1ab1e2f1d664c489b0c72e7256ff033baf7
3
+ size 11418160
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "eos_token": "<|endoftext|>",
6
+ "is_local": false,
7
+ "model_max_length": 1000000000000000019884624838656,
8
+ "pad_token": "<|endoftext|>",
9
+ "tokenizer_class": "Qwen2Tokenizer",
10
+ "unk_token": "<|endoftext|>"
11
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff