Instructions to use codelion/sprog-9m with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use codelion/sprog-9m with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir sprog-9m codelion/sprog-9m
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| """SPROG-9M — standalone inference. | |
| A 9.37M-parameter from-scratch MLX seq2seq that solves grade-school math word | |
| problems WITHOUT an LLM at inference. It abstracts the numbers in a question to | |
| slots ([N0], [N1], ...) and predicts a postfix PROGRAM over those slots; the | |
| program is executed symbolically. Self-consistency (96 temperature samples) plus | |
| a free symbolic verifier (0 trainable params) select the final answer. | |
| Usage: | |
| pip install mlx numpy huggingface_hub | |
| huggingface-cli download codelion/sprog-9m --local-dir ./sprog-9m | |
| python sprog-9m/inference.py --question "A baker had 24 muffins, sold 3/4, then baked 10 more. How many now?" | |
| Python: | |
| from huggingface_hub import snapshot_download | |
| from pathlib import Path | |
| import sys | |
| p = snapshot_download("codelion/sprog-9m"); sys.path.insert(0, p) | |
| from inference import load_model, solve | |
| model, stoi, cfg = load_model(Path(p)) | |
| print(solve(model, stoi, "If 3 pens cost $6, how much do 5 pens cost?")) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import numpy as np | |
| from mlx.utils import tree_unflatten | |
| # --------------------------------------------------------------------------- | |
| # 1. number-slot tokenization (numbers -> [Ni]; spelled numbers -> digits) | |
| # --------------------------------------------------------------------------- | |
| _UNITS = {"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5, | |
| "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11, | |
| "twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, | |
| "sixteen": 16, "seventeen": 17, "eighteen": 18, "nineteen": 19} | |
| _TENS = {"twenty": 20, "thirty": 30, "forty": 40, "fifty": 50, "sixty": 60, | |
| "seventy": 70, "eighty": 80, "ninety": 90} | |
| _SCALES = {"hundred": 100, "thousand": 1000, "million": 1_000_000} | |
| OPS = {"+", "-", "*", "/", "**"} | |
| PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>" | |
| SPECIAL = [PAD, BOS, EOS, UNK] | |
| MAX_SLOTS = 20 | |
| SLOT_TOKENS = [f"[N{i}]" for i in range(MAX_SLOTS)] | |
| CONST_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 15, 16, 20, 24, 25, | |
| 30, 50, 52, 60, 100, 365, 1000, | |
| 0.01, 0.05, 0.1, 0.125, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, | |
| 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, | |
| 1.2, 1.25, 1.5, 2.5] | |
| CONST_TOKENS = [f"[C{c}]" for c in CONST_VALUES] | |
| TGT_VOCAB = SPECIAL + sorted(OPS) + SLOT_TOKENS + CONST_TOKENS | |
| TGT_STOI = {t: i for i, t in enumerate(TGT_VOCAB)} | |
| BOS_ID, EOS_ID = TGT_STOI[BOS], TGT_STOI[EOS] | |
| _NUM = re.compile(r"\d[\d,]*(?:\.\d+)?") | |
| _TOKEN = re.compile(r"\[N\d+\]|[a-z]+|\d+\.?\d*|[^\sa-z\d]") | |
| _WORDTOK = re.compile(r"[A-Za-z]+|\d[\d,]*\.?\d*") | |
| def to_digits(text: str) -> str: | |
| s = text.replace("-", " ") | |
| spans, run_start, total, current, have, last_end = [], None, 0, 0, False, 0 | |
| def flush(end): | |
| nonlocal run_start, total, current, have | |
| if have and run_start is not None: | |
| spans.append((run_start, end, str(total + current))) | |
| run_start, total, current, have = None, 0, 0, False | |
| for m in _WORDTOK.finditer(s): | |
| w = m.group().lower() | |
| if w in _UNITS: | |
| if not have: | |
| run_start = m.start() | |
| current += _UNITS[w]; have = True | |
| elif w in _TENS: | |
| if not have: | |
| run_start = m.start() | |
| current += _TENS[w]; have = True | |
| elif w == "hundred" and have: | |
| current = (current or 1) * 100 | |
| elif w in _SCALES and have: | |
| total += (current or 1) * _SCALES[w]; current = 0 | |
| elif w == "and" and have: | |
| pass | |
| else: | |
| flush(last_end) | |
| last_end = m.end() | |
| flush(last_end) | |
| if not spans: | |
| return text | |
| out, prev = [], 0 | |
| for st, en, rep in spans: | |
| out.append(s[prev:st]); out.append(rep); prev = en | |
| out.append(s[prev:]) | |
| return "".join(out) | |
| def slot_encode(text: str): | |
| """(token_list, slot_values) — numbers replaced by [Ni] in order of appearance.""" | |
| t = to_digits(text).lower() | |
| values, pieces, prev = [], [], 0 | |
| for m in _NUM.finditer(t): | |
| values.append(float(m.group().replace(",", ""))) | |
| if len(values) > MAX_SLOTS: | |
| return None | |
| pieces.append(t[prev:m.start()]); pieces.append(f" [N{len(values)-1}] ") | |
| prev = m.end() | |
| pieces.append(t[prev:]) | |
| return _TOKEN.findall("".join(pieces)), values | |
| # --------------------------------------------------------------------------- | |
| # 2. symbolic execution + verifier (0 trainable params) | |
| # --------------------------------------------------------------------------- | |
| def decode_postfix(seq, values): | |
| stack = [] | |
| for t in seq: | |
| if t in OPS: | |
| if len(stack) < 2: | |
| return None | |
| b, a = stack.pop(), stack.pop() | |
| if t == "/": | |
| if b == 0: | |
| return None | |
| stack.append(a / b) | |
| elif t == "**": | |
| if abs(a) > 1e4 or abs(b) > 6 or (a < 0 and b != int(b)) or (a == 0 and b < 0): | |
| return None | |
| try: | |
| r = a ** b | |
| except (OverflowError, ValueError, ZeroDivisionError): | |
| return None | |
| if isinstance(r, complex): | |
| return None | |
| stack.append(r) | |
| else: | |
| stack.append({"+": a + b, "-": a - b, "*": a * b}[t]) | |
| elif t.startswith("[N"): | |
| i = int(t[2:-1]) | |
| if i >= len(values): | |
| return None | |
| stack.append(values[i]) | |
| elif t.startswith("[C"): | |
| stack.append(float(t[2:-1])) | |
| else: | |
| return None | |
| return stack[0] if len(stack) == 1 else None | |
| def _intermediates(prog, vals): | |
| s, inter = [], [] | |
| for t in prog: | |
| if t in OPS: | |
| if len(s) < 2: | |
| return None | |
| b, a = s.pop(), s.pop() | |
| try: | |
| if t == "+": r = a + b | |
| elif t == "-": r = a - b | |
| elif t == "*": r = a * b | |
| elif t == "/": r = a / b if b != 0 else None | |
| elif t == "**": r = a ** b if not (a < 0 and b != int(b)) and not (a == 0 and b < 0) else None | |
| else: r = None | |
| except Exception: | |
| return None | |
| if r is None or isinstance(r, complex): | |
| return None | |
| s.append(r); inter.append(r) | |
| elif t.startswith("[N"): | |
| i = int(t[2:-1]) | |
| if i >= len(vals): | |
| return None | |
| s.append(vals[i]) | |
| elif t.startswith("[C"): | |
| s.append(float(t[2:-1])) | |
| else: | |
| return None | |
| return inter | |
| def _plausible(a): | |
| return a is not None and a > 0 and abs(a - round(a)) < 1e-6 | |
| def verify_select(question, candidates, mag_k=100.0): | |
| """Pick the answer whose best program maximizes the structural score | |
| (coverage + magnitude-sanity + intermediate-sanity), tie-broken by votes.""" | |
| enc = slot_encode(question) | |
| vals = enc[1] if enc else [] | |
| navail = max(len(vals), 1) | |
| nums = [float(x) for x in re.findall(r"\d+\.?\d*", question)] | |
| maxn = max(nums) if nums else 1.0 | |
| votes, V = defaultdict(float), defaultdict(lambda: -1e9) | |
| for prog, a in candidates: | |
| votes[a] += 1.0 | |
| cov = len({t for t in prog if t.startswith("[N")}) / navail | |
| mag = 1.0 if a <= mag_k * maxn else 0.0 | |
| inter = _intermediates(prog, vals) | |
| ip = 1.0 if (inter is not None and all(x >= 0 for x in inter)) else 0.0 | |
| ii = 1.0 if (inter is not None and all(abs(x - round(x)) < 1e-6 for x in inter)) else 0.0 | |
| V[a] = max(V[a], cov + mag + ip + ii) | |
| if not votes: | |
| return None | |
| plau = [a for a in votes if _plausible(a)] or list(votes) | |
| return max(plau, key=lambda a: (V[a], votes[a])) | |
| # --------------------------------------------------------------------------- | |
| # 3. the MLX seq2seq model | |
| # --------------------------------------------------------------------------- | |
| def _pad_mask(pad): | |
| return mx.where(pad[:, None, None, :], -1e9, 0.0).astype(mx.float32) | |
| def _causal_mask(L): | |
| return mx.triu(mx.full((L, L), -1e9), k=1)[None, None] | |
| class FFN(nn.Module): | |
| def __init__(self, d, ff): | |
| super().__init__(); self.l1 = nn.Linear(d, ff); self.l2 = nn.Linear(ff, d) | |
| def __call__(self, x): | |
| return self.l2(nn.gelu(self.l1(x))) | |
| class EncLayer(nn.Module): | |
| def __init__(self, d, h, ff): | |
| super().__init__() | |
| self.n1 = nn.LayerNorm(d); self.attn = nn.MultiHeadAttention(d, h) | |
| self.n2 = nn.LayerNorm(d); self.ffn = FFN(d, ff) | |
| def __call__(self, x, m): | |
| h = self.n1(x); x = x + self.attn(h, h, h, m) | |
| h = self.n2(x); return x + self.ffn(h) | |
| class DecLayer(nn.Module): | |
| def __init__(self, d, h, ff): | |
| super().__init__() | |
| self.n1 = nn.LayerNorm(d); self.sa = nn.MultiHeadAttention(d, h) | |
| self.n2 = nn.LayerNorm(d); self.ca = nn.MultiHeadAttention(d, h) | |
| self.n3 = nn.LayerNorm(d); self.ffn = FFN(d, ff) | |
| def __call__(self, x, mem, tm, mm): | |
| h = self.n1(x); x = x + self.sa(h, h, h, tm) | |
| h = self.n2(x); x = x + self.ca(h, mem, mem, mm) | |
| h = self.n3(x); return x + self.ffn(h) | |
| class Seq2Seq(nn.Module): | |
| def __init__(self, src_v, tgt_v, d=304, h=4, ne=4, nd=4, ff=608, | |
| max_src=220, max_tgt=64): | |
| super().__init__() | |
| self.src_emb = nn.Embedding(src_v, d); self.tgt_emb = nn.Embedding(tgt_v, d) | |
| self.src_pos = nn.Embedding(max_src, d); self.tgt_pos = nn.Embedding(max_tgt, d) | |
| self.enc = [EncLayer(d, h, ff) for _ in range(ne)] | |
| self.dec = [DecLayer(d, h, ff) for _ in range(nd)] | |
| self.norm = nn.LayerNorm(d); self.out = nn.Linear(d, tgt_v) | |
| def encode(self, src, src_pad): | |
| L = src.shape[1] | |
| x = self.src_emb(src) + self.src_pos(mx.arange(L)) | |
| m = _pad_mask(src_pad) | |
| for layer in self.enc: | |
| x = layer(x, m) | |
| return x, m | |
| def decode(self, tgt, mem, mm, tgt_pad): | |
| L = tgt.shape[1] | |
| x = self.tgt_emb(tgt) + self.tgt_pos(mx.arange(L)) | |
| tm = _causal_mask(L) + _pad_mask(tgt_pad) | |
| for layer in self.dec: | |
| x = layer(x, mem, tm, mm) | |
| return self.out(self.norm(x)) | |
| def _sample_batch(model, mem, mm, S, T, maxlen=48): | |
| memS = mx.broadcast_to(mem, (S, mem.shape[1], mem.shape[2])) | |
| mmS = mx.broadcast_to(mm, (S, mm.shape[1], mm.shape[2], mm.shape[3])) | |
| seqs = mx.full((S, 1), BOS_ID, dtype=mx.int32) | |
| for _ in range(maxlen): | |
| logits = model.decode(seqs, memS, mmS, | |
| mx.zeros((S, seqs.shape[1]), dtype=mx.bool_)) | |
| nxt = mx.random.categorical(logits[:, -1, :] / T)[:, None] | |
| seqs = mx.concatenate([seqs, nxt.astype(mx.int32)], axis=1) | |
| seqs, out = np.array(seqs), [] | |
| for row in seqs: | |
| toks = [] | |
| for t in row[1:]: | |
| if t == EOS_ID: | |
| break | |
| toks.append(TGT_VOCAB[int(t)]) | |
| out.append(toks) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # 4. load + solve | |
| # --------------------------------------------------------------------------- | |
| def load_model(model_dir: Path): | |
| model_dir = Path(model_dir) | |
| cfg = json.load(open(model_dir / "config.json")) | |
| src_vocab = json.load(open(model_dir / "src_vocab.json")) | |
| stoi = {t: i for i, t in enumerate(src_vocab)} | |
| model = Seq2Seq(len(src_vocab), len(TGT_VOCAB), d=cfg["d"], h=cfg["n_heads"], | |
| ne=cfg["n_layers"], nd=cfg["n_layers"], ff=cfg["ff"], | |
| max_src=cfg["max_src"], max_tgt=cfg["max_tgt"]) | |
| w = np.load(model_dir / "model.npz") | |
| model.update(tree_unflatten([(k, mx.array(w[k])) for k in w.files])) | |
| model.eval() | |
| return model, stoi, cfg | |
| def solve(model, stoi, question, n_samples=96, temp=0.9, seed=0, return_details=False): | |
| """Solve one question. Returns the answer (float) or None.""" | |
| mx.random.seed(seed) | |
| enc = slot_encode(question) | |
| if enc is None: | |
| return None | |
| toks, vals = enc | |
| S = mx.array([[stoi.get(t, stoi[UNK]) for t in toks]]) | |
| mem, mm = model.encode(S, S == 0) | |
| cands = [] | |
| for prog in _sample_batch(model, mem, mm, n_samples, temp): | |
| a = decode_postfix(prog, vals) | |
| if a is None or isinstance(a, complex): | |
| continue | |
| cands.append((prog, round(float(a), 4))) | |
| if not cands: | |
| return None | |
| answer = verify_select(question, cands) | |
| if return_details: | |
| votes = defaultdict(int) | |
| for _, a in cands: | |
| votes[a] += 1 | |
| return {"answer": answer, "n_candidates": len(cands), | |
| "n_distinct": len(votes), "top_votes": sorted(votes.items(), | |
| key=lambda x: -x[1])[:5]} | |
| return answer | |
| def main(): | |
| ap = argparse.ArgumentParser(description="SPROG-9M: LLM-free GSM8K solver") | |
| ap.add_argument("--question", "-q", required=True) | |
| ap.add_argument("--model-dir", default=str(Path(__file__).parent)) | |
| ap.add_argument("--samples", type=int, default=96) | |
| ap.add_argument("--temp", type=float, default=0.9) | |
| ap.add_argument("--details", action="store_true") | |
| a = ap.parse_args() | |
| model, stoi, cfg = load_model(Path(a.model_dir)) | |
| r = solve(model, stoi, a.question, a.samples, a.temp, return_details=a.details) | |
| if a.details: | |
| print(json.dumps(r, indent=2)) | |
| else: | |
| print(f"Answer: {r}") | |
| if __name__ == "__main__": | |
| main() | |