#!/usr/bin/env python3 """MALM Inference Script - Run directly from Hugging Face. Usage: # Install dependencies pip install mlx huggingface_hub # Download and run huggingface-cli download codelion/malm-165m --local-dir ./malm-165m python malm-165m/inference.py --query "function that sorts a list" """ import mlx.core as mx import mlx.nn as nn import numpy as np import json import argparse from pathlib import Path from typing import List, Dict, Tuple import re class MALM(nn.Module): """Memory-Augmented Language Model.""" def __init__( self, vocab_size: int, d_model: int = 768, n_heads: int = 12, n_layers: int = 12, n_query_layers: int = 4, max_seq_len: int = 128, dropout: float = 0.0, ): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers self.n_query_layers = n_query_layers self.max_seq_len = max_seq_len # Embeddings self.embed = nn.Embedding(vocab_size, d_model) self.pos_embed = nn.Embedding(max_seq_len, d_model) self.embed_dropout = nn.Dropout(dropout) # Query encoder self.query_layers = [ nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) for _ in range(n_query_layers) ] self.query_ln = nn.LayerNorm(d_model) self.query_proj = nn.Linear(d_model, d_model) # Value encoder self.value_layers = [ nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) for _ in range(n_query_layers) ] self.value_ln = nn.LayerNorm(d_model) self.value_proj = nn.Linear(d_model, d_model) # Decoder layers self.decoder_layers = [ nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) for _ in range(n_layers) ] self.decoder_ln = nn.LayerNorm(d_model) # Output self.output = nn.Linear(d_model, vocab_size) # Temperature for retrieval self.log_temp = mx.array([0.0]) def encode_query(self, query_ids: mx.array) -> mx.array: """Encode query to single embedding.""" B, L = query_ids.shape h = self.embed(query_ids) pos = mx.arange(min(L, self.max_seq_len)) h = h + self.pos_embed(pos) h = self.embed_dropout(h) for layer in self.query_layers: h = layer(h, None) h = self.query_ln(h) mask = (query_ids != 0).astype(mx.float32)[:, :, None] h = h * mask query_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8) return self.query_proj(query_emb) def encode_value(self, value_ids: mx.array) -> mx.array: """Encode value to single embedding.""" B, L = value_ids.shape h = self.embed(value_ids) pos = mx.arange(min(L, self.max_seq_len)) h = h + self.pos_embed(pos) for layer in self.value_layers: h = layer(h, None) h = self.value_ln(h) mask = (value_ids != 0).astype(mx.float32)[:, :, None] h = h * mask val_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8) return self.value_proj(val_emb) def retrieve( self, query_emb: mx.array, key_emb: mx.array, val_emb: mx.array, ) -> Tuple[mx.array, mx.array, mx.array]: """Retrieve from memory.""" scale = self.d_model ** -0.5 temp = mx.exp(self.log_temp) + 0.1 scores = (query_emb @ key_emb.T) * scale / temp attn = mx.softmax(scores, axis=-1) retrieved = attn @ val_emb return retrieved, attn, scores class Tokenizer: """Simple tokenizer for MALM.""" def __init__(self, tokenizer_dict: Dict): self.token_to_id = tokenizer_dict.get("token_to_id", {}) self.id_to_token = {int(v): k for k, v in self.token_to_id.items()} self.special = {"": 0, "": 1, "": 2, "": 3} def encode(self, text: str) -> List[int]: """Tokenize text.""" tokens = re.findall(r"[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|[^\s]", text.lower()) return [self.token_to_id.get(t, self.special.get("", 1)) for t in tokens] def decode(self, ids: List[int]) -> str: """Decode token IDs to text.""" tokens = [self.id_to_token.get(i, "") for i in ids] return " ".join(tokens) def load_model(model_dir: Path): """Load MALM model from directory.""" import mlx.utils as mlx_utils # Load config with open(model_dir / "config.json") as f: config = json.load(f) # Create model model = MALM( vocab_size=config["vocab_size"], d_model=config["d_model"], n_heads=config["n_heads"], n_layers=config["n_layers"], n_query_layers=config["n_query_layers"], max_seq_len=config["max_seq_len"], ) # Load weights and convert to mlx arrays weights = dict(np.load(model_dir / "model.npz")) weights = {k: mx.array(v) for k, v in weights.items()} # Unflatten and load params = mlx_utils.tree_unflatten(list(weights.items())) model.update(params) mx.eval(model.parameters()) # Load tokenizer with open(model_dir / "tokenizer.json") as f: tokenizer_dict = json.load(f) tokenizer = Tokenizer(tokenizer_dict) # Load functions with open(model_dir / "functions.json") as f: functions = json.load(f) return model, tokenizer, functions, config def search_functions( model: MALM, tokenizer: Tokenizer, functions: List[Dict], query: str, top_k: int = 5, ) -> List[Tuple[str, str, float]]: """Search for functions matching a query. Uses the function name as key and signature+docstring as value for retrieval. """ # Encode query query_ids = tokenizer.encode(query) if not query_ids: query_ids = [1] # query_ids = mx.array([query_ids]) # Encode all function keys and values key_tokens = [] value_tokens = [] max_val_len = 64 for func in functions: name = func["name"] # Use signature + docstring as the "value" to search over sig = func.get("signature", name) doc = func.get("docstring", "") value_text = f"{sig} {doc}" key_id = tokenizer.token_to_id.get(name.lower(), 1) key_tokens.append(key_id) val_ids = tokenizer.encode(value_text)[:max_val_len] val_ids = val_ids + [0] * (max_val_len - len(val_ids)) value_tokens.append(val_ids) key_tokens = mx.array(key_tokens) value_tokens = mx.array(value_tokens) # Encode memory key_emb = model.embed(key_tokens) val_emb = model.encode_value(value_tokens) # Get query embedding and compute similarity query_emb = model.encode_query(query_ids) _, attn, scores = model.retrieve(query_emb, key_emb, val_emb) mx.eval(scores) # Get top-k scores_np = np.array(scores[0]) top_indices = np.argsort(scores_np)[::-1][:top_k] results = [] for idx in top_indices: func = functions[idx] score = float(scores_np[idx]) sig = func.get("signature", func["name"]) doc = func.get("docstring", "") results.append((func["name"], sig, doc, score)) return results def main(): parser = argparse.ArgumentParser(description="MALM Inference - Semantic Code Search") parser.add_argument("--query", type=str, required=True, help="Natural language query") parser.add_argument("--top-k", type=int, default=5, help="Number of results") parser.add_argument("--model-dir", type=str, default=None, help="Model directory") args = parser.parse_args() # Determine model directory if args.model_dir: model_dir = Path(args.model_dir) else: model_dir = Path(__file__).parent print(f"Loading model from {model_dir}...") model, tokenizer, functions, config = load_model(model_dir) print(f"Loaded {len(functions)} functions, {config['num_parameters']:,} parameters") # Search print(f"\nQuery: {args.query}") print("-" * 60) results = search_functions(model, tokenizer, functions, args.query, args.top_k) for i, (name, signature, docstring, score) in enumerate(results, 1): print(f"\n{i}. {name} (score: {score:.4f})") print(f" Signature: {signature}") if docstring: print(f" Docstring: {docstring[:100]}{'...' if len(docstring) > 100 else ''}") if __name__ == "__main__": main()