CharlesCNorton commited on
Commit
a696964
·
1 Parent(s): 240c04c

Fix priority encoder circuit logic

Browse files

The priority encoder was using any_ge{pos} which ORs bits from position
pos to end, but the correct logic needs any_higher{pos} which ORs bits
from position 0 to pos-1 (all higher-priority positions).

Circuit structure (MSB-first, position 0 = highest priority):
- any_higher{pos}: OR of bits 0 to pos-1
- is_highest{0}: bit[0] directly (MSB always highest if set)
- is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0
- out{bit}: OR of is_highest signals for matching positions

This fixes the 32-bit priority encoder which was failing 18/217 tests.
Both 8-bit and 32-bit models now pass all tests:
- neural_computer.safetensors: 6772/6772 (100%)
- neural_alu32.safetensors: 7256/7256 (100%)

Files changed (4) hide show
  1. build.py +20 -8
  2. eval.py +37 -9
  3. neural_alu32.safetensors +2 -2
  4. neural_computer.safetensors +2 -2
build.py CHANGED
@@ -724,8 +724,15 @@ def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> N
724
  """Add N-bit priority encoder circuit.
725
 
726
  Finds the position of the highest set bit (0 to bits-1).
 
727
  Output is ceil(log2(bits))-bit index + valid flag.
728
 
 
 
 
 
 
 
729
  Args:
730
  bits: Input width (8, 16, 32, etc.)
731
  """
@@ -733,18 +740,23 @@ def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> N
733
  out_bits = max(1, math.ceil(math.log2(bits)))
734
  prefix = f"combinational.priorityencoder{bits}"
735
 
736
- # Check each bit position (OR gates to detect any bit set at or above position)
737
- for pos in range(bits):
738
- num_inputs = bits - pos
739
- weights = [1.0] * num_inputs
740
- add_gate(tensors, f"{prefix}.any_ge{pos}", weights, [-1.0])
741
-
742
- # Priority logic: pos N is highest if bit N is set AND no higher bit is set
743
- for pos in range(bits):
 
 
 
 
744
  add_gate(tensors, f"{prefix}.is_highest{pos}.not_higher", [-1.0], [0.0])
745
  add_gate(tensors, f"{prefix}.is_highest{pos}.and", [1.0, 1.0], [-2.0])
746
 
747
  # Encode position to output bits
 
748
  for out_bit in range(out_bits):
749
  weights = []
750
  for pos in range(bits):
 
724
  """Add N-bit priority encoder circuit.
725
 
726
  Finds the position of the highest set bit (0 to bits-1).
727
+ Position 0 = MSB (highest priority), Position bits-1 = LSB (lowest priority).
728
  Output is ceil(log2(bits))-bit index + valid flag.
729
 
730
+ Circuit structure:
731
+ 1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions)
732
+ 2. is_highest{pos}: bit[pos] AND NOT(any_higher{pos})
733
+ 3. out{bit}: OR of is_highest{pos} for positions where (pos >> bit) & 1
734
+ 4. valid: OR of all input bits
735
+
736
  Args:
737
  bits: Input width (8, 16, 32, etc.)
738
  """
 
740
  out_bits = max(1, math.ceil(math.log2(bits)))
741
  prefix = f"combinational.priorityencoder{bits}"
742
 
743
+ # any_higher{pos}: OR of all bits at positions 0 to pos-1 (higher priority)
744
+ # any_higher{0} not needed (no higher bits)
745
+ # any_higher{1} = bit[0]
746
+ # any_higher{N} = bit[0] OR bit[1] OR ... OR bit[N-1]
747
+ for pos in range(1, bits):
748
+ weights = [1.0] * pos
749
+ add_gate(tensors, f"{prefix}.any_higher{pos}", weights, [-1.0])
750
+
751
+ # is_highest{pos}: bit[pos] is set AND no higher-priority bit is set
752
+ # is_highest{0} = bit[0] (always highest if set)
753
+ # is_highest{pos} = bit[pos] AND NOT(any_higher{pos}) for pos > 0
754
+ for pos in range(1, bits):
755
  add_gate(tensors, f"{prefix}.is_highest{pos}.not_higher", [-1.0], [0.0])
756
  add_gate(tensors, f"{prefix}.is_highest{pos}.and", [1.0, 1.0], [-2.0])
757
 
758
  # Encode position to output bits
759
+ # out{bit} = OR of is_highest{pos} for all pos where (pos >> bit) & 1
760
  for out_bit in range(out_bits):
761
  weights = []
762
  for pos in range(bits):
eval.py CHANGED
@@ -2998,7 +2998,15 @@ class BatchedFitnessEvaluator:
2998
  return scores, total
2999
 
3000
  def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
3001
- """Test N-bit priority encoder (find highest set bit)."""
 
 
 
 
 
 
 
 
3002
  import math
3003
  pop_size = next(iter(pop.values())).shape[0]
3004
  scores = torch.zeros(pop_size, device=self.device)
@@ -3035,17 +3043,37 @@ class BatchedFitnessEvaluator:
3035
  total += 1
3036
 
3037
  if expected_valid == 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3038
  for idx_bit in range(out_bits):
3039
  try:
3040
- w_idx = pop[f'{prefix}.out{idx_bit}.weight']
3041
- num_weights = w_idx.numel() // pop_size
3042
- w_idx = w_idx.view(pop_size, num_weights)
3043
  b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size)
3044
- relevant_bits = torch.tensor([val_bits[i].item() for i in range(bits)
3045
- if (i >> idx_bit) & 1],
3046
- device=self.device, dtype=torch.float32)
3047
- if len(relevant_bits) > 0:
3048
- out_idx = heaviside((relevant_bits[:w_idx.shape[1]] * w_idx).sum(-1) + b_idx)
3049
  expected_bit = (expected_idx >> idx_bit) & 1
3050
  if int(out_idx[0].item()) == expected_bit:
3051
  scores += 1
 
2998
  return scores, total
2999
 
3000
  def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
3001
+ """Test N-bit priority encoder (find highest set bit).
3002
+
3003
+ The priority encoder is a multi-layer circuit:
3004
+ 1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions)
3005
+ 2. is_highest{0}: bit[0] directly (MSB is always highest if set)
3006
+ 3. is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0
3007
+ 4. out{bit}: OR of is_highest{pos} for all pos where (pos >> bit) & 1
3008
+ 5. valid: OR of all input bits
3009
+ """
3010
  import math
3011
  pop_size = next(iter(pop.values())).shape[0]
3012
  scores = torch.zeros(pop_size, device=self.device)
 
3043
  total += 1
3044
 
3045
  if expected_valid == 1:
3046
+ any_higher = [None]
3047
+ for pos in range(1, bits):
3048
+ w = pop[f'{prefix}.any_higher{pos}.weight'].view(pop_size, -1)
3049
+ b = pop[f'{prefix}.any_higher{pos}.bias'].view(pop_size)
3050
+ inp = val_bits[:pos]
3051
+ out = heaviside((inp * w[:, :len(inp)]).sum(-1) + b)
3052
+ any_higher.append(out)
3053
+
3054
+ is_highest = []
3055
+ for pos in range(bits):
3056
+ if pos == 0:
3057
+ is_high = val_bits[0].unsqueeze(0).expand(pop_size)
3058
+ else:
3059
+ w_not = pop[f'{prefix}.is_highest{pos}.not_higher.weight'].view(pop_size, -1)
3060
+ b_not = pop[f'{prefix}.is_highest{pos}.not_higher.bias'].view(pop_size)
3061
+ not_higher = heaviside(any_higher[pos].unsqueeze(-1) * w_not + b_not).squeeze(-1)
3062
+
3063
+ w_and = pop[f'{prefix}.is_highest{pos}.and.weight'].view(pop_size, -1)
3064
+ b_and = pop[f'{prefix}.is_highest{pos}.and.bias'].view(pop_size)
3065
+ inp = torch.stack([val_bits[pos].expand(pop_size), not_higher], dim=-1)
3066
+ is_high = heaviside((inp * w_and).sum(-1) + b_and)
3067
+ is_highest.append(is_high)
3068
+
3069
  for idx_bit in range(out_bits):
3070
  try:
3071
+ w_idx = pop[f'{prefix}.out{idx_bit}.weight'].view(pop_size, -1)
 
 
3072
  b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size)
3073
+ relevant = [is_highest[pos] for pos in range(bits) if (pos >> idx_bit) & 1]
3074
+ if len(relevant) > 0:
3075
+ inp = torch.stack(relevant[:w_idx.shape[1]], dim=-1)
3076
+ out_idx = heaviside((inp * w_idx).sum(-1) + b_idx)
 
3077
  expected_bit = (expected_idx >> idx_bit) & 1
3078
  if int(out_idx[0].item()) == expected_bit:
3079
  scores += 1
neural_alu32.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6efa5b719d55fa8e071c4dacc90bfe5bff7337c6fab952460f4ccdadf237facb
3
- size 10083624
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2277b9b7ca05aeca4b84da8f8cf48c8ceba9d81ea926a2a1f6be46462fdc9944
3
+ size 10082208
neural_computer.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:812d1833c915945eeb694bca530b075b3e08685bac8646f29e87d26a2d644b88
3
- size 8436636
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a76635005d7031f01492b1c5d6286dbede39c1ecf08ed4b08daf4e7c3c2fe097
3
+ size 8435820