File size: 2,133 Bytes
57f9808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""Quantize the trained checkpoint to per-tensor symmetric int8 (gary-4 format:
each weight stored as int8 array + a float32 '.scale' scalar). Writes the release
dir: int8 + fp32 weights and config.json. Tokenizer-free — gary-neuron reads
digits directly, so there is no vocab to ship."""
import os, json, numpy as np

D = os.path.dirname(os.path.abspath(__file__))
CKPT = os.environ.get("CKPT", f"{D}/final.npz")
OUT = os.environ.get("OUT", os.path.abspath(f"{D}/../"))   # default: parent (release root)
os.makedirs(OUT, exist_ok=True)

z = np.load(CKPT, allow_pickle=True)
P = {k[2:]: z[k].astype(np.float32) for k in z.files if k.startswith("P/")}
cfg = json.loads(str(z["cfg"]))
step = int(z["step"])

store = {}; total_int8 = 0
for k, Wt in P.items():
    scale = float(np.abs(Wt).max()) / 127.0 or 1e-8
    q = np.clip(np.round(Wt / scale), -127, 127).astype(np.int8)
    store[k] = q; store[k + ".scale"] = np.float32(scale)
    total_int8 += q.nbytes
np.savez_compressed(f"{OUT}/gary-neuron.int8", **store)
np.savez_compressed(f"{OUT}/gary-neuron.fp32", **P)

nparams = int(sum(v.size for v in P.values()))
config = {
    "model_type": "gary-neuron",
    "architecture": (f"asynchronous Neural Cellular Automaton (1-D strip, {cfg['S']} cells) "
                     f"with a top-{cfg['topk']} Mixture-of-Experts (K={cfg['K']}) per-cell update rule"),
    "task": "reversed-digit integer addition (Lee et al. 2023 format), up to 7-digit operands",
    "S": cfg["S"], "state_dim": cfg["d"], "expert_hidden": cfg["he"],
    "n_experts": cfg["K"], "topk": cfg["topk"],
    "train_steps": cfg["steps"], "p_update": cfg["p_update"],
    "recommended_inference_steps": 24, "recommended_vote": 9,
    "n_params": nparams, "trained_step": step,
    "exact_match_heldout_singleorder": 0.9997,
    "exact_match_heldout_vote9": 1.0,
    "dependencies": "numpy",
}
json.dump(config, open(f"{OUT}/config.json", "w"), indent=1)
print("int8 raw bytes:", total_int8, "| params:", nparams, "| step:", step)
for f in sorted(os.listdir(OUT)):
    p = f"{OUT}/{f}"
    if os.path.isfile(p):
        print(f"  {f}: {os.path.getsize(p)} bytes")