CharlesCNorton
commited on
Commit
·
6087b2e
1
Parent(s):
7c967c0
Add SHL, SHR, MUL, DIV, and comparator circuits
Browse files- SHL/SHR: 8 identity gates each with zero injection
- Comparators: GT, LT, GE, LE (single-layer), EQ (two-layer AND)
- MUL: 64 partial product AND gates
- DIV: 8 stages with comparison + conditional mux
build.py: Added add_shl_shr(), add_mul(), add_div(), add_comparators(),
cmd_alu subcommand, input inference for new circuits
eval.py: Added tests for all new circuits (5282 -> 5884 tests)
threshold_cpu.py: Added shift_left(), shift_right(), multiply(), divide()
to ThresholdALU; updated ref_step() and step()
Tensors: 9429 -> 10399, Fitness: 1.000000
- build.py +224 -1
- eval.py +191 -0
- neural_computer.safetensors +2 -2
- threshold_cpu.py +106 -8
build.py
CHANGED
|
@@ -227,6 +227,117 @@ def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor]) -> None:
|
|
| 227 |
add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
|
| 228 |
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
def update_manifest(tensors: Dict[str, torch.Tensor]) -> None:
|
| 231 |
tensors["manifest.memory_bytes"] = torch.tensor([float(MEM_BYTES)], dtype=torch.float32)
|
| 232 |
tensors["manifest.pc_width"] = torch.tensor([float(ADDR_BITS)], dtype=torch.float32)
|
|
@@ -493,6 +604,49 @@ def infer_alu_inputs(gate: str, reg: SignalRegistry) -> List[int]:
|
|
| 493 |
return [reg.get_id(f"$opcode[{i}]") for i in range(4)]
|
| 494 |
if 'aluflags' in gate:
|
| 495 |
return [reg.register("$result"), reg.register("$carry"), reg.register("$overflow")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
if '.and' in gate or '.or' in gate or '.xor' in gate:
|
| 497 |
m = re.search(r'bit(\d+)', gate)
|
| 498 |
if m:
|
|
@@ -632,6 +786,20 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
|
|
| 632 |
return infer_adcsbc_inputs(gate, "arithmetic.sbc8bit", True, reg)
|
| 633 |
if 'sub8bit' in gate:
|
| 634 |
return infer_sub8bit_inputs(gate, reg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
for i in range(8):
|
| 636 |
reg.register(f"$a[{i}]")
|
| 637 |
reg.register(f"$b[{i}]")
|
|
@@ -752,9 +920,61 @@ def cmd_inputs(args) -> None:
|
|
| 752 |
print("=" * 60)
|
| 753 |
|
| 754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
def cmd_all(args) -> None:
|
| 756 |
print("Running: memory")
|
| 757 |
cmd_memory(args)
|
|
|
|
|
|
|
| 758 |
print("\nRunning: inputs")
|
| 759 |
cmd_inputs(args)
|
| 760 |
|
|
@@ -766,11 +986,14 @@ def main() -> None:
|
|
| 766 |
parser.add_argument("--manifest", action="store_true", help="Write tensors.txt manifest (memory only)")
|
| 767 |
subparsers = parser.add_subparsers(dest="command", help="Subcommands")
|
| 768 |
subparsers.add_parser("memory", help="Generate 64KB memory circuits")
|
|
|
|
| 769 |
subparsers.add_parser("inputs", help="Add .inputs metadata tensors")
|
| 770 |
-
subparsers.add_parser("all", help="Run memory then inputs")
|
| 771 |
args = parser.parse_args()
|
| 772 |
if args.command == "memory":
|
| 773 |
cmd_memory(args)
|
|
|
|
|
|
|
| 774 |
elif args.command == "inputs":
|
| 775 |
cmd_inputs(args)
|
| 776 |
elif args.command == "all":
|
|
|
|
| 227 |
add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
|
| 228 |
|
| 229 |
|
| 230 |
+
def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None:
|
| 231 |
+
"""Add SHL (shift left) and SHR (shift right) circuits.
|
| 232 |
+
|
| 233 |
+
Identity gate: w=2, b=-1 -> H(x*2 - 1) = x for x in {0,1}
|
| 234 |
+
Zero gate: w=0, b=-1 -> H(-1) = 0
|
| 235 |
+
|
| 236 |
+
SHL (MSB-first): out[i] = in[i+1] for i<7, out[7] = 0
|
| 237 |
+
SHR (MSB-first): out[0] = 0, out[i] = in[i-1] for i>0
|
| 238 |
+
"""
|
| 239 |
+
for bit in range(8):
|
| 240 |
+
if bit < 7:
|
| 241 |
+
add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [2.0], [-1.0])
|
| 242 |
+
else:
|
| 243 |
+
add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [0.0], [-1.0])
|
| 244 |
+
|
| 245 |
+
for bit in range(8):
|
| 246 |
+
if bit > 0:
|
| 247 |
+
add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [2.0], [-1.0])
|
| 248 |
+
else:
|
| 249 |
+
add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [0.0], [-1.0])
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def add_mul(tensors: Dict[str, torch.Tensor]) -> None:
|
| 253 |
+
"""Add 8-bit multiplication circuit.
|
| 254 |
+
|
| 255 |
+
Produces low 8 bits of the 16-bit result.
|
| 256 |
+
|
| 257 |
+
Structure:
|
| 258 |
+
- 64 AND gates for partial products P[i][j] = A[i] AND B[j]
|
| 259 |
+
- Uses existing ripple-carry adder components for summation
|
| 260 |
+
|
| 261 |
+
The multiply method in ThresholdALU computes:
|
| 262 |
+
1. Partial products via these AND gates
|
| 263 |
+
2. Shift-add accumulation via existing 8-bit adder
|
| 264 |
+
"""
|
| 265 |
+
# AND gates for partial products: P[i][j] = A[i] AND B[j]
|
| 266 |
+
# These compute whether bit i of A and bit j of B are both 1
|
| 267 |
+
for i in range(8):
|
| 268 |
+
for j in range(8):
|
| 269 |
+
add_gate(tensors, f"alu.alu8bit.mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0])
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def add_div(tensors: Dict[str, torch.Tensor]) -> None:
|
| 273 |
+
"""Add 8-bit division circuit.
|
| 274 |
+
|
| 275 |
+
Produces quotient (8 bits) and remainder (8 bits).
|
| 276 |
+
|
| 277 |
+
Uses restoring division algorithm:
|
| 278 |
+
- 8 iterations, each producing one quotient bit
|
| 279 |
+
- Each iteration: compare, conditionally subtract, shift
|
| 280 |
+
|
| 281 |
+
Structure:
|
| 282 |
+
- 8 comparison gates (one per iteration)
|
| 283 |
+
- 8 conditional subtraction stages
|
| 284 |
+
- Uses existing comparator and subtractor components
|
| 285 |
+
"""
|
| 286 |
+
# Comparison gates: check if (remainder << 1 | next_bit) >= divisor
|
| 287 |
+
for stage in range(8):
|
| 288 |
+
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.cmp",
|
| 289 |
+
[128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0,
|
| 290 |
+
-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0], [0.0])
|
| 291 |
+
|
| 292 |
+
# Conditional mux gates: select (rem - div) or rem based on comparison
|
| 293 |
+
for stage in range(8):
|
| 294 |
+
for bit in range(8):
|
| 295 |
+
# NOT for inverting comparison result
|
| 296 |
+
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0])
|
| 297 |
+
# AND gates for mux
|
| 298 |
+
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_a", [1.0, 1.0], [-2.0])
|
| 299 |
+
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_b", [1.0, 1.0], [-2.0])
|
| 300 |
+
# OR gate for mux output
|
| 301 |
+
add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0])
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
|
| 305 |
+
"""Add 8-bit comparator circuits (GT, LT, GE, LE, EQ).
|
| 306 |
+
|
| 307 |
+
Each comparator takes 16 inputs (8 bits from A, 8 bits from B) in MSB-first order.
|
| 308 |
+
Uses weighted sum comparison on the binary representation.
|
| 309 |
+
|
| 310 |
+
For unsigned comparison of A vs B:
|
| 311 |
+
- Assign positional weights: bit i has weight 2^(7-i)
|
| 312 |
+
- A > B: sum(a_i * w_i) > sum(b_i * w_i)
|
| 313 |
+
- This becomes: sum(a_i * w_i - b_i * w_i) > 0
|
| 314 |
+
- Or: sum((a_i - b_i) * w_i) > 0
|
| 315 |
+
|
| 316 |
+
Threshold gate: H(sum(x_i * w_i) + b) = 1 if sum >= -b
|
| 317 |
+
|
| 318 |
+
For A > B: weights = [128, 64, 32, 16, 8, 4, 2, 1, -128, -64, -32, -16, -8, -4, -2, -1]
|
| 319 |
+
bias = -1 (strictly greater, so need sum >= 1)
|
| 320 |
+
For A >= B: bias = 0 (sum >= 0)
|
| 321 |
+
For A < B: flip weights, bias = -1
|
| 322 |
+
For A <= B: flip weights, bias = 0
|
| 323 |
+
For A == B: need A >= B AND A <= B (two-layer)
|
| 324 |
+
"""
|
| 325 |
+
pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0]
|
| 326 |
+
neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0]
|
| 327 |
+
|
| 328 |
+
gt_weights = pos_weights + neg_weights
|
| 329 |
+
lt_weights = neg_weights + pos_weights
|
| 330 |
+
|
| 331 |
+
add_gate(tensors, "arithmetic.greaterthan8bit", gt_weights, [-1.0])
|
| 332 |
+
add_gate(tensors, "arithmetic.greaterorequal8bit", gt_weights, [0.0])
|
| 333 |
+
add_gate(tensors, "arithmetic.lessthan8bit", lt_weights, [-1.0])
|
| 334 |
+
add_gate(tensors, "arithmetic.lessorequal8bit", lt_weights, [0.0])
|
| 335 |
+
|
| 336 |
+
add_gate(tensors, "arithmetic.equality8bit.layer1.geq", gt_weights, [0.0])
|
| 337 |
+
add_gate(tensors, "arithmetic.equality8bit.layer1.leq", lt_weights, [0.0])
|
| 338 |
+
add_gate(tensors, "arithmetic.equality8bit.layer2", [1.0, 1.0], [-2.0])
|
| 339 |
+
|
| 340 |
+
|
| 341 |
def update_manifest(tensors: Dict[str, torch.Tensor]) -> None:
|
| 342 |
tensors["manifest.memory_bytes"] = torch.tensor([float(MEM_BYTES)], dtype=torch.float32)
|
| 343 |
tensors["manifest.pc_width"] = torch.tensor([float(ADDR_BITS)], dtype=torch.float32)
|
|
|
|
| 604 |
return [reg.get_id(f"$opcode[{i}]") for i in range(4)]
|
| 605 |
if 'aluflags' in gate:
|
| 606 |
return [reg.register("$result"), reg.register("$carry"), reg.register("$overflow")]
|
| 607 |
+
if '.shl.bit' in gate:
|
| 608 |
+
m = re.search(r'bit(\d+)', gate)
|
| 609 |
+
if m:
|
| 610 |
+
bit = int(m.group(1))
|
| 611 |
+
if bit < 7:
|
| 612 |
+
return [reg.get_id(f"$a[{bit + 1}]")]
|
| 613 |
+
else:
|
| 614 |
+
return [reg.get_id("#0")]
|
| 615 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)]
|
| 616 |
+
if '.shr.bit' in gate:
|
| 617 |
+
m = re.search(r'bit(\d+)', gate)
|
| 618 |
+
if m:
|
| 619 |
+
bit = int(m.group(1))
|
| 620 |
+
if bit > 0:
|
| 621 |
+
return [reg.get_id(f"$a[{bit - 1}]")]
|
| 622 |
+
else:
|
| 623 |
+
return [reg.get_id("#0")]
|
| 624 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)]
|
| 625 |
+
if '.mul.pp.a' in gate:
|
| 626 |
+
m = re.search(r'a(\d+)b(\d+)', gate)
|
| 627 |
+
if m:
|
| 628 |
+
i, j = int(m.group(1)), int(m.group(2))
|
| 629 |
+
return [reg.get_id(f"$a[{i}]"), reg.get_id(f"$b[{j}]")]
|
| 630 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
|
| 631 |
+
if '.mul.' in gate:
|
| 632 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
|
| 633 |
+
if '.div.stage' in gate:
|
| 634 |
+
if '.cmp' in gate:
|
| 635 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
|
| 636 |
+
if '.mux.bit' in gate:
|
| 637 |
+
m = re.search(r'stage(\d+)\.mux\.bit(\d+)', gate)
|
| 638 |
+
if m:
|
| 639 |
+
stage, bit = int(m.group(1)), int(m.group(2))
|
| 640 |
+
prefix = f"alu.alu8bit.div.stage{stage}"
|
| 641 |
+
if '.not_sel' in gate:
|
| 642 |
+
return [reg.register(f"{prefix}.cmp")]
|
| 643 |
+
if '.and_a' in gate:
|
| 644 |
+
return [reg.register(f"$rem[{bit}]"), reg.register(f"{prefix}.mux.bit{bit}.not_sel")]
|
| 645 |
+
if '.and_b' in gate:
|
| 646 |
+
return [reg.register(f"$sub[{bit}]"), reg.register(f"{prefix}.cmp")]
|
| 647 |
+
if '.or' in gate:
|
| 648 |
+
return [reg.register(f"{prefix}.mux.bit{bit}.and_a"), reg.register(f"{prefix}.mux.bit{bit}.and_b")]
|
| 649 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
|
| 650 |
if '.and' in gate or '.or' in gate or '.xor' in gate:
|
| 651 |
m = re.search(r'bit(\d+)', gate)
|
| 652 |
if m:
|
|
|
|
| 786 |
return infer_adcsbc_inputs(gate, "arithmetic.sbc8bit", True, reg)
|
| 787 |
if 'sub8bit' in gate:
|
| 788 |
return infer_sub8bit_inputs(gate, reg)
|
| 789 |
+
if any(cmp in gate for cmp in ['greaterthan8bit', 'lessthan8bit', 'greaterorequal8bit', 'lessorequal8bit']):
|
| 790 |
+
for i in range(8):
|
| 791 |
+
reg.register(f"$a[{i}]")
|
| 792 |
+
reg.register(f"$b[{i}]")
|
| 793 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
|
| 794 |
+
if 'equality8bit' in gate:
|
| 795 |
+
for i in range(8):
|
| 796 |
+
reg.register(f"$a[{i}]")
|
| 797 |
+
reg.register(f"$b[{i}]")
|
| 798 |
+
if 'layer1' in gate:
|
| 799 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
|
| 800 |
+
if 'layer2' in gate:
|
| 801 |
+
return [reg.register("arithmetic.equality8bit.layer1.geq"), reg.register("arithmetic.equality8bit.layer1.leq")]
|
| 802 |
+
return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
|
| 803 |
for i in range(8):
|
| 804 |
reg.register(f"$a[{i}]")
|
| 805 |
reg.register(f"$b[{i}]")
|
|
|
|
| 920 |
print("=" * 60)
|
| 921 |
|
| 922 |
|
| 923 |
+
def cmd_alu(args) -> None:
|
| 924 |
+
print("=" * 60)
|
| 925 |
+
print(" BUILD ALU CIRCUITS")
|
| 926 |
+
print("=" * 60)
|
| 927 |
+
print(f"\nLoading: {args.model}")
|
| 928 |
+
tensors = load_tensors(args.model)
|
| 929 |
+
print(f" Loaded {len(tensors)} tensors")
|
| 930 |
+
print("\nDropping existing ALU extension tensors...")
|
| 931 |
+
drop_prefixes(tensors, [
|
| 932 |
+
"alu.alu8bit.shl.", "alu.alu8bit.shr.",
|
| 933 |
+
"alu.alu8bit.mul.", "alu.alu8bit.div.",
|
| 934 |
+
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
|
| 935 |
+
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
|
| 936 |
+
"arithmetic.equality8bit.",
|
| 937 |
+
])
|
| 938 |
+
print(f" Now {len(tensors)} tensors")
|
| 939 |
+
print("\nGenerating SHL/SHR circuits...")
|
| 940 |
+
try:
|
| 941 |
+
add_shl_shr(tensors)
|
| 942 |
+
print(" Added SHL (8 gates), SHR (8 gates)")
|
| 943 |
+
except ValueError as e:
|
| 944 |
+
print(f" SHL/SHR already exist: {e}")
|
| 945 |
+
print("\nGenerating MUL circuit...")
|
| 946 |
+
try:
|
| 947 |
+
add_mul(tensors)
|
| 948 |
+
print(" Added MUL (64 partial product AND gates)")
|
| 949 |
+
except ValueError as e:
|
| 950 |
+
print(f" MUL already exists: {e}")
|
| 951 |
+
print("\nGenerating DIV circuit...")
|
| 952 |
+
try:
|
| 953 |
+
add_div(tensors)
|
| 954 |
+
print(" Added DIV (8 stages x comparison + mux)")
|
| 955 |
+
except ValueError as e:
|
| 956 |
+
print(f" DIV already exists: {e}")
|
| 957 |
+
print("\nGenerating comparator circuits...")
|
| 958 |
+
try:
|
| 959 |
+
add_comparators(tensors)
|
| 960 |
+
print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)")
|
| 961 |
+
except ValueError as e:
|
| 962 |
+
print(f" Comparators already exist: {e}")
|
| 963 |
+
if args.apply:
|
| 964 |
+
print(f"\nSaving: {args.model}")
|
| 965 |
+
save_file(tensors, str(args.model))
|
| 966 |
+
print(" Done.")
|
| 967 |
+
else:
|
| 968 |
+
print("\n[DRY-RUN] Use --apply to save.")
|
| 969 |
+
print(f"\nTotal: {len(tensors)} tensors")
|
| 970 |
+
print("=" * 60)
|
| 971 |
+
|
| 972 |
+
|
| 973 |
def cmd_all(args) -> None:
|
| 974 |
print("Running: memory")
|
| 975 |
cmd_memory(args)
|
| 976 |
+
print("\nRunning: alu")
|
| 977 |
+
cmd_alu(args)
|
| 978 |
print("\nRunning: inputs")
|
| 979 |
cmd_inputs(args)
|
| 980 |
|
|
|
|
| 986 |
parser.add_argument("--manifest", action="store_true", help="Write tensors.txt manifest (memory only)")
|
| 987 |
subparsers = parser.add_subparsers(dest="command", help="Subcommands")
|
| 988 |
subparsers.add_parser("memory", help="Generate 64KB memory circuits")
|
| 989 |
+
subparsers.add_parser("alu", help="Generate ALU extension circuits (SHL, SHR, comparators)")
|
| 990 |
subparsers.add_parser("inputs", help="Add .inputs metadata tensors")
|
| 991 |
+
subparsers.add_parser("all", help="Run memory, alu, then inputs")
|
| 992 |
args = parser.parse_args()
|
| 993 |
if args.command == "memory":
|
| 994 |
cmd_memory(args)
|
| 995 |
+
elif args.command == "alu":
|
| 996 |
+
cmd_alu(args)
|
| 997 |
elif args.command == "inputs":
|
| 998 |
cmd_inputs(args)
|
| 999 |
elif args.command == "all":
|
eval.py
CHANGED
|
@@ -588,6 +588,8 @@ class BatchedFitnessEvaluator:
|
|
| 588 |
]
|
| 589 |
|
| 590 |
for name, op in comparators:
|
|
|
|
|
|
|
| 591 |
try:
|
| 592 |
s, t = self._test_comparator(pop, name, op, debug)
|
| 593 |
scores += s
|
|
@@ -595,6 +597,53 @@ class BatchedFitnessEvaluator:
|
|
| 595 |
except KeyError:
|
| 596 |
pass # Circuit not present
|
| 597 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
return scores, total
|
| 599 |
|
| 600 |
# =========================================================================
|
|
@@ -1231,6 +1280,148 @@ class BatchedFitnessEvaluator:
|
|
| 1231 |
except (KeyError, RuntimeError):
|
| 1232 |
pass
|
| 1233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1234 |
return scores, total
|
| 1235 |
|
| 1236 |
# =========================================================================
|
|
|
|
| 588 |
]
|
| 589 |
|
| 590 |
for name, op in comparators:
|
| 591 |
+
if name == 'equality8bit':
|
| 592 |
+
continue # Handle separately as two-layer
|
| 593 |
try:
|
| 594 |
s, t = self._test_comparator(pop, name, op, debug)
|
| 595 |
scores += s
|
|
|
|
| 597 |
except KeyError:
|
| 598 |
pass # Circuit not present
|
| 599 |
|
| 600 |
+
# Two-layer equality circuit
|
| 601 |
+
try:
|
| 602 |
+
prefix = 'arithmetic.equality8bit'
|
| 603 |
+
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
|
| 604 |
+
for a, b in zip(self.comp_a, self.comp_b)],
|
| 605 |
+
device=self.device)
|
| 606 |
+
|
| 607 |
+
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1)
|
| 608 |
+
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1)
|
| 609 |
+
inputs = torch.cat([a_bits, b_bits], dim=1)
|
| 610 |
+
|
| 611 |
+
# Layer 1: geq and leq
|
| 612 |
+
w_geq = pop[f'{prefix}.layer1.geq.weight']
|
| 613 |
+
b_geq = pop[f'{prefix}.layer1.geq.bias']
|
| 614 |
+
w_leq = pop[f'{prefix}.layer1.leq.weight']
|
| 615 |
+
b_leq = pop[f'{prefix}.layer1.leq.bias']
|
| 616 |
+
|
| 617 |
+
h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
|
| 618 |
+
h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
|
| 619 |
+
hidden = torch.stack([h_geq, h_leq], dim=-1) # [num_tests, pop_size, 2]
|
| 620 |
+
|
| 621 |
+
# Layer 2: AND
|
| 622 |
+
w2 = pop[f'{prefix}.layer2.weight']
|
| 623 |
+
b2 = pop[f'{prefix}.layer2.bias']
|
| 624 |
+
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
|
| 625 |
+
|
| 626 |
+
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 627 |
+
|
| 628 |
+
failures = []
|
| 629 |
+
if pop_size == 1:
|
| 630 |
+
for i in range(len(self.comp_a)):
|
| 631 |
+
if out[i, 0].item() != expected[i].item():
|
| 632 |
+
failures.append((
|
| 633 |
+
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
|
| 634 |
+
expected[i].item(),
|
| 635 |
+
out[i, 0].item()
|
| 636 |
+
))
|
| 637 |
+
|
| 638 |
+
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures)
|
| 639 |
+
if debug:
|
| 640 |
+
r = self.results[-1]
|
| 641 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 642 |
+
scores += correct
|
| 643 |
+
total += len(self.comp_a)
|
| 644 |
+
except KeyError:
|
| 645 |
+
pass
|
| 646 |
+
|
| 647 |
return scores, total
|
| 648 |
|
| 649 |
# =========================================================================
|
|
|
|
| 1280 |
except (KeyError, RuntimeError):
|
| 1281 |
pass
|
| 1282 |
|
| 1283 |
+
# SHL (shift left)
|
| 1284 |
+
try:
|
| 1285 |
+
op_scores = torch.zeros(pop_size, device=self.device)
|
| 1286 |
+
op_total = 0
|
| 1287 |
+
|
| 1288 |
+
for a_val, _ in test_vals:
|
| 1289 |
+
expected_val = (a_val << 1) & 0xFF
|
| 1290 |
+
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
|
| 1291 |
+
device=self.device, dtype=torch.float32)
|
| 1292 |
+
out_bits = []
|
| 1293 |
+
for bit in range(8):
|
| 1294 |
+
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size)
|
| 1295 |
+
b = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size)
|
| 1296 |
+
if bit < 7:
|
| 1297 |
+
inp = a_bits[bit + 1].unsqueeze(0).expand(pop_size)
|
| 1298 |
+
else:
|
| 1299 |
+
inp = torch.zeros(pop_size, device=self.device)
|
| 1300 |
+
out = heaviside(inp * w + b)
|
| 1301 |
+
out_bits.append(out)
|
| 1302 |
+
out = torch.stack(out_bits, dim=-1) # [pop, 8]
|
| 1303 |
+
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
|
| 1304 |
+
device=self.device, dtype=torch.float32)
|
| 1305 |
+
correct = (out == expected.unsqueeze(0)).float().sum(1)
|
| 1306 |
+
op_scores += correct
|
| 1307 |
+
op_total += 8
|
| 1308 |
+
|
| 1309 |
+
scores += op_scores
|
| 1310 |
+
total += op_total
|
| 1311 |
+
self._record('alu.alu8bit.shl', int(op_scores[0].item()), op_total, [])
|
| 1312 |
+
if debug:
|
| 1313 |
+
r = self.results[-1]
|
| 1314 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1315 |
+
except (KeyError, RuntimeError) as e:
|
| 1316 |
+
if debug:
|
| 1317 |
+
print(f" alu.alu8bit.shl: SKIP ({e})")
|
| 1318 |
+
|
| 1319 |
+
# SHR (shift right)
|
| 1320 |
+
try:
|
| 1321 |
+
op_scores = torch.zeros(pop_size, device=self.device)
|
| 1322 |
+
op_total = 0
|
| 1323 |
+
|
| 1324 |
+
for a_val, _ in test_vals:
|
| 1325 |
+
expected_val = (a_val >> 1) & 0xFF
|
| 1326 |
+
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
|
| 1327 |
+
device=self.device, dtype=torch.float32)
|
| 1328 |
+
out_bits = []
|
| 1329 |
+
for bit in range(8):
|
| 1330 |
+
w = pop[f'alu.alu8bit.shr.bit{bit}.weight'].view(pop_size)
|
| 1331 |
+
b = pop[f'alu.alu8bit.shr.bit{bit}.bias'].view(pop_size)
|
| 1332 |
+
if bit > 0:
|
| 1333 |
+
inp = a_bits[bit - 1].unsqueeze(0).expand(pop_size)
|
| 1334 |
+
else:
|
| 1335 |
+
inp = torch.zeros(pop_size, device=self.device)
|
| 1336 |
+
out = heaviside(inp * w + b)
|
| 1337 |
+
out_bits.append(out)
|
| 1338 |
+
out = torch.stack(out_bits, dim=-1) # [pop, 8]
|
| 1339 |
+
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
|
| 1340 |
+
device=self.device, dtype=torch.float32)
|
| 1341 |
+
correct = (out == expected.unsqueeze(0)).float().sum(1)
|
| 1342 |
+
op_scores += correct
|
| 1343 |
+
op_total += 8
|
| 1344 |
+
|
| 1345 |
+
scores += op_scores
|
| 1346 |
+
total += op_total
|
| 1347 |
+
self._record('alu.alu8bit.shr', int(op_scores[0].item()), op_total, [])
|
| 1348 |
+
if debug:
|
| 1349 |
+
r = self.results[-1]
|
| 1350 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1351 |
+
except (KeyError, RuntimeError) as e:
|
| 1352 |
+
if debug:
|
| 1353 |
+
print(f" alu.alu8bit.shr: SKIP ({e})")
|
| 1354 |
+
|
| 1355 |
+
# MUL (partial products only - just verify AND gates work)
|
| 1356 |
+
try:
|
| 1357 |
+
op_scores = torch.zeros(pop_size, device=self.device)
|
| 1358 |
+
op_total = 0
|
| 1359 |
+
|
| 1360 |
+
mul_tests = [(3, 4), (7, 8), (15, 17), (0, 255)]
|
| 1361 |
+
for a_val, b_val in mul_tests:
|
| 1362 |
+
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
|
| 1363 |
+
device=self.device, dtype=torch.float32)
|
| 1364 |
+
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
|
| 1365 |
+
device=self.device, dtype=torch.float32)
|
| 1366 |
+
|
| 1367 |
+
# Test partial product AND gates
|
| 1368 |
+
for i in range(8):
|
| 1369 |
+
for j in range(8):
|
| 1370 |
+
w = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.weight'].view(pop_size, 2)
|
| 1371 |
+
b = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.bias'].view(pop_size)
|
| 1372 |
+
inp = torch.tensor([a_bits[i].item(), b_bits[j].item()], device=self.device)
|
| 1373 |
+
out = heaviside((inp * w).sum(-1) + b)
|
| 1374 |
+
expected = float(int(a_bits[i].item()) & int(b_bits[j].item()))
|
| 1375 |
+
correct = (out == expected).float()
|
| 1376 |
+
op_scores += correct
|
| 1377 |
+
op_total += 1
|
| 1378 |
+
|
| 1379 |
+
scores += op_scores
|
| 1380 |
+
total += op_total
|
| 1381 |
+
self._record('alu.alu8bit.mul', int(op_scores[0].item()), op_total, [])
|
| 1382 |
+
if debug:
|
| 1383 |
+
r = self.results[-1]
|
| 1384 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1385 |
+
except (KeyError, RuntimeError) as e:
|
| 1386 |
+
if debug:
|
| 1387 |
+
print(f" alu.alu8bit.mul: SKIP ({e})")
|
| 1388 |
+
|
| 1389 |
+
# DIV (comparison gates only)
|
| 1390 |
+
try:
|
| 1391 |
+
op_scores = torch.zeros(pop_size, device=self.device)
|
| 1392 |
+
op_total = 0
|
| 1393 |
+
|
| 1394 |
+
div_tests = [(100, 10), (255, 17), (50, 7), (128, 16)]
|
| 1395 |
+
for a_val, b_val in div_tests:
|
| 1396 |
+
# Test each stage's comparison gate
|
| 1397 |
+
for stage in range(8):
|
| 1398 |
+
w = pop[f'alu.alu8bit.div.stage{stage}.cmp.weight'].view(pop_size, 16)
|
| 1399 |
+
b = pop[f'alu.alu8bit.div.stage{stage}.cmp.bias'].view(pop_size)
|
| 1400 |
+
|
| 1401 |
+
# Create test inputs (simplified: just test that gate exists and has correct shape)
|
| 1402 |
+
test_rem = (a_val >> (7 - stage)) & 0xFF
|
| 1403 |
+
rem_bits = torch.tensor([((test_rem >> (7 - i)) & 1) for i in range(8)],
|
| 1404 |
+
device=self.device, dtype=torch.float32)
|
| 1405 |
+
div_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
|
| 1406 |
+
device=self.device, dtype=torch.float32)
|
| 1407 |
+
inp = torch.cat([rem_bits, div_bits])
|
| 1408 |
+
|
| 1409 |
+
out = heaviside((inp * w).sum(-1) + b)
|
| 1410 |
+
expected = float(test_rem >= b_val)
|
| 1411 |
+
correct = (out == expected).float()
|
| 1412 |
+
op_scores += correct
|
| 1413 |
+
op_total += 1
|
| 1414 |
+
|
| 1415 |
+
scores += op_scores
|
| 1416 |
+
total += op_total
|
| 1417 |
+
self._record('alu.alu8bit.div', int(op_scores[0].item()), op_total, [])
|
| 1418 |
+
if debug:
|
| 1419 |
+
r = self.results[-1]
|
| 1420 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1421 |
+
except (KeyError, RuntimeError) as e:
|
| 1422 |
+
if debug:
|
| 1423 |
+
print(f" alu.alu8bit.div: SKIP ({e})")
|
| 1424 |
+
|
| 1425 |
return scores, total
|
| 1426 |
|
| 1427 |
# =========================================================================
|
neural_computer.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68c76f0ec6822e071d2532c4ca40a216d959d5344e617990371bbc856134c4a0
|
| 3 |
+
size 34342684
|
threshold_cpu.py
CHANGED
|
@@ -193,13 +193,16 @@ def ref_step(state: CPUState) -> CPUState:
|
|
| 193 |
elif opcode == 0x4:
|
| 194 |
result = a ^ b
|
| 195 |
elif opcode == 0x5:
|
| 196 |
-
|
| 197 |
elif opcode == 0x6:
|
| 198 |
-
|
| 199 |
elif opcode == 0x7:
|
| 200 |
-
|
| 201 |
elif opcode == 0x8:
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
| 203 |
elif opcode == 0x9:
|
| 204 |
result, carry, overflow = alu_sub(a, b)
|
| 205 |
write_result = False
|
|
@@ -431,6 +434,101 @@ class ThresholdALU:
|
|
| 431 |
|
| 432 |
return bits_to_int(out_bits)
|
| 433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
class ThresholdCPU:
|
| 436 |
def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None:
|
|
@@ -574,13 +672,13 @@ class ThresholdCPU:
|
|
| 574 |
elif opcode == 0x4:
|
| 575 |
result = self.alu.bitwise_xor(a, b)
|
| 576 |
elif opcode == 0x5:
|
| 577 |
-
|
| 578 |
elif opcode == 0x6:
|
| 579 |
-
|
| 580 |
elif opcode == 0x7:
|
| 581 |
-
|
| 582 |
elif opcode == 0x8:
|
| 583 |
-
|
| 584 |
elif opcode == 0x9:
|
| 585 |
result, carry, overflow = self.alu.sub(a, b)
|
| 586 |
write_result = False
|
|
|
|
| 193 |
elif opcode == 0x4:
|
| 194 |
result = a ^ b
|
| 195 |
elif opcode == 0x5:
|
| 196 |
+
result = (a << 1) & 0xFF
|
| 197 |
elif opcode == 0x6:
|
| 198 |
+
result = (a >> 1) & 0xFF
|
| 199 |
elif opcode == 0x7:
|
| 200 |
+
result = (a * b) & 0xFF
|
| 201 |
elif opcode == 0x8:
|
| 202 |
+
if b == 0:
|
| 203 |
+
result = 0xFF
|
| 204 |
+
else:
|
| 205 |
+
result = a // b
|
| 206 |
elif opcode == 0x9:
|
| 207 |
result, carry, overflow = alu_sub(a, b)
|
| 208 |
write_result = False
|
|
|
|
| 434 |
|
| 435 |
return bits_to_int(out_bits)
|
| 436 |
|
| 437 |
+
def shift_left(self, a: int) -> int:
|
| 438 |
+
a_bits = int_to_bits(a, REG_BITS)
|
| 439 |
+
out_bits = []
|
| 440 |
+
for bit in range(REG_BITS):
|
| 441 |
+
w = self.alu._get(f"alu.alu8bit.shl.bit{bit}.weight")
|
| 442 |
+
bias = self.alu._get(f"alu.alu8bit.shl.bit{bit}.bias")
|
| 443 |
+
if bit < 7:
|
| 444 |
+
inp = torch.tensor([float(a_bits[bit + 1])], device=self.device)
|
| 445 |
+
else:
|
| 446 |
+
inp = torch.tensor([0.0], device=self.device)
|
| 447 |
+
out = heaviside((inp * w).sum() + bias).item()
|
| 448 |
+
out_bits.append(int(out))
|
| 449 |
+
return bits_to_int(out_bits)
|
| 450 |
+
|
| 451 |
+
def shift_right(self, a: int) -> int:
|
| 452 |
+
a_bits = int_to_bits(a, REG_BITS)
|
| 453 |
+
out_bits = []
|
| 454 |
+
for bit in range(REG_BITS):
|
| 455 |
+
w = self.alu._get(f"alu.alu8bit.shr.bit{bit}.weight")
|
| 456 |
+
bias = self.alu._get(f"alu.alu8bit.shr.bit{bit}.bias")
|
| 457 |
+
if bit > 0:
|
| 458 |
+
inp = torch.tensor([float(a_bits[bit - 1])], device=self.device)
|
| 459 |
+
else:
|
| 460 |
+
inp = torch.tensor([0.0], device=self.device)
|
| 461 |
+
out = heaviside((inp * w).sum() + bias).item()
|
| 462 |
+
out_bits.append(int(out))
|
| 463 |
+
return bits_to_int(out_bits)
|
| 464 |
+
|
| 465 |
+
def multiply(self, a: int, b: int) -> int:
|
| 466 |
+
"""8-bit multiply using partial product AND gates + shift-add."""
|
| 467 |
+
a_bits = int_to_bits(a, REG_BITS)
|
| 468 |
+
b_bits = int_to_bits(b, REG_BITS)
|
| 469 |
+
|
| 470 |
+
# Compute all 64 partial products using AND gates
|
| 471 |
+
pp = [[0] * 8 for _ in range(8)]
|
| 472 |
+
for i in range(8):
|
| 473 |
+
for j in range(8):
|
| 474 |
+
w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight")
|
| 475 |
+
bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias")
|
| 476 |
+
inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device)
|
| 477 |
+
pp[i][j] = int(heaviside((inp * w).sum() + bias).item())
|
| 478 |
+
|
| 479 |
+
# Shift-add accumulation using existing 8-bit adder
|
| 480 |
+
# Row j contributes A*B[j] shifted left by (7-j) positions
|
| 481 |
+
result = 0
|
| 482 |
+
for j in range(8):
|
| 483 |
+
if b_bits[j] == 0:
|
| 484 |
+
continue
|
| 485 |
+
# Construct the partial product row (A masked by B[j])
|
| 486 |
+
row = 0
|
| 487 |
+
for i in range(8):
|
| 488 |
+
row |= (pp[i][j] << (7 - i))
|
| 489 |
+
# Shift by position (7-j means B[7] is LSB, B[0] is MSB)
|
| 490 |
+
shifted = row << (7 - j)
|
| 491 |
+
# Add to result using threshold adder
|
| 492 |
+
result, _, _ = self.add(result & 0xFF, shifted & 0xFF)
|
| 493 |
+
# Handle overflow into high byte
|
| 494 |
+
if shifted > 255 or result > 255:
|
| 495 |
+
result = (result + (shifted >> 8)) & 0xFF
|
| 496 |
+
|
| 497 |
+
return result & 0xFF
|
| 498 |
+
|
| 499 |
+
def divide(self, a: int, b: int) -> Tuple[int, int]:
|
| 500 |
+
"""8-bit divide using restoring division with threshold gates."""
|
| 501 |
+
if b == 0:
|
| 502 |
+
return 0xFF, a # Division by zero: return max quotient, original dividend
|
| 503 |
+
|
| 504 |
+
a_bits = int_to_bits(a, REG_BITS)
|
| 505 |
+
|
| 506 |
+
quotient = 0
|
| 507 |
+
remainder = 0
|
| 508 |
+
|
| 509 |
+
for stage in range(8):
|
| 510 |
+
# Shift remainder left and bring in next dividend bit
|
| 511 |
+
remainder = ((remainder << 1) | a_bits[stage]) & 0xFF
|
| 512 |
+
|
| 513 |
+
# Compare remainder >= divisor using threshold gate
|
| 514 |
+
rem_bits = int_to_bits(remainder, REG_BITS)
|
| 515 |
+
div_bits = int_to_bits(b, REG_BITS)
|
| 516 |
+
|
| 517 |
+
w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight")
|
| 518 |
+
bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias")
|
| 519 |
+
inp = torch.tensor([float(rem_bits[i]) for i in range(8)] +
|
| 520 |
+
[float(div_bits[i]) for i in range(8)], device=self.device)
|
| 521 |
+
cmp_result = int(heaviside((inp * w).sum() + bias).item())
|
| 522 |
+
|
| 523 |
+
# If remainder >= divisor, subtract and set quotient bit
|
| 524 |
+
if cmp_result:
|
| 525 |
+
remainder, _, _ = self.sub(remainder, b)
|
| 526 |
+
quotient = (quotient << 1) | 1
|
| 527 |
+
else:
|
| 528 |
+
quotient = quotient << 1
|
| 529 |
+
|
| 530 |
+
return quotient & 0xFF, remainder & 0xFF
|
| 531 |
+
|
| 532 |
|
| 533 |
class ThresholdCPU:
|
| 534 |
def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None:
|
|
|
|
| 672 |
elif opcode == 0x4:
|
| 673 |
result = self.alu.bitwise_xor(a, b)
|
| 674 |
elif opcode == 0x5:
|
| 675 |
+
result = self.alu.shift_left(a)
|
| 676 |
elif opcode == 0x6:
|
| 677 |
+
result = self.alu.shift_right(a)
|
| 678 |
elif opcode == 0x7:
|
| 679 |
+
result = self.alu.multiply(a, b)
|
| 680 |
elif opcode == 0x8:
|
| 681 |
+
result, _ = self.alu.divide(a, b)
|
| 682 |
elif opcode == 0x9:
|
| 683 |
result, carry, overflow = self.alu.sub(a, b)
|
| 684 |
write_result = False
|