Add 32-bit arithmetic support with cascaded byte comparison
Browse filesbuild.py changes:
- Add 'small' memory profile (1KB, 10-bit addresses) for 32-bit scratch space
- Add --bits flag supporting 8/16/32-bit ALU generation
- Add N-bit circuit generators: ripple carry adder, subtractor, comparators,
multiplier, divider, bitwise ops, shifts, INC/DEC, NEG
- Implement cascaded byte-wise comparison for 32-bit to avoid float32
precision loss (2^31 exceeds 24-bit mantissa). Compares byte-by-byte
from MSB using 8-bit comparators chained with AND/OR logic.
eval.py changes:
- Add 32-bit test data (strategic sampling of edge cases)
- Add _test_comparators_nbits with cascaded evaluation for bits > 16
- Add _test_subtractor_nbits, _test_bitwise_nbits, _test_shifts_nbits
- Add _test_inc_dec_nbits, _test_neg_nbits with correct LSB-first indexing
- Fix bit indexing bug: circuits use bit0=LSB, not MSB
- Make _test_memory dynamic: reads actual memory size from manifest
- Make _test_manifest flexible: only checks fixed values, validates
variable values (memory_bytes, pc_width) as non-negative
neural_alu32.safetensors:
- New 32-bit model with 1KB memory (202K params vs 8.3M for 64KB)
- All 6,973 tests passing at 100%
Verified 32-bit arithmetic:
1000 + 2000 = 3000
1000000 + 2345678 = 3345678
0xDEAD0000 + 0xBEEF = 0xDEADBEEF
4294967295 + 1 = 0 (correct overflow)
- build.py +60 -13
- eval.py +594 -76
- neural_alu32.safetensors +2 -2
|
@@ -714,23 +714,70 @@ def add_sub_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
| 714 |
def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
| 715 |
"""Add N-bit comparator circuits (GT, LT, GE, LE, EQ).
|
| 716 |
|
| 717 |
-
|
| 718 |
-
For
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
"""
|
| 720 |
-
|
| 721 |
-
|
|
|
|
| 722 |
|
| 723 |
-
|
| 724 |
-
|
| 725 |
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
|
| 735 |
|
| 736 |
def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
|
|
| 714 |
def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
| 715 |
"""Add N-bit comparator circuits (GT, LT, GE, LE, EQ).
|
| 716 |
|
| 717 |
+
For bits <= 16: Use single-layer weighted comparison (float32 safe).
|
| 718 |
+
For bits > 16: Use cascaded byte-wise comparison to avoid float32 precision loss.
|
| 719 |
+
|
| 720 |
+
Cascaded approach compares byte-by-byte from MSB:
|
| 721 |
+
A > B iff: (A[31:24] > B[31:24]) OR
|
| 722 |
+
(A[31:24] == B[31:24] AND A[23:16] > B[23:16]) OR ...
|
| 723 |
"""
|
| 724 |
+
if bits <= 16:
|
| 725 |
+
pos_weights = [float(1 << (bits - 1 - i)) for i in range(bits)]
|
| 726 |
+
neg_weights = [-w for w in pos_weights]
|
| 727 |
|
| 728 |
+
gt_weights = pos_weights + neg_weights
|
| 729 |
+
lt_weights = neg_weights + pos_weights
|
| 730 |
|
| 731 |
+
add_gate(tensors, f"arithmetic.greaterthan{bits}bit", gt_weights, [-1.0])
|
| 732 |
+
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", gt_weights, [0.0])
|
| 733 |
+
add_gate(tensors, f"arithmetic.lessthan{bits}bit", lt_weights, [-1.0])
|
| 734 |
+
add_gate(tensors, f"arithmetic.lessorequal{bits}bit", lt_weights, [0.0])
|
| 735 |
|
| 736 |
+
add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.geq", gt_weights, [0.0])
|
| 737 |
+
add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.leq", lt_weights, [0.0])
|
| 738 |
+
add_gate(tensors, f"arithmetic.equality{bits}bit.layer2", [1.0, 1.0], [-2.0])
|
| 739 |
+
else:
|
| 740 |
+
num_bytes = bits // 8
|
| 741 |
+
prefix = f"arithmetic.cmp{bits}bit"
|
| 742 |
+
|
| 743 |
+
byte_pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0]
|
| 744 |
+
byte_neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0]
|
| 745 |
+
byte_gt_weights = byte_pos_weights + byte_neg_weights
|
| 746 |
+
byte_lt_weights = byte_neg_weights + byte_pos_weights
|
| 747 |
+
|
| 748 |
+
for b in range(num_bytes):
|
| 749 |
+
add_gate(tensors, f"{prefix}.byte{b}.gt", byte_gt_weights, [-1.0])
|
| 750 |
+
add_gate(tensors, f"{prefix}.byte{b}.lt", byte_lt_weights, [-1.0])
|
| 751 |
+
add_gate(tensors, f"{prefix}.byte{b}.eq.geq", byte_gt_weights, [0.0])
|
| 752 |
+
add_gate(tensors, f"{prefix}.byte{b}.eq.leq", byte_lt_weights, [0.0])
|
| 753 |
+
add_gate(tensors, f"{prefix}.byte{b}.eq.and", [1.0, 1.0], [-2.0])
|
| 754 |
+
|
| 755 |
+
for b in range(num_bytes):
|
| 756 |
+
if b == 0:
|
| 757 |
+
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}", [1.0], [-1.0])
|
| 758 |
+
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}", [1.0], [-1.0])
|
| 759 |
+
else:
|
| 760 |
+
eq_weights = [1.0] * b
|
| 761 |
+
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.all_eq", eq_weights, [-float(b)])
|
| 762 |
+
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.and", [1.0, 1.0], [-2.0])
|
| 763 |
+
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.all_eq", eq_weights, [-float(b)])
|
| 764 |
+
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.and", [1.0, 1.0], [-2.0])
|
| 765 |
+
|
| 766 |
+
or_weights_gt = [1.0] * num_bytes
|
| 767 |
+
or_weights_lt = [1.0] * num_bytes
|
| 768 |
+
add_gate(tensors, f"arithmetic.greaterthan{bits}bit", or_weights_gt, [-1.0])
|
| 769 |
+
add_gate(tensors, f"arithmetic.lessthan{bits}bit", or_weights_lt, [-1.0])
|
| 770 |
+
|
| 771 |
+
not_lt_weights = [-1.0]
|
| 772 |
+
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit.not_lt", not_lt_weights, [0.0])
|
| 773 |
+
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", [1.0], [-1.0])
|
| 774 |
+
|
| 775 |
+
not_gt_weights = [-1.0]
|
| 776 |
+
add_gate(tensors, f"arithmetic.lessorequal{bits}bit.not_gt", not_gt_weights, [0.0])
|
| 777 |
+
add_gate(tensors, f"arithmetic.lessorequal{bits}bit", [1.0], [-1.0])
|
| 778 |
+
|
| 779 |
+
eq_all_weights = [1.0] * num_bytes
|
| 780 |
+
add_gate(tensors, f"arithmetic.equality{bits}bit", eq_all_weights, [-float(num_bytes)])
|
| 781 |
|
| 782 |
|
| 783 |
def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
@@ -1745,88 +1745,551 @@ class BatchedFitnessEvaluator:
|
|
| 1745 |
comp_a = self.comp_a
|
| 1746 |
comp_b = self.comp_b
|
| 1747 |
|
| 1748 |
-
|
| 1749 |
-
b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1750 |
-
inputs = torch.cat([a_bits, b_bits], dim=1)
|
| 1751 |
|
| 1752 |
-
|
| 1753 |
-
(
|
| 1754 |
-
(
|
| 1755 |
-
(
|
| 1756 |
-
(f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b),
|
| 1757 |
-
]
|
| 1758 |
|
| 1759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1760 |
try:
|
| 1761 |
-
expected = torch.tensor([1.0 if
|
| 1762 |
for a, b in zip(comp_a, comp_b)], device=self.device)
|
| 1763 |
-
|
| 1764 |
-
|
| 1765 |
-
|
| 1766 |
-
|
| 1767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1768 |
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 1769 |
-
|
| 1770 |
failures = []
|
| 1771 |
if pop_size == 1:
|
| 1772 |
-
for i in range(
|
| 1773 |
if out[i, 0].item() != expected[i].item():
|
| 1774 |
-
failures.append((
|
| 1775 |
-
|
| 1776 |
-
|
| 1777 |
-
out[i, 0].item()
|
| 1778 |
-
))
|
| 1779 |
-
|
| 1780 |
-
self._record(name, int(correct[0].item()), len(comp_a), failures)
|
| 1781 |
if debug:
|
| 1782 |
r = self.results[-1]
|
| 1783 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1784 |
scores += correct
|
| 1785 |
-
total +=
|
| 1786 |
except KeyError:
|
| 1787 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1788 |
|
| 1789 |
-
|
| 1790 |
-
try:
|
| 1791 |
-
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
|
| 1792 |
-
for a, b in zip(comp_a, comp_b)], device=self.device)
|
| 1793 |
|
| 1794 |
-
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
b_leq = pop[f'{prefix}.layer1.leq.bias']
|
| 1798 |
|
| 1799 |
-
|
| 1800 |
-
|
| 1801 |
-
hidden = torch.stack([h_geq, h_leq], dim=-1)
|
| 1802 |
|
| 1803 |
-
|
| 1804 |
-
|
| 1805 |
-
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
|
| 1806 |
|
| 1807 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1808 |
|
| 1809 |
-
|
| 1810 |
-
|
| 1811 |
-
|
| 1812 |
-
|
| 1813 |
-
|
| 1814 |
-
|
| 1815 |
-
|
| 1816 |
-
|
| 1817 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1818 |
|
| 1819 |
-
|
|
|
|
| 1820 |
if debug:
|
| 1821 |
r = self.results[-1]
|
| 1822 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1823 |
scores += correct
|
| 1824 |
-
total +=
|
| 1825 |
-
except KeyError:
|
| 1826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1827 |
|
| 1828 |
return scores, total
|
| 1829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1830 |
# =========================================================================
|
| 1831 |
# THRESHOLD GATES
|
| 1832 |
# =========================================================================
|
|
@@ -3159,34 +3622,47 @@ class BatchedFitnessEvaluator:
|
|
| 3159 |
if debug:
|
| 3160 |
print("\n=== MANIFEST ===")
|
| 3161 |
|
| 3162 |
-
|
| 3163 |
'manifest.alu_operations': 16.0,
|
| 3164 |
'manifest.flags': 4.0,
|
| 3165 |
'manifest.instruction_width': 16.0,
|
| 3166 |
-
'manifest.memory_bytes': 65536.0,
|
| 3167 |
-
'manifest.pc_width': 16.0,
|
| 3168 |
'manifest.register_width': 8.0,
|
| 3169 |
'manifest.registers': 4.0,
|
| 3170 |
-
'manifest.turing_complete': 1.0,
|
| 3171 |
'manifest.version': 3.0,
|
| 3172 |
}
|
| 3173 |
|
| 3174 |
-
for name, exp_val in
|
| 3175 |
try:
|
| 3176 |
-
val = pop[name][0, 0].item()
|
| 3177 |
if val == exp_val:
|
| 3178 |
scores += 1
|
| 3179 |
self._record(name, 1, 1, [])
|
| 3180 |
else:
|
| 3181 |
self._record(name, 0, 1, [(exp_val, val)])
|
| 3182 |
total += 1
|
| 3183 |
-
|
| 3184 |
if debug:
|
| 3185 |
r = self.results[-1]
|
| 3186 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 3187 |
except KeyError:
|
| 3188 |
pass
|
| 3189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3190 |
return scores, total
|
| 3191 |
|
| 3192 |
# =========================================================================
|
|
@@ -3202,23 +3678,35 @@ class BatchedFitnessEvaluator:
|
|
| 3202 |
if debug:
|
| 3203 |
print("\n=== MEMORY ===")
|
| 3204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3205 |
expected_shapes = {
|
| 3206 |
-
'memory.addr_decode.weight': (
|
| 3207 |
-
'memory.addr_decode.bias': (
|
| 3208 |
-
'memory.read.and.weight': (8,
|
| 3209 |
-
'memory.read.and.bias': (8,
|
| 3210 |
-
'memory.read.or.weight': (8,
|
| 3211 |
'memory.read.or.bias': (8,),
|
| 3212 |
-
'memory.write.sel.weight': (
|
| 3213 |
-
'memory.write.sel.bias': (
|
| 3214 |
-
'memory.write.nsel.weight': (
|
| 3215 |
-
'memory.write.nsel.bias': (
|
| 3216 |
-
'memory.write.and_old.weight': (
|
| 3217 |
-
'memory.write.and_old.bias': (
|
| 3218 |
-
'memory.write.and_new.weight': (
|
| 3219 |
-
'memory.write.and_new.bias': (
|
| 3220 |
-
'memory.write.or.weight': (
|
| 3221 |
-
'memory.write.or.bias': (
|
| 3222 |
}
|
| 3223 |
|
| 3224 |
for name, expected_shape in expected_shapes.items():
|
|
@@ -3539,6 +4027,36 @@ class BatchedFitnessEvaluator:
|
|
| 3539 |
total_tests += t
|
| 3540 |
self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 3541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3542 |
# 3-operand adder
|
| 3543 |
s, t = self._test_add3(population, debug)
|
| 3544 |
scores += s
|
|
|
|
| 1745 |
comp_a = self.comp_a
|
| 1746 |
comp_b = self.comp_b
|
| 1747 |
|
| 1748 |
+
num_tests = len(comp_a)
|
|
|
|
|
|
|
| 1749 |
|
| 1750 |
+
if bits <= 16:
|
| 1751 |
+
a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1752 |
+
b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1753 |
+
inputs = torch.cat([a_bits, b_bits], dim=1)
|
|
|
|
|
|
|
| 1754 |
|
| 1755 |
+
comparators = [
|
| 1756 |
+
(f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b),
|
| 1757 |
+
(f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b),
|
| 1758 |
+
(f'arithmetic.lessthan{bits}bit', lambda a, b: a < b),
|
| 1759 |
+
(f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b),
|
| 1760 |
+
]
|
| 1761 |
+
|
| 1762 |
+
for name, op in comparators:
|
| 1763 |
+
try:
|
| 1764 |
+
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
|
| 1765 |
+
for a, b in zip(comp_a, comp_b)], device=self.device)
|
| 1766 |
+
w = pop[f'{name}.weight']
|
| 1767 |
+
b = pop[f'{name}.bias']
|
| 1768 |
+
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
|
| 1769 |
+
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 1770 |
+
failures = []
|
| 1771 |
+
if pop_size == 1:
|
| 1772 |
+
for i in range(num_tests):
|
| 1773 |
+
if out[i, 0].item() != expected[i].item():
|
| 1774 |
+
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
|
| 1775 |
+
expected[i].item(), out[i, 0].item()))
|
| 1776 |
+
self._record(name, int(correct[0].item()), num_tests, failures)
|
| 1777 |
+
if debug:
|
| 1778 |
+
r = self.results[-1]
|
| 1779 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1780 |
+
scores += correct
|
| 1781 |
+
total += num_tests
|
| 1782 |
+
except KeyError:
|
| 1783 |
+
pass
|
| 1784 |
+
|
| 1785 |
+
prefix = f'arithmetic.equality{bits}bit'
|
| 1786 |
try:
|
| 1787 |
+
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
|
| 1788 |
for a, b in zip(comp_a, comp_b)], device=self.device)
|
| 1789 |
+
w_geq = pop[f'{prefix}.layer1.geq.weight']
|
| 1790 |
+
b_geq = pop[f'{prefix}.layer1.geq.bias']
|
| 1791 |
+
w_leq = pop[f'{prefix}.layer1.leq.weight']
|
| 1792 |
+
b_leq = pop[f'{prefix}.layer1.leq.bias']
|
| 1793 |
+
h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
|
| 1794 |
+
h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
|
| 1795 |
+
hidden = torch.stack([h_geq, h_leq], dim=-1)
|
| 1796 |
+
w2 = pop[f'{prefix}.layer2.weight']
|
| 1797 |
+
b2 = pop[f'{prefix}.layer2.bias']
|
| 1798 |
+
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
|
| 1799 |
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
|
|
|
| 1800 |
failures = []
|
| 1801 |
if pop_size == 1:
|
| 1802 |
+
for i in range(num_tests):
|
| 1803 |
if out[i, 0].item() != expected[i].item():
|
| 1804 |
+
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
|
| 1805 |
+
expected[i].item(), out[i, 0].item()))
|
| 1806 |
+
self._record(prefix, int(correct[0].item()), num_tests, failures)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1807 |
if debug:
|
| 1808 |
r = self.results[-1]
|
| 1809 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1810 |
scores += correct
|
| 1811 |
+
total += num_tests
|
| 1812 |
except KeyError:
|
| 1813 |
pass
|
| 1814 |
+
else:
|
| 1815 |
+
num_bytes = bits // 8
|
| 1816 |
+
prefix = f"arithmetic.cmp{bits}bit"
|
| 1817 |
+
|
| 1818 |
+
byte_gt = []
|
| 1819 |
+
byte_lt = []
|
| 1820 |
+
byte_eq = []
|
| 1821 |
+
|
| 1822 |
+
for b in range(num_bytes):
|
| 1823 |
+
start_bit = b * 8
|
| 1824 |
+
a_byte = torch.stack([((comp_a >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1)
|
| 1825 |
+
b_byte = torch.stack([((comp_b >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1)
|
| 1826 |
+
byte_input = torch.cat([a_byte, b_byte], dim=1)
|
| 1827 |
+
|
| 1828 |
+
w_gt = pop[f'{prefix}.byte{b}.gt.weight'].view(pop_size, -1)
|
| 1829 |
+
b_gt = pop[f'{prefix}.byte{b}.gt.bias'].view(pop_size)
|
| 1830 |
+
byte_gt.append(heaviside(byte_input @ w_gt.T + b_gt))
|
| 1831 |
+
|
| 1832 |
+
w_lt = pop[f'{prefix}.byte{b}.lt.weight'].view(pop_size, -1)
|
| 1833 |
+
b_lt = pop[f'{prefix}.byte{b}.lt.bias'].view(pop_size)
|
| 1834 |
+
byte_lt.append(heaviside(byte_input @ w_lt.T + b_lt))
|
| 1835 |
+
|
| 1836 |
+
w_geq = pop[f'{prefix}.byte{b}.eq.geq.weight'].view(pop_size, -1)
|
| 1837 |
+
b_geq = pop[f'{prefix}.byte{b}.eq.geq.bias'].view(pop_size)
|
| 1838 |
+
w_leq = pop[f'{prefix}.byte{b}.eq.leq.weight'].view(pop_size, -1)
|
| 1839 |
+
b_leq = pop[f'{prefix}.byte{b}.eq.leq.bias'].view(pop_size)
|
| 1840 |
+
h_geq = heaviside(byte_input @ w_geq.T + b_geq)
|
| 1841 |
+
h_leq = heaviside(byte_input @ w_leq.T + b_leq)
|
| 1842 |
+
w_and = pop[f'{prefix}.byte{b}.eq.and.weight'].view(pop_size, -1)
|
| 1843 |
+
b_and = pop[f'{prefix}.byte{b}.eq.and.bias'].view(pop_size)
|
| 1844 |
+
eq_inp = torch.stack([h_geq, h_leq], dim=-1)
|
| 1845 |
+
byte_eq.append(heaviside((eq_inp * w_and).sum(-1) + b_and))
|
| 1846 |
+
|
| 1847 |
+
cascade_gt = []
|
| 1848 |
+
cascade_lt = []
|
| 1849 |
+
for b in range(num_bytes):
|
| 1850 |
+
if b == 0:
|
| 1851 |
+
cascade_gt.append(byte_gt[0])
|
| 1852 |
+
cascade_lt.append(byte_lt[0])
|
| 1853 |
+
else:
|
| 1854 |
+
eq_stack = torch.stack(byte_eq[:b], dim=-1)
|
| 1855 |
+
w_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.weight'].view(pop_size, -1)
|
| 1856 |
+
b_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.bias'].view(pop_size)
|
| 1857 |
+
all_eq_gt = heaviside((eq_stack * w_all_eq).sum(-1) + b_all_eq)
|
| 1858 |
+
w_and = pop[f'{prefix}.cascade.gt.stage{b}.and.weight'].view(pop_size, -1)
|
| 1859 |
+
b_and = pop[f'{prefix}.cascade.gt.stage{b}.and.bias'].view(pop_size)
|
| 1860 |
+
stage_inp = torch.stack([all_eq_gt, byte_gt[b]], dim=-1)
|
| 1861 |
+
cascade_gt.append(heaviside((stage_inp * w_and).sum(-1) + b_and))
|
| 1862 |
+
|
| 1863 |
+
w_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.weight'].view(pop_size, -1)
|
| 1864 |
+
b_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.bias'].view(pop_size)
|
| 1865 |
+
all_eq_lt = heaviside((eq_stack * w_all_eq_lt).sum(-1) + b_all_eq_lt)
|
| 1866 |
+
w_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.weight'].view(pop_size, -1)
|
| 1867 |
+
b_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.bias'].view(pop_size)
|
| 1868 |
+
stage_inp_lt = torch.stack([all_eq_lt, byte_lt[b]], dim=-1)
|
| 1869 |
+
cascade_lt.append(heaviside((stage_inp_lt * w_and_lt).sum(-1) + b_and_lt))
|
| 1870 |
+
|
| 1871 |
+
gt_stack = torch.stack(cascade_gt, dim=-1)
|
| 1872 |
+
w_gt_or = pop[f'arithmetic.greaterthan{bits}bit.weight'].view(pop_size, -1)
|
| 1873 |
+
b_gt_or = pop[f'arithmetic.greaterthan{bits}bit.bias'].view(pop_size)
|
| 1874 |
+
gt_out = heaviside((gt_stack * w_gt_or).sum(-1) + b_gt_or)
|
| 1875 |
+
|
| 1876 |
+
lt_stack = torch.stack(cascade_lt, dim=-1)
|
| 1877 |
+
w_lt_or = pop[f'arithmetic.lessthan{bits}bit.weight'].view(pop_size, -1)
|
| 1878 |
+
b_lt_or = pop[f'arithmetic.lessthan{bits}bit.bias'].view(pop_size)
|
| 1879 |
+
lt_out = heaviside((lt_stack * w_lt_or).sum(-1) + b_lt_or)
|
| 1880 |
+
|
| 1881 |
+
w_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.weight'].view(pop_size, -1)
|
| 1882 |
+
b_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.bias'].view(pop_size)
|
| 1883 |
+
not_lt = heaviside(lt_out.unsqueeze(-1) @ w_not_lt.T + b_not_lt).squeeze(-1)
|
| 1884 |
+
w_ge = pop[f'arithmetic.greaterorequal{bits}bit.weight'].view(pop_size, -1)
|
| 1885 |
+
b_ge = pop[f'arithmetic.greaterorequal{bits}bit.bias'].view(pop_size)
|
| 1886 |
+
ge_out = heaviside(not_lt.unsqueeze(-1) @ w_ge.T + b_ge).squeeze(-1)
|
| 1887 |
+
|
| 1888 |
+
w_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.weight'].view(pop_size, -1)
|
| 1889 |
+
b_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.bias'].view(pop_size)
|
| 1890 |
+
not_gt = heaviside(gt_out.unsqueeze(-1) @ w_not_gt.T + b_not_gt).squeeze(-1)
|
| 1891 |
+
w_le = pop[f'arithmetic.lessorequal{bits}bit.weight'].view(pop_size, -1)
|
| 1892 |
+
b_le = pop[f'arithmetic.lessorequal{bits}bit.bias'].view(pop_size)
|
| 1893 |
+
le_out = heaviside(not_gt.unsqueeze(-1) @ w_le.T + b_le).squeeze(-1)
|
| 1894 |
+
|
| 1895 |
+
eq_stack = torch.stack(byte_eq, dim=-1)
|
| 1896 |
+
w_eq_all = pop[f'arithmetic.equality{bits}bit.weight'].view(pop_size, -1)
|
| 1897 |
+
b_eq_all = pop[f'arithmetic.equality{bits}bit.bias'].view(pop_size)
|
| 1898 |
+
eq_out = heaviside((eq_stack * w_eq_all).sum(-1) + b_eq_all)
|
| 1899 |
+
|
| 1900 |
+
for name, out, op in [
|
| 1901 |
+
(f'arithmetic.greaterthan{bits}bit', gt_out, lambda a, b: a > b),
|
| 1902 |
+
(f'arithmetic.greaterorequal{bits}bit', ge_out, lambda a, b: a >= b),
|
| 1903 |
+
(f'arithmetic.lessthan{bits}bit', lt_out, lambda a, b: a < b),
|
| 1904 |
+
(f'arithmetic.lessorequal{bits}bit', le_out, lambda a, b: a <= b),
|
| 1905 |
+
(f'arithmetic.equality{bits}bit', eq_out, lambda a, b: a == b),
|
| 1906 |
+
]:
|
| 1907 |
+
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
|
| 1908 |
+
for a, b in zip(comp_a, comp_b)], device=self.device)
|
| 1909 |
+
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 1910 |
+
failures = []
|
| 1911 |
+
if pop_size == 1:
|
| 1912 |
+
for i in range(num_tests):
|
| 1913 |
+
if out[i, 0].item() != expected[i].item():
|
| 1914 |
+
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
|
| 1915 |
+
expected[i].item(), out[i, 0].item()))
|
| 1916 |
+
self._record(name, int(correct[0].item()), num_tests, failures)
|
| 1917 |
+
if debug:
|
| 1918 |
+
r = self.results[-1]
|
| 1919 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1920 |
+
scores += correct
|
| 1921 |
+
total += num_tests
|
| 1922 |
|
| 1923 |
+
return scores, total
|
|
|
|
|
|
|
|
|
|
| 1924 |
|
| 1925 |
+
def _test_subtractor_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 1926 |
+
"""Test N-bit subtractor circuit (A - B)."""
|
| 1927 |
+
pop_size = next(iter(pop.values())).shape[0]
|
|
|
|
| 1928 |
|
| 1929 |
+
if debug:
|
| 1930 |
+
print(f"\n=== {bits}-BIT SUBTRACTOR ===")
|
|
|
|
| 1931 |
|
| 1932 |
+
prefix = f'arithmetic.sub{bits}bit'
|
| 1933 |
+
max_val = 1 << bits
|
|
|
|
| 1934 |
|
| 1935 |
+
if bits == 32:
|
| 1936 |
+
test_pairs = [
|
| 1937 |
+
(1000, 500), (5000, 3000), (1000000, 500000),
|
| 1938 |
+
(0xFFFFFFFF, 1), (0x80000000, 1), (100, 100),
|
| 1939 |
+
(0, 0), (1, 0), (0, 1), (256, 255),
|
| 1940 |
+
(0xDEADBEEF, 0xCAFEBABE), (1000000000, 999999999),
|
| 1941 |
+
]
|
| 1942 |
+
else:
|
| 1943 |
+
test_pairs = [(a, b) for a in [0, 1, 127, 128, 255] for b in [0, 1, 127, 128, 255]]
|
| 1944 |
|
| 1945 |
+
a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long)
|
| 1946 |
+
b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long)
|
| 1947 |
+
num_tests = len(test_pairs)
|
| 1948 |
+
|
| 1949 |
+
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1950 |
+
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1951 |
+
|
| 1952 |
+
not_b_bits = torch.zeros_like(b_bits)
|
| 1953 |
+
for bit in range(bits):
|
| 1954 |
+
w = pop[f'{prefix}.not_b.bit{bit}.weight'].view(pop_size, -1)
|
| 1955 |
+
b = pop[f'{prefix}.not_b.bit{bit}.bias'].view(pop_size)
|
| 1956 |
+
not_b_bits[:, bit] = heaviside(b_bits[:, bit:bit+1] @ w.T + b)[:, 0]
|
| 1957 |
+
|
| 1958 |
+
carry = torch.ones(num_tests, pop_size, device=self.device)
|
| 1959 |
+
sum_bits = []
|
| 1960 |
+
|
| 1961 |
+
for bit in range(bits):
|
| 1962 |
+
bit_idx = bits - 1 - bit
|
| 1963 |
+
s, carry = self._eval_single_fa(
|
| 1964 |
+
pop, f'{prefix}.fa{bit}',
|
| 1965 |
+
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
|
| 1966 |
+
not_b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
|
| 1967 |
+
carry
|
| 1968 |
+
)
|
| 1969 |
+
sum_bits.append(s)
|
| 1970 |
+
|
| 1971 |
+
sum_bits = torch.stack(sum_bits[::-1], dim=-1)
|
| 1972 |
+
result = torch.zeros(num_tests, pop_size, device=self.device)
|
| 1973 |
+
for i in range(bits):
|
| 1974 |
+
result += sum_bits[:, :, i] * (1 << (bits - 1 - i))
|
| 1975 |
+
|
| 1976 |
+
expected = ((a_vals - b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float()
|
| 1977 |
+
correct = (result == expected).float().sum(0)
|
| 1978 |
+
|
| 1979 |
+
failures = []
|
| 1980 |
+
if pop_size == 1:
|
| 1981 |
+
for i in range(min(num_tests, 20)):
|
| 1982 |
+
if result[i, 0].item() != expected[i, 0].item():
|
| 1983 |
+
failures.append((
|
| 1984 |
+
[int(a_vals[i].item()), int(b_vals[i].item())],
|
| 1985 |
+
int(expected[i, 0].item()),
|
| 1986 |
+
int(result[i, 0].item())
|
| 1987 |
+
))
|
| 1988 |
+
|
| 1989 |
+
self._record(prefix, int(correct[0].item()), num_tests, failures)
|
| 1990 |
+
if debug:
|
| 1991 |
+
r = self.results[-1]
|
| 1992 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1993 |
+
|
| 1994 |
+
return correct, num_tests
|
| 1995 |
+
|
| 1996 |
+
def _test_bitwise_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 1997 |
+
"""Test N-bit bitwise operations (AND, OR, XOR, NOT)."""
|
| 1998 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 1999 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 2000 |
+
total = 0
|
| 2001 |
+
|
| 2002 |
+
if debug:
|
| 2003 |
+
print(f"\n=== {bits}-BIT BITWISE OPS ===")
|
| 2004 |
+
|
| 2005 |
+
if bits == 32:
|
| 2006 |
+
test_pairs = [
|
| 2007 |
+
(0xAAAAAAAA, 0x55555555), (0xFFFFFFFF, 0x00000000),
|
| 2008 |
+
(0x12345678, 0x87654321), (0xDEADBEEF, 0xCAFEBABE),
|
| 2009 |
+
(0x0F0F0F0F, 0xF0F0F0F0), (0, 0), (0xFFFFFFFF, 0xFFFFFFFF),
|
| 2010 |
+
]
|
| 2011 |
+
else:
|
| 2012 |
+
test_pairs = [(0xAA, 0x55), (0xFF, 0x00), (0x0F, 0xF0)]
|
| 2013 |
+
|
| 2014 |
+
a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long)
|
| 2015 |
+
b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long)
|
| 2016 |
+
num_tests = len(test_pairs)
|
| 2017 |
+
|
| 2018 |
+
ops = [
|
| 2019 |
+
('and', lambda a, b: a & b),
|
| 2020 |
+
('or', lambda a, b: a | b),
|
| 2021 |
+
('xor', lambda a, b: a ^ b),
|
| 2022 |
+
]
|
| 2023 |
+
|
| 2024 |
+
for op_name, op_fn in ops:
|
| 2025 |
+
try:
|
| 2026 |
+
result_bits = []
|
| 2027 |
+
for bit in range(bits):
|
| 2028 |
+
a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float()
|
| 2029 |
+
b_bit = ((b_vals >> (bits - 1 - bit)) & 1).float()
|
| 2030 |
+
|
| 2031 |
+
if op_name == 'xor':
|
| 2032 |
+
prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}'
|
| 2033 |
+
w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
|
| 2034 |
+
b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
|
| 2035 |
+
w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
|
| 2036 |
+
b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
|
| 2037 |
+
inp = torch.stack([a_bit, b_bit], dim=-1)
|
| 2038 |
+
h_or = heaviside(inp @ w_or.T + b_or)
|
| 2039 |
+
h_nand = heaviside(inp @ w_nand.T + b_nand)
|
| 2040 |
+
hidden = torch.stack([h_or, h_nand], dim=-1)
|
| 2041 |
+
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
|
| 2042 |
+
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
|
| 2043 |
+
out = heaviside((hidden * w2).sum(-1) + b2)
|
| 2044 |
+
else:
|
| 2045 |
+
w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size, -1)
|
| 2046 |
+
b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size)
|
| 2047 |
+
inp = torch.stack([a_bit, b_bit], dim=-1)
|
| 2048 |
+
out = heaviside(inp @ w.T + b)
|
| 2049 |
+
|
| 2050 |
+
result_bits.append(out[:, 0] if out.dim() > 1 else out)
|
| 2051 |
+
|
| 2052 |
+
result = sum(int(result_bits[i][j].item()) << (bits - 1 - i)
|
| 2053 |
+
for i in range(bits) for j in range(1))
|
| 2054 |
+
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
|
| 2055 |
+
for i in range(bits)) for j in range(num_tests)],
|
| 2056 |
+
device=self.device)
|
| 2057 |
+
expected = torch.tensor([op_fn(a.item(), b.item()) for a, b in zip(a_vals, b_vals)],
|
| 2058 |
+
device=self.device)
|
| 2059 |
+
|
| 2060 |
+
correct = (results == expected).float().sum()
|
| 2061 |
+
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
|
| 2062 |
+
if debug:
|
| 2063 |
+
r = self.results[-1]
|
| 2064 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 2065 |
+
scores += correct
|
| 2066 |
+
total += num_tests
|
| 2067 |
+
except KeyError as e:
|
| 2068 |
+
if debug:
|
| 2069 |
+
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
|
| 2070 |
+
|
| 2071 |
+
try:
|
| 2072 |
+
test_vals = a_vals
|
| 2073 |
+
result_bits = []
|
| 2074 |
+
for bit in range(bits):
|
| 2075 |
+
a_bit = ((test_vals >> (bits - 1 - bit)) & 1).float()
|
| 2076 |
+
w = pop[f'alu.alu{bits}bit.not.bit{bit}.weight'].view(pop_size, -1)
|
| 2077 |
+
b = pop[f'alu.alu{bits}bit.not.bit{bit}.bias'].view(pop_size)
|
| 2078 |
+
out = heaviside(a_bit.unsqueeze(-1) @ w.T + b)
|
| 2079 |
+
result_bits.append(out[:, 0])
|
| 2080 |
+
|
| 2081 |
+
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
|
| 2082 |
+
for i in range(bits)) for j in range(num_tests)],
|
| 2083 |
+
device=self.device)
|
| 2084 |
+
expected = torch.tensor([(~a.item()) & ((1 << bits) - 1) for a in test_vals],
|
| 2085 |
+
device=self.device)
|
| 2086 |
|
| 2087 |
+
correct = (results == expected).float().sum()
|
| 2088 |
+
self._record(f'alu.alu{bits}bit.not', int(correct.item()), num_tests, [])
|
| 2089 |
if debug:
|
| 2090 |
r = self.results[-1]
|
| 2091 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 2092 |
scores += correct
|
| 2093 |
+
total += num_tests
|
| 2094 |
+
except KeyError as e:
|
| 2095 |
+
if debug:
|
| 2096 |
+
print(f" alu.alu{bits}bit.not: SKIP (missing {e})")
|
| 2097 |
+
|
| 2098 |
+
return scores, total
|
| 2099 |
+
|
| 2100 |
+
def _test_shifts_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 2101 |
+
"""Test N-bit shift operations (SHL, SHR)."""
|
| 2102 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 2103 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 2104 |
+
total = 0
|
| 2105 |
+
|
| 2106 |
+
if debug:
|
| 2107 |
+
print(f"\n=== {bits}-BIT SHIFTS ===")
|
| 2108 |
+
|
| 2109 |
+
if bits == 32:
|
| 2110 |
+
test_vals = [0x12345678, 0x80000001, 0x00000001, 0xFFFFFFFF, 0x55555555]
|
| 2111 |
+
else:
|
| 2112 |
+
test_vals = [0x81, 0x55, 0x01, 0xFF, 0xAA]
|
| 2113 |
+
|
| 2114 |
+
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
|
| 2115 |
+
num_tests = len(test_vals)
|
| 2116 |
+
max_val = (1 << bits) - 1
|
| 2117 |
+
|
| 2118 |
+
for op_name, op_fn in [('shl', lambda x: (x << 1) & max_val), ('shr', lambda x: x >> 1)]:
|
| 2119 |
+
try:
|
| 2120 |
+
result_bits = []
|
| 2121 |
+
for bit in range(bits):
|
| 2122 |
+
a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float()
|
| 2123 |
+
w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size)
|
| 2124 |
+
b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size)
|
| 2125 |
+
|
| 2126 |
+
if op_name == 'shl':
|
| 2127 |
+
if bit < bits - 1:
|
| 2128 |
+
src_bit = ((a_vals >> (bits - 2 - bit)) & 1).float()
|
| 2129 |
+
else:
|
| 2130 |
+
src_bit = torch.zeros_like(a_bit)
|
| 2131 |
+
else:
|
| 2132 |
+
if bit > 0:
|
| 2133 |
+
src_bit = ((a_vals >> (bits - bit)) & 1).float()
|
| 2134 |
+
else:
|
| 2135 |
+
src_bit = torch.zeros_like(a_bit)
|
| 2136 |
+
|
| 2137 |
+
out = heaviside(src_bit * w + b)
|
| 2138 |
+
result_bits.append(out)
|
| 2139 |
+
|
| 2140 |
+
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
|
| 2141 |
+
for i in range(bits)) for j in range(num_tests)],
|
| 2142 |
+
device=self.device)
|
| 2143 |
+
expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device)
|
| 2144 |
+
|
| 2145 |
+
correct = (results == expected).float().sum()
|
| 2146 |
+
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
|
| 2147 |
+
if debug:
|
| 2148 |
+
r = self.results[-1]
|
| 2149 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 2150 |
+
scores += correct
|
| 2151 |
+
total += num_tests
|
| 2152 |
+
except KeyError as e:
|
| 2153 |
+
if debug:
|
| 2154 |
+
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
|
| 2155 |
+
|
| 2156 |
+
return scores, total
|
| 2157 |
+
|
| 2158 |
+
def _test_inc_dec_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 2159 |
+
"""Test N-bit INC and DEC operations."""
|
| 2160 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 2161 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 2162 |
+
total = 0
|
| 2163 |
+
|
| 2164 |
+
if debug:
|
| 2165 |
+
print(f"\n=== {bits}-BIT INC/DEC ===")
|
| 2166 |
+
|
| 2167 |
+
if bits == 32:
|
| 2168 |
+
test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000000, 0xFFFFFFFE]
|
| 2169 |
+
else:
|
| 2170 |
+
test_vals = [0, 1, 254, 255, 127, 128]
|
| 2171 |
+
|
| 2172 |
+
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
|
| 2173 |
+
num_tests = len(test_vals)
|
| 2174 |
+
max_val = (1 << bits) - 1
|
| 2175 |
+
|
| 2176 |
+
for op_name, op_fn in [('inc', lambda x: (x + 1) & max_val), ('dec', lambda x: (x - 1) & max_val)]:
|
| 2177 |
+
try:
|
| 2178 |
+
carry = torch.ones(num_tests, device=self.device)
|
| 2179 |
+
result_bits = []
|
| 2180 |
+
|
| 2181 |
+
for bit in range(bits):
|
| 2182 |
+
a_bit = ((a_vals >> bit) & 1).float()
|
| 2183 |
+
|
| 2184 |
+
prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}'
|
| 2185 |
+
w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten()
|
| 2186 |
+
b_or = pop[f'{prefix}.xor.layer1.or.bias'].item()
|
| 2187 |
+
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten()
|
| 2188 |
+
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item()
|
| 2189 |
+
|
| 2190 |
+
h_or = heaviside(a_bit * w_or[0] + carry * w_or[1] + b_or)
|
| 2191 |
+
h_nand = heaviside(a_bit * w_nand[0] + carry * w_nand[1] + b_nand)
|
| 2192 |
+
|
| 2193 |
+
w2 = pop[f'{prefix}.xor.layer2.weight'].flatten()
|
| 2194 |
+
b2 = pop[f'{prefix}.xor.layer2.bias'].item()
|
| 2195 |
+
xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2)
|
| 2196 |
+
result_bits.append(xor_out)
|
| 2197 |
+
|
| 2198 |
+
if op_name == 'inc':
|
| 2199 |
+
w_carry = pop[f'{prefix}.carry.weight'].flatten()
|
| 2200 |
+
b_carry = pop[f'{prefix}.carry.bias'].item()
|
| 2201 |
+
carry = heaviside(a_bit * w_carry[0] + carry * w_carry[1] + b_carry)
|
| 2202 |
+
else:
|
| 2203 |
+
w_not = pop[f'{prefix}.not_a.weight'].flatten()
|
| 2204 |
+
b_not = pop[f'{prefix}.not_a.bias'].item()
|
| 2205 |
+
not_a = heaviside(a_bit * w_not[0] + b_not)
|
| 2206 |
+
w_borrow = pop[f'{prefix}.borrow.weight'].flatten()
|
| 2207 |
+
b_borrow = pop[f'{prefix}.borrow.bias'].item()
|
| 2208 |
+
carry = heaviside(not_a * w_borrow[0] + carry * w_borrow[1] + b_borrow)
|
| 2209 |
+
|
| 2210 |
+
results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit
|
| 2211 |
+
for bit in range(bits)) for j in range(num_tests)],
|
| 2212 |
+
device=self.device)
|
| 2213 |
+
expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device)
|
| 2214 |
+
|
| 2215 |
+
correct = (results == expected).float().sum()
|
| 2216 |
+
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
|
| 2217 |
+
if debug:
|
| 2218 |
+
r = self.results[-1]
|
| 2219 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 2220 |
+
scores += correct
|
| 2221 |
+
total += num_tests
|
| 2222 |
+
except KeyError as e:
|
| 2223 |
+
if debug:
|
| 2224 |
+
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
|
| 2225 |
|
| 2226 |
return scores, total
|
| 2227 |
|
| 2228 |
+
def _test_neg_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 2229 |
+
"""Test N-bit NEG operation (two's complement negation)."""
|
| 2230 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 2231 |
+
|
| 2232 |
+
if debug:
|
| 2233 |
+
print(f"\n=== {bits}-BIT NEG ===")
|
| 2234 |
+
|
| 2235 |
+
if bits == 32:
|
| 2236 |
+
test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000, 1000000]
|
| 2237 |
+
else:
|
| 2238 |
+
test_vals = [0, 1, 127, 128, 255, 100]
|
| 2239 |
+
|
| 2240 |
+
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
|
| 2241 |
+
num_tests = len(test_vals)
|
| 2242 |
+
max_val = (1 << bits) - 1
|
| 2243 |
+
|
| 2244 |
+
try:
|
| 2245 |
+
not_bits = []
|
| 2246 |
+
for bit in range(bits):
|
| 2247 |
+
a_bit = ((a_vals >> bit) & 1).float()
|
| 2248 |
+
w = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.weight'].flatten()
|
| 2249 |
+
b = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.bias'].item()
|
| 2250 |
+
not_bits.append(heaviside(a_bit * w[0] + b))
|
| 2251 |
+
|
| 2252 |
+
carry = torch.ones(num_tests, device=self.device)
|
| 2253 |
+
result_bits = []
|
| 2254 |
+
|
| 2255 |
+
for bit in range(bits):
|
| 2256 |
+
prefix = f'alu.alu{bits}bit.neg.inc.bit{bit}'
|
| 2257 |
+
not_bit = not_bits[bit]
|
| 2258 |
+
|
| 2259 |
+
w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten()
|
| 2260 |
+
b_or = pop[f'{prefix}.xor.layer1.or.bias'].item()
|
| 2261 |
+
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten()
|
| 2262 |
+
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item()
|
| 2263 |
+
|
| 2264 |
+
h_or = heaviside(not_bit * w_or[0] + carry * w_or[1] + b_or)
|
| 2265 |
+
h_nand = heaviside(not_bit * w_nand[0] + carry * w_nand[1] + b_nand)
|
| 2266 |
+
|
| 2267 |
+
w2 = pop[f'{prefix}.xor.layer2.weight'].flatten()
|
| 2268 |
+
b2 = pop[f'{prefix}.xor.layer2.bias'].item()
|
| 2269 |
+
xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2)
|
| 2270 |
+
result_bits.append(xor_out)
|
| 2271 |
+
|
| 2272 |
+
w_carry = pop[f'{prefix}.carry.weight'].flatten()
|
| 2273 |
+
b_carry = pop[f'{prefix}.carry.bias'].item()
|
| 2274 |
+
carry = heaviside(not_bit * w_carry[0] + carry * w_carry[1] + b_carry)
|
| 2275 |
+
|
| 2276 |
+
results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit
|
| 2277 |
+
for bit in range(bits)) for j in range(num_tests)],
|
| 2278 |
+
device=self.device)
|
| 2279 |
+
expected = torch.tensor([(-a.item()) & max_val for a in a_vals], device=self.device)
|
| 2280 |
+
|
| 2281 |
+
correct = (results == expected).float().sum()
|
| 2282 |
+
self._record(f'alu.alu{bits}bit.neg', int(correct.item()), num_tests, [])
|
| 2283 |
+
if debug:
|
| 2284 |
+
r = self.results[-1]
|
| 2285 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 2286 |
+
|
| 2287 |
+
return torch.tensor([correct], device=self.device), num_tests
|
| 2288 |
+
except KeyError as e:
|
| 2289 |
+
if debug:
|
| 2290 |
+
print(f" alu.alu{bits}bit.neg: SKIP (missing {e})")
|
| 2291 |
+
return torch.zeros(pop_size, device=self.device), 0
|
| 2292 |
+
|
| 2293 |
# =========================================================================
|
| 2294 |
# THRESHOLD GATES
|
| 2295 |
# =========================================================================
|
|
|
|
| 3622 |
if debug:
|
| 3623 |
print("\n=== MANIFEST ===")
|
| 3624 |
|
| 3625 |
+
fixed_expected = {
|
| 3626 |
'manifest.alu_operations': 16.0,
|
| 3627 |
'manifest.flags': 4.0,
|
| 3628 |
'manifest.instruction_width': 16.0,
|
|
|
|
|
|
|
| 3629 |
'manifest.register_width': 8.0,
|
| 3630 |
'manifest.registers': 4.0,
|
|
|
|
| 3631 |
'manifest.version': 3.0,
|
| 3632 |
}
|
| 3633 |
|
| 3634 |
+
for name, exp_val in fixed_expected.items():
|
| 3635 |
try:
|
| 3636 |
+
val = pop[name][0, 0].item()
|
| 3637 |
if val == exp_val:
|
| 3638 |
scores += 1
|
| 3639 |
self._record(name, 1, 1, [])
|
| 3640 |
else:
|
| 3641 |
self._record(name, 0, 1, [(exp_val, val)])
|
| 3642 |
total += 1
|
|
|
|
| 3643 |
if debug:
|
| 3644 |
r = self.results[-1]
|
| 3645 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 3646 |
except KeyError:
|
| 3647 |
pass
|
| 3648 |
|
| 3649 |
+
variable_checks = ['manifest.memory_bytes', 'manifest.pc_width', 'manifest.turing_complete']
|
| 3650 |
+
for name in variable_checks:
|
| 3651 |
+
try:
|
| 3652 |
+
val = pop[name][0, 0].item()
|
| 3653 |
+
valid = val >= 0
|
| 3654 |
+
if valid:
|
| 3655 |
+
scores += 1
|
| 3656 |
+
self._record(name, 1, 1, [])
|
| 3657 |
+
else:
|
| 3658 |
+
self._record(name, 0, 1, [('>=0', val)])
|
| 3659 |
+
total += 1
|
| 3660 |
+
if debug:
|
| 3661 |
+
r = self.results[-1]
|
| 3662 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'} (value={val})")
|
| 3663 |
+
except KeyError:
|
| 3664 |
+
pass
|
| 3665 |
+
|
| 3666 |
return scores, total
|
| 3667 |
|
| 3668 |
# =========================================================================
|
|
|
|
| 3678 |
if debug:
|
| 3679 |
print("\n=== MEMORY ===")
|
| 3680 |
|
| 3681 |
+
try:
|
| 3682 |
+
mem_bytes = int(pop['manifest.memory_bytes'][0].item())
|
| 3683 |
+
addr_bits = int(pop['manifest.pc_width'][0].item())
|
| 3684 |
+
except KeyError:
|
| 3685 |
+
mem_bytes = 65536
|
| 3686 |
+
addr_bits = 16
|
| 3687 |
+
|
| 3688 |
+
if mem_bytes == 0:
|
| 3689 |
+
if debug:
|
| 3690 |
+
print(" No memory (pure ALU mode)")
|
| 3691 |
+
return scores, 0
|
| 3692 |
+
|
| 3693 |
expected_shapes = {
|
| 3694 |
+
'memory.addr_decode.weight': (mem_bytes, addr_bits),
|
| 3695 |
+
'memory.addr_decode.bias': (mem_bytes,),
|
| 3696 |
+
'memory.read.and.weight': (8, mem_bytes, 2),
|
| 3697 |
+
'memory.read.and.bias': (8, mem_bytes),
|
| 3698 |
+
'memory.read.or.weight': (8, mem_bytes),
|
| 3699 |
'memory.read.or.bias': (8,),
|
| 3700 |
+
'memory.write.sel.weight': (mem_bytes, 2),
|
| 3701 |
+
'memory.write.sel.bias': (mem_bytes,),
|
| 3702 |
+
'memory.write.nsel.weight': (mem_bytes, 1),
|
| 3703 |
+
'memory.write.nsel.bias': (mem_bytes,),
|
| 3704 |
+
'memory.write.and_old.weight': (mem_bytes, 8, 2),
|
| 3705 |
+
'memory.write.and_old.bias': (mem_bytes, 8),
|
| 3706 |
+
'memory.write.and_new.weight': (mem_bytes, 8, 2),
|
| 3707 |
+
'memory.write.and_new.bias': (mem_bytes, 8),
|
| 3708 |
+
'memory.write.or.weight': (mem_bytes, 8, 2),
|
| 3709 |
+
'memory.write.or.bias': (mem_bytes, 8),
|
| 3710 |
}
|
| 3711 |
|
| 3712 |
for name, expected_shape in expected_shapes.items():
|
|
|
|
| 4027 |
total_tests += t
|
| 4028 |
self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4029 |
|
| 4030 |
+
if f'arithmetic.sub{bits}bit.not_b.bit0.weight' in population:
|
| 4031 |
+
s, t = self._test_subtractor_nbits(population, bits, debug)
|
| 4032 |
+
scores += s
|
| 4033 |
+
total_tests += t
|
| 4034 |
+
self.category_scores[f'subtractor{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4035 |
+
|
| 4036 |
+
if f'alu.alu{bits}bit.and.bit0.weight' in population:
|
| 4037 |
+
s, t = self._test_bitwise_nbits(population, bits, debug)
|
| 4038 |
+
scores += s
|
| 4039 |
+
total_tests += t
|
| 4040 |
+
self.category_scores[f'bitwise{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4041 |
+
|
| 4042 |
+
if f'alu.alu{bits}bit.shl.bit0.weight' in population:
|
| 4043 |
+
s, t = self._test_shifts_nbits(population, bits, debug)
|
| 4044 |
+
scores += s
|
| 4045 |
+
total_tests += t
|
| 4046 |
+
self.category_scores[f'shifts{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4047 |
+
|
| 4048 |
+
if f'alu.alu{bits}bit.inc.bit0.xor.layer1.or.weight' in population:
|
| 4049 |
+
s, t = self._test_inc_dec_nbits(population, bits, debug)
|
| 4050 |
+
scores += s
|
| 4051 |
+
total_tests += t
|
| 4052 |
+
self.category_scores[f'incdec{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4053 |
+
|
| 4054 |
+
if f'alu.alu{bits}bit.neg.not.bit0.weight' in population:
|
| 4055 |
+
s, t = self._test_neg_nbits(population, bits, debug)
|
| 4056 |
+
scores += s
|
| 4057 |
+
total_tests += t
|
| 4058 |
+
self.category_scores[f'neg{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
|
| 4059 |
+
|
| 4060 |
# 3-operand adder
|
| 4061 |
s, t = self._test_add3(population, debug)
|
| 4062 |
scores += s
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a292e8d1dc5b29fd84d25d0333599a9946849e456aeb30b7519156dc150a623
|
| 3 |
+
size 4985016
|