mathstral-nano-sat / modeling_mathstral_nano.py
jvaldi's picture
Upload mathstral-nano-sat
65c3c06 verified
"""
MathstralNano: tiny 4-layer GPT-style transformer trained on SAT math problems.
Load and run inference entirely in NumPy β€” no PyTorch required.
Usage:
from modeling_mathstral_nano import MathstralNano
model = MathstralNano.from_pretrained(".") # path to repo folder
print(model.generate("Problem: If 2x+5=13, find x. Solution:"))
"""
import numpy as np, math, json, os
from scipy.special import softmax as sp_softmax # pip install scipy
class MathstralNano:
"""Tiny causal transformer for SAT-level math problem solving."""
SYSTEM = "Problem: {question} Solution:"
def __init__(self, params: dict, config: dict):
self.P = params
self.cfg = config
# ── class method: load from a local directory ─────────────────────────────
@classmethod
def from_pretrained(cls, model_dir: str) -> "MathstralNano":
"""Load weights and config from a local folder or HF repo path."""
try:
from safetensors.numpy import load_file
params = load_file(os.path.join(model_dir, "model.safetensors"))
except ImportError:
raise ImportError("pip install safetensors")
with open(os.path.join(model_dir, "config.json")) as f:
config = json.load(f)
return cls(params, config)
# ── internal helpers ──────────────────────────────────────────────────────
@staticmethod
def _encode(text: str, length: int) -> list:
ids = list(text.encode("utf-8", errors="replace"))[:length]
ids += [0] * (length - len(ids))
return ids
@staticmethod
def _layer_norm(x, w, b, eps=1e-5):
mu = x.mean(-1, keepdims=True)
std = x.std(-1, keepdims=True)
return w * (x - mu) / (std + eps) + b
@staticmethod
def _gelu(x):
c = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + np.tanh(c * (x + 0.044715 * x ** 3)))
def _forward(self, x_ids: np.ndarray):
"""x_ids (B, T) -> logits (B, T, VOCAB)"""
P = self.P
B, T = x_ids.shape
D = self.cfg["d_model"]
H = self.cfg["n_head"]
DH = D // H
NL = self.cfg["n_layer"]
SEQ = self.cfg["seq_len"]
x = P["tok_emb"][x_ids] + P["pos_emb"][:T]
causal = np.triu(np.full((T, T), -1e9), k=1)
for i in range(NL):
n = f"L{i}"
x_ln = self._layer_norm(x, P[f"{n}_ln1_w"], P[f"{n}_ln1_b"])
qkv = x_ln @ P[f"{n}_qkv"] + P[f"{n}_qkv_b"]
Q_mat, K_mat, Val = np.split(qkv, 3, axis=-1)
Q_mat = Q_mat.reshape(B, T, H, DH).transpose(0, 2, 1, 3)
K_mat = K_mat.reshape(B, T, H, DH).transpose(0, 2, 1, 3)
Val = Val.reshape(B, T, H, DH).transpose(0, 2, 1, 3)
sc = Q_mat @ K_mat.transpose(0, 1, 3, 2) / math.sqrt(DH) + causal
attn = sp_softmax(sc, axis=-1)
ctx = (attn @ Val).transpose(0, 2, 1, 3).reshape(B, T, D)
x = x + ctx @ P[f"{n}_proj"] + P[f"{n}_proj_b"]
x_ln2 = self._layer_norm(x, P[f"{n}_ln2_w"], P[f"{n}_ln2_b"])
h1 = self._gelu(x_ln2 @ P[f"{n}_fc1"] + P[f"{n}_fc1_b"])
x = x + h1 @ P[f"{n}_fc2"] + P[f"{n}_fc2_b"]
x_out = self._layer_norm(x, P["ln_f_w"], P["ln_f_b"])
logits = x_out @ P["head"]
return logits
# ── public inference API ──────────────────────────────────────────────────
def generate(
self,
prompt: str,
max_new: int = 80,
temperature: float = 0.8,
seed: int = None,
) -> str:
"""Generate a completion for the given prompt string."""
rng = np.random.default_rng(seed)
SEQ = self.cfg["seq_len"]
ids = self._encode(prompt, SEQ)
out = []
for _ in range(max_new):
logits = self._forward(np.array([ids]))
last = logits[0, -1, :].astype(np.float64)
last = (last - last.max()) / max(temperature, 1e-6)
probs = np.exp(last) / np.exp(last).sum()
tok = int(rng.choice(self.cfg["vocab_size"], p=probs))
out.append(tok)
ids = ids[1:] + [tok]
return bytes(out).decode("utf-8", errors="replace")
def solve(self, question: str, **kwargs) -> str:
"""Convenience wrapper: formats the SAT-style prompt automatically."""
prompt = self.SYSTEM.format(question=question.strip())
return prompt + self.generate(prompt, **kwargs)
def __repr__(self):
c = self.cfg
return (
f"MathstralNano("
f"{c['n_layer']}L {c['n_head']}H {c['d_model']}d "
f"params={c['total_params']:,})"
)