CharlesCNorton commited on
Commit ·
a696964
1
Parent(s): 240c04c
Fix priority encoder circuit logic
Browse filesThe priority encoder was using any_ge{pos} which ORs bits from position
pos to end, but the correct logic needs any_higher{pos} which ORs bits
from position 0 to pos-1 (all higher-priority positions).
Circuit structure (MSB-first, position 0 = highest priority):
- any_higher{pos}: OR of bits 0 to pos-1
- is_highest{0}: bit[0] directly (MSB always highest if set)
- is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0
- out{bit}: OR of is_highest signals for matching positions
This fixes the 32-bit priority encoder which was failing 18/217 tests.
Both 8-bit and 32-bit models now pass all tests:
- neural_computer.safetensors: 6772/6772 (100%)
- neural_alu32.safetensors: 7256/7256 (100%)
- build.py +20 -8
- eval.py +37 -9
- neural_alu32.safetensors +2 -2
- neural_computer.safetensors +2 -2
build.py
CHANGED
|
@@ -724,8 +724,15 @@ def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> N
|
|
| 724 |
"""Add N-bit priority encoder circuit.
|
| 725 |
|
| 726 |
Finds the position of the highest set bit (0 to bits-1).
|
|
|
|
| 727 |
Output is ceil(log2(bits))-bit index + valid flag.
|
| 728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
Args:
|
| 730 |
bits: Input width (8, 16, 32, etc.)
|
| 731 |
"""
|
|
@@ -733,18 +740,23 @@ def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> N
|
|
| 733 |
out_bits = max(1, math.ceil(math.log2(bits)))
|
| 734 |
prefix = f"combinational.priorityencoder{bits}"
|
| 735 |
|
| 736 |
-
#
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
add_gate(tensors, f"{prefix}.is_highest{pos}.not_higher", [-1.0], [0.0])
|
| 745 |
add_gate(tensors, f"{prefix}.is_highest{pos}.and", [1.0, 1.0], [-2.0])
|
| 746 |
|
| 747 |
# Encode position to output bits
|
|
|
|
| 748 |
for out_bit in range(out_bits):
|
| 749 |
weights = []
|
| 750 |
for pos in range(bits):
|
|
|
|
| 724 |
"""Add N-bit priority encoder circuit.
|
| 725 |
|
| 726 |
Finds the position of the highest set bit (0 to bits-1).
|
| 727 |
+
Position 0 = MSB (highest priority), Position bits-1 = LSB (lowest priority).
|
| 728 |
Output is ceil(log2(bits))-bit index + valid flag.
|
| 729 |
|
| 730 |
+
Circuit structure:
|
| 731 |
+
1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions)
|
| 732 |
+
2. is_highest{pos}: bit[pos] AND NOT(any_higher{pos})
|
| 733 |
+
3. out{bit}: OR of is_highest{pos} for positions where (pos >> bit) & 1
|
| 734 |
+
4. valid: OR of all input bits
|
| 735 |
+
|
| 736 |
Args:
|
| 737 |
bits: Input width (8, 16, 32, etc.)
|
| 738 |
"""
|
|
|
|
| 740 |
out_bits = max(1, math.ceil(math.log2(bits)))
|
| 741 |
prefix = f"combinational.priorityencoder{bits}"
|
| 742 |
|
| 743 |
+
# any_higher{pos}: OR of all bits at positions 0 to pos-1 (higher priority)
|
| 744 |
+
# any_higher{0} not needed (no higher bits)
|
| 745 |
+
# any_higher{1} = bit[0]
|
| 746 |
+
# any_higher{N} = bit[0] OR bit[1] OR ... OR bit[N-1]
|
| 747 |
+
for pos in range(1, bits):
|
| 748 |
+
weights = [1.0] * pos
|
| 749 |
+
add_gate(tensors, f"{prefix}.any_higher{pos}", weights, [-1.0])
|
| 750 |
+
|
| 751 |
+
# is_highest{pos}: bit[pos] is set AND no higher-priority bit is set
|
| 752 |
+
# is_highest{0} = bit[0] (always highest if set)
|
| 753 |
+
# is_highest{pos} = bit[pos] AND NOT(any_higher{pos}) for pos > 0
|
| 754 |
+
for pos in range(1, bits):
|
| 755 |
add_gate(tensors, f"{prefix}.is_highest{pos}.not_higher", [-1.0], [0.0])
|
| 756 |
add_gate(tensors, f"{prefix}.is_highest{pos}.and", [1.0, 1.0], [-2.0])
|
| 757 |
|
| 758 |
# Encode position to output bits
|
| 759 |
+
# out{bit} = OR of is_highest{pos} for all pos where (pos >> bit) & 1
|
| 760 |
for out_bit in range(out_bits):
|
| 761 |
weights = []
|
| 762 |
for pos in range(bits):
|
eval.py
CHANGED
|
@@ -2998,7 +2998,15 @@ class BatchedFitnessEvaluator:
|
|
| 2998 |
return scores, total
|
| 2999 |
|
| 3000 |
def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 3001 |
-
"""Test N-bit priority encoder (find highest set bit).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3002 |
import math
|
| 3003 |
pop_size = next(iter(pop.values())).shape[0]
|
| 3004 |
scores = torch.zeros(pop_size, device=self.device)
|
|
@@ -3035,17 +3043,37 @@ class BatchedFitnessEvaluator:
|
|
| 3035 |
total += 1
|
| 3036 |
|
| 3037 |
if expected_valid == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3038 |
for idx_bit in range(out_bits):
|
| 3039 |
try:
|
| 3040 |
-
w_idx = pop[f'{prefix}.out{idx_bit}.weight']
|
| 3041 |
-
num_weights = w_idx.numel() // pop_size
|
| 3042 |
-
w_idx = w_idx.view(pop_size, num_weights)
|
| 3043 |
b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size)
|
| 3044 |
-
|
| 3045 |
-
|
| 3046 |
-
|
| 3047 |
-
|
| 3048 |
-
out_idx = heaviside((relevant_bits[:w_idx.shape[1]] * w_idx).sum(-1) + b_idx)
|
| 3049 |
expected_bit = (expected_idx >> idx_bit) & 1
|
| 3050 |
if int(out_idx[0].item()) == expected_bit:
|
| 3051 |
scores += 1
|
|
|
|
| 2998 |
return scores, total
|
| 2999 |
|
| 3000 |
def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 3001 |
+
"""Test N-bit priority encoder (find highest set bit).
|
| 3002 |
+
|
| 3003 |
+
The priority encoder is a multi-layer circuit:
|
| 3004 |
+
1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions)
|
| 3005 |
+
2. is_highest{0}: bit[0] directly (MSB is always highest if set)
|
| 3006 |
+
3. is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0
|
| 3007 |
+
4. out{bit}: OR of is_highest{pos} for all pos where (pos >> bit) & 1
|
| 3008 |
+
5. valid: OR of all input bits
|
| 3009 |
+
"""
|
| 3010 |
import math
|
| 3011 |
pop_size = next(iter(pop.values())).shape[0]
|
| 3012 |
scores = torch.zeros(pop_size, device=self.device)
|
|
|
|
| 3043 |
total += 1
|
| 3044 |
|
| 3045 |
if expected_valid == 1:
|
| 3046 |
+
any_higher = [None]
|
| 3047 |
+
for pos in range(1, bits):
|
| 3048 |
+
w = pop[f'{prefix}.any_higher{pos}.weight'].view(pop_size, -1)
|
| 3049 |
+
b = pop[f'{prefix}.any_higher{pos}.bias'].view(pop_size)
|
| 3050 |
+
inp = val_bits[:pos]
|
| 3051 |
+
out = heaviside((inp * w[:, :len(inp)]).sum(-1) + b)
|
| 3052 |
+
any_higher.append(out)
|
| 3053 |
+
|
| 3054 |
+
is_highest = []
|
| 3055 |
+
for pos in range(bits):
|
| 3056 |
+
if pos == 0:
|
| 3057 |
+
is_high = val_bits[0].unsqueeze(0).expand(pop_size)
|
| 3058 |
+
else:
|
| 3059 |
+
w_not = pop[f'{prefix}.is_highest{pos}.not_higher.weight'].view(pop_size, -1)
|
| 3060 |
+
b_not = pop[f'{prefix}.is_highest{pos}.not_higher.bias'].view(pop_size)
|
| 3061 |
+
not_higher = heaviside(any_higher[pos].unsqueeze(-1) * w_not + b_not).squeeze(-1)
|
| 3062 |
+
|
| 3063 |
+
w_and = pop[f'{prefix}.is_highest{pos}.and.weight'].view(pop_size, -1)
|
| 3064 |
+
b_and = pop[f'{prefix}.is_highest{pos}.and.bias'].view(pop_size)
|
| 3065 |
+
inp = torch.stack([val_bits[pos].expand(pop_size), not_higher], dim=-1)
|
| 3066 |
+
is_high = heaviside((inp * w_and).sum(-1) + b_and)
|
| 3067 |
+
is_highest.append(is_high)
|
| 3068 |
+
|
| 3069 |
for idx_bit in range(out_bits):
|
| 3070 |
try:
|
| 3071 |
+
w_idx = pop[f'{prefix}.out{idx_bit}.weight'].view(pop_size, -1)
|
|
|
|
|
|
|
| 3072 |
b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size)
|
| 3073 |
+
relevant = [is_highest[pos] for pos in range(bits) if (pos >> idx_bit) & 1]
|
| 3074 |
+
if len(relevant) > 0:
|
| 3075 |
+
inp = torch.stack(relevant[:w_idx.shape[1]], dim=-1)
|
| 3076 |
+
out_idx = heaviside((inp * w_idx).sum(-1) + b_idx)
|
|
|
|
| 3077 |
expected_bit = (expected_idx >> idx_bit) & 1
|
| 3078 |
if int(out_idx[0].item()) == expected_bit:
|
| 3079 |
scores += 1
|
neural_alu32.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2277b9b7ca05aeca4b84da8f8cf48c8ceba9d81ea926a2a1f6be46462fdc9944
|
| 3 |
+
size 10082208
|
neural_computer.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a76635005d7031f01492b1c5d6286dbede39c1ecf08ed4b08daf4e7c3c2fe097
|
| 3 |
+
size 8435820
|