| """Quantize the trained checkpoint to per-tensor symmetric int8 (gary-4 format: |
| each weight stored as int8 array + a float32 '.scale' scalar). Writes the release |
| dir: int8 + fp32 weights and config.json. Tokenizer-free — gary-neuron reads |
| digits directly, so there is no vocab to ship.""" |
| import os, json, numpy as np |
|
|
| D = os.path.dirname(os.path.abspath(__file__)) |
| CKPT = os.environ.get("CKPT", f"{D}/final.npz") |
| OUT = os.environ.get("OUT", os.path.abspath(f"{D}/../")) |
| os.makedirs(OUT, exist_ok=True) |
|
|
| z = np.load(CKPT, allow_pickle=True) |
| P = {k[2:]: z[k].astype(np.float32) for k in z.files if k.startswith("P/")} |
| cfg = json.loads(str(z["cfg"])) |
| step = int(z["step"]) |
|
|
| store = {}; total_int8 = 0 |
| for k, Wt in P.items(): |
| scale = float(np.abs(Wt).max()) / 127.0 or 1e-8 |
| q = np.clip(np.round(Wt / scale), -127, 127).astype(np.int8) |
| store[k] = q; store[k + ".scale"] = np.float32(scale) |
| total_int8 += q.nbytes |
| np.savez_compressed(f"{OUT}/gary-neuron.int8", **store) |
| np.savez_compressed(f"{OUT}/gary-neuron.fp32", **P) |
|
|
| nparams = int(sum(v.size for v in P.values())) |
| config = { |
| "model_type": "gary-neuron", |
| "architecture": (f"asynchronous Neural Cellular Automaton (1-D strip, {cfg['S']} cells) " |
| f"with a top-{cfg['topk']} Mixture-of-Experts (K={cfg['K']}) per-cell update rule"), |
| "task": "reversed-digit integer addition (Lee et al. 2023 format), up to 7-digit operands", |
| "S": cfg["S"], "state_dim": cfg["d"], "expert_hidden": cfg["he"], |
| "n_experts": cfg["K"], "topk": cfg["topk"], |
| "train_steps": cfg["steps"], "p_update": cfg["p_update"], |
| "recommended_inference_steps": 24, "recommended_vote": 9, |
| "n_params": nparams, "trained_step": step, |
| "exact_match_heldout_singleorder": 0.9997, |
| "exact_match_heldout_vote9": 1.0, |
| "dependencies": "numpy", |
| } |
| json.dump(config, open(f"{OUT}/config.json", "w"), indent=1) |
| print("int8 raw bytes:", total_int8, "| params:", nparams, "| step:", step) |
| for f in sorted(os.listdir(OUT)): |
| p = f"{OUT}/{f}" |
| if os.path.isfile(p): |
| print(f" {f}: {os.path.getsize(p)} bytes") |
|
|