| """ |
| Hands-on playground for the 8bit-threshold-computer. |
| |
| Loads a safetensors model, reads its manifest, and exercises threshold |
| circuits at every level: raw Boolean gates, 8-bit ALU arithmetic and |
| comparators, multi-layer modular arithmetic, and a manifest-sized CPU |
| runtime running a small assembled program end-to-end through the |
| threshold weights. |
| |
| The CPU demo defaults to the small (1 KB) profile so the run finishes in |
| a fraction of a second. Larger profiles (4 KB, 64 KB) take proportionally |
| longer because every memory access decodes against every address line. |
| |
| Usage: |
| python play.py # fast 1KB demo |
| python play.py --model neural_computer.safetensors # full 64KB |
| python play.py --model variants/neural_alu8.safetensors --skip-cpu # ALU only |
| """ |
|
|
| from __future__ import annotations |
| import argparse |
| import os |
| import sys |
|
|
| import torch |
| from safetensors import safe_open |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| |
| from eval_all import GenericThresholdCPU, builtin_program |
|
|
|
|
| def heaviside(x): |
| return (x >= 0).float() |
|
|
|
|
| def load_tensors(path): |
| out = {} |
| with safe_open(path, framework="pt") as f: |
| for name in f.keys(): |
| out[name] = f.get_tensor(name).float() |
| return out |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description="Threshold computer playground") |
| parser.add_argument( |
| "--model", type=str, |
| default=os.path.join(os.path.dirname(__file__), |
| "variants", "neural_computer8_small.safetensors"), |
| help="Path to a .safetensors variant" |
| ) |
| parser.add_argument("--skip-cpu", action="store_true", |
| help="Skip the CPU program demo (useful for pure-ALU files)") |
| args = parser.parse_args() |
|
|
| print("Loading", args.model) |
| T = load_tensors(args.model) |
|
|
| DATA_BITS = int(T["manifest.data_bits"].item()) |
| ADDR_BITS = int(T["manifest.addr_bits"].item()) |
| MEM_BYTES = int(T["manifest.memory_bytes"].item()) |
| REGISTERS = int(T["manifest.registers"].item()) |
| print(f"Manifest: data={DATA_BITS}-bit, addr={ADDR_BITS}-bit, mem={MEM_BYTES}B, regs={REGISTERS}") |
| print(f"Tensors: {len(T):,}") |
| print(f"Total params: {sum(t.numel() for t in T.values()):,}") |
| print() |
|
|
| def gate(name, inputs): |
| w = T[name + ".weight"].view(-1) |
| b = T[name + ".bias"].view(-1) |
| return int(heaviside((torch.tensor(inputs, dtype=torch.float32) * w).sum() + b).item()) |
|
|
| def xor(prefix, inputs): |
| a, b_ = inputs |
| h_or = gate(f"{prefix}.layer1.or", [a, b_]) |
| h_nand = gate(f"{prefix}.layer1.nand", [a, b_]) |
| return gate(f"{prefix}.layer2", [h_or, h_nand]) |
|
|
| def xor_neuron(prefix, inputs): |
| a, b_ = inputs |
| h1 = gate(f"{prefix}.layer1.neuron1", [a, b_]) |
| h2 = gate(f"{prefix}.layer1.neuron2", [a, b_]) |
| return gate(f"{prefix}.layer2", [h1, h2]) |
|
|
| def int_to_bits_msb(v, n): |
| return [(v >> (n - 1 - i)) & 1 for i in range(n)] |
|
|
| def bits_msb_to_int(bits): |
| out = 0 |
| for b in bits: |
| out = (out << 1) | int(b) |
| return out |
|
|
| |
| print("=" * 64) |
| print(" Demo 1: Boolean threshold gates") |
| print("=" * 64) |
| truth_2 = [(0, 0), (0, 1), (1, 0), (1, 1)] |
| for gname in ["and", "or", "nand", "nor", "implies"]: |
| row = " ".join(f"{a}{b}->{gate(f'boolean.{gname}', [a, b])}" for a, b in truth_2) |
| print(f" {gname:8} {row}") |
| for gname in ["xor", "xnor", "biimplies"]: |
| row = " ".join(f"{a}{b}->{xor_neuron(f'boolean.{gname}', [a, b])}" for a, b in truth_2) |
| print(f" {gname:8} {row}") |
| print(f" not 0->{gate('boolean.not', [0])} 1->{gate('boolean.not', [1])}") |
| print() |
|
|
| |
| print("=" * 64) |
| print(" Demo 2: 8-bit ALU arithmetic (every gate is threshold logic)") |
| print("=" * 64) |
|
|
| def fa(prefix, a, b, cin): |
| s1 = xor(f"{prefix}.ha1.sum", [a, b]) |
| c1 = gate(f"{prefix}.ha1.carry", [a, b]) |
| s2 = xor(f"{prefix}.ha2.sum", [s1, cin]) |
| c2 = gate(f"{prefix}.ha2.carry", [s1, cin]) |
| return s2, gate(f"{prefix}.carry_or", [c1, c2]) |
|
|
| def alu_add(a, b): |
| a_lsb = list(reversed(int_to_bits_msb(a, 8))) |
| b_lsb = list(reversed(int_to_bits_msb(b, 8))) |
| carry = 0 |
| sum_lsb = [] |
| for i in range(8): |
| s, carry = fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb[i], b_lsb[i], carry) |
| sum_lsb.append(s) |
| return bits_msb_to_int(list(reversed(sum_lsb))), carry |
|
|
| def alu_sub(a, b): |
| a_lsb = list(reversed(int_to_bits_msb(a, 8))) |
| b_lsb = list(reversed(int_to_bits_msb(b, 8))) |
| carry = 1 |
| diff_lsb = [] |
| for i in range(8): |
| notb = gate(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]]) |
| x1 = xor(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb]) |
| x2 = xor(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry]) |
| and1 = gate(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb]) |
| and2 = gate(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry]) |
| carry = gate(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2]) |
| diff_lsb.append(x2) |
| return bits_msb_to_int(list(reversed(diff_lsb))), carry |
|
|
| def alu_compare(a, b, kind): |
| |
| |
| a_msb = int_to_bits_msb(a, 8) |
| b_msb = int_to_bits_msb(b, 8) |
| bit_gt = [gate(f"arithmetic.cmp8bit.bit{i}.gt", [a_msb[i], b_msb[i]]) for i in range(8)] |
| bit_lt = [gate(f"arithmetic.cmp8bit.bit{i}.lt", [a_msb[i], b_msb[i]]) for i in range(8)] |
| bit_eq = [] |
| for i in range(8): |
| eq_and = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.and", [a_msb[i], b_msb[i]]) |
| eq_nor = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.nor", [a_msb[i], b_msb[i]]) |
| bit_eq.append(gate(f"arithmetic.cmp8bit.bit{i}.eq", [eq_and, eq_nor])) |
| cas_gt = [bit_gt[0]] |
| cas_lt = [bit_lt[0]] |
| for i in range(1, 8): |
| eq_pref = gate(f"arithmetic.cmp8bit.cascade.eq_prefix.bit{i}", bit_eq[:i]) |
| cas_gt.append(gate(f"arithmetic.cmp8bit.cascade.gt.bit{i}", [eq_pref, bit_gt[i]])) |
| cas_lt.append(gate(f"arithmetic.cmp8bit.cascade.lt.bit{i}", [eq_pref, bit_lt[i]])) |
| if kind == "greaterthan": |
| return gate("arithmetic.greaterthan8bit", cas_gt) |
| if kind == "lessthan": |
| return gate("arithmetic.lessthan8bit", cas_lt) |
| if kind == "eq": |
| return gate("arithmetic.equality8bit", bit_eq) |
| raise ValueError(kind) |
|
|
| def alu_mul(a, b): |
| a_bits = int_to_bits_msb(a, 8) |
| b_bits = int_to_bits_msb(b, 8) |
| result = 0 |
| for j in range(8): |
| if b_bits[j] == 0: |
| continue |
| row = 0 |
| for i in range(8): |
| pp = gate(f"alu.alu8bit.mul.pp.a{i}b{j}", [a_bits[i], b_bits[j]]) |
| row |= (pp << (7 - i)) |
| shift = 7 - j |
| result, _ = alu_add(result & 0xFF, (row << shift) & 0xFF) |
| return result & 0xFF |
|
|
| cases_arith = [(5, 3), (37, 100), (200, 99), (255, 1), (127, 128), (15, 17)] |
| print("ADD:") |
| for a, b in cases_arith: |
| r, c = alu_add(a, b) |
| e = (a + b) & 0xFF |
| print(f" {a:3} + {b:3} = {r:3} (carry={c}) expected {e:3} [{'OK' if r == e else 'FAIL'}]") |
| print("SUB:") |
| for a, b in cases_arith: |
| r, c = alu_sub(a, b) |
| e = (a - b) & 0xFF |
| print(f" {a:3} - {b:3} = {r:3} (no_borrow={c}) expected {e:3} [{'OK' if r == e else 'FAIL'}]") |
| print("CMP:") |
| for a, b in [(50, 30), (30, 50), (77, 77), (255, 0), (0, 255), (128, 127)]: |
| gt = alu_compare(a, b, "greaterthan") |
| lt = alu_compare(a, b, "lessthan") |
| eq = alu_compare(a, b, "eq") |
| print(f" {a:3} vs {b:3} -> GT={gt} LT={lt} EQ={eq}") |
| print("MUL (low 8 bits):") |
| for a, b in [(12, 11), (15, 17), (8, 32), (200, 3), (0, 99), (1, 255)]: |
| r = alu_mul(a, b) |
| e = (a * b) & 0xFF |
| print(f" {a:3} * {b:3} = {r:3} expected {e:3} [{'OK' if r == e else 'FAIL'}]") |
| print() |
|
|
| |
| print("=" * 64) |
| print(" Demo 3: mod-5 divisibility (multi-layer, hand-constructed)") |
| print("=" * 64) |
|
|
| def mod5(v): |
| |
| |
| |
| bits = int_to_bits_msb(v, 8) |
| ks = [k for k in range(256) if k % 5 == 0] |
| alls = [] |
| for k in ks: |
| matches = [gate(f"modular.mod5.eq.k{k}.bit{i}.match", [bits[i]]) for i in range(8)] |
| alls.append(gate(f"modular.mod5.eq.k{k}.all", matches)) |
| return gate("modular.mod5", alls) |
|
|
| hits = [v for v in range(256) if mod5(v)] |
| print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}") |
| print(f" Sanity (each %5): {[h % 5 for h in hits[:12]]}") |
| print() |
|
|
| |
| if args.skip_cpu or MEM_BYTES < 0x84: |
| if args.skip_cpu: |
| print("Demo 4 skipped (--skip-cpu).") |
| else: |
| print(f"Demo 4 skipped (memory={MEM_BYTES}B too small for the demo program).") |
| return 0 |
|
|
| print("=" * 64) |
| print(f" Demo 4: Threshold CPU running an assembled program ({MEM_BYTES} B memory)") |
| print("=" * 64) |
| print(" Program: sum 5+4+3+2+1 via loop") |
| print(" uses LOAD/STORE/ADD/SUB/CMP/JNZ/HALT, all threshold-gated") |
| print(" Running ... (larger memories take longer because every memory access") |
| print(" decodes against every address line)") |
| cpu = GenericThresholdCPU({k: v for k, v in T.items()}) |
| mem, expected = builtin_program(ADDR_BITS) |
| state = {"pc": 0, "regs": [0] * 4, "flags": [0] * 4, "mem": mem, "halted": False} |
| final, cycles = cpu.run(state, max_cycles=200) |
| got = final["mem"][0x83] |
| print(f" Halted after {cycles} cycles") |
| print(f" R0={final['regs'][0]} R1={final['regs'][1]} " |
| f"R2={final['regs'][2]} R3={final['regs'][3]}") |
| print(f" M[0x0083] = {got} (expected {expected}) [{'OK' if got == expected else 'FAIL'}]") |
| return 0 if got == expected else 1 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|