CharlesCNorton commited on
Commit
5d9ffc5
·
1 Parent(s): ce3e896

Add barrel shifter and priority encoder tests

Browse files

- Barrel shifter: 40 tests (left shift by 0-7 positions)
- Priority encoder: 13 tests (find highest set bit)
- Tests: 6,116 -> 6,169
- Fitness: 1.000000

Files changed (2) hide show
  1. README.md +1 -1
  2. eval.py +150 -0
README.md CHANGED
@@ -479,7 +479,7 @@ The interface generalizes to **all** 65,536 8-bit additions once trained—no me
479
  |------|-------------|
480
  | `neural_computer.safetensors` | 11,581 tensors, 8,290,134 parameters |
481
  | `threshold_cpu.py` | CPU state, reference cycle, threshold runtime |
482
- | `eval.py` | Unified evaluation suite (6,116 tests, GPU-batched) |
483
  | `build.py` | Build tools for memory, ALU, and .inputs tensors |
484
  | `prune_weights.py` | Weight magnitude pruning |
485
 
 
479
  |------|-------------|
480
  | `neural_computer.safetensors` | 11,581 tensors, 8,290,134 parameters |
481
  | `threshold_cpu.py` | CPU state, reference cycle, threshold runtime |
482
+ | `eval.py` | Unified evaluation suite (6,169 tests, GPU-batched) |
483
  | `build.py` | Build tools for memory, ALU, and .inputs tensors |
484
  | `prune_weights.py` | Weight magnitude pruning |
485
 
eval.py CHANGED
@@ -1096,6 +1096,156 @@ class BatchedFitnessEvaluator:
1096
  scores += s
1097
  total += t
1098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099
  return scores, total
1100
 
1101
  # =========================================================================
 
1096
  scores += s
1097
  total += t
1098
 
1099
+ s, t = self._test_barrel_shifter(pop, debug)
1100
+ scores += s
1101
+ total += t
1102
+
1103
+ s, t = self._test_priority_encoder(pop, debug)
1104
+ scores += s
1105
+ total += t
1106
+
1107
+ return scores, total
1108
+
1109
+ def _test_barrel_shifter(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
1110
+ """Test barrel shifter (shift by 0-7 positions)."""
1111
+ pop_size = next(iter(pop.values())).shape[0]
1112
+ scores = torch.zeros(pop_size, device=self.device)
1113
+ total = 0
1114
+
1115
+ if debug:
1116
+ print("\n=== BARREL SHIFTER ===")
1117
+
1118
+ try:
1119
+ # Test all shift amounts 0-7 with various input patterns
1120
+ test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, 0xFF]
1121
+
1122
+ for val in test_vals:
1123
+ for shift in range(8):
1124
+ expected_val = (val << shift) & 0xFF # Left shift
1125
+ val_bits = [float((val >> (7 - i)) & 1) for i in range(8)]
1126
+ shift_bits = [float((shift >> (2 - i)) & 1) for i in range(3)]
1127
+
1128
+ # Process through 3 layers
1129
+ layer_in = val_bits[:]
1130
+ for layer in range(3):
1131
+ shift_amount = 1 << (2 - layer) # 4, 2, 1
1132
+ sel = shift_bits[layer]
1133
+ layer_out = []
1134
+
1135
+ for bit in range(8):
1136
+ prefix = f'combinational.barrelshifter.layer{layer}.bit{bit}'
1137
+
1138
+ # NOT sel
1139
+ w_not = pop[f'{prefix}.not_sel.weight'].view(pop_size)
1140
+ b_not = pop[f'{prefix}.not_sel.bias'].view(pop_size)
1141
+ not_sel = heaviside(sel * w_not + b_not)
1142
+
1143
+ # Source for shifted value
1144
+ shifted_src = bit + shift_amount
1145
+ if shifted_src < 8:
1146
+ shifted_val = layer_in[shifted_src]
1147
+ else:
1148
+ shifted_val = 0.0
1149
+
1150
+ # AND a: original AND NOT sel
1151
+ w_and_a = pop[f'{prefix}.and_a.weight'].view(pop_size, 2)
1152
+ b_and_a = pop[f'{prefix}.and_a.bias'].view(pop_size)
1153
+ inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device)
1154
+ and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a)
1155
+
1156
+ # AND b: shifted AND sel
1157
+ w_and_b = pop[f'{prefix}.and_b.weight'].view(pop_size, 2)
1158
+ b_and_b = pop[f'{prefix}.and_b.bias'].view(pop_size)
1159
+ inp_b = torch.tensor([shifted_val, sel], device=self.device)
1160
+ and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b)
1161
+
1162
+ # OR
1163
+ w_or = pop[f'{prefix}.or.weight'].view(pop_size, 2)
1164
+ b_or = pop[f'{prefix}.or.bias'].view(pop_size)
1165
+ inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device)
1166
+ out = heaviside((inp_or * w_or).sum(-1) + b_or)
1167
+ layer_out.append(out[0].item())
1168
+
1169
+ layer_in = layer_out
1170
+
1171
+ # Check result
1172
+ result = sum(int(layer_in[i]) << (7 - i) for i in range(8))
1173
+ if result == expected_val:
1174
+ scores += 1
1175
+ total += 1
1176
+
1177
+ self._record('combinational.barrelshifter', int(scores[0].item()), total, [])
1178
+ if debug:
1179
+ r = self.results[-1]
1180
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1181
+ except (KeyError, RuntimeError) as e:
1182
+ if debug:
1183
+ print(f" combinational.barrelshifter: SKIP ({e})")
1184
+
1185
+ return scores, total
1186
+
1187
+ def _test_priority_encoder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
1188
+ """Test priority encoder (find highest set bit)."""
1189
+ pop_size = next(iter(pop.values())).shape[0]
1190
+ scores = torch.zeros(pop_size, device=self.device)
1191
+ total = 0
1192
+
1193
+ if debug:
1194
+ print("\n=== PRIORITY ENCODER ===")
1195
+
1196
+ try:
1197
+ # Test cases: input -> (valid, index of highest bit)
1198
+ test_cases = [
1199
+ (0b00000000, 0, 0), # No bits set, valid=0
1200
+ (0b00000001, 1, 7), # Bit 7 (LSB)
1201
+ (0b00000010, 1, 6),
1202
+ (0b00000100, 1, 5),
1203
+ (0b00001000, 1, 4),
1204
+ (0b00010000, 1, 3),
1205
+ (0b00100000, 1, 2),
1206
+ (0b01000000, 1, 1),
1207
+ (0b10000000, 1, 0), # Bit 0 (MSB)
1208
+ (0b10000001, 1, 0), # Multiple bits, highest wins
1209
+ (0b01010101, 1, 1),
1210
+ (0b00001111, 1, 4),
1211
+ (0b11111111, 1, 0),
1212
+ ]
1213
+
1214
+ for val, expected_valid, expected_idx in test_cases:
1215
+ val_bits = torch.tensor([float((val >> (7 - i)) & 1) for i in range(8)],
1216
+ device=self.device, dtype=torch.float32)
1217
+
1218
+ # Valid output: OR of all input bits
1219
+ w_valid = pop['combinational.priorityencoder.valid.weight'].view(pop_size, 8)
1220
+ b_valid = pop['combinational.priorityencoder.valid.bias'].view(pop_size)
1221
+ out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid)
1222
+
1223
+ if int(out_valid[0].item()) == expected_valid:
1224
+ scores += 1
1225
+ total += 1
1226
+
1227
+ # Index outputs (3 bits)
1228
+ if expected_valid == 1:
1229
+ for idx_bit in range(3):
1230
+ try:
1231
+ w_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.weight'].view(pop_size, 8)
1232
+ b_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.bias'].view(pop_size)
1233
+ out_idx = heaviside((val_bits * w_idx).sum(-1) + b_idx)
1234
+ expected_bit = (expected_idx >> (2 - idx_bit)) & 1
1235
+ if int(out_idx[0].item()) == expected_bit:
1236
+ scores += 1
1237
+ total += 1
1238
+ except KeyError:
1239
+ pass
1240
+
1241
+ self._record('combinational.priorityencoder', int(scores[0].item()), total, [])
1242
+ if debug:
1243
+ r = self.results[-1]
1244
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1245
+ except (KeyError, RuntimeError) as e:
1246
+ if debug:
1247
+ print(f" combinational.priorityencoder: SKIP ({e})")
1248
+
1249
  return scores, total
1250
 
1251
  # =========================================================================