CharlesCNorton
Update play.py and eval.py to current bit-cascade comparator/mod-5 layout
df99f2e
"""
Hands-on playground for the 8bit-threshold-computer.
Loads a safetensors model, reads its manifest, and exercises threshold
circuits at every level: raw Boolean gates, 8-bit ALU arithmetic and
comparators, multi-layer modular arithmetic, and a manifest-sized CPU
runtime running a small assembled program end-to-end through the
threshold weights.
The CPU demo defaults to the small (1 KB) profile so the run finishes in
a fraction of a second. Larger profiles (4 KB, 64 KB) take proportionally
longer because every memory access decodes against every address line.
Usage:
python play.py # fast 1KB demo
python play.py --model neural_computer.safetensors # full 64KB
python play.py --model variants/neural_alu8.safetensors --skip-cpu # ALU only
"""
from __future__ import annotations
import argparse
import os
import sys
import torch
from safetensors import safe_open
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Reuse the variant-aware CPU runtime from eval_all.py
from eval_all import GenericThresholdCPU, builtin_program
def heaviside(x):
return (x >= 0).float()
def load_tensors(path):
out = {}
with safe_open(path, framework="pt") as f:
for name in f.keys():
out[name] = f.get_tensor(name).float()
return out
def main() -> int:
parser = argparse.ArgumentParser(description="Threshold computer playground")
parser.add_argument(
"--model", type=str,
default=os.path.join(os.path.dirname(__file__),
"variants", "neural_computer8_small.safetensors"),
help="Path to a .safetensors variant"
)
parser.add_argument("--skip-cpu", action="store_true",
help="Skip the CPU program demo (useful for pure-ALU files)")
args = parser.parse_args()
print("Loading", args.model)
T = load_tensors(args.model)
DATA_BITS = int(T["manifest.data_bits"].item())
ADDR_BITS = int(T["manifest.addr_bits"].item())
MEM_BYTES = int(T["manifest.memory_bytes"].item())
REGISTERS = int(T["manifest.registers"].item())
print(f"Manifest: data={DATA_BITS}-bit, addr={ADDR_BITS}-bit, mem={MEM_BYTES}B, regs={REGISTERS}")
print(f"Tensors: {len(T):,}")
print(f"Total params: {sum(t.numel() for t in T.values()):,}")
print()
def gate(name, inputs):
w = T[name + ".weight"].view(-1)
b = T[name + ".bias"].view(-1)
return int(heaviside((torch.tensor(inputs, dtype=torch.float32) * w).sum() + b).item())
def xor(prefix, inputs):
a, b_ = inputs
h_or = gate(f"{prefix}.layer1.or", [a, b_])
h_nand = gate(f"{prefix}.layer1.nand", [a, b_])
return gate(f"{prefix}.layer2", [h_or, h_nand])
def xor_neuron(prefix, inputs):
a, b_ = inputs
h1 = gate(f"{prefix}.layer1.neuron1", [a, b_])
h2 = gate(f"{prefix}.layer1.neuron2", [a, b_])
return gate(f"{prefix}.layer2", [h1, h2])
def int_to_bits_msb(v, n):
return [(v >> (n - 1 - i)) & 1 for i in range(n)]
def bits_msb_to_int(bits):
out = 0
for b in bits:
out = (out << 1) | int(b)
return out
# ---------- Demo 1: Boolean gates ----------
print("=" * 64)
print(" Demo 1: Boolean threshold gates")
print("=" * 64)
truth_2 = [(0, 0), (0, 1), (1, 0), (1, 1)]
for gname in ["and", "or", "nand", "nor", "implies"]:
row = " ".join(f"{a}{b}->{gate(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
print(f" {gname:8} {row}")
for gname in ["xor", "xnor", "biimplies"]:
row = " ".join(f"{a}{b}->{xor_neuron(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
print(f" {gname:8} {row}")
print(f" not 0->{gate('boolean.not', [0])} 1->{gate('boolean.not', [1])}")
print()
# ---------- Demo 2: 8-bit ALU arithmetic ----------
print("=" * 64)
print(" Demo 2: 8-bit ALU arithmetic (every gate is threshold logic)")
print("=" * 64)
def fa(prefix, a, b, cin):
s1 = xor(f"{prefix}.ha1.sum", [a, b])
c1 = gate(f"{prefix}.ha1.carry", [a, b])
s2 = xor(f"{prefix}.ha2.sum", [s1, cin])
c2 = gate(f"{prefix}.ha2.carry", [s1, cin])
return s2, gate(f"{prefix}.carry_or", [c1, c2])
def alu_add(a, b):
a_lsb = list(reversed(int_to_bits_msb(a, 8)))
b_lsb = list(reversed(int_to_bits_msb(b, 8)))
carry = 0
sum_lsb = []
for i in range(8):
s, carry = fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb[i], b_lsb[i], carry)
sum_lsb.append(s)
return bits_msb_to_int(list(reversed(sum_lsb))), carry
def alu_sub(a, b):
a_lsb = list(reversed(int_to_bits_msb(a, 8)))
b_lsb = list(reversed(int_to_bits_msb(b, 8)))
carry = 1
diff_lsb = []
for i in range(8):
notb = gate(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]])
x1 = xor(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb])
x2 = xor(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry])
and1 = gate(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb])
and2 = gate(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry])
carry = gate(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2])
diff_lsb.append(x2)
return bits_msb_to_int(list(reversed(diff_lsb))), carry
def alu_compare(a, b, kind):
# Walks the bit-cascade comparator family: per-bit gt/lt/eq, cascaded
# eq_prefix, cascade.gt/lt, and the final OR/AND gates. Bit 0 is MSB.
a_msb = int_to_bits_msb(a, 8)
b_msb = int_to_bits_msb(b, 8)
bit_gt = [gate(f"arithmetic.cmp8bit.bit{i}.gt", [a_msb[i], b_msb[i]]) for i in range(8)]
bit_lt = [gate(f"arithmetic.cmp8bit.bit{i}.lt", [a_msb[i], b_msb[i]]) for i in range(8)]
bit_eq = []
for i in range(8):
eq_and = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.and", [a_msb[i], b_msb[i]])
eq_nor = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.nor", [a_msb[i], b_msb[i]])
bit_eq.append(gate(f"arithmetic.cmp8bit.bit{i}.eq", [eq_and, eq_nor]))
cas_gt = [bit_gt[0]]
cas_lt = [bit_lt[0]]
for i in range(1, 8):
eq_pref = gate(f"arithmetic.cmp8bit.cascade.eq_prefix.bit{i}", bit_eq[:i])
cas_gt.append(gate(f"arithmetic.cmp8bit.cascade.gt.bit{i}", [eq_pref, bit_gt[i]]))
cas_lt.append(gate(f"arithmetic.cmp8bit.cascade.lt.bit{i}", [eq_pref, bit_lt[i]]))
if kind == "greaterthan":
return gate("arithmetic.greaterthan8bit", cas_gt)
if kind == "lessthan":
return gate("arithmetic.lessthan8bit", cas_lt)
if kind == "eq":
return gate("arithmetic.equality8bit", bit_eq)
raise ValueError(kind)
def alu_mul(a, b):
a_bits = int_to_bits_msb(a, 8)
b_bits = int_to_bits_msb(b, 8)
result = 0
for j in range(8):
if b_bits[j] == 0:
continue
row = 0
for i in range(8):
pp = gate(f"alu.alu8bit.mul.pp.a{i}b{j}", [a_bits[i], b_bits[j]])
row |= (pp << (7 - i))
shift = 7 - j
result, _ = alu_add(result & 0xFF, (row << shift) & 0xFF)
return result & 0xFF
cases_arith = [(5, 3), (37, 100), (200, 99), (255, 1), (127, 128), (15, 17)]
print("ADD:")
for a, b in cases_arith:
r, c = alu_add(a, b)
e = (a + b) & 0xFF
print(f" {a:3} + {b:3} = {r:3} (carry={c}) expected {e:3} [{'OK' if r == e else 'FAIL'}]")
print("SUB:")
for a, b in cases_arith:
r, c = alu_sub(a, b)
e = (a - b) & 0xFF
print(f" {a:3} - {b:3} = {r:3} (no_borrow={c}) expected {e:3} [{'OK' if r == e else 'FAIL'}]")
print("CMP:")
for a, b in [(50, 30), (30, 50), (77, 77), (255, 0), (0, 255), (128, 127)]:
gt = alu_compare(a, b, "greaterthan")
lt = alu_compare(a, b, "lessthan")
eq = alu_compare(a, b, "eq")
print(f" {a:3} vs {b:3} -> GT={gt} LT={lt} EQ={eq}")
print("MUL (low 8 bits):")
for a, b in [(12, 11), (15, 17), (8, 32), (200, 3), (0, 99), (1, 255)]:
r = alu_mul(a, b)
e = (a * b) & 0xFF
print(f" {a:3} * {b:3} = {r:3} expected {e:3} [{'OK' if r == e else 'FAIL'}]")
print()
# ---------- Demo 3: mod-5 divisibility ----------
print("=" * 64)
print(" Demo 3: mod-5 divisibility (multi-layer, hand-constructed)")
print("=" * 64)
def mod5(v):
# Per-multiple-of-5 match (k0, k5, ..., k255): each k has 8 single-input
# "bit{i}.match" gates that fire when bit i of v matches bit i of k,
# ANDed by ".all". Final ".weight" ORs all 52 "all" outputs.
bits = int_to_bits_msb(v, 8)
ks = [k for k in range(256) if k % 5 == 0]
alls = []
for k in ks:
matches = [gate(f"modular.mod5.eq.k{k}.bit{i}.match", [bits[i]]) for i in range(8)]
alls.append(gate(f"modular.mod5.eq.k{k}.all", matches))
return gate("modular.mod5", alls)
hits = [v for v in range(256) if mod5(v)]
print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")
print(f" Sanity (each %5): {[h % 5 for h in hits[:12]]}")
print()
# ---------- Demo 4: CPU running an assembled program ----------
if args.skip_cpu or MEM_BYTES < 0x84:
if args.skip_cpu:
print("Demo 4 skipped (--skip-cpu).")
else:
print(f"Demo 4 skipped (memory={MEM_BYTES}B too small for the demo program).")
return 0
print("=" * 64)
print(f" Demo 4: Threshold CPU running an assembled program ({MEM_BYTES} B memory)")
print("=" * 64)
print(" Program: sum 5+4+3+2+1 via loop")
print(" uses LOAD/STORE/ADD/SUB/CMP/JNZ/HALT, all threshold-gated")
print(" Running ... (larger memories take longer because every memory access")
print(" decodes against every address line)")
cpu = GenericThresholdCPU({k: v for k, v in T.items()})
mem, expected = builtin_program(ADDR_BITS)
state = {"pc": 0, "regs": [0] * 4, "flags": [0] * 4, "mem": mem, "halted": False}
final, cycles = cpu.run(state, max_cycles=200)
got = final["mem"][0x83]
print(f" Halted after {cycles} cycles")
print(f" R0={final['regs'][0]} R1={final['regs'][1]} "
f"R2={final['regs'][2]} R3={final['regs'][3]}")
print(f" M[0x0083] = {got} (expected {expected}) [{'OK' if got == expected else 'FAIL'}]")
return 0 if got == expected else 1
if __name__ == "__main__":
sys.exit(main())