codelion commited on
Commit
54aa220
·
verified ·
1 Parent(s): b242434

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +152 -0
  2. config.json +10 -0
  3. functions.json +0 -0
  4. inference.py +286 -0
  5. model.npz +3 -0
  6. tokenizer.json +0 -0
README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: mlx
4
+ tags:
5
+ - mlx
6
+ - memory-augmented
7
+ - code-generation
8
+ - retrieval-augmented
9
+ - python
10
+ - code-search
11
+ pipeline_tag: text-generation
12
+ ---
13
+
14
+ # MALM-165M: Memory-Augmented Language Model
15
+
16
+ A 165M parameter Memory-Augmented Language Model (MALM) for semantic code search, trained on CodeParrot.
17
+
18
+ ## Quick Start
19
+
20
+ ```bash
21
+ # Install dependencies
22
+ pip install mlx huggingface_hub numpy
23
+
24
+ # Download model
25
+ huggingface-cli download mlx-community/malm-165m --local-dir ./malm-165m
26
+
27
+ # Run semantic search
28
+ python malm-165m/inference.py --query "function that sorts a list"
29
+ ```
30
+
31
+ **Example output:**
32
+ ```
33
+ Query: function that sorts a list
34
+ ------------------------------------------------------------
35
+
36
+ 1. array_sort (score: 0.9526)
37
+ Signature: array_sort(col)
38
+ Docstring: Collection function: sorts the input array in ascending order...
39
+
40
+ 2. sort_array (score: 0.7707)
41
+ Signature: sort_array(col, asc)
42
+ Docstring: Collection function: sorts the input array in ascending or descending order...
43
+ ```
44
+
45
+ ## Python API
46
+
47
+ ```python
48
+ from huggingface_hub import snapshot_download
49
+ from pathlib import Path
50
+ import sys
51
+
52
+ # Download and import
53
+ model_path = snapshot_download("mlx-community/malm-165m")
54
+ sys.path.insert(0, model_path)
55
+
56
+ from inference import load_model, search_functions
57
+
58
+ # Load model
59
+ model, tokenizer, functions, config = load_model(Path(model_path))
60
+ print(f"Loaded {len(functions)} functions")
61
+
62
+ # Search
63
+ results = search_functions(
64
+ model, tokenizer, functions,
65
+ query="connect to database",
66
+ top_k=5
67
+ )
68
+
69
+ for name, signature, docstring, score in results:
70
+ print(f"{name}: {score:.4f}")
71
+ ```
72
+
73
+ ## Model Description
74
+
75
+ MALM combines a transformer with learned memory retrieval for semantic code search:
76
+
77
+ 1. **Query encoder** - Encodes natural language queries into embeddings
78
+ 2. **Value encoder** - Encodes function signatures/docstrings
79
+ 3. **Retrieval** - Attention-based lookup from query to memory
80
+ 4. **Memory bank** - 2000 Python functions from CodeParrot
81
+
82
+ ### Why not mlx-lm?
83
+
84
+ MALM uses a **memory-augmented** architecture different from standard LLMs:
85
+ - Separate query and value encoders for retrieval
86
+ - Requires a memory bank of functions
87
+ - Inference is retrieval-based, not autoregressive generation
88
+
89
+ This architecture doesn't fit `mlx-lm generate`, so we provide a custom inference script.
90
+
91
+ ## Architecture
92
+
93
+ | Component | Parameters |
94
+ |-----------|------------|
95
+ | Embedding | 11.1M |
96
+ | Position Embedding | 0.1M |
97
+ | Query Encoder (4 layers) | 28.4M |
98
+ | Value Encoder (4 layers) | 28.4M |
99
+ | Decoder (12 layers) | 85.1M |
100
+ | Output Projection | 11.1M |
101
+ | **Total** | **~165M** |
102
+
103
+ ### Configuration
104
+
105
+ ```json
106
+ {
107
+ "vocab_size": 14407,
108
+ "d_model": 768,
109
+ "n_heads": 12,
110
+ "n_layers": 12,
111
+ "n_query_layers": 4,
112
+ "max_seq_len": 128,
113
+ "num_parameters": 165123656,
114
+ "num_functions": 2000
115
+ }
116
+ ```
117
+
118
+ ## Files
119
+
120
+ | File | Description |
121
+ |------|-------------|
122
+ | `model.npz` | Model weights (MLX-compatible NumPy format) |
123
+ | `config.json` | Model configuration |
124
+ | `tokenizer.json` | Tokenizer vocabulary |
125
+ | `functions.json` | Memory bank of 2000 Python functions |
126
+ | `inference.py` | Standalone inference script |
127
+
128
+ ## Training
129
+
130
+ Trained on CodeParrot with a focus on Python function retrieval:
131
+ - Encodes natural language queries into embedding space
132
+ - Learns semantic similarity between queries and function signatures
133
+ - Uses attention-based retrieval over a memory bank
134
+
135
+ ## Related Work
136
+
137
+ Part of the [HashHop](https://github.com/codelion/hash-hop) project exploring long-context evaluation and memory-augmented architectures.
138
+
139
+ ## Citation
140
+
141
+ ```bibtex
142
+ @misc{malm2025,
143
+ title={MALM: Memory-Augmented Language Model},
144
+ author={HashHop Contributors},
145
+ year={2025},
146
+ url={https://github.com/codelion/hash-hop}
147
+ }
148
+ ```
149
+
150
+ ## License
151
+
152
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 14407,
3
+ "d_model": 768,
4
+ "n_heads": 12,
5
+ "n_layers": 12,
6
+ "n_query_layers": 4,
7
+ "max_seq_len": 128,
8
+ "num_parameters": 165123656,
9
+ "num_functions": 2000
10
+ }
functions.json ADDED
The diff for this file is too large to render. See raw diff
 
inference.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """MALM Inference Script - Run directly from Hugging Face.
3
+
4
+ Usage:
5
+ # Install dependencies
6
+ pip install mlx huggingface_hub
7
+
8
+ # Download and run
9
+ huggingface-cli download mlx-community/malm-165m --local-dir ./malm-165m
10
+ python malm-165m/inference.py --query "function that sorts a list"
11
+ """
12
+
13
+ import mlx.core as mx
14
+ import mlx.nn as nn
15
+ import numpy as np
16
+ import json
17
+ import argparse
18
+ from pathlib import Path
19
+ from typing import List, Dict, Tuple
20
+ import re
21
+
22
+
23
+ class MALM(nn.Module):
24
+ """Memory-Augmented Language Model."""
25
+
26
+ def __init__(
27
+ self,
28
+ vocab_size: int,
29
+ d_model: int = 768,
30
+ n_heads: int = 12,
31
+ n_layers: int = 12,
32
+ n_query_layers: int = 4,
33
+ max_seq_len: int = 128,
34
+ dropout: float = 0.0,
35
+ ):
36
+ super().__init__()
37
+ self.vocab_size = vocab_size
38
+ self.d_model = d_model
39
+ self.n_heads = n_heads
40
+ self.n_layers = n_layers
41
+ self.n_query_layers = n_query_layers
42
+ self.max_seq_len = max_seq_len
43
+
44
+ # Embeddings
45
+ self.embed = nn.Embedding(vocab_size, d_model)
46
+ self.pos_embed = nn.Embedding(max_seq_len, d_model)
47
+ self.embed_dropout = nn.Dropout(dropout)
48
+
49
+ # Query encoder
50
+ self.query_layers = [
51
+ nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4)
52
+ for _ in range(n_query_layers)
53
+ ]
54
+ self.query_ln = nn.LayerNorm(d_model)
55
+ self.query_proj = nn.Linear(d_model, d_model)
56
+
57
+ # Value encoder
58
+ self.value_layers = [
59
+ nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4)
60
+ for _ in range(n_query_layers)
61
+ ]
62
+ self.value_ln = nn.LayerNorm(d_model)
63
+ self.value_proj = nn.Linear(d_model, d_model)
64
+
65
+ # Decoder layers
66
+ self.decoder_layers = [
67
+ nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4)
68
+ for _ in range(n_layers)
69
+ ]
70
+ self.decoder_ln = nn.LayerNorm(d_model)
71
+
72
+ # Output
73
+ self.output = nn.Linear(d_model, vocab_size)
74
+
75
+ # Temperature for retrieval
76
+ self.log_temp = mx.array([0.0])
77
+
78
+ def encode_query(self, query_ids: mx.array) -> mx.array:
79
+ """Encode query to single embedding."""
80
+ B, L = query_ids.shape
81
+
82
+ h = self.embed(query_ids)
83
+ pos = mx.arange(min(L, self.max_seq_len))
84
+ h = h + self.pos_embed(pos)
85
+ h = self.embed_dropout(h)
86
+
87
+ for layer in self.query_layers:
88
+ h = layer(h, None)
89
+
90
+ h = self.query_ln(h)
91
+
92
+ mask = (query_ids != 0).astype(mx.float32)[:, :, None]
93
+ h = h * mask
94
+ query_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8)
95
+
96
+ return self.query_proj(query_emb)
97
+
98
+ def encode_value(self, value_ids: mx.array) -> mx.array:
99
+ """Encode value to single embedding."""
100
+ B, L = value_ids.shape
101
+
102
+ h = self.embed(value_ids)
103
+ pos = mx.arange(min(L, self.max_seq_len))
104
+ h = h + self.pos_embed(pos)
105
+
106
+ for layer in self.value_layers:
107
+ h = layer(h, None)
108
+
109
+ h = self.value_ln(h)
110
+
111
+ mask = (value_ids != 0).astype(mx.float32)[:, :, None]
112
+ h = h * mask
113
+ val_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8)
114
+
115
+ return self.value_proj(val_emb)
116
+
117
+ def retrieve(
118
+ self,
119
+ query_emb: mx.array,
120
+ key_emb: mx.array,
121
+ val_emb: mx.array,
122
+ ) -> Tuple[mx.array, mx.array, mx.array]:
123
+ """Retrieve from memory."""
124
+ scale = self.d_model ** -0.5
125
+ temp = mx.exp(self.log_temp) + 0.1
126
+
127
+ scores = (query_emb @ key_emb.T) * scale / temp
128
+ attn = mx.softmax(scores, axis=-1)
129
+ retrieved = attn @ val_emb
130
+
131
+ return retrieved, attn, scores
132
+
133
+
134
+ class Tokenizer:
135
+ """Simple tokenizer for MALM."""
136
+
137
+ def __init__(self, tokenizer_dict: Dict):
138
+ self.token_to_id = tokenizer_dict.get("token_to_id", {})
139
+ self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
140
+ self.special = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
141
+
142
+ def encode(self, text: str) -> List[int]:
143
+ """Tokenize text."""
144
+ tokens = re.findall(r"[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|[^\s]", text.lower())
145
+ return [self.token_to_id.get(t, self.special.get("<UNK>", 1)) for t in tokens]
146
+
147
+ def decode(self, ids: List[int]) -> str:
148
+ """Decode token IDs to text."""
149
+ tokens = [self.id_to_token.get(i, "<UNK>") for i in ids]
150
+ return " ".join(tokens)
151
+
152
+
153
+ def load_model(model_dir: Path):
154
+ """Load MALM model from directory."""
155
+ import mlx.utils as mlx_utils
156
+
157
+ # Load config
158
+ with open(model_dir / "config.json") as f:
159
+ config = json.load(f)
160
+
161
+ # Create model
162
+ model = MALM(
163
+ vocab_size=config["vocab_size"],
164
+ d_model=config["d_model"],
165
+ n_heads=config["n_heads"],
166
+ n_layers=config["n_layers"],
167
+ n_query_layers=config["n_query_layers"],
168
+ max_seq_len=config["max_seq_len"],
169
+ )
170
+
171
+ # Load weights and convert to mlx arrays
172
+ weights = dict(np.load(model_dir / "model.npz"))
173
+ weights = {k: mx.array(v) for k, v in weights.items()}
174
+
175
+ # Unflatten and load
176
+ params = mlx_utils.tree_unflatten(list(weights.items()))
177
+ model.update(params)
178
+ mx.eval(model.parameters())
179
+
180
+ # Load tokenizer
181
+ with open(model_dir / "tokenizer.json") as f:
182
+ tokenizer_dict = json.load(f)
183
+ tokenizer = Tokenizer(tokenizer_dict)
184
+
185
+ # Load functions
186
+ with open(model_dir / "functions.json") as f:
187
+ functions = json.load(f)
188
+
189
+ return model, tokenizer, functions, config
190
+
191
+
192
+ def search_functions(
193
+ model: MALM,
194
+ tokenizer: Tokenizer,
195
+ functions: List[Dict],
196
+ query: str,
197
+ top_k: int = 5,
198
+ ) -> List[Tuple[str, str, float]]:
199
+ """Search for functions matching a query.
200
+
201
+ Uses the function name as key and signature+docstring as value for retrieval.
202
+ """
203
+ # Encode query
204
+ query_ids = tokenizer.encode(query)
205
+ if not query_ids:
206
+ query_ids = [1] # <UNK>
207
+ query_ids = mx.array([query_ids])
208
+
209
+ # Encode all function keys and values
210
+ key_tokens = []
211
+ value_tokens = []
212
+ max_val_len = 64
213
+
214
+ for func in functions:
215
+ name = func["name"]
216
+ # Use signature + docstring as the "value" to search over
217
+ sig = func.get("signature", name)
218
+ doc = func.get("docstring", "")
219
+ value_text = f"{sig} {doc}"
220
+
221
+ key_id = tokenizer.token_to_id.get(name.lower(), 1)
222
+ key_tokens.append(key_id)
223
+
224
+ val_ids = tokenizer.encode(value_text)[:max_val_len]
225
+ val_ids = val_ids + [0] * (max_val_len - len(val_ids))
226
+ value_tokens.append(val_ids)
227
+
228
+ key_tokens = mx.array(key_tokens)
229
+ value_tokens = mx.array(value_tokens)
230
+
231
+ # Encode memory
232
+ key_emb = model.embed(key_tokens)
233
+ val_emb = model.encode_value(value_tokens)
234
+
235
+ # Get query embedding and compute similarity
236
+ query_emb = model.encode_query(query_ids)
237
+ _, attn, scores = model.retrieve(query_emb, key_emb, val_emb)
238
+ mx.eval(scores)
239
+
240
+ # Get top-k
241
+ scores_np = np.array(scores[0])
242
+ top_indices = np.argsort(scores_np)[::-1][:top_k]
243
+
244
+ results = []
245
+ for idx in top_indices:
246
+ func = functions[idx]
247
+ score = float(scores_np[idx])
248
+ sig = func.get("signature", func["name"])
249
+ doc = func.get("docstring", "")
250
+ results.append((func["name"], sig, doc, score))
251
+
252
+ return results
253
+
254
+
255
+ def main():
256
+ parser = argparse.ArgumentParser(description="MALM Inference - Semantic Code Search")
257
+ parser.add_argument("--query", type=str, required=True, help="Natural language query")
258
+ parser.add_argument("--top-k", type=int, default=5, help="Number of results")
259
+ parser.add_argument("--model-dir", type=str, default=None, help="Model directory")
260
+ args = parser.parse_args()
261
+
262
+ # Determine model directory
263
+ if args.model_dir:
264
+ model_dir = Path(args.model_dir)
265
+ else:
266
+ model_dir = Path(__file__).parent
267
+
268
+ print(f"Loading model from {model_dir}...")
269
+ model, tokenizer, functions, config = load_model(model_dir)
270
+ print(f"Loaded {len(functions)} functions, {config['num_parameters']:,} parameters")
271
+
272
+ # Search
273
+ print(f"\nQuery: {args.query}")
274
+ print("-" * 60)
275
+
276
+ results = search_functions(model, tokenizer, functions, args.query, args.top_k)
277
+
278
+ for i, (name, signature, docstring, score) in enumerate(results, 1):
279
+ print(f"\n{i}. {name} (score: {score:.4f})")
280
+ print(f" Signature: {signature}")
281
+ if docstring:
282
+ print(f" Docstring: {docstring[:100]}{'...' if len(docstring) > 100 else ''}")
283
+
284
+
285
+ if __name__ == "__main__":
286
+ main()
model.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f7c9ba2754706b3e7c9c0b00b53226a6bafdcb2aaab5b47a46bdf5d21e4f14f
3
+ size 660569734
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff