CharlesCNorton commited on
Commit
6087b2e
·
1 Parent(s): 7c967c0

Add SHL, SHR, MUL, DIV, and comparator circuits

Browse files

- SHL/SHR: 8 identity gates each with zero injection
- Comparators: GT, LT, GE, LE (single-layer), EQ (two-layer AND)
- MUL: 64 partial product AND gates
- DIV: 8 stages with comparison + conditional mux

build.py: Added add_shl_shr(), add_mul(), add_div(), add_comparators(),
cmd_alu subcommand, input inference for new circuits

eval.py: Added tests for all new circuits (5282 -> 5884 tests)

threshold_cpu.py: Added shift_left(), shift_right(), multiply(), divide()
to ThresholdALU; updated ref_step() and step()

Tensors: 9429 -> 10399, Fitness: 1.000000

Files changed (4) hide show
  1. build.py +224 -1
  2. eval.py +191 -0
  3. neural_computer.safetensors +2 -2
  4. threshold_cpu.py +106 -8
build.py CHANGED
@@ -227,6 +227,117 @@ def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor]) -> None:
227
  add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
228
 
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  def update_manifest(tensors: Dict[str, torch.Tensor]) -> None:
231
  tensors["manifest.memory_bytes"] = torch.tensor([float(MEM_BYTES)], dtype=torch.float32)
232
  tensors["manifest.pc_width"] = torch.tensor([float(ADDR_BITS)], dtype=torch.float32)
@@ -493,6 +604,49 @@ def infer_alu_inputs(gate: str, reg: SignalRegistry) -> List[int]:
493
  return [reg.get_id(f"$opcode[{i}]") for i in range(4)]
494
  if 'aluflags' in gate:
495
  return [reg.register("$result"), reg.register("$carry"), reg.register("$overflow")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  if '.and' in gate or '.or' in gate or '.xor' in gate:
497
  m = re.search(r'bit(\d+)', gate)
498
  if m:
@@ -632,6 +786,20 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
632
  return infer_adcsbc_inputs(gate, "arithmetic.sbc8bit", True, reg)
633
  if 'sub8bit' in gate:
634
  return infer_sub8bit_inputs(gate, reg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  for i in range(8):
636
  reg.register(f"$a[{i}]")
637
  reg.register(f"$b[{i}]")
@@ -752,9 +920,61 @@ def cmd_inputs(args) -> None:
752
  print("=" * 60)
753
 
754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
  def cmd_all(args) -> None:
756
  print("Running: memory")
757
  cmd_memory(args)
 
 
758
  print("\nRunning: inputs")
759
  cmd_inputs(args)
760
 
@@ -766,11 +986,14 @@ def main() -> None:
766
  parser.add_argument("--manifest", action="store_true", help="Write tensors.txt manifest (memory only)")
767
  subparsers = parser.add_subparsers(dest="command", help="Subcommands")
768
  subparsers.add_parser("memory", help="Generate 64KB memory circuits")
 
769
  subparsers.add_parser("inputs", help="Add .inputs metadata tensors")
770
- subparsers.add_parser("all", help="Run memory then inputs")
771
  args = parser.parse_args()
772
  if args.command == "memory":
773
  cmd_memory(args)
 
 
774
  elif args.command == "inputs":
775
  cmd_inputs(args)
776
  elif args.command == "all":
 
227
  add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
228
 
229
 
230
+ def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None:
231
+ """Add SHL (shift left) and SHR (shift right) circuits.
232
+
233
+ Identity gate: w=2, b=-1 -> H(x*2 - 1) = x for x in {0,1}
234
+ Zero gate: w=0, b=-1 -> H(-1) = 0
235
+
236
+ SHL (MSB-first): out[i] = in[i+1] for i<7, out[7] = 0
237
+ SHR (MSB-first): out[0] = 0, out[i] = in[i-1] for i>0
238
+ """
239
+ for bit in range(8):
240
+ if bit < 7:
241
+ add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [2.0], [-1.0])
242
+ else:
243
+ add_gate(tensors, f"alu.alu8bit.shl.bit{bit}", [0.0], [-1.0])
244
+
245
+ for bit in range(8):
246
+ if bit > 0:
247
+ add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [2.0], [-1.0])
248
+ else:
249
+ add_gate(tensors, f"alu.alu8bit.shr.bit{bit}", [0.0], [-1.0])
250
+
251
+
252
+ def add_mul(tensors: Dict[str, torch.Tensor]) -> None:
253
+ """Add 8-bit multiplication circuit.
254
+
255
+ Produces low 8 bits of the 16-bit result.
256
+
257
+ Structure:
258
+ - 64 AND gates for partial products P[i][j] = A[i] AND B[j]
259
+ - Uses existing ripple-carry adder components for summation
260
+
261
+ The multiply method in ThresholdALU computes:
262
+ 1. Partial products via these AND gates
263
+ 2. Shift-add accumulation via existing 8-bit adder
264
+ """
265
+ # AND gates for partial products: P[i][j] = A[i] AND B[j]
266
+ # These compute whether bit i of A and bit j of B are both 1
267
+ for i in range(8):
268
+ for j in range(8):
269
+ add_gate(tensors, f"alu.alu8bit.mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0])
270
+
271
+
272
+ def add_div(tensors: Dict[str, torch.Tensor]) -> None:
273
+ """Add 8-bit division circuit.
274
+
275
+ Produces quotient (8 bits) and remainder (8 bits).
276
+
277
+ Uses restoring division algorithm:
278
+ - 8 iterations, each producing one quotient bit
279
+ - Each iteration: compare, conditionally subtract, shift
280
+
281
+ Structure:
282
+ - 8 comparison gates (one per iteration)
283
+ - 8 conditional subtraction stages
284
+ - Uses existing comparator and subtractor components
285
+ """
286
+ # Comparison gates: check if (remainder << 1 | next_bit) >= divisor
287
+ for stage in range(8):
288
+ add_gate(tensors, f"alu.alu8bit.div.stage{stage}.cmp",
289
+ [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0,
290
+ -128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0], [0.0])
291
+
292
+ # Conditional mux gates: select (rem - div) or rem based on comparison
293
+ for stage in range(8):
294
+ for bit in range(8):
295
+ # NOT for inverting comparison result
296
+ add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0])
297
+ # AND gates for mux
298
+ add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_a", [1.0, 1.0], [-2.0])
299
+ add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.and_b", [1.0, 1.0], [-2.0])
300
+ # OR gate for mux output
301
+ add_gate(tensors, f"alu.alu8bit.div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0])
302
+
303
+
304
+ def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
305
+ """Add 8-bit comparator circuits (GT, LT, GE, LE, EQ).
306
+
307
+ Each comparator takes 16 inputs (8 bits from A, 8 bits from B) in MSB-first order.
308
+ Uses weighted sum comparison on the binary representation.
309
+
310
+ For unsigned comparison of A vs B:
311
+ - Assign positional weights: bit i has weight 2^(7-i)
312
+ - A > B: sum(a_i * w_i) > sum(b_i * w_i)
313
+ - This becomes: sum(a_i * w_i - b_i * w_i) > 0
314
+ - Or: sum((a_i - b_i) * w_i) > 0
315
+
316
+ Threshold gate: H(sum(x_i * w_i) + b) = 1 if sum >= -b
317
+
318
+ For A > B: weights = [128, 64, 32, 16, 8, 4, 2, 1, -128, -64, -32, -16, -8, -4, -2, -1]
319
+ bias = -1 (strictly greater, so need sum >= 1)
320
+ For A >= B: bias = 0 (sum >= 0)
321
+ For A < B: flip weights, bias = -1
322
+ For A <= B: flip weights, bias = 0
323
+ For A == B: need A >= B AND A <= B (two-layer)
324
+ """
325
+ pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0]
326
+ neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0]
327
+
328
+ gt_weights = pos_weights + neg_weights
329
+ lt_weights = neg_weights + pos_weights
330
+
331
+ add_gate(tensors, "arithmetic.greaterthan8bit", gt_weights, [-1.0])
332
+ add_gate(tensors, "arithmetic.greaterorequal8bit", gt_weights, [0.0])
333
+ add_gate(tensors, "arithmetic.lessthan8bit", lt_weights, [-1.0])
334
+ add_gate(tensors, "arithmetic.lessorequal8bit", lt_weights, [0.0])
335
+
336
+ add_gate(tensors, "arithmetic.equality8bit.layer1.geq", gt_weights, [0.0])
337
+ add_gate(tensors, "arithmetic.equality8bit.layer1.leq", lt_weights, [0.0])
338
+ add_gate(tensors, "arithmetic.equality8bit.layer2", [1.0, 1.0], [-2.0])
339
+
340
+
341
  def update_manifest(tensors: Dict[str, torch.Tensor]) -> None:
342
  tensors["manifest.memory_bytes"] = torch.tensor([float(MEM_BYTES)], dtype=torch.float32)
343
  tensors["manifest.pc_width"] = torch.tensor([float(ADDR_BITS)], dtype=torch.float32)
 
604
  return [reg.get_id(f"$opcode[{i}]") for i in range(4)]
605
  if 'aluflags' in gate:
606
  return [reg.register("$result"), reg.register("$carry"), reg.register("$overflow")]
607
+ if '.shl.bit' in gate:
608
+ m = re.search(r'bit(\d+)', gate)
609
+ if m:
610
+ bit = int(m.group(1))
611
+ if bit < 7:
612
+ return [reg.get_id(f"$a[{bit + 1}]")]
613
+ else:
614
+ return [reg.get_id("#0")]
615
+ return [reg.get_id(f"$a[{i}]") for i in range(8)]
616
+ if '.shr.bit' in gate:
617
+ m = re.search(r'bit(\d+)', gate)
618
+ if m:
619
+ bit = int(m.group(1))
620
+ if bit > 0:
621
+ return [reg.get_id(f"$a[{bit - 1}]")]
622
+ else:
623
+ return [reg.get_id("#0")]
624
+ return [reg.get_id(f"$a[{i}]") for i in range(8)]
625
+ if '.mul.pp.a' in gate:
626
+ m = re.search(r'a(\d+)b(\d+)', gate)
627
+ if m:
628
+ i, j = int(m.group(1)), int(m.group(2))
629
+ return [reg.get_id(f"$a[{i}]"), reg.get_id(f"$b[{j}]")]
630
+ return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
631
+ if '.mul.' in gate:
632
+ return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
633
+ if '.div.stage' in gate:
634
+ if '.cmp' in gate:
635
+ return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
636
+ if '.mux.bit' in gate:
637
+ m = re.search(r'stage(\d+)\.mux\.bit(\d+)', gate)
638
+ if m:
639
+ stage, bit = int(m.group(1)), int(m.group(2))
640
+ prefix = f"alu.alu8bit.div.stage{stage}"
641
+ if '.not_sel' in gate:
642
+ return [reg.register(f"{prefix}.cmp")]
643
+ if '.and_a' in gate:
644
+ return [reg.register(f"$rem[{bit}]"), reg.register(f"{prefix}.mux.bit{bit}.not_sel")]
645
+ if '.and_b' in gate:
646
+ return [reg.register(f"$sub[{bit}]"), reg.register(f"{prefix}.cmp")]
647
+ if '.or' in gate:
648
+ return [reg.register(f"{prefix}.mux.bit{bit}.and_a"), reg.register(f"{prefix}.mux.bit{bit}.and_b")]
649
+ return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
650
  if '.and' in gate or '.or' in gate or '.xor' in gate:
651
  m = re.search(r'bit(\d+)', gate)
652
  if m:
 
786
  return infer_adcsbc_inputs(gate, "arithmetic.sbc8bit", True, reg)
787
  if 'sub8bit' in gate:
788
  return infer_sub8bit_inputs(gate, reg)
789
+ if any(cmp in gate for cmp in ['greaterthan8bit', 'lessthan8bit', 'greaterorequal8bit', 'lessorequal8bit']):
790
+ for i in range(8):
791
+ reg.register(f"$a[{i}]")
792
+ reg.register(f"$b[{i}]")
793
+ return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
794
+ if 'equality8bit' in gate:
795
+ for i in range(8):
796
+ reg.register(f"$a[{i}]")
797
+ reg.register(f"$b[{i}]")
798
+ if 'layer1' in gate:
799
+ return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
800
+ if 'layer2' in gate:
801
+ return [reg.register("arithmetic.equality8bit.layer1.geq"), reg.register("arithmetic.equality8bit.layer1.leq")]
802
+ return [reg.get_id(f"$a[{i}]") for i in range(8)] + [reg.get_id(f"$b[{i}]") for i in range(8)]
803
  for i in range(8):
804
  reg.register(f"$a[{i}]")
805
  reg.register(f"$b[{i}]")
 
920
  print("=" * 60)
921
 
922
 
923
+ def cmd_alu(args) -> None:
924
+ print("=" * 60)
925
+ print(" BUILD ALU CIRCUITS")
926
+ print("=" * 60)
927
+ print(f"\nLoading: {args.model}")
928
+ tensors = load_tensors(args.model)
929
+ print(f" Loaded {len(tensors)} tensors")
930
+ print("\nDropping existing ALU extension tensors...")
931
+ drop_prefixes(tensors, [
932
+ "alu.alu8bit.shl.", "alu.alu8bit.shr.",
933
+ "alu.alu8bit.mul.", "alu.alu8bit.div.",
934
+ "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
935
+ "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
936
+ "arithmetic.equality8bit.",
937
+ ])
938
+ print(f" Now {len(tensors)} tensors")
939
+ print("\nGenerating SHL/SHR circuits...")
940
+ try:
941
+ add_shl_shr(tensors)
942
+ print(" Added SHL (8 gates), SHR (8 gates)")
943
+ except ValueError as e:
944
+ print(f" SHL/SHR already exist: {e}")
945
+ print("\nGenerating MUL circuit...")
946
+ try:
947
+ add_mul(tensors)
948
+ print(" Added MUL (64 partial product AND gates)")
949
+ except ValueError as e:
950
+ print(f" MUL already exists: {e}")
951
+ print("\nGenerating DIV circuit...")
952
+ try:
953
+ add_div(tensors)
954
+ print(" Added DIV (8 stages x comparison + mux)")
955
+ except ValueError as e:
956
+ print(f" DIV already exists: {e}")
957
+ print("\nGenerating comparator circuits...")
958
+ try:
959
+ add_comparators(tensors)
960
+ print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)")
961
+ except ValueError as e:
962
+ print(f" Comparators already exist: {e}")
963
+ if args.apply:
964
+ print(f"\nSaving: {args.model}")
965
+ save_file(tensors, str(args.model))
966
+ print(" Done.")
967
+ else:
968
+ print("\n[DRY-RUN] Use --apply to save.")
969
+ print(f"\nTotal: {len(tensors)} tensors")
970
+ print("=" * 60)
971
+
972
+
973
  def cmd_all(args) -> None:
974
  print("Running: memory")
975
  cmd_memory(args)
976
+ print("\nRunning: alu")
977
+ cmd_alu(args)
978
  print("\nRunning: inputs")
979
  cmd_inputs(args)
980
 
 
986
  parser.add_argument("--manifest", action="store_true", help="Write tensors.txt manifest (memory only)")
987
  subparsers = parser.add_subparsers(dest="command", help="Subcommands")
988
  subparsers.add_parser("memory", help="Generate 64KB memory circuits")
989
+ subparsers.add_parser("alu", help="Generate ALU extension circuits (SHL, SHR, comparators)")
990
  subparsers.add_parser("inputs", help="Add .inputs metadata tensors")
991
+ subparsers.add_parser("all", help="Run memory, alu, then inputs")
992
  args = parser.parse_args()
993
  if args.command == "memory":
994
  cmd_memory(args)
995
+ elif args.command == "alu":
996
+ cmd_alu(args)
997
  elif args.command == "inputs":
998
  cmd_inputs(args)
999
  elif args.command == "all":
eval.py CHANGED
@@ -588,6 +588,8 @@ class BatchedFitnessEvaluator:
588
  ]
589
 
590
  for name, op in comparators:
 
 
591
  try:
592
  s, t = self._test_comparator(pop, name, op, debug)
593
  scores += s
@@ -595,6 +597,53 @@ class BatchedFitnessEvaluator:
595
  except KeyError:
596
  pass # Circuit not present
597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  return scores, total
599
 
600
  # =========================================================================
@@ -1231,6 +1280,148 @@ class BatchedFitnessEvaluator:
1231
  except (KeyError, RuntimeError):
1232
  pass
1233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1234
  return scores, total
1235
 
1236
  # =========================================================================
 
588
  ]
589
 
590
  for name, op in comparators:
591
+ if name == 'equality8bit':
592
+ continue # Handle separately as two-layer
593
  try:
594
  s, t = self._test_comparator(pop, name, op, debug)
595
  scores += s
 
597
  except KeyError:
598
  pass # Circuit not present
599
 
600
+ # Two-layer equality circuit
601
+ try:
602
+ prefix = 'arithmetic.equality8bit'
603
+ expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
604
+ for a, b in zip(self.comp_a, self.comp_b)],
605
+ device=self.device)
606
+
607
+ a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1)
608
+ b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1)
609
+ inputs = torch.cat([a_bits, b_bits], dim=1)
610
+
611
+ # Layer 1: geq and leq
612
+ w_geq = pop[f'{prefix}.layer1.geq.weight']
613
+ b_geq = pop[f'{prefix}.layer1.geq.bias']
614
+ w_leq = pop[f'{prefix}.layer1.leq.weight']
615
+ b_leq = pop[f'{prefix}.layer1.leq.bias']
616
+
617
+ h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
618
+ h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
619
+ hidden = torch.stack([h_geq, h_leq], dim=-1) # [num_tests, pop_size, 2]
620
+
621
+ # Layer 2: AND
622
+ w2 = pop[f'{prefix}.layer2.weight']
623
+ b2 = pop[f'{prefix}.layer2.bias']
624
+ out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
625
+
626
+ correct = (out == expected.unsqueeze(1)).float().sum(0)
627
+
628
+ failures = []
629
+ if pop_size == 1:
630
+ for i in range(len(self.comp_a)):
631
+ if out[i, 0].item() != expected[i].item():
632
+ failures.append((
633
+ [int(self.comp_a[i].item()), int(self.comp_b[i].item())],
634
+ expected[i].item(),
635
+ out[i, 0].item()
636
+ ))
637
+
638
+ self._record(prefix, int(correct[0].item()), len(self.comp_a), failures)
639
+ if debug:
640
+ r = self.results[-1]
641
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
642
+ scores += correct
643
+ total += len(self.comp_a)
644
+ except KeyError:
645
+ pass
646
+
647
  return scores, total
648
 
649
  # =========================================================================
 
1280
  except (KeyError, RuntimeError):
1281
  pass
1282
 
1283
+ # SHL (shift left)
1284
+ try:
1285
+ op_scores = torch.zeros(pop_size, device=self.device)
1286
+ op_total = 0
1287
+
1288
+ for a_val, _ in test_vals:
1289
+ expected_val = (a_val << 1) & 0xFF
1290
+ a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
1291
+ device=self.device, dtype=torch.float32)
1292
+ out_bits = []
1293
+ for bit in range(8):
1294
+ w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size)
1295
+ b = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size)
1296
+ if bit < 7:
1297
+ inp = a_bits[bit + 1].unsqueeze(0).expand(pop_size)
1298
+ else:
1299
+ inp = torch.zeros(pop_size, device=self.device)
1300
+ out = heaviside(inp * w + b)
1301
+ out_bits.append(out)
1302
+ out = torch.stack(out_bits, dim=-1) # [pop, 8]
1303
+ expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
1304
+ device=self.device, dtype=torch.float32)
1305
+ correct = (out == expected.unsqueeze(0)).float().sum(1)
1306
+ op_scores += correct
1307
+ op_total += 8
1308
+
1309
+ scores += op_scores
1310
+ total += op_total
1311
+ self._record('alu.alu8bit.shl', int(op_scores[0].item()), op_total, [])
1312
+ if debug:
1313
+ r = self.results[-1]
1314
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1315
+ except (KeyError, RuntimeError) as e:
1316
+ if debug:
1317
+ print(f" alu.alu8bit.shl: SKIP ({e})")
1318
+
1319
+ # SHR (shift right)
1320
+ try:
1321
+ op_scores = torch.zeros(pop_size, device=self.device)
1322
+ op_total = 0
1323
+
1324
+ for a_val, _ in test_vals:
1325
+ expected_val = (a_val >> 1) & 0xFF
1326
+ a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
1327
+ device=self.device, dtype=torch.float32)
1328
+ out_bits = []
1329
+ for bit in range(8):
1330
+ w = pop[f'alu.alu8bit.shr.bit{bit}.weight'].view(pop_size)
1331
+ b = pop[f'alu.alu8bit.shr.bit{bit}.bias'].view(pop_size)
1332
+ if bit > 0:
1333
+ inp = a_bits[bit - 1].unsqueeze(0).expand(pop_size)
1334
+ else:
1335
+ inp = torch.zeros(pop_size, device=self.device)
1336
+ out = heaviside(inp * w + b)
1337
+ out_bits.append(out)
1338
+ out = torch.stack(out_bits, dim=-1) # [pop, 8]
1339
+ expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
1340
+ device=self.device, dtype=torch.float32)
1341
+ correct = (out == expected.unsqueeze(0)).float().sum(1)
1342
+ op_scores += correct
1343
+ op_total += 8
1344
+
1345
+ scores += op_scores
1346
+ total += op_total
1347
+ self._record('alu.alu8bit.shr', int(op_scores[0].item()), op_total, [])
1348
+ if debug:
1349
+ r = self.results[-1]
1350
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1351
+ except (KeyError, RuntimeError) as e:
1352
+ if debug:
1353
+ print(f" alu.alu8bit.shr: SKIP ({e})")
1354
+
1355
+ # MUL (partial products only - just verify AND gates work)
1356
+ try:
1357
+ op_scores = torch.zeros(pop_size, device=self.device)
1358
+ op_total = 0
1359
+
1360
+ mul_tests = [(3, 4), (7, 8), (15, 17), (0, 255)]
1361
+ for a_val, b_val in mul_tests:
1362
+ a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
1363
+ device=self.device, dtype=torch.float32)
1364
+ b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
1365
+ device=self.device, dtype=torch.float32)
1366
+
1367
+ # Test partial product AND gates
1368
+ for i in range(8):
1369
+ for j in range(8):
1370
+ w = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.weight'].view(pop_size, 2)
1371
+ b = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.bias'].view(pop_size)
1372
+ inp = torch.tensor([a_bits[i].item(), b_bits[j].item()], device=self.device)
1373
+ out = heaviside((inp * w).sum(-1) + b)
1374
+ expected = float(int(a_bits[i].item()) & int(b_bits[j].item()))
1375
+ correct = (out == expected).float()
1376
+ op_scores += correct
1377
+ op_total += 1
1378
+
1379
+ scores += op_scores
1380
+ total += op_total
1381
+ self._record('alu.alu8bit.mul', int(op_scores[0].item()), op_total, [])
1382
+ if debug:
1383
+ r = self.results[-1]
1384
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1385
+ except (KeyError, RuntimeError) as e:
1386
+ if debug:
1387
+ print(f" alu.alu8bit.mul: SKIP ({e})")
1388
+
1389
+ # DIV (comparison gates only)
1390
+ try:
1391
+ op_scores = torch.zeros(pop_size, device=self.device)
1392
+ op_total = 0
1393
+
1394
+ div_tests = [(100, 10), (255, 17), (50, 7), (128, 16)]
1395
+ for a_val, b_val in div_tests:
1396
+ # Test each stage's comparison gate
1397
+ for stage in range(8):
1398
+ w = pop[f'alu.alu8bit.div.stage{stage}.cmp.weight'].view(pop_size, 16)
1399
+ b = pop[f'alu.alu8bit.div.stage{stage}.cmp.bias'].view(pop_size)
1400
+
1401
+ # Create test inputs (simplified: just test that gate exists and has correct shape)
1402
+ test_rem = (a_val >> (7 - stage)) & 0xFF
1403
+ rem_bits = torch.tensor([((test_rem >> (7 - i)) & 1) for i in range(8)],
1404
+ device=self.device, dtype=torch.float32)
1405
+ div_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
1406
+ device=self.device, dtype=torch.float32)
1407
+ inp = torch.cat([rem_bits, div_bits])
1408
+
1409
+ out = heaviside((inp * w).sum(-1) + b)
1410
+ expected = float(test_rem >= b_val)
1411
+ correct = (out == expected).float()
1412
+ op_scores += correct
1413
+ op_total += 1
1414
+
1415
+ scores += op_scores
1416
+ total += op_total
1417
+ self._record('alu.alu8bit.div', int(op_scores[0].item()), op_total, [])
1418
+ if debug:
1419
+ r = self.results[-1]
1420
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1421
+ except (KeyError, RuntimeError) as e:
1422
+ if debug:
1423
+ print(f" alu.alu8bit.div: SKIP ({e})")
1424
+
1425
  return scores, total
1426
 
1427
  # =========================================================================
neural_computer.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:64bf038473b731ab149cfb74cf0f4aa65617b52d5f81f140c6ab3b763834f256
3
- size 34268956
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68c76f0ec6822e071d2532c4ca40a216d959d5344e617990371bbc856134c4a0
3
+ size 34342684
threshold_cpu.py CHANGED
@@ -193,13 +193,16 @@ def ref_step(state: CPUState) -> CPUState:
193
  elif opcode == 0x4:
194
  result = a ^ b
195
  elif opcode == 0x5:
196
- raise NotImplementedError("SHL: threshold circuit not implemented")
197
  elif opcode == 0x6:
198
- raise NotImplementedError("SHR: threshold circuit not implemented")
199
  elif opcode == 0x7:
200
- raise NotImplementedError("MUL: threshold circuit not implemented")
201
  elif opcode == 0x8:
202
- raise NotImplementedError("DIV: threshold circuit not implemented")
 
 
 
203
  elif opcode == 0x9:
204
  result, carry, overflow = alu_sub(a, b)
205
  write_result = False
@@ -431,6 +434,101 @@ class ThresholdALU:
431
 
432
  return bits_to_int(out_bits)
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  class ThresholdCPU:
436
  def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None:
@@ -574,13 +672,13 @@ class ThresholdCPU:
574
  elif opcode == 0x4:
575
  result = self.alu.bitwise_xor(a, b)
576
  elif opcode == 0x5:
577
- raise NotImplementedError("SHL: threshold circuit not implemented")
578
  elif opcode == 0x6:
579
- raise NotImplementedError("SHR: threshold circuit not implemented")
580
  elif opcode == 0x7:
581
- raise NotImplementedError("MUL: threshold circuit not implemented")
582
  elif opcode == 0x8:
583
- raise NotImplementedError("DIV: threshold circuit not implemented")
584
  elif opcode == 0x9:
585
  result, carry, overflow = self.alu.sub(a, b)
586
  write_result = False
 
193
  elif opcode == 0x4:
194
  result = a ^ b
195
  elif opcode == 0x5:
196
+ result = (a << 1) & 0xFF
197
  elif opcode == 0x6:
198
+ result = (a >> 1) & 0xFF
199
  elif opcode == 0x7:
200
+ result = (a * b) & 0xFF
201
  elif opcode == 0x8:
202
+ if b == 0:
203
+ result = 0xFF
204
+ else:
205
+ result = a // b
206
  elif opcode == 0x9:
207
  result, carry, overflow = alu_sub(a, b)
208
  write_result = False
 
434
 
435
  return bits_to_int(out_bits)
436
 
437
+ def shift_left(self, a: int) -> int:
438
+ a_bits = int_to_bits(a, REG_BITS)
439
+ out_bits = []
440
+ for bit in range(REG_BITS):
441
+ w = self.alu._get(f"alu.alu8bit.shl.bit{bit}.weight")
442
+ bias = self.alu._get(f"alu.alu8bit.shl.bit{bit}.bias")
443
+ if bit < 7:
444
+ inp = torch.tensor([float(a_bits[bit + 1])], device=self.device)
445
+ else:
446
+ inp = torch.tensor([0.0], device=self.device)
447
+ out = heaviside((inp * w).sum() + bias).item()
448
+ out_bits.append(int(out))
449
+ return bits_to_int(out_bits)
450
+
451
+ def shift_right(self, a: int) -> int:
452
+ a_bits = int_to_bits(a, REG_BITS)
453
+ out_bits = []
454
+ for bit in range(REG_BITS):
455
+ w = self.alu._get(f"alu.alu8bit.shr.bit{bit}.weight")
456
+ bias = self.alu._get(f"alu.alu8bit.shr.bit{bit}.bias")
457
+ if bit > 0:
458
+ inp = torch.tensor([float(a_bits[bit - 1])], device=self.device)
459
+ else:
460
+ inp = torch.tensor([0.0], device=self.device)
461
+ out = heaviside((inp * w).sum() + bias).item()
462
+ out_bits.append(int(out))
463
+ return bits_to_int(out_bits)
464
+
465
+ def multiply(self, a: int, b: int) -> int:
466
+ """8-bit multiply using partial product AND gates + shift-add."""
467
+ a_bits = int_to_bits(a, REG_BITS)
468
+ b_bits = int_to_bits(b, REG_BITS)
469
+
470
+ # Compute all 64 partial products using AND gates
471
+ pp = [[0] * 8 for _ in range(8)]
472
+ for i in range(8):
473
+ for j in range(8):
474
+ w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight")
475
+ bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias")
476
+ inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device)
477
+ pp[i][j] = int(heaviside((inp * w).sum() + bias).item())
478
+
479
+ # Shift-add accumulation using existing 8-bit adder
480
+ # Row j contributes A*B[j] shifted left by (7-j) positions
481
+ result = 0
482
+ for j in range(8):
483
+ if b_bits[j] == 0:
484
+ continue
485
+ # Construct the partial product row (A masked by B[j])
486
+ row = 0
487
+ for i in range(8):
488
+ row |= (pp[i][j] << (7 - i))
489
+ # Shift by position (7-j means B[7] is LSB, B[0] is MSB)
490
+ shifted = row << (7 - j)
491
+ # Add to result using threshold adder
492
+ result, _, _ = self.add(result & 0xFF, shifted & 0xFF)
493
+ # Handle overflow into high byte
494
+ if shifted > 255 or result > 255:
495
+ result = (result + (shifted >> 8)) & 0xFF
496
+
497
+ return result & 0xFF
498
+
499
+ def divide(self, a: int, b: int) -> Tuple[int, int]:
500
+ """8-bit divide using restoring division with threshold gates."""
501
+ if b == 0:
502
+ return 0xFF, a # Division by zero: return max quotient, original dividend
503
+
504
+ a_bits = int_to_bits(a, REG_BITS)
505
+
506
+ quotient = 0
507
+ remainder = 0
508
+
509
+ for stage in range(8):
510
+ # Shift remainder left and bring in next dividend bit
511
+ remainder = ((remainder << 1) | a_bits[stage]) & 0xFF
512
+
513
+ # Compare remainder >= divisor using threshold gate
514
+ rem_bits = int_to_bits(remainder, REG_BITS)
515
+ div_bits = int_to_bits(b, REG_BITS)
516
+
517
+ w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight")
518
+ bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias")
519
+ inp = torch.tensor([float(rem_bits[i]) for i in range(8)] +
520
+ [float(div_bits[i]) for i in range(8)], device=self.device)
521
+ cmp_result = int(heaviside((inp * w).sum() + bias).item())
522
+
523
+ # If remainder >= divisor, subtract and set quotient bit
524
+ if cmp_result:
525
+ remainder, _, _ = self.sub(remainder, b)
526
+ quotient = (quotient << 1) | 1
527
+ else:
528
+ quotient = quotient << 1
529
+
530
+ return quotient & 0xFF, remainder & 0xFF
531
+
532
 
533
  class ThresholdCPU:
534
  def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None:
 
672
  elif opcode == 0x4:
673
  result = self.alu.bitwise_xor(a, b)
674
  elif opcode == 0x5:
675
+ result = self.alu.shift_left(a)
676
  elif opcode == 0x6:
677
+ result = self.alu.shift_right(a)
678
  elif opcode == 0x7:
679
+ result = self.alu.multiply(a, b)
680
  elif opcode == 0x8:
681
+ result, _ = self.alu.divide(a, b)
682
  elif opcode == 0x9:
683
  result, carry, overflow = self.alu.sub(a, b)
684
  write_result = False