gary-4 / chat.py
gary23w's picture
Upload chat.py with huggingface_hub
267684e verified
Raw
History Blame Contribute Delete
2.65 kB
#!/usr/bin/env python3
"""gary-4: a ~69KB chat model. Runs anywhere with just numpy.
Usage: python chat.py (interactive)
python chat.py "hi" (one-shot)"""
import json, sys, os
import numpy as np
D = os.path.dirname(os.path.abspath(__file__))
C = json.load(open(f"{D}/config.json"))
chars, E, H, L, BLK = C["chars"], C["n_embd"], C["n_head"], C["n_layer"], C["block_size"]
stoi = {c: i for i, c in enumerate(chars)}
z = np.load(f"{D}/gary4.int8.npz")
W = {k: z[k].astype(np.float32) * z[k + ".scale"] for k in z.files if not k.endswith(".scale")}
def ln(x, w, b, eps=1e-5):
m, v = x.mean(-1, keepdims=True), x.var(-1, keepdims=True)
return (x - m) / np.sqrt(v + eps) * w + b
def gelu(x):
return 0.5 * x * (1 + np.tanh(0.7978845608 * (x + 0.044715 * x**3)))
def forward(idx):
T = len(idx)
x = W["tok.weight"][idx] + W["pos.weight"][:T]
mask = np.triu(np.full((T, T), -1e9), 1)
hd = E // H
for i in range(L):
p = f"blocks.{i}."
h = ln(x, W[p+"ln1.weight"], W[p+"ln1.bias"])
qkv = h @ W[p+"attn.in_proj_weight"].T + W[p+"attn.in_proj_bias"]
q, k, v = np.split(qkv, 3, -1)
o = np.zeros_like(q)
for j in range(H):
s = slice(j*hd, (j+1)*hd)
att = q[:, s] @ k[:, s].T / np.sqrt(hd) + mask
att = np.exp(att - att.max(-1, keepdims=True))
att /= att.sum(-1, keepdims=True)
o[:, s] = att @ v[:, s]
x = x + o @ W[p+"attn.out_proj.weight"].T + W[p+"attn.out_proj.bias"]
h = ln(x, W[p+"ln2.weight"], W[p+"ln2.bias"])
h = gelu(h @ W[p+"mlp.0.weight"].T + W[p+"mlp.0.bias"])
x = x + h @ W[p+"mlp.2.weight"].T + W[p+"mlp.2.bias"]
x = ln(x, W["lnf.weight"], W["lnf.bias"])
return x[-1] @ W["tok.weight"].T
def generate(prompt, n=120, temp=0.7, seed=None):
rng = np.random.default_rng(seed)
idx = [stoi.get(c, stoi[" "]) for c in prompt]
out = ""
for _ in range(n):
logits = forward(idx[-BLK:]) / temp
p = np.exp(logits - logits.max()); p /= p.sum()
nxt = int(rng.choice(len(p), p=p))
ch = chars[nxt]
out += ch
idx.append(nxt)
if out.endswith("\n\n"): break
return out.strip()
def reply(msg):
return generate(f"User: {msg}\ngary-4:").split("\nUser:")[0].strip()
if __name__ == "__main__":
if len(sys.argv) > 1:
print(reply(" ".join(sys.argv[1:])))
else:
print("gary-4 (~69KB). ctrl-c to exit.")
while True:
try: msg = input("you: ")
except (EOFError, KeyboardInterrupt): break
print("gary-4:", reply(msg))