Add 18 prebuilt variants and unified eval harness
Browse filesvariants/ holds every (8|16|32)-bit x (none|registers|scratchpad|small|reduced|full)
build (~325 MB total) so users can pull weights from HF without running build.py.
eval_all.py is variant-agnostic: reads each safetensors' manifest, runs the
BatchedFitnessEvaluator, and with --cpu-program also runs an assembled program
through the threshold CPU sized to the variant plus a chained N-bit ALU test
for 16/32-bit data widths.
build.py: fix infer_combinational_inputs N-bit handling. The barrel shifter
case used 1 << (2 - layer), valid only for 3-layer (8-bit) shifters; 16/32-bit
versions have 4-5 layers and crashed at the .inputs step. Priority encoder also
hardcoded 8 inputs and the legacy any_ge naming. Both now parse the bit width
from the gate name and emit correct shift amounts and signal references.
build_all.py orchestrates building + evaluating every named profile.
play.py is a standalone demo (Boolean gates, 8-bit ALU, mod-5, threshold CPU).
- build.py +41 -21
- build_all.py +181 -0
- eval_all.py +613 -0
- play.py +484 -0
- variants/neural_alu16.safetensors +3 -0
- variants/neural_alu32.safetensors +3 -0
- variants/neural_alu8.safetensors +3 -0
- variants/neural_computer16.safetensors +3 -0
- variants/neural_computer16_reduced.safetensors +3 -0
- variants/neural_computer16_registers.safetensors +3 -0
- variants/neural_computer16_scratchpad.safetensors +3 -0
- variants/neural_computer16_small.safetensors +3 -0
- variants/neural_computer32.safetensors +3 -0
- variants/neural_computer32_reduced.safetensors +3 -0
- variants/neural_computer32_registers.safetensors +3 -0
- variants/neural_computer32_scratchpad.safetensors +3 -0
- variants/neural_computer32_small.safetensors +3 -0
- variants/neural_computer8.safetensors +3 -0
- variants/neural_computer8_reduced.safetensors +3 -0
- variants/neural_computer8_registers.safetensors +3 -0
- variants/neural_computer8_scratchpad.safetensors +3 -0
- variants/neural_computer8_small.safetensors +3 -0
|
@@ -2505,7 +2505,7 @@ def infer_error_detection_inputs(gate: str, reg: SignalRegistry) -> List[int]:
|
|
| 2505 |
return [reg.get_id(f"$x[{i}]") for i in range(8)]
|
| 2506 |
|
| 2507 |
|
| 2508 |
-
def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]:
|
| 2509 |
if 'decoder3to8' in gate:
|
| 2510 |
for i in range(3):
|
| 2511 |
reg.register(f"$sel[{i}]")
|
|
@@ -2550,41 +2550,57 @@ def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]:
|
|
| 2550 |
return [reg.register(f"combinational.regmux4to1.bit{bit}.and{i}") for i in range(4)]
|
| 2551 |
return []
|
| 2552 |
if 'barrelshifter' in gate:
|
| 2553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2554 |
reg.register(f"$x[{i}]")
|
| 2555 |
-
for i in range(
|
| 2556 |
reg.register(f"$shift[{i}]")
|
| 2557 |
m = re.search(r'layer(\d+)\.bit(\d+)', gate)
|
| 2558 |
if m:
|
| 2559 |
layer, bit = int(m.group(1)), int(m.group(2))
|
| 2560 |
-
shift_amount = 1 << (
|
| 2561 |
-
prefix = f"
|
|
|
|
| 2562 |
if '.not_sel' in gate:
|
| 2563 |
-
return [reg.get_id(f"$shift[{
|
| 2564 |
if '.and_a' in gate:
|
| 2565 |
if layer == 0:
|
| 2566 |
return [reg.get_id(f"$x[{bit}]"), reg.register(f"{prefix}.not_sel")]
|
| 2567 |
else:
|
| 2568 |
-
prev_prefix = f"
|
| 2569 |
return [reg.register(f"{prev_prefix}.or"), reg.register(f"{prefix}.not_sel")]
|
| 2570 |
if '.and_b' in gate:
|
| 2571 |
-
src = (bit + shift_amount) %
|
| 2572 |
if layer == 0:
|
| 2573 |
-
return [reg.get_id(f"$x[{src}]"), reg.get_id(f"$shift[{
|
| 2574 |
else:
|
| 2575 |
-
prev_prefix = f"
|
| 2576 |
-
return [reg.register(f"{prev_prefix}.or"), reg.get_id(f"$shift[{
|
| 2577 |
if '.or' in gate:
|
| 2578 |
return [reg.register(f"{prefix}.and_a"), reg.register(f"{prefix}.and_b")]
|
| 2579 |
-
return [reg.get_id(f"$x[{i}]") for i in range(
|
| 2580 |
if 'priorityencoder' in gate:
|
| 2581 |
-
|
|
|
|
|
|
|
|
|
|
| 2582 |
reg.register(f"$x[{i}]")
|
|
|
|
| 2583 |
if '.any_ge' in gate:
|
| 2584 |
m = re.search(r'any_ge(\d+)', gate)
|
| 2585 |
if m:
|
| 2586 |
pos = int(m.group(1))
|
| 2587 |
-
return [reg.get_id(f"$x[{i}]") for i in range(pos,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2588 |
if '.is_highest' in gate:
|
| 2589 |
m = re.search(r'is_highest(\d+)', gate)
|
| 2590 |
if m:
|
|
@@ -2593,21 +2609,25 @@ def infer_combinational_inputs(gate: str, reg: SignalRegistry) -> List[int]:
|
|
| 2593 |
if pos == 0:
|
| 2594 |
return [reg.get_id("#0")]
|
| 2595 |
else:
|
| 2596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2597 |
if '.and' in gate:
|
| 2598 |
-
return [reg.get_id(f"$x[{pos}]"), reg.register(f"
|
| 2599 |
if '.out' in gate:
|
| 2600 |
m = re.search(r'out(\d+)', gate)
|
| 2601 |
if m:
|
| 2602 |
out_bit = int(m.group(1))
|
| 2603 |
inputs = []
|
| 2604 |
-
for pos in range(
|
| 2605 |
if (pos >> out_bit) & 1:
|
| 2606 |
-
inputs.append(reg.register(f"
|
| 2607 |
return inputs
|
| 2608 |
if '.valid' in gate:
|
| 2609 |
-
return [reg.get_id(f"$x[{i}]") for i in range(
|
| 2610 |
-
return [reg.get_id(f"$x[{i}]") for i in range(
|
| 2611 |
return []
|
| 2612 |
|
| 2613 |
|
|
@@ -2706,7 +2726,7 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
|
|
| 2706 |
if gate.startswith('error_detection.'):
|
| 2707 |
return infer_error_detection_inputs(gate, reg)
|
| 2708 |
if gate.startswith('combinational.'):
|
| 2709 |
-
return infer_combinational_inputs(gate, reg)
|
| 2710 |
weight_key = f"{gate}.weight"
|
| 2711 |
if weight_key in tensors:
|
| 2712 |
w = tensors[weight_key]
|
|
|
|
| 2505 |
return [reg.get_id(f"$x[{i}]") for i in range(8)]
|
| 2506 |
|
| 2507 |
|
| 2508 |
+
def infer_combinational_inputs(gate: str, reg: SignalRegistry, tensors: Dict[str, torch.Tensor] = None) -> List[int]:
|
| 2509 |
if 'decoder3to8' in gate:
|
| 2510 |
for i in range(3):
|
| 2511 |
reg.register(f"$sel[{i}]")
|
|
|
|
| 2550 |
return [reg.register(f"combinational.regmux4to1.bit{bit}.and{i}") for i in range(4)]
|
| 2551 |
return []
|
| 2552 |
if 'barrelshifter' in gate:
|
| 2553 |
+
import math as _math
|
| 2554 |
+
bs_match = re.search(r'barrelshifter(\d*)', gate)
|
| 2555 |
+
bits = int(bs_match.group(1)) if bs_match and bs_match.group(1) else 8
|
| 2556 |
+
bs_prefix = f"combinational.barrelshifter{bs_match.group(1) if bs_match else ''}"
|
| 2557 |
+
num_layers = max(1, _math.ceil(_math.log2(bits))) if bits > 1 else 1
|
| 2558 |
+
for i in range(bits):
|
| 2559 |
reg.register(f"$x[{i}]")
|
| 2560 |
+
for i in range(num_layers):
|
| 2561 |
reg.register(f"$shift[{i}]")
|
| 2562 |
m = re.search(r'layer(\d+)\.bit(\d+)', gate)
|
| 2563 |
if m:
|
| 2564 |
layer, bit = int(m.group(1)), int(m.group(2))
|
| 2565 |
+
shift_amount = 1 << (num_layers - 1 - layer)
|
| 2566 |
+
prefix = f"{bs_prefix}.layer{layer}.bit{bit}"
|
| 2567 |
+
sel_idx = num_layers - 1 - layer
|
| 2568 |
if '.not_sel' in gate:
|
| 2569 |
+
return [reg.get_id(f"$shift[{sel_idx}]")]
|
| 2570 |
if '.and_a' in gate:
|
| 2571 |
if layer == 0:
|
| 2572 |
return [reg.get_id(f"$x[{bit}]"), reg.register(f"{prefix}.not_sel")]
|
| 2573 |
else:
|
| 2574 |
+
prev_prefix = f"{bs_prefix}.layer{layer-1}.bit{bit}"
|
| 2575 |
return [reg.register(f"{prev_prefix}.or"), reg.register(f"{prefix}.not_sel")]
|
| 2576 |
if '.and_b' in gate:
|
| 2577 |
+
src = (bit + shift_amount) % bits
|
| 2578 |
if layer == 0:
|
| 2579 |
+
return [reg.get_id(f"$x[{src}]"), reg.get_id(f"$shift[{sel_idx}]")]
|
| 2580 |
else:
|
| 2581 |
+
prev_prefix = f"{bs_prefix}.layer{layer-1}.bit{src}"
|
| 2582 |
+
return [reg.register(f"{prev_prefix}.or"), reg.get_id(f"$shift[{sel_idx}]")]
|
| 2583 |
if '.or' in gate:
|
| 2584 |
return [reg.register(f"{prefix}.and_a"), reg.register(f"{prefix}.and_b")]
|
| 2585 |
+
return [reg.get_id(f"$x[{i}]") for i in range(bits)]
|
| 2586 |
if 'priorityencoder' in gate:
|
| 2587 |
+
pe_match = re.search(r'priorityencoder(\d*)', gate)
|
| 2588 |
+
bits = int(pe_match.group(1)) if pe_match and pe_match.group(1) else 8
|
| 2589 |
+
pe_prefix = f"combinational.priorityencoder{pe_match.group(1) if pe_match else ''}"
|
| 2590 |
+
for i in range(bits):
|
| 2591 |
reg.register(f"$x[{i}]")
|
| 2592 |
+
# Legacy 8-bit naming: any_ge{pos} = OR of bits at positions [pos..bits-1]
|
| 2593 |
if '.any_ge' in gate:
|
| 2594 |
m = re.search(r'any_ge(\d+)', gate)
|
| 2595 |
if m:
|
| 2596 |
pos = int(m.group(1))
|
| 2597 |
+
return [reg.get_id(f"$x[{i}]") for i in range(pos, bits)]
|
| 2598 |
+
# N-bit naming: any_higher{pos} = OR of bits 0..pos-1
|
| 2599 |
+
if '.any_higher' in gate:
|
| 2600 |
+
m = re.search(r'any_higher(\d+)', gate)
|
| 2601 |
+
if m:
|
| 2602 |
+
pos = int(m.group(1))
|
| 2603 |
+
return [reg.get_id(f"$x[{i}]") for i in range(pos)]
|
| 2604 |
if '.is_highest' in gate:
|
| 2605 |
m = re.search(r'is_highest(\d+)', gate)
|
| 2606 |
if m:
|
|
|
|
| 2609 |
if pos == 0:
|
| 2610 |
return [reg.get_id("#0")]
|
| 2611 |
else:
|
| 2612 |
+
# Try N-bit any_higher first, fall back to legacy any_ge
|
| 2613 |
+
ah_key = f"{pe_prefix}.any_higher{pos}"
|
| 2614 |
+
if tensors is not None and f"{ah_key}.weight" in tensors:
|
| 2615 |
+
return [reg.register(ah_key)]
|
| 2616 |
+
return [reg.register(f"{pe_prefix}.any_ge{pos-1}")]
|
| 2617 |
if '.and' in gate:
|
| 2618 |
+
return [reg.get_id(f"$x[{pos}]"), reg.register(f"{pe_prefix}.is_highest{pos}.not_higher")]
|
| 2619 |
if '.out' in gate:
|
| 2620 |
m = re.search(r'out(\d+)', gate)
|
| 2621 |
if m:
|
| 2622 |
out_bit = int(m.group(1))
|
| 2623 |
inputs = []
|
| 2624 |
+
for pos in range(bits):
|
| 2625 |
if (pos >> out_bit) & 1:
|
| 2626 |
+
inputs.append(reg.register(f"{pe_prefix}.is_highest{pos}.and"))
|
| 2627 |
return inputs
|
| 2628 |
if '.valid' in gate:
|
| 2629 |
+
return [reg.get_id(f"$x[{i}]") for i in range(bits)]
|
| 2630 |
+
return [reg.get_id(f"$x[{i}]") for i in range(bits)]
|
| 2631 |
return []
|
| 2632 |
|
| 2633 |
|
|
|
|
| 2726 |
if gate.startswith('error_detection.'):
|
| 2727 |
return infer_error_detection_inputs(gate, reg)
|
| 2728 |
if gate.startswith('combinational.'):
|
| 2729 |
+
return infer_combinational_inputs(gate, reg, tensors)
|
| 2730 |
weight_key = f"{gate}.weight"
|
| 2731 |
if weight_key in tensors:
|
| 2732 |
w = tensors[weight_key]
|
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build and verify every named (bits, memory_profile) variant.
|
| 3 |
+
|
| 4 |
+
Outputs:
|
| 5 |
+
variants/neural_alu{8,16,32}.safetensors - no memory
|
| 6 |
+
variants/neural_computer{8,16,32}_registers.safetensors - 16 B
|
| 7 |
+
variants/neural_computer{8,16,32}_scratchpad.safetensors - 256 B
|
| 8 |
+
variants/neural_computer{8,16,32}_small.safetensors - 1 KB
|
| 9 |
+
variants/neural_computer{8,16,32}_reduced.safetensors - 4 KB
|
| 10 |
+
variants/neural_computer{8,16,32}.safetensors - 64 KB
|
| 11 |
+
|
| 12 |
+
For each, runs eval.py via the BatchedFitnessEvaluator and records
|
| 13 |
+
(tensor count, params, file size, fitness, total_tests, seconds).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
import os
|
| 18 |
+
import shutil
|
| 19 |
+
import subprocess
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from safetensors import safe_open
|
| 26 |
+
|
| 27 |
+
ROOT = Path(__file__).resolve().parent
|
| 28 |
+
SEED = ROOT / "neural_computer.safetensors"
|
| 29 |
+
OUT_DIR = ROOT / "variants"
|
| 30 |
+
OUT_DIR.mkdir(exist_ok=True)
|
| 31 |
+
|
| 32 |
+
PROFILES = ["none", "registers", "scratchpad", "small", "reduced", "full"]
|
| 33 |
+
BITS = [8, 16, 32]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def variant_filename(bits: int, profile: str) -> str:
|
| 37 |
+
if profile == "none":
|
| 38 |
+
return f"neural_alu{bits}.safetensors"
|
| 39 |
+
if profile == "full":
|
| 40 |
+
return f"neural_computer{bits}.safetensors"
|
| 41 |
+
return f"neural_computer{bits}_{profile}.safetensors"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def run(cmd: list[str], timeout: int = 600) -> tuple[int, str]:
|
| 45 |
+
p = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
|
| 46 |
+
return p.returncode, (p.stdout or "") + (p.stderr or "")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def build_variant(bits: int, profile: str) -> Path:
|
| 50 |
+
out = OUT_DIR / variant_filename(bits, profile)
|
| 51 |
+
shutil.copy2(SEED, out)
|
| 52 |
+
cmd = [
|
| 53 |
+
sys.executable, str(ROOT / "build.py"),
|
| 54 |
+
"--bits", str(bits),
|
| 55 |
+
"-m", profile,
|
| 56 |
+
"--apply",
|
| 57 |
+
"--model", str(out),
|
| 58 |
+
"all",
|
| 59 |
+
]
|
| 60 |
+
rc, log = run(cmd, timeout=900)
|
| 61 |
+
if rc != 0:
|
| 62 |
+
raise RuntimeError(f"build failed for bits={bits} profile={profile}:\n{log[-1500:]}")
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def measure_variant(path: Path) -> dict:
|
| 67 |
+
"""Read tensor count, params, manifest values from the variant."""
|
| 68 |
+
with safe_open(str(path), framework="pt") as f:
|
| 69 |
+
keys = list(f.keys())
|
| 70 |
+
params = sum(f.get_tensor(k).numel() for k in keys)
|
| 71 |
+
manifest = {
|
| 72 |
+
k.split(".", 1)[1]: f.get_tensor(k).item()
|
| 73 |
+
for k in keys if k.startswith("manifest.") and f.get_tensor(k).numel() == 1
|
| 74 |
+
}
|
| 75 |
+
return {
|
| 76 |
+
"tensors": len(keys),
|
| 77 |
+
"params": params,
|
| 78 |
+
"size_mb": path.stat().st_size / (1024 * 1024),
|
| 79 |
+
"manifest": manifest,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def eval_variant(path: Path, device: str = "cpu", timeout: int = 600) -> dict:
|
| 84 |
+
"""Run eval.py against a variant and parse fitness."""
|
| 85 |
+
cmd = [
|
| 86 |
+
sys.executable, str(ROOT / "eval.py"),
|
| 87 |
+
"--model", str(path),
|
| 88 |
+
"--device", device,
|
| 89 |
+
"--quiet",
|
| 90 |
+
]
|
| 91 |
+
t0 = time.time()
|
| 92 |
+
rc, log = run(cmd, timeout=timeout)
|
| 93 |
+
dt = time.time() - t0
|
| 94 |
+
|
| 95 |
+
fitness = None
|
| 96 |
+
total_tests = None
|
| 97 |
+
status = "ERROR"
|
| 98 |
+
for line in log.splitlines():
|
| 99 |
+
line = line.strip()
|
| 100 |
+
if line.startswith("Fitness:"):
|
| 101 |
+
try:
|
| 102 |
+
fitness = float(line.split()[1])
|
| 103 |
+
except Exception:
|
| 104 |
+
pass
|
| 105 |
+
elif line.startswith("Total tests:"):
|
| 106 |
+
try:
|
| 107 |
+
total_tests = int(line.split()[2])
|
| 108 |
+
except Exception:
|
| 109 |
+
pass
|
| 110 |
+
elif line.startswith("STATUS:"):
|
| 111 |
+
status = line.split()[1]
|
| 112 |
+
return {
|
| 113 |
+
"rc": rc,
|
| 114 |
+
"fitness": fitness,
|
| 115 |
+
"total_tests": total_tests,
|
| 116 |
+
"status": status,
|
| 117 |
+
"elapsed_s": dt,
|
| 118 |
+
"log_tail": "\n".join(log.splitlines()[-15:]),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def main() -> None:
|
| 123 |
+
rows = []
|
| 124 |
+
print(f"Building 18 variants into {OUT_DIR}\n")
|
| 125 |
+
for bits in BITS:
|
| 126 |
+
for profile in PROFILES:
|
| 127 |
+
label = f"bits={bits} profile={profile}"
|
| 128 |
+
print(f"=== {label} ===", flush=True)
|
| 129 |
+
t0 = time.time()
|
| 130 |
+
try:
|
| 131 |
+
path = build_variant(bits, profile)
|
| 132 |
+
bt = time.time() - t0
|
| 133 |
+
meta = measure_variant(path)
|
| 134 |
+
ev = eval_variant(path, device="cpu", timeout=900)
|
| 135 |
+
rows.append({
|
| 136 |
+
"bits": bits, "profile": profile,
|
| 137 |
+
"filename": path.name,
|
| 138 |
+
"build_s": bt,
|
| 139 |
+
**meta,
|
| 140 |
+
**{k: ev[k] for k in ("fitness", "total_tests", "status", "elapsed_s")},
|
| 141 |
+
"log_tail": ev["log_tail"] if ev["status"] != "PASS" else "",
|
| 142 |
+
})
|
| 143 |
+
print(f" built in {bt:.1f}s size={meta['size_mb']:.1f}MB"
|
| 144 |
+
f" params={meta['params']:,} tensors={meta['tensors']:,}")
|
| 145 |
+
print(f" eval: fitness={ev['fitness']} tests={ev['total_tests']}"
|
| 146 |
+
f" status={ev['status']} ({ev['elapsed_s']:.1f}s)")
|
| 147 |
+
if ev["status"] != "PASS":
|
| 148 |
+
print(" --- failure tail ---")
|
| 149 |
+
print(" " + "\n ".join(ev["log_tail"].splitlines()))
|
| 150 |
+
print(" --------------------")
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f" EXCEPTION: {e}")
|
| 153 |
+
rows.append({"bits": bits, "profile": profile, "error": str(e)})
|
| 154 |
+
print()
|
| 155 |
+
|
| 156 |
+
print("=" * 88)
|
| 157 |
+
print(" SUMMARY")
|
| 158 |
+
print("=" * 88)
|
| 159 |
+
header = f"{'bits':>4} {'profile':<11} {'size_MB':>8} {'tensors':>8} {'params':>11} {'fitness':>9} {'tests':>6} {'status':>7}"
|
| 160 |
+
print(header)
|
| 161 |
+
print("-" * len(header))
|
| 162 |
+
for r in rows:
|
| 163 |
+
if "error" in r:
|
| 164 |
+
print(f"{r['bits']:>4} {r['profile']:<11} ERROR: {r['error'][:60]}")
|
| 165 |
+
continue
|
| 166 |
+
fit = f"{r['fitness']:.4f}" if r['fitness'] is not None else "n/a"
|
| 167 |
+
tests = r['total_tests'] if r['total_tests'] is not None else "?"
|
| 168 |
+
print(f"{r['bits']:>4} {r['profile']:<11} {r['size_mb']:>8.1f} "
|
| 169 |
+
f"{r['tensors']:>8,} {r['params']:>11,} "
|
| 170 |
+
f"{fit:>9} {tests:>6} {r['status']:>7}")
|
| 171 |
+
|
| 172 |
+
fail = [r for r in rows if r.get("status") != "PASS" or "error" in r]
|
| 173 |
+
print()
|
| 174 |
+
if fail:
|
| 175 |
+
print(f"FAILURES: {len(fail)}/{len(rows)}")
|
| 176 |
+
else:
|
| 177 |
+
print(f"ALL {len(rows)} VARIANTS PASS")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
main()
|
|
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified evaluation harness for any threshold-computer variant.
|
| 3 |
+
|
| 4 |
+
Drops the `--cpu-test` smoke test (which was hardcoded to 16-bit/64KB) and
|
| 5 |
+
adds variant-aware sweep modes. The same harness handles every (data_bits,
|
| 6 |
+
addr_bits) configuration: it reads the manifest from each safetensors file,
|
| 7 |
+
runs the BatchedFitnessEvaluator at the right device, and reports per-file
|
| 8 |
+
plus per-category results.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python eval_all.py path/to/file.safetensors # one file
|
| 12 |
+
python eval_all.py variants/ # every .safetensors in dir
|
| 13 |
+
python eval_all.py --device cpu variants/ # CPU only (default)
|
| 14 |
+
python eval_all.py --pop_size 32 variants/ # batched pop eval
|
| 15 |
+
python eval_all.py --debug path/to/file.safetensors # per-circuit detail
|
| 16 |
+
python eval_all.py --cpu-program PATH # also run an assembled program
|
| 17 |
+
# through the threshold CPU
|
| 18 |
+
# sized to the file's manifest
|
| 19 |
+
|
| 20 |
+
Exit code:
|
| 21 |
+
0 if all files PASS (fitness >= 0.9999)
|
| 22 |
+
N where N is the number of FAILing files
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Dict, List, Optional, Tuple
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
from safetensors import safe_open
|
| 37 |
+
|
| 38 |
+
# Reuse eval.py's evaluator (variant-aware)
|
| 39 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 40 |
+
from eval import (
|
| 41 |
+
BatchedFitnessEvaluator,
|
| 42 |
+
create_population,
|
| 43 |
+
load_model,
|
| 44 |
+
get_manifest,
|
| 45 |
+
heaviside,
|
| 46 |
+
int_to_bits,
|
| 47 |
+
bits_to_int,
|
| 48 |
+
bits_msb_to_lsb,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Variant-aware threshold ALU + CPU
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
class GenericThresholdALU:
|
| 57 |
+
"""Variant-aware threshold ALU. Reads manifest, runs ADD/SUB/CMP/MUL etc.
|
| 58 |
+
|
| 59 |
+
Currently supports the 8-bit ALU primitives (ripplecarry8bit, sub8bit,
|
| 60 |
+
cmp8bit, mul/div). For wider data paths, use the BatchedFitnessEvaluator
|
| 61 |
+
which already handles 16/32-bit comparators, subtractors, etc.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, tensors: Dict[str, torch.Tensor], data_bits: int):
|
| 65 |
+
self.T = tensors
|
| 66 |
+
self.data_bits = data_bits
|
| 67 |
+
|
| 68 |
+
def _g(self, name, inputs):
|
| 69 |
+
w = self.T[name + ".weight"].view(-1)
|
| 70 |
+
b = self.T[name + ".bias"].view(-1)
|
| 71 |
+
return int(heaviside((torch.tensor(inputs, dtype=torch.float32) * w).sum() + b).item())
|
| 72 |
+
|
| 73 |
+
def _xor_or_nand(self, prefix, inputs):
|
| 74 |
+
a, b_ = inputs
|
| 75 |
+
h_or = self._g(f"{prefix}.layer1.or", [a, b_])
|
| 76 |
+
h_nand = self._g(f"{prefix}.layer1.nand", [a, b_])
|
| 77 |
+
return self._g(f"{prefix}.layer2", [h_or, h_nand])
|
| 78 |
+
|
| 79 |
+
def _fa(self, prefix, a, b, cin):
|
| 80 |
+
s1 = self._xor_or_nand(f"{prefix}.ha1.sum", [a, b])
|
| 81 |
+
c1 = self._g(f"{prefix}.ha1.carry", [a, b])
|
| 82 |
+
s2 = self._xor_or_nand(f"{prefix}.ha2.sum", [s1, cin])
|
| 83 |
+
c2 = self._g(f"{prefix}.ha2.carry", [s1, cin])
|
| 84 |
+
cout = self._g(f"{prefix}.carry_or", [c1, c2])
|
| 85 |
+
return s2, cout
|
| 86 |
+
|
| 87 |
+
def add8(self, a, b):
|
| 88 |
+
a_lsb = list(reversed(int_to_bits(a, 8)))
|
| 89 |
+
b_lsb = list(reversed(int_to_bits(b, 8)))
|
| 90 |
+
carry = 0
|
| 91 |
+
s_lsb = []
|
| 92 |
+
for i in range(8):
|
| 93 |
+
s, carry = self._fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb[i], b_lsb[i], carry)
|
| 94 |
+
s_lsb.append(s)
|
| 95 |
+
return bits_to_int(list(reversed(s_lsb))), carry
|
| 96 |
+
|
| 97 |
+
def sub8(self, a, b):
|
| 98 |
+
a_lsb = list(reversed(int_to_bits(a, 8)))
|
| 99 |
+
b_lsb = list(reversed(int_to_bits(b, 8)))
|
| 100 |
+
carry = 1
|
| 101 |
+
d_lsb = []
|
| 102 |
+
for i in range(8):
|
| 103 |
+
notb = self._g(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]])
|
| 104 |
+
x1 = self._xor_or_nand(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb])
|
| 105 |
+
x2 = self._xor_or_nand(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry])
|
| 106 |
+
and1 = self._g(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb])
|
| 107 |
+
and2 = self._g(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry])
|
| 108 |
+
carry = self._g(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2])
|
| 109 |
+
d_lsb.append(x2)
|
| 110 |
+
return bits_to_int(list(reversed(d_lsb))), carry
|
| 111 |
+
|
| 112 |
+
def cmp8(self, a, b, kind):
|
| 113 |
+
inp = int_to_bits(a, 8) + int_to_bits(b, 8)
|
| 114 |
+
if kind == "eq":
|
| 115 |
+
h_geq = self._g("arithmetic.equality8bit.layer1.geq", inp)
|
| 116 |
+
h_leq = self._g("arithmetic.equality8bit.layer1.leq", inp)
|
| 117 |
+
return self._g("arithmetic.equality8bit.layer2", [h_geq, h_leq])
|
| 118 |
+
return self._g(f"arithmetic.{kind}8bit", inp)
|
| 119 |
+
|
| 120 |
+
def mul8(self, a, b):
|
| 121 |
+
ab = int_to_bits(a, 8)
|
| 122 |
+
bb = int_to_bits(b, 8)
|
| 123 |
+
result = 0
|
| 124 |
+
for j in range(8):
|
| 125 |
+
if bb[j] == 0:
|
| 126 |
+
continue
|
| 127 |
+
row = 0
|
| 128 |
+
for i in range(8):
|
| 129 |
+
pp = self._g(f"alu.alu8bit.mul.pp.a{i}b{j}", [ab[i], bb[j]])
|
| 130 |
+
row |= (pp << (7 - i))
|
| 131 |
+
shift = 7 - j
|
| 132 |
+
result, _ = self.add8(result & 0xFF, (row << shift) & 0xFF)
|
| 133 |
+
return result & 0xFF
|
| 134 |
+
|
| 135 |
+
# ----- N-bit primitives (for 16-bit and 32-bit variants) ----------------
|
| 136 |
+
|
| 137 |
+
def add_n(self, a: int, b: int, bits: int):
|
| 138 |
+
"""Width-generic ripple-carry add via arithmetic.ripplecarry{N}bit."""
|
| 139 |
+
prefix = f"arithmetic.ripplecarry{bits}bit"
|
| 140 |
+
a_lsb = list(reversed(int_to_bits(a, bits)))
|
| 141 |
+
b_lsb = list(reversed(int_to_bits(b, bits)))
|
| 142 |
+
carry = 0
|
| 143 |
+
s_lsb = []
|
| 144 |
+
for i in range(bits):
|
| 145 |
+
s, carry = self._fa(f"{prefix}.fa{i}", a_lsb[i], b_lsb[i], carry)
|
| 146 |
+
s_lsb.append(s)
|
| 147 |
+
return bits_to_int(list(reversed(s_lsb))), carry
|
| 148 |
+
|
| 149 |
+
def sub_n(self, a: int, b: int, bits: int):
|
| 150 |
+
"""N-bit two's-complement subtract via arithmetic.sub{N}bit (N >= 16).
|
| 151 |
+
|
| 152 |
+
Structure (per build.add_sub_nbits): N NOT gates + N standard full adders.
|
| 153 |
+
"""
|
| 154 |
+
prefix = f"arithmetic.sub{bits}bit"
|
| 155 |
+
a_lsb = list(reversed(int_to_bits(a, bits)))
|
| 156 |
+
b_lsb = list(reversed(int_to_bits(b, bits)))
|
| 157 |
+
# NOT each B bit
|
| 158 |
+
notb = [self._g(f"{prefix}.not_b.bit{i}", [b_lsb[i]]) for i in range(bits)]
|
| 159 |
+
carry = 1 # carry-in = 1 for two's-complement
|
| 160 |
+
d_lsb = []
|
| 161 |
+
for i in range(bits):
|
| 162 |
+
s, carry = self._fa(f"{prefix}.fa{i}", a_lsb[i], notb[i], carry)
|
| 163 |
+
d_lsb.append(s)
|
| 164 |
+
return bits_to_int(list(reversed(d_lsb))), carry
|
| 165 |
+
|
| 166 |
+
def cmp_n(self, a: int, b: int, kind: str, bits: int):
|
| 167 |
+
"""N-bit comparator. For bits <= 16 single-layer; bits == 32 cascaded."""
|
| 168 |
+
a_bits = int_to_bits(a, bits)
|
| 169 |
+
b_bits = int_to_bits(b, bits)
|
| 170 |
+
if bits <= 16:
|
| 171 |
+
inp = a_bits + b_bits
|
| 172 |
+
if kind == "eq":
|
| 173 |
+
h_geq = self._g(f"arithmetic.equality{bits}bit.layer1.geq", inp)
|
| 174 |
+
h_leq = self._g(f"arithmetic.equality{bits}bit.layer1.leq", inp)
|
| 175 |
+
return self._g(f"arithmetic.equality{bits}bit.layer2", [h_geq, h_leq])
|
| 176 |
+
return self._g(f"arithmetic.{kind}{bits}bit", inp)
|
| 177 |
+
# 32-bit: cascaded byte-wise
|
| 178 |
+
prefix = f"arithmetic.cmp{bits}bit"
|
| 179 |
+
num_bytes = bits // 8
|
| 180 |
+
# per-byte gt/lt/eq
|
| 181 |
+
byte_gt, byte_lt, byte_eq = [], [], []
|
| 182 |
+
for bn in range(num_bytes):
|
| 183 |
+
ab = a_bits[bn*8:(bn+1)*8]
|
| 184 |
+
bb = b_bits[bn*8:(bn+1)*8]
|
| 185 |
+
byte_gt.append(self._g(f"{prefix}.byte{bn}.gt", ab + bb))
|
| 186 |
+
byte_lt.append(self._g(f"{prefix}.byte{bn}.lt", ab + bb))
|
| 187 |
+
geq = self._g(f"{prefix}.byte{bn}.eq.geq", ab + bb)
|
| 188 |
+
leq = self._g(f"{prefix}.byte{bn}.eq.leq", ab + bb)
|
| 189 |
+
byte_eq.append(self._g(f"{prefix}.byte{bn}.eq.and", [geq, leq]))
|
| 190 |
+
if kind == "equality":
|
| 191 |
+
# OR of all eq's, but the gate is `arithmetic.equality{bits}bit` with weight=[1,1,..,1]/bias=-num_bytes
|
| 192 |
+
return self._g(f"arithmetic.equality{bits}bit", byte_eq)
|
| 193 |
+
# cascade
|
| 194 |
+
cascade_gt = [byte_gt[0]]
|
| 195 |
+
cascade_lt = [byte_lt[0]]
|
| 196 |
+
for bn in range(1, num_bytes):
|
| 197 |
+
all_eq = self._g(f"{prefix}.cascade.gt.stage{bn}.all_eq", byte_eq[:bn])
|
| 198 |
+
cascade_gt.append(self._g(f"{prefix}.cascade.gt.stage{bn}.and", [all_eq, byte_gt[bn]]))
|
| 199 |
+
all_eq2 = self._g(f"{prefix}.cascade.lt.stage{bn}.all_eq", byte_eq[:bn])
|
| 200 |
+
cascade_lt.append(self._g(f"{prefix}.cascade.lt.stage{bn}.and", [all_eq2, byte_lt[bn]]))
|
| 201 |
+
if kind == "greaterthan":
|
| 202 |
+
return self._g(f"arithmetic.greaterthan{bits}bit", cascade_gt)
|
| 203 |
+
if kind == "lessthan":
|
| 204 |
+
return self._g(f"arithmetic.lessthan{bits}bit", cascade_lt)
|
| 205 |
+
raise ValueError(f"unsupported cmp kind {kind} for bits={bits}")
|
| 206 |
+
|
| 207 |
+
def mul_n(self, a: int, b: int, bits: int):
|
| 208 |
+
"""N-bit shift-add multiply (low N bits only)."""
|
| 209 |
+
ab = int_to_bits(a, bits)
|
| 210 |
+
bb = int_to_bits(b, bits)
|
| 211 |
+
mask = (1 << bits) - 1
|
| 212 |
+
result = 0
|
| 213 |
+
for j in range(bits):
|
| 214 |
+
if bb[j] == 0:
|
| 215 |
+
continue
|
| 216 |
+
row = 0
|
| 217 |
+
for i in range(bits):
|
| 218 |
+
pp = self._g(f"alu.alu{bits}bit.mul.pp.a{i}b{j}", [ab[i], bb[j]])
|
| 219 |
+
row |= (pp << (bits - 1 - i))
|
| 220 |
+
shift = (bits - 1) - j
|
| 221 |
+
result, _ = self.add_n(result & mask, (row << shift) & mask, bits)
|
| 222 |
+
return result & mask
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class GenericThresholdCPU:
|
| 226 |
+
"""Variant-aware CPU runtime. Sized from the variant's manifest."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, tensors: Dict[str, torch.Tensor]):
|
| 229 |
+
self.T = tensors
|
| 230 |
+
m = get_manifest(tensors)
|
| 231 |
+
self.data_bits = m["data_bits"]
|
| 232 |
+
self.addr_bits = m["addr_bits"]
|
| 233 |
+
self.mem_bytes = m["memory_bytes"]
|
| 234 |
+
# 8-bit CPU primitives (ripplecarry8bit, sub8bit, alu.alu8bit.*, memory.*,
|
| 235 |
+
# control.*) are present in every variant regardless of manifest data_bits.
|
| 236 |
+
# Wider data widths simply add additional standalone ALU primitives.
|
| 237 |
+
if self.mem_bytes == 0:
|
| 238 |
+
raise NotImplementedError(
|
| 239 |
+
"Pure-ALU variants have no memory; cannot run CPU programs"
|
| 240 |
+
)
|
| 241 |
+
self.alu = GenericThresholdALU(tensors, 8)
|
| 242 |
+
|
| 243 |
+
def _addr_decode(self, addr):
|
| 244 |
+
bits = torch.tensor(int_to_bits(addr, self.addr_bits), dtype=torch.float32)
|
| 245 |
+
w = self.T["memory.addr_decode.weight"]
|
| 246 |
+
b = self.T["memory.addr_decode.bias"]
|
| 247 |
+
return heaviside((w * bits).sum(dim=1) + b)
|
| 248 |
+
|
| 249 |
+
def mem_read(self, mem, addr):
|
| 250 |
+
sel = self._addr_decode(addr)
|
| 251 |
+
mem_bits = torch.tensor(
|
| 252 |
+
[int_to_bits(byte, 8) for byte in mem], dtype=torch.float32
|
| 253 |
+
)
|
| 254 |
+
and_w = self.T["memory.read.and.weight"]
|
| 255 |
+
and_b = self.T["memory.read.and.bias"]
|
| 256 |
+
or_w = self.T["memory.read.or.weight"]
|
| 257 |
+
or_b = self.T["memory.read.or.bias"]
|
| 258 |
+
out = []
|
| 259 |
+
for bit in range(8):
|
| 260 |
+
inp = torch.stack([mem_bits[:, bit], sel], dim=1)
|
| 261 |
+
and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
|
| 262 |
+
out.append(int(heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()))
|
| 263 |
+
return bits_to_int(out)
|
| 264 |
+
|
| 265 |
+
def mem_write(self, mem, addr, value):
|
| 266 |
+
sel = self._addr_decode(addr)
|
| 267 |
+
data_bits = torch.tensor(int_to_bits(value, 8), dtype=torch.float32)
|
| 268 |
+
mem_bits = torch.tensor(
|
| 269 |
+
[int_to_bits(byte, 8) for byte in mem], dtype=torch.float32
|
| 270 |
+
)
|
| 271 |
+
sel_w = self.T["memory.write.sel.weight"]
|
| 272 |
+
sel_b = self.T["memory.write.sel.bias"]
|
| 273 |
+
nsel_w = self.T["memory.write.nsel.weight"].squeeze(1)
|
| 274 |
+
nsel_b = self.T["memory.write.nsel.bias"]
|
| 275 |
+
and_old_w = self.T["memory.write.and_old.weight"]
|
| 276 |
+
and_old_b = self.T["memory.write.and_old.bias"]
|
| 277 |
+
and_new_w = self.T["memory.write.and_new.weight"]
|
| 278 |
+
and_new_b = self.T["memory.write.and_new.bias"]
|
| 279 |
+
or_w = self.T["memory.write.or.weight"]
|
| 280 |
+
or_b = self.T["memory.write.or.bias"]
|
| 281 |
+
we = torch.ones_like(sel)
|
| 282 |
+
sel_inp = torch.stack([sel, we], dim=1)
|
| 283 |
+
write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
|
| 284 |
+
nsel = heaviside(write_sel * nsel_w + nsel_b)
|
| 285 |
+
for bit in range(8):
|
| 286 |
+
old = mem_bits[:, bit]
|
| 287 |
+
data_bit = data_bits[bit].expand(self.mem_bytes)
|
| 288 |
+
inp_old = torch.stack([old, nsel], dim=1)
|
| 289 |
+
inp_new = torch.stack([data_bit, write_sel], dim=1)
|
| 290 |
+
and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
|
| 291 |
+
and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
|
| 292 |
+
or_inp = torch.stack([and_old, and_new], dim=1)
|
| 293 |
+
new_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
|
| 294 |
+
mem_bits[:, bit] = new_bit
|
| 295 |
+
return [bits_to_int([int(b) for b in mem_bits[i].tolist()]) for i in range(self.mem_bytes)]
|
| 296 |
+
|
| 297 |
+
def step(self, state):
|
| 298 |
+
if state["halted"]:
|
| 299 |
+
return state
|
| 300 |
+
s = dict(state)
|
| 301 |
+
s["mem"] = state["mem"][:]
|
| 302 |
+
s["regs"] = state["regs"][:]
|
| 303 |
+
s["flags"] = state["flags"][:]
|
| 304 |
+
addr_mask = (1 << self.addr_bits) - 1
|
| 305 |
+
pc = s["pc"]
|
| 306 |
+
hi = self.mem_read(s["mem"], pc & addr_mask)
|
| 307 |
+
lo = self.mem_read(s["mem"], (pc + 1) & addr_mask)
|
| 308 |
+
ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
|
| 309 |
+
opcode = (ir >> 12) & 0xF
|
| 310 |
+
rd = (ir >> 10) & 0x3
|
| 311 |
+
rs = (ir >> 8) & 0x3
|
| 312 |
+
imm = ir & 0xFF
|
| 313 |
+
next_pc = (pc + 2) & addr_mask
|
| 314 |
+
addr_full = None
|
| 315 |
+
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
|
| 316 |
+
ah = self.mem_read(s["mem"], next_pc)
|
| 317 |
+
al = self.mem_read(s["mem"], (next_pc + 1) & addr_mask)
|
| 318 |
+
addr_full = ((ah & 0xFF) << 8) | (al & 0xFF)
|
| 319 |
+
next_pc = (next_pc + 2) & addr_mask
|
| 320 |
+
addr = (addr_full & addr_mask) if addr_full is not None else None
|
| 321 |
+
a = s["regs"][rd]
|
| 322 |
+
b = s["regs"][rs]
|
| 323 |
+
result = a
|
| 324 |
+
carry = 0
|
| 325 |
+
write_result = True
|
| 326 |
+
if opcode == 0x0:
|
| 327 |
+
result, carry = self.alu.add8(a, b)
|
| 328 |
+
elif opcode == 0x1:
|
| 329 |
+
result, carry = self.alu.sub8(a, b)
|
| 330 |
+
elif opcode == 0x7:
|
| 331 |
+
result = self.alu.mul8(a, b)
|
| 332 |
+
elif opcode == 0x9:
|
| 333 |
+
r2, carry = self.alu.sub8(a, b)
|
| 334 |
+
z = 1 if r2 == 0 else 0
|
| 335 |
+
n = 1 if (r2 & 0x80) else 0
|
| 336 |
+
s["flags"] = [z, n, carry, 0]
|
| 337 |
+
write_result = False
|
| 338 |
+
elif opcode == 0xA:
|
| 339 |
+
result = self.mem_read(s["mem"], addr)
|
| 340 |
+
elif opcode == 0xB:
|
| 341 |
+
s["mem"] = self.mem_write(s["mem"], addr, b & 0xFF)
|
| 342 |
+
write_result = False
|
| 343 |
+
elif opcode == 0xC:
|
| 344 |
+
s["pc"] = addr
|
| 345 |
+
return s
|
| 346 |
+
elif opcode == 0xD:
|
| 347 |
+
cond = imm & 0x7
|
| 348 |
+
z, n, c, v = s["flags"]
|
| 349 |
+
take = [z == 1, z == 0, c == 1, c == 0,
|
| 350 |
+
n == 1, n == 0, v == 1, v == 0][cond]
|
| 351 |
+
s["pc"] = addr if take else next_pc
|
| 352 |
+
return s
|
| 353 |
+
elif opcode == 0xF:
|
| 354 |
+
s["halted"] = True
|
| 355 |
+
return s
|
| 356 |
+
|
| 357 |
+
if write_result and opcode != 0x9:
|
| 358 |
+
s["regs"][rd] = result & 0xFF
|
| 359 |
+
if opcode in (0x0, 0x1, 0x7):
|
| 360 |
+
z = 1 if (result & 0xFF) == 0 else 0
|
| 361 |
+
n = 1 if (result & 0x80) else 0
|
| 362 |
+
s["flags"] = [z, n, carry, 0]
|
| 363 |
+
s["pc"] = next_pc
|
| 364 |
+
return s
|
| 365 |
+
|
| 366 |
+
def run(self, state, max_cycles=200):
|
| 367 |
+
s = state
|
| 368 |
+
cycles = 0
|
| 369 |
+
while not s["halted"] and cycles < max_cycles:
|
| 370 |
+
s = self.step(s)
|
| 371 |
+
cycles += 1
|
| 372 |
+
return s, cycles
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _encode_instr(opcode, rd, rs, imm):
|
| 376 |
+
return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm & 0xFF)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def _w16(mem, addr, value):
|
| 380 |
+
mem[addr] = (value >> 8) & 0xFF
|
| 381 |
+
mem[addr + 1] = value & 0xFF
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
PROGRAM_MIN_BYTES = 0x84 # code 0x00..0x1F + data 0x80..0x83
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def builtin_program(addr_bits: int) -> Tuple[List[int], int]:
|
| 388 |
+
"""Sum 5+4+3+2+1 via a loop. Returns (mem, expected_result_at_0x83).
|
| 389 |
+
|
| 390 |
+
Compact layout: code at 0x00..0x1F (32 bytes), data at 0x80..0x83 (4 bytes).
|
| 391 |
+
Total footprint 132 bytes -- fits within scratchpad (256 B) and larger.
|
| 392 |
+
Requires addr_bits >= 8.
|
| 393 |
+
"""
|
| 394 |
+
if (1 << addr_bits) < PROGRAM_MIN_BYTES:
|
| 395 |
+
raise ValueError(f"addr_bits={addr_bits} too small for builtin program")
|
| 396 |
+
mem = [0] * (1 << addr_bits)
|
| 397 |
+
mem[0x80] = 5 # initial counter
|
| 398 |
+
mem[0x81] = 1 # decrement
|
| 399 |
+
mem[0x82] = 0 # zero (for compare and accumulator init)
|
| 400 |
+
# mem[0x83] is the output
|
| 401 |
+
_w16(mem, 0x0000, _encode_instr(0xA, 1, 0, 0)); _w16(mem, 0x0002, 0x0080)
|
| 402 |
+
_w16(mem, 0x0004, _encode_instr(0xA, 2, 0, 0)); _w16(mem, 0x0006, 0x0081)
|
| 403 |
+
_w16(mem, 0x0008, _encode_instr(0xA, 3, 0, 0)); _w16(mem, 0x000A, 0x0082)
|
| 404 |
+
_w16(mem, 0x000C, _encode_instr(0xA, 0, 0, 0)); _w16(mem, 0x000E, 0x0082)
|
| 405 |
+
_w16(mem, 0x0010, _encode_instr(0x0, 0, 1, 0))
|
| 406 |
+
_w16(mem, 0x0012, _encode_instr(0x1, 1, 2, 0))
|
| 407 |
+
_w16(mem, 0x0014, _encode_instr(0x9, 1, 3, 0))
|
| 408 |
+
_w16(mem, 0x0016, _encode_instr(0xD, 0, 0, 0x01)); _w16(mem, 0x0018, 0x0010)
|
| 409 |
+
_w16(mem, 0x001A, _encode_instr(0xB, 0, 0, 0)); _w16(mem, 0x001C, 0x0083)
|
| 410 |
+
_w16(mem, 0x001E, _encode_instr(0xF, 0, 0, 0))
|
| 411 |
+
return mem, 15
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# ---------------------------------------------------------------------------
|
| 415 |
+
# Eval driver
|
| 416 |
+
# ---------------------------------------------------------------------------
|
| 417 |
+
|
| 418 |
+
def list_safetensors(path: Path) -> List[Path]:
|
| 419 |
+
if path.is_file():
|
| 420 |
+
return [path]
|
| 421 |
+
if path.is_dir():
|
| 422 |
+
return sorted(p for p in path.glob("*.safetensors") if p.is_file())
|
| 423 |
+
return []
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def evaluate_one(path: Path, device: str, pop_size: int, debug: bool, run_cpu_program: bool) -> Dict:
|
| 427 |
+
out: Dict = {"path": str(path), "filename": path.name}
|
| 428 |
+
try:
|
| 429 |
+
tensors = load_model(str(path))
|
| 430 |
+
except Exception as e:
|
| 431 |
+
out.update(error=f"load failed: {e}", status="ERROR")
|
| 432 |
+
return out
|
| 433 |
+
|
| 434 |
+
manifest = get_manifest(tensors)
|
| 435 |
+
out.update(
|
| 436 |
+
size_mb=path.stat().st_size / (1024 * 1024),
|
| 437 |
+
tensors=len(tensors),
|
| 438 |
+
params=sum(t.numel() for t in tensors.values()),
|
| 439 |
+
manifest=manifest,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Move to device
|
| 443 |
+
tensors = {k: v.to(device) for k, v in tensors.items()}
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
evaluator = BatchedFitnessEvaluator(device=device, model_path=str(path), tensors=tensors)
|
| 447 |
+
population = create_population(tensors, pop_size=pop_size, device=device)
|
| 448 |
+
t0 = time.perf_counter()
|
| 449 |
+
fitness = evaluator.evaluate(population, debug=debug)
|
| 450 |
+
elapsed = time.perf_counter() - t0
|
| 451 |
+
f0 = float(fitness[0].item()) if pop_size == 1 else float(fitness.mean().item())
|
| 452 |
+
out.update(
|
| 453 |
+
fitness=f0,
|
| 454 |
+
total_tests=evaluator.total_tests,
|
| 455 |
+
elapsed_s=elapsed,
|
| 456 |
+
categories={k: (float(v[0]), int(v[1])) for k, v in evaluator.category_scores.items()},
|
| 457 |
+
status="PASS" if f0 >= 0.9999 else "FAIL",
|
| 458 |
+
)
|
| 459 |
+
except Exception as e:
|
| 460 |
+
out.update(error=f"eval failed: {type(e).__name__}: {e}", status="ERROR")
|
| 461 |
+
return out
|
| 462 |
+
|
| 463 |
+
# Optional: CPU program test (8-bit CPU primitives are in every variant)
|
| 464 |
+
if run_cpu_program:
|
| 465 |
+
if manifest["memory_bytes"] >= PROGRAM_MIN_BYTES:
|
| 466 |
+
try:
|
| 467 |
+
cpu_tensors = {k: v.cpu() for k, v in tensors.items()}
|
| 468 |
+
cpu = GenericThresholdCPU(cpu_tensors)
|
| 469 |
+
mem, expected = builtin_program(manifest["addr_bits"])
|
| 470 |
+
state = {"pc": 0, "regs": [0] * 4, "flags": [0] * 4, "mem": mem, "halted": False}
|
| 471 |
+
t0 = time.perf_counter()
|
| 472 |
+
final, cycles = cpu.run(state, max_cycles=200)
|
| 473 |
+
cpu_elapsed = time.perf_counter() - t0
|
| 474 |
+
got = final["mem"][0x83]
|
| 475 |
+
out["cpu_program"] = {
|
| 476 |
+
"ok": got == expected,
|
| 477 |
+
"got": got,
|
| 478 |
+
"expected": expected,
|
| 479 |
+
"cycles": cycles,
|
| 480 |
+
"elapsed_s": cpu_elapsed,
|
| 481 |
+
}
|
| 482 |
+
if got != expected:
|
| 483 |
+
out["status"] = "FAIL"
|
| 484 |
+
except Exception as e:
|
| 485 |
+
out["cpu_program"] = {"error": str(e)}
|
| 486 |
+
else:
|
| 487 |
+
out["cpu_program"] = {"skipped": f"mem={manifest['memory_bytes']}B < {PROGRAM_MIN_BYTES}"}
|
| 488 |
+
|
| 489 |
+
# Wider-ALU chain test for 16/32-bit variants
|
| 490 |
+
bits = manifest["data_bits"]
|
| 491 |
+
if bits in (16, 32):
|
| 492 |
+
try:
|
| 493 |
+
alu_tensors = {k: v.cpu() for k, v in tensors.items()}
|
| 494 |
+
alu = GenericThresholdALU(alu_tensors, bits)
|
| 495 |
+
t0 = time.perf_counter()
|
| 496 |
+
if bits == 16:
|
| 497 |
+
x, y = 1234, 5678
|
| 498 |
+
z, _ = alu.add_n(x, y, 16); assert z == (x + y) & 0xFFFF
|
| 499 |
+
w, _ = alu.sub_n(z, x, 16); assert w == (z - x) & 0xFFFF, (w, z - x)
|
| 500 |
+
gt = alu.cmp_n(z, x, "greaterthan", 16); assert gt == 1
|
| 501 |
+
lt = alu.cmp_n(x, z, "lessthan", 16); assert lt == 1
|
| 502 |
+
eq = alu.cmp_n(w, y, "eq", 16); assert eq == 1
|
| 503 |
+
p = alu.mul_n(123, 5, 16); assert p == (123 * 5) & 0xFFFF
|
| 504 |
+
else: # 32
|
| 505 |
+
x, y = 1_000_000, 999_000
|
| 506 |
+
z, _ = alu.sub_n(x, y, 32); assert z == 1_000
|
| 507 |
+
s, _ = alu.add_n(z, x, 32); assert s == 1_001_000
|
| 508 |
+
p = alu.mul_n(z, 100, 32); assert p == 100_000
|
| 509 |
+
gt = alu.cmp_n(x, y, "greaterthan", 32); assert gt == 1
|
| 510 |
+
lt = alu.cmp_n(y, x, "lessthan", 32); assert lt == 1
|
| 511 |
+
eq = alu.cmp_n(p, 100_000, "equality", 32); assert eq == 1
|
| 512 |
+
chain_dt = time.perf_counter() - t0
|
| 513 |
+
out[f"alu_chain_{bits}"] = {"ok": True, "elapsed_s": chain_dt}
|
| 514 |
+
except AssertionError as e:
|
| 515 |
+
out[f"alu_chain_{bits}"] = {"ok": False, "error": f"chain mismatch: {e}"}
|
| 516 |
+
out["status"] = "FAIL"
|
| 517 |
+
except Exception as e:
|
| 518 |
+
out[f"alu_chain_{bits}"] = {"ok": False, "error": f"{type(e).__name__}: {e}"}
|
| 519 |
+
out["status"] = "FAIL"
|
| 520 |
+
|
| 521 |
+
return out
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def print_row(r: Dict, show_cpu: bool) -> None:
|
| 525 |
+
if "error" in r:
|
| 526 |
+
print(f" {r['filename']:<48} ERROR: {r['error'][:80]}")
|
| 527 |
+
return
|
| 528 |
+
m = r["manifest"]
|
| 529 |
+
fit = f"{r['fitness']:.4f}" if r.get("fitness") is not None else "n/a"
|
| 530 |
+
cpu_col = ""
|
| 531 |
+
if show_cpu and "cpu_program" in r:
|
| 532 |
+
cp = r["cpu_program"]
|
| 533 |
+
if cp.get("ok"):
|
| 534 |
+
cpu_col = f" CPU OK ({cp['cycles']}cyc/{cp['elapsed_s']:.1f}s)"
|
| 535 |
+
elif "skipped" in cp:
|
| 536 |
+
cpu_col = f" CPU SKIP"
|
| 537 |
+
elif "error" in cp:
|
| 538 |
+
cpu_col = f" CPU ERR"
|
| 539 |
+
else:
|
| 540 |
+
cpu_col = f" CPU FAIL ({cp.get('got')}!={cp.get('expected')})"
|
| 541 |
+
chain_col = ""
|
| 542 |
+
if show_cpu:
|
| 543 |
+
for bits in (16, 32):
|
| 544 |
+
key = f"alu_chain_{bits}"
|
| 545 |
+
if key in r:
|
| 546 |
+
ch = r[key]
|
| 547 |
+
if ch.get("ok"):
|
| 548 |
+
chain_col = f" ALU{bits} OK ({ch['elapsed_s']:.2f}s)"
|
| 549 |
+
else:
|
| 550 |
+
chain_col = f" ALU{bits} FAIL"
|
| 551 |
+
print(
|
| 552 |
+
f" {r['filename']:<48} d={m['data_bits']:>2}b a={m['addr_bits']:>2}b "
|
| 553 |
+
f"mem={m['memory_bytes']:>6}B size={r['size_mb']:>6.1f}MB "
|
| 554 |
+
f"params={r['params']:>10,} fit={fit:>6} tests={r['total_tests']:>5} "
|
| 555 |
+
f"{r['status']:>5}{cpu_col}{chain_col}"
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def main() -> int:
|
| 560 |
+
parser = argparse.ArgumentParser(description="Variant-agnostic eval harness")
|
| 561 |
+
parser.add_argument("path", help="Path to .safetensors file or directory of files")
|
| 562 |
+
parser.add_argument("--device", default="cpu", help="cpu (default) or cuda")
|
| 563 |
+
parser.add_argument("--pop_size", type=int, default=1)
|
| 564 |
+
parser.add_argument("--debug", action="store_true", help="Per-circuit detail per file")
|
| 565 |
+
parser.add_argument("--cpu-program", action="store_true",
|
| 566 |
+
help="Also run a small assembled program through the threshold CPU "
|
| 567 |
+
"(only applies to 8-bit variants with >= 512 B memory)")
|
| 568 |
+
parser.add_argument("--json", action="store_true", help="Emit JSON results to stdout instead of a table")
|
| 569 |
+
args = parser.parse_args()
|
| 570 |
+
|
| 571 |
+
files = list_safetensors(Path(args.path))
|
| 572 |
+
if not files:
|
| 573 |
+
print(f"No .safetensors files found under {args.path}", file=sys.stderr)
|
| 574 |
+
return 2
|
| 575 |
+
|
| 576 |
+
print(f"Evaluating {len(files)} file(s) on {args.device}\n")
|
| 577 |
+
results = []
|
| 578 |
+
fail_count = 0
|
| 579 |
+
for f in files:
|
| 580 |
+
print(f"=== {f.name}")
|
| 581 |
+
r = evaluate_one(f, device=args.device, pop_size=args.pop_size,
|
| 582 |
+
debug=args.debug, run_cpu_program=args.cpu_program)
|
| 583 |
+
results.append(r)
|
| 584 |
+
print_row(r, show_cpu=args.cpu_program)
|
| 585 |
+
if r.get("status") != "PASS":
|
| 586 |
+
fail_count += 1
|
| 587 |
+
|
| 588 |
+
if args.json:
|
| 589 |
+
# Make it serialisable
|
| 590 |
+
for r in results:
|
| 591 |
+
r["manifest"] = {k: (int(v) if isinstance(v, float) and v.is_integer() else v)
|
| 592 |
+
for k, v in r.get("manifest", {}).items()}
|
| 593 |
+
print(json.dumps(results, indent=2, default=str))
|
| 594 |
+
return fail_count
|
| 595 |
+
|
| 596 |
+
# Summary
|
| 597 |
+
print()
|
| 598 |
+
print("=" * 100)
|
| 599 |
+
print(" SUMMARY")
|
| 600 |
+
print("=" * 100)
|
| 601 |
+
for r in results:
|
| 602 |
+
print_row(r, show_cpu=args.cpu_program)
|
| 603 |
+
|
| 604 |
+
print()
|
| 605 |
+
if fail_count == 0:
|
| 606 |
+
print(f"ALL {len(files)} variants PASS")
|
| 607 |
+
else:
|
| 608 |
+
print(f"{fail_count}/{len(files)} variants FAIL")
|
| 609 |
+
return fail_count
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
if __name__ == "__main__":
|
| 613 |
+
sys.exit(main())
|
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hands-on playground for the 8bit-threshold-computer.
|
| 3 |
+
|
| 4 |
+
Loads the bundled safetensors model, reads its manifest, and exercises
|
| 5 |
+
threshold circuits at every level: raw Boolean gates, ALU arithmetic,
|
| 6 |
+
comparators, then a CPU runtime sized to the actual manifest that runs
|
| 7 |
+
a small assembled program end-to-end through the threshold weights.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import torch
|
| 14 |
+
from safetensors import safe_open
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# Load model + manifest
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
MODEL_PATH = os.path.join(os.path.dirname(__file__), "neural_computer.safetensors")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def heaviside(x):
|
| 26 |
+
return (x >= 0).float()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_tensors(path):
|
| 30 |
+
out = {}
|
| 31 |
+
with safe_open(path, framework="pt") as f:
|
| 32 |
+
for name in f.keys():
|
| 33 |
+
out[name] = f.get_tensor(name).float()
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
print("Loading", MODEL_PATH)
|
| 38 |
+
T = load_tensors(MODEL_PATH)
|
| 39 |
+
|
| 40 |
+
DATA_BITS = int(T["manifest.data_bits"].item())
|
| 41 |
+
ADDR_BITS = int(T["manifest.addr_bits"].item())
|
| 42 |
+
MEM_BYTES = int(T["manifest.memory_bytes"].item())
|
| 43 |
+
REGISTERS = int(T["manifest.registers"].item())
|
| 44 |
+
print(f"Manifest: data={DATA_BITS}-bit, addr={ADDR_BITS}-bit, mem={MEM_BYTES}B, regs={REGISTERS}")
|
| 45 |
+
print(f"Tensors: {len(T):,}")
|
| 46 |
+
print(f"Total params: {sum(t.numel() for t in T.values()):,}")
|
| 47 |
+
print()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def gate(name, inputs):
|
| 51 |
+
"""Run one threshold gate identified by `name` (no .weight/.bias suffix)."""
|
| 52 |
+
w = T[name + ".weight"].view(-1)
|
| 53 |
+
b = T[name + ".bias"].view(-1)
|
| 54 |
+
inp = torch.tensor(inputs, dtype=torch.float32)
|
| 55 |
+
return int(heaviside((inp * w).sum() + b).item())
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def xor(prefix, inputs):
|
| 59 |
+
"""Run a 2-layer XOR-style gate (or/nand naming, e.g. ripple-carry adders)."""
|
| 60 |
+
a, b_ = inputs
|
| 61 |
+
h_or = gate(f"{prefix}.layer1.or", [a, b_])
|
| 62 |
+
h_nand = gate(f"{prefix}.layer1.nand", [a, b_])
|
| 63 |
+
return gate(f"{prefix}.layer2", [h_or, h_nand])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def xor_neuron(prefix, inputs):
|
| 67 |
+
"""Run a 2-layer XOR-style gate (neuron1/neuron2 naming, e.g. boolean.xor)."""
|
| 68 |
+
a, b_ = inputs
|
| 69 |
+
h1 = gate(f"{prefix}.layer1.neuron1", [a, b_])
|
| 70 |
+
h2 = gate(f"{prefix}.layer1.neuron2", [a, b_])
|
| 71 |
+
return gate(f"{prefix}.layer2", [h1, h2])
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def int_to_bits_msb(v, n):
|
| 75 |
+
return [(v >> (n - 1 - i)) & 1 for i in range(n)]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def bits_msb_to_int(bits):
|
| 79 |
+
out = 0
|
| 80 |
+
for b in bits:
|
| 81 |
+
out = (out << 1) | int(b)
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
# Demo 1: Boolean gates (README Usage example)
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
print("=" * 64)
|
| 90 |
+
print(" Demo 1: Boolean threshold gates")
|
| 91 |
+
print("=" * 64)
|
| 92 |
+
truth_2 = [(0, 0), (0, 1), (1, 0), (1, 1)]
|
| 93 |
+
for gname in ["and", "or", "nand", "nor", "implies"]:
|
| 94 |
+
row = " ".join(f"{a}{b}->{gate(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
|
| 95 |
+
print(f" {gname:8} {row}")
|
| 96 |
+
# 2-layer gates (boolean.* uses neuron1/neuron2 naming)
|
| 97 |
+
for gname in ["xor", "xnor", "biimplies"]:
|
| 98 |
+
row = " ".join(f"{a}{b}->{xor_neuron(f'boolean.{gname}', [a, b])}" for a, b in truth_2)
|
| 99 |
+
print(f" {gname:8} {row}")
|
| 100 |
+
# NOT (1-input)
|
| 101 |
+
print(f" not 0->{gate('boolean.not', [0])} 1->{gate('boolean.not', [1])}")
|
| 102 |
+
print()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Demo 2: 8-bit ALU operations via threshold weights
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
|
| 109 |
+
print("=" * 64)
|
| 110 |
+
print(" Demo 2: 8-bit ALU arithmetic (every gate is threshold logic)")
|
| 111 |
+
print("=" * 64)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def fa(prefix, a, b, cin):
|
| 115 |
+
s1 = xor(f"{prefix}.ha1.sum", [a, b])
|
| 116 |
+
c1 = gate(f"{prefix}.ha1.carry", [a, b])
|
| 117 |
+
s2 = xor(f"{prefix}.ha2.sum", [s1, cin])
|
| 118 |
+
c2 = gate(f"{prefix}.ha2.carry", [s1, cin])
|
| 119 |
+
cout = gate(f"{prefix}.carry_or", [c1, c2])
|
| 120 |
+
return s2, cout
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def alu_add(a, b):
|
| 124 |
+
"""8-bit ripple carry add via threshold full-adders."""
|
| 125 |
+
a_bits = int_to_bits_msb(a, 8)
|
| 126 |
+
b_bits = int_to_bits_msb(b, 8)
|
| 127 |
+
a_lsb_first = list(reversed(a_bits))
|
| 128 |
+
b_lsb_first = list(reversed(b_bits))
|
| 129 |
+
carry = 0
|
| 130 |
+
sum_lsb_first = []
|
| 131 |
+
for i in range(8):
|
| 132 |
+
s, carry = fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb_first[i], b_lsb_first[i], carry)
|
| 133 |
+
sum_lsb_first.append(s)
|
| 134 |
+
return bits_msb_to_int(list(reversed(sum_lsb_first))), carry
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def alu_sub(a, b):
|
| 138 |
+
"""A - B via two's complement; uses sub8bit circuit family."""
|
| 139 |
+
a_lsb = list(reversed(int_to_bits_msb(a, 8)))
|
| 140 |
+
b_lsb = list(reversed(int_to_bits_msb(b, 8)))
|
| 141 |
+
carry = 1
|
| 142 |
+
diff_lsb = []
|
| 143 |
+
for i in range(8):
|
| 144 |
+
notb = gate(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]])
|
| 145 |
+
x1 = xor(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb])
|
| 146 |
+
x2 = xor(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry])
|
| 147 |
+
and1 = gate(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb])
|
| 148 |
+
and2 = gate(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry])
|
| 149 |
+
carry = gate(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2])
|
| 150 |
+
diff_lsb.append(x2)
|
| 151 |
+
return bits_msb_to_int(list(reversed(diff_lsb))), carry
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def alu_compare(a, b, kind):
|
| 155 |
+
"""8-bit comparators (single-layer GT/LT, two-layer EQ)."""
|
| 156 |
+
a_bits = int_to_bits_msb(a, 8)
|
| 157 |
+
b_bits = int_to_bits_msb(b, 8)
|
| 158 |
+
inp = a_bits + b_bits
|
| 159 |
+
if kind == "eq":
|
| 160 |
+
h_geq = gate("arithmetic.equality8bit.layer1.geq", inp)
|
| 161 |
+
h_leq = gate("arithmetic.equality8bit.layer1.leq", inp)
|
| 162 |
+
return gate("arithmetic.equality8bit.layer2", [h_geq, h_leq])
|
| 163 |
+
return gate(f"arithmetic.{kind}8bit", inp)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def alu_mul(a, b):
|
| 167 |
+
"""Shift-add multiply via partial-product threshold AND gates + repeated add."""
|
| 168 |
+
a_bits = int_to_bits_msb(a, 8)
|
| 169 |
+
b_bits = int_to_bits_msb(b, 8)
|
| 170 |
+
pp = [[0] * 8 for _ in range(8)]
|
| 171 |
+
for i in range(8):
|
| 172 |
+
for j in range(8):
|
| 173 |
+
pp[i][j] = gate(f"alu.alu8bit.mul.pp.a{i}b{j}", [a_bits[i], b_bits[j]])
|
| 174 |
+
# accumulate weighted partial products in 8 bits (drop overflow above bit 7)
|
| 175 |
+
result = 0
|
| 176 |
+
for j in range(8): # j=0 is MSB of b -> weight 7-j
|
| 177 |
+
if b_bits[j] == 0:
|
| 178 |
+
continue
|
| 179 |
+
row = 0
|
| 180 |
+
for i in range(8):
|
| 181 |
+
row |= (pp[i][j] << (7 - i))
|
| 182 |
+
shift = 7 - j
|
| 183 |
+
result, _ = alu_add(result & 0xFF, (row << shift) & 0xFF)
|
| 184 |
+
return result & 0xFF
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
cases_arith = [(5, 3), (37, 100), (200, 99), (255, 1), (127, 128), (15, 17)]
|
| 188 |
+
print("ADD:")
|
| 189 |
+
for a, b in cases_arith:
|
| 190 |
+
r, c = alu_add(a, b)
|
| 191 |
+
expect = (a + b) & 0xFF
|
| 192 |
+
ok = "OK" if r == expect else "FAIL"
|
| 193 |
+
print(f" {a:3} + {b:3} = {r:3} (carry={c}) expected {expect:3} [{ok}]")
|
| 194 |
+
|
| 195 |
+
print("SUB:")
|
| 196 |
+
for a, b in cases_arith:
|
| 197 |
+
r, c = alu_sub(a, b)
|
| 198 |
+
expect = (a - b) & 0xFF
|
| 199 |
+
ok = "OK" if r == expect else "FAIL"
|
| 200 |
+
print(f" {a:3} - {b:3} = {r:3} (no_borrow={c}) expected {expect:3} [{ok}]")
|
| 201 |
+
|
| 202 |
+
print("CMP:")
|
| 203 |
+
cmp_cases = [(50, 30), (30, 50), (77, 77), (255, 0), (0, 255), (128, 127)]
|
| 204 |
+
for a, b in cmp_cases:
|
| 205 |
+
gt = alu_compare(a, b, "greaterthan")
|
| 206 |
+
lt = alu_compare(a, b, "lessthan")
|
| 207 |
+
eq = alu_compare(a, b, "eq")
|
| 208 |
+
print(f" {a:3} vs {b:3} -> GT={gt} LT={lt} EQ={eq}")
|
| 209 |
+
|
| 210 |
+
print("MUL (low 8 bits):")
|
| 211 |
+
for a, b in [(12, 11), (15, 17), (8, 32), (200, 3), (0, 99), (1, 255)]:
|
| 212 |
+
r = alu_mul(a, b)
|
| 213 |
+
expect = (a * b) & 0xFF
|
| 214 |
+
ok = "OK" if r == expect else "FAIL"
|
| 215 |
+
print(f" {a:3} * {b:3} = {r:3} expected {expect:3} [{ok}]")
|
| 216 |
+
print()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ---------------------------------------------------------------------------
|
| 220 |
+
# Demo 3: A 4-bit divisibility test (mod 5) - non-linearly-separable
|
| 221 |
+
# ---------------------------------------------------------------------------
|
| 222 |
+
|
| 223 |
+
print("=" * 64)
|
| 224 |
+
print(" Demo 3: mod-5 divisibility (multi-layer, hand-constructed)")
|
| 225 |
+
print("=" * 64)
|
| 226 |
+
# layer1: per-residue geq/leq -> layer2: eq -> layer3: OR all eq's
|
| 227 |
+
def mod5(v):
|
| 228 |
+
bits = int_to_bits_msb(v, 8)
|
| 229 |
+
# discover number of geq/leq neurons
|
| 230 |
+
n = 0
|
| 231 |
+
while f"modular.mod5.layer1.geq{n}.weight" in T:
|
| 232 |
+
n += 1
|
| 233 |
+
eqs = []
|
| 234 |
+
for i in range(n):
|
| 235 |
+
h_geq = gate(f"modular.mod5.layer1.geq{i}", bits)
|
| 236 |
+
h_leq = gate(f"modular.mod5.layer1.leq{i}", bits)
|
| 237 |
+
eqs.append(gate(f"modular.mod5.layer2.eq{i}", [h_geq, h_leq]))
|
| 238 |
+
return gate("modular.mod5.layer3.or", eqs)
|
| 239 |
+
|
| 240 |
+
hits = [v for v in range(256) if mod5(v)]
|
| 241 |
+
print(f" v in [0,255] with mod5(v)==1: {len(hits)} hits, first 12: {hits[:12]}")
|
| 242 |
+
print(f" Sanity: {[h % 5 for h in hits[:12]]}")
|
| 243 |
+
print()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ---------------------------------------------------------------------------
|
| 247 |
+
# Demo 4: Manifest-aware threshold CPU - run a real program
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
|
| 250 |
+
print("=" * 64)
|
| 251 |
+
print(" Demo 4: Threshold CPU running an assembled program")
|
| 252 |
+
print("=" * 64)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class ThresholdCPU10:
|
| 256 |
+
"""CPU runtime matching the bundled small-profile manifest (10-bit addr)."""
|
| 257 |
+
|
| 258 |
+
def __init__(self, addr_bits, mem_bytes):
|
| 259 |
+
self.addr_bits = addr_bits
|
| 260 |
+
self.mem_bytes = mem_bytes
|
| 261 |
+
|
| 262 |
+
# --- memory primitives, fully through threshold weights ---
|
| 263 |
+
def addr_decode(self, addr):
|
| 264 |
+
bits = torch.tensor(int_to_bits_msb(addr, self.addr_bits), dtype=torch.float32)
|
| 265 |
+
w = T["memory.addr_decode.weight"]
|
| 266 |
+
b = T["memory.addr_decode.bias"]
|
| 267 |
+
return heaviside((w * bits).sum(dim=1) + b) # [mem_bytes]
|
| 268 |
+
|
| 269 |
+
def mem_read(self, mem, addr):
|
| 270 |
+
sel = self.addr_decode(addr)
|
| 271 |
+
mem_bits = torch.tensor(
|
| 272 |
+
[int_to_bits_msb(byte, 8) for byte in mem], dtype=torch.float32
|
| 273 |
+
)
|
| 274 |
+
and_w = T["memory.read.and.weight"]
|
| 275 |
+
and_b = T["memory.read.and.bias"]
|
| 276 |
+
or_w = T["memory.read.or.weight"]
|
| 277 |
+
or_b = T["memory.read.or.bias"]
|
| 278 |
+
out_bits = []
|
| 279 |
+
for bit in range(8):
|
| 280 |
+
inp = torch.stack([mem_bits[:, bit], sel], dim=1)
|
| 281 |
+
and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
|
| 282 |
+
out_bits.append(int(heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()))
|
| 283 |
+
return bits_msb_to_int(out_bits)
|
| 284 |
+
|
| 285 |
+
def mem_write(self, mem, addr, value):
|
| 286 |
+
sel = self.addr_decode(addr)
|
| 287 |
+
data_bits = torch.tensor(int_to_bits_msb(value, 8), dtype=torch.float32)
|
| 288 |
+
mem_bits = torch.tensor(
|
| 289 |
+
[int_to_bits_msb(byte, 8) for byte in mem], dtype=torch.float32
|
| 290 |
+
)
|
| 291 |
+
sel_w = T["memory.write.sel.weight"]
|
| 292 |
+
sel_b = T["memory.write.sel.bias"]
|
| 293 |
+
nsel_w = T["memory.write.nsel.weight"].squeeze(1)
|
| 294 |
+
nsel_b = T["memory.write.nsel.bias"]
|
| 295 |
+
and_old_w = T["memory.write.and_old.weight"]
|
| 296 |
+
and_old_b = T["memory.write.and_old.bias"]
|
| 297 |
+
and_new_w = T["memory.write.and_new.weight"]
|
| 298 |
+
and_new_b = T["memory.write.and_new.bias"]
|
| 299 |
+
or_w = T["memory.write.or.weight"]
|
| 300 |
+
or_b = T["memory.write.or.bias"]
|
| 301 |
+
|
| 302 |
+
we = torch.ones_like(sel)
|
| 303 |
+
sel_inp = torch.stack([sel, we], dim=1)
|
| 304 |
+
write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
|
| 305 |
+
nsel = heaviside(write_sel * nsel_w + nsel_b)
|
| 306 |
+
|
| 307 |
+
new_mem = mem[:]
|
| 308 |
+
for bit in range(8):
|
| 309 |
+
old = mem_bits[:, bit]
|
| 310 |
+
data_bit = data_bits[bit].expand(self.mem_bytes)
|
| 311 |
+
inp_old = torch.stack([old, nsel], dim=1)
|
| 312 |
+
inp_new = torch.stack([data_bit, write_sel], dim=1)
|
| 313 |
+
and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
|
| 314 |
+
and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
|
| 315 |
+
or_inp = torch.stack([and_old, and_new], dim=1)
|
| 316 |
+
new_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
|
| 317 |
+
mem_bits[:, bit] = new_bit
|
| 318 |
+
return [bits_msb_to_int([int(b) for b in mem_bits[i].tolist()]) for i in range(self.mem_bytes)]
|
| 319 |
+
|
| 320 |
+
# --- helper to use threshold ALU functions defined above ---
|
| 321 |
+
def step(self, state):
|
| 322 |
+
if state["halted"]:
|
| 323 |
+
return state
|
| 324 |
+
s = dict(state)
|
| 325 |
+
s["mem"] = state["mem"][:]
|
| 326 |
+
s["regs"] = state["regs"][:]
|
| 327 |
+
s["flags"] = state["flags"][:]
|
| 328 |
+
|
| 329 |
+
pc = s["pc"]
|
| 330 |
+
addr_mask = (1 << self.addr_bits) - 1
|
| 331 |
+
hi = self.mem_read(s["mem"], pc & addr_mask)
|
| 332 |
+
lo = self.mem_read(s["mem"], (pc + 1) & addr_mask)
|
| 333 |
+
ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
|
| 334 |
+
opcode = (ir >> 12) & 0xF
|
| 335 |
+
rd = (ir >> 10) & 0x3
|
| 336 |
+
rs = (ir >> 8) & 0x3
|
| 337 |
+
imm = ir & 0xFF
|
| 338 |
+
|
| 339 |
+
next_pc = (pc + 2) & addr_mask
|
| 340 |
+
addr16 = None
|
| 341 |
+
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
|
| 342 |
+
ah = self.mem_read(s["mem"], next_pc)
|
| 343 |
+
al = self.mem_read(s["mem"], (next_pc + 1) & addr_mask)
|
| 344 |
+
addr16 = ((ah & 0xFF) << 8) | (al & 0xFF)
|
| 345 |
+
next_pc = (next_pc + 2) & addr_mask
|
| 346 |
+
addr10 = (addr16 & addr_mask) if addr16 is not None else None
|
| 347 |
+
|
| 348 |
+
a = s["regs"][rd]
|
| 349 |
+
b = s["regs"][rs]
|
| 350 |
+
write = True
|
| 351 |
+
result = a
|
| 352 |
+
carry = 0
|
| 353 |
+
|
| 354 |
+
if opcode == 0x0: # ADD
|
| 355 |
+
result, carry = alu_add(a, b)
|
| 356 |
+
elif opcode == 0x1: # SUB
|
| 357 |
+
result, carry = alu_sub(a, b)
|
| 358 |
+
elif opcode == 0x7: # MUL
|
| 359 |
+
result = alu_mul(a, b)
|
| 360 |
+
elif opcode == 0x9: # CMP
|
| 361 |
+
_r, carry = alu_sub(a, b)
|
| 362 |
+
z = 1 if _r == 0 else 0
|
| 363 |
+
n = 1 if (_r & 0x80) else 0
|
| 364 |
+
s["flags"] = [z, n, carry, 0]
|
| 365 |
+
write = False
|
| 366 |
+
opcode_was_cmp = True
|
| 367 |
+
elif opcode == 0xA: # LOAD
|
| 368 |
+
result = self.mem_read(s["mem"], addr10)
|
| 369 |
+
elif opcode == 0xB: # STORE
|
| 370 |
+
s["mem"] = self.mem_write(s["mem"], addr10, b & 0xFF)
|
| 371 |
+
write = False
|
| 372 |
+
elif opcode == 0xC: # JMP
|
| 373 |
+
s["pc"] = addr10
|
| 374 |
+
return s
|
| 375 |
+
elif opcode == 0xD: # Jcc
|
| 376 |
+
cond = imm & 0x7
|
| 377 |
+
take = False
|
| 378 |
+
z, n, c, v = s["flags"]
|
| 379 |
+
if cond == 0: take = z == 1
|
| 380 |
+
elif cond == 1: take = z == 0
|
| 381 |
+
elif cond == 2: take = c == 1
|
| 382 |
+
elif cond == 3: take = c == 0
|
| 383 |
+
elif cond == 4: take = n == 1
|
| 384 |
+
elif cond == 5: take = n == 0
|
| 385 |
+
elif cond == 6: take = v == 1
|
| 386 |
+
else: take = v == 0
|
| 387 |
+
s["pc"] = addr10 if take else next_pc
|
| 388 |
+
return s
|
| 389 |
+
elif opcode == 0xF: # HALT
|
| 390 |
+
s["halted"] = True
|
| 391 |
+
return s
|
| 392 |
+
|
| 393 |
+
if write and opcode != 0x9:
|
| 394 |
+
s["regs"][rd] = result & 0xFF
|
| 395 |
+
if opcode in (0x0, 0x1, 0x7):
|
| 396 |
+
z = 1 if (result & 0xFF) == 0 else 0
|
| 397 |
+
n = 1 if (result & 0x80) else 0
|
| 398 |
+
s["flags"] = [z, n, carry, 0]
|
| 399 |
+
s["pc"] = next_pc
|
| 400 |
+
return s
|
| 401 |
+
|
| 402 |
+
def run(self, state, max_cycles=64):
|
| 403 |
+
s = state
|
| 404 |
+
cycles = 0
|
| 405 |
+
while not s["halted"] and cycles < max_cycles:
|
| 406 |
+
s = self.step(s)
|
| 407 |
+
cycles += 1
|
| 408 |
+
return s, cycles
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def encode_instr(opcode, rd, rs, imm):
|
| 412 |
+
return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm & 0xFF)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def write_word(mem, addr, value):
|
| 416 |
+
mem[addr] = (value >> 8) & 0xFF
|
| 417 |
+
mem[addr + 1] = value & 0xFF
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
# Program: count down from 5 to 0 with a loop, accumulate sum into R0.
|
| 421 |
+
#
|
| 422 |
+
# R1 = 5
|
| 423 |
+
# R0 = 0
|
| 424 |
+
# loop:
|
| 425 |
+
# R0 = R0 + R1 ; ADD R0, R1
|
| 426 |
+
# R1 = R1 - 1 ; we need an immediate decrement; use SUB R1, R2 with R2=1
|
| 427 |
+
# CMP R1, R3 ; R3=0
|
| 428 |
+
# JNZ loop
|
| 429 |
+
# HALT
|
| 430 |
+
#
|
| 431 |
+
# Memory layout (1KB):
|
| 432 |
+
# 0x0000: LOAD R1 <- M[0x0100] (5)
|
| 433 |
+
# 0x0004: LOAD R2 <- M[0x0101] (1)
|
| 434 |
+
# 0x0008: LOAD R3 <- M[0x0102] (0)
|
| 435 |
+
# 0x000C: LOAD R0 <- M[0x0102] (0)
|
| 436 |
+
# 0x0010: ADD R0, R1
|
| 437 |
+
# 0x0012: SUB R1, R2
|
| 438 |
+
# 0x0014: CMP R1, R3
|
| 439 |
+
# 0x0016: JNZ 0x0010
|
| 440 |
+
# 0x001A: STORE R0 -> M[0x0103]
|
| 441 |
+
# 0x001E: HALT
|
| 442 |
+
|
| 443 |
+
mem = [0] * 1024
|
| 444 |
+
mem[0x100] = 5
|
| 445 |
+
mem[0x101] = 1
|
| 446 |
+
mem[0x102] = 0
|
| 447 |
+
|
| 448 |
+
# LOAD R1 <- M[0x0100]
|
| 449 |
+
write_word(mem, 0x0000, encode_instr(0xA, 1, 0, 0)); write_word(mem, 0x0002, 0x0100)
|
| 450 |
+
# LOAD R2 <- M[0x0101]
|
| 451 |
+
write_word(mem, 0x0004, encode_instr(0xA, 2, 0, 0)); write_word(mem, 0x0006, 0x0101)
|
| 452 |
+
# LOAD R3 <- M[0x0102]
|
| 453 |
+
write_word(mem, 0x0008, encode_instr(0xA, 3, 0, 0)); write_word(mem, 0x000A, 0x0102)
|
| 454 |
+
# LOAD R0 <- M[0x0102]
|
| 455 |
+
write_word(mem, 0x000C, encode_instr(0xA, 0, 0, 0)); write_word(mem, 0x000E, 0x0102)
|
| 456 |
+
# ADD R0, R1
|
| 457 |
+
write_word(mem, 0x0010, encode_instr(0x0, 0, 1, 0))
|
| 458 |
+
# SUB R1, R2
|
| 459 |
+
write_word(mem, 0x0012, encode_instr(0x1, 1, 2, 0))
|
| 460 |
+
# CMP R1, R3
|
| 461 |
+
write_word(mem, 0x0014, encode_instr(0x9, 1, 3, 0))
|
| 462 |
+
# JNZ 0x0010 (cond=1 = NZ)
|
| 463 |
+
write_word(mem, 0x0016, encode_instr(0xD, 0, 0, 0x01)); write_word(mem, 0x0018, 0x0010)
|
| 464 |
+
# STORE R0 -> M[0x0103]
|
| 465 |
+
write_word(mem, 0x001A, encode_instr(0xB, 0, 0, 0)); write_word(mem, 0x001C, 0x0103)
|
| 466 |
+
# HALT
|
| 467 |
+
write_word(mem, 0x001E, encode_instr(0xF, 0, 0, 0))
|
| 468 |
+
|
| 469 |
+
cpu = ThresholdCPU10(addr_bits=ADDR_BITS, mem_bytes=MEM_BYTES)
|
| 470 |
+
state = {
|
| 471 |
+
"pc": 0,
|
| 472 |
+
"regs": [0, 0, 0, 0],
|
| 473 |
+
"flags": [0, 0, 0, 0],
|
| 474 |
+
"mem": mem,
|
| 475 |
+
"halted": False,
|
| 476 |
+
}
|
| 477 |
+
print(" Program: sum 5+4+3+2+1 via loop (uses ADD/SUB/CMP/Jcc/LOAD/STORE/HALT, all threshold-gated)")
|
| 478 |
+
print(" Running ...")
|
| 479 |
+
final, cycles = cpu.run(state, max_cycles=200)
|
| 480 |
+
print(f" Halted after {cycles} cycles")
|
| 481 |
+
print(f" R0={final['regs'][0]} R1={final['regs'][1]} R2={final['regs'][2]} R3={final['regs'][3]}")
|
| 482 |
+
print(f" M[0x0103] = {final['mem'][0x103]} (expected 15)")
|
| 483 |
+
print()
|
| 484 |
+
print("Done.")
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f702736cd85124aac22602bf44617698309c03739a254b338409df87e22344c9
|
| 3 |
+
size 12434484
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c6761fa0366a19cdb9abb7c1c72f53b3a3a07032056b6d17dbed4131cc5e21d
|
| 3 |
+
size 14378864
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:246546bba4668a80a81e32b115d883d57b6b49bdfe8254034090089d5bf168cf
|
| 3 |
+
size 11561076
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2daa9ab42ab63534010e363adbb3423502ebbe94a4b354797c25dece5eb5948
|
| 3 |
+
size 45730164
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1808dd34084e68120bccd277310749e047c357274440901baf2b01ca64e9e41
|
| 3 |
+
size 14640476
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7487bbfe4da343bb2072c190e33b7861b452c947efd424a068927c413595049
|
| 3 |
+
size 12534076
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58dfbdf0a987c1675a68439d86a57a7631a7657ea60b6b0d3e568dfdeee88f2e
|
| 3 |
+
size 12704876
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:759920dcb38a340ee31f4d116df4983322258b796ac5d6021f7ca165986f5f5b
|
| 3 |
+
size 13104212
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18f4f3420fb307d90ea7a8fe356c196a59d7a0f2ed4ec57679d87b209a7fec22
|
| 3 |
+
size 47693920
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:51e14c8819de3402881ce2ffe3cdd7e94a801c038c6ef8495110144e9348e2e7
|
| 3 |
+
size 16604104
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13ac4eb1c793a6331a2ecfa13d3372edc9f4649163883244847ca6616062de05
|
| 3 |
+
size 14497800
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8c389b1730cc297f40944815aebda1b1a71b79bff738e9da767c82609e9d9bd
|
| 3 |
+
size 14668512
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3c67a2047c0cf7370e802727b9be51d8b7185dedfe108409968f0d838157e04
|
| 3 |
+
size 15067856
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:acde9e66a5bae870b5684ddc8592a206f00b518e088e90965a73bfa35274ba2a
|
| 3 |
+
size 44846164
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e318727316bfb34f82cdc4a2b627d9f8475c3282cab67a6424ba642350dc823
|
| 3 |
+
size 13756476
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7b2c49af2b18786699351235d4d051afd7452e17616f0f06a87b3e5e9820da66
|
| 3 |
+
size 11649932
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:40fe6db0454dd6ba33072a18f6c81ed1463830b270b708b9ae45f976e32cfc50
|
| 3 |
+
size 11820860
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:547aef648729c49dc106c14d05bfcdf12a6f1aca5de5b7d1c475fce65aef1373
|
| 3 |
+
size 12220204
|