CharlesCNorton commited on
Commit
d15242d
·
1 Parent(s): df088a9

Add order of operations circuit (arithmetic.expr_add_mul)

Browse files

Computes A + (B × C) with correct precedence (multiply before add).

Circuit structure:
- 64 AND gates for B[bit] AND C[stage] masks
- 7 accumulator stages with shifted addition for shift-add multiply
- 8 full adders for final A + result

build.py: add_expr_add_mul(), infer_expr_add_mul_inputs()
eval.py: _test_expr_add_mul() with 73 test cases
Examples: 5 + 3 × 2 = 11, 10 + 4 × 3 = 22
Fitness 1.000000

Files changed (4) hide show
  1. README.md +5 -5
  2. build.py +179 -1
  3. eval.py +154 -0
  4. neural_computer.safetensors +2 -2
README.md CHANGED
@@ -457,18 +457,18 @@ The interface generalizes to **all** 65,536 8-bit additions once trained—no me
457
 
458
  ### Extension Roadmap
459
 
460
- 1. **Order of operations (5 + 3 × 2 = 11)** — Parse expression into tree, evaluate depth-first. MUL before ADD. Requires either: (a) expression parser producing evaluation order, or (b) learned routing that implicitly respects precedence.
461
 
462
- 2. **Parenthetical expressions ((5 + 3) × 2 = 16)** Explicit grouping overrides precedence. Parser must recognize parens and build correct tree. Evaluation proceeds innermost-out. Adds complexity to extraction layer.
463
 
464
- 3. **16-bit operations (0-65535)** — Chain two 8-bit circuits with carry propagation. ADD16: low = ADD8(A_lo, B_lo), high = ADD8(A_hi, B_hi, carry_out). MUL16: four partial products + shift-add. Doubles operand extraction width.
465
-
466
- 4. **Floating point arithmetic** — IEEE 754-style with separate circuits for mantissa and exponent. ADD: align exponents, add mantissas, renormalize. MUL: add exponents, multiply mantissas. Requires sign handling, overflow detection, and rounding logic.
467
 
468
  ### Completed Extensions
469
 
470
  - **3-operand addition (15 + 27 + 33 = 75)** — `arithmetic.add3_8bit` chains two 8-bit ripple carry stages. 16 full adders, 144 gates, 240 test cases verified.
471
 
 
 
472
  ---
473
 
474
  ## Files
 
457
 
458
  ### Extension Roadmap
459
 
460
+ 1. **Parenthetical expressions ((5 + 3) × 2 = 16)** — Explicit grouping overrides precedence. Parser must recognize parens and build correct tree. Evaluation proceeds innermost-out. Adds complexity to extraction layer.
461
 
462
+ 2. **16-bit operations (0-65535)** Chain two 8-bit circuits with carry propagation. ADD16: low = ADD8(A_lo, B_lo), high = ADD8(A_hi, B_hi, carry_out). MUL16: four partial products + shift-add. Doubles operand extraction width.
463
 
464
+ 3. **Floating point arithmetic** — IEEE 754-style with separate circuits for mantissa and exponent. ADD: align exponents, add mantissas, renormalize. MUL: add exponents, multiply mantissas. Requires sign handling, overflow detection, and rounding logic.
 
 
465
 
466
  ### Completed Extensions
467
 
468
  - **3-operand addition (15 + 27 + 33 = 75)** — `arithmetic.add3_8bit` chains two 8-bit ripple carry stages. 16 full adders, 144 gates, 240 test cases verified.
469
 
470
+ - **Order of operations (5 + 3 × 2 = 11)** — `arithmetic.expr_add_mul` computes A + (B × C) using shift-add multiplication then addition. 64 AND gates + 64 full adders, 73 test cases verified.
471
+
472
  ---
473
 
474
  ## Files
build.py CHANGED
@@ -259,6 +259,56 @@ def add_full_adder(tensors: Dict[str, torch.Tensor], prefix: str) -> None:
259
  add_gate(tensors, f"{prefix}.carry_or", [1.0, 1.0], [-1.0])
260
 
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
263
  """Add 3-operand 8-bit adder circuit.
264
 
@@ -649,6 +699,126 @@ def infer_ripplecarry_inputs(gate: str, prefix: str, bits: int, reg: SignalRegis
649
  return []
650
 
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
653
  """Infer inputs for 3-operand adder: A + B + C."""
654
  prefix = "arithmetic.add3_8bit"
@@ -1179,6 +1349,8 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
1179
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
1180
  if 'add3_8bit' in gate:
1181
  return infer_add3_inputs(gate, reg)
 
 
1182
  if 'adc8bit' in gate:
1183
  return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
1184
  if 'sbc8bit' in gate:
@@ -1404,7 +1576,7 @@ def cmd_alu(args) -> None:
1404
  "alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
1405
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
1406
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
1407
- "arithmetic.equality8bit.", "arithmetic.add3_8bit.",
1408
  "control.push.", "control.pop.", "control.ret.",
1409
  "combinational.barrelshifter.", "combinational.priorityencoder.",
1410
  ])
@@ -1475,6 +1647,12 @@ def cmd_alu(args) -> None:
1475
  print(" Added ADD3 (16 full adders = 144 gates)")
1476
  except ValueError as e:
1477
  print(f" ADD3 already exists: {e}")
 
 
 
 
 
 
1478
  if args.apply:
1479
  print(f"\nSaving: {args.model}")
1480
  save_file(tensors, str(args.model))
 
259
  add_gate(tensors, f"{prefix}.carry_or", [1.0, 1.0], [-1.0])
260
 
261
 
262
+ def add_expr_add_mul(tensors: Dict[str, torch.Tensor]) -> None:
263
+ """Add expression circuit for A + B × C (order of operations).
264
+
265
+ Computes A + (B × C) where multiplication has higher precedence.
266
+
267
+ Structure:
268
+ - Stage 1: Multiply B × C using shift-add algorithm
269
+ - 8 mask stages: mask[i] = B AND C[i] (8 AND gates each, shifted)
270
+ - 7 accumulator adders to sum masked values
271
+ - Stage 2: Add A to multiplication result (8-bit ripple carry)
272
+
273
+ Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
274
+ Output: 8-bit result of A + (B × C), wrapping on overflow
275
+
276
+ Total: 64 AND gates + 7×8 full adders (mul) + 8 full adders (add) = ~640 gates
277
+ """
278
+ prefix = "arithmetic.expr_add_mul"
279
+
280
+ # Stage 1: Multiply B × C using shift-add
281
+ # For each bit i of C, we AND all bits of B with C[i]
282
+ # This creates partial products that are shifted by i positions
283
+
284
+ # Mask AND gates: mask[stage][bit] = B[bit] AND C[stage]
285
+ # These compute B & (C[i] ? 0xFF : 0x00) for each bit of C
286
+ for stage in range(8):
287
+ for bit in range(8):
288
+ add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
289
+
290
+ # Accumulator adders for shift-add multiplication
291
+ # Stage 0: acc = mask0 (no adder needed, just the masked value)
292
+ # Stage 1-7: acc = acc + (mask[i] << i)
293
+ # We need to handle the shifting by connecting different bit positions
294
+
295
+ # For proper shift-add, we need adders that accumulate partial products
296
+ # Each stage adds a shifted partial product to the accumulator
297
+ # Using 16-bit internal accumulator, output low 8 bits
298
+
299
+ # Simplified approach: chain of 8-bit adders with proper bit alignment
300
+ # acc_stage[i] = acc_stage[i-1] + (mask[i] << i)
301
+ # We keep only low 8 bits at each stage for 8-bit result
302
+
303
+ for stage in range(1, 8): # 7 accumulator adders
304
+ for bit in range(8):
305
+ add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
306
+
307
+ # Stage 2: Add A to multiplication result
308
+ for bit in range(8):
309
+ add_full_adder(tensors, f"{prefix}.add.fa{bit}")
310
+
311
+
312
  def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
313
  """Add 3-operand 8-bit adder circuit.
314
 
 
699
  return []
700
 
701
 
702
+ def infer_expr_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]:
703
+ """Infer inputs for A + B × C expression circuit (order of operations).
704
+
705
+ Circuit structure:
706
+ - Mask stage: mask.s[stage].b[bit] = B[bit] AND C[stage]
707
+ - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage)
708
+ - Final add: result = A + acc.s7
709
+
710
+ Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB)
711
+ - $x[7] = bit 0 (LSB), $x[0] = bit 7 (MSB)
712
+ """
713
+ prefix = "arithmetic.expr_add_mul"
714
+
715
+ # Register all inputs
716
+ for i in range(8):
717
+ reg.register(f"$a[{i}]")
718
+ reg.register(f"$b[{i}]")
719
+ reg.register(f"$c[{i}]")
720
+
721
+ # Mask AND gates: mask.s[stage].b[bit] = B[bit] AND C[stage]
722
+ if '.mul.mask.' in gate:
723
+ m = re.search(r'\.s(\d+)\.b(\d+)', gate)
724
+ if m:
725
+ stage = int(m.group(1))
726
+ bit = int(m.group(2))
727
+ # MSB-first: $b[7-bit] is bit position 'bit', $c[7-stage] is stage position 'stage'
728
+ b_input = reg.get_id(f"$b[{7-bit}]")
729
+ c_input = reg.get_id(f"$c[{7-stage}]")
730
+ return [b_input, c_input]
731
+ return []
732
+
733
+ # Accumulator adders: acc.s[stage].fa[bit]
734
+ if '.mul.acc.' in gate:
735
+ m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate)
736
+ if not m:
737
+ return []
738
+ stage = int(m.group(1)) # 1-7
739
+ bit = int(m.group(2)) # 0-7
740
+
741
+ # A input: previous stage output
742
+ if stage == 1:
743
+ # First accumulator: A = mask.s0.b[bit] (AND gate output)
744
+ a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}")
745
+ else:
746
+ # Later stages: A = previous accumulator sum
747
+ a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2")
748
+
749
+ # B input: (mask.s[stage] << stage)[bit]
750
+ # Shift left by 'stage' positions means:
751
+ # - bit positions 0 to stage-1 get 0
752
+ # - bit position 'bit' gets mask.s[stage].b[bit-stage]
753
+ if bit < stage:
754
+ b_input = reg.get_id("#0")
755
+ else:
756
+ b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}")
757
+
758
+ # Carry input
759
+ if bit == 0:
760
+ cin = reg.get_id("#0")
761
+ else:
762
+ cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or")
763
+
764
+ fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}"
765
+
766
+ if '.ha1.sum.layer1' in gate:
767
+ return [a_input, b_input]
768
+ if '.ha1.sum.layer2' in gate:
769
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
770
+ if '.ha1.carry' in gate and '.layer' not in gate:
771
+ return [a_input, b_input]
772
+ if '.ha2.sum.layer1' in gate:
773
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
774
+ if '.ha2.sum.layer2' in gate:
775
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
776
+ if '.ha2.carry' in gate and '.layer' not in gate:
777
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
778
+ if '.carry_or' in gate:
779
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
780
+ return []
781
+
782
+ # Final add stage: A + mul_result
783
+ if '.add.fa' in gate:
784
+ m = re.search(r'\.fa(\d+)\.', gate)
785
+ if not m:
786
+ return []
787
+ bit = int(m.group(1))
788
+
789
+ # A input: $a[7-bit] (MSB-first to positional bit)
790
+ a_input = reg.get_id(f"$a[{7-bit}]")
791
+
792
+ # B input: multiplication result = acc.s7.fa[bit] sum output
793
+ b_input = reg.register(f"{prefix}.mul.acc.s7.fa{bit}.ha2.sum.layer2")
794
+
795
+ # Carry input
796
+ if bit == 0:
797
+ cin = reg.get_id("#0")
798
+ else:
799
+ cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or")
800
+
801
+ fa_prefix = f"{prefix}.add.fa{bit}"
802
+
803
+ if '.ha1.sum.layer1' in gate:
804
+ return [a_input, b_input]
805
+ if '.ha1.sum.layer2' in gate:
806
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
807
+ if '.ha1.carry' in gate and '.layer' not in gate:
808
+ return [a_input, b_input]
809
+ if '.ha2.sum.layer1' in gate:
810
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
811
+ if '.ha2.sum.layer2' in gate:
812
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
813
+ if '.ha2.carry' in gate and '.layer' not in gate:
814
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
815
+ if '.carry_or' in gate:
816
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
817
+ return []
818
+
819
+ return []
820
+
821
+
822
  def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
823
  """Infer inputs for 3-operand adder: A + B + C."""
824
  prefix = "arithmetic.add3_8bit"
 
1349
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
1350
  if 'add3_8bit' in gate:
1351
  return infer_add3_inputs(gate, reg)
1352
+ if 'expr_add_mul' in gate:
1353
+ return infer_expr_add_mul_inputs(gate, reg)
1354
  if 'adc8bit' in gate:
1355
  return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
1356
  if 'sbc8bit' in gate:
 
1576
  "alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
1577
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
1578
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
1579
+ "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.",
1580
  "control.push.", "control.pop.", "control.ret.",
1581
  "combinational.barrelshifter.", "combinational.priorityencoder.",
1582
  ])
 
1647
  print(" Added ADD3 (16 full adders = 144 gates)")
1648
  except ValueError as e:
1649
  print(f" ADD3 already exists: {e}")
1650
+ print("\nGenerating expression A + B × C circuit...")
1651
+ try:
1652
+ add_expr_add_mul(tensors)
1653
+ print(" Added EXPR_ADD_MUL (64 AND + 56 + 8 full adders = 640 gates)")
1654
+ except ValueError as e:
1655
+ print(f" EXPR_ADD_MUL already exists: {e}")
1656
  if args.apply:
1657
  print(f"\nSaving: {args.model}")
1658
  save_file(tensors, str(args.model))
eval.py CHANGED
@@ -631,6 +631,154 @@ class BatchedFitnessEvaluator:
631
 
632
  return correct, num_tests
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  # =========================================================================
635
  # COMPARATORS
636
  # =========================================================================
@@ -2450,6 +2598,12 @@ class BatchedFitnessEvaluator:
2450
  total_tests += t
2451
  self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
2452
 
 
 
 
 
 
 
2453
  # Comparators
2454
  s, t = self._test_comparators(population, debug)
2455
  scores += s
 
631
 
632
  return correct, num_tests
633
 
634
+ # =========================================================================
635
+ # ORDER OF OPERATIONS (A + B × C)
636
+ # =========================================================================
637
+
638
+ def _test_expr_add_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
639
+ """Test A + B × C expression circuit (order of operations)."""
640
+ pop_size = next(iter(pop.values())).shape[0]
641
+
642
+ if debug:
643
+ print(f"\n=== ORDER OF OPERATIONS (A + B × C) ===")
644
+
645
+ prefix = 'arithmetic.expr_add_mul'
646
+ bits = 8
647
+
648
+ # Test cases for order of operations
649
+ test_cases = []
650
+
651
+ # Specific examples from roadmap
652
+ test_cases.extend([
653
+ (5, 3, 2), # 5 + 3 × 2 = 5 + 6 = 11
654
+ (10, 4, 3), # 10 + 4 × 3 = 10 + 12 = 22
655
+ (1, 10, 10), # 1 + 10 × 10 = 1 + 100 = 101
656
+ (0, 15, 17), # 0 + 15 × 17 = 255
657
+ (1, 15, 17), # 1 + 15 × 17 = 256 -> 0 (overflow)
658
+ (100, 5, 5), # 100 + 5 × 5 = 100 + 25 = 125
659
+ ])
660
+
661
+ # Edge cases
662
+ test_cases.extend([
663
+ (0, 0, 0), # 0 + 0 × 0 = 0
664
+ (255, 0, 0), # 255 + 0 × 0 = 255
665
+ (0, 255, 1), # 0 + 255 × 1 = 255
666
+ (0, 1, 255), # 0 + 1 × 255 = 255
667
+ (1, 1, 1), # 1 + 1 × 1 = 2
668
+ (0, 16, 16), # 0 + 16 × 16 = 256 -> 0 (overflow)
669
+ ])
670
+
671
+ # Systematic small values
672
+ for a in [0, 1, 5, 10]:
673
+ for b in [0, 1, 2, 3]:
674
+ for c in [0, 1, 2, 3]:
675
+ test_cases.append((a, b, c))
676
+
677
+ # Remove duplicates
678
+ test_cases = list(set(test_cases))
679
+
680
+ a_vals = torch.tensor([t[0] for t in test_cases], device=self.device)
681
+ b_vals = torch.tensor([t[1] for t in test_cases], device=self.device)
682
+ c_vals = torch.tensor([t[2] for t in test_cases], device=self.device)
683
+ num_tests = len(test_cases)
684
+
685
+ # Convert to bits [num_tests, bits] MSB-first
686
+ a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
687
+ b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
688
+ c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
689
+
690
+ # Evaluate mask stage: mask[stage][bit] = B[bit] AND C[stage]
691
+ # In the circuit: mask.s[stage].b[bit] operates on positional bits
692
+ # stage 0 = LSB of C (c_bits[:, 7]), stage 7 = MSB of C (c_bits[:, 0])
693
+ # bit 0 = LSB of B (b_bits[:, 7]), bit 7 = MSB of B (b_bits[:, 0])
694
+ masks = torch.zeros(8, num_tests, pop_size, 8, device=self.device) # [stage, tests, pop, bits]
695
+ for stage in range(8):
696
+ c_stage_bit = c_bits[:, 7 - stage].unsqueeze(1).expand(-1, pop_size) # C[stage]
697
+ for bit in range(8):
698
+ b_bit_val = b_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) # B[bit]
699
+ # AND gate
700
+ w = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.weight')
701
+ bias = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.bias')
702
+ if w is not None and bias is not None:
703
+ w = w.squeeze(-1) # [pop]
704
+ b_tensor = bias.squeeze(-1) # [pop]
705
+ # Properly broadcast for batch evaluation
706
+ inp = torch.stack([b_bit_val, c_stage_bit], dim=-1) # [tests, pop, 2]
707
+ out = heaviside(torch.einsum('tpi,pi->tp', inp, w) + b_tensor)
708
+ masks[stage, :, :, bit] = out
709
+
710
+ # Accumulator stages
711
+ # acc[0] = mask[0] (no shift)
712
+ # acc[1] = acc[0] + (mask[1] << 1)
713
+ # ...
714
+ # acc[7] = acc[6] + (mask[7] << 7)
715
+ acc = masks[0].clone() # [tests, pop, 8] - start with mask[0]
716
+
717
+ for stage in range(1, 8):
718
+ # Create shifted mask: (mask[stage] << stage)
719
+ # Shift left by 'stage' positions: bits 0..stage-1 become 0, bit k becomes mask[stage][k-stage]
720
+ shifted_mask = torch.zeros(num_tests, pop_size, 8, device=self.device)
721
+ for bit in range(8):
722
+ if bit >= stage:
723
+ shifted_mask[:, :, bit] = masks[stage, :, :, bit - stage]
724
+ # else: remains 0
725
+
726
+ # Add acc + shifted_mask using full adders
727
+ carry = torch.zeros(num_tests, pop_size, device=self.device)
728
+ new_acc = torch.zeros(num_tests, pop_size, 8, device=self.device)
729
+ for bit in range(8):
730
+ s, carry = self._eval_single_fa(
731
+ pop, f'{prefix}.mul.acc.s{stage}.fa{bit}',
732
+ acc[:, :, bit],
733
+ shifted_mask[:, :, bit],
734
+ carry
735
+ )
736
+ new_acc[:, :, bit] = s
737
+ acc = new_acc
738
+
739
+ # Final add stage: A + acc (multiplication result)
740
+ carry = torch.zeros(num_tests, pop_size, device=self.device)
741
+ result_bits = []
742
+ for bit in range(8):
743
+ a_bit_val = a_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size)
744
+ s, carry = self._eval_single_fa(
745
+ pop, f'{prefix}.add.fa{bit}',
746
+ a_bit_val,
747
+ acc[:, :, bit],
748
+ carry
749
+ )
750
+ result_bits.append(s)
751
+
752
+ # Reconstruct result
753
+ result_bits = torch.stack(result_bits[::-1], dim=-1) # MSB first
754
+ result = torch.zeros(num_tests, pop_size, device=self.device)
755
+ for i in range(bits):
756
+ result += result_bits[:, :, i] * (1 << (bits - 1 - i))
757
+
758
+ # Expected: A + (B × C), with 8-bit wrap
759
+ expected = ((a_vals + b_vals * c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float()
760
+ correct = (result == expected).float().sum(0)
761
+
762
+ failures = []
763
+ if pop_size == 1:
764
+ for i in range(min(num_tests, 100)):
765
+ if result[i, 0].item() != expected[i, 0].item():
766
+ failures.append((
767
+ [int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())],
768
+ int(expected[i, 0].item()),
769
+ int(result[i, 0].item())
770
+ ))
771
+
772
+ self._record(prefix, int(correct[0].item()), num_tests, failures)
773
+ if debug:
774
+ r = self.results[-1]
775
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
776
+ if failures:
777
+ for inp, exp, got in failures[:5]:
778
+ print(f" FAIL: {inp[0]} + {inp[1]} × {inp[2]} = {exp}, got {got}")
779
+
780
+ return correct, num_tests
781
+
782
  # =========================================================================
783
  # COMPARATORS
784
  # =========================================================================
 
2598
  total_tests += t
2599
  self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
2600
 
2601
+ # Order of operations (A + B × C)
2602
+ s, t = self._test_expr_add_mul(population, debug)
2603
+ scores += s
2604
+ total_tests += t
2605
+ self.category_scores['expr_add_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
2606
+
2607
  # Comparators
2608
  s, t = self._test_comparators(population, debug)
2609
  scores += s
neural_computer.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:270309b1ac57e808827cee555b6f6f9e3f14c37abe23fa21069db4ff251a0b72
3
- size 34552948
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaabeed4fa50c13129fe4f83f6a8f31b6ccd41de12e83c62448460881373fc3e
3
+ size 34838348