""" Hands-on playground for the 8bit-threshold-computer. Loads a safetensors model, reads its manifest, and exercises threshold circuits at every level: raw Boolean gates, 8-bit ALU arithmetic and comparators, multi-layer modular arithmetic, and a manifest-sized CPU runtime running a small assembled program end-to-end through the threshold weights. The CPU demo defaults to the small (1 KB) profile so the run finishes in a fraction of a second. Larger profiles (4 KB, 64 KB) take proportionally longer because every memory access decodes against every address line. Usage: python play.py # fast 1KB demo python play.py --model neural_computer.safetensors # full 64KB python play.py --model variants/neural_alu8.safetensors --skip-cpu # ALU only """ from __future__ import annotations import argparse import os import sys import torch from safetensors import safe_open sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # Reuse the variant-aware CPU runtime from eval_all.py from eval_all import GenericThresholdCPU, builtin_program def heaviside(x): return (x >= 0).float() def load_tensors(path): out = {} with safe_open(path, framework="pt") as f: for name in f.keys(): out[name] = f.get_tensor(name).float() return out def main() -> int: parser = argparse.ArgumentParser(description="Threshold computer playground") parser.add_argument( "--model", type=str, default=os.path.join(os.path.dirname(__file__), "variants", "neural_computer8_small.safetensors"), help="Path to a .safetensors variant" ) parser.add_argument("--skip-cpu", action="store_true", help="Skip the CPU program demo (useful for pure-ALU files)") args = parser.parse_args() print("Loading", args.model) T = load_tensors(args.model) DATA_BITS = int(T["manifest.data_bits"].item()) ADDR_BITS = int(T["manifest.addr_bits"].item()) MEM_BYTES = int(T["manifest.memory_bytes"].item()) REGISTERS = int(T["manifest.registers"].item()) print(f"Manifest: data={DATA_BITS}-bit, addr={ADDR_BITS}-bit, mem={MEM_BYTES}B, regs={REGISTERS}") print(f"Tensors: {len(T):,}") print(f"Total params: {sum(t.numel() for t in T.values()):,}") print() def gate(name, inputs): w = T[name + ".weight"].view(-1) b = T[name + ".bias"].view(-1) return int(heaviside((torch.tensor(inputs, dtype=torch.float32) * w).sum() + b).item()) def xor(prefix, inputs): a, b_ = inputs h_or = gate(f"{prefix}.layer1.or", [a, b_]) h_nand = gate(f"{prefix}.layer1.nand", [a, b_]) return gate(f"{prefix}.layer2", [h_or, h_nand]) def xor_neuron(prefix, inputs): a, b_ = inputs h1 = gate(f"{prefix}.layer1.neuron1", [a, b_]) h2 = gate(f"{prefix}.layer1.neuron2", [a, b_]) return gate(f"{prefix}.layer2", [h1, h2]) def int_to_bits_msb(v, n): return [(v >> (n - 1 - i)) & 1 for i in range(n)] def bits_msb_to_int(bits): out = 0 for b in bits: out = (out << 1) | int(b) return out # ---------- Demo 1: Boolean gates ---------- print("=" * 64) print(" Demo 1: Boolean threshold gates") print("=" * 64) truth_2 = [(0, 0), (0, 1), (1, 0), (1, 1)] for gname in ["and", "or", "nand", "nor", "implies"]: row = " ".join(f"{a}{b}->{gate(f'boolean.{gname}', [a, b])}" for a, b in truth_2) print(f" {gname:8} {row}") for gname in ["xor", "xnor", "biimplies"]: row = " ".join(f"{a}{b}->{xor_neuron(f'boolean.{gname}', [a, b])}" for a, b in truth_2) print(f" {gname:8} {row}") print(f" not 0->{gate('boolean.not', [0])} 1->{gate('boolean.not', [1])}") print() # ---------- Demo 2: 8-bit ALU arithmetic ---------- print("=" * 64) print(" Demo 2: 8-bit ALU arithmetic (every gate is threshold logic)") print("=" * 64) def fa(prefix, a, b, cin): s1 = xor(f"{prefix}.ha1.sum", [a, b]) c1 = gate(f"{prefix}.ha1.carry", [a, b]) s2 = xor(f"{prefix}.ha2.sum", [s1, cin]) c2 = gate(f"{prefix}.ha2.carry", [s1, cin]) return s2, gate(f"{prefix}.carry_or", [c1, c2]) def alu_add(a, b): a_lsb = list(reversed(int_to_bits_msb(a, 8))) b_lsb = list(reversed(int_to_bits_msb(b, 8))) carry = 0 sum_lsb = [] for i in range(8): s, carry = fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb[i], b_lsb[i], carry) sum_lsb.append(s) return bits_msb_to_int(list(reversed(sum_lsb))), carry def alu_sub(a, b): a_lsb = list(reversed(int_to_bits_msb(a, 8))) b_lsb = list(reversed(int_to_bits_msb(b, 8))) carry = 1 diff_lsb = [] for i in range(8): notb = gate(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]]) x1 = xor(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb]) x2 = xor(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry]) and1 = gate(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb]) and2 = gate(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry]) carry = gate(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2]) diff_lsb.append(x2) return bits_msb_to_int(list(reversed(diff_lsb))), carry def alu_compare(a, b, kind): # Walks the bit-cascade comparator family: per-bit gt/lt/eq, cascaded # eq_prefix, cascade.gt/lt, and the final OR/AND gates. Bit 0 is MSB. a_msb = int_to_bits_msb(a, 8) b_msb = int_to_bits_msb(b, 8) bit_gt = [gate(f"arithmetic.cmp8bit.bit{i}.gt", [a_msb[i], b_msb[i]]) for i in range(8)] bit_lt = [gate(f"arithmetic.cmp8bit.bit{i}.lt", [a_msb[i], b_msb[i]]) for i in range(8)] bit_eq = [] for i in range(8): eq_and = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.and", [a_msb[i], b_msb[i]]) eq_nor = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.nor", [a_msb[i], b_msb[i]]) bit_eq.append(gate(f"arithmetic.cmp8bit.bit{i}.eq", [eq_and, eq_nor])) cas_gt = [bit_gt[0]] cas_lt = [bit_lt[0]] for i in range(1, 8): eq_pref = gate(f"arithmetic.cmp8bit.cascade.eq_prefix.bit{i}", bit_eq[:i]) cas_gt.append(gate(f"arithmetic.cmp8bit.cascade.gt.bit{i}", [eq_pref, bit_gt[i]])) cas_lt.append(gate(f"arithmetic.cmp8bit.cascade.lt.bit{i}", [eq_pref, bit_lt[i]])) if kind == "greaterthan": return gate("arithmetic.greaterthan8bit", cas_gt) if kind == "lessthan": return gate("arithmetic.lessthan8bit", cas_lt) if kind == "eq": return gate("arithmetic.equality8bit", bit_eq) raise ValueError(kind) def alu_mul(a, b): a_bits = int_to_bits_msb(a, 8) b_bits = int_to_bits_msb(b, 8) result = 0 for j in range(8): if b_bits[j] == 0: continue row = 0 for i in range(8): pp = gate(f"alu.alu8bit.mul.pp.a{i}b{j}", [a_bits[i], b_bits[j]]) row |= (pp << (7 - i)) shift = 7 - j result, _ = alu_add(result & 0xFF, (row << shift) & 0xFF) return result & 0xFF cases_arith = [(5, 3), (37, 100), (200, 99), (255, 1), (127, 128), (15, 17)] print("ADD:") for a, b in cases_arith: r, c = alu_add(a, b) e = (a + b) & 0xFF print(f" {a:3} + {b:3} = {r:3} (carry={c}) expected {e:3} [{'OK' if r == e else 'FAIL'}]") print("SUB:") for a, b in cases_arith: r, c = alu_sub(a, b) e = (a - b) & 0xFF print(f" {a:3} - {b:3} = {r:3} (no_borrow={c}) expected {e:3} [{'OK' if r == e else 'FAIL'}]") print("CMP:") for a, b in [(50, 30), (30, 50), (77, 77), (255, 0), (0, 255), (128, 127)]: gt = alu_compare(a, b, "greaterthan") lt = alu_compare(a, b, "lessthan") eq = alu_compare(a, b, "eq") print(f" {a:3} vs {b:3} -> GT={gt} LT={lt} EQ={eq}") print("MUL (low 8 bits):") for a, b in [(12, 11), (15, 17), (8, 32), (200, 3), (0, 99), (1, 255)]: r = alu_mul(a, b) e = (a * b) & 0xFF print(f" {a:3} * {b:3} = {r:3} expected {e:3} [{'OK' if r == e else 'FAIL'}]") print() # ---------- Demo 3: mod-5 divisibility ---------- print("=" * 64) print(" Demo 3: mod-5 divisibility (multi-layer, hand-constructed)") print("=" * 64) def mod5(v): # Per-multiple-of-5 match (k0, k5, ..., k255): each k has 8 single-input # "bit{i}.match" gates that fire when bit i of v matches bit i of k, # ANDed by ".all". Final ".weight" ORs all 52 "all" outputs. bits = int_to_bits_msb(v, 8) ks = [k for k in range(256) if k % 5 == 0] alls = [] for k in ks: matches = [gate(f"modular.mod5.eq.k{k}.bit{i}.match", [bits[i]]) for i in range(8)] alls.append(gate(f"modular.mod5.eq.k{k}.all", matches)) return gate("modular.mod5", alls) hits = [v for v in range(256) if mod5(v)] print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}") print(f" Sanity (each %5): {[h % 5 for h in hits[:12]]}") print() # ---------- Demo 4: CPU running an assembled program ---------- if args.skip_cpu or MEM_BYTES < 0x84: if args.skip_cpu: print("Demo 4 skipped (--skip-cpu).") else: print(f"Demo 4 skipped (memory={MEM_BYTES}B too small for the demo program).") return 0 print("=" * 64) print(f" Demo 4: Threshold CPU running an assembled program ({MEM_BYTES} B memory)") print("=" * 64) print(" Program: sum 5+4+3+2+1 via loop") print(" uses LOAD/STORE/ADD/SUB/CMP/JNZ/HALT, all threshold-gated") print(" Running ... (larger memories take longer because every memory access") print(" decodes against every address line)") cpu = GenericThresholdCPU({k: v for k, v in T.items()}) mem, expected = builtin_program(ADDR_BITS) state = {"pc": 0, "regs": [0] * 4, "flags": [0] * 4, "mem": mem, "halted": False} final, cycles = cpu.run(state, max_cycles=200) got = final["mem"][0x83] print(f" Halted after {cycles} cycles") print(f" R0={final['regs'][0]} R1={final['regs'][1]} " f"R2={final['regs'][2]} R3={final['regs'][3]}") print(f" M[0x0083] = {got} (expected {expected}) [{'OK' if got == expected else 'FAIL'}]") return 0 if got == expected else 1 if __name__ == "__main__": sys.exit(main())