PortfolioAI commited on
Commit
b8ed073
·
1 Parent(s): a4326d5

Add CLZ16BIT, fix README claims, update TODO

Browse files

- arithmetic.clz16bit: 16-bit count leading zeros (63 gates)
- Remove "formally verified" claim (exhaustively tested, not formally proven)
- Mark evaluator improvements complete
- WIP: float16.normalize scaffolding

Files changed (5) hide show
  1. README.md +1 -1
  2. TODO.md +4 -4
  3. arithmetic.safetensors +2 -2
  4. convert_to_explicit_inputs.py +391 -0
  5. eval.py +51 -0
README.md CHANGED
@@ -16,7 +16,7 @@ pipeline_tag: other
16
 
17
  **Verified arithmetic circuits as frozen neural network weights.**
18
 
19
- This repository contains a complete, formally verified arithmetic core implemented as threshold logic gates stored in safetensors format. Every tensor in this model represents a neural network weight or bias that, when combined with a Heaviside step activation function, computes exact arithmetic operations with 100% correctness across all possible inputs.
20
 
21
  ---
22
 
 
16
 
17
  **Verified arithmetic circuits as frozen neural network weights.**
18
 
19
+ This repository contains an arithmetic core implemented as threshold logic gates stored in safetensors format. Every tensor represents a neural network weight or bias that, when combined with a Heaviside step activation function, computes exact arithmetic operations. All circuits are exhaustively tested across all possible inputs (100% pass rate).
20
 
21
  ---
22
 
TODO.md CHANGED
@@ -18,7 +18,7 @@
18
 
19
  ### Supporting Infrastructure
20
  - [x] `arithmetic.clz8bit` -- count leading zeros (needed for float normalization)
21
- - [ ] `arithmetic.clz16bit` -- 16-bit count leading zeros
22
 
23
  ## Medium Priority
24
 
@@ -31,9 +31,9 @@
31
  - [ ] `arithmetic.lcm8bit` -- least common multiple
32
 
33
  ### Evaluator Improvements
34
- - [ ] Full circuit evaluation using .inputs topology
35
- - [ ] Exhaustive testing for all circuits (not just comparators/thresholds)
36
- - [ ] Automatic topological sort from signal registry
37
 
38
  ## Low Priority
39
 
 
18
 
19
  ### Supporting Infrastructure
20
  - [x] `arithmetic.clz8bit` -- count leading zeros (needed for float normalization)
21
+ - [x] `arithmetic.clz16bit` -- 16-bit count leading zeros
22
 
23
  ## Medium Priority
24
 
 
31
  - [ ] `arithmetic.lcm8bit` -- least common multiple
32
 
33
  ### Evaluator Improvements
34
+ - [x] Full circuit evaluation using .inputs topology
35
+ - [x] Exhaustive testing for boolean, threshold, CLZ, float16, comparator circuits
36
+ - [x] Automatic topological sort from signal registry
37
 
38
  ## Low Priority
39
 
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4272c22035d7c264fd8f6bcb22c129f01cd033fb4061b77f94b4f93555a2e823
3
- size 1084844
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebe8e155f964f27d26a8a35750f6af361556a65c1178a1c96e4dd5eea95a66c4
3
+ size 1111188
convert_to_explicit_inputs.py CHANGED
@@ -694,6 +694,105 @@ def infer_minmax_inputs(gate: str, registry: SignalRegistry) -> List[int]:
694
  return inputs
695
 
696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
  def infer_clz8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
698
  """Infer inputs for CLZ8BIT (count leading zeros)."""
699
  prefix = "arithmetic.clz8bit"
@@ -938,6 +1037,8 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
938
  return infer_comparator_inputs(gate, registry)
939
 
940
  # CLZ (count leading zeros)
 
 
941
  if 'clz8bit' in gate:
942
  return infer_clz8bit_inputs(gate, registry)
943
 
@@ -949,11 +1050,125 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
949
  return infer_float16_pack_inputs(gate, registry)
950
  if 'cmp' in gate:
951
  return infer_float16_cmp_inputs(gate, registry)
 
 
952
 
953
  # Default: couldn't infer, return empty (will need manual fix or routing)
954
  return []
955
 
956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
  def infer_float16_cmp_inputs(gate: str, registry: SignalRegistry) -> List[int]:
958
  """Infer inputs for float16.cmp circuit."""
959
  prefix = "float16.cmp"
@@ -1115,6 +1330,94 @@ def infer_float16_unpack_inputs(gate: str, registry: SignalRegistry) -> List[int
1115
  return []
1116
 
1117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1118
  def build_float16_cmp_tensors() -> Dict[str, torch.Tensor]:
1119
  """Build tensors for float16.cmp circuit.
1120
 
@@ -1255,6 +1558,90 @@ def build_float16_unpack_tensors() -> Dict[str, torch.Tensor]:
1255
  return tensors
1256
 
1257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1258
  def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
1259
  """Build tensors for arithmetic.clz8bit circuit.
1260
 
@@ -1330,6 +1717,10 @@ def main():
1330
  tensors.update(clz_tensors)
1331
  print(f" CLZ8BIT: {len(clz_tensors)} tensors")
1332
 
 
 
 
 
1333
  unpack_tensors = build_float16_unpack_tensors()
1334
  tensors.update(unpack_tensors)
1335
  print(f" float16.unpack: {len(unpack_tensors)} tensors")
 
694
  return inputs
695
 
696
 
697
+ def infer_clz16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
698
+ """Infer inputs for CLZ16BIT (count leading zeros, 16-bit)."""
699
+ prefix = "arithmetic.clz16bit"
700
+
701
+ # Register 16-bit input
702
+ for i in range(16):
703
+ registry.register(f"{prefix}.$x[{i}]")
704
+
705
+ # pz gates: prefix zero detectors (NOR of top k bits)
706
+ if '.pz' in gate:
707
+ match = re.search(r'\.pz(\d+)', gate)
708
+ if match:
709
+ k = int(match.group(1))
710
+ return [registry.get_id(f"{prefix}.$x[{15-i}]") for i in range(k)]
711
+
712
+ # Register pz outputs
713
+ for i in range(1, 17):
714
+ registry.register(f"{prefix}.pz{i}")
715
+
716
+ pz_ids = [registry.get_id(f"{prefix}.pz{i}") for i in range(1, 17)]
717
+
718
+ # ge gates: sum of pz >= k
719
+ if '.ge' in gate and '.not_ge' not in gate:
720
+ match = re.search(r'\.ge(\d+)', gate)
721
+ if match:
722
+ return pz_ids
723
+
724
+ # Register ge outputs
725
+ for k in range(1, 17):
726
+ registry.register(f"{prefix}.ge{k}")
727
+
728
+ # NOT gates
729
+ if '.not_ge' in gate:
730
+ match = re.search(r'\.not_ge(\d+)', gate)
731
+ if match:
732
+ k = int(match.group(1))
733
+ return [registry.get_id(f"{prefix}.ge{k}")]
734
+
735
+ # Register NOT outputs
736
+ for k in [2, 4, 6, 8, 10, 12, 14, 16]:
737
+ registry.register(f"{prefix}.not_ge{k}")
738
+
739
+ # AND gates for ranges
740
+ if '.and_8_15' in gate:
741
+ return [registry.get_id(f"{prefix}.ge8"), registry.get_id(f"{prefix}.not_ge16")]
742
+ if '.and_4_7' in gate:
743
+ return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")]
744
+ if '.and_12_15' in gate:
745
+ return [registry.get_id(f"{prefix}.ge12"), registry.get_id(f"{prefix}.not_ge16")]
746
+ if '.and_2_3' in gate:
747
+ return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")]
748
+ if '.and_6_7' in gate:
749
+ return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
750
+ if '.and_10_11' in gate:
751
+ return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.not_ge12")]
752
+ if '.and_14_15' in gate:
753
+ return [registry.get_id(f"{prefix}.ge14"), registry.get_id(f"{prefix}.not_ge16")]
754
+
755
+ # Odd number AND gates (use regex for exact match to avoid .and_1 matching .and_15)
756
+ match = re.search(r'\.and_(\d+)$', gate)
757
+ if match:
758
+ i = int(match.group(1))
759
+ if i in [1, 3, 5, 7, 9, 11, 13, 15]:
760
+ return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{i+1}")]
761
+
762
+ # Register AND outputs
763
+ for name in ['and_8_15', 'and_4_7', 'and_12_15', 'and_2_3', 'and_6_7', 'and_10_11', 'and_14_15']:
764
+ registry.register(f"{prefix}.{name}")
765
+ for i in [1, 3, 5, 7, 9, 11, 13, 15]:
766
+ registry.register(f"{prefix}.and_{i}")
767
+
768
+ # OR gates for bits
769
+ if '.or_bit2' in gate:
770
+ return [registry.get_id(f"{prefix}.and_4_7"), registry.get_id(f"{prefix}.and_12_15")]
771
+ if '.or_bit1' in gate:
772
+ return [registry.get_id(f"{prefix}.and_2_3"), registry.get_id(f"{prefix}.and_6_7"),
773
+ registry.get_id(f"{prefix}.and_10_11"), registry.get_id(f"{prefix}.and_14_15")]
774
+ if '.or_bit0' in gate:
775
+ return [registry.get_id(f"{prefix}.and_{i}") for i in [1, 3, 5, 7, 9, 11, 13, 15]]
776
+
777
+ registry.register(f"{prefix}.or_bit2")
778
+ registry.register(f"{prefix}.or_bit1")
779
+ registry.register(f"{prefix}.or_bit0")
780
+
781
+ # Output gates
782
+ if '.out4' in gate:
783
+ return [registry.get_id(f"{prefix}.ge16")]
784
+ if '.out3' in gate:
785
+ return [registry.get_id(f"{prefix}.and_8_15")]
786
+ if '.out2' in gate:
787
+ return [registry.get_id(f"{prefix}.or_bit2")]
788
+ if '.out1' in gate:
789
+ return [registry.get_id(f"{prefix}.or_bit1")]
790
+ if '.out0' in gate:
791
+ return [registry.get_id(f"{prefix}.or_bit0")]
792
+
793
+ return []
794
+
795
+
796
  def infer_clz8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
797
  """Infer inputs for CLZ8BIT (count leading zeros)."""
798
  prefix = "arithmetic.clz8bit"
 
1037
  return infer_comparator_inputs(gate, registry)
1038
 
1039
  # CLZ (count leading zeros)
1040
+ if 'clz16bit' in gate:
1041
+ return infer_clz16bit_inputs(gate, registry)
1042
  if 'clz8bit' in gate:
1043
  return infer_clz8bit_inputs(gate, registry)
1044
 
 
1050
  return infer_float16_pack_inputs(gate, registry)
1051
  if 'cmp' in gate:
1052
  return infer_float16_cmp_inputs(gate, registry)
1053
+ if 'normalize' in gate:
1054
+ return infer_float16_normalize_inputs(gate, registry)
1055
 
1056
  # Default: couldn't infer, return empty (will need manual fix or routing)
1057
  return []
1058
 
1059
 
1060
+ def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1061
+ """Infer inputs for float16.normalize circuit."""
1062
+ prefix = "float16.normalize"
1063
+
1064
+ # Register 13-bit mantissa input
1065
+ for i in range(13):
1066
+ registry.register(f"{prefix}.$m[{i}]")
1067
+
1068
+ # Overflow detection (bit 12)
1069
+ if '.overflow' in gate and '.not_overflow' not in gate:
1070
+ return [registry.get_id(f"{prefix}.$m[12]")]
1071
+
1072
+ registry.register(f"{prefix}.overflow")
1073
+
1074
+ # is_zero (NOR of all mantissa bits)
1075
+ if '.is_zero' in gate:
1076
+ return [registry.get_id(f"{prefix}.$m[{i}]") for i in range(13)]
1077
+
1078
+ # pz gates (CLZ on bits 11:0)
1079
+ if '.pz' in gate:
1080
+ match = re.search(r'\.pz(\d+)', gate)
1081
+ if match:
1082
+ k = int(match.group(1))
1083
+ # Check top k bits of m[11:0]
1084
+ return [registry.get_id(f"{prefix}.$m[{11-i}]") for i in range(k)]
1085
+
1086
+ # Register pz outputs
1087
+ for i in range(1, 13):
1088
+ registry.register(f"{prefix}.pz{i}")
1089
+
1090
+ pz_ids = [registry.get_id(f"{prefix}.pz{i}") for i in range(1, 13)]
1091
+
1092
+ # ge gates
1093
+ if '.ge' in gate and '.not_ge' not in gate:
1094
+ match = re.search(r'\.ge(\d+)', gate)
1095
+ if match:
1096
+ return pz_ids
1097
+
1098
+ # Register ge outputs
1099
+ for k in range(1, 13):
1100
+ registry.register(f"{prefix}.ge{k}")
1101
+
1102
+ # NOT gates
1103
+ if '.not_ge' in gate:
1104
+ match = re.search(r'\.not_ge(\d+)', gate)
1105
+ if match:
1106
+ k = int(match.group(1))
1107
+ return [registry.get_id(f"{prefix}.ge{k}")]
1108
+
1109
+ for k in [2, 4, 8]:
1110
+ registry.register(f"{prefix}.not_ge{k}")
1111
+
1112
+ # AND gates for ranges
1113
+ if '.and_4_7' in gate:
1114
+ return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")]
1115
+ if '.and_2_3' in gate:
1116
+ return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")]
1117
+ if '.and_6_7' in gate:
1118
+ return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
1119
+ if '.and_10_11' in gate:
1120
+ return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.ge12")]
1121
+ # Note: and_10_11 should be ge10 AND NOT ge12, but we don't have not_ge12
1122
+
1123
+ # Odd AND gates
1124
+ match = re.search(r'\.and_(\d+)$', gate)
1125
+ if match:
1126
+ i = int(match.group(1))
1127
+ if i in [1, 3, 5, 7, 9, 11]:
1128
+ next_even = i + 1
1129
+ if next_even in [2, 4, 8]:
1130
+ return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{next_even}")]
1131
+ else:
1132
+ # Need to register not_ge for this value
1133
+ registry.register(f"{prefix}.not_ge{next_even}")
1134
+ return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{next_even}")]
1135
+
1136
+ # Register AND outputs
1137
+ for name in ['and_4_7', 'and_2_3', 'and_6_7', 'and_10_11']:
1138
+ registry.register(f"{prefix}.{name}")
1139
+ for i in [1, 3, 5, 7, 9, 11]:
1140
+ registry.register(f"{prefix}.and_{i}")
1141
+
1142
+ # Shift bit gates
1143
+ if '.shift3' in gate:
1144
+ return [registry.get_id(f"{prefix}.ge8")]
1145
+ if '.shift2' in gate:
1146
+ return [registry.get_id(f"{prefix}.and_4_7"), registry.get_id(f"{prefix}.ge12")]
1147
+ if '.shift1' in gate:
1148
+ return [registry.get_id(f"{prefix}.and_2_3"), registry.get_id(f"{prefix}.and_6_7"),
1149
+ registry.get_id(f"{prefix}.and_10_11")]
1150
+ if '.shift0' in gate:
1151
+ return [registry.get_id(f"{prefix}.and_{i}") for i in [1, 3, 5, 7, 9, 11]]
1152
+
1153
+ for i in range(4):
1154
+ registry.register(f"{prefix}.shift{i}")
1155
+
1156
+ # not_overflow
1157
+ if '.not_overflow' in gate:
1158
+ return [registry.get_id(f"{prefix}.overflow")]
1159
+
1160
+ registry.register(f"{prefix}.not_overflow")
1161
+
1162
+ # Output shift bits (masked by not_overflow)
1163
+ if '.out_shift' in gate:
1164
+ match = re.search(r'\.out_shift(\d+)', gate)
1165
+ if match:
1166
+ i = int(match.group(1))
1167
+ return [registry.get_id(f"{prefix}.shift{i}"), registry.get_id(f"{prefix}.not_overflow")]
1168
+
1169
+ return []
1170
+
1171
+
1172
  def infer_float16_cmp_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1173
  """Infer inputs for float16.cmp circuit."""
1174
  prefix = "float16.cmp"
 
1330
  return []
1331
 
1332
 
1333
+ def build_float16_normalize_tensors() -> Dict[str, torch.Tensor]:
1334
+ """Build tensors for float16.normalize circuit.
1335
+
1336
+ Normalizes an extended mantissa by finding leading 1 and shifting.
1337
+ Used after float16 addition/subtraction.
1338
+
1339
+ Inputs:
1340
+ - 13-bit extended mantissa ($m[12:0], where $m[12] is overflow bit)
1341
+ - 8-bit raw exponent ($e[7:0])
1342
+ - 1-bit sign ($sign)
1343
+
1344
+ Outputs:
1345
+ - shift_amt[3:0]: how many positions to shift left (0-12)
1346
+ - is_zero: mantissa is all zeros
1347
+ - overflow: mantissa bit 12 is set (need right shift)
1348
+
1349
+ The actual shifting and exponent adjustment are done externally
1350
+ since a full barrel shifter is complex.
1351
+ """
1352
+ tensors = {}
1353
+ prefix = "float16.normalize"
1354
+
1355
+ # Detect overflow (bit 12 set) - needs right shift, not left
1356
+ tensors[f"{prefix}.overflow.weight"] = torch.tensor([1.0])
1357
+ tensors[f"{prefix}.overflow.bias"] = torch.tensor([-0.5])
1358
+
1359
+ # Detect all-zero mantissa
1360
+ # is_zero = NOR of all 13 mantissa bits
1361
+ tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 13)
1362
+ tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0])
1363
+
1364
+ # CLZ on bits 11:0 (excluding overflow bit) to find shift amount
1365
+ # If overflow, shift amount is 0 (actually -1, handled specially)
1366
+ # pz[k] = 1 if top k bits of m[11:0] are all zero
1367
+ for k in range(1, 13):
1368
+ tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k)
1369
+ tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0])
1370
+
1371
+ # ge[k] = sum of pz >= k (CLZ >= k)
1372
+ for k in range(1, 13):
1373
+ tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 12)
1374
+ tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
1375
+
1376
+ # NOT gates for binary encoding (need all even values for odd AND gates)
1377
+ for k in [2, 4, 6, 8, 10, 12]:
1378
+ tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
1379
+ tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
1380
+
1381
+ # Shift amount is min(CLZ, 12) encoded in 4 bits
1382
+ # bit3: CLZ >= 8
1383
+ tensors[f"{prefix}.shift3.weight"] = torch.tensor([1.0])
1384
+ tensors[f"{prefix}.shift3.bias"] = torch.tensor([-0.5]) # pass ge8
1385
+
1386
+ # bit2: CLZ in {4-7, 12} = (ge4 AND NOT ge8) OR ge12
1387
+ tensors[f"{prefix}.and_4_7.weight"] = torch.tensor([1.0, 1.0])
1388
+ tensors[f"{prefix}.and_4_7.bias"] = torch.tensor([-2.0])
1389
+ tensors[f"{prefix}.shift2.weight"] = torch.tensor([1.0, 1.0])
1390
+ tensors[f"{prefix}.shift2.bias"] = torch.tensor([-1.0])
1391
+
1392
+ # bit1: CLZ in {2,3,6,7,10,11}
1393
+ tensors[f"{prefix}.and_2_3.weight"] = torch.tensor([1.0, 1.0])
1394
+ tensors[f"{prefix}.and_2_3.bias"] = torch.tensor([-2.0])
1395
+ tensors[f"{prefix}.and_6_7.weight"] = torch.tensor([1.0, 1.0])
1396
+ tensors[f"{prefix}.and_6_7.bias"] = torch.tensor([-2.0])
1397
+ tensors[f"{prefix}.and_10_11.weight"] = torch.tensor([1.0, 1.0])
1398
+ tensors[f"{prefix}.and_10_11.bias"] = torch.tensor([-2.0])
1399
+ tensors[f"{prefix}.shift1.weight"] = torch.tensor([1.0, 1.0, 1.0])
1400
+ tensors[f"{prefix}.shift1.bias"] = torch.tensor([-1.0])
1401
+
1402
+ # bit0: CLZ is odd {1,3,5,7,9,11}
1403
+ for i in [1, 3, 5, 7, 9, 11]:
1404
+ tensors[f"{prefix}.and_{i}.weight"] = torch.tensor([1.0, 1.0])
1405
+ tensors[f"{prefix}.and_{i}.bias"] = torch.tensor([-2.0])
1406
+ tensors[f"{prefix}.shift0.weight"] = torch.tensor([1.0] * 6)
1407
+ tensors[f"{prefix}.shift0.bias"] = torch.tensor([-1.0])
1408
+
1409
+ # When overflow is set, shift amount should be 0 (we'll right-shift by 1 externally)
1410
+ # Mask shift bits with NOT overflow
1411
+ tensors[f"{prefix}.not_overflow.weight"] = torch.tensor([-1.0])
1412
+ tensors[f"{prefix}.not_overflow.bias"] = torch.tensor([0.0])
1413
+
1414
+ for i in range(4):
1415
+ tensors[f"{prefix}.out_shift{i}.weight"] = torch.tensor([1.0, 1.0])
1416
+ tensors[f"{prefix}.out_shift{i}.bias"] = torch.tensor([-2.0])
1417
+
1418
+ return tensors
1419
+
1420
+
1421
  def build_float16_cmp_tensors() -> Dict[str, torch.Tensor]:
1422
  """Build tensors for float16.cmp circuit.
1423
 
 
1558
  return tensors
1559
 
1560
 
1561
+ def build_clz16bit_tensors() -> Dict[str, torch.Tensor]:
1562
+ """Build tensors for arithmetic.clz16bit circuit.
1563
+
1564
+ CLZ16BIT counts leading zeros in a 16-bit input.
1565
+ Output is 0-16 (5 bits).
1566
+
1567
+ Architecture (same as CLZ8BIT):
1568
+ 1. pz[k] gates: NOR of top k bits (fires if top k bits are all zero)
1569
+ 2. ge[k] gates: sum of pz >= k (threshold gates)
1570
+ 3. Logic gates to convert thermometer code to binary
1571
+ """
1572
+ tensors = {}
1573
+ prefix = "arithmetic.clz16bit"
1574
+
1575
+ # === PREFIX ZERO GATES (NOR of top k bits) ===
1576
+ for k in range(1, 17):
1577
+ tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k)
1578
+ tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0])
1579
+
1580
+ # === GE GATES (sum of pz >= k) ===
1581
+ for k in range(1, 17):
1582
+ tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 16)
1583
+ tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
1584
+
1585
+ # === NOT GATES (for all values used in range detection) ===
1586
+ for k in [2, 4, 6, 8, 10, 12, 14, 16]:
1587
+ tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
1588
+ tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
1589
+
1590
+ # === AND GATES for range detection ===
1591
+ # For 5-bit output (0-16), need to detect ranges for each bit
1592
+
1593
+ # bit4 (16's place): CLZ >= 16, just ge16
1594
+ # bit3 (8's place): CLZ in {8-15} = ge8 AND NOT ge16
1595
+ tensors[f"{prefix}.and_8_15.weight"] = torch.tensor([1.0, 1.0])
1596
+ tensors[f"{prefix}.and_8_15.bias"] = torch.tensor([-2.0])
1597
+
1598
+ # bit2 (4's place): CLZ in {4-7, 12-15}
1599
+ # = (ge4 AND NOT ge8) OR (ge12 AND NOT ge16)
1600
+ tensors[f"{prefix}.and_4_7.weight"] = torch.tensor([1.0, 1.0])
1601
+ tensors[f"{prefix}.and_4_7.bias"] = torch.tensor([-2.0])
1602
+ tensors[f"{prefix}.and_12_15.weight"] = torch.tensor([1.0, 1.0])
1603
+ tensors[f"{prefix}.and_12_15.bias"] = torch.tensor([-2.0])
1604
+ tensors[f"{prefix}.or_bit2.weight"] = torch.tensor([1.0, 1.0])
1605
+ tensors[f"{prefix}.or_bit2.bias"] = torch.tensor([-1.0])
1606
+
1607
+ # bit1 (2's place): CLZ in {2,3,6,7,10,11,14,15}
1608
+ tensors[f"{prefix}.and_2_3.weight"] = torch.tensor([1.0, 1.0])
1609
+ tensors[f"{prefix}.and_2_3.bias"] = torch.tensor([-2.0])
1610
+ tensors[f"{prefix}.and_6_7.weight"] = torch.tensor([1.0, 1.0])
1611
+ tensors[f"{prefix}.and_6_7.bias"] = torch.tensor([-2.0])
1612
+ tensors[f"{prefix}.and_10_11.weight"] = torch.tensor([1.0, 1.0])
1613
+ tensors[f"{prefix}.and_10_11.bias"] = torch.tensor([-2.0])
1614
+ tensors[f"{prefix}.and_14_15.weight"] = torch.tensor([1.0, 1.0])
1615
+ tensors[f"{prefix}.and_14_15.bias"] = torch.tensor([-2.0])
1616
+ tensors[f"{prefix}.or_bit1.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0])
1617
+ tensors[f"{prefix}.or_bit1.bias"] = torch.tensor([-1.0])
1618
+
1619
+ # bit0 (1's place): CLZ is odd {1,3,5,7,9,11,13,15}
1620
+ for i in [1, 3, 5, 7, 9, 11, 13, 15]:
1621
+ tensors[f"{prefix}.and_{i}.weight"] = torch.tensor([1.0, 1.0])
1622
+ tensors[f"{prefix}.and_{i}.bias"] = torch.tensor([-2.0])
1623
+ tensors[f"{prefix}.or_bit0.weight"] = torch.tensor([1.0] * 8)
1624
+ tensors[f"{prefix}.or_bit0.bias"] = torch.tensor([-1.0])
1625
+
1626
+ # === OUTPUT GATES ===
1627
+ tensors[f"{prefix}.out4.weight"] = torch.tensor([1.0])
1628
+ tensors[f"{prefix}.out4.bias"] = torch.tensor([-0.5]) # pass-through ge16
1629
+
1630
+ tensors[f"{prefix}.out3.weight"] = torch.tensor([1.0])
1631
+ tensors[f"{prefix}.out3.bias"] = torch.tensor([-0.5]) # pass-through and_8_15
1632
+
1633
+ tensors[f"{prefix}.out2.weight"] = torch.tensor([1.0])
1634
+ tensors[f"{prefix}.out2.bias"] = torch.tensor([-0.5]) # pass-through or_bit2
1635
+
1636
+ tensors[f"{prefix}.out1.weight"] = torch.tensor([1.0])
1637
+ tensors[f"{prefix}.out1.bias"] = torch.tensor([-0.5]) # pass-through or_bit1
1638
+
1639
+ tensors[f"{prefix}.out0.weight"] = torch.tensor([1.0])
1640
+ tensors[f"{prefix}.out0.bias"] = torch.tensor([-0.5]) # pass-through or_bit0
1641
+
1642
+ return tensors
1643
+
1644
+
1645
  def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
1646
  """Build tensors for arithmetic.clz8bit circuit.
1647
 
 
1717
  tensors.update(clz_tensors)
1718
  print(f" CLZ8BIT: {len(clz_tensors)} tensors")
1719
 
1720
+ clz16_tensors = build_clz16bit_tensors()
1721
+ tensors.update(clz16_tensors)
1722
+ print(f" CLZ16BIT: {len(clz16_tensors)} tensors")
1723
+
1724
  unpack_tensors = build_float16_unpack_tensors()
1725
  tensors.update(unpack_tensors)
1726
  print(f" float16.unpack: {len(unpack_tensors)} tensors")
eval.py CHANGED
@@ -291,6 +291,52 @@ class CircuitEvaluator:
291
 
292
  return TestResult('arithmetic.clz8bit', passed, 256, failures)
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # =========================================================================
295
  # FLOAT16 TESTS
296
  # =========================================================================
@@ -623,6 +669,11 @@ class Evaluator:
623
  self.results.append(result)
624
  if verbose:
625
  self._print_result(result)
 
 
 
 
 
626
 
627
  # Float16
628
  if verbose:
 
291
 
292
  return TestResult('arithmetic.clz8bit', passed, 256, failures)
293
 
294
+ def test_clz16bit(self) -> TestResult:
295
+ """Test 16-bit count leading zeros."""
296
+ prefix = 'arithmetic.clz16bit'
297
+ failures = []
298
+ passed = 0
299
+
300
+ # Test all powers of 2 and some random values
301
+ test_values = [0] + [1 << i for i in range(16)] # 0, 1, 2, 4, ..., 32768
302
+
303
+ import random
304
+ random.seed(42)
305
+ for _ in range(200):
306
+ test_values.append(random.randint(0, 0xFFFF))
307
+
308
+ for val in test_values:
309
+ # Expected CLZ
310
+ expected = 16
311
+ for i in range(16):
312
+ if (val >> (15-i)) & 1:
313
+ expected = i
314
+ break
315
+
316
+ # Set up inputs: $x[15] = MSB, $x[0] = LSB
317
+ ext = {}
318
+ for i in range(16):
319
+ ext[f'{prefix}.$x[{i}]'] = float((val >> i) & 1)
320
+
321
+ values = self.eval_circuit(prefix, ext)
322
+
323
+ # Extract result from output gates
324
+ out4 = values.get(f'{prefix}.out4', 0)
325
+ out3 = values.get(f'{prefix}.out3', 0)
326
+ out2 = values.get(f'{prefix}.out2', 0)
327
+ out1 = values.get(f'{prefix}.out1', 0)
328
+ out0 = values.get(f'{prefix}.out0', 0)
329
+
330
+ result = int(out4)*16 + int(out3)*8 + int(out2)*4 + int(out1)*2 + int(out0)
331
+
332
+ if result == expected:
333
+ passed += 1
334
+ else:
335
+ if len(failures) < 10:
336
+ failures.append((val, expected, result))
337
+
338
+ return TestResult('arithmetic.clz16bit', passed, len(test_values), failures)
339
+
340
  # =========================================================================
341
  # FLOAT16 TESTS
342
  # =========================================================================
 
669
  self.results.append(result)
670
  if verbose:
671
  self._print_result(result)
672
+ if 'arithmetic.clz16bit.pz1.weight' in self.eval.tensors:
673
+ result = self.eval.test_clz16bit()
674
+ self.results.append(result)
675
+ if verbose:
676
+ self._print_result(result)
677
 
678
  # Float16
679
  if verbose: