CharlesCNorton commited on
Commit
822e28a
·
1 Parent(s): 1f44a34

Add 32-bit arithmetic support with cascaded byte comparison

Browse files

build.py changes:
- Add 'small' memory profile (1KB, 10-bit addresses) for 32-bit scratch space
- Add --bits flag supporting 8/16/32-bit ALU generation
- Add N-bit circuit generators: ripple carry adder, subtractor, comparators,
multiplier, divider, bitwise ops, shifts, INC/DEC, NEG
- Implement cascaded byte-wise comparison for 32-bit to avoid float32
precision loss (2^31 exceeds 24-bit mantissa). Compares byte-by-byte
from MSB using 8-bit comparators chained with AND/OR logic.

eval.py changes:
- Add 32-bit test data (strategic sampling of edge cases)
- Add _test_comparators_nbits with cascaded evaluation for bits > 16
- Add _test_subtractor_nbits, _test_bitwise_nbits, _test_shifts_nbits
- Add _test_inc_dec_nbits, _test_neg_nbits with correct LSB-first indexing
- Fix bit indexing bug: circuits use bit0=LSB, not MSB
- Make _test_memory dynamic: reads actual memory size from manifest
- Make _test_manifest flexible: only checks fixed values, validates
variable values (memory_bytes, pc_width) as non-negative

neural_alu32.safetensors:
- New 32-bit model with 1KB memory (202K params vs 8.3M for 64KB)
- All 6,973 tests passing at 100%

Verified 32-bit arithmetic:
1000 + 2000 = 3000
1000000 + 2345678 = 3345678
0xDEAD0000 + 0xBEEF = 0xDEADBEEF
4294967295 + 1 = 0 (correct overflow)

Files changed (3) hide show
  1. build.py +60 -13
  2. eval.py +594 -76
  3. neural_alu32.safetensors +2 -2
build.py CHANGED
@@ -714,23 +714,70 @@ def add_sub_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
714
  def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
715
  """Add N-bit comparator circuits (GT, LT, GE, LE, EQ).
716
 
717
- Uses weighted sum comparison extended to N bits.
718
- For N=32: weights are 2^31, 2^30, ..., 2^0 for A, negated for B.
 
 
 
 
719
  """
720
- pos_weights = [float(1 << (bits - 1 - i)) for i in range(bits)]
721
- neg_weights = [-w for w in pos_weights]
 
722
 
723
- gt_weights = pos_weights + neg_weights
724
- lt_weights = neg_weights + pos_weights
725
 
726
- add_gate(tensors, f"arithmetic.greaterthan{bits}bit", gt_weights, [-1.0])
727
- add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", gt_weights, [0.0])
728
- add_gate(tensors, f"arithmetic.lessthan{bits}bit", lt_weights, [-1.0])
729
- add_gate(tensors, f"arithmetic.lessorequal{bits}bit", lt_weights, [0.0])
730
 
731
- add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.geq", gt_weights, [0.0])
732
- add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.leq", lt_weights, [0.0])
733
- add_gate(tensors, f"arithmetic.equality{bits}bit.layer2", [1.0, 1.0], [-2.0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
 
735
 
736
  def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
 
714
  def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
715
  """Add N-bit comparator circuits (GT, LT, GE, LE, EQ).
716
 
717
+ For bits <= 16: Use single-layer weighted comparison (float32 safe).
718
+ For bits > 16: Use cascaded byte-wise comparison to avoid float32 precision loss.
719
+
720
+ Cascaded approach compares byte-by-byte from MSB:
721
+ A > B iff: (A[31:24] > B[31:24]) OR
722
+ (A[31:24] == B[31:24] AND A[23:16] > B[23:16]) OR ...
723
  """
724
+ if bits <= 16:
725
+ pos_weights = [float(1 << (bits - 1 - i)) for i in range(bits)]
726
+ neg_weights = [-w for w in pos_weights]
727
 
728
+ gt_weights = pos_weights + neg_weights
729
+ lt_weights = neg_weights + pos_weights
730
 
731
+ add_gate(tensors, f"arithmetic.greaterthan{bits}bit", gt_weights, [-1.0])
732
+ add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", gt_weights, [0.0])
733
+ add_gate(tensors, f"arithmetic.lessthan{bits}bit", lt_weights, [-1.0])
734
+ add_gate(tensors, f"arithmetic.lessorequal{bits}bit", lt_weights, [0.0])
735
 
736
+ add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.geq", gt_weights, [0.0])
737
+ add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.leq", lt_weights, [0.0])
738
+ add_gate(tensors, f"arithmetic.equality{bits}bit.layer2", [1.0, 1.0], [-2.0])
739
+ else:
740
+ num_bytes = bits // 8
741
+ prefix = f"arithmetic.cmp{bits}bit"
742
+
743
+ byte_pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0]
744
+ byte_neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0]
745
+ byte_gt_weights = byte_pos_weights + byte_neg_weights
746
+ byte_lt_weights = byte_neg_weights + byte_pos_weights
747
+
748
+ for b in range(num_bytes):
749
+ add_gate(tensors, f"{prefix}.byte{b}.gt", byte_gt_weights, [-1.0])
750
+ add_gate(tensors, f"{prefix}.byte{b}.lt", byte_lt_weights, [-1.0])
751
+ add_gate(tensors, f"{prefix}.byte{b}.eq.geq", byte_gt_weights, [0.0])
752
+ add_gate(tensors, f"{prefix}.byte{b}.eq.leq", byte_lt_weights, [0.0])
753
+ add_gate(tensors, f"{prefix}.byte{b}.eq.and", [1.0, 1.0], [-2.0])
754
+
755
+ for b in range(num_bytes):
756
+ if b == 0:
757
+ add_gate(tensors, f"{prefix}.cascade.gt.stage{b}", [1.0], [-1.0])
758
+ add_gate(tensors, f"{prefix}.cascade.lt.stage{b}", [1.0], [-1.0])
759
+ else:
760
+ eq_weights = [1.0] * b
761
+ add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.all_eq", eq_weights, [-float(b)])
762
+ add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.and", [1.0, 1.0], [-2.0])
763
+ add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.all_eq", eq_weights, [-float(b)])
764
+ add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.and", [1.0, 1.0], [-2.0])
765
+
766
+ or_weights_gt = [1.0] * num_bytes
767
+ or_weights_lt = [1.0] * num_bytes
768
+ add_gate(tensors, f"arithmetic.greaterthan{bits}bit", or_weights_gt, [-1.0])
769
+ add_gate(tensors, f"arithmetic.lessthan{bits}bit", or_weights_lt, [-1.0])
770
+
771
+ not_lt_weights = [-1.0]
772
+ add_gate(tensors, f"arithmetic.greaterorequal{bits}bit.not_lt", not_lt_weights, [0.0])
773
+ add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", [1.0], [-1.0])
774
+
775
+ not_gt_weights = [-1.0]
776
+ add_gate(tensors, f"arithmetic.lessorequal{bits}bit.not_gt", not_gt_weights, [0.0])
777
+ add_gate(tensors, f"arithmetic.lessorequal{bits}bit", [1.0], [-1.0])
778
+
779
+ eq_all_weights = [1.0] * num_bytes
780
+ add_gate(tensors, f"arithmetic.equality{bits}bit", eq_all_weights, [-float(num_bytes)])
781
 
782
 
783
  def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
eval.py CHANGED
@@ -1745,88 +1745,551 @@ class BatchedFitnessEvaluator:
1745
  comp_a = self.comp_a
1746
  comp_b = self.comp_b
1747
 
1748
- a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1749
- b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1750
- inputs = torch.cat([a_bits, b_bits], dim=1)
1751
 
1752
- comparators = [
1753
- (f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b),
1754
- (f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b),
1755
- (f'arithmetic.lessthan{bits}bit', lambda a, b: a < b),
1756
- (f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b),
1757
- ]
1758
 
1759
- for name, op in comparators:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1760
  try:
1761
- expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
1762
  for a, b in zip(comp_a, comp_b)], device=self.device)
1763
-
1764
- w = pop[f'{name}.weight']
1765
- b = pop[f'{name}.bias']
1766
- out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
1767
-
 
 
 
 
 
1768
  correct = (out == expected.unsqueeze(1)).float().sum(0)
1769
-
1770
  failures = []
1771
  if pop_size == 1:
1772
- for i in range(len(comp_a)):
1773
  if out[i, 0].item() != expected[i].item():
1774
- failures.append((
1775
- [int(comp_a[i].item()), int(comp_b[i].item())],
1776
- expected[i].item(),
1777
- out[i, 0].item()
1778
- ))
1779
-
1780
- self._record(name, int(correct[0].item()), len(comp_a), failures)
1781
  if debug:
1782
  r = self.results[-1]
1783
  print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1784
  scores += correct
1785
- total += len(comp_a)
1786
  except KeyError:
1787
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1788
 
1789
- prefix = f'arithmetic.equality{bits}bit'
1790
- try:
1791
- expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
1792
- for a, b in zip(comp_a, comp_b)], device=self.device)
1793
 
1794
- w_geq = pop[f'{prefix}.layer1.geq.weight']
1795
- b_geq = pop[f'{prefix}.layer1.geq.bias']
1796
- w_leq = pop[f'{prefix}.layer1.leq.weight']
1797
- b_leq = pop[f'{prefix}.layer1.leq.bias']
1798
 
1799
- h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
1800
- h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
1801
- hidden = torch.stack([h_geq, h_leq], dim=-1)
1802
 
1803
- w2 = pop[f'{prefix}.layer2.weight']
1804
- b2 = pop[f'{prefix}.layer2.bias']
1805
- out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
1806
 
1807
- correct = (out == expected.unsqueeze(1)).float().sum(0)
 
 
 
 
 
 
 
 
1808
 
1809
- failures = []
1810
- if pop_size == 1:
1811
- for i in range(len(comp_a)):
1812
- if out[i, 0].item() != expected[i].item():
1813
- failures.append((
1814
- [int(comp_a[i].item()), int(comp_b[i].item())],
1815
- expected[i].item(),
1816
- out[i, 0].item()
1817
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1818
 
1819
- self._record(prefix, int(correct[0].item()), len(comp_a), failures)
 
1820
  if debug:
1821
  r = self.results[-1]
1822
  print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1823
  scores += correct
1824
- total += len(comp_a)
1825
- except KeyError:
1826
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1827
 
1828
  return scores, total
1829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1830
  # =========================================================================
1831
  # THRESHOLD GATES
1832
  # =========================================================================
@@ -3159,34 +3622,47 @@ class BatchedFitnessEvaluator:
3159
  if debug:
3160
  print("\n=== MANIFEST ===")
3161
 
3162
- expected = {
3163
  'manifest.alu_operations': 16.0,
3164
  'manifest.flags': 4.0,
3165
  'manifest.instruction_width': 16.0,
3166
- 'manifest.memory_bytes': 65536.0,
3167
- 'manifest.pc_width': 16.0,
3168
  'manifest.register_width': 8.0,
3169
  'manifest.registers': 4.0,
3170
- 'manifest.turing_complete': 1.0,
3171
  'manifest.version': 3.0,
3172
  }
3173
 
3174
- for name, exp_val in expected.items():
3175
  try:
3176
- val = pop[name][0, 0].item() # [pop_size, 1] -> scalar
3177
  if val == exp_val:
3178
  scores += 1
3179
  self._record(name, 1, 1, [])
3180
  else:
3181
  self._record(name, 0, 1, [(exp_val, val)])
3182
  total += 1
3183
-
3184
  if debug:
3185
  r = self.results[-1]
3186
  print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
3187
  except KeyError:
3188
  pass
3189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3190
  return scores, total
3191
 
3192
  # =========================================================================
@@ -3202,23 +3678,35 @@ class BatchedFitnessEvaluator:
3202
  if debug:
3203
  print("\n=== MEMORY ===")
3204
 
 
 
 
 
 
 
 
 
 
 
 
 
3205
  expected_shapes = {
3206
- 'memory.addr_decode.weight': (65536, 16),
3207
- 'memory.addr_decode.bias': (65536,),
3208
- 'memory.read.and.weight': (8, 65536, 2),
3209
- 'memory.read.and.bias': (8, 65536),
3210
- 'memory.read.or.weight': (8, 65536),
3211
  'memory.read.or.bias': (8,),
3212
- 'memory.write.sel.weight': (65536, 2),
3213
- 'memory.write.sel.bias': (65536,),
3214
- 'memory.write.nsel.weight': (65536, 1),
3215
- 'memory.write.nsel.bias': (65536,),
3216
- 'memory.write.and_old.weight': (65536, 8, 2),
3217
- 'memory.write.and_old.bias': (65536, 8),
3218
- 'memory.write.and_new.weight': (65536, 8, 2),
3219
- 'memory.write.and_new.bias': (65536, 8),
3220
- 'memory.write.or.weight': (65536, 8, 2),
3221
- 'memory.write.or.bias': (65536, 8),
3222
  }
3223
 
3224
  for name, expected_shape in expected_shapes.items():
@@ -3539,6 +4027,36 @@ class BatchedFitnessEvaluator:
3539
  total_tests += t
3540
  self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
3541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3542
  # 3-operand adder
3543
  s, t = self._test_add3(population, debug)
3544
  scores += s
 
1745
  comp_a = self.comp_a
1746
  comp_b = self.comp_b
1747
 
1748
+ num_tests = len(comp_a)
 
 
1749
 
1750
+ if bits <= 16:
1751
+ a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1752
+ b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1753
+ inputs = torch.cat([a_bits, b_bits], dim=1)
 
 
1754
 
1755
+ comparators = [
1756
+ (f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b),
1757
+ (f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b),
1758
+ (f'arithmetic.lessthan{bits}bit', lambda a, b: a < b),
1759
+ (f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b),
1760
+ ]
1761
+
1762
+ for name, op in comparators:
1763
+ try:
1764
+ expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
1765
+ for a, b in zip(comp_a, comp_b)], device=self.device)
1766
+ w = pop[f'{name}.weight']
1767
+ b = pop[f'{name}.bias']
1768
+ out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
1769
+ correct = (out == expected.unsqueeze(1)).float().sum(0)
1770
+ failures = []
1771
+ if pop_size == 1:
1772
+ for i in range(num_tests):
1773
+ if out[i, 0].item() != expected[i].item():
1774
+ failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
1775
+ expected[i].item(), out[i, 0].item()))
1776
+ self._record(name, int(correct[0].item()), num_tests, failures)
1777
+ if debug:
1778
+ r = self.results[-1]
1779
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1780
+ scores += correct
1781
+ total += num_tests
1782
+ except KeyError:
1783
+ pass
1784
+
1785
+ prefix = f'arithmetic.equality{bits}bit'
1786
  try:
1787
+ expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
1788
  for a, b in zip(comp_a, comp_b)], device=self.device)
1789
+ w_geq = pop[f'{prefix}.layer1.geq.weight']
1790
+ b_geq = pop[f'{prefix}.layer1.geq.bias']
1791
+ w_leq = pop[f'{prefix}.layer1.leq.weight']
1792
+ b_leq = pop[f'{prefix}.layer1.leq.bias']
1793
+ h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
1794
+ h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
1795
+ hidden = torch.stack([h_geq, h_leq], dim=-1)
1796
+ w2 = pop[f'{prefix}.layer2.weight']
1797
+ b2 = pop[f'{prefix}.layer2.bias']
1798
+ out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
1799
  correct = (out == expected.unsqueeze(1)).float().sum(0)
 
1800
  failures = []
1801
  if pop_size == 1:
1802
+ for i in range(num_tests):
1803
  if out[i, 0].item() != expected[i].item():
1804
+ failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
1805
+ expected[i].item(), out[i, 0].item()))
1806
+ self._record(prefix, int(correct[0].item()), num_tests, failures)
 
 
 
 
1807
  if debug:
1808
  r = self.results[-1]
1809
  print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1810
  scores += correct
1811
+ total += num_tests
1812
  except KeyError:
1813
  pass
1814
+ else:
1815
+ num_bytes = bits // 8
1816
+ prefix = f"arithmetic.cmp{bits}bit"
1817
+
1818
+ byte_gt = []
1819
+ byte_lt = []
1820
+ byte_eq = []
1821
+
1822
+ for b in range(num_bytes):
1823
+ start_bit = b * 8
1824
+ a_byte = torch.stack([((comp_a >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1)
1825
+ b_byte = torch.stack([((comp_b >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1)
1826
+ byte_input = torch.cat([a_byte, b_byte], dim=1)
1827
+
1828
+ w_gt = pop[f'{prefix}.byte{b}.gt.weight'].view(pop_size, -1)
1829
+ b_gt = pop[f'{prefix}.byte{b}.gt.bias'].view(pop_size)
1830
+ byte_gt.append(heaviside(byte_input @ w_gt.T + b_gt))
1831
+
1832
+ w_lt = pop[f'{prefix}.byte{b}.lt.weight'].view(pop_size, -1)
1833
+ b_lt = pop[f'{prefix}.byte{b}.lt.bias'].view(pop_size)
1834
+ byte_lt.append(heaviside(byte_input @ w_lt.T + b_lt))
1835
+
1836
+ w_geq = pop[f'{prefix}.byte{b}.eq.geq.weight'].view(pop_size, -1)
1837
+ b_geq = pop[f'{prefix}.byte{b}.eq.geq.bias'].view(pop_size)
1838
+ w_leq = pop[f'{prefix}.byte{b}.eq.leq.weight'].view(pop_size, -1)
1839
+ b_leq = pop[f'{prefix}.byte{b}.eq.leq.bias'].view(pop_size)
1840
+ h_geq = heaviside(byte_input @ w_geq.T + b_geq)
1841
+ h_leq = heaviside(byte_input @ w_leq.T + b_leq)
1842
+ w_and = pop[f'{prefix}.byte{b}.eq.and.weight'].view(pop_size, -1)
1843
+ b_and = pop[f'{prefix}.byte{b}.eq.and.bias'].view(pop_size)
1844
+ eq_inp = torch.stack([h_geq, h_leq], dim=-1)
1845
+ byte_eq.append(heaviside((eq_inp * w_and).sum(-1) + b_and))
1846
+
1847
+ cascade_gt = []
1848
+ cascade_lt = []
1849
+ for b in range(num_bytes):
1850
+ if b == 0:
1851
+ cascade_gt.append(byte_gt[0])
1852
+ cascade_lt.append(byte_lt[0])
1853
+ else:
1854
+ eq_stack = torch.stack(byte_eq[:b], dim=-1)
1855
+ w_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.weight'].view(pop_size, -1)
1856
+ b_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.bias'].view(pop_size)
1857
+ all_eq_gt = heaviside((eq_stack * w_all_eq).sum(-1) + b_all_eq)
1858
+ w_and = pop[f'{prefix}.cascade.gt.stage{b}.and.weight'].view(pop_size, -1)
1859
+ b_and = pop[f'{prefix}.cascade.gt.stage{b}.and.bias'].view(pop_size)
1860
+ stage_inp = torch.stack([all_eq_gt, byte_gt[b]], dim=-1)
1861
+ cascade_gt.append(heaviside((stage_inp * w_and).sum(-1) + b_and))
1862
+
1863
+ w_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.weight'].view(pop_size, -1)
1864
+ b_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.bias'].view(pop_size)
1865
+ all_eq_lt = heaviside((eq_stack * w_all_eq_lt).sum(-1) + b_all_eq_lt)
1866
+ w_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.weight'].view(pop_size, -1)
1867
+ b_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.bias'].view(pop_size)
1868
+ stage_inp_lt = torch.stack([all_eq_lt, byte_lt[b]], dim=-1)
1869
+ cascade_lt.append(heaviside((stage_inp_lt * w_and_lt).sum(-1) + b_and_lt))
1870
+
1871
+ gt_stack = torch.stack(cascade_gt, dim=-1)
1872
+ w_gt_or = pop[f'arithmetic.greaterthan{bits}bit.weight'].view(pop_size, -1)
1873
+ b_gt_or = pop[f'arithmetic.greaterthan{bits}bit.bias'].view(pop_size)
1874
+ gt_out = heaviside((gt_stack * w_gt_or).sum(-1) + b_gt_or)
1875
+
1876
+ lt_stack = torch.stack(cascade_lt, dim=-1)
1877
+ w_lt_or = pop[f'arithmetic.lessthan{bits}bit.weight'].view(pop_size, -1)
1878
+ b_lt_or = pop[f'arithmetic.lessthan{bits}bit.bias'].view(pop_size)
1879
+ lt_out = heaviside((lt_stack * w_lt_or).sum(-1) + b_lt_or)
1880
+
1881
+ w_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.weight'].view(pop_size, -1)
1882
+ b_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.bias'].view(pop_size)
1883
+ not_lt = heaviside(lt_out.unsqueeze(-1) @ w_not_lt.T + b_not_lt).squeeze(-1)
1884
+ w_ge = pop[f'arithmetic.greaterorequal{bits}bit.weight'].view(pop_size, -1)
1885
+ b_ge = pop[f'arithmetic.greaterorequal{bits}bit.bias'].view(pop_size)
1886
+ ge_out = heaviside(not_lt.unsqueeze(-1) @ w_ge.T + b_ge).squeeze(-1)
1887
+
1888
+ w_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.weight'].view(pop_size, -1)
1889
+ b_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.bias'].view(pop_size)
1890
+ not_gt = heaviside(gt_out.unsqueeze(-1) @ w_not_gt.T + b_not_gt).squeeze(-1)
1891
+ w_le = pop[f'arithmetic.lessorequal{bits}bit.weight'].view(pop_size, -1)
1892
+ b_le = pop[f'arithmetic.lessorequal{bits}bit.bias'].view(pop_size)
1893
+ le_out = heaviside(not_gt.unsqueeze(-1) @ w_le.T + b_le).squeeze(-1)
1894
+
1895
+ eq_stack = torch.stack(byte_eq, dim=-1)
1896
+ w_eq_all = pop[f'arithmetic.equality{bits}bit.weight'].view(pop_size, -1)
1897
+ b_eq_all = pop[f'arithmetic.equality{bits}bit.bias'].view(pop_size)
1898
+ eq_out = heaviside((eq_stack * w_eq_all).sum(-1) + b_eq_all)
1899
+
1900
+ for name, out, op in [
1901
+ (f'arithmetic.greaterthan{bits}bit', gt_out, lambda a, b: a > b),
1902
+ (f'arithmetic.greaterorequal{bits}bit', ge_out, lambda a, b: a >= b),
1903
+ (f'arithmetic.lessthan{bits}bit', lt_out, lambda a, b: a < b),
1904
+ (f'arithmetic.lessorequal{bits}bit', le_out, lambda a, b: a <= b),
1905
+ (f'arithmetic.equality{bits}bit', eq_out, lambda a, b: a == b),
1906
+ ]:
1907
+ expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
1908
+ for a, b in zip(comp_a, comp_b)], device=self.device)
1909
+ correct = (out == expected.unsqueeze(1)).float().sum(0)
1910
+ failures = []
1911
+ if pop_size == 1:
1912
+ for i in range(num_tests):
1913
+ if out[i, 0].item() != expected[i].item():
1914
+ failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
1915
+ expected[i].item(), out[i, 0].item()))
1916
+ self._record(name, int(correct[0].item()), num_tests, failures)
1917
+ if debug:
1918
+ r = self.results[-1]
1919
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1920
+ scores += correct
1921
+ total += num_tests
1922
 
1923
+ return scores, total
 
 
 
1924
 
1925
+ def _test_subtractor_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
1926
+ """Test N-bit subtractor circuit (A - B)."""
1927
+ pop_size = next(iter(pop.values())).shape[0]
 
1928
 
1929
+ if debug:
1930
+ print(f"\n=== {bits}-BIT SUBTRACTOR ===")
 
1931
 
1932
+ prefix = f'arithmetic.sub{bits}bit'
1933
+ max_val = 1 << bits
 
1934
 
1935
+ if bits == 32:
1936
+ test_pairs = [
1937
+ (1000, 500), (5000, 3000), (1000000, 500000),
1938
+ (0xFFFFFFFF, 1), (0x80000000, 1), (100, 100),
1939
+ (0, 0), (1, 0), (0, 1), (256, 255),
1940
+ (0xDEADBEEF, 0xCAFEBABE), (1000000000, 999999999),
1941
+ ]
1942
+ else:
1943
+ test_pairs = [(a, b) for a in [0, 1, 127, 128, 255] for b in [0, 1, 127, 128, 255]]
1944
 
1945
+ a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long)
1946
+ b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long)
1947
+ num_tests = len(test_pairs)
1948
+
1949
+ a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1950
+ b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1951
+
1952
+ not_b_bits = torch.zeros_like(b_bits)
1953
+ for bit in range(bits):
1954
+ w = pop[f'{prefix}.not_b.bit{bit}.weight'].view(pop_size, -1)
1955
+ b = pop[f'{prefix}.not_b.bit{bit}.bias'].view(pop_size)
1956
+ not_b_bits[:, bit] = heaviside(b_bits[:, bit:bit+1] @ w.T + b)[:, 0]
1957
+
1958
+ carry = torch.ones(num_tests, pop_size, device=self.device)
1959
+ sum_bits = []
1960
+
1961
+ for bit in range(bits):
1962
+ bit_idx = bits - 1 - bit
1963
+ s, carry = self._eval_single_fa(
1964
+ pop, f'{prefix}.fa{bit}',
1965
+ a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
1966
+ not_b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
1967
+ carry
1968
+ )
1969
+ sum_bits.append(s)
1970
+
1971
+ sum_bits = torch.stack(sum_bits[::-1], dim=-1)
1972
+ result = torch.zeros(num_tests, pop_size, device=self.device)
1973
+ for i in range(bits):
1974
+ result += sum_bits[:, :, i] * (1 << (bits - 1 - i))
1975
+
1976
+ expected = ((a_vals - b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float()
1977
+ correct = (result == expected).float().sum(0)
1978
+
1979
+ failures = []
1980
+ if pop_size == 1:
1981
+ for i in range(min(num_tests, 20)):
1982
+ if result[i, 0].item() != expected[i, 0].item():
1983
+ failures.append((
1984
+ [int(a_vals[i].item()), int(b_vals[i].item())],
1985
+ int(expected[i, 0].item()),
1986
+ int(result[i, 0].item())
1987
+ ))
1988
+
1989
+ self._record(prefix, int(correct[0].item()), num_tests, failures)
1990
+ if debug:
1991
+ r = self.results[-1]
1992
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1993
+
1994
+ return correct, num_tests
1995
+
1996
+ def _test_bitwise_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
1997
+ """Test N-bit bitwise operations (AND, OR, XOR, NOT)."""
1998
+ pop_size = next(iter(pop.values())).shape[0]
1999
+ scores = torch.zeros(pop_size, device=self.device)
2000
+ total = 0
2001
+
2002
+ if debug:
2003
+ print(f"\n=== {bits}-BIT BITWISE OPS ===")
2004
+
2005
+ if bits == 32:
2006
+ test_pairs = [
2007
+ (0xAAAAAAAA, 0x55555555), (0xFFFFFFFF, 0x00000000),
2008
+ (0x12345678, 0x87654321), (0xDEADBEEF, 0xCAFEBABE),
2009
+ (0x0F0F0F0F, 0xF0F0F0F0), (0, 0), (0xFFFFFFFF, 0xFFFFFFFF),
2010
+ ]
2011
+ else:
2012
+ test_pairs = [(0xAA, 0x55), (0xFF, 0x00), (0x0F, 0xF0)]
2013
+
2014
+ a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long)
2015
+ b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long)
2016
+ num_tests = len(test_pairs)
2017
+
2018
+ ops = [
2019
+ ('and', lambda a, b: a & b),
2020
+ ('or', lambda a, b: a | b),
2021
+ ('xor', lambda a, b: a ^ b),
2022
+ ]
2023
+
2024
+ for op_name, op_fn in ops:
2025
+ try:
2026
+ result_bits = []
2027
+ for bit in range(bits):
2028
+ a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float()
2029
+ b_bit = ((b_vals >> (bits - 1 - bit)) & 1).float()
2030
+
2031
+ if op_name == 'xor':
2032
+ prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}'
2033
+ w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
2034
+ b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
2035
+ w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
2036
+ b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
2037
+ inp = torch.stack([a_bit, b_bit], dim=-1)
2038
+ h_or = heaviside(inp @ w_or.T + b_or)
2039
+ h_nand = heaviside(inp @ w_nand.T + b_nand)
2040
+ hidden = torch.stack([h_or, h_nand], dim=-1)
2041
+ w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
2042
+ b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
2043
+ out = heaviside((hidden * w2).sum(-1) + b2)
2044
+ else:
2045
+ w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size, -1)
2046
+ b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size)
2047
+ inp = torch.stack([a_bit, b_bit], dim=-1)
2048
+ out = heaviside(inp @ w.T + b)
2049
+
2050
+ result_bits.append(out[:, 0] if out.dim() > 1 else out)
2051
+
2052
+ result = sum(int(result_bits[i][j].item()) << (bits - 1 - i)
2053
+ for i in range(bits) for j in range(1))
2054
+ results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
2055
+ for i in range(bits)) for j in range(num_tests)],
2056
+ device=self.device)
2057
+ expected = torch.tensor([op_fn(a.item(), b.item()) for a, b in zip(a_vals, b_vals)],
2058
+ device=self.device)
2059
+
2060
+ correct = (results == expected).float().sum()
2061
+ self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
2062
+ if debug:
2063
+ r = self.results[-1]
2064
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
2065
+ scores += correct
2066
+ total += num_tests
2067
+ except KeyError as e:
2068
+ if debug:
2069
+ print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
2070
+
2071
+ try:
2072
+ test_vals = a_vals
2073
+ result_bits = []
2074
+ for bit in range(bits):
2075
+ a_bit = ((test_vals >> (bits - 1 - bit)) & 1).float()
2076
+ w = pop[f'alu.alu{bits}bit.not.bit{bit}.weight'].view(pop_size, -1)
2077
+ b = pop[f'alu.alu{bits}bit.not.bit{bit}.bias'].view(pop_size)
2078
+ out = heaviside(a_bit.unsqueeze(-1) @ w.T + b)
2079
+ result_bits.append(out[:, 0])
2080
+
2081
+ results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
2082
+ for i in range(bits)) for j in range(num_tests)],
2083
+ device=self.device)
2084
+ expected = torch.tensor([(~a.item()) & ((1 << bits) - 1) for a in test_vals],
2085
+ device=self.device)
2086
 
2087
+ correct = (results == expected).float().sum()
2088
+ self._record(f'alu.alu{bits}bit.not', int(correct.item()), num_tests, [])
2089
  if debug:
2090
  r = self.results[-1]
2091
  print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
2092
  scores += correct
2093
+ total += num_tests
2094
+ except KeyError as e:
2095
+ if debug:
2096
+ print(f" alu.alu{bits}bit.not: SKIP (missing {e})")
2097
+
2098
+ return scores, total
2099
+
2100
+ def _test_shifts_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
2101
+ """Test N-bit shift operations (SHL, SHR)."""
2102
+ pop_size = next(iter(pop.values())).shape[0]
2103
+ scores = torch.zeros(pop_size, device=self.device)
2104
+ total = 0
2105
+
2106
+ if debug:
2107
+ print(f"\n=== {bits}-BIT SHIFTS ===")
2108
+
2109
+ if bits == 32:
2110
+ test_vals = [0x12345678, 0x80000001, 0x00000001, 0xFFFFFFFF, 0x55555555]
2111
+ else:
2112
+ test_vals = [0x81, 0x55, 0x01, 0xFF, 0xAA]
2113
+
2114
+ a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
2115
+ num_tests = len(test_vals)
2116
+ max_val = (1 << bits) - 1
2117
+
2118
+ for op_name, op_fn in [('shl', lambda x: (x << 1) & max_val), ('shr', lambda x: x >> 1)]:
2119
+ try:
2120
+ result_bits = []
2121
+ for bit in range(bits):
2122
+ a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float()
2123
+ w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size)
2124
+ b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size)
2125
+
2126
+ if op_name == 'shl':
2127
+ if bit < bits - 1:
2128
+ src_bit = ((a_vals >> (bits - 2 - bit)) & 1).float()
2129
+ else:
2130
+ src_bit = torch.zeros_like(a_bit)
2131
+ else:
2132
+ if bit > 0:
2133
+ src_bit = ((a_vals >> (bits - bit)) & 1).float()
2134
+ else:
2135
+ src_bit = torch.zeros_like(a_bit)
2136
+
2137
+ out = heaviside(src_bit * w + b)
2138
+ result_bits.append(out)
2139
+
2140
+ results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
2141
+ for i in range(bits)) for j in range(num_tests)],
2142
+ device=self.device)
2143
+ expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device)
2144
+
2145
+ correct = (results == expected).float().sum()
2146
+ self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
2147
+ if debug:
2148
+ r = self.results[-1]
2149
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
2150
+ scores += correct
2151
+ total += num_tests
2152
+ except KeyError as e:
2153
+ if debug:
2154
+ print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
2155
+
2156
+ return scores, total
2157
+
2158
+ def _test_inc_dec_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
2159
+ """Test N-bit INC and DEC operations."""
2160
+ pop_size = next(iter(pop.values())).shape[0]
2161
+ scores = torch.zeros(pop_size, device=self.device)
2162
+ total = 0
2163
+
2164
+ if debug:
2165
+ print(f"\n=== {bits}-BIT INC/DEC ===")
2166
+
2167
+ if bits == 32:
2168
+ test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000000, 0xFFFFFFFE]
2169
+ else:
2170
+ test_vals = [0, 1, 254, 255, 127, 128]
2171
+
2172
+ a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
2173
+ num_tests = len(test_vals)
2174
+ max_val = (1 << bits) - 1
2175
+
2176
+ for op_name, op_fn in [('inc', lambda x: (x + 1) & max_val), ('dec', lambda x: (x - 1) & max_val)]:
2177
+ try:
2178
+ carry = torch.ones(num_tests, device=self.device)
2179
+ result_bits = []
2180
+
2181
+ for bit in range(bits):
2182
+ a_bit = ((a_vals >> bit) & 1).float()
2183
+
2184
+ prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}'
2185
+ w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten()
2186
+ b_or = pop[f'{prefix}.xor.layer1.or.bias'].item()
2187
+ w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten()
2188
+ b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item()
2189
+
2190
+ h_or = heaviside(a_bit * w_or[0] + carry * w_or[1] + b_or)
2191
+ h_nand = heaviside(a_bit * w_nand[0] + carry * w_nand[1] + b_nand)
2192
+
2193
+ w2 = pop[f'{prefix}.xor.layer2.weight'].flatten()
2194
+ b2 = pop[f'{prefix}.xor.layer2.bias'].item()
2195
+ xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2)
2196
+ result_bits.append(xor_out)
2197
+
2198
+ if op_name == 'inc':
2199
+ w_carry = pop[f'{prefix}.carry.weight'].flatten()
2200
+ b_carry = pop[f'{prefix}.carry.bias'].item()
2201
+ carry = heaviside(a_bit * w_carry[0] + carry * w_carry[1] + b_carry)
2202
+ else:
2203
+ w_not = pop[f'{prefix}.not_a.weight'].flatten()
2204
+ b_not = pop[f'{prefix}.not_a.bias'].item()
2205
+ not_a = heaviside(a_bit * w_not[0] + b_not)
2206
+ w_borrow = pop[f'{prefix}.borrow.weight'].flatten()
2207
+ b_borrow = pop[f'{prefix}.borrow.bias'].item()
2208
+ carry = heaviside(not_a * w_borrow[0] + carry * w_borrow[1] + b_borrow)
2209
+
2210
+ results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit
2211
+ for bit in range(bits)) for j in range(num_tests)],
2212
+ device=self.device)
2213
+ expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device)
2214
+
2215
+ correct = (results == expected).float().sum()
2216
+ self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
2217
+ if debug:
2218
+ r = self.results[-1]
2219
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
2220
+ scores += correct
2221
+ total += num_tests
2222
+ except KeyError as e:
2223
+ if debug:
2224
+ print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
2225
 
2226
  return scores, total
2227
 
2228
+ def _test_neg_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
2229
+ """Test N-bit NEG operation (two's complement negation)."""
2230
+ pop_size = next(iter(pop.values())).shape[0]
2231
+
2232
+ if debug:
2233
+ print(f"\n=== {bits}-BIT NEG ===")
2234
+
2235
+ if bits == 32:
2236
+ test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000, 1000000]
2237
+ else:
2238
+ test_vals = [0, 1, 127, 128, 255, 100]
2239
+
2240
+ a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
2241
+ num_tests = len(test_vals)
2242
+ max_val = (1 << bits) - 1
2243
+
2244
+ try:
2245
+ not_bits = []
2246
+ for bit in range(bits):
2247
+ a_bit = ((a_vals >> bit) & 1).float()
2248
+ w = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.weight'].flatten()
2249
+ b = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.bias'].item()
2250
+ not_bits.append(heaviside(a_bit * w[0] + b))
2251
+
2252
+ carry = torch.ones(num_tests, device=self.device)
2253
+ result_bits = []
2254
+
2255
+ for bit in range(bits):
2256
+ prefix = f'alu.alu{bits}bit.neg.inc.bit{bit}'
2257
+ not_bit = not_bits[bit]
2258
+
2259
+ w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten()
2260
+ b_or = pop[f'{prefix}.xor.layer1.or.bias'].item()
2261
+ w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten()
2262
+ b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item()
2263
+
2264
+ h_or = heaviside(not_bit * w_or[0] + carry * w_or[1] + b_or)
2265
+ h_nand = heaviside(not_bit * w_nand[0] + carry * w_nand[1] + b_nand)
2266
+
2267
+ w2 = pop[f'{prefix}.xor.layer2.weight'].flatten()
2268
+ b2 = pop[f'{prefix}.xor.layer2.bias'].item()
2269
+ xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2)
2270
+ result_bits.append(xor_out)
2271
+
2272
+ w_carry = pop[f'{prefix}.carry.weight'].flatten()
2273
+ b_carry = pop[f'{prefix}.carry.bias'].item()
2274
+ carry = heaviside(not_bit * w_carry[0] + carry * w_carry[1] + b_carry)
2275
+
2276
+ results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit
2277
+ for bit in range(bits)) for j in range(num_tests)],
2278
+ device=self.device)
2279
+ expected = torch.tensor([(-a.item()) & max_val for a in a_vals], device=self.device)
2280
+
2281
+ correct = (results == expected).float().sum()
2282
+ self._record(f'alu.alu{bits}bit.neg', int(correct.item()), num_tests, [])
2283
+ if debug:
2284
+ r = self.results[-1]
2285
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
2286
+
2287
+ return torch.tensor([correct], device=self.device), num_tests
2288
+ except KeyError as e:
2289
+ if debug:
2290
+ print(f" alu.alu{bits}bit.neg: SKIP (missing {e})")
2291
+ return torch.zeros(pop_size, device=self.device), 0
2292
+
2293
  # =========================================================================
2294
  # THRESHOLD GATES
2295
  # =========================================================================
 
3622
  if debug:
3623
  print("\n=== MANIFEST ===")
3624
 
3625
+ fixed_expected = {
3626
  'manifest.alu_operations': 16.0,
3627
  'manifest.flags': 4.0,
3628
  'manifest.instruction_width': 16.0,
 
 
3629
  'manifest.register_width': 8.0,
3630
  'manifest.registers': 4.0,
 
3631
  'manifest.version': 3.0,
3632
  }
3633
 
3634
+ for name, exp_val in fixed_expected.items():
3635
  try:
3636
+ val = pop[name][0, 0].item()
3637
  if val == exp_val:
3638
  scores += 1
3639
  self._record(name, 1, 1, [])
3640
  else:
3641
  self._record(name, 0, 1, [(exp_val, val)])
3642
  total += 1
 
3643
  if debug:
3644
  r = self.results[-1]
3645
  print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
3646
  except KeyError:
3647
  pass
3648
 
3649
+ variable_checks = ['manifest.memory_bytes', 'manifest.pc_width', 'manifest.turing_complete']
3650
+ for name in variable_checks:
3651
+ try:
3652
+ val = pop[name][0, 0].item()
3653
+ valid = val >= 0
3654
+ if valid:
3655
+ scores += 1
3656
+ self._record(name, 1, 1, [])
3657
+ else:
3658
+ self._record(name, 0, 1, [('>=0', val)])
3659
+ total += 1
3660
+ if debug:
3661
+ r = self.results[-1]
3662
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'} (value={val})")
3663
+ except KeyError:
3664
+ pass
3665
+
3666
  return scores, total
3667
 
3668
  # =========================================================================
 
3678
  if debug:
3679
  print("\n=== MEMORY ===")
3680
 
3681
+ try:
3682
+ mem_bytes = int(pop['manifest.memory_bytes'][0].item())
3683
+ addr_bits = int(pop['manifest.pc_width'][0].item())
3684
+ except KeyError:
3685
+ mem_bytes = 65536
3686
+ addr_bits = 16
3687
+
3688
+ if mem_bytes == 0:
3689
+ if debug:
3690
+ print(" No memory (pure ALU mode)")
3691
+ return scores, 0
3692
+
3693
  expected_shapes = {
3694
+ 'memory.addr_decode.weight': (mem_bytes, addr_bits),
3695
+ 'memory.addr_decode.bias': (mem_bytes,),
3696
+ 'memory.read.and.weight': (8, mem_bytes, 2),
3697
+ 'memory.read.and.bias': (8, mem_bytes),
3698
+ 'memory.read.or.weight': (8, mem_bytes),
3699
  'memory.read.or.bias': (8,),
3700
+ 'memory.write.sel.weight': (mem_bytes, 2),
3701
+ 'memory.write.sel.bias': (mem_bytes,),
3702
+ 'memory.write.nsel.weight': (mem_bytes, 1),
3703
+ 'memory.write.nsel.bias': (mem_bytes,),
3704
+ 'memory.write.and_old.weight': (mem_bytes, 8, 2),
3705
+ 'memory.write.and_old.bias': (mem_bytes, 8),
3706
+ 'memory.write.and_new.weight': (mem_bytes, 8, 2),
3707
+ 'memory.write.and_new.bias': (mem_bytes, 8),
3708
+ 'memory.write.or.weight': (mem_bytes, 8, 2),
3709
+ 'memory.write.or.bias': (mem_bytes, 8),
3710
  }
3711
 
3712
  for name, expected_shape in expected_shapes.items():
 
4027
  total_tests += t
4028
  self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4029
 
4030
+ if f'arithmetic.sub{bits}bit.not_b.bit0.weight' in population:
4031
+ s, t = self._test_subtractor_nbits(population, bits, debug)
4032
+ scores += s
4033
+ total_tests += t
4034
+ self.category_scores[f'subtractor{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4035
+
4036
+ if f'alu.alu{bits}bit.and.bit0.weight' in population:
4037
+ s, t = self._test_bitwise_nbits(population, bits, debug)
4038
+ scores += s
4039
+ total_tests += t
4040
+ self.category_scores[f'bitwise{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4041
+
4042
+ if f'alu.alu{bits}bit.shl.bit0.weight' in population:
4043
+ s, t = self._test_shifts_nbits(population, bits, debug)
4044
+ scores += s
4045
+ total_tests += t
4046
+ self.category_scores[f'shifts{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4047
+
4048
+ if f'alu.alu{bits}bit.inc.bit0.xor.layer1.or.weight' in population:
4049
+ s, t = self._test_inc_dec_nbits(population, bits, debug)
4050
+ scores += s
4051
+ total_tests += t
4052
+ self.category_scores[f'incdec{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4053
+
4054
+ if f'alu.alu{bits}bit.neg.not.bit0.weight' in population:
4055
+ s, t = self._test_neg_nbits(population, bits, debug)
4056
+ scores += s
4057
+ total_tests += t
4058
+ self.category_scores[f'neg{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
4059
+
4060
  # 3-operand adder
4061
  s, t = self._test_add3(population, debug)
4062
  scores += s
neural_alu32.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:788a277fbff9e44eb9006f5f76839ced42d90c1ff31513b36b34c9ee604e3d97
3
- size 4972488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a292e8d1dc5b29fd84d25d0333599a9946849e456aeb30b7519156dc150a623
3
+ size 4985016