| |
| """gary-4: a ~69KB chat model. Runs anywhere with just numpy. |
| Usage: python chat.py (interactive) |
| python chat.py "hi" (one-shot)""" |
| import json, sys, os |
| import numpy as np |
|
|
| D = os.path.dirname(os.path.abspath(__file__)) |
| C = json.load(open(f"{D}/config.json")) |
| chars, E, H, L, BLK = C["chars"], C["n_embd"], C["n_head"], C["n_layer"], C["block_size"] |
| stoi = {c: i for i, c in enumerate(chars)} |
| z = np.load(f"{D}/gary4.int8.npz") |
| W = {k: z[k].astype(np.float32) * z[k + ".scale"] for k in z.files if not k.endswith(".scale")} |
|
|
| def ln(x, w, b, eps=1e-5): |
| m, v = x.mean(-1, keepdims=True), x.var(-1, keepdims=True) |
| return (x - m) / np.sqrt(v + eps) * w + b |
|
|
| def gelu(x): |
| return 0.5 * x * (1 + np.tanh(0.7978845608 * (x + 0.044715 * x**3))) |
|
|
| def forward(idx): |
| T = len(idx) |
| x = W["tok.weight"][idx] + W["pos.weight"][:T] |
| mask = np.triu(np.full((T, T), -1e9), 1) |
| hd = E // H |
| for i in range(L): |
| p = f"blocks.{i}." |
| h = ln(x, W[p+"ln1.weight"], W[p+"ln1.bias"]) |
| qkv = h @ W[p+"attn.in_proj_weight"].T + W[p+"attn.in_proj_bias"] |
| q, k, v = np.split(qkv, 3, -1) |
| o = np.zeros_like(q) |
| for j in range(H): |
| s = slice(j*hd, (j+1)*hd) |
| att = q[:, s] @ k[:, s].T / np.sqrt(hd) + mask |
| att = np.exp(att - att.max(-1, keepdims=True)) |
| att /= att.sum(-1, keepdims=True) |
| o[:, s] = att @ v[:, s] |
| x = x + o @ W[p+"attn.out_proj.weight"].T + W[p+"attn.out_proj.bias"] |
| h = ln(x, W[p+"ln2.weight"], W[p+"ln2.bias"]) |
| h = gelu(h @ W[p+"mlp.0.weight"].T + W[p+"mlp.0.bias"]) |
| x = x + h @ W[p+"mlp.2.weight"].T + W[p+"mlp.2.bias"] |
| x = ln(x, W["lnf.weight"], W["lnf.bias"]) |
| return x[-1] @ W["tok.weight"].T |
|
|
| def generate(prompt, n=120, temp=0.7, seed=None): |
| rng = np.random.default_rng(seed) |
| idx = [stoi.get(c, stoi[" "]) for c in prompt] |
| out = "" |
| for _ in range(n): |
| logits = forward(idx[-BLK:]) / temp |
| p = np.exp(logits - logits.max()); p /= p.sum() |
| nxt = int(rng.choice(len(p), p=p)) |
| ch = chars[nxt] |
| out += ch |
| idx.append(nxt) |
| if out.endswith("\n\n"): break |
| return out.strip() |
|
|
| def reply(msg): |
| return generate(f"User: {msg}\ngary-4:").split("\nUser:")[0].strip() |
|
|
| if __name__ == "__main__": |
| if len(sys.argv) > 1: |
| print(reply(" ".join(sys.argv[1:]))) |
| else: |
| print("gary-4 (~69KB). ctrl-c to exit.") |
| while True: |
| try: msg = input("you: ") |
| except (EOFError, KeyboardInterrupt): break |
| print("gary-4:", reply(msg)) |
|
|