Bit-cascade integer comparators (8/16/32-bit) with ternary weights
Browse filesbuild.add_bit_cascade_compare emits an N-bit comparator family using
only weights in {-1, 0, 1}:
per-bit: gt = A AND NOT B, lt = NOT A AND B, eq = XNOR (2 layers)
cascade prefix: eq_prefix[i] = AND of eq[0..i-1]
cascade per-bit: cascade.gt[i] = eq_prefix[i] AND gt[i] (LT analogous)
final OR for GT and LT, AND of all eq for EQ
GE = NOT(LT), LE = NOT(GT) via NOT + identity buffer
add_comparators (8-bit) and add_comparators_nbits (16/32-bit) now both
delegate to this builder. The single-layer 8-bit case (16 weights at
+/-128, +1 bias) and the 32-bit byte-cascade structure (byte-level
+/-128 weights) are removed; all comparator weights across the library
are now ternary by construction.
eval._eval_bit_cascade_compare walks the new structure. _test_comparators
and _test_comparators_nbits were rewritten to use it; legacy paths kept
under _test_comparators_nbits_legacy for old files.
cmd_alu now also drops arithmetic.cmp{N}bit. on rebuild so stale
byte-cascade gates from old seed files don't leak through.
22 of 183 non-ternary weight tensors eliminated by this change. The
remaining 161 are positional comparators inside division stages
(integer 8/16/32-bit and float16/32 mantissa div), float magnitude
comparators, modular arithmetic detection gates, and pattern_recognition
priority encoders -- subsequent passes will bit-cascade those too.
All 18 variants rebuilt; eval_all.py reports 100% fitness on every one.
Test counts increased across the board (e.g. 8-bit small: 6772 -> 7171)
because the bit-cascade exposes more sub-gates for individual checking.
- build.py +96 -92
- eval.py +221 -92
- neural_computer.safetensors +2 -2
- variants/neural_alu16.safetensors +2 -2
- variants/neural_alu32.safetensors +2 -2
- variants/neural_alu8.safetensors +2 -2
- variants/neural_computer16.safetensors +2 -2
- variants/neural_computer16_reduced.safetensors +2 -2
- variants/neural_computer16_registers.safetensors +2 -2
- variants/neural_computer16_scratchpad.safetensors +2 -2
- variants/neural_computer16_small.safetensors +2 -2
- variants/neural_computer32.safetensors +2 -2
- variants/neural_computer32_reduced.safetensors +2 -2
- variants/neural_computer32_registers.safetensors +2 -2
- variants/neural_computer32_scratchpad.safetensors +2 -2
- variants/neural_computer32_small.safetensors +2 -2
- variants/neural_computer8.safetensors +2 -2
- variants/neural_computer8_reduced.safetensors +2 -2
- variants/neural_computer8_registers.safetensors +2 -2
- variants/neural_computer8_scratchpad.safetensors +2 -2
- variants/neural_computer8_small.safetensors +2 -2
|
@@ -769,41 +769,95 @@ def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> N
|
|
| 769 |
add_gate(tensors, f"{prefix}.valid", [1.0] * bits, [-1.0])
|
| 770 |
|
| 771 |
|
| 772 |
-
def
|
| 773 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
|
| 775 |
-
|
| 776 |
-
|
|
|
|
|
|
|
| 777 |
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
- This becomes: sum(a_i * w_i - b_i * w_i) > 0
|
| 782 |
-
- Or: sum((a_i - b_i) * w_i) > 0
|
| 783 |
|
| 784 |
-
|
|
|
|
| 785 |
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
For A == B: need A >= B AND A <= B (two-layer)
|
| 792 |
-
"""
|
| 793 |
-
pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0]
|
| 794 |
-
neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0]
|
| 795 |
|
| 796 |
-
gt_weights = pos_weights + neg_weights
|
| 797 |
-
lt_weights = neg_weights + pos_weights
|
| 798 |
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
add_gate(tensors, "arithmetic.lessthan8bit", lt_weights, [-1.0])
|
| 802 |
-
add_gate(tensors, "arithmetic.lessorequal8bit", lt_weights, [0.0])
|
| 803 |
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
|
| 808 |
|
| 809 |
def add_ripple_carry_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
@@ -841,72 +895,20 @@ def add_sub_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
| 841 |
|
| 842 |
|
| 843 |
def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
| 844 |
-
"""Add N-bit comparator circuits (GT, LT, GE, LE, EQ).
|
| 845 |
|
| 846 |
-
|
| 847 |
-
For bits > 16: Use cascaded byte-wise comparison to avoid float32 precision loss.
|
| 848 |
-
|
| 849 |
-
Cascaded approach compares byte-by-byte from MSB:
|
| 850 |
-
A > B iff: (A[31:24] > B[31:24]) OR
|
| 851 |
-
(A[31:24] == B[31:24] AND A[23:16] > B[23:16]) OR ...
|
| 852 |
"""
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
add_gate(tensors, f"arithmetic.lessorequal{bits}bit", lt_weights, [0.0])
|
| 864 |
-
|
| 865 |
-
add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.geq", gt_weights, [0.0])
|
| 866 |
-
add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.leq", lt_weights, [0.0])
|
| 867 |
-
add_gate(tensors, f"arithmetic.equality{bits}bit.layer2", [1.0, 1.0], [-2.0])
|
| 868 |
-
else:
|
| 869 |
-
num_bytes = bits // 8
|
| 870 |
-
prefix = f"arithmetic.cmp{bits}bit"
|
| 871 |
-
|
| 872 |
-
byte_pos_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0]
|
| 873 |
-
byte_neg_weights = [-128.0, -64.0, -32.0, -16.0, -8.0, -4.0, -2.0, -1.0]
|
| 874 |
-
byte_gt_weights = byte_pos_weights + byte_neg_weights
|
| 875 |
-
byte_lt_weights = byte_neg_weights + byte_pos_weights
|
| 876 |
-
|
| 877 |
-
for b in range(num_bytes):
|
| 878 |
-
add_gate(tensors, f"{prefix}.byte{b}.gt", byte_gt_weights, [-1.0])
|
| 879 |
-
add_gate(tensors, f"{prefix}.byte{b}.lt", byte_lt_weights, [-1.0])
|
| 880 |
-
add_gate(tensors, f"{prefix}.byte{b}.eq.geq", byte_gt_weights, [0.0])
|
| 881 |
-
add_gate(tensors, f"{prefix}.byte{b}.eq.leq", byte_lt_weights, [0.0])
|
| 882 |
-
add_gate(tensors, f"{prefix}.byte{b}.eq.and", [1.0, 1.0], [-2.0])
|
| 883 |
-
|
| 884 |
-
for b in range(num_bytes):
|
| 885 |
-
if b == 0:
|
| 886 |
-
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}", [1.0], [-1.0])
|
| 887 |
-
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}", [1.0], [-1.0])
|
| 888 |
-
else:
|
| 889 |
-
eq_weights = [1.0] * b
|
| 890 |
-
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.all_eq", eq_weights, [-float(b)])
|
| 891 |
-
add_gate(tensors, f"{prefix}.cascade.gt.stage{b}.and", [1.0, 1.0], [-2.0])
|
| 892 |
-
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.all_eq", eq_weights, [-float(b)])
|
| 893 |
-
add_gate(tensors, f"{prefix}.cascade.lt.stage{b}.and", [1.0, 1.0], [-2.0])
|
| 894 |
-
|
| 895 |
-
or_weights_gt = [1.0] * num_bytes
|
| 896 |
-
or_weights_lt = [1.0] * num_bytes
|
| 897 |
-
add_gate(tensors, f"arithmetic.greaterthan{bits}bit", or_weights_gt, [-1.0])
|
| 898 |
-
add_gate(tensors, f"arithmetic.lessthan{bits}bit", or_weights_lt, [-1.0])
|
| 899 |
-
|
| 900 |
-
not_lt_weights = [-1.0]
|
| 901 |
-
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit.not_lt", not_lt_weights, [0.0])
|
| 902 |
-
add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", [1.0], [-1.0])
|
| 903 |
-
|
| 904 |
-
not_gt_weights = [-1.0]
|
| 905 |
-
add_gate(tensors, f"arithmetic.lessorequal{bits}bit.not_gt", not_gt_weights, [0.0])
|
| 906 |
-
add_gate(tensors, f"arithmetic.lessorequal{bits}bit", [1.0], [-1.0])
|
| 907 |
-
|
| 908 |
-
eq_all_weights = [1.0] * num_bytes
|
| 909 |
-
add_gate(tensors, f"arithmetic.equality{bits}bit", eq_all_weights, [-float(num_bytes)])
|
| 910 |
|
| 911 |
|
| 912 |
def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
@@ -2910,6 +2912,7 @@ def cmd_alu(args) -> None:
|
|
| 2910 |
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
|
| 2911 |
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
|
| 2912 |
"arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
|
|
|
|
| 2913 |
"combinational.barrelshifter.", "combinational.priorityencoder.",
|
| 2914 |
"float16.", "float32.",
|
| 2915 |
]
|
|
@@ -2920,6 +2923,7 @@ def cmd_alu(args) -> None:
|
|
| 2920 |
f"arithmetic.sub{bits}bit.", f"arithmetic.greaterthan{bits}bit.",
|
| 2921 |
f"arithmetic.lessthan{bits}bit.", f"arithmetic.greaterorequal{bits}bit.",
|
| 2922 |
f"arithmetic.lessorequal{bits}bit.", f"arithmetic.equality{bits}bit.",
|
|
|
|
| 2923 |
])
|
| 2924 |
|
| 2925 |
print("\nDropping existing ALU extension tensors...")
|
|
|
|
| 769 |
add_gate(tensors, f"{prefix}.valid", [1.0] * bits, [-1.0])
|
| 770 |
|
| 771 |
|
| 772 |
+
def add_bit_cascade_compare(
|
| 773 |
+
tensors: Dict[str, torch.Tensor],
|
| 774 |
+
cmp_prefix: str,
|
| 775 |
+
bits: int,
|
| 776 |
+
out_gt: str,
|
| 777 |
+
out_lt: str,
|
| 778 |
+
out_ge: str,
|
| 779 |
+
out_le: str,
|
| 780 |
+
out_eq: str,
|
| 781 |
+
) -> None:
|
| 782 |
+
"""Generic ternary-only N-bit comparator.
|
| 783 |
+
|
| 784 |
+
Inputs are two N-bit values A and B in MSB-first order. The structure
|
| 785 |
+
produces unsigned-magnitude GT, LT, GE, LE, EQ outputs using only
|
| 786 |
+
weights in {-1, 0, 1} and integer biases.
|
| 787 |
+
|
| 788 |
+
Per-bit primitives (i = 0 is the MSB):
|
| 789 |
+
{cmp_prefix}.bit{i}.gt A[i] AND NOT B[i] weights [1, -1], bias -1
|
| 790 |
+
{cmp_prefix}.bit{i}.lt NOT A[i] AND B[i] weights [-1, 1], bias -1
|
| 791 |
+
{cmp_prefix}.bit{i}.eq.layer1.and A[i] AND B[i]
|
| 792 |
+
{cmp_prefix}.bit{i}.eq.layer1.nor NOR(A[i], B[i])
|
| 793 |
+
{cmp_prefix}.bit{i}.eq XNOR via OR of layer1 outputs
|
| 794 |
+
|
| 795 |
+
Cascade (linear chain from MSB to LSB):
|
| 796 |
+
{cmp_prefix}.cascade.eq_prefix.bit{i} AND of eq[0..i-1] (i in 1..N-1)
|
| 797 |
+
{cmp_prefix}.cascade.gt.bit{i} eq_prefix[i] AND gt[i]
|
| 798 |
+
{cmp_prefix}.cascade.lt.bit{i} eq_prefix[i] AND lt[i]
|
| 799 |
+
|
| 800 |
+
Final outputs:
|
| 801 |
+
out_gt = OR of (gt[0], cascade.gt.bit{1..N-1})
|
| 802 |
+
out_lt = OR of (lt[0], cascade.lt.bit{1..N-1})
|
| 803 |
+
out_eq = AND of all eq[i]
|
| 804 |
+
out_ge = NOT(out_lt)
|
| 805 |
+
out_le = NOT(out_gt)
|
| 806 |
+
"""
|
| 807 |
+
for i in range(bits):
|
| 808 |
+
# per-bit GT: A[i] AND NOT B[i] -> H(A - B - 1)
|
| 809 |
+
add_gate(tensors, f"{cmp_prefix}.bit{i}.gt", [1.0, -1.0], [-1.0])
|
| 810 |
+
# per-bit LT: NOT A[i] AND B[i] -> H(-A + B - 1)
|
| 811 |
+
add_gate(tensors, f"{cmp_prefix}.bit{i}.lt", [-1.0, 1.0], [-1.0])
|
| 812 |
+
# per-bit EQ via XNOR = (A AND B) OR (NOR A B)
|
| 813 |
+
add_gate(tensors, f"{cmp_prefix}.bit{i}.eq.layer1.and", [1.0, 1.0], [-2.0])
|
| 814 |
+
add_gate(tensors, f"{cmp_prefix}.bit{i}.eq.layer1.nor", [-1.0, -1.0], [0.0])
|
| 815 |
+
add_gate(tensors, f"{cmp_prefix}.bit{i}.eq", [1.0, 1.0], [-1.0])
|
| 816 |
+
|
| 817 |
+
# eq_prefix[i] = AND of eq[0..i-1], i in 1..N-1
|
| 818 |
+
for i in range(1, bits):
|
| 819 |
+
add_gate(
|
| 820 |
+
tensors,
|
| 821 |
+
f"{cmp_prefix}.cascade.eq_prefix.bit{i}",
|
| 822 |
+
[1.0] * i,
|
| 823 |
+
[-float(i)],
|
| 824 |
+
)
|
| 825 |
|
| 826 |
+
# cascade.gt[i], cascade.lt[i] for i in 1..N-1
|
| 827 |
+
for i in range(1, bits):
|
| 828 |
+
add_gate(tensors, f"{cmp_prefix}.cascade.gt.bit{i}", [1.0, 1.0], [-2.0])
|
| 829 |
+
add_gate(tensors, f"{cmp_prefix}.cascade.lt.bit{i}", [1.0, 1.0], [-2.0])
|
| 830 |
|
| 831 |
+
# Final OR for GT and LT (N inputs each)
|
| 832 |
+
add_gate(tensors, out_gt, [1.0] * bits, [-1.0])
|
| 833 |
+
add_gate(tensors, out_lt, [1.0] * bits, [-1.0])
|
|
|
|
|
|
|
| 834 |
|
| 835 |
+
# AND of all eq's for EQ
|
| 836 |
+
add_gate(tensors, out_eq, [1.0] * bits, [-float(bits)])
|
| 837 |
|
| 838 |
+
# GE = NOT(LT), LE = NOT(GT) -- single-input NOT then identity buffer
|
| 839 |
+
add_gate(tensors, f"{out_ge}.not_lt", [-1.0], [0.0])
|
| 840 |
+
add_gate(tensors, out_ge, [1.0], [-1.0])
|
| 841 |
+
add_gate(tensors, f"{out_le}.not_gt", [-1.0], [0.0])
|
| 842 |
+
add_gate(tensors, out_le, [1.0], [-1.0])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
|
|
|
|
|
|
|
| 844 |
|
| 845 |
+
def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
|
| 846 |
+
"""Add 8-bit comparator circuits (GT, LT, GE, LE, EQ) using bit-cascade.
|
|
|
|
|
|
|
| 847 |
|
| 848 |
+
Inputs are 8 bits of A then 8 bits of B in MSB-first order. The
|
| 849 |
+
underlying bit-cascade produces only ternary {-1, 0, 1} weights.
|
| 850 |
+
"""
|
| 851 |
+
add_bit_cascade_compare(
|
| 852 |
+
tensors,
|
| 853 |
+
cmp_prefix="arithmetic.cmp8bit",
|
| 854 |
+
bits=8,
|
| 855 |
+
out_gt="arithmetic.greaterthan8bit",
|
| 856 |
+
out_lt="arithmetic.lessthan8bit",
|
| 857 |
+
out_ge="arithmetic.greaterorequal8bit",
|
| 858 |
+
out_le="arithmetic.lessorequal8bit",
|
| 859 |
+
out_eq="arithmetic.equality8bit",
|
| 860 |
+
)
|
| 861 |
|
| 862 |
|
| 863 |
def add_ripple_carry_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
|
|
| 895 |
|
| 896 |
|
| 897 |
def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
| 898 |
+
"""Add N-bit comparator circuits (GT, LT, GE, LE, EQ) via bit-cascade.
|
| 899 |
|
| 900 |
+
All weights are in {-1, 0, 1}. Inputs are A bits then B bits, MSB-first.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
"""
|
| 902 |
+
add_bit_cascade_compare(
|
| 903 |
+
tensors,
|
| 904 |
+
cmp_prefix=f"arithmetic.cmp{bits}bit",
|
| 905 |
+
bits=bits,
|
| 906 |
+
out_gt=f"arithmetic.greaterthan{bits}bit",
|
| 907 |
+
out_lt=f"arithmetic.lessthan{bits}bit",
|
| 908 |
+
out_ge=f"arithmetic.greaterorequal{bits}bit",
|
| 909 |
+
out_le=f"arithmetic.lessorequal{bits}bit",
|
| 910 |
+
out_eq=f"arithmetic.equality{bits}bit",
|
| 911 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
|
| 913 |
|
| 914 |
def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
|
|
| 2912 |
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
|
| 2913 |
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
|
| 2914 |
"arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
|
| 2915 |
+
"arithmetic.cmp8bit.", # bit-cascade internals (replaces single-layer)
|
| 2916 |
"combinational.barrelshifter.", "combinational.priorityencoder.",
|
| 2917 |
"float16.", "float32.",
|
| 2918 |
]
|
|
|
|
| 2923 |
f"arithmetic.sub{bits}bit.", f"arithmetic.greaterthan{bits}bit.",
|
| 2924 |
f"arithmetic.lessthan{bits}bit.", f"arithmetic.greaterorequal{bits}bit.",
|
| 2925 |
f"arithmetic.lessorequal{bits}bit.", f"arithmetic.equality{bits}bit.",
|
| 2926 |
+
f"arithmetic.cmp{bits}bit.", # legacy byte-cascade (32-bit) and new bit-cascade
|
| 2927 |
])
|
| 2928 |
|
| 2929 |
print("\nDropping existing ALU extension tensors...")
|
|
@@ -1636,100 +1636,162 @@ class BatchedFitnessEvaluator:
|
|
| 1636 |
# COMPARATORS
|
| 1637 |
# =========================================================================
|
| 1638 |
|
| 1639 |
-
def
|
| 1640 |
-
|
| 1641 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1642 |
pop_size = next(iter(pop.values())).shape[0]
|
| 1643 |
-
prefix = f'arithmetic.{name}'
|
| 1644 |
-
|
| 1645 |
-
# Use pre-computed test pairs
|
| 1646 |
-
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
|
| 1647 |
-
for a, b in zip(self.comp_a, self.comp_b)],
|
| 1648 |
-
device=self.device)
|
| 1649 |
-
|
| 1650 |
-
# Convert to bits
|
| 1651 |
-
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1)
|
| 1652 |
-
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1)
|
| 1653 |
-
inputs = torch.cat([a_bits, b_bits], dim=1)
|
| 1654 |
-
|
| 1655 |
-
w = pop[f'{prefix}.weight']
|
| 1656 |
-
b = pop[f'{prefix}.bias']
|
| 1657 |
-
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
|
| 1658 |
-
|
| 1659 |
-
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 1660 |
-
|
| 1661 |
-
failures = []
|
| 1662 |
-
if pop_size == 1:
|
| 1663 |
-
for i in range(len(self.comp_a)):
|
| 1664 |
-
if out[i, 0].item() != expected[i].item():
|
| 1665 |
-
failures.append((
|
| 1666 |
-
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
|
| 1667 |
-
expected[i].item(),
|
| 1668 |
-
out[i, 0].item()
|
| 1669 |
-
))
|
| 1670 |
|
| 1671 |
-
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
|
| 1675 |
-
|
| 1676 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1677 |
|
| 1678 |
def _test_comparators(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 1679 |
-
"""Test
|
| 1680 |
pop_size = next(iter(pop.values())).shape[0]
|
| 1681 |
scores = torch.zeros(pop_size, device=self.device)
|
| 1682 |
total = 0
|
| 1683 |
|
| 1684 |
if debug:
|
| 1685 |
-
print("\n=== COMPARATORS ===")
|
| 1686 |
-
|
| 1687 |
-
comparators = [
|
| 1688 |
-
('greaterthan8bit', lambda a, b: a > b),
|
| 1689 |
-
('lessthan8bit', lambda a, b: a < b),
|
| 1690 |
-
('greaterorequal8bit', lambda a, b: a >= b),
|
| 1691 |
-
('lessorequal8bit', lambda a, b: a <= b),
|
| 1692 |
-
('equality8bit', lambda a, b: a == b),
|
| 1693 |
-
]
|
| 1694 |
-
|
| 1695 |
-
for name, op in comparators:
|
| 1696 |
-
if name == 'equality8bit':
|
| 1697 |
-
continue # Handle separately as two-layer
|
| 1698 |
-
try:
|
| 1699 |
-
s, t = self._test_comparator(pop, name, op, debug)
|
| 1700 |
-
scores += s
|
| 1701 |
-
total += t
|
| 1702 |
-
except KeyError:
|
| 1703 |
-
pass # Circuit not present
|
| 1704 |
|
| 1705 |
-
|
|
|
|
|
|
|
| 1706 |
try:
|
| 1707 |
-
|
| 1708 |
-
|
| 1709 |
-
|
| 1710 |
-
|
| 1711 |
-
|
| 1712 |
-
|
| 1713 |
-
|
| 1714 |
-
|
| 1715 |
-
|
| 1716 |
-
|
| 1717 |
-
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
|
| 1721 |
-
|
| 1722 |
-
|
| 1723 |
-
|
| 1724 |
-
|
| 1725 |
-
|
| 1726 |
-
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
|
| 1730 |
-
|
|
|
|
|
|
|
|
|
|
| 1731 |
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 1732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1733 |
failures = []
|
| 1734 |
if pop_size == 1:
|
| 1735 |
for i in range(len(self.comp_a)):
|
|
@@ -1737,28 +1799,22 @@ class BatchedFitnessEvaluator:
|
|
| 1737 |
failures.append((
|
| 1738 |
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
|
| 1739 |
expected[i].item(),
|
| 1740 |
-
out[i, 0].item()
|
| 1741 |
))
|
| 1742 |
-
|
| 1743 |
-
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures)
|
| 1744 |
if debug:
|
| 1745 |
r = self.results[-1]
|
| 1746 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1747 |
-
scores += correct
|
| 1748 |
-
total += len(self.comp_a)
|
| 1749 |
-
except KeyError:
|
| 1750 |
-
pass
|
| 1751 |
-
|
| 1752 |
return scores, total
|
| 1753 |
|
| 1754 |
def _test_comparators_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 1755 |
-
"""Test N-bit comparator circuits (GT, LT, GE, LE, EQ)."""
|
| 1756 |
pop_size = next(iter(pop.values())).shape[0]
|
| 1757 |
scores = torch.zeros(pop_size, device=self.device)
|
| 1758 |
total = 0
|
| 1759 |
|
| 1760 |
if debug:
|
| 1761 |
-
print(f"\n=== {bits}-BIT COMPARATORS ===")
|
| 1762 |
|
| 1763 |
if bits == 32:
|
| 1764 |
comp_a = self.comp32_a
|
|
@@ -1771,7 +1827,80 @@ class BatchedFitnessEvaluator:
|
|
| 1771 |
comp_b = self.comp_b
|
| 1772 |
|
| 1773 |
num_tests = len(comp_a)
|
|
|
|
|
|
|
| 1774 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1775 |
if bits <= 16:
|
| 1776 |
a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1777 |
b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
|
|
|
| 1636 |
# COMPARATORS
|
| 1637 |
# =========================================================================
|
| 1638 |
|
| 1639 |
+
def _eval_bit_cascade_compare(
|
| 1640 |
+
self,
|
| 1641 |
+
pop: Dict,
|
| 1642 |
+
cmp_prefix: str,
|
| 1643 |
+
out_gt: str,
|
| 1644 |
+
out_lt: str,
|
| 1645 |
+
out_ge: str,
|
| 1646 |
+
out_le: str,
|
| 1647 |
+
out_eq: str,
|
| 1648 |
+
bits: int,
|
| 1649 |
+
a_bits_2d: torch.Tensor,
|
| 1650 |
+
b_bits_2d: torch.Tensor,
|
| 1651 |
+
) -> Dict[str, torch.Tensor]:
|
| 1652 |
+
"""Walk the ternary bit-cascade comparator generated by
|
| 1653 |
+
build.add_bit_cascade_compare. Returns a dict with gt/lt/ge/le/eq each of
|
| 1654 |
+
shape [num_tests, pop_size]. a_bits_2d, b_bits_2d are [num_tests, bits]
|
| 1655 |
+
MSB-first.
|
| 1656 |
+
"""
|
| 1657 |
pop_size = next(iter(pop.values())).shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1658 |
|
| 1659 |
+
# Per-bit gt, lt, eq
|
| 1660 |
+
gt_b: List[torch.Tensor] = []
|
| 1661 |
+
lt_b: List[torch.Tensor] = []
|
| 1662 |
+
eq_b: List[torch.Tensor] = []
|
| 1663 |
+
for i in range(bits):
|
| 1664 |
+
a_i = a_bits_2d[:, i].unsqueeze(1).expand(-1, pop_size)
|
| 1665 |
+
b_i = b_bits_2d[:, i].unsqueeze(1).expand(-1, pop_size)
|
| 1666 |
+
ab = torch.stack([a_i, b_i], dim=-1)
|
| 1667 |
+
|
| 1668 |
+
w = pop[f'{cmp_prefix}.bit{i}.gt.weight'].view(pop_size, 2)
|
| 1669 |
+
bb = pop[f'{cmp_prefix}.bit{i}.gt.bias'].view(pop_size)
|
| 1670 |
+
gt_b.append(heaviside((ab * w).sum(-1) + bb))
|
| 1671 |
+
|
| 1672 |
+
w = pop[f'{cmp_prefix}.bit{i}.lt.weight'].view(pop_size, 2)
|
| 1673 |
+
bb = pop[f'{cmp_prefix}.bit{i}.lt.bias'].view(pop_size)
|
| 1674 |
+
lt_b.append(heaviside((ab * w).sum(-1) + bb))
|
| 1675 |
+
|
| 1676 |
+
w = pop[f'{cmp_prefix}.bit{i}.eq.layer1.and.weight'].view(pop_size, 2)
|
| 1677 |
+
bb = pop[f'{cmp_prefix}.bit{i}.eq.layer1.and.bias'].view(pop_size)
|
| 1678 |
+
h_and = heaviside((ab * w).sum(-1) + bb)
|
| 1679 |
+
w = pop[f'{cmp_prefix}.bit{i}.eq.layer1.nor.weight'].view(pop_size, 2)
|
| 1680 |
+
bb = pop[f'{cmp_prefix}.bit{i}.eq.layer1.nor.bias'].view(pop_size)
|
| 1681 |
+
h_nor = heaviside((ab * w).sum(-1) + bb)
|
| 1682 |
+
hidden = torch.stack([h_and, h_nor], dim=-1)
|
| 1683 |
+
w = pop[f'{cmp_prefix}.bit{i}.eq.weight'].view(pop_size, 2)
|
| 1684 |
+
bb = pop[f'{cmp_prefix}.bit{i}.eq.bias'].view(pop_size)
|
| 1685 |
+
eq_b.append(heaviside((hidden * w).sum(-1) + bb))
|
| 1686 |
+
|
| 1687 |
+
# eq_prefix[i] = AND of eq[0..i-1]
|
| 1688 |
+
eq_pref: List[Optional[torch.Tensor]] = [None]
|
| 1689 |
+
for i in range(1, bits):
|
| 1690 |
+
eq_stack = torch.stack(eq_b[:i], dim=-1)
|
| 1691 |
+
w = pop[f'{cmp_prefix}.cascade.eq_prefix.bit{i}.weight'].view(pop_size, i)
|
| 1692 |
+
bb = pop[f'{cmp_prefix}.cascade.eq_prefix.bit{i}.bias'].view(pop_size)
|
| 1693 |
+
eq_pref.append(heaviside((eq_stack * w).sum(-1) + bb))
|
| 1694 |
+
|
| 1695 |
+
# cascade gt[i], lt[i] = eq_prefix[i] AND gt_b[i] / lt_b[i]
|
| 1696 |
+
casc_gt = [gt_b[0]]
|
| 1697 |
+
casc_lt = [lt_b[0]]
|
| 1698 |
+
for i in range(1, bits):
|
| 1699 |
+
inp = torch.stack([eq_pref[i], gt_b[i]], dim=-1)
|
| 1700 |
+
w = pop[f'{cmp_prefix}.cascade.gt.bit{i}.weight'].view(pop_size, 2)
|
| 1701 |
+
bb = pop[f'{cmp_prefix}.cascade.gt.bit{i}.bias'].view(pop_size)
|
| 1702 |
+
casc_gt.append(heaviside((inp * w).sum(-1) + bb))
|
| 1703 |
+
inp = torch.stack([eq_pref[i], lt_b[i]], dim=-1)
|
| 1704 |
+
w = pop[f'{cmp_prefix}.cascade.lt.bit{i}.weight'].view(pop_size, 2)
|
| 1705 |
+
bb = pop[f'{cmp_prefix}.cascade.lt.bit{i}.bias'].view(pop_size)
|
| 1706 |
+
casc_lt.append(heaviside((inp * w).sum(-1) + bb))
|
| 1707 |
+
|
| 1708 |
+
# Final OR for GT / LT
|
| 1709 |
+
gt_stack = torch.stack(casc_gt, dim=-1)
|
| 1710 |
+
w = pop[f'{out_gt}.weight'].view(pop_size, bits)
|
| 1711 |
+
bb = pop[f'{out_gt}.bias'].view(pop_size)
|
| 1712 |
+
final_gt = heaviside((gt_stack * w).sum(-1) + bb)
|
| 1713 |
+
|
| 1714 |
+
lt_stack = torch.stack(casc_lt, dim=-1)
|
| 1715 |
+
w = pop[f'{out_lt}.weight'].view(pop_size, bits)
|
| 1716 |
+
bb = pop[f'{out_lt}.bias'].view(pop_size)
|
| 1717 |
+
final_lt = heaviside((lt_stack * w).sum(-1) + bb)
|
| 1718 |
+
|
| 1719 |
+
# Final AND for EQ
|
| 1720 |
+
eq_stack = torch.stack(eq_b, dim=-1)
|
| 1721 |
+
w = pop[f'{out_eq}.weight'].view(pop_size, bits)
|
| 1722 |
+
bb = pop[f'{out_eq}.bias'].view(pop_size)
|
| 1723 |
+
final_eq = heaviside((eq_stack * w).sum(-1) + bb)
|
| 1724 |
+
|
| 1725 |
+
# GE = NOT(LT) buffer pair, LE = NOT(GT) buffer pair
|
| 1726 |
+
w = pop[f'{out_ge}.not_lt.weight'].view(pop_size)
|
| 1727 |
+
bb = pop[f'{out_ge}.not_lt.bias'].view(pop_size)
|
| 1728 |
+
not_lt = heaviside(final_lt * w + bb)
|
| 1729 |
+
w = pop[f'{out_ge}.weight'].view(pop_size)
|
| 1730 |
+
bb = pop[f'{out_ge}.bias'].view(pop_size)
|
| 1731 |
+
final_ge = heaviside(not_lt * w + bb)
|
| 1732 |
+
|
| 1733 |
+
w = pop[f'{out_le}.not_gt.weight'].view(pop_size)
|
| 1734 |
+
bb = pop[f'{out_le}.not_gt.bias'].view(pop_size)
|
| 1735 |
+
not_gt = heaviside(final_gt * w + bb)
|
| 1736 |
+
w = pop[f'{out_le}.weight'].view(pop_size)
|
| 1737 |
+
bb = pop[f'{out_le}.bias'].view(pop_size)
|
| 1738 |
+
final_le = heaviside(not_gt * w + bb)
|
| 1739 |
+
|
| 1740 |
+
return {
|
| 1741 |
+
"gt": final_gt, "lt": final_lt, "eq": final_eq,
|
| 1742 |
+
"ge": final_ge, "le": final_le,
|
| 1743 |
+
}
|
| 1744 |
|
| 1745 |
def _test_comparators(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 1746 |
+
"""Test 8-bit comparators (bit-cascade)."""
|
| 1747 |
pop_size = next(iter(pop.values())).shape[0]
|
| 1748 |
scores = torch.zeros(pop_size, device=self.device)
|
| 1749 |
total = 0
|
| 1750 |
|
| 1751 |
if debug:
|
| 1752 |
+
print("\n=== COMPARATORS (8-bit bit-cascade) ===")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1753 |
|
| 1754 |
+
bits = 8
|
| 1755 |
+
a_bits = torch.stack([((self.comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1756 |
+
b_bits = torch.stack([((self.comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1757 |
try:
|
| 1758 |
+
outs = self._eval_bit_cascade_compare(
|
| 1759 |
+
pop,
|
| 1760 |
+
f"arithmetic.cmp{bits}bit",
|
| 1761 |
+
f"arithmetic.greaterthan{bits}bit",
|
| 1762 |
+
f"arithmetic.lessthan{bits}bit",
|
| 1763 |
+
f"arithmetic.greaterorequal{bits}bit",
|
| 1764 |
+
f"arithmetic.lessorequal{bits}bit",
|
| 1765 |
+
f"arithmetic.equality{bits}bit",
|
| 1766 |
+
bits,
|
| 1767 |
+
a_bits,
|
| 1768 |
+
b_bits,
|
| 1769 |
+
)
|
| 1770 |
+
except KeyError:
|
| 1771 |
+
return scores, total
|
| 1772 |
+
|
| 1773 |
+
for kind, op in [
|
| 1774 |
+
("gt", lambda a, b: a > b),
|
| 1775 |
+
("lt", lambda a, b: a < b),
|
| 1776 |
+
("ge", lambda a, b: a >= b),
|
| 1777 |
+
("le", lambda a, b: a <= b),
|
| 1778 |
+
("eq", lambda a, b: a == b),
|
| 1779 |
+
]:
|
| 1780 |
+
expected = torch.tensor(
|
| 1781 |
+
[1.0 if op(a.item(), b.item()) else 0.0 for a, b in zip(self.comp_a, self.comp_b)],
|
| 1782 |
+
device=self.device,
|
| 1783 |
+
)
|
| 1784 |
+
out = outs[kind]
|
| 1785 |
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 1786 |
+
scores += correct
|
| 1787 |
+
total += len(self.comp_a)
|
| 1788 |
+
name_map = {
|
| 1789 |
+
"gt": f"arithmetic.greaterthan{bits}bit",
|
| 1790 |
+
"lt": f"arithmetic.lessthan{bits}bit",
|
| 1791 |
+
"ge": f"arithmetic.greaterorequal{bits}bit",
|
| 1792 |
+
"le": f"arithmetic.lessorequal{bits}bit",
|
| 1793 |
+
"eq": f"arithmetic.equality{bits}bit",
|
| 1794 |
+
}
|
| 1795 |
failures = []
|
| 1796 |
if pop_size == 1:
|
| 1797 |
for i in range(len(self.comp_a)):
|
|
|
|
| 1799 |
failures.append((
|
| 1800 |
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
|
| 1801 |
expected[i].item(),
|
| 1802 |
+
out[i, 0].item(),
|
| 1803 |
))
|
| 1804 |
+
self._record(name_map[kind], int(correct[0].item()), len(self.comp_a), failures)
|
|
|
|
| 1805 |
if debug:
|
| 1806 |
r = self.results[-1]
|
| 1807 |
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1808 |
return scores, total
|
| 1809 |
|
| 1810 |
def _test_comparators_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 1811 |
+
"""Test N-bit comparator circuits (GT, LT, GE, LE, EQ) via bit-cascade."""
|
| 1812 |
pop_size = next(iter(pop.values())).shape[0]
|
| 1813 |
scores = torch.zeros(pop_size, device=self.device)
|
| 1814 |
total = 0
|
| 1815 |
|
| 1816 |
if debug:
|
| 1817 |
+
print(f"\n=== {bits}-BIT COMPARATORS (bit-cascade) ===")
|
| 1818 |
|
| 1819 |
if bits == 32:
|
| 1820 |
comp_a = self.comp32_a
|
|
|
|
| 1827 |
comp_b = self.comp_b
|
| 1828 |
|
| 1829 |
num_tests = len(comp_a)
|
| 1830 |
+
a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1831 |
+
b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1832 |
|
| 1833 |
+
try:
|
| 1834 |
+
outs = self._eval_bit_cascade_compare(
|
| 1835 |
+
pop,
|
| 1836 |
+
f"arithmetic.cmp{bits}bit",
|
| 1837 |
+
f"arithmetic.greaterthan{bits}bit",
|
| 1838 |
+
f"arithmetic.lessthan{bits}bit",
|
| 1839 |
+
f"arithmetic.greaterorequal{bits}bit",
|
| 1840 |
+
f"arithmetic.lessorequal{bits}bit",
|
| 1841 |
+
f"arithmetic.equality{bits}bit",
|
| 1842 |
+
bits,
|
| 1843 |
+
a_bits,
|
| 1844 |
+
b_bits,
|
| 1845 |
+
)
|
| 1846 |
+
except KeyError:
|
| 1847 |
+
return scores, total
|
| 1848 |
+
|
| 1849 |
+
for kind, op in [
|
| 1850 |
+
("gt", lambda a, b: a > b),
|
| 1851 |
+
("lt", lambda a, b: a < b),
|
| 1852 |
+
("ge", lambda a, b: a >= b),
|
| 1853 |
+
("le", lambda a, b: a <= b),
|
| 1854 |
+
("eq", lambda a, b: a == b),
|
| 1855 |
+
]:
|
| 1856 |
+
expected = torch.tensor(
|
| 1857 |
+
[1.0 if op(a.item(), b.item()) else 0.0 for a, b in zip(comp_a, comp_b)],
|
| 1858 |
+
device=self.device,
|
| 1859 |
+
)
|
| 1860 |
+
out = outs[kind]
|
| 1861 |
+
correct = (out == expected.unsqueeze(1)).float().sum(0)
|
| 1862 |
+
scores += correct
|
| 1863 |
+
total += num_tests
|
| 1864 |
+
name_map = {
|
| 1865 |
+
"gt": f"arithmetic.greaterthan{bits}bit",
|
| 1866 |
+
"lt": f"arithmetic.lessthan{bits}bit",
|
| 1867 |
+
"ge": f"arithmetic.greaterorequal{bits}bit",
|
| 1868 |
+
"le": f"arithmetic.lessorequal{bits}bit",
|
| 1869 |
+
"eq": f"arithmetic.equality{bits}bit",
|
| 1870 |
+
}
|
| 1871 |
+
failures = []
|
| 1872 |
+
if pop_size == 1:
|
| 1873 |
+
for i in range(num_tests):
|
| 1874 |
+
if out[i, 0].item() != expected[i].item():
|
| 1875 |
+
failures.append((
|
| 1876 |
+
[int(comp_a[i].item()), int(comp_b[i].item())],
|
| 1877 |
+
expected[i].item(),
|
| 1878 |
+
out[i, 0].item(),
|
| 1879 |
+
))
|
| 1880 |
+
self._record(name_map[kind], int(correct[0].item()), num_tests, failures)
|
| 1881 |
+
if debug:
|
| 1882 |
+
r = self.results[-1]
|
| 1883 |
+
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
|
| 1884 |
+
|
| 1885 |
+
return scores, total
|
| 1886 |
+
|
| 1887 |
+
# Legacy single-layer/byte-cascade path retained for backwards-compat with
|
| 1888 |
+
# variants built before the bit-cascade migration. Unused on freshly-built
|
| 1889 |
+
# variants but kept to avoid surprises if someone loads an older file.
|
| 1890 |
+
def _test_comparators_nbits_legacy(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 1891 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 1892 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 1893 |
+
total = 0
|
| 1894 |
+
if bits == 32:
|
| 1895 |
+
comp_a = self.comp32_a
|
| 1896 |
+
comp_b = self.comp32_b
|
| 1897 |
+
elif bits == 16:
|
| 1898 |
+
comp_a = self.comp_a.clamp(0, 65535)
|
| 1899 |
+
comp_b = self.comp_b.clamp(0, 65535)
|
| 1900 |
+
else:
|
| 1901 |
+
comp_a = self.comp_a
|
| 1902 |
+
comp_b = self.comp_b
|
| 1903 |
+
num_tests = len(comp_a)
|
| 1904 |
if bits <= 16:
|
| 1905 |
a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
| 1906 |
b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67c1d45eebfde84f4a82fc272ca94b80f23007a69f6d26c120fce62b86eb8b3c
|
| 3 |
+
size 21787023
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2456dd03f42ff0f8a268ad865e326ee5ea0506987a21e08b77d2bc3fafde970
|
| 3 |
+
size 13852411
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5470a8ea7b8d0fe2cfb416f8f12b357b59d4cc6f439219080089db14b18d8ea0
|
| 3 |
+
size 13267721
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:634a5e8bcb0d4daef76bafc7337187ec7de0e4fd75095020d06fcd6381d78180
|
| 3 |
+
size 13029504
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c33ef0e17aee7ceb7f19e5675a4f3f873e48e94a67b35fcb549ddcec60f04bdf
|
| 3 |
+
size 22353297
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34127e23065b24b0c27b0ae485531f4c174dedebfe7cf399dc9fd2ef765b03a1
|
| 3 |
+
size 14542033
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58058a73c9b0082e8bfbbeef816f971181936e7992ec2213d4657e54e863f967
|
| 3 |
+
size 13939193
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb1a44b57d7b78f5df0f12d48343514eb480f625bcad90bb73b757df62157c72
|
| 3 |
+
size 14019889
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de94a6123e7c785adf69054da8d3d54e90cd9c76034a7439244b868c72485c61
|
| 3 |
+
size 14138753
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67c1d45eebfde84f4a82fc272ca94b80f23007a69f6d26c120fce62b86eb8b3c
|
| 3 |
+
size 21787023
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd1294e256be584333fb1969b8130e7a0613b15352ef082be9fa318048548b09
|
| 3 |
+
size 13975735
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6d3191aee7c7420989701f80d3a4ab675655790e8ce05229d1a6526a09887e7
|
| 3 |
+
size 13372895
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6d56828f2884b6fe510c4a53030c934e0d1097fa8092c7e15a775b0676c5d55
|
| 3 |
+
size 13453599
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:729c94a95ead7b34529281072f290bbd289d2b4cdae4e48bbe694f881729dbed
|
| 3 |
+
size 13572439
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfb9886cf5aa965c83c694cd6f9190ccbd33556a31e67af2ddff6adee39330ab
|
| 3 |
+
size 21521230
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b64964bb91266f3a3bb5b8033e69d2a9fabc97b7800865143432fa473f89d3f7
|
| 3 |
+
size 13709942
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50c7526bf822e84c6a4bd1bd46ea221bb7fc91c027f109d8cd53ebad2a1c9385
|
| 3 |
+
size 13107102
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d8dedea26432f64b614854b7b8df3da01b756936dde10cff7eaca25b589e9a5
|
| 3 |
+
size 13187790
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3911afed70ae0f6d1c47c9ccd0aaf9e5a70dc79031157aba0469e4c4481af9f8
|
| 3 |
+
size 13306646
|