#!/usr/bin/env python3 """gary-4: a ~69KB chat model. Runs anywhere with just numpy. Usage: python chat.py (interactive) python chat.py "hi" (one-shot)""" import json, sys, os import numpy as np D = os.path.dirname(os.path.abspath(__file__)) C = json.load(open(f"{D}/config.json")) chars, E, H, L, BLK = C["chars"], C["n_embd"], C["n_head"], C["n_layer"], C["block_size"] stoi = {c: i for i, c in enumerate(chars)} z = np.load(f"{D}/gary4.int8.npz") W = {k: z[k].astype(np.float32) * z[k + ".scale"] for k in z.files if not k.endswith(".scale")} def ln(x, w, b, eps=1e-5): m, v = x.mean(-1, keepdims=True), x.var(-1, keepdims=True) return (x - m) / np.sqrt(v + eps) * w + b def gelu(x): return 0.5 * x * (1 + np.tanh(0.7978845608 * (x + 0.044715 * x**3))) def forward(idx): T = len(idx) x = W["tok.weight"][idx] + W["pos.weight"][:T] mask = np.triu(np.full((T, T), -1e9), 1) hd = E // H for i in range(L): p = f"blocks.{i}." h = ln(x, W[p+"ln1.weight"], W[p+"ln1.bias"]) qkv = h @ W[p+"attn.in_proj_weight"].T + W[p+"attn.in_proj_bias"] q, k, v = np.split(qkv, 3, -1) o = np.zeros_like(q) for j in range(H): s = slice(j*hd, (j+1)*hd) att = q[:, s] @ k[:, s].T / np.sqrt(hd) + mask att = np.exp(att - att.max(-1, keepdims=True)) att /= att.sum(-1, keepdims=True) o[:, s] = att @ v[:, s] x = x + o @ W[p+"attn.out_proj.weight"].T + W[p+"attn.out_proj.bias"] h = ln(x, W[p+"ln2.weight"], W[p+"ln2.bias"]) h = gelu(h @ W[p+"mlp.0.weight"].T + W[p+"mlp.0.bias"]) x = x + h @ W[p+"mlp.2.weight"].T + W[p+"mlp.2.bias"] x = ln(x, W["lnf.weight"], W["lnf.bias"]) return x[-1] @ W["tok.weight"].T def generate(prompt, n=120, temp=0.7, seed=None): rng = np.random.default_rng(seed) idx = [stoi.get(c, stoi[" "]) for c in prompt] out = "" for _ in range(n): logits = forward(idx[-BLK:]) / temp p = np.exp(logits - logits.max()); p /= p.sum() nxt = int(rng.choice(len(p), p=p)) ch = chars[nxt] out += ch idx.append(nxt) if out.endswith("\n\n"): break return out.strip() def reply(msg): return generate(f"User: {msg}\ngary-4:").split("\nUser:")[0].strip() if __name__ == "__main__": if len(sys.argv) > 1: print(reply(" ".join(sys.argv[1:]))) else: print("gary-4 (~69KB). ctrl-c to exit.") while True: try: msg = input("you: ") except (EOFError, KeyboardInterrupt): break print("gary-4:", reply(msg))