CharlesCNorton commited on
Commit
05b1aea
·
1 Parent(s): 1e992a1

Add 18 prebuilt variants and unified eval harness

Browse files

variants/ holds every (8|16|32)-bit x (none|registers|scratchpad|small|reduced|full)
build (~325 MB total) so users can pull weights from HF without running build.py.

eval_all.py is variant-agnostic: reads each safetensors' manifest, runs the
BatchedFitnessEvaluator, and with --cpu-program also runs an assembled program
through the threshold CPU sized to the variant plus a chained N-bit ALU test
for 16/32-bit data widths.

build.py: fix infer_combinational_inputs N-bit handling. The barrel shifter
case used 1 << (2 - layer), valid only for 3-layer (8-bit) shifters; 16/32-bit
versions have 4-5 layers and crashed at the .inputs step. Priority encoder also
hardcoded 8 inputs and the legacy any_ge naming. Both now parse the bit width
from the gate name and emit correct shift amounts and signal references.

build_all.py orchestrates building + evaluating every named profile.
play.py is a standalone demo (Boolean gates, 8-bit ALU, mod-5, threshold CPU).

build.py CHANGED
@@ -2505,7 +2505,7 @@ def infer_error_detection_inputs(gate: str, reg: SignalRegistry) -> List[int]:
2505
  return [reg.get_id(f"$x[{i}]") for i in range(8)]
2506
 
2507
 
2508
- def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]:
2509
  if 'decoder3to8' in gate:
2510
  for i in range(3):
2511
  reg.register(f"$sel[{i}]")
@@ -2550,41 +2550,57 @@ def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]:
2550
  return [reg.register(f"combinational.regmux4to1.bit{bit}.and{i}") for i in range(4)]
2551
  return []
2552
  if 'barrelshifter' in gate:
2553
- for i in range(8):
 
 
 
 
 
2554
  reg.register(f"$x[{i}]")
2555
- for i in range(3):
2556
  reg.register(f"$shift[{i}]")
2557
  m = re.search(r'layer(\d+)\.bit(\d+)', gate)
2558
  if m:
2559
  layer, bit = int(m.group(1)), int(m.group(2))
2560
- shift_amount = 1 << (2 - layer)
2561
- prefix = f"combinational.barrelshifter.layer{layer}.bit{bit}"
 
2562
  if '.not_sel' in gate:
2563
- return [reg.get_id(f"$shift[{2 - layer}]")]
2564
  if '.and_a' in gate:
2565
  if layer == 0:
2566
  return [reg.get_id(f"$x[{bit}]"), reg.register(f"{prefix}.not_sel")]
2567
  else:
2568
- prev_prefix = f"combinational.barrelshifter.layer{layer-1}.bit{bit}"
2569
  return [reg.register(f"{prev_prefix}.or"), reg.register(f"{prefix}.not_sel")]
2570
  if '.and_b' in gate:
2571
- src = (bit + shift_amount) % 8
2572
  if layer == 0:
2573
- return [reg.get_id(f"$x[{src}]"), reg.get_id(f"$shift[{2 - layer}]")]
2574
  else:
2575
- prev_prefix = f"combinational.barrelshifter.layer{layer-1}.bit{src}"
2576
- return [reg.register(f"{prev_prefix}.or"), reg.get_id(f"$shift[{2 - layer}]")]
2577
  if '.or' in gate:
2578
  return [reg.register(f"{prefix}.and_a"), reg.register(f"{prefix}.and_b")]
2579
- return [reg.get_id(f"$x[{i}]") for i in range(8)]
2580
  if 'priorityencoder' in gate:
2581
- for i in range(8):
 
 
 
2582
  reg.register(f"$x[{i}]")
 
2583
  if '.any_ge' in gate:
2584
  m = re.search(r'any_ge(\d+)', gate)
2585
  if m:
2586
  pos = int(m.group(1))
2587
- return [reg.get_id(f"$x[{i}]") for i in range(pos, 8)]
 
 
 
 
 
 
2588
  if '.is_highest' in gate:
2589
  m = re.search(r'is_highest(\d+)', gate)
2590
  if m:
@@ -2593,21 +2609,25 @@ def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]:
2593
  if pos == 0:
2594
  return [reg.get_id("#0")]
2595
  else:
2596
- return [reg.register(f"combinational.priorityencoder.any_ge{pos-1}")]
 
 
 
 
2597
  if '.and' in gate:
2598
- return [reg.get_id(f"$x[{pos}]"), reg.register(f"combinational.priorityencoder.is_highest{pos}.not_higher")]
2599
  if '.out' in gate:
2600
  m = re.search(r'out(\d+)', gate)
2601
  if m:
2602
  out_bit = int(m.group(1))
2603
  inputs = []
2604
- for pos in range(8):
2605
  if (pos >> out_bit) & 1:
2606
- inputs.append(reg.register(f"combinational.priorityencoder.is_highest{pos}.and"))
2607
  return inputs
2608
  if '.valid' in gate:
2609
- return [reg.get_id(f"$x[{i}]") for i in range(8)]
2610
- return [reg.get_id(f"$x[{i}]") for i in range(8)]
2611
  return []
2612
 
2613
 
@@ -2706,7 +2726,7 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
2706
  if gate.startswith('error_detection.'):
2707
  return infer_error_detection_inputs(gate, reg)
2708
  if gate.startswith('combinational.'):
2709
- return infer_combinational_inputs(gate, reg)
2710
  weight_key = f"{gate}.weight"
2711
  if weight_key in tensors:
2712
  w = tensors[weight_key]
 
2505
  return [reg.get_id(f"$x[{i}]") for i in range(8)]
2506
 
2507
 
2508
+ def infer_combinational_inputs(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor] = None) -> List[int]:
2509
  if 'decoder3to8' in gate:
2510
  for i in range(3):
2511
  reg.register(f"$sel[{i}]")
 
2550
  return [reg.register(f"combinational.regmux4to1.bit{bit}.and{i}") for i in range(4)]
2551
  return []
2552
  if 'barrelshifter' in gate:
2553
+ import math as _math
2554
+ bs_match = re.search(r'barrelshifter(\d*)', gate)
2555
+ bits = int(bs_match.group(1)) if bs_match and bs_match.group(1) else 8
2556
+ bs_prefix = f"combinational.barrelshifter{bs_match.group(1) if bs_match else ''}"
2557
+ num_layers = max(1, _math.ceil(_math.log2(bits))) if bits > 1 else 1
2558
+ for i in range(bits):
2559
  reg.register(f"$x[{i}]")
2560
+ for i in range(num_layers):
2561
  reg.register(f"$shift[{i}]")
2562
  m = re.search(r'layer(\d+)\.bit(\d+)', gate)
2563
  if m:
2564
  layer, bit = int(m.group(1)), int(m.group(2))
2565
+ shift_amount = 1 << (num_layers - 1 - layer)
2566
+ prefix = f"{bs_prefix}.layer{layer}.bit{bit}"
2567
+ sel_idx = num_layers - 1 - layer
2568
  if '.not_sel' in gate:
2569
+ return [reg.get_id(f"$shift[{sel_idx}]")]
2570
  if '.and_a' in gate:
2571
  if layer == 0:
2572
  return [reg.get_id(f"$x[{bit}]"), reg.register(f"{prefix}.not_sel")]
2573
  else:
2574
+ prev_prefix = f"{bs_prefix}.layer{layer-1}.bit{bit}"
2575
  return [reg.register(f"{prev_prefix}.or"), reg.register(f"{prefix}.not_sel")]
2576
  if '.and_b' in gate:
2577
+ src = (bit + shift_amount) % bits
2578
  if layer == 0:
2579
+ return [reg.get_id(f"$x[{src}]"), reg.get_id(f"$shift[{sel_idx}]")]
2580
  else:
2581
+ prev_prefix = f"{bs_prefix}.layer{layer-1}.bit{src}"
2582
+ return [reg.register(f"{prev_prefix}.or"), reg.get_id(f"$shift[{sel_idx}]")]
2583
  if '.or' in gate:
2584
  return [reg.register(f"{prefix}.and_a"), reg.register(f"{prefix}.and_b")]
2585
+ return [reg.get_id(f"$x[{i}]") for i in range(bits)]
2586
  if 'priorityencoder' in gate:
2587
+ pe_match = re.search(r'priorityencoder(\d*)', gate)
2588
+ bits = int(pe_match.group(1)) if pe_match and pe_match.group(1) else 8
2589
+ pe_prefix = f"combinational.priorityencoder{pe_match.group(1) if pe_match else ''}"
2590
+ for i in range(bits):
2591
  reg.register(f"$x[{i}]")
2592
+ # Legacy 8-bit naming: any_ge{pos} = OR of bits at positions [pos..bits-1]
2593
  if '.any_ge' in gate:
2594
  m = re.search(r'any_ge(\d+)', gate)
2595
  if m:
2596
  pos = int(m.group(1))
2597
+ return [reg.get_id(f"$x[{i}]") for i in range(pos, bits)]
2598
+ # N-bit naming: any_higher{pos} = OR of bits 0..pos-1
2599
+ if '.any_higher' in gate:
2600
+ m = re.search(r'any_higher(\d+)', gate)
2601
+ if m:
2602
+ pos = int(m.group(1))
2603
+ return [reg.get_id(f"$x[{i}]") for i in range(pos)]
2604
  if '.is_highest' in gate:
2605
  m = re.search(r'is_highest(\d+)', gate)
2606
  if m:
 
2609
  if pos == 0:
2610
  return [reg.get_id("#0")]
2611
  else:
2612
+ # Try N-bit any_higher first, fall back to legacy any_ge
2613
+ ah_key = f"{pe_prefix}.any_higher{pos}"
2614
+ if tensors is not None and f"{ah_key}.weight" in tensors:
2615
+ return [reg.register(ah_key)]
2616
+ return [reg.register(f"{pe_prefix}.any_ge{pos-1}")]
2617
  if '.and' in gate:
2618
+ return [reg.get_id(f"$x[{pos}]"), reg.register(f"{pe_prefix}.is_highest{pos}.not_higher")]
2619
  if '.out' in gate:
2620
  m = re.search(r'out(\d+)', gate)
2621
  if m:
2622
  out_bit = int(m.group(1))
2623
  inputs = []
2624
+ for pos in range(bits):
2625
  if (pos >> out_bit) & 1:
2626
+ inputs.append(reg.register(f"{pe_prefix}.is_highest{pos}.and"))
2627
  return inputs
2628
  if '.valid' in gate:
2629
+ return [reg.get_id(f"$x[{i}]") for i in range(bits)]
2630
+ return [reg.get_id(f"$x[{i}]") for i in range(bits)]
2631
  return []
2632
 
2633
 
 
2726
  if gate.startswith('error_detection.'):
2727
  return infer_error_detection_inputs(gate, reg)
2728
  if gate.startswith('combinational.'):
2729
+ return infer_combinational_inputs(gate, reg, tensors)
2730
  weight_key = f"{gate}.weight"
2731
  if weight_key in tensors:
2732
  w = tensors[weight_key]
build_all.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build and verify every named (bits, memory_profile) variant.
3
+
4
+ Outputs:
5
+ variants/neural_alu{8,16,32}.safetensors - no memory
6
+ variants/neural_computer{8,16,32}_registers.safetensors - 16 B
7
+ variants/neural_computer{8,16,32}_scratchpad.safetensors - 256 B
8
+ variants/neural_computer{8,16,32}_small.safetensors - 1 KB
9
+ variants/neural_computer{8,16,32}_reduced.safetensors - 4 KB
10
+ variants/neural_computer{8,16,32}.safetensors - 64 KB
11
+
12
+ For each, runs eval.py via the BatchedFitnessEvaluator and records
13
+ (tensor count, params, file size, fitness, total_tests, seconds).
14
+ """
15
+
16
+ from __future__ import annotations
17
+ import os
18
+ import shutil
19
+ import subprocess
20
+ import sys
21
+ import time
22
+ from pathlib import Path
23
+
24
+ import torch
25
+ from safetensors import safe_open
26
+
27
+ ROOT = Path(__file__).resolve().parent
28
+ SEED = ROOT / "neural_computer.safetensors"
29
+ OUT_DIR = ROOT / "variants"
30
+ OUT_DIR.mkdir(exist_ok=True)
31
+
32
+ PROFILES = ["none", "registers", "scratchpad", "small", "reduced", "full"]
33
+ BITS = [8, 16, 32]
34
+
35
+
36
+ def variant_filename(bits: int, profile: str) -> str:
37
+ if profile == "none":
38
+ return f"neural_alu{bits}.safetensors"
39
+ if profile == "full":
40
+ return f"neural_computer{bits}.safetensors"
41
+ return f"neural_computer{bits}_{profile}.safetensors"
42
+
43
+
44
+ def run(cmd: list[str], timeout: int = 600) -> tuple[int, str]:
45
+ p = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
46
+ return p.returncode, (p.stdout or "") + (p.stderr or "")
47
+
48
+
49
+ def build_variant(bits: int, profile: str) -> Path:
50
+ out = OUT_DIR / variant_filename(bits, profile)
51
+ shutil.copy2(SEED, out)
52
+ cmd = [
53
+ sys.executable, str(ROOT / "build.py"),
54
+ "--bits", str(bits),
55
+ "-m", profile,
56
+ "--apply",
57
+ "--model", str(out),
58
+ "all",
59
+ ]
60
+ rc, log = run(cmd, timeout=900)
61
+ if rc != 0:
62
+ raise RuntimeError(f"build failed for bits={bits} profile={profile}:\n{log[-1500:]}")
63
+ return out
64
+
65
+
66
+ def measure_variant(path: Path) -> dict:
67
+ """Read tensor count, params, manifest values from the variant."""
68
+ with safe_open(str(path), framework="pt") as f:
69
+ keys = list(f.keys())
70
+ params = sum(f.get_tensor(k).numel() for k in keys)
71
+ manifest = {
72
+ k.split(".", 1)[1]: f.get_tensor(k).item()
73
+ for k in keys if k.startswith("manifest.") and f.get_tensor(k).numel() == 1
74
+ }
75
+ return {
76
+ "tensors": len(keys),
77
+ "params": params,
78
+ "size_mb": path.stat().st_size / (1024 * 1024),
79
+ "manifest": manifest,
80
+ }
81
+
82
+
83
+ def eval_variant(path: Path, device: str = "cpu", timeout: int = 600) -> dict:
84
+ """Run eval.py against a variant and parse fitness."""
85
+ cmd = [
86
+ sys.executable, str(ROOT / "eval.py"),
87
+ "--model", str(path),
88
+ "--device", device,
89
+ "--quiet",
90
+ ]
91
+ t0 = time.time()
92
+ rc, log = run(cmd, timeout=timeout)
93
+ dt = time.time() - t0
94
+
95
+ fitness = None
96
+ total_tests = None
97
+ status = "ERROR"
98
+ for line in log.splitlines():
99
+ line = line.strip()
100
+ if line.startswith("Fitness:"):
101
+ try:
102
+ fitness = float(line.split()[1])
103
+ except Exception:
104
+ pass
105
+ elif line.startswith("Total tests:"):
106
+ try:
107
+ total_tests = int(line.split()[2])
108
+ except Exception:
109
+ pass
110
+ elif line.startswith("STATUS:"):
111
+ status = line.split()[1]
112
+ return {
113
+ "rc": rc,
114
+ "fitness": fitness,
115
+ "total_tests": total_tests,
116
+ "status": status,
117
+ "elapsed_s": dt,
118
+ "log_tail": "\n".join(log.splitlines()[-15:]),
119
+ }
120
+
121
+
122
+ def main() -> None:
123
+ rows = []
124
+ print(f"Building 18 variants into {OUT_DIR}\n")
125
+ for bits in BITS:
126
+ for profile in PROFILES:
127
+ label = f"bits={bits} profile={profile}"
128
+ print(f"=== {label} ===", flush=True)
129
+ t0 = time.time()
130
+ try:
131
+ path = build_variant(bits, profile)
132
+ bt = time.time() - t0
133
+ meta = measure_variant(path)
134
+ ev = eval_variant(path, device="cpu", timeout=900)
135
+ rows.append({
136
+ "bits": bits, "profile": profile,
137
+ "filename": path.name,
138
+ "build_s": bt,
139
+ **meta,
140
+ **{k: ev[k] for k in ("fitness", "total_tests", "status", "elapsed_s")},
141
+ "log_tail": ev["log_tail"] if ev["status"] != "PASS" else "",
142
+ })
143
+ print(f" built in {bt:.1f}s size={meta['size_mb']:.1f}MB"
144
+ f" params={meta['params']:,} tensors={meta['tensors']:,}")
145
+ print(f" eval: fitness={ev['fitness']} tests={ev['total_tests']}"
146
+ f" status={ev['status']} ({ev['elapsed_s']:.1f}s)")
147
+ if ev["status"] != "PASS":
148
+ print(" --- failure tail ---")
149
+ print(" " + "\n ".join(ev["log_tail"].splitlines()))
150
+ print(" --------------------")
151
+ except Exception as e:
152
+ print(f" EXCEPTION: {e}")
153
+ rows.append({"bits": bits, "profile": profile, "error": str(e)})
154
+ print()
155
+
156
+ print("=" * 88)
157
+ print(" SUMMARY")
158
+ print("=" * 88)
159
+ header = f"{'bits':>4} {'profile':<11} {'size_MB':>8} {'tensors':>8} {'params':>11} {'fitness':>9} {'tests':>6} {'status':>7}"
160
+ print(header)
161
+ print("-" * len(header))
162
+ for r in rows:
163
+ if "error" in r:
164
+ print(f"{r['bits']:>4} {r['profile']:<11} ERROR: {r['error'][:60]}")
165
+ continue
166
+ fit = f"{r['fitness']:.4f}" if r['fitness'] is not None else "n/a"
167
+ tests = r['total_tests'] if r['total_tests'] is not None else "?"
168
+ print(f"{r['bits']:>4} {r['profile']:<11} {r['size_mb']:>8.1f} "
169
+ f"{r['tensors']:>8,} {r['params']:>11,} "
170
+ f"{fit:>9} {tests:>6} {r['status']:>7}")
171
+
172
+ fail = [r for r in rows if r.get("status") != "PASS" or "error" in r]
173
+ print()
174
+ if fail:
175
+ print(f"FAILURES: {len(fail)}/{len(rows)}")
176
+ else:
177
+ print(f"ALL {len(rows)} VARIANTS PASS")
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
eval_all.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified evaluation harness for any threshold-computer variant.
3
+
4
+ Drops the `--cpu-test` smoke test (which was hardcoded to 16-bit/64KB) and
5
+ adds variant-aware sweep modes. The same harness handles every (data_bits,
6
+ addr_bits) configuration: it reads the manifest from each safetensors file,
7
+ runs the BatchedFitnessEvaluator at the right device, and reports per-file
8
+ plus per-category results.
9
+
10
+ Usage:
11
+ python eval_all.py path/to/file.safetensors # one file
12
+ python eval_all.py variants/ # every .safetensors in dir
13
+ python eval_all.py --device cpu variants/ # CPU only (default)
14
+ python eval_all.py --pop_size 32 variants/ # batched pop eval
15
+ python eval_all.py --debug path/to/file.safetensors # per-circuit detail
16
+ python eval_all.py --cpu-program PATH # also run an assembled program
17
+ # through the threshold CPU
18
+ # sized to the file's manifest
19
+
20
+ Exit code:
21
+ 0 if all files PASS (fitness >= 0.9999)
22
+ N where N is the number of FAILing files
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import json
29
+ import os
30
+ import sys
31
+ import time
32
+ from pathlib import Path
33
+ from typing import Dict, List, Optional, Tuple
34
+
35
+ import torch
36
+ from safetensors import safe_open
37
+
38
+ # Reuse eval.py's evaluator (variant-aware)
39
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
40
+ from eval import (
41
+ BatchedFitnessEvaluator,
42
+ create_population,
43
+ load_model,
44
+ get_manifest,
45
+ heaviside,
46
+ int_to_bits,
47
+ bits_to_int,
48
+ bits_msb_to_lsb,
49
+ )
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Variant-aware threshold ALU + CPU
54
+ # ---------------------------------------------------------------------------
55
+
56
+ class GenericThresholdALU:
57
+ """Variant-aware threshold ALU. Reads manifest, runs ADD/SUB/CMP/MUL etc.
58
+
59
+ Currently supports the 8-bit ALU primitives (ripplecarry8bit, sub8bit,
60
+ cmp8bit, mul/div). For wider data paths, use the BatchedFitnessEvaluator
61
+ which already handles 16/32-bit comparators, subtractors, etc.
62
+ """
63
+
64
+ def __init__(self, tensors: Dict[str, torch.Tensor], data_bits: int):
65
+ self.T = tensors
66
+ self.data_bits = data_bits
67
+
68
+ def _g(self, name, inputs):
69
+ w = self.T[name + ".weight"].view(-1)
70
+ b = self.T[name + ".bias"].view(-1)
71
+ return int(heaviside((torch.tensor(inputs, dtype=torch.float32) * w).sum() + b).item())
72
+
73
+ def _xor_or_nand(self, prefix, inputs):
74
+ a, b_ = inputs
75
+ h_or = self._g(f"{prefix}.layer1.or", [a, b_])
76
+ h_nand = self._g(f"{prefix}.layer1.nand", [a, b_])
77
+ return self._g(f"{prefix}.layer2", [h_or, h_nand])
78
+
79
+ def _fa(self, prefix, a, b, cin):
80
+ s1 = self._xor_or_nand(f"{prefix}.ha1.sum", [a, b])
81
+ c1 = self._g(f"{prefix}.ha1.carry", [a, b])
82
+ s2 = self._xor_or_nand(f"{prefix}.ha2.sum", [s1, cin])
83
+ c2 = self._g(f"{prefix}.ha2.carry", [s1, cin])
84
+ cout = self._g(f"{prefix}.carry_or", [c1, c2])
85
+ return s2, cout
86
+
87
+ def add8(self, a, b):
88
+ a_lsb = list(reversed(int_to_bits(a, 8)))
89
+ b_lsb = list(reversed(int_to_bits(b, 8)))
90
+ carry = 0
91
+ s_lsb = []
92
+ for i in range(8):
93
+ s, carry = self._fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb[i], b_lsb[i], carry)
94
+ s_lsb.append(s)
95
+ return bits_to_int(list(reversed(s_lsb))), carry
96
+
97
+ def sub8(self, a, b):
98
+ a_lsb = list(reversed(int_to_bits(a, 8)))
99
+ b_lsb = list(reversed(int_to_bits(b, 8)))
100
+ carry = 1
101
+ d_lsb = []
102
+ for i in range(8):
103
+ notb = self._g(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]])
104
+ x1 = self._xor_or_nand(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb])
105
+ x2 = self._xor_or_nand(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry])
106
+ and1 = self._g(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb])
107
+ and2 = self._g(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry])
108
+ carry = self._g(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2])
109
+ d_lsb.append(x2)
110
+ return bits_to_int(list(reversed(d_lsb))), carry
111
+
112
+ def cmp8(self, a, b, kind):
113
+ inp = int_to_bits(a, 8) + int_to_bits(b, 8)
114
+ if kind == "eq":
115
+ h_geq = self._g("arithmetic.equality8bit.layer1.geq", inp)
116
+ h_leq = self._g("arithmetic.equality8bit.layer1.leq", inp)
117
+ return self._g("arithmetic.equality8bit.layer2", [h_geq, h_leq])
118
+ return self._g(f"arithmetic.{kind}8bit", inp)
119
+
120
+ def mul8(self, a, b):
121
+ ab = int_to_bits(a, 8)
122
+ bb = int_to_bits(b, 8)
123
+ result = 0
124
+ for j in range(8):
125
+ if bb[j] == 0:
126
+ continue
127
+ row = 0
128
+ for i in range(8):
129
+ pp = self._g(f"alu.alu8bit.mul.pp.a{i}b{j}", [ab[i], bb[j]])
130
+ row |= (pp << (7 - i))
131
+ shift = 7 - j
132
+ result, _ = self.add8(result & 0xFF, (row << shift) & 0xFF)
133
+ return result & 0xFF
134
+
135
+ # ----- N-bit primitives (for 16-bit and 32-bit variants) ----------------
136
+
137
+ def add_n(self, a: int, b: int, bits: int):
138
+ """Width-generic ripple-carry add via arithmetic.ripplecarry{N}bit."""
139
+ prefix = f"arithmetic.ripplecarry{bits}bit"
140
+ a_lsb = list(reversed(int_to_bits(a, bits)))
141
+ b_lsb = list(reversed(int_to_bits(b, bits)))
142
+ carry = 0
143
+ s_lsb = []
144
+ for i in range(bits):
145
+ s, carry = self._fa(f"{prefix}.fa{i}", a_lsb[i], b_lsb[i], carry)
146
+ s_lsb.append(s)
147
+ return bits_to_int(list(reversed(s_lsb))), carry
148
+
149
+ def sub_n(self, a: int, b: int, bits: int):
150
+ """N-bit two's-complement subtract via arithmetic.sub{N}bit (N >= 16).
151
+
152
+ Structure (per build.add_sub_nbits): N NOT gates + N standard full adders.
153
+ """
154
+ prefix = f"arithmetic.sub{bits}bit"
155
+ a_lsb = list(reversed(int_to_bits(a, bits)))
156
+ b_lsb = list(reversed(int_to_bits(b, bits)))
157
+ # NOT each B bit
158
+ notb = [self._g(f"{prefix}.not_b.bit{i}", [b_lsb[i]]) for i in range(bits)]
159
+ carry = 1 # carry-in = 1 for two's-complement
160
+ d_lsb = []
161
+ for i in range(bits):
162
+ s, carry = self._fa(f"{prefix}.fa{i}", a_lsb[i], notb[i], carry)
163
+ d_lsb.append(s)
164
+ return bits_to_int(list(reversed(d_lsb))), carry
165
+
166
+ def cmp_n(self, a: int, b: int, kind: str, bits: int):
167
+ """N-bit comparator. For bits <= 16 single-layer; bits == 32 cascaded."""
168
+ a_bits = int_to_bits(a, bits)
169
+ b_bits = int_to_bits(b, bits)
170
+ if bits <= 16:
171
+ inp = a_bits + b_bits
172
+ if kind == "eq":
173
+ h_geq = self._g(f"arithmetic.equality{bits}bit.layer1.geq", inp)
174
+ h_leq = self._g(f"arithmetic.equality{bits}bit.layer1.leq", inp)
175
+ return self._g(f"arithmetic.equality{bits}bit.layer2", [h_geq, h_leq])
176
+ return self._g(f"arithmetic.{kind}{bits}bit", inp)
177
+ # 32-bit: cascaded byte-wise
178
+ prefix = f"arithmetic.cmp{bits}bit"
179
+ num_bytes = bits // 8
180
+ # per-byte gt/lt/eq
181
+ byte_gt, byte_lt, byte_eq = [], [], []
182
+ for bn in range(num_bytes):
183
+ ab = a_bits[bn*8:(bn+1)*8]
184
+ bb = b_bits[bn*8:(bn+1)*8]
185
+ byte_gt.append(self._g(f"{prefix}.byte{bn}.gt", ab + bb))
186
+ byte_lt.append(self._g(f"{prefix}.byte{bn}.lt", ab + bb))
187
+ geq = self._g(f"{prefix}.byte{bn}.eq.geq", ab + bb)
188
+ leq = self._g(f"{prefix}.byte{bn}.eq.leq", ab + bb)
189
+ byte_eq.append(self._g(f"{prefix}.byte{bn}.eq.and", [geq, leq]))
190
+ if kind == "equality":
191
+ # OR of all eq's, but the gate is `arithmetic.equality{bits}bit` with weight=[1,1,..,1]/bias=-num_bytes
192
+ return self._g(f"arithmetic.equality{bits}bit", byte_eq)
193
+ # cascade
194
+ cascade_gt = [byte_gt[0]]
195
+ cascade_lt = [byte_lt[0]]
196
+ for bn in range(1, num_bytes):
197
+ all_eq = self._g(f"{prefix}.cascade.gt.stage{bn}.all_eq", byte_eq[:bn])
198
+ cascade_gt.append(self._g(f"{prefix}.cascade.gt.stage{bn}.and", [all_eq, byte_gt[bn]]))
199
+ all_eq2 = self._g(f"{prefix}.cascade.lt.stage{bn}.all_eq", byte_eq[:bn])
200
+ cascade_lt.append(self._g(f"{prefix}.cascade.lt.stage{bn}.and", [all_eq2, byte_lt[bn]]))
201
+ if kind == "greaterthan":
202
+ return self._g(f"arithmetic.greaterthan{bits}bit", cascade_gt)
203
+ if kind == "lessthan":
204
+ return self._g(f"arithmetic.lessthan{bits}bit", cascade_lt)
205
+ raise ValueError(f"unsupported cmp kind {kind} for bits={bits}")
206
+
207
+ def mul_n(self, a: int, b: int, bits: int):
208
+ """N-bit shift-add multiply (low N bits only)."""
209
+ ab = int_to_bits(a, bits)
210
+ bb = int_to_bits(b, bits)
211
+ mask = (1 << bits) - 1
212
+ result = 0
213
+ for j in range(bits):
214
+ if bb[j] == 0:
215
+ continue
216
+ row = 0
217
+ for i in range(bits):
218
+ pp = self._g(f"alu.alu{bits}bit.mul.pp.a{i}b{j}", [ab[i], bb[j]])
219
+ row |= (pp << (bits - 1 - i))
220
+ shift = (bits - 1) - j
221
+ result, _ = self.add_n(result & mask, (row << shift) & mask, bits)
222
+ return result & mask
223
+
224
+
225
+ class GenericThresholdCPU:
226
+ """Variant-aware CPU runtime. Sized from the variant's manifest."""
227
+
228
+ def __init__(self, tensors: Dict[str, torch.Tensor]):
229
+ self.T = tensors
230
+ m = get_manifest(tensors)
231
+ self.data_bits = m["data_bits"]
232
+ self.addr_bits = m["addr_bits"]
233
+ self.mem_bytes = m["memory_bytes"]
234
+ # 8-bit CPU primitives (ripplecarry8bit, sub8bit, alu.alu8bit.*, memory.*,
235
+ # control.*) are present in every variant regardless of manifest data_bits.
236
+ # Wider data widths simply add additional standalone ALU primitives.
237
+ if self.mem_bytes == 0:
238
+ raise NotImplementedError(
239
+ "Pure-ALU variants have no memory; cannot run CPU programs"
240
+ )
241
+ self.alu = GenericThresholdALU(tensors, 8)
242
+
243
+ def _addr_decode(self, addr):
244
+ bits = torch.tensor(int_to_bits(addr, self.addr_bits), dtype=torch.float32)
245
+ w = self.T["memory.addr_decode.weight"]
246
+ b = self.T["memory.addr_decode.bias"]
247
+ return heaviside((w * bits).sum(dim=1) + b)
248
+
249
+ def mem_read(self, mem, addr):
250
+ sel = self._addr_decode(addr)
251
+ mem_bits = torch.tensor(
252
+ [int_to_bits(byte, 8) for byte in mem], dtype=torch.float32
253
+ )
254
+ and_w = self.T["memory.read.and.weight"]
255
+ and_b = self.T["memory.read.and.bias"]
256
+ or_w = self.T["memory.read.or.weight"]
257
+ or_b = self.T["memory.read.or.bias"]
258
+ out = []
259
+ for bit in range(8):
260
+ inp = torch.stack([mem_bits[:, bit], sel], dim=1)
261
+ and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
262
+ out.append(int(heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()))
263
+ return bits_to_int(out)
264
+
265
+ def mem_write(self, mem, addr, value):
266
+ sel = self._addr_decode(addr)
267
+ data_bits = torch.tensor(int_to_bits(value, 8), dtype=torch.float32)
268
+ mem_bits = torch.tensor(
269
+ [int_to_bits(byte, 8) for byte in mem], dtype=torch.float32
270
+ )
271
+ sel_w = self.T["memory.write.sel.weight"]
272
+ sel_b = self.T["memory.write.sel.bias"]
273
+ nsel_w = self.T["memory.write.nsel.weight"].squeeze(1)
274
+ nsel_b = self.T["memory.write.nsel.bias"]
275
+ and_old_w = self.T["memory.write.and_old.weight"]
276
+ and_old_b = self.T["memory.write.and_old.bias"]
277
+ and_new_w = self.T["memory.write.and_new.weight"]
278
+ and_new_b = self.T["memory.write.and_new.bias"]
279
+ or_w = self.T["memory.write.or.weight"]
280
+ or_b = self.T["memory.write.or.bias"]
281
+ we = torch.ones_like(sel)
282
+ sel_inp = torch.stack([sel, we], dim=1)
283
+ write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
284
+ nsel = heaviside(write_sel * nsel_w + nsel_b)
285
+ for bit in range(8):
286
+ old = mem_bits[:, bit]
287
+ data_bit = data_bits[bit].expand(self.mem_bytes)
288
+ inp_old = torch.stack([old, nsel], dim=1)
289
+ inp_new = torch.stack([data_bit, write_sel], dim=1)
290
+ and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
291
+ and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
292
+ or_inp = torch.stack([and_old, and_new], dim=1)
293
+ new_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
294
+ mem_bits[:, bit] = new_bit
295
+ return [bits_to_int([int(b) for b in mem_bits[i].tolist()]) for i in range(self.mem_bytes)]
296
+
297
+ def step(self, state):
298
+ if state["halted"]:
299
+ return state
300
+ s = dict(state)
301
+ s["mem"] = state["mem"][:]
302
+ s["regs"] = state["regs"][:]
303
+ s["flags"] = state["flags"][:]
304
+ addr_mask = (1 << self.addr_bits) - 1
305
+ pc = s["pc"]
306
+ hi = self.mem_read(s["mem"], pc & addr_mask)
307
+ lo = self.mem_read(s["mem"], (pc + 1) & addr_mask)
308
+ ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
309
+ opcode = (ir >> 12) & 0xF
310
+ rd = (ir >> 10) & 0x3
311
+ rs = (ir >> 8) & 0x3
312
+ imm = ir & 0xFF
313
+ next_pc = (pc + 2) & addr_mask
314
+ addr_full = None
315
+ if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
316
+ ah = self.mem_read(s["mem"], next_pc)
317
+ al = self.mem_read(s["mem"], (next_pc + 1) & addr_mask)
318
+ addr_full = ((ah & 0xFF) << 8) | (al & 0xFF)
319
+ next_pc = (next_pc + 2) & addr_mask
320
+ addr = (addr_full & addr_mask) if addr_full is not None else None
321
+ a = s["regs"][rd]
322
+ b = s["regs"][rs]
323
+ result = a
324
+ carry = 0
325
+ write_result = True
326
+ if opcode == 0x0:
327
+ result, carry = self.alu.add8(a, b)
328
+ elif opcode == 0x1:
329
+ result, carry = self.alu.sub8(a, b)
330
+ elif opcode == 0x7:
331
+ result = self.alu.mul8(a, b)
332
+ elif opcode == 0x9:
333
+ r2, carry = self.alu.sub8(a, b)
334
+ z = 1 if r2 == 0 else 0
335
+ n = 1 if (r2 & 0x80) else 0
336
+ s["flags"] = [z, n, carry, 0]
337
+ write_result = False
338
+ elif opcode == 0xA:
339
+ result = self.mem_read(s["mem"], addr)
340
+ elif opcode == 0xB:
341
+ s["mem"] = self.mem_write(s["mem"], addr, b & 0xFF)
342
+ write_result = False
343
+ elif opcode == 0xC:
344
+ s["pc"] = addr
345
+ return s
346
+ elif opcode == 0xD:
347
+ cond = imm & 0x7
348
+ z, n, c, v = s["flags"]
349
+ take = [z == 1, z == 0, c == 1, c == 0,
350
+ n == 1, n == 0, v == 1, v == 0][cond]
351
+ s["pc"] = addr if take else next_pc
352
+ return s
353
+ elif opcode == 0xF:
354
+ s["halted"] = True
355
+ return s
356
+
357
+ if write_result and opcode != 0x9:
358
+ s["regs"][rd] = result & 0xFF
359
+ if opcode in (0x0, 0x1, 0x7):
360
+ z = 1 if (result & 0xFF) == 0 else 0
361
+ n = 1 if (result & 0x80) else 0
362
+ s["flags"] = [z, n, carry, 0]
363
+ s["pc"] = next_pc
364
+ return s
365
+
366
+ def run(self, state, max_cycles=200):
367
+ s = state
368
+ cycles = 0
369
+ while not s["halted"] and cycles < max_cycles:
370
+ s = self.step(s)
371
+ cycles += 1
372
+ return s, cycles
373
+
374
+
375
+ def _encode_instr(opcode, rd, rs, imm):
376
+ return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm & 0xFF)
377
+
378
+
379
+ def _w16(mem, addr, value):
380
+ mem[addr] = (value >> 8) & 0xFF
381
+ mem[addr + 1] = value & 0xFF
382
+
383
+
384
+ PROGRAM_MIN_BYTES = 0x84 # code 0x00..0x1F + data 0x80..0x83
385
+
386
+
387
+ def builtin_program(addr_bits: int) -> Tuple[List[int], int]:
388
+ """Sum 5+4+3+2+1 via a loop. Returns (mem, expected_result_at_0x83).
389
+
390
+ Compact layout: code at 0x00..0x1F (32 bytes), data at 0x80..0x83 (4 bytes).
391
+ Total footprint 132 bytes -- fits within scratchpad (256 B) and larger.
392
+ Requires addr_bits >= 8.
393
+ """
394
+ if (1 << addr_bits) < PROGRAM_MIN_BYTES:
395
+ raise ValueError(f"addr_bits={addr_bits} too small for builtin program")
396
+ mem = [0] * (1 << addr_bits)
397
+ mem[0x80] = 5 # initial counter
398
+ mem[0x81] = 1 # decrement
399
+ mem[0x82] = 0 # zero (for compare and accumulator init)
400
+ # mem[0x83] is the output
401
+ _w16(mem, 0x0000, _encode_instr(0xA, 1, 0, 0)); _w16(mem, 0x0002, 0x0080)
402
+ _w16(mem, 0x0004, _encode_instr(0xA, 2, 0, 0)); _w16(mem, 0x0006, 0x0081)
403
+ _w16(mem, 0x0008, _encode_instr(0xA, 3, 0, 0)); _w16(mem, 0x000A, 0x0082)
404
+ _w16(mem, 0x000C, _encode_instr(0xA, 0, 0, 0)); _w16(mem, 0x000E, 0x0082)
405
+ _w16(mem, 0x0010, _encode_instr(0x0, 0, 1, 0))
406
+ _w16(mem, 0x0012, _encode_instr(0x1, 1, 2, 0))
407
+ _w16(mem, 0x0014, _encode_instr(0x9, 1, 3, 0))
408
+ _w16(mem, 0x0016, _encode_instr(0xD, 0, 0, 0x01)); _w16(mem, 0x0018, 0x0010)
409
+ _w16(mem, 0x001A, _encode_instr(0xB, 0, 0, 0)); _w16(mem, 0x001C, 0x0083)
410
+ _w16(mem, 0x001E, _encode_instr(0xF, 0, 0, 0))
411
+ return mem, 15
412
+
413
+
414
+ # ---------------------------------------------------------------------------
415
+ # Eval driver
416
+ # ---------------------------------------------------------------------------
417
+
418
+ def list_safetensors(path: Path) -> List[Path]:
419
+ if path.is_file():
420
+ return [path]
421
+ if path.is_dir():
422
+ return sorted(p for p in path.glob("*.safetensors") if p.is_file())
423
+ return []
424
+
425
+
426
+ def evaluate_one(path: Path, device: str, pop_size: int, debug: bool, run_cpu_program: bool) -> Dict:
427
+ out: Dict = {"path": str(path), "filename": path.name}
428
+ try:
429
+ tensors = load_model(str(path))
430
+ except Exception as e:
431
+ out.update(error=f"load failed: {e}", status="ERROR")
432
+ return out
433
+
434
+ manifest = get_manifest(tensors)
435
+ out.update(
436
+ size_mb=path.stat().st_size / (1024 * 1024),
437
+ tensors=len(tensors),
438
+ params=sum(t.numel() for t in tensors.values()),
439
+ manifest=manifest,
440
+ )
441
+
442
+ # Move to device
443
+ tensors = {k: v.to(device) for k, v in tensors.items()}
444
+
445
+ try:
446
+ evaluator = BatchedFitnessEvaluator(device=device, model_path=str(path), tensors=tensors)
447
+ population = create_population(tensors, pop_size=pop_size, device=device)
448
+ t0 = time.perf_counter()
449
+ fitness = evaluator.evaluate(population, debug=debug)
450
+ elapsed = time.perf_counter() - t0
451
+ f0 = float(fitness[0].item()) if pop_size == 1 else float(fitness.mean().item())
452
+ out.update(
453
+ fitness=f0,
454
+ total_tests=evaluator.total_tests,
455
+ elapsed_s=elapsed,
456
+ categories={k: (float(v[0]), int(v[1])) for k, v in evaluator.category_scores.items()},
457
+ status="PASS" if f0 >= 0.9999 else "FAIL",
458
+ )
459
+ except Exception as e:
460
+ out.update(error=f"eval failed: {type(e).__name__}: {e}", status="ERROR")
461
+ return out
462
+
463
+ # Optional: CPU program test (8-bit CPU primitives are in every variant)
464
+ if run_cpu_program:
465
+ if manifest["memory_bytes"] >= PROGRAM_MIN_BYTES:
466
+ try:
467
+ cpu_tensors = {k: v.cpu() for k, v in tensors.items()}
468
+ cpu = GenericThresholdCPU(cpu_tensors)
469
+ mem, expected = builtin_program(manifest["addr_bits"])
470
+ state = {"pc": 0, "regs": [0] * 4, "flags": [0] * 4, "mem": mem, "halted": False}
471
+ t0 = time.perf_counter()
472
+ final, cycles = cpu.run(state, max_cycles=200)
473
+ cpu_elapsed = time.perf_counter() - t0
474
+ got = final["mem"][0x83]
475
+ out["cpu_program"] = {
476
+ "ok": got == expected,
477
+ "got": got,
478
+ "expected": expected,
479
+ "cycles": cycles,
480
+ "elapsed_s": cpu_elapsed,
481
+ }
482
+ if got != expected:
483
+ out["status"] = "FAIL"
484
+ except Exception as e:
485
+ out["cpu_program"] = {"error": str(e)}
486
+ else:
487
+ out["cpu_program"] = {"skipped": f"mem={manifest['memory_bytes']}B < {PROGRAM_MIN_BYTES}"}
488
+
489
+ # Wider-ALU chain test for 16/32-bit variants
490
+ bits = manifest["data_bits"]
491
+ if bits in (16, 32):
492
+ try:
493
+ alu_tensors = {k: v.cpu() for k, v in tensors.items()}
494
+ alu = GenericThresholdALU(alu_tensors, bits)
495
+ t0 = time.perf_counter()
496
+ if bits == 16:
497
+ x, y = 1234, 5678
498
+ z, _ = alu.add_n(x, y, 16); assert z == (x + y) & 0xFFFF
499
+ w, _ = alu.sub_n(z, x, 16); assert w == (z - x) & 0xFFFF, (w, z - x)
500
+ gt = alu.cmp_n(z, x, "greaterthan", 16); assert gt == 1
501
+ lt = alu.cmp_n(x, z, "lessthan", 16); assert lt == 1
502
+ eq = alu.cmp_n(w, y, "eq", 16); assert eq == 1
503
+ p = alu.mul_n(123, 5, 16); assert p == (123 * 5) & 0xFFFF
504
+ else: # 32
505
+ x, y = 1_000_000, 999_000
506
+ z, _ = alu.sub_n(x, y, 32); assert z == 1_000
507
+ s, _ = alu.add_n(z, x, 32); assert s == 1_001_000
508
+ p = alu.mul_n(z, 100, 32); assert p == 100_000
509
+ gt = alu.cmp_n(x, y, "greaterthan", 32); assert gt == 1
510
+ lt = alu.cmp_n(y, x, "lessthan", 32); assert lt == 1
511
+ eq = alu.cmp_n(p, 100_000, "equality", 32); assert eq == 1
512
+ chain_dt = time.perf_counter() - t0
513
+ out[f"alu_chain_{bits}"] = {"ok": True, "elapsed_s": chain_dt}
514
+ except AssertionError as e:
515
+ out[f"alu_chain_{bits}"] = {"ok": False, "error": f"chain mismatch: {e}"}
516
+ out["status"] = "FAIL"
517
+ except Exception as e:
518
+ out[f"alu_chain_{bits}"] = {"ok": False, "error": f"{type(e).__name__}: {e}"}
519
+ out["status"] = "FAIL"
520
+
521
+ return out
522
+
523
+
524
+ def print_row(r: Dict, show_cpu: bool) -> None:
525
+ if "error" in r:
526
+ print(f" {r['filename']:<48} ERROR: {r['error'][:80]}")
527
+ return
528
+ m = r["manifest"]
529
+ fit = f"{r['fitness']:.4f}" if r.get("fitness") is not None else "n/a"
530
+ cpu_col = ""
531
+ if show_cpu and "cpu_program" in r:
532
+ cp = r["cpu_program"]
533
+ if cp.get("ok"):
534
+ cpu_col = f" CPU OK ({cp['cycles']}cyc/{cp['elapsed_s']:.1f}s)"
535
+ elif "skipped" in cp:
536
+ cpu_col = f" CPU SKIP"
537
+ elif "error" in cp:
538
+ cpu_col = f" CPU ERR"
539
+ else:
540
+ cpu_col = f" CPU FAIL ({cp.get('got')}!={cp.get('expected')})"
541
+ chain_col = ""
542
+ if show_cpu:
543
+ for bits in (16, 32):
544
+ key = f"alu_chain_{bits}"
545
+ if key in r:
546
+ ch = r[key]
547
+ if ch.get("ok"):
548
+ chain_col = f" ALU{bits} OK ({ch['elapsed_s']:.2f}s)"
549
+ else:
550
+ chain_col = f" ALU{bits} FAIL"
551
+ print(
552
+ f" {r['filename']:<48} d={m['data_bits']:>2}b a={m['addr_bits']:>2}b "
553
+ f"mem={m['memory_bytes']:>6}B size={r['size_mb']:>6.1f}MB "
554
+ f"params={r['params']:>10,} fit={fit:>6} tests={r['total_tests']:>5} "
555
+ f"{r['status']:>5}{cpu_col}{chain_col}"
556
+ )
557
+
558
+
559
+ def main() -> int:
560
+ parser = argparse.ArgumentParser(description="Variant-agnostic eval harness")
561
+ parser.add_argument("path", help="Path to .safetensors file or directory of files")
562
+ parser.add_argument("--device", default="cpu", help="cpu (default) or cuda")
563
+ parser.add_argument("--pop_size", type=int, default=1)
564
+ parser.add_argument("--debug", action="store_true", help="Per-circuit detail per file")
565
+ parser.add_argument("--cpu-program", action="store_true",
566
+ help="Also run a small assembled program through the threshold CPU "
567
+ "(only applies to 8-bit variants with >= 512 B memory)")
568
+ parser.add_argument("--json", action="store_true", help="Emit JSON results to stdout instead of a table")
569
+ args = parser.parse_args()
570
+
571
+ files = list_safetensors(Path(args.path))
572
+ if not files:
573
+ print(f"No .safetensors files found under {args.path}", file=sys.stderr)
574
+ return 2
575
+
576
+ print(f"Evaluating {len(files)} file(s) on {args.device}\n")
577
+ results = []
578
+ fail_count = 0
579
+ for f in files:
580
+ print(f"=== {f.name}")
581
+ r = evaluate_one(f, device=args.device, pop_size=args.pop_size,
582
+ debug=args.debug, run_cpu_program=args.cpu_program)
583
+ results.append(r)
584
+ print_row(r, show_cpu=args.cpu_program)
585
+ if r.get("status") != "PASS":
586
+ fail_count += 1
587
+
588
+ if args.json:
589
+ # Make it serialisable
590
+ for r in results:
591
+ r["manifest"] = {k: (int(v) if isinstance(v, float) and v.is_integer() else v)
592
+ for k, v in r.get("manifest", {}).items()}
593
+ print(json.dumps(results, indent=2, default=str))
594
+ return fail_count
595
+
596
+ # Summary
597
+ print()
598
+ print("=" * 100)
599
+ print(" SUMMARY")
600
+ print("=" * 100)
601
+ for r in results:
602
+ print_row(r, show_cpu=args.cpu_program)
603
+
604
+ print()
605
+ if fail_count == 0:
606
+ print(f"ALL {len(files)} variants PASS")
607
+ else:
608
+ print(f"{fail_count}/{len(files)} variants FAIL")
609
+ return fail_count
610
+
611
+
612
+ if __name__ == "__main__":
613
+ sys.exit(main())
play.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hands-on playground for the 8bit-threshold-computer.
3
+
4
+ Loads the bundled safetensors model, reads its manifest, and exercises
5
+ threshold circuits at every level: raw Boolean gates, ALU arithmetic,
6
+ comparators, then a CPU runtime sized to the actual manifest that runs
7
+ a small assembled program end-to-end through the threshold weights.
8
+ """
9
+
10
+ from __future__ import annotations
11
+ import os
12
+ import sys
13
+ import torch
14
+ from safetensors import safe_open
15
+
16
+ sys.path.insert(0, os.path.dirname(__file__))
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Load model + manifest
20
+ # ---------------------------------------------------------------------------
21
+
22
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "neural_computer.safetensors")
23
+
24
+
25
+ def heaviside(x):
26
+ return (x >= 0).float()
27
+
28
+
29
+ def load_tensors(path):
30
+ out = {}
31
+ with safe_open(path, framework="pt") as f:
32
+ for name in f.keys():
33
+ out[name] = f.get_tensor(name).float()
34
+ return out
35
+
36
+
37
+ print("Loading", MODEL_PATH)
38
+ T = load_tensors(MODEL_PATH)
39
+
40
+ DATA_BITS = int(T["manifest.data_bits"].item())
41
+ ADDR_BITS = int(T["manifest.addr_bits"].item())
42
+ MEM_BYTES = int(T["manifest.memory_bytes"].item())
43
+ REGISTERS = int(T["manifest.registers"].item())
44
+ print(f"Manifest: data={DATA_BITS}-bit, addr={ADDR_BITS}-bit, mem={MEM_BYTES}B, regs={REGISTERS}")
45
+ print(f"Tensors: {len(T):,}")
46
+ print(f"Total params: {sum(t.numel() for t in T.values()):,}")
47
+ print()
48
+
49
+
50
+ def gate(name, inputs):
51
+ """Run one threshold gate identified by `name` (no .weight/.bias suffix)."""
52
+ w = T[name + ".weight"].view(-1)
53
+ b = T[name + ".bias"].view(-1)
54
+ inp = torch.tensor(inputs, dtype=torch.float32)
55
+ return int(heaviside((inp * w).sum() + b).item())
56
+
57
+
58
+ def xor(prefix, inputs):
59
+ """Run a 2-layer XOR-style gate (or/nand naming, e.g. ripple-carry adders)."""
60
+ a, b_ = inputs
61
+ h_or = gate(f"{prefix}.layer1.or", [a, b_])
62
+ h_nand = gate(f"{prefix}.layer1.nand", [a, b_])
63
+ return gate(f"{prefix}.layer2", [h_or, h_nand])
64
+
65
+
66
+ def xor_neuron(prefix, inputs):
67
+ """Run a 2-layer XOR-style gate (neuron1/neuron2 naming, e.g. boolean.xor)."""
68
+ a, b_ = inputs
69
+ h1 = gate(f"{prefix}.layer1.neuron1", [a, b_])
70
+ h2 = gate(f"{prefix}.layer1.neuron2", [a, b_])
71
+ return gate(f"{prefix}.layer2", [h1, h2])
72
+
73
+
74
+ def int_to_bits_msb(v, n):
75
+ return [(v >> (n - 1 - i)) & 1 for i in range(n)]
76
+
77
+
78
+ def bits_msb_to_int(bits):
79
+ out = 0
80
+ for b in bits:
81
+ out = (out << 1) | int(b)
82
+ return out
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Demo 1: Boolean gates (README Usage example)
87
+ # ---------------------------------------------------------------------------
88
+
89
+ print("=" * 64)
90
+ print(" Demo 1: Boolean threshold gates")
91
+ print("=" * 64)
92
+ truth_2 = [(0, 0), (0, 1), (1, 0), (1, 1)]
93
+ for gname in ["and", "or", "nand", "nor", "implies"]:
94
+ row = " ".join(f"{a}{b}->{gate(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
95
+ print(f" {gname:8} {row}")
96
+ # 2-layer gates (boolean.* uses neuron1/neuron2 naming)
97
+ for gname in ["xor", "xnor", "biimplies"]:
98
+ row = " ".join(f"{a}{b}->{xor_neuron(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
99
+ print(f" {gname:8} {row}")
100
+ # NOT (1-input)
101
+ print(f" not 0->{gate('boolean.not', [0])} 1->{gate('boolean.not', [1])}")
102
+ print()
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Demo 2: 8-bit ALU operations via threshold weights
107
+ # ---------------------------------------------------------------------------
108
+
109
+ print("=" * 64)
110
+ print(" Demo 2: 8-bit ALU arithmetic (every gate is threshold logic)")
111
+ print("=" * 64)
112
+
113
+
114
+ def fa(prefix, a, b, cin):
115
+ s1 = xor(f"{prefix}.ha1.sum", [a, b])
116
+ c1 = gate(f"{prefix}.ha1.carry", [a, b])
117
+ s2 = xor(f"{prefix}.ha2.sum", [s1, cin])
118
+ c2 = gate(f"{prefix}.ha2.carry", [s1, cin])
119
+ cout = gate(f"{prefix}.carry_or", [c1, c2])
120
+ return s2, cout
121
+
122
+
123
+ def alu_add(a, b):
124
+ """8-bit ripple carry add via threshold full-adders."""
125
+ a_bits = int_to_bits_msb(a, 8)
126
+ b_bits = int_to_bits_msb(b, 8)
127
+ a_lsb_first = list(reversed(a_bits))
128
+ b_lsb_first = list(reversed(b_bits))
129
+ carry = 0
130
+ sum_lsb_first = []
131
+ for i in range(8):
132
+ s, carry = fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb_first[i], b_lsb_first[i], carry)
133
+ sum_lsb_first.append(s)
134
+ return bits_msb_to_int(list(reversed(sum_lsb_first))), carry
135
+
136
+
137
+ def alu_sub(a, b):
138
+ """A - B via two's complement; uses sub8bit circuit family."""
139
+ a_lsb = list(reversed(int_to_bits_msb(a, 8)))
140
+ b_lsb = list(reversed(int_to_bits_msb(b, 8)))
141
+ carry = 1
142
+ diff_lsb = []
143
+ for i in range(8):
144
+ notb = gate(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]])
145
+ x1 = xor(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb])
146
+ x2 = xor(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry])
147
+ and1 = gate(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb])
148
+ and2 = gate(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry])
149
+ carry = gate(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2])
150
+ diff_lsb.append(x2)
151
+ return bits_msb_to_int(list(reversed(diff_lsb))), carry
152
+
153
+
154
+ def alu_compare(a, b, kind):
155
+ """8-bit comparators (single-layer GT/LT, two-layer EQ)."""
156
+ a_bits = int_to_bits_msb(a, 8)
157
+ b_bits = int_to_bits_msb(b, 8)
158
+ inp = a_bits + b_bits
159
+ if kind == "eq":
160
+ h_geq = gate("arithmetic.equality8bit.layer1.geq", inp)
161
+ h_leq = gate("arithmetic.equality8bit.layer1.leq", inp)
162
+ return gate("arithmetic.equality8bit.layer2", [h_geq, h_leq])
163
+ return gate(f"arithmetic.{kind}8bit", inp)
164
+
165
+
166
+ def alu_mul(a, b):
167
+ """Shift-add multiply via partial-product threshold AND gates + repeated add."""
168
+ a_bits = int_to_bits_msb(a, 8)
169
+ b_bits = int_to_bits_msb(b, 8)
170
+ pp = [[0] * 8 for _ in range(8)]
171
+ for i in range(8):
172
+ for j in range(8):
173
+ pp[i][j] = gate(f"alu.alu8bit.mul.pp.a{i}b{j}", [a_bits[i], b_bits[j]])
174
+ # accumulate weighted partial products in 8 bits (drop overflow above bit 7)
175
+ result = 0
176
+ for j in range(8): # j=0 is MSB of b -> weight 7-j
177
+ if b_bits[j] == 0:
178
+ continue
179
+ row = 0
180
+ for i in range(8):
181
+ row |= (pp[i][j] << (7 - i))
182
+ shift = 7 - j
183
+ result, _ = alu_add(result & 0xFF, (row << shift) & 0xFF)
184
+ return result & 0xFF
185
+
186
+
187
+ cases_arith = [(5, 3), (37, 100), (200, 99), (255, 1), (127, 128), (15, 17)]
188
+ print("ADD:")
189
+ for a, b in cases_arith:
190
+ r, c = alu_add(a, b)
191
+ expect = (a + b) & 0xFF
192
+ ok = "OK" if r == expect else "FAIL"
193
+ print(f" {a:3} + {b:3} = {r:3} (carry={c}) expected {expect:3} [{ok}]")
194
+
195
+ print("SUB:")
196
+ for a, b in cases_arith:
197
+ r, c = alu_sub(a, b)
198
+ expect = (a - b) & 0xFF
199
+ ok = "OK" if r == expect else "FAIL"
200
+ print(f" {a:3} - {b:3} = {r:3} (no_borrow={c}) expected {expect:3} [{ok}]")
201
+
202
+ print("CMP:")
203
+ cmp_cases = [(50, 30), (30, 50), (77, 77), (255, 0), (0, 255), (128, 127)]
204
+ for a, b in cmp_cases:
205
+ gt = alu_compare(a, b, "greaterthan")
206
+ lt = alu_compare(a, b, "lessthan")
207
+ eq = alu_compare(a, b, "eq")
208
+ print(f" {a:3} vs {b:3} -> GT={gt} LT={lt} EQ={eq}")
209
+
210
+ print("MUL (low 8 bits):")
211
+ for a, b in [(12, 11), (15, 17), (8, 32), (200, 3), (0, 99), (1, 255)]:
212
+ r = alu_mul(a, b)
213
+ expect = (a * b) & 0xFF
214
+ ok = "OK" if r == expect else "FAIL"
215
+ print(f" {a:3} * {b:3} = {r:3} expected {expect:3} [{ok}]")
216
+ print()
217
+
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # Demo 3: A 4-bit divisibility test (mod 5) - non-linearly-separable
221
+ # ---------------------------------------------------------------------------
222
+
223
+ print("=" * 64)
224
+ print(" Demo 3: mod-5 divisibility (multi-layer, hand-constructed)")
225
+ print("=" * 64)
226
+ # layer1: per-residue geq/leq -> layer2: eq -> layer3: OR all eq's
227
+ def mod5(v):
228
+ bits = int_to_bits_msb(v, 8)
229
+ # discover number of geq/leq neurons
230
+ n = 0
231
+ while f"modular.mod5.layer1.geq{n}.weight" in T:
232
+ n += 1
233
+ eqs = []
234
+ for i in range(n):
235
+ h_geq = gate(f"modular.mod5.layer1.geq{i}", bits)
236
+ h_leq = gate(f"modular.mod5.layer1.leq{i}", bits)
237
+ eqs.append(gate(f"modular.mod5.layer2.eq{i}", [h_geq, h_leq]))
238
+ return gate("modular.mod5.layer3.or", eqs)
239
+
240
+ hits = [v for v in range(256) if mod5(v)]
241
+ print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")
242
+ print(f" Sanity: {[h % 5 for h in hits[:12]]}")
243
+ print()
244
+
245
+
246
+ # ---------------------------------------------------------------------------
247
+ # Demo 4: Manifest-aware threshold CPU - run a real program
248
+ # ---------------------------------------------------------------------------
249
+
250
+ print("=" * 64)
251
+ print(" Demo 4: Threshold CPU running an assembled program")
252
+ print("=" * 64)
253
+
254
+
255
+ class ThresholdCPU10:
256
+ """CPU runtime matching the bundled small-profile manifest (10-bit addr)."""
257
+
258
+ def __init__(self, addr_bits, mem_bytes):
259
+ self.addr_bits = addr_bits
260
+ self.mem_bytes = mem_bytes
261
+
262
+ # --- memory primitives, fully through threshold weights ---
263
+ def addr_decode(self, addr):
264
+ bits = torch.tensor(int_to_bits_msb(addr, self.addr_bits), dtype=torch.float32)
265
+ w = T["memory.addr_decode.weight"]
266
+ b = T["memory.addr_decode.bias"]
267
+ return heaviside((w * bits).sum(dim=1) + b) # [mem_bytes]
268
+
269
+ def mem_read(self, mem, addr):
270
+ sel = self.addr_decode(addr)
271
+ mem_bits = torch.tensor(
272
+ [int_to_bits_msb(byte, 8) for byte in mem], dtype=torch.float32
273
+ )
274
+ and_w = T["memory.read.and.weight"]
275
+ and_b = T["memory.read.and.bias"]
276
+ or_w = T["memory.read.or.weight"]
277
+ or_b = T["memory.read.or.bias"]
278
+ out_bits = []
279
+ for bit in range(8):
280
+ inp = torch.stack([mem_bits[:, bit], sel], dim=1)
281
+ and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
282
+ out_bits.append(int(heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()))
283
+ return bits_msb_to_int(out_bits)
284
+
285
+ def mem_write(self, mem, addr, value):
286
+ sel = self.addr_decode(addr)
287
+ data_bits = torch.tensor(int_to_bits_msb(value, 8), dtype=torch.float32)
288
+ mem_bits = torch.tensor(
289
+ [int_to_bits_msb(byte, 8) for byte in mem], dtype=torch.float32
290
+ )
291
+ sel_w = T["memory.write.sel.weight"]
292
+ sel_b = T["memory.write.sel.bias"]
293
+ nsel_w = T["memory.write.nsel.weight"].squeeze(1)
294
+ nsel_b = T["memory.write.nsel.bias"]
295
+ and_old_w = T["memory.write.and_old.weight"]
296
+ and_old_b = T["memory.write.and_old.bias"]
297
+ and_new_w = T["memory.write.and_new.weight"]
298
+ and_new_b = T["memory.write.and_new.bias"]
299
+ or_w = T["memory.write.or.weight"]
300
+ or_b = T["memory.write.or.bias"]
301
+
302
+ we = torch.ones_like(sel)
303
+ sel_inp = torch.stack([sel, we], dim=1)
304
+ write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
305
+ nsel = heaviside(write_sel * nsel_w + nsel_b)
306
+
307
+ new_mem = mem[:]
308
+ for bit in range(8):
309
+ old = mem_bits[:, bit]
310
+ data_bit = data_bits[bit].expand(self.mem_bytes)
311
+ inp_old = torch.stack([old, nsel], dim=1)
312
+ inp_new = torch.stack([data_bit, write_sel], dim=1)
313
+ and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
314
+ and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
315
+ or_inp = torch.stack([and_old, and_new], dim=1)
316
+ new_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
317
+ mem_bits[:, bit] = new_bit
318
+ return [bits_msb_to_int([int(b) for b in mem_bits[i].tolist()]) for i in range(self.mem_bytes)]
319
+
320
+ # --- helper to use threshold ALU functions defined above ---
321
+ def step(self, state):
322
+ if state["halted"]:
323
+ return state
324
+ s = dict(state)
325
+ s["mem"] = state["mem"][:]
326
+ s["regs"] = state["regs"][:]
327
+ s["flags"] = state["flags"][:]
328
+
329
+ pc = s["pc"]
330
+ addr_mask = (1 << self.addr_bits) - 1
331
+ hi = self.mem_read(s["mem"], pc & addr_mask)
332
+ lo = self.mem_read(s["mem"], (pc + 1) & addr_mask)
333
+ ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
334
+ opcode = (ir >> 12) & 0xF
335
+ rd = (ir >> 10) & 0x3
336
+ rs = (ir >> 8) & 0x3
337
+ imm = ir & 0xFF
338
+
339
+ next_pc = (pc + 2) & addr_mask
340
+ addr16 = None
341
+ if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
342
+ ah = self.mem_read(s["mem"], next_pc)
343
+ al = self.mem_read(s["mem"], (next_pc + 1) & addr_mask)
344
+ addr16 = ((ah & 0xFF) << 8) | (al & 0xFF)
345
+ next_pc = (next_pc + 2) & addr_mask
346
+ addr10 = (addr16 & addr_mask) if addr16 is not None else None
347
+
348
+ a = s["regs"][rd]
349
+ b = s["regs"][rs]
350
+ write = True
351
+ result = a
352
+ carry = 0
353
+
354
+ if opcode == 0x0: # ADD
355
+ result, carry = alu_add(a, b)
356
+ elif opcode == 0x1: # SUB
357
+ result, carry = alu_sub(a, b)
358
+ elif opcode == 0x7: # MUL
359
+ result = alu_mul(a, b)
360
+ elif opcode == 0x9: # CMP
361
+ _r, carry = alu_sub(a, b)
362
+ z = 1 if _r == 0 else 0
363
+ n = 1 if (_r & 0x80) else 0
364
+ s["flags"] = [z, n, carry, 0]
365
+ write = False
366
+ opcode_was_cmp = True
367
+ elif opcode == 0xA: # LOAD
368
+ result = self.mem_read(s["mem"], addr10)
369
+ elif opcode == 0xB: # STORE
370
+ s["mem"] = self.mem_write(s["mem"], addr10, b & 0xFF)
371
+ write = False
372
+ elif opcode == 0xC: # JMP
373
+ s["pc"] = addr10
374
+ return s
375
+ elif opcode == 0xD: # Jcc
376
+ cond = imm & 0x7
377
+ take = False
378
+ z, n, c, v = s["flags"]
379
+ if cond == 0: take = z == 1
380
+ elif cond == 1: take = z == 0
381
+ elif cond == 2: take = c == 1
382
+ elif cond == 3: take = c == 0
383
+ elif cond == 4: take = n == 1
384
+ elif cond == 5: take = n == 0
385
+ elif cond == 6: take = v == 1
386
+ else: take = v == 0
387
+ s["pc"] = addr10 if take else next_pc
388
+ return s
389
+ elif opcode == 0xF: # HALT
390
+ s["halted"] = True
391
+ return s
392
+
393
+ if write and opcode != 0x9:
394
+ s["regs"][rd] = result & 0xFF
395
+ if opcode in (0x0, 0x1, 0x7):
396
+ z = 1 if (result & 0xFF) == 0 else 0
397
+ n = 1 if (result & 0x80) else 0
398
+ s["flags"] = [z, n, carry, 0]
399
+ s["pc"] = next_pc
400
+ return s
401
+
402
+ def run(self, state, max_cycles=64):
403
+ s = state
404
+ cycles = 0
405
+ while not s["halted"] and cycles < max_cycles:
406
+ s = self.step(s)
407
+ cycles += 1
408
+ return s, cycles
409
+
410
+
411
+ def encode_instr(opcode, rd, rs, imm):
412
+ return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm & 0xFF)
413
+
414
+
415
+ def write_word(mem, addr, value):
416
+ mem[addr] = (value >> 8) & 0xFF
417
+ mem[addr + 1] = value & 0xFF
418
+
419
+
420
+ # Program: count down from 5 to 0 with a loop, accumulate sum into R0.
421
+ #
422
+ # R1 = 5
423
+ # R0 = 0
424
+ # loop:
425
+ # R0 = R0 + R1 ; ADD R0, R1
426
+ # R1 = R1 - 1 ; we need an immediate decrement; use SUB R1, R2 with R2=1
427
+ # CMP R1, R3 ; R3=0
428
+ # JNZ loop
429
+ # HALT
430
+ #
431
+ # Memory layout (1KB):
432
+ # 0x0000: LOAD R1 <- M[0x0100] (5)
433
+ # 0x0004: LOAD R2 <- M[0x0101] (1)
434
+ # 0x0008: LOAD R3 <- M[0x0102] (0)
435
+ # 0x000C: LOAD R0 <- M[0x0102] (0)
436
+ # 0x0010: ADD R0, R1
437
+ # 0x0012: SUB R1, R2
438
+ # 0x0014: CMP R1, R3
439
+ # 0x0016: JNZ 0x0010
440
+ # 0x001A: STORE R0 -> M[0x0103]
441
+ # 0x001E: HALT
442
+
443
+ mem = [0] * 1024
444
+ mem[0x100] = 5
445
+ mem[0x101] = 1
446
+ mem[0x102] = 0
447
+
448
+ # LOAD R1 <- M[0x0100]
449
+ write_word(mem, 0x0000, encode_instr(0xA, 1, 0, 0)); write_word(mem, 0x0002, 0x0100)
450
+ # LOAD R2 <- M[0x0101]
451
+ write_word(mem, 0x0004, encode_instr(0xA, 2, 0, 0)); write_word(mem, 0x0006, 0x0101)
452
+ # LOAD R3 <- M[0x0102]
453
+ write_word(mem, 0x0008, encode_instr(0xA, 3, 0, 0)); write_word(mem, 0x000A, 0x0102)
454
+ # LOAD R0 <- M[0x0102]
455
+ write_word(mem, 0x000C, encode_instr(0xA, 0, 0, 0)); write_word(mem, 0x000E, 0x0102)
456
+ # ADD R0, R1
457
+ write_word(mem, 0x0010, encode_instr(0x0, 0, 1, 0))
458
+ # SUB R1, R2
459
+ write_word(mem, 0x0012, encode_instr(0x1, 1, 2, 0))
460
+ # CMP R1, R3
461
+ write_word(mem, 0x0014, encode_instr(0x9, 1, 3, 0))
462
+ # JNZ 0x0010 (cond=1 = NZ)
463
+ write_word(mem, 0x0016, encode_instr(0xD, 0, 0, 0x01)); write_word(mem, 0x0018, 0x0010)
464
+ # STORE R0 -> M[0x0103]
465
+ write_word(mem, 0x001A, encode_instr(0xB, 0, 0, 0)); write_word(mem, 0x001C, 0x0103)
466
+ # HALT
467
+ write_word(mem, 0x001E, encode_instr(0xF, 0, 0, 0))
468
+
469
+ cpu = ThresholdCPU10(addr_bits=ADDR_BITS, mem_bytes=MEM_BYTES)
470
+ state = {
471
+ "pc": 0,
472
+ "regs": [0, 0, 0, 0],
473
+ "flags": [0, 0, 0, 0],
474
+ "mem": mem,
475
+ "halted": False,
476
+ }
477
+ print(" Program: sum 5+4+3+2+1 via loop (uses ADD/SUB/CMP/Jcc/LOAD/STORE/HALT, all threshold-gated)")
478
+ print(" Running ...")
479
+ final, cycles = cpu.run(state, max_cycles=200)
480
+ print(f" Halted after {cycles} cycles")
481
+ print(f" R0={final['regs'][0]} R1={final['regs'][1]} R2={final['regs'][2]} R3={final['regs'][3]}")
482
+ print(f" M[0x0103] = {final['mem'][0x103]} (expected 15)")
483
+ print()
484
+ print("Done.")
variants/neural_alu16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f702736cd85124aac22602bf44617698309c03739a254b338409df87e22344c9
3
+ size 12434484
variants/neural_alu32.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c6761fa0366a19cdb9abb7c1c72f53b3a3a07032056b6d17dbed4131cc5e21d
3
+ size 14378864
variants/neural_alu8.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:246546bba4668a80a81e32b115d883d57b6b49bdfe8254034090089d5bf168cf
3
+ size 11561076
variants/neural_computer16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2daa9ab42ab63534010e363adbb3423502ebbe94a4b354797c25dece5eb5948
3
+ size 45730164
variants/neural_computer16_reduced.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1808dd34084e68120bccd277310749e047c357274440901baf2b01ca64e9e41
3
+ size 14640476
variants/neural_computer16_registers.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7487bbfe4da343bb2072c190e33b7861b452c947efd424a068927c413595049
3
+ size 12534076
variants/neural_computer16_scratchpad.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58dfbdf0a987c1675a68439d86a57a7631a7657ea60b6b0d3e568dfdeee88f2e
3
+ size 12704876
variants/neural_computer16_small.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:759920dcb38a340ee31f4d116df4983322258b796ac5d6021f7ca165986f5f5b
3
+ size 13104212
variants/neural_computer32.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18f4f3420fb307d90ea7a8fe356c196a59d7a0f2ed4ec57679d87b209a7fec22
3
+ size 47693920
variants/neural_computer32_reduced.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51e14c8819de3402881ce2ffe3cdd7e94a801c038c6ef8495110144e9348e2e7
3
+ size 16604104
variants/neural_computer32_registers.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13ac4eb1c793a6331a2ecfa13d3372edc9f4649163883244847ca6616062de05
3
+ size 14497800
variants/neural_computer32_scratchpad.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8c389b1730cc297f40944815aebda1b1a71b79bff738e9da767c82609e9d9bd
3
+ size 14668512
variants/neural_computer32_small.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3c67a2047c0cf7370e802727b9be51d8b7185dedfe108409968f0d838157e04
3
+ size 15067856
variants/neural_computer8.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acde9e66a5bae870b5684ddc8592a206f00b518e088e90965a73bfa35274ba2a
3
+ size 44846164
variants/neural_computer8_reduced.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e318727316bfb34f82cdc4a2b627d9f8475c3282cab67a6424ba642350dc823
3
+ size 13756476
variants/neural_computer8_registers.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b2c49af2b18786699351235d4d051afd7452e17616f0f06a87b3e5e9820da66
3
+ size 11649932
variants/neural_computer8_scratchpad.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40fe6db0454dd6ba33072a18f6c81ed1463830b270b708b9ae45f976e32cfc50
3
+ size 11820860
variants/neural_computer8_small.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:547aef648729c49dc106c14d05bfcdf12a6f1aca5de5b7d1c475fce65aef1373
3
+ size 12220204