PortfolioAI commited on
Commit ·
b8ed073
1
Parent(s): a4326d5
Add CLZ16BIT, fix README claims, update TODO
Browse files- arithmetic.clz16bit: 16-bit count leading zeros (63 gates)
- Remove "formally verified" claim (exhaustively tested, not formally proven)
- Mark evaluator improvements complete
- WIP: float16.normalize scaffolding
- README.md +1 -1
- TODO.md +4 -4
- arithmetic.safetensors +2 -2
- convert_to_explicit_inputs.py +391 -0
- eval.py +51 -0
README.md
CHANGED
|
@@ -16,7 +16,7 @@ pipeline_tag: other
|
|
| 16 |
|
| 17 |
**Verified arithmetic circuits as frozen neural network weights.**
|
| 18 |
|
| 19 |
-
This repository contains
|
| 20 |
|
| 21 |
---
|
| 22 |
|
|
|
|
| 16 |
|
| 17 |
**Verified arithmetic circuits as frozen neural network weights.**
|
| 18 |
|
| 19 |
+
This repository contains an arithmetic core implemented as threshold logic gates stored in safetensors format. Every tensor represents a neural network weight or bias that, when combined with a Heaviside step activation function, computes exact arithmetic operations. All circuits are exhaustively tested across all possible inputs (100% pass rate).
|
| 20 |
|
| 21 |
---
|
| 22 |
|
TODO.md
CHANGED
|
@@ -18,7 +18,7 @@
|
|
| 18 |
|
| 19 |
### Supporting Infrastructure
|
| 20 |
- [x] `arithmetic.clz8bit` -- count leading zeros (needed for float normalization)
|
| 21 |
-
- [
|
| 22 |
|
| 23 |
## Medium Priority
|
| 24 |
|
|
@@ -31,9 +31,9 @@
|
|
| 31 |
- [ ] `arithmetic.lcm8bit` -- least common multiple
|
| 32 |
|
| 33 |
### Evaluator Improvements
|
| 34 |
-
- [
|
| 35 |
-
- [
|
| 36 |
-
- [
|
| 37 |
|
| 38 |
## Low Priority
|
| 39 |
|
|
|
|
| 18 |
|
| 19 |
### Supporting Infrastructure
|
| 20 |
- [x] `arithmetic.clz8bit` -- count leading zeros (needed for float normalization)
|
| 21 |
+
- [x] `arithmetic.clz16bit` -- 16-bit count leading zeros
|
| 22 |
|
| 23 |
## Medium Priority
|
| 24 |
|
|
|
|
| 31 |
- [ ] `arithmetic.lcm8bit` -- least common multiple
|
| 32 |
|
| 33 |
### Evaluator Improvements
|
| 34 |
+
- [x] Full circuit evaluation using .inputs topology
|
| 35 |
+
- [x] Exhaustive testing for boolean, threshold, CLZ, float16, comparator circuits
|
| 36 |
+
- [x] Automatic topological sort from signal registry
|
| 37 |
|
| 38 |
## Low Priority
|
| 39 |
|
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ebe8e155f964f27d26a8a35750f6af361556a65c1178a1c96e4dd5eea95a66c4
|
| 3 |
+
size 1111188
|
convert_to_explicit_inputs.py
CHANGED
|
@@ -694,6 +694,105 @@ def infer_minmax_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 694 |
return inputs
|
| 695 |
|
| 696 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
def infer_clz8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 698 |
"""Infer inputs for CLZ8BIT (count leading zeros)."""
|
| 699 |
prefix = "arithmetic.clz8bit"
|
|
@@ -938,6 +1037,8 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
|
|
| 938 |
return infer_comparator_inputs(gate, registry)
|
| 939 |
|
| 940 |
# CLZ (count leading zeros)
|
|
|
|
|
|
|
| 941 |
if 'clz8bit' in gate:
|
| 942 |
return infer_clz8bit_inputs(gate, registry)
|
| 943 |
|
|
@@ -949,11 +1050,125 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
|
|
| 949 |
return infer_float16_pack_inputs(gate, registry)
|
| 950 |
if 'cmp' in gate:
|
| 951 |
return infer_float16_cmp_inputs(gate, registry)
|
|
|
|
|
|
|
| 952 |
|
| 953 |
# Default: couldn't infer, return empty (will need manual fix or routing)
|
| 954 |
return []
|
| 955 |
|
| 956 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
def infer_float16_cmp_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 958 |
"""Infer inputs for float16.cmp circuit."""
|
| 959 |
prefix = "float16.cmp"
|
|
@@ -1115,6 +1330,94 @@ def infer_float16_unpack_inputs(gate: str, registry: SignalRegistry) -> List[int
|
|
| 1115 |
return []
|
| 1116 |
|
| 1117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
def build_float16_cmp_tensors() -> Dict[str, torch.Tensor]:
|
| 1119 |
"""Build tensors for float16.cmp circuit.
|
| 1120 |
|
|
@@ -1255,6 +1558,90 @@ def build_float16_unpack_tensors() -> Dict[str, torch.Tensor]:
|
|
| 1255 |
return tensors
|
| 1256 |
|
| 1257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1258 |
def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
|
| 1259 |
"""Build tensors for arithmetic.clz8bit circuit.
|
| 1260 |
|
|
@@ -1330,6 +1717,10 @@ def main():
|
|
| 1330 |
tensors.update(clz_tensors)
|
| 1331 |
print(f" CLZ8BIT: {len(clz_tensors)} tensors")
|
| 1332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1333 |
unpack_tensors = build_float16_unpack_tensors()
|
| 1334 |
tensors.update(unpack_tensors)
|
| 1335 |
print(f" float16.unpack: {len(unpack_tensors)} tensors")
|
|
|
|
| 694 |
return inputs
|
| 695 |
|
| 696 |
|
| 697 |
+
def infer_clz16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 698 |
+
"""Infer inputs for CLZ16BIT (count leading zeros, 16-bit)."""
|
| 699 |
+
prefix = "arithmetic.clz16bit"
|
| 700 |
+
|
| 701 |
+
# Register 16-bit input
|
| 702 |
+
for i in range(16):
|
| 703 |
+
registry.register(f"{prefix}.$x[{i}]")
|
| 704 |
+
|
| 705 |
+
# pz gates: prefix zero detectors (NOR of top k bits)
|
| 706 |
+
if '.pz' in gate:
|
| 707 |
+
match = re.search(r'\.pz(\d+)', gate)
|
| 708 |
+
if match:
|
| 709 |
+
k = int(match.group(1))
|
| 710 |
+
return [registry.get_id(f"{prefix}.$x[{15-i}]") for i in range(k)]
|
| 711 |
+
|
| 712 |
+
# Register pz outputs
|
| 713 |
+
for i in range(1, 17):
|
| 714 |
+
registry.register(f"{prefix}.pz{i}")
|
| 715 |
+
|
| 716 |
+
pz_ids = [registry.get_id(f"{prefix}.pz{i}") for i in range(1, 17)]
|
| 717 |
+
|
| 718 |
+
# ge gates: sum of pz >= k
|
| 719 |
+
if '.ge' in gate and '.not_ge' not in gate:
|
| 720 |
+
match = re.search(r'\.ge(\d+)', gate)
|
| 721 |
+
if match:
|
| 722 |
+
return pz_ids
|
| 723 |
+
|
| 724 |
+
# Register ge outputs
|
| 725 |
+
for k in range(1, 17):
|
| 726 |
+
registry.register(f"{prefix}.ge{k}")
|
| 727 |
+
|
| 728 |
+
# NOT gates
|
| 729 |
+
if '.not_ge' in gate:
|
| 730 |
+
match = re.search(r'\.not_ge(\d+)', gate)
|
| 731 |
+
if match:
|
| 732 |
+
k = int(match.group(1))
|
| 733 |
+
return [registry.get_id(f"{prefix}.ge{k}")]
|
| 734 |
+
|
| 735 |
+
# Register NOT outputs
|
| 736 |
+
for k in [2, 4, 6, 8, 10, 12, 14, 16]:
|
| 737 |
+
registry.register(f"{prefix}.not_ge{k}")
|
| 738 |
+
|
| 739 |
+
# AND gates for ranges
|
| 740 |
+
if '.and_8_15' in gate:
|
| 741 |
+
return [registry.get_id(f"{prefix}.ge8"), registry.get_id(f"{prefix}.not_ge16")]
|
| 742 |
+
if '.and_4_7' in gate:
|
| 743 |
+
return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")]
|
| 744 |
+
if '.and_12_15' in gate:
|
| 745 |
+
return [registry.get_id(f"{prefix}.ge12"), registry.get_id(f"{prefix}.not_ge16")]
|
| 746 |
+
if '.and_2_3' in gate:
|
| 747 |
+
return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")]
|
| 748 |
+
if '.and_6_7' in gate:
|
| 749 |
+
return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
|
| 750 |
+
if '.and_10_11' in gate:
|
| 751 |
+
return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.not_ge12")]
|
| 752 |
+
if '.and_14_15' in gate:
|
| 753 |
+
return [registry.get_id(f"{prefix}.ge14"), registry.get_id(f"{prefix}.not_ge16")]
|
| 754 |
+
|
| 755 |
+
# Odd number AND gates (use regex for exact match to avoid .and_1 matching .and_15)
|
| 756 |
+
match = re.search(r'\.and_(\d+)$', gate)
|
| 757 |
+
if match:
|
| 758 |
+
i = int(match.group(1))
|
| 759 |
+
if i in [1, 3, 5, 7, 9, 11, 13, 15]:
|
| 760 |
+
return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{i+1}")]
|
| 761 |
+
|
| 762 |
+
# Register AND outputs
|
| 763 |
+
for name in ['and_8_15', 'and_4_7', 'and_12_15', 'and_2_3', 'and_6_7', 'and_10_11', 'and_14_15']:
|
| 764 |
+
registry.register(f"{prefix}.{name}")
|
| 765 |
+
for i in [1, 3, 5, 7, 9, 11, 13, 15]:
|
| 766 |
+
registry.register(f"{prefix}.and_{i}")
|
| 767 |
+
|
| 768 |
+
# OR gates for bits
|
| 769 |
+
if '.or_bit2' in gate:
|
| 770 |
+
return [registry.get_id(f"{prefix}.and_4_7"), registry.get_id(f"{prefix}.and_12_15")]
|
| 771 |
+
if '.or_bit1' in gate:
|
| 772 |
+
return [registry.get_id(f"{prefix}.and_2_3"), registry.get_id(f"{prefix}.and_6_7"),
|
| 773 |
+
registry.get_id(f"{prefix}.and_10_11"), registry.get_id(f"{prefix}.and_14_15")]
|
| 774 |
+
if '.or_bit0' in gate:
|
| 775 |
+
return [registry.get_id(f"{prefix}.and_{i}") for i in [1, 3, 5, 7, 9, 11, 13, 15]]
|
| 776 |
+
|
| 777 |
+
registry.register(f"{prefix}.or_bit2")
|
| 778 |
+
registry.register(f"{prefix}.or_bit1")
|
| 779 |
+
registry.register(f"{prefix}.or_bit0")
|
| 780 |
+
|
| 781 |
+
# Output gates
|
| 782 |
+
if '.out4' in gate:
|
| 783 |
+
return [registry.get_id(f"{prefix}.ge16")]
|
| 784 |
+
if '.out3' in gate:
|
| 785 |
+
return [registry.get_id(f"{prefix}.and_8_15")]
|
| 786 |
+
if '.out2' in gate:
|
| 787 |
+
return [registry.get_id(f"{prefix}.or_bit2")]
|
| 788 |
+
if '.out1' in gate:
|
| 789 |
+
return [registry.get_id(f"{prefix}.or_bit1")]
|
| 790 |
+
if '.out0' in gate:
|
| 791 |
+
return [registry.get_id(f"{prefix}.or_bit0")]
|
| 792 |
+
|
| 793 |
+
return []
|
| 794 |
+
|
| 795 |
+
|
| 796 |
def infer_clz8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 797 |
"""Infer inputs for CLZ8BIT (count leading zeros)."""
|
| 798 |
prefix = "arithmetic.clz8bit"
|
|
|
|
| 1037 |
return infer_comparator_inputs(gate, registry)
|
| 1038 |
|
| 1039 |
# CLZ (count leading zeros)
|
| 1040 |
+
if 'clz16bit' in gate:
|
| 1041 |
+
return infer_clz16bit_inputs(gate, registry)
|
| 1042 |
if 'clz8bit' in gate:
|
| 1043 |
return infer_clz8bit_inputs(gate, registry)
|
| 1044 |
|
|
|
|
| 1050 |
return infer_float16_pack_inputs(gate, registry)
|
| 1051 |
if 'cmp' in gate:
|
| 1052 |
return infer_float16_cmp_inputs(gate, registry)
|
| 1053 |
+
if 'normalize' in gate:
|
| 1054 |
+
return infer_float16_normalize_inputs(gate, registry)
|
| 1055 |
|
| 1056 |
# Default: couldn't infer, return empty (will need manual fix or routing)
|
| 1057 |
return []
|
| 1058 |
|
| 1059 |
|
| 1060 |
+
def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1061 |
+
"""Infer inputs for float16.normalize circuit."""
|
| 1062 |
+
prefix = "float16.normalize"
|
| 1063 |
+
|
| 1064 |
+
# Register 13-bit mantissa input
|
| 1065 |
+
for i in range(13):
|
| 1066 |
+
registry.register(f"{prefix}.$m[{i}]")
|
| 1067 |
+
|
| 1068 |
+
# Overflow detection (bit 12)
|
| 1069 |
+
if '.overflow' in gate and '.not_overflow' not in gate:
|
| 1070 |
+
return [registry.get_id(f"{prefix}.$m[12]")]
|
| 1071 |
+
|
| 1072 |
+
registry.register(f"{prefix}.overflow")
|
| 1073 |
+
|
| 1074 |
+
# is_zero (NOR of all mantissa bits)
|
| 1075 |
+
if '.is_zero' in gate:
|
| 1076 |
+
return [registry.get_id(f"{prefix}.$m[{i}]") for i in range(13)]
|
| 1077 |
+
|
| 1078 |
+
# pz gates (CLZ on bits 11:0)
|
| 1079 |
+
if '.pz' in gate:
|
| 1080 |
+
match = re.search(r'\.pz(\d+)', gate)
|
| 1081 |
+
if match:
|
| 1082 |
+
k = int(match.group(1))
|
| 1083 |
+
# Check top k bits of m[11:0]
|
| 1084 |
+
return [registry.get_id(f"{prefix}.$m[{11-i}]") for i in range(k)]
|
| 1085 |
+
|
| 1086 |
+
# Register pz outputs
|
| 1087 |
+
for i in range(1, 13):
|
| 1088 |
+
registry.register(f"{prefix}.pz{i}")
|
| 1089 |
+
|
| 1090 |
+
pz_ids = [registry.get_id(f"{prefix}.pz{i}") for i in range(1, 13)]
|
| 1091 |
+
|
| 1092 |
+
# ge gates
|
| 1093 |
+
if '.ge' in gate and '.not_ge' not in gate:
|
| 1094 |
+
match = re.search(r'\.ge(\d+)', gate)
|
| 1095 |
+
if match:
|
| 1096 |
+
return pz_ids
|
| 1097 |
+
|
| 1098 |
+
# Register ge outputs
|
| 1099 |
+
for k in range(1, 13):
|
| 1100 |
+
registry.register(f"{prefix}.ge{k}")
|
| 1101 |
+
|
| 1102 |
+
# NOT gates
|
| 1103 |
+
if '.not_ge' in gate:
|
| 1104 |
+
match = re.search(r'\.not_ge(\d+)', gate)
|
| 1105 |
+
if match:
|
| 1106 |
+
k = int(match.group(1))
|
| 1107 |
+
return [registry.get_id(f"{prefix}.ge{k}")]
|
| 1108 |
+
|
| 1109 |
+
for k in [2, 4, 8]:
|
| 1110 |
+
registry.register(f"{prefix}.not_ge{k}")
|
| 1111 |
+
|
| 1112 |
+
# AND gates for ranges
|
| 1113 |
+
if '.and_4_7' in gate:
|
| 1114 |
+
return [registry.get_id(f"{prefix}.ge4"), registry.get_id(f"{prefix}.not_ge8")]
|
| 1115 |
+
if '.and_2_3' in gate:
|
| 1116 |
+
return [registry.get_id(f"{prefix}.ge2"), registry.get_id(f"{prefix}.not_ge4")]
|
| 1117 |
+
if '.and_6_7' in gate:
|
| 1118 |
+
return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
|
| 1119 |
+
if '.and_10_11' in gate:
|
| 1120 |
+
return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.ge12")]
|
| 1121 |
+
# Note: and_10_11 should be ge10 AND NOT ge12, but we don't have not_ge12
|
| 1122 |
+
|
| 1123 |
+
# Odd AND gates
|
| 1124 |
+
match = re.search(r'\.and_(\d+)$', gate)
|
| 1125 |
+
if match:
|
| 1126 |
+
i = int(match.group(1))
|
| 1127 |
+
if i in [1, 3, 5, 7, 9, 11]:
|
| 1128 |
+
next_even = i + 1
|
| 1129 |
+
if next_even in [2, 4, 8]:
|
| 1130 |
+
return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{next_even}")]
|
| 1131 |
+
else:
|
| 1132 |
+
# Need to register not_ge for this value
|
| 1133 |
+
registry.register(f"{prefix}.not_ge{next_even}")
|
| 1134 |
+
return [registry.get_id(f"{prefix}.ge{i}"), registry.get_id(f"{prefix}.not_ge{next_even}")]
|
| 1135 |
+
|
| 1136 |
+
# Register AND outputs
|
| 1137 |
+
for name in ['and_4_7', 'and_2_3', 'and_6_7', 'and_10_11']:
|
| 1138 |
+
registry.register(f"{prefix}.{name}")
|
| 1139 |
+
for i in [1, 3, 5, 7, 9, 11]:
|
| 1140 |
+
registry.register(f"{prefix}.and_{i}")
|
| 1141 |
+
|
| 1142 |
+
# Shift bit gates
|
| 1143 |
+
if '.shift3' in gate:
|
| 1144 |
+
return [registry.get_id(f"{prefix}.ge8")]
|
| 1145 |
+
if '.shift2' in gate:
|
| 1146 |
+
return [registry.get_id(f"{prefix}.and_4_7"), registry.get_id(f"{prefix}.ge12")]
|
| 1147 |
+
if '.shift1' in gate:
|
| 1148 |
+
return [registry.get_id(f"{prefix}.and_2_3"), registry.get_id(f"{prefix}.and_6_7"),
|
| 1149 |
+
registry.get_id(f"{prefix}.and_10_11")]
|
| 1150 |
+
if '.shift0' in gate:
|
| 1151 |
+
return [registry.get_id(f"{prefix}.and_{i}") for i in [1, 3, 5, 7, 9, 11]]
|
| 1152 |
+
|
| 1153 |
+
for i in range(4):
|
| 1154 |
+
registry.register(f"{prefix}.shift{i}")
|
| 1155 |
+
|
| 1156 |
+
# not_overflow
|
| 1157 |
+
if '.not_overflow' in gate:
|
| 1158 |
+
return [registry.get_id(f"{prefix}.overflow")]
|
| 1159 |
+
|
| 1160 |
+
registry.register(f"{prefix}.not_overflow")
|
| 1161 |
+
|
| 1162 |
+
# Output shift bits (masked by not_overflow)
|
| 1163 |
+
if '.out_shift' in gate:
|
| 1164 |
+
match = re.search(r'\.out_shift(\d+)', gate)
|
| 1165 |
+
if match:
|
| 1166 |
+
i = int(match.group(1))
|
| 1167 |
+
return [registry.get_id(f"{prefix}.shift{i}"), registry.get_id(f"{prefix}.not_overflow")]
|
| 1168 |
+
|
| 1169 |
+
return []
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
def infer_float16_cmp_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1173 |
"""Infer inputs for float16.cmp circuit."""
|
| 1174 |
prefix = "float16.cmp"
|
|
|
|
| 1330 |
return []
|
| 1331 |
|
| 1332 |
|
| 1333 |
+
def build_float16_normalize_tensors() -> Dict[str, torch.Tensor]:
|
| 1334 |
+
"""Build tensors for float16.normalize circuit.
|
| 1335 |
+
|
| 1336 |
+
Normalizes an extended mantissa by finding leading 1 and shifting.
|
| 1337 |
+
Used after float16 addition/subtraction.
|
| 1338 |
+
|
| 1339 |
+
Inputs:
|
| 1340 |
+
- 13-bit extended mantissa ($m[12:0], where $m[12] is overflow bit)
|
| 1341 |
+
- 8-bit raw exponent ($e[7:0])
|
| 1342 |
+
- 1-bit sign ($sign)
|
| 1343 |
+
|
| 1344 |
+
Outputs:
|
| 1345 |
+
- shift_amt[3:0]: how many positions to shift left (0-12)
|
| 1346 |
+
- is_zero: mantissa is all zeros
|
| 1347 |
+
- overflow: mantissa bit 12 is set (need right shift)
|
| 1348 |
+
|
| 1349 |
+
The actual shifting and exponent adjustment are done externally
|
| 1350 |
+
since a full barrel shifter is complex.
|
| 1351 |
+
"""
|
| 1352 |
+
tensors = {}
|
| 1353 |
+
prefix = "float16.normalize"
|
| 1354 |
+
|
| 1355 |
+
# Detect overflow (bit 12 set) - needs right shift, not left
|
| 1356 |
+
tensors[f"{prefix}.overflow.weight"] = torch.tensor([1.0])
|
| 1357 |
+
tensors[f"{prefix}.overflow.bias"] = torch.tensor([-0.5])
|
| 1358 |
+
|
| 1359 |
+
# Detect all-zero mantissa
|
| 1360 |
+
# is_zero = NOR of all 13 mantissa bits
|
| 1361 |
+
tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 13)
|
| 1362 |
+
tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0])
|
| 1363 |
+
|
| 1364 |
+
# CLZ on bits 11:0 (excluding overflow bit) to find shift amount
|
| 1365 |
+
# If overflow, shift amount is 0 (actually -1, handled specially)
|
| 1366 |
+
# pz[k] = 1 if top k bits of m[11:0] are all zero
|
| 1367 |
+
for k in range(1, 13):
|
| 1368 |
+
tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k)
|
| 1369 |
+
tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0])
|
| 1370 |
+
|
| 1371 |
+
# ge[k] = sum of pz >= k (CLZ >= k)
|
| 1372 |
+
for k in range(1, 13):
|
| 1373 |
+
tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 12)
|
| 1374 |
+
tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
|
| 1375 |
+
|
| 1376 |
+
# NOT gates for binary encoding (need all even values for odd AND gates)
|
| 1377 |
+
for k in [2, 4, 6, 8, 10, 12]:
|
| 1378 |
+
tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
|
| 1379 |
+
tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
|
| 1380 |
+
|
| 1381 |
+
# Shift amount is min(CLZ, 12) encoded in 4 bits
|
| 1382 |
+
# bit3: CLZ >= 8
|
| 1383 |
+
tensors[f"{prefix}.shift3.weight"] = torch.tensor([1.0])
|
| 1384 |
+
tensors[f"{prefix}.shift3.bias"] = torch.tensor([-0.5]) # pass ge8
|
| 1385 |
+
|
| 1386 |
+
# bit2: CLZ in {4-7, 12} = (ge4 AND NOT ge8) OR ge12
|
| 1387 |
+
tensors[f"{prefix}.and_4_7.weight"] = torch.tensor([1.0, 1.0])
|
| 1388 |
+
tensors[f"{prefix}.and_4_7.bias"] = torch.tensor([-2.0])
|
| 1389 |
+
tensors[f"{prefix}.shift2.weight"] = torch.tensor([1.0, 1.0])
|
| 1390 |
+
tensors[f"{prefix}.shift2.bias"] = torch.tensor([-1.0])
|
| 1391 |
+
|
| 1392 |
+
# bit1: CLZ in {2,3,6,7,10,11}
|
| 1393 |
+
tensors[f"{prefix}.and_2_3.weight"] = torch.tensor([1.0, 1.0])
|
| 1394 |
+
tensors[f"{prefix}.and_2_3.bias"] = torch.tensor([-2.0])
|
| 1395 |
+
tensors[f"{prefix}.and_6_7.weight"] = torch.tensor([1.0, 1.0])
|
| 1396 |
+
tensors[f"{prefix}.and_6_7.bias"] = torch.tensor([-2.0])
|
| 1397 |
+
tensors[f"{prefix}.and_10_11.weight"] = torch.tensor([1.0, 1.0])
|
| 1398 |
+
tensors[f"{prefix}.and_10_11.bias"] = torch.tensor([-2.0])
|
| 1399 |
+
tensors[f"{prefix}.shift1.weight"] = torch.tensor([1.0, 1.0, 1.0])
|
| 1400 |
+
tensors[f"{prefix}.shift1.bias"] = torch.tensor([-1.0])
|
| 1401 |
+
|
| 1402 |
+
# bit0: CLZ is odd {1,3,5,7,9,11}
|
| 1403 |
+
for i in [1, 3, 5, 7, 9, 11]:
|
| 1404 |
+
tensors[f"{prefix}.and_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 1405 |
+
tensors[f"{prefix}.and_{i}.bias"] = torch.tensor([-2.0])
|
| 1406 |
+
tensors[f"{prefix}.shift0.weight"] = torch.tensor([1.0] * 6)
|
| 1407 |
+
tensors[f"{prefix}.shift0.bias"] = torch.tensor([-1.0])
|
| 1408 |
+
|
| 1409 |
+
# When overflow is set, shift amount should be 0 (we'll right-shift by 1 externally)
|
| 1410 |
+
# Mask shift bits with NOT overflow
|
| 1411 |
+
tensors[f"{prefix}.not_overflow.weight"] = torch.tensor([-1.0])
|
| 1412 |
+
tensors[f"{prefix}.not_overflow.bias"] = torch.tensor([0.0])
|
| 1413 |
+
|
| 1414 |
+
for i in range(4):
|
| 1415 |
+
tensors[f"{prefix}.out_shift{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 1416 |
+
tensors[f"{prefix}.out_shift{i}.bias"] = torch.tensor([-2.0])
|
| 1417 |
+
|
| 1418 |
+
return tensors
|
| 1419 |
+
|
| 1420 |
+
|
| 1421 |
def build_float16_cmp_tensors() -> Dict[str, torch.Tensor]:
|
| 1422 |
"""Build tensors for float16.cmp circuit.
|
| 1423 |
|
|
|
|
| 1558 |
return tensors
|
| 1559 |
|
| 1560 |
|
| 1561 |
+
def build_clz16bit_tensors() -> Dict[str, torch.Tensor]:
|
| 1562 |
+
"""Build tensors for arithmetic.clz16bit circuit.
|
| 1563 |
+
|
| 1564 |
+
CLZ16BIT counts leading zeros in a 16-bit input.
|
| 1565 |
+
Output is 0-16 (5 bits).
|
| 1566 |
+
|
| 1567 |
+
Architecture (same as CLZ8BIT):
|
| 1568 |
+
1. pz[k] gates: NOR of top k bits (fires if top k bits are all zero)
|
| 1569 |
+
2. ge[k] gates: sum of pz >= k (threshold gates)
|
| 1570 |
+
3. Logic gates to convert thermometer code to binary
|
| 1571 |
+
"""
|
| 1572 |
+
tensors = {}
|
| 1573 |
+
prefix = "arithmetic.clz16bit"
|
| 1574 |
+
|
| 1575 |
+
# === PREFIX ZERO GATES (NOR of top k bits) ===
|
| 1576 |
+
for k in range(1, 17):
|
| 1577 |
+
tensors[f"{prefix}.pz{k}.weight"] = torch.tensor([-1.0] * k)
|
| 1578 |
+
tensors[f"{prefix}.pz{k}.bias"] = torch.tensor([0.0])
|
| 1579 |
+
|
| 1580 |
+
# === GE GATES (sum of pz >= k) ===
|
| 1581 |
+
for k in range(1, 17):
|
| 1582 |
+
tensors[f"{prefix}.ge{k}.weight"] = torch.tensor([1.0] * 16)
|
| 1583 |
+
tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
|
| 1584 |
+
|
| 1585 |
+
# === NOT GATES (for all values used in range detection) ===
|
| 1586 |
+
for k in [2, 4, 6, 8, 10, 12, 14, 16]:
|
| 1587 |
+
tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
|
| 1588 |
+
tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
|
| 1589 |
+
|
| 1590 |
+
# === AND GATES for range detection ===
|
| 1591 |
+
# For 5-bit output (0-16), need to detect ranges for each bit
|
| 1592 |
+
|
| 1593 |
+
# bit4 (16's place): CLZ >= 16, just ge16
|
| 1594 |
+
# bit3 (8's place): CLZ in {8-15} = ge8 AND NOT ge16
|
| 1595 |
+
tensors[f"{prefix}.and_8_15.weight"] = torch.tensor([1.0, 1.0])
|
| 1596 |
+
tensors[f"{prefix}.and_8_15.bias"] = torch.tensor([-2.0])
|
| 1597 |
+
|
| 1598 |
+
# bit2 (4's place): CLZ in {4-7, 12-15}
|
| 1599 |
+
# = (ge4 AND NOT ge8) OR (ge12 AND NOT ge16)
|
| 1600 |
+
tensors[f"{prefix}.and_4_7.weight"] = torch.tensor([1.0, 1.0])
|
| 1601 |
+
tensors[f"{prefix}.and_4_7.bias"] = torch.tensor([-2.0])
|
| 1602 |
+
tensors[f"{prefix}.and_12_15.weight"] = torch.tensor([1.0, 1.0])
|
| 1603 |
+
tensors[f"{prefix}.and_12_15.bias"] = torch.tensor([-2.0])
|
| 1604 |
+
tensors[f"{prefix}.or_bit2.weight"] = torch.tensor([1.0, 1.0])
|
| 1605 |
+
tensors[f"{prefix}.or_bit2.bias"] = torch.tensor([-1.0])
|
| 1606 |
+
|
| 1607 |
+
# bit1 (2's place): CLZ in {2,3,6,7,10,11,14,15}
|
| 1608 |
+
tensors[f"{prefix}.and_2_3.weight"] = torch.tensor([1.0, 1.0])
|
| 1609 |
+
tensors[f"{prefix}.and_2_3.bias"] = torch.tensor([-2.0])
|
| 1610 |
+
tensors[f"{prefix}.and_6_7.weight"] = torch.tensor([1.0, 1.0])
|
| 1611 |
+
tensors[f"{prefix}.and_6_7.bias"] = torch.tensor([-2.0])
|
| 1612 |
+
tensors[f"{prefix}.and_10_11.weight"] = torch.tensor([1.0, 1.0])
|
| 1613 |
+
tensors[f"{prefix}.and_10_11.bias"] = torch.tensor([-2.0])
|
| 1614 |
+
tensors[f"{prefix}.and_14_15.weight"] = torch.tensor([1.0, 1.0])
|
| 1615 |
+
tensors[f"{prefix}.and_14_15.bias"] = torch.tensor([-2.0])
|
| 1616 |
+
tensors[f"{prefix}.or_bit1.weight"] = torch.tensor([1.0, 1.0, 1.0, 1.0])
|
| 1617 |
+
tensors[f"{prefix}.or_bit1.bias"] = torch.tensor([-1.0])
|
| 1618 |
+
|
| 1619 |
+
# bit0 (1's place): CLZ is odd {1,3,5,7,9,11,13,15}
|
| 1620 |
+
for i in [1, 3, 5, 7, 9, 11, 13, 15]:
|
| 1621 |
+
tensors[f"{prefix}.and_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 1622 |
+
tensors[f"{prefix}.and_{i}.bias"] = torch.tensor([-2.0])
|
| 1623 |
+
tensors[f"{prefix}.or_bit0.weight"] = torch.tensor([1.0] * 8)
|
| 1624 |
+
tensors[f"{prefix}.or_bit0.bias"] = torch.tensor([-1.0])
|
| 1625 |
+
|
| 1626 |
+
# === OUTPUT GATES ===
|
| 1627 |
+
tensors[f"{prefix}.out4.weight"] = torch.tensor([1.0])
|
| 1628 |
+
tensors[f"{prefix}.out4.bias"] = torch.tensor([-0.5]) # pass-through ge16
|
| 1629 |
+
|
| 1630 |
+
tensors[f"{prefix}.out3.weight"] = torch.tensor([1.0])
|
| 1631 |
+
tensors[f"{prefix}.out3.bias"] = torch.tensor([-0.5]) # pass-through and_8_15
|
| 1632 |
+
|
| 1633 |
+
tensors[f"{prefix}.out2.weight"] = torch.tensor([1.0])
|
| 1634 |
+
tensors[f"{prefix}.out2.bias"] = torch.tensor([-0.5]) # pass-through or_bit2
|
| 1635 |
+
|
| 1636 |
+
tensors[f"{prefix}.out1.weight"] = torch.tensor([1.0])
|
| 1637 |
+
tensors[f"{prefix}.out1.bias"] = torch.tensor([-0.5]) # pass-through or_bit1
|
| 1638 |
+
|
| 1639 |
+
tensors[f"{prefix}.out0.weight"] = torch.tensor([1.0])
|
| 1640 |
+
tensors[f"{prefix}.out0.bias"] = torch.tensor([-0.5]) # pass-through or_bit0
|
| 1641 |
+
|
| 1642 |
+
return tensors
|
| 1643 |
+
|
| 1644 |
+
|
| 1645 |
def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
|
| 1646 |
"""Build tensors for arithmetic.clz8bit circuit.
|
| 1647 |
|
|
|
|
| 1717 |
tensors.update(clz_tensors)
|
| 1718 |
print(f" CLZ8BIT: {len(clz_tensors)} tensors")
|
| 1719 |
|
| 1720 |
+
clz16_tensors = build_clz16bit_tensors()
|
| 1721 |
+
tensors.update(clz16_tensors)
|
| 1722 |
+
print(f" CLZ16BIT: {len(clz16_tensors)} tensors")
|
| 1723 |
+
|
| 1724 |
unpack_tensors = build_float16_unpack_tensors()
|
| 1725 |
tensors.update(unpack_tensors)
|
| 1726 |
print(f" float16.unpack: {len(unpack_tensors)} tensors")
|
eval.py
CHANGED
|
@@ -291,6 +291,52 @@ class CircuitEvaluator:
|
|
| 291 |
|
| 292 |
return TestResult('arithmetic.clz8bit', passed, 256, failures)
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
# =========================================================================
|
| 295 |
# FLOAT16 TESTS
|
| 296 |
# =========================================================================
|
|
@@ -623,6 +669,11 @@ class Evaluator:
|
|
| 623 |
self.results.append(result)
|
| 624 |
if verbose:
|
| 625 |
self._print_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
|
| 627 |
# Float16
|
| 628 |
if verbose:
|
|
|
|
| 291 |
|
| 292 |
return TestResult('arithmetic.clz8bit', passed, 256, failures)
|
| 293 |
|
| 294 |
+
def test_clz16bit(self) -> TestResult:
|
| 295 |
+
"""Test 16-bit count leading zeros."""
|
| 296 |
+
prefix = 'arithmetic.clz16bit'
|
| 297 |
+
failures = []
|
| 298 |
+
passed = 0
|
| 299 |
+
|
| 300 |
+
# Test all powers of 2 and some random values
|
| 301 |
+
test_values = [0] + [1 << i for i in range(16)] # 0, 1, 2, 4, ..., 32768
|
| 302 |
+
|
| 303 |
+
import random
|
| 304 |
+
random.seed(42)
|
| 305 |
+
for _ in range(200):
|
| 306 |
+
test_values.append(random.randint(0, 0xFFFF))
|
| 307 |
+
|
| 308 |
+
for val in test_values:
|
| 309 |
+
# Expected CLZ
|
| 310 |
+
expected = 16
|
| 311 |
+
for i in range(16):
|
| 312 |
+
if (val >> (15-i)) & 1:
|
| 313 |
+
expected = i
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
# Set up inputs: $x[15] = MSB, $x[0] = LSB
|
| 317 |
+
ext = {}
|
| 318 |
+
for i in range(16):
|
| 319 |
+
ext[f'{prefix}.$x[{i}]'] = float((val >> i) & 1)
|
| 320 |
+
|
| 321 |
+
values = self.eval_circuit(prefix, ext)
|
| 322 |
+
|
| 323 |
+
# Extract result from output gates
|
| 324 |
+
out4 = values.get(f'{prefix}.out4', 0)
|
| 325 |
+
out3 = values.get(f'{prefix}.out3', 0)
|
| 326 |
+
out2 = values.get(f'{prefix}.out2', 0)
|
| 327 |
+
out1 = values.get(f'{prefix}.out1', 0)
|
| 328 |
+
out0 = values.get(f'{prefix}.out0', 0)
|
| 329 |
+
|
| 330 |
+
result = int(out4)*16 + int(out3)*8 + int(out2)*4 + int(out1)*2 + int(out0)
|
| 331 |
+
|
| 332 |
+
if result == expected:
|
| 333 |
+
passed += 1
|
| 334 |
+
else:
|
| 335 |
+
if len(failures) < 10:
|
| 336 |
+
failures.append((val, expected, result))
|
| 337 |
+
|
| 338 |
+
return TestResult('arithmetic.clz16bit', passed, len(test_values), failures)
|
| 339 |
+
|
| 340 |
# =========================================================================
|
| 341 |
# FLOAT16 TESTS
|
| 342 |
# =========================================================================
|
|
|
|
| 669 |
self.results.append(result)
|
| 670 |
if verbose:
|
| 671 |
self._print_result(result)
|
| 672 |
+
if 'arithmetic.clz16bit.pz1.weight' in self.eval.tensors:
|
| 673 |
+
result = self.eval.test_clz16bit()
|
| 674 |
+
self.results.append(result)
|
| 675 |
+
if verbose:
|
| 676 |
+
self._print_result(result)
|
| 677 |
|
| 678 |
# Float16
|
| 679 |
if verbose:
|