File size: 10,696 Bytes
05b1aea
 
 
3615a51
 
 
 
 
 
 
 
 
 
 
 
 
 
05b1aea
 
 
3615a51
05b1aea
 
3615a51
05b1aea
 
 
3615a51
05b1aea
3615a51
 
05b1aea
 
 
 
 
 
 
 
 
 
 
 
 
 
3615a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05b1aea
3615a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df99f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3615a51
df99f2e
 
3615a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df99f2e
 
 
3615a51
df99f2e
 
 
 
 
 
3615a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""
Hands-on playground for the 8bit-threshold-computer.

Loads a safetensors model, reads its manifest, and exercises threshold
circuits at every level: raw Boolean gates, 8-bit ALU arithmetic and
comparators, multi-layer modular arithmetic, and a manifest-sized CPU
runtime running a small assembled program end-to-end through the
threshold weights.

The CPU demo defaults to the small (1 KB) profile so the run finishes in
a fraction of a second. Larger profiles (4 KB, 64 KB) take proportionally
longer because every memory access decodes against every address line.

Usage:
    python play.py                                  # fast 1KB demo
    python play.py --model neural_computer.safetensors   # full 64KB
    python play.py --model variants/neural_alu8.safetensors --skip-cpu  # ALU only
"""

from __future__ import annotations
import argparse
import os
import sys

import torch
from safetensors import safe_open

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

# Reuse the variant-aware CPU runtime from eval_all.py
from eval_all import GenericThresholdCPU, builtin_program


def heaviside(x):
    return (x >= 0).float()


def load_tensors(path):
    out = {}
    with safe_open(path, framework="pt") as f:
        for name in f.keys():
            out[name] = f.get_tensor(name).float()
    return out


def main() -> int:
    parser = argparse.ArgumentParser(description="Threshold computer playground")
    parser.add_argument(
        "--model", type=str,
        default=os.path.join(os.path.dirname(__file__),
                             "variants", "neural_computer8_small.safetensors"),
        help="Path to a .safetensors variant"
    )
    parser.add_argument("--skip-cpu", action="store_true",
                        help="Skip the CPU program demo (useful for pure-ALU files)")
    args = parser.parse_args()

    print("Loading", args.model)
    T = load_tensors(args.model)

    DATA_BITS = int(T["manifest.data_bits"].item())
    ADDR_BITS = int(T["manifest.addr_bits"].item())
    MEM_BYTES = int(T["manifest.memory_bytes"].item())
    REGISTERS = int(T["manifest.registers"].item())
    print(f"Manifest: data={DATA_BITS}-bit, addr={ADDR_BITS}-bit, mem={MEM_BYTES}B, regs={REGISTERS}")
    print(f"Tensors: {len(T):,}")
    print(f"Total params: {sum(t.numel() for t in T.values()):,}")
    print()

    def gate(name, inputs):
        w = T[name + ".weight"].view(-1)
        b = T[name + ".bias"].view(-1)
        return int(heaviside((torch.tensor(inputs, dtype=torch.float32) * w).sum() + b).item())

    def xor(prefix, inputs):
        a, b_ = inputs
        h_or = gate(f"{prefix}.layer1.or", [a, b_])
        h_nand = gate(f"{prefix}.layer1.nand", [a, b_])
        return gate(f"{prefix}.layer2", [h_or, h_nand])

    def xor_neuron(prefix, inputs):
        a, b_ = inputs
        h1 = gate(f"{prefix}.layer1.neuron1", [a, b_])
        h2 = gate(f"{prefix}.layer1.neuron2", [a, b_])
        return gate(f"{prefix}.layer2", [h1, h2])

    def int_to_bits_msb(v, n):
        return [(v >> (n - 1 - i)) & 1 for i in range(n)]

    def bits_msb_to_int(bits):
        out = 0
        for b in bits:
            out = (out << 1) | int(b)
        return out

    # ---------- Demo 1: Boolean gates ----------
    print("=" * 64)
    print(" Demo 1: Boolean threshold gates")
    print("=" * 64)
    truth_2 = [(0, 0), (0, 1), (1, 0), (1, 1)]
    for gname in ["and", "or", "nand", "nor", "implies"]:
        row = " ".join(f"{a}{b}->{gate(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
        print(f"  {gname:8} {row}")
    for gname in ["xor", "xnor", "biimplies"]:
        row = " ".join(f"{a}{b}->{xor_neuron(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
        print(f"  {gname:8} {row}")
    print(f"  not      0->{gate('boolean.not', [0])} 1->{gate('boolean.not', [1])}")
    print()

    # ---------- Demo 2: 8-bit ALU arithmetic ----------
    print("=" * 64)
    print(" Demo 2: 8-bit ALU arithmetic (every gate is threshold logic)")
    print("=" * 64)

    def fa(prefix, a, b, cin):
        s1 = xor(f"{prefix}.ha1.sum", [a, b])
        c1 = gate(f"{prefix}.ha1.carry", [a, b])
        s2 = xor(f"{prefix}.ha2.sum", [s1, cin])
        c2 = gate(f"{prefix}.ha2.carry", [s1, cin])
        return s2, gate(f"{prefix}.carry_or", [c1, c2])

    def alu_add(a, b):
        a_lsb = list(reversed(int_to_bits_msb(a, 8)))
        b_lsb = list(reversed(int_to_bits_msb(b, 8)))
        carry = 0
        sum_lsb = []
        for i in range(8):
            s, carry = fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb[i], b_lsb[i], carry)
            sum_lsb.append(s)
        return bits_msb_to_int(list(reversed(sum_lsb))), carry

    def alu_sub(a, b):
        a_lsb = list(reversed(int_to_bits_msb(a, 8)))
        b_lsb = list(reversed(int_to_bits_msb(b, 8)))
        carry = 1
        diff_lsb = []
        for i in range(8):
            notb = gate(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]])
            x1 = xor(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb])
            x2 = xor(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry])
            and1 = gate(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb])
            and2 = gate(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry])
            carry = gate(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2])
            diff_lsb.append(x2)
        return bits_msb_to_int(list(reversed(diff_lsb))), carry

    def alu_compare(a, b, kind):
        # Walks the bit-cascade comparator family: per-bit gt/lt/eq, cascaded
        # eq_prefix, cascade.gt/lt, and the final OR/AND gates. Bit 0 is MSB.
        a_msb = int_to_bits_msb(a, 8)
        b_msb = int_to_bits_msb(b, 8)
        bit_gt = [gate(f"arithmetic.cmp8bit.bit{i}.gt", [a_msb[i], b_msb[i]]) for i in range(8)]
        bit_lt = [gate(f"arithmetic.cmp8bit.bit{i}.lt", [a_msb[i], b_msb[i]]) for i in range(8)]
        bit_eq = []
        for i in range(8):
            eq_and = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.and", [a_msb[i], b_msb[i]])
            eq_nor = gate(f"arithmetic.cmp8bit.bit{i}.eq.layer1.nor", [a_msb[i], b_msb[i]])
            bit_eq.append(gate(f"arithmetic.cmp8bit.bit{i}.eq", [eq_and, eq_nor]))
        cas_gt = [bit_gt[0]]
        cas_lt = [bit_lt[0]]
        for i in range(1, 8):
            eq_pref = gate(f"arithmetic.cmp8bit.cascade.eq_prefix.bit{i}", bit_eq[:i])
            cas_gt.append(gate(f"arithmetic.cmp8bit.cascade.gt.bit{i}", [eq_pref, bit_gt[i]]))
            cas_lt.append(gate(f"arithmetic.cmp8bit.cascade.lt.bit{i}", [eq_pref, bit_lt[i]]))
        if kind == "greaterthan":
            return gate("arithmetic.greaterthan8bit", cas_gt)
        if kind == "lessthan":
            return gate("arithmetic.lessthan8bit", cas_lt)
        if kind == "eq":
            return gate("arithmetic.equality8bit", bit_eq)
        raise ValueError(kind)

    def alu_mul(a, b):
        a_bits = int_to_bits_msb(a, 8)
        b_bits = int_to_bits_msb(b, 8)
        result = 0
        for j in range(8):
            if b_bits[j] == 0:
                continue
            row = 0
            for i in range(8):
                pp = gate(f"alu.alu8bit.mul.pp.a{i}b{j}", [a_bits[i], b_bits[j]])
                row |= (pp << (7 - i))
            shift = 7 - j
            result, _ = alu_add(result & 0xFF, (row << shift) & 0xFF)
        return result & 0xFF

    cases_arith = [(5, 3), (37, 100), (200, 99), (255, 1), (127, 128), (15, 17)]
    print("ADD:")
    for a, b in cases_arith:
        r, c = alu_add(a, b)
        e = (a + b) & 0xFF
        print(f"  {a:3} + {b:3} = {r:3} (carry={c})  expected {e:3}  [{'OK' if r == e else 'FAIL'}]")
    print("SUB:")
    for a, b in cases_arith:
        r, c = alu_sub(a, b)
        e = (a - b) & 0xFF
        print(f"  {a:3} - {b:3} = {r:3} (no_borrow={c})  expected {e:3}  [{'OK' if r == e else 'FAIL'}]")
    print("CMP:")
    for a, b in [(50, 30), (30, 50), (77, 77), (255, 0), (0, 255), (128, 127)]:
        gt = alu_compare(a, b, "greaterthan")
        lt = alu_compare(a, b, "lessthan")
        eq = alu_compare(a, b, "eq")
        print(f"  {a:3} vs {b:3} -> GT={gt} LT={lt} EQ={eq}")
    print("MUL (low 8 bits):")
    for a, b in [(12, 11), (15, 17), (8, 32), (200, 3), (0, 99), (1, 255)]:
        r = alu_mul(a, b)
        e = (a * b) & 0xFF
        print(f"  {a:3} * {b:3} = {r:3}  expected {e:3}  [{'OK' if r == e else 'FAIL'}]")
    print()

    # ---------- Demo 3: mod-5 divisibility ----------
    print("=" * 64)
    print(" Demo 3: mod-5 divisibility (multi-layer, hand-constructed)")
    print("=" * 64)

    def mod5(v):
        # Per-multiple-of-5 match (k0, k5, ..., k255): each k has 8 single-input
        # "bit{i}.match" gates that fire when bit i of v matches bit i of k,
        # ANDed by ".all". Final ".weight" ORs all 52 "all" outputs.
        bits = int_to_bits_msb(v, 8)
        ks = [k for k in range(256) if k % 5 == 0]
        alls = []
        for k in ks:
            matches = [gate(f"modular.mod5.eq.k{k}.bit{i}.match", [bits[i]]) for i in range(8)]
            alls.append(gate(f"modular.mod5.eq.k{k}.all", matches))
        return gate("modular.mod5", alls)

    hits = [v for v in range(256) if mod5(v)]
    print(f"  v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")
    print(f"  Sanity (each %5): {[h % 5 for h in hits[:12]]}")
    print()

    # ---------- Demo 4: CPU running an assembled program ----------
    if args.skip_cpu or MEM_BYTES < 0x84:
        if args.skip_cpu:
            print("Demo 4 skipped (--skip-cpu).")
        else:
            print(f"Demo 4 skipped (memory={MEM_BYTES}B too small for the demo program).")
        return 0

    print("=" * 64)
    print(f" Demo 4: Threshold CPU running an assembled program ({MEM_BYTES} B memory)")
    print("=" * 64)
    print("  Program: sum 5+4+3+2+1 via loop")
    print("           uses LOAD/STORE/ADD/SUB/CMP/JNZ/HALT, all threshold-gated")
    print("  Running ... (larger memories take longer because every memory access")
    print("              decodes against every address line)")
    cpu = GenericThresholdCPU({k: v for k, v in T.items()})
    mem, expected = builtin_program(ADDR_BITS)
    state = {"pc": 0, "regs": [0] * 4, "flags": [0] * 4, "mem": mem, "halted": False}
    final, cycles = cpu.run(state, max_cycles=200)
    got = final["mem"][0x83]
    print(f"  Halted after {cycles} cycles")
    print(f"  R0={final['regs'][0]} R1={final['regs'][1]} "
          f"R2={final['regs'][2]} R3={final['regs'][3]}")
    print(f"  M[0x0083] = {got}  (expected {expected})  [{'OK' if got == expected else 'FAIL'}]")
    return 0 if got == expected else 1


if __name__ == "__main__":
    sys.exit(main())