CharlesCNorton commited on
Commit
df088a9
·
1 Parent(s): a5b1777

Add 3-operand adder circuit (arithmetic.add3_8bit)

Browse files

- build.py: add_full_adder() and add_add3() functions
- build.py: infer_add3_inputs() for routing metadata
- eval.py: _test_add3() with 240 test cases including 15+27+33=75
- Fitness 1.000000, all tests pass

Files changed (4) hide show
  1. README.md +7 -5
  2. build.py +106 -1
  3. eval.py +110 -0
  4. neural_computer.safetensors +2 -2
README.md CHANGED
@@ -457,15 +457,17 @@ The interface generalizes to **all** 65,536 8-bit additions once trained—no me
457
 
458
  ### Extension Roadmap
459
 
460
- 1. **Multi-operand expressions (15 + 27 + 33)** — Accumulator pattern: result = 0; for each operand, result = ADD(result, operand). Router must fire multiple times per input sequence. Requires stateful dispatch or unrolled circuit.
461
 
462
- 2. **Order of operations (5 + 3 × 2 = 11)** — Parse expression into tree, evaluate depth-first. MUL before ADD. Requires either: (a) expression parser producing evaluation order, or (b) learned routing that implicitly respects precedence.
463
 
464
- 3. **Parenthetical expressions ((5 + 3) × 2 = 16)** Explicit grouping overrides precedence. Parser must recognize parens and build correct tree. Evaluation proceeds innermost-out. Adds complexity to extraction layer.
465
 
466
- 4. **16-bit operations (0-65535)** — Chain two 8-bit circuits with carry propagation. ADD16: low = ADD8(A_lo, B_lo), high = ADD8(A_hi, B_hi, carry_out). MUL16: four partial products + shift-add. Doubles operand extraction width.
467
 
468
- 5. **Floating point arithmetic** — IEEE 754-style with separate circuits for mantissa and exponent. ADD: align exponents, add mantissas, renormalize. MUL: add exponents, multiply mantissas. Requires sign handling, overflow detection, and rounding logic.
 
 
469
 
470
  ---
471
 
 
457
 
458
  ### Extension Roadmap
459
 
460
+ 1. **Order of operations (5 + 3 × 2 = 11)** — Parse expression into tree, evaluate depth-first. MUL before ADD. Requires either: (a) expression parser producing evaluation order, or (b) learned routing that implicitly respects precedence.
461
 
462
+ 2. **Parenthetical expressions ((5 + 3) × 2 = 16)** — Explicit grouping overrides precedence. Parser must recognize parens and build correct tree. Evaluation proceeds innermost-out. Adds complexity to extraction layer.
463
 
464
+ 3. **16-bit operations (0-65535)** Chain two 8-bit circuits with carry propagation. ADD16: low = ADD8(A_lo, B_lo), high = ADD8(A_hi, B_hi, carry_out). MUL16: four partial products + shift-add. Doubles operand extraction width.
465
 
466
+ 4. **Floating point arithmetic** — IEEE 754-style with separate circuits for mantissa and exponent. ADD: align exponents, add mantissas, renormalize. MUL: add exponents, multiply mantissas. Requires sign handling, overflow detection, and rounding logic.
467
 
468
+ ### Completed Extensions
469
+
470
+ - **3-operand addition (15 + 27 + 33 = 75)** — `arithmetic.add3_8bit` chains two 8-bit ripple carry stages. 16 full adders, 144 gates, 240 test cases verified.
471
 
472
  ---
473
 
build.py CHANGED
@@ -235,6 +235,51 @@ def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], addr_bits: in
235
  add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None:
239
  """Add SHL (shift left) and SHR (shift right) circuits.
240
 
@@ -604,6 +649,58 @@ def infer_ripplecarry_inputs(gate: str, prefix: str, bits: int, reg: SignalRegis
604
  return []
605
 
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  def infer_adcsbc_inputs(gate: str, prefix: str, is_sub: bool, reg: SignalRegistry) -> List[int]:
608
  for i in range(8):
609
  reg.register(f"{prefix}.$a[{i}]")
@@ -1080,6 +1177,8 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
1080
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry4bit", 4, reg)
1081
  if 'ripplecarry8bit' in gate:
1082
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
 
 
1083
  if 'adc8bit' in gate:
1084
  return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
1085
  if 'sbc8bit' in gate:
@@ -1305,7 +1404,7 @@ def cmd_alu(args) -> None:
1305
  "alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
1306
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
1307
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
1308
- "arithmetic.equality8bit.",
1309
  "control.push.", "control.pop.", "control.ret.",
1310
  "combinational.barrelshifter.", "combinational.priorityencoder.",
1311
  ])
@@ -1370,6 +1469,12 @@ def cmd_alu(args) -> None:
1370
  print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)")
1371
  except ValueError as e:
1372
  print(f" Comparators already exist: {e}")
 
 
 
 
 
 
1373
  if args.apply:
1374
  print(f"\nSaving: {args.model}")
1375
  save_file(tensors, str(args.model))
 
235
  add_gate(tensors, f"control.mem_addr.bit{bit}", [1.0], [-1.0])
236
 
237
 
238
+ def add_full_adder(tensors: Dict[str, torch.Tensor], prefix: str) -> None:
239
+ """Add a single full adder at the given prefix.
240
+
241
+ Full adder structure:
242
+ - ha1: first half adder (A XOR B for sum, A AND B for carry)
243
+ - ha2: second half adder (ha1.sum XOR Cin for sum, ha1.sum AND Cin for carry)
244
+ - carry_or: OR of ha1.carry and ha2.carry for final carry out
245
+ """
246
+ # XOR for ha1.sum (2-layer: OR + NAND -> AND)
247
+ add_gate(tensors, f"{prefix}.ha1.sum.layer1.or", [1.0, 1.0], [-1.0])
248
+ add_gate(tensors, f"{prefix}.ha1.sum.layer1.nand", [-1.0, -1.0], [1.0])
249
+ add_gate(tensors, f"{prefix}.ha1.sum.layer2", [1.0, 1.0], [-2.0])
250
+ # AND for ha1.carry
251
+ add_gate(tensors, f"{prefix}.ha1.carry", [1.0, 1.0], [-2.0])
252
+ # XOR for ha2.sum
253
+ add_gate(tensors, f"{prefix}.ha2.sum.layer1.or", [1.0, 1.0], [-1.0])
254
+ add_gate(tensors, f"{prefix}.ha2.sum.layer1.nand", [-1.0, -1.0], [1.0])
255
+ add_gate(tensors, f"{prefix}.ha2.sum.layer2", [1.0, 1.0], [-2.0])
256
+ # AND for ha2.carry
257
+ add_gate(tensors, f"{prefix}.ha2.carry", [1.0, 1.0], [-2.0])
258
+ # OR for final carry
259
+ add_gate(tensors, f"{prefix}.carry_or", [1.0, 1.0], [-1.0])
260
+
261
+
262
+ def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
263
+ """Add 3-operand 8-bit adder circuit.
264
+
265
+ Computes A + B + C using two chained ripple-carry stages:
266
+ - Stage 1: temp = A + B (8 full adders)
267
+ - Stage 2: result = temp + C (8 full adders)
268
+
269
+ Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first)
270
+ Outputs: stage2.fa0-7.ha2.sum.layer2 (result bits), stage2.fa7.carry_or (overflow)
271
+
272
+ Total: 16 full adders = 144 gates
273
+ """
274
+ # Stage 1: A + B -> temp
275
+ for bit in range(8):
276
+ add_full_adder(tensors, f"arithmetic.add3_8bit.stage1.fa{bit}")
277
+
278
+ # Stage 2: temp + C -> result
279
+ for bit in range(8):
280
+ add_full_adder(tensors, f"arithmetic.add3_8bit.stage2.fa{bit}")
281
+
282
+
283
  def add_shl_shr(tensors: Dict[str, torch.Tensor]) -> None:
284
  """Add SHL (shift left) and SHR (shift right) circuits.
285
 
 
649
  return []
650
 
651
 
652
+ def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
653
+ """Infer inputs for 3-operand adder: A + B + C."""
654
+ prefix = "arithmetic.add3_8bit"
655
+ # Register all inputs
656
+ for i in range(8):
657
+ reg.register(f"$a[{i}]")
658
+ reg.register(f"$b[{i}]")
659
+ reg.register(f"$c[{i}]")
660
+
661
+ # Parse stage and bit
662
+ if '.stage1.' in gate:
663
+ m = re.search(r'\.fa(\d+)\.', gate)
664
+ if not m:
665
+ return []
666
+ bit = int(m.group(1))
667
+ # Stage 1: A + B (LSB is index 7 in MSB-first)
668
+ a_bit = reg.get_id(f"$a[{7-bit}]")
669
+ b_bit = reg.get_id(f"$b[{7-bit}]")
670
+ cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage1.fa{bit-1}.carry_or")
671
+ fa_prefix = f"{prefix}.stage1.fa{bit}"
672
+ elif '.stage2.' in gate:
673
+ m = re.search(r'\.fa(\d+)\.', gate)
674
+ if not m:
675
+ return []
676
+ bit = int(m.group(1))
677
+ # Stage 2: stage1_result + C
678
+ temp_bit = reg.register(f"{prefix}.stage1.fa{bit}.ha2.sum.layer2")
679
+ c_bit = reg.get_id(f"$c[{7-bit}]")
680
+ cin = reg.get_id("#0") if bit == 0 else reg.register(f"{prefix}.stage2.fa{bit-1}.carry_or")
681
+ a_bit = temp_bit
682
+ b_bit = c_bit
683
+ fa_prefix = f"{prefix}.stage2.fa{bit}"
684
+ else:
685
+ return []
686
+
687
+ if '.ha1.sum.layer1' in gate:
688
+ return [a_bit, b_bit]
689
+ if '.ha1.sum.layer2' in gate:
690
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
691
+ if '.ha1.carry' in gate and '.layer' not in gate:
692
+ return [a_bit, b_bit]
693
+ if '.ha2.sum.layer1' in gate:
694
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
695
+ if '.ha2.sum.layer2' in gate:
696
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
697
+ if '.ha2.carry' in gate and '.layer' not in gate:
698
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
699
+ if '.carry_or' in gate:
700
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
701
+ return []
702
+
703
+
704
  def infer_adcsbc_inputs(gate: str, prefix: str, is_sub: bool, reg: SignalRegistry) -> List[int]:
705
  for i in range(8):
706
  reg.register(f"{prefix}.$a[{i}]")
 
1177
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry4bit", 4, reg)
1178
  if 'ripplecarry8bit' in gate:
1179
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
1180
+ if 'add3_8bit' in gate:
1181
+ return infer_add3_inputs(gate, reg)
1182
  if 'adc8bit' in gate:
1183
  return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
1184
  if 'sbc8bit' in gate:
 
1404
  "alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
1405
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
1406
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
1407
+ "arithmetic.equality8bit.", "arithmetic.add3_8bit.",
1408
  "control.push.", "control.pop.", "control.ret.",
1409
  "combinational.barrelshifter.", "combinational.priorityencoder.",
1410
  ])
 
1469
  print(" Added GT, GE, LT, LE (single-layer), EQ (two-layer)")
1470
  except ValueError as e:
1471
  print(f" Comparators already exist: {e}")
1472
+ print("\nGenerating 3-operand adder circuit...")
1473
+ try:
1474
+ add_add3(tensors)
1475
+ print(" Added ADD3 (16 full adders = 144 gates)")
1476
+ except ValueError as e:
1477
+ print(f" ADD3 already exists: {e}")
1478
  if args.apply:
1479
  print(f"\nSaving: {args.model}")
1480
  save_file(tensors, str(args.model))
eval.py CHANGED
@@ -527,6 +527,110 @@ class BatchedFitnessEvaluator:
527
 
528
  return correct, num_tests
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  # =========================================================================
531
  # COMPARATORS
532
  # =========================================================================
@@ -2340,6 +2444,12 @@ class BatchedFitnessEvaluator:
2340
  total_tests += t
2341
  self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
2342
 
 
 
 
 
 
 
2343
  # Comparators
2344
  s, t = self._test_comparators(population, debug)
2345
  scores += s
 
527
 
528
  return correct, num_tests
529
 
530
+ # =========================================================================
531
+ # 3-OPERAND ADDER
532
+ # =========================================================================
533
+
534
+ def _test_add3(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
535
+ """Test 3-operand 8-bit adder (A + B + C)."""
536
+ pop_size = next(iter(pop.values())).shape[0]
537
+
538
+ if debug:
539
+ print(f"\n=== 3-OPERAND ADDER ===")
540
+
541
+ prefix = 'arithmetic.add3_8bit'
542
+ bits = 8
543
+
544
+ # Strategic test cases for 3-operand addition
545
+ # Include edge cases and overflow scenarios
546
+ test_cases = []
547
+ # Small values
548
+ for a in [0, 1, 2]:
549
+ for b in [0, 1, 2]:
550
+ for c in [0, 1, 2]:
551
+ test_cases.append((a, b, c))
552
+ # Edge values
553
+ edge = [0, 1, 127, 128, 254, 255]
554
+ for a in edge:
555
+ for b in edge:
556
+ for c in edge:
557
+ test_cases.append((a, b, c))
558
+ # Specific multi-operand expression tests
559
+ test_cases.extend([
560
+ (15, 27, 33), # Example from roadmap: 15 + 27 + 33 = 75
561
+ (100, 100, 55), # = 255 (exact fit)
562
+ (100, 100, 56), # = 256 -> 0 (overflow)
563
+ (85, 85, 85), # = 255 (exact fit)
564
+ (86, 85, 85), # = 256 -> 0 (overflow)
565
+ ])
566
+ test_cases = list(set(test_cases))
567
+
568
+ a_vals = torch.tensor([t[0] for t in test_cases], device=self.device)
569
+ b_vals = torch.tensor([t[1] for t in test_cases], device=self.device)
570
+ c_vals = torch.tensor([t[2] for t in test_cases], device=self.device)
571
+ num_tests = len(test_cases)
572
+
573
+ # Convert to bits [num_tests, bits] MSB-first
574
+ a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
575
+ b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
576
+ c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
577
+
578
+ # Stage 1: A + B
579
+ carry1 = torch.zeros(num_tests, pop_size, device=self.device)
580
+ stage1_bits = []
581
+ for bit in range(bits):
582
+ bit_idx = bits - 1 - bit # LSB first
583
+ s, carry1 = self._eval_single_fa(
584
+ pop, f'{prefix}.stage1.fa{bit}',
585
+ a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
586
+ b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
587
+ carry1
588
+ )
589
+ stage1_bits.append(s)
590
+
591
+ # Stage 2: stage1_result + C
592
+ carry2 = torch.zeros(num_tests, pop_size, device=self.device)
593
+ result_bits = []
594
+ for bit in range(bits):
595
+ bit_idx = bits - 1 - bit # LSB first
596
+ s, carry2 = self._eval_single_fa(
597
+ pop, f'{prefix}.stage2.fa{bit}',
598
+ stage1_bits[bit], # Already [num_tests, pop_size]
599
+ c_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
600
+ carry2
601
+ )
602
+ result_bits.append(s)
603
+
604
+ # Reconstruct result (bits are in LSB-first order, need to reverse for MSB-first)
605
+ result_bits = torch.stack(result_bits[::-1], dim=-1) # MSB first
606
+ result = torch.zeros(num_tests, pop_size, device=self.device)
607
+ for i in range(bits):
608
+ result += result_bits[:, :, i] * (1 << (bits - 1 - i))
609
+
610
+ # Expected (8-bit wrap)
611
+ expected = ((a_vals + b_vals + c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float()
612
+ correct = (result == expected).float().sum(0)
613
+
614
+ failures = []
615
+ if pop_size == 1:
616
+ for i in range(min(num_tests, 100)):
617
+ if result[i, 0].item() != expected[i, 0].item():
618
+ failures.append((
619
+ [int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())],
620
+ int(expected[i, 0].item()),
621
+ int(result[i, 0].item())
622
+ ))
623
+
624
+ self._record(prefix, int(correct[0].item()), num_tests, failures)
625
+ if debug:
626
+ r = self.results[-1]
627
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
628
+ if failures:
629
+ for inp, exp, got in failures[:5]:
630
+ print(f" FAIL: {inp[0]} + {inp[1]} + {inp[2]} = {exp}, got {got}")
631
+
632
+ return correct, num_tests
633
+
634
  # =========================================================================
635
  # COMPARATORS
636
  # =========================================================================
 
2444
  total_tests += t
2445
  self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
2446
 
2447
+ # 3-operand adder
2448
+ s, t = self._test_add3(population, debug)
2449
+ scores += s
2450
+ total_tests += t
2451
+ self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
2452
+
2453
  # Comparators
2454
  s, t = self._test_comparators(population, debug)
2455
  scores += s
neural_computer.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8f97c127018647da3a788ee40cbe498ee583d2031bbec04e9347894b1fb5c19
3
- size 34491396
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:270309b1ac57e808827cee555b6f6f9e3f14c37abe23fa21069db4ff251a0b72
3
+ size 34552948