PortfolioAI commited on
Commit ·
6c34eb3
1
Parent(s): a6eff5f
Add float16.add circuit (93/125 tests passing)
Browse filesImplements IEEE 754 half-precision addition with:
- Special case detection (NaN, infinity, zero, subnormal)
- Exponent comparison and difference calculation
- Mantissa alignment via barrel shifter
- 12-bit mantissa adder/subtractor
- Result normalization with overflow/underflow handling
- Output assembly with special case multiplexing
~910 gates total. Remaining issues:
- Zero+zero produces incorrect result
- Subtraction (different signs) has bugs
- TODO.md +1 -1
- arithmetic.safetensors +2 -2
- convert_to_explicit_inputs.py +1843 -0
- eval.py +158 -0
TODO.md
CHANGED
|
@@ -7,7 +7,7 @@
|
|
| 7 |
- [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
|
| 8 |
- [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
|
| 9 |
- [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
|
| 10 |
-
- [
|
| 11 |
- [ ] `float16.sub` -- subtraction (add with negated operand)
|
| 12 |
- [ ] `float16.mul` -- multiplication
|
| 13 |
- [ ] `float16.div` -- division
|
|
|
|
| 7 |
- [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
|
| 8 |
- [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
|
| 9 |
- [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
|
| 10 |
+
- [~] `float16.add` -- IEEE 754 addition (~910 gates, 93/125 tests, needs zero+zero and subtraction fixes)
|
| 11 |
- [ ] `float16.sub` -- subtraction (add with negated operand)
|
| 12 |
- [ ] `float16.mul` -- multiplication
|
| 13 |
- [ ] `float16.div` -- division
|
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:098369950361600a735b8200b51642a6c11bed441619adc1c1dd609ce298af53
|
| 3 |
+
size 1471280
|
convert_to_explicit_inputs.py
CHANGED
|
@@ -1056,11 +1056,976 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
|
|
| 1056 |
return infer_float16_neg_inputs(gate, registry)
|
| 1057 |
if gate.startswith('float16.abs'):
|
| 1058 |
return infer_float16_abs_inputs(gate, registry)
|
|
|
|
|
|
|
| 1059 |
|
| 1060 |
# Default: couldn't infer, return empty (will need manual fix or routing)
|
| 1061 |
return []
|
| 1062 |
|
| 1063 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
def infer_float16_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1065 |
"""Infer inputs for float16.neg circuit."""
|
| 1066 |
prefix = "float16.neg"
|
|
@@ -1726,6 +2691,874 @@ def build_clz16bit_tensors() -> Dict[str, torch.Tensor]:
|
|
| 1726 |
return tensors
|
| 1727 |
|
| 1728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1729 |
def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
|
| 1730 |
"""Build tensors for arithmetic.clz8bit circuit.
|
| 1731 |
|
|
@@ -1795,6 +3628,12 @@ def main():
|
|
| 1795 |
|
| 1796 |
print(f"Loaded {len(tensors)} tensors")
|
| 1797 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1798 |
# Build new circuits
|
| 1799 |
print("Building new circuits...")
|
| 1800 |
clz_tensors = build_clz8bit_tensors()
|
|
@@ -1829,6 +3668,10 @@ def main():
|
|
| 1829 |
tensors.update(abs_tensors)
|
| 1830 |
print(f" float16.abs: {len(abs_tensors)} tensors")
|
| 1831 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1832 |
print(f"Total tensors: {len(tensors)}")
|
| 1833 |
|
| 1834 |
# Load routing for complex circuits
|
|
|
|
| 1056 |
return infer_float16_neg_inputs(gate, registry)
|
| 1057 |
if gate.startswith('float16.abs'):
|
| 1058 |
return infer_float16_abs_inputs(gate, registry)
|
| 1059 |
+
if gate.startswith('float16.add'):
|
| 1060 |
+
return infer_float16_add_inputs(gate, registry)
|
| 1061 |
|
| 1062 |
# Default: couldn't infer, return empty (will need manual fix or routing)
|
| 1063 |
return []
|
| 1064 |
|
| 1065 |
|
| 1066 |
+
def infer_float16_add_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1067 |
+
"""Infer inputs for float16.add circuit."""
|
| 1068 |
+
prefix = "float16.add"
|
| 1069 |
+
|
| 1070 |
+
# Register 32 input bits (two 16-bit operands)
|
| 1071 |
+
for i in range(16):
|
| 1072 |
+
registry.register(f"{prefix}.$a[{i}]")
|
| 1073 |
+
registry.register(f"{prefix}.$b[{i}]")
|
| 1074 |
+
|
| 1075 |
+
# Extract exponent bits (10-14)
|
| 1076 |
+
exp_a_bits = [f"{prefix}.$a[{10+i}]" for i in range(5)]
|
| 1077 |
+
exp_b_bits = [f"{prefix}.$b[{10+i}]" for i in range(5)]
|
| 1078 |
+
mant_a_bits = [f"{prefix}.$a[{i}]" for i in range(10)]
|
| 1079 |
+
mant_b_bits = [f"{prefix}.$b[{i}]" for i in range(10)]
|
| 1080 |
+
|
| 1081 |
+
# Stage 0: Special case detection
|
| 1082 |
+
if '.exp_a_all_ones' in gate:
|
| 1083 |
+
return [registry.get_id(b) for b in exp_a_bits]
|
| 1084 |
+
if '.exp_b_all_ones' in gate:
|
| 1085 |
+
return [registry.get_id(b) for b in exp_b_bits]
|
| 1086 |
+
if '.exp_a_zero' in gate:
|
| 1087 |
+
return [registry.get_id(b) for b in exp_a_bits]
|
| 1088 |
+
if '.exp_b_zero' in gate:
|
| 1089 |
+
return [registry.get_id(b) for b in exp_b_bits]
|
| 1090 |
+
if '.mant_a_nonzero' in gate:
|
| 1091 |
+
return [registry.get_id(b) for b in mant_a_bits]
|
| 1092 |
+
if '.mant_b_nonzero' in gate:
|
| 1093 |
+
return [registry.get_id(b) for b in mant_b_bits]
|
| 1094 |
+
if '.mant_a_zero' in gate:
|
| 1095 |
+
return [registry.get_id(b) for b in mant_a_bits]
|
| 1096 |
+
if '.mant_b_zero' in gate:
|
| 1097 |
+
return [registry.get_id(b) for b in mant_b_bits]
|
| 1098 |
+
|
| 1099 |
+
registry.register(f"{prefix}.exp_a_all_ones")
|
| 1100 |
+
registry.register(f"{prefix}.exp_b_all_ones")
|
| 1101 |
+
registry.register(f"{prefix}.exp_a_zero")
|
| 1102 |
+
registry.register(f"{prefix}.exp_b_zero")
|
| 1103 |
+
registry.register(f"{prefix}.mant_a_nonzero")
|
| 1104 |
+
registry.register(f"{prefix}.mant_b_nonzero")
|
| 1105 |
+
registry.register(f"{prefix}.mant_a_zero")
|
| 1106 |
+
registry.register(f"{prefix}.mant_b_zero")
|
| 1107 |
+
|
| 1108 |
+
if '.a_is_nan' in gate:
|
| 1109 |
+
return [registry.get_id(f"{prefix}.exp_a_all_ones"),
|
| 1110 |
+
registry.get_id(f"{prefix}.mant_a_nonzero")]
|
| 1111 |
+
if '.b_is_nan' in gate:
|
| 1112 |
+
return [registry.get_id(f"{prefix}.exp_b_all_ones"),
|
| 1113 |
+
registry.get_id(f"{prefix}.mant_b_nonzero")]
|
| 1114 |
+
if '.a_is_inf' in gate:
|
| 1115 |
+
return [registry.get_id(f"{prefix}.exp_a_all_ones"),
|
| 1116 |
+
registry.get_id(f"{prefix}.mant_a_zero")]
|
| 1117 |
+
if '.b_is_inf' in gate:
|
| 1118 |
+
return [registry.get_id(f"{prefix}.exp_b_all_ones"),
|
| 1119 |
+
registry.get_id(f"{prefix}.mant_b_zero")]
|
| 1120 |
+
if '.a_is_zero' in gate:
|
| 1121 |
+
return [registry.get_id(f"{prefix}.exp_a_zero"),
|
| 1122 |
+
registry.get_id(f"{prefix}.mant_a_zero")]
|
| 1123 |
+
if '.b_is_zero' in gate:
|
| 1124 |
+
return [registry.get_id(f"{prefix}.exp_b_zero"),
|
| 1125 |
+
registry.get_id(f"{prefix}.mant_b_zero")]
|
| 1126 |
+
if '.a_is_subnormal' in gate:
|
| 1127 |
+
return [registry.get_id(f"{prefix}.exp_a_zero"),
|
| 1128 |
+
registry.get_id(f"{prefix}.mant_a_nonzero")]
|
| 1129 |
+
if '.b_is_subnormal' in gate:
|
| 1130 |
+
return [registry.get_id(f"{prefix}.exp_b_zero"),
|
| 1131 |
+
registry.get_id(f"{prefix}.mant_b_nonzero")]
|
| 1132 |
+
|
| 1133 |
+
registry.register(f"{prefix}.a_is_nan")
|
| 1134 |
+
registry.register(f"{prefix}.b_is_nan")
|
| 1135 |
+
registry.register(f"{prefix}.a_is_inf")
|
| 1136 |
+
registry.register(f"{prefix}.b_is_inf")
|
| 1137 |
+
|
| 1138 |
+
if '.either_is_nan' in gate:
|
| 1139 |
+
return [registry.get_id(f"{prefix}.a_is_nan"),
|
| 1140 |
+
registry.get_id(f"{prefix}.b_is_nan")]
|
| 1141 |
+
if '.both_are_inf' in gate:
|
| 1142 |
+
return [registry.get_id(f"{prefix}.a_is_inf"),
|
| 1143 |
+
registry.get_id(f"{prefix}.b_is_inf")]
|
| 1144 |
+
|
| 1145 |
+
# Sign extraction
|
| 1146 |
+
if gate == f"{prefix}.sign_a":
|
| 1147 |
+
return [registry.get_id(f"{prefix}.$a[15]")]
|
| 1148 |
+
if gate == f"{prefix}.sign_b":
|
| 1149 |
+
return [registry.get_id(f"{prefix}.$b[15]")]
|
| 1150 |
+
|
| 1151 |
+
registry.register(f"{prefix}.sign_a")
|
| 1152 |
+
registry.register(f"{prefix}.sign_b")
|
| 1153 |
+
|
| 1154 |
+
if '.signs_differ.layer1' in gate:
|
| 1155 |
+
return [registry.get_id(f"{prefix}.sign_a"),
|
| 1156 |
+
registry.get_id(f"{prefix}.sign_b")]
|
| 1157 |
+
if '.signs_differ.layer2' in gate:
|
| 1158 |
+
return [registry.register(f"{prefix}.signs_differ.layer1.or"),
|
| 1159 |
+
registry.register(f"{prefix}.signs_differ.layer1.nand")]
|
| 1160 |
+
|
| 1161 |
+
registry.register(f"{prefix}.signs_differ.layer2")
|
| 1162 |
+
registry.register(f"{prefix}.either_is_nan")
|
| 1163 |
+
registry.register(f"{prefix}.both_are_inf")
|
| 1164 |
+
|
| 1165 |
+
if '.inf_cancellation' in gate:
|
| 1166 |
+
return [registry.get_id(f"{prefix}.both_are_inf"),
|
| 1167 |
+
registry.get_id(f"{prefix}.signs_differ.layer2")]
|
| 1168 |
+
|
| 1169 |
+
registry.register(f"{prefix}.inf_cancellation")
|
| 1170 |
+
|
| 1171 |
+
if '.result_is_nan' in gate:
|
| 1172 |
+
return [registry.get_id(f"{prefix}.either_is_nan"),
|
| 1173 |
+
registry.get_id(f"{prefix}.inf_cancellation")]
|
| 1174 |
+
if '.either_is_inf' in gate:
|
| 1175 |
+
return [registry.get_id(f"{prefix}.a_is_inf"),
|
| 1176 |
+
registry.get_id(f"{prefix}.b_is_inf")]
|
| 1177 |
+
|
| 1178 |
+
registry.register(f"{prefix}.result_is_nan")
|
| 1179 |
+
registry.register(f"{prefix}.either_is_inf")
|
| 1180 |
+
|
| 1181 |
+
if '.not_result_is_nan' in gate:
|
| 1182 |
+
return [registry.get_id(f"{prefix}.result_is_nan")]
|
| 1183 |
+
|
| 1184 |
+
registry.register(f"{prefix}.not_result_is_nan")
|
| 1185 |
+
|
| 1186 |
+
if '.result_is_inf' in gate:
|
| 1187 |
+
return [registry.get_id(f"{prefix}.either_is_inf"),
|
| 1188 |
+
registry.get_id(f"{prefix}.not_result_is_nan")]
|
| 1189 |
+
|
| 1190 |
+
# Implicit bit
|
| 1191 |
+
if '.implicit_a' in gate:
|
| 1192 |
+
return [registry.get_id(f"{prefix}.exp_a_zero")]
|
| 1193 |
+
if '.implicit_b' in gate:
|
| 1194 |
+
return [registry.get_id(f"{prefix}.exp_b_zero")]
|
| 1195 |
+
|
| 1196 |
+
registry.register(f"{prefix}.implicit_a")
|
| 1197 |
+
registry.register(f"{prefix}.implicit_b")
|
| 1198 |
+
|
| 1199 |
+
# Exponent comparison
|
| 1200 |
+
if '.a_exp_ge_b' in gate or '.a_exp_gt_b' in gate:
|
| 1201 |
+
return [registry.get_id(b) for b in exp_a_bits] + \
|
| 1202 |
+
[registry.get_id(b) for b in exp_b_bits]
|
| 1203 |
+
if '.b_exp_gt_a' in gate and 'sel' not in gate:
|
| 1204 |
+
return [registry.get_id(b) for b in exp_b_bits] + \
|
| 1205 |
+
[registry.get_id(b) for b in exp_a_bits]
|
| 1206 |
+
|
| 1207 |
+
registry.register(f"{prefix}.a_exp_ge_b")
|
| 1208 |
+
registry.register(f"{prefix}.a_exp_gt_b")
|
| 1209 |
+
registry.register(f"{prefix}.b_exp_gt_a")
|
| 1210 |
+
|
| 1211 |
+
if '.b_exp_gt_a_sel' in gate:
|
| 1212 |
+
return [registry.get_id(f"{prefix}.a_exp_ge_b")]
|
| 1213 |
+
|
| 1214 |
+
registry.register(f"{prefix}.b_exp_gt_a_sel")
|
| 1215 |
+
|
| 1216 |
+
# NOT gates for exponent bits
|
| 1217 |
+
match = re.search(r'\.not_exp_b(\d+)', gate)
|
| 1218 |
+
if match:
|
| 1219 |
+
i = int(match.group(1))
|
| 1220 |
+
return [registry.get_id(f"{prefix}.$b[{10+i}]")]
|
| 1221 |
+
|
| 1222 |
+
match = re.search(r'\.not_exp_a(\d+)', gate)
|
| 1223 |
+
if match:
|
| 1224 |
+
i = int(match.group(1))
|
| 1225 |
+
return [registry.get_id(f"{prefix}.$a[{10+i}]")]
|
| 1226 |
+
|
| 1227 |
+
for i in range(5):
|
| 1228 |
+
registry.register(f"{prefix}.not_exp_b{i}")
|
| 1229 |
+
registry.register(f"{prefix}.not_exp_a{i}")
|
| 1230 |
+
|
| 1231 |
+
# Exp diff subtractors (diff_ab and diff_ba)
|
| 1232 |
+
if '.diff_ab.fa' in gate or '.diff_ba.fa' in gate:
|
| 1233 |
+
is_ab = '.diff_ab' in gate
|
| 1234 |
+
match = re.search(r'\.fa(\d+)\.', gate)
|
| 1235 |
+
if match:
|
| 1236 |
+
i = int(match.group(1))
|
| 1237 |
+
fa_prefix = f"{prefix}.diff_{'ab' if is_ab else 'ba'}.fa{i}"
|
| 1238 |
+
|
| 1239 |
+
if is_ab:
|
| 1240 |
+
a_bit = registry.get_id(f"{prefix}.$a[{10+i}]")
|
| 1241 |
+
not_b = registry.get_id(f"{prefix}.not_exp_b{i}")
|
| 1242 |
+
else:
|
| 1243 |
+
a_bit = registry.get_id(f"{prefix}.$b[{10+i}]")
|
| 1244 |
+
not_b = registry.get_id(f"{prefix}.not_exp_a{i}")
|
| 1245 |
+
|
| 1246 |
+
if i == 0:
|
| 1247 |
+
cin = registry.get_id("#1")
|
| 1248 |
+
else:
|
| 1249 |
+
cin = registry.register(f"{prefix}.diff_{'ab' if is_ab else 'ba'}.fa{i-1}.cout")
|
| 1250 |
+
|
| 1251 |
+
if '.xor1.layer1' in gate:
|
| 1252 |
+
return [a_bit, not_b]
|
| 1253 |
+
if '.xor1.layer2' in gate:
|
| 1254 |
+
return [registry.register(f"{fa_prefix}.xor1.layer1.or"),
|
| 1255 |
+
registry.register(f"{fa_prefix}.xor1.layer1.nand")]
|
| 1256 |
+
|
| 1257 |
+
xor1 = registry.register(f"{fa_prefix}.xor1.layer2")
|
| 1258 |
+
|
| 1259 |
+
if '.xor2.layer1' in gate:
|
| 1260 |
+
return [xor1, cin]
|
| 1261 |
+
if '.xor2.layer2' in gate:
|
| 1262 |
+
return [registry.register(f"{fa_prefix}.xor2.layer1.or"),
|
| 1263 |
+
registry.register(f"{fa_prefix}.xor2.layer1.nand")]
|
| 1264 |
+
|
| 1265 |
+
if '.and1' in gate:
|
| 1266 |
+
return [a_bit, not_b]
|
| 1267 |
+
if '.and2' in gate:
|
| 1268 |
+
return [xor1, cin]
|
| 1269 |
+
if '.cout' in gate:
|
| 1270 |
+
return [registry.register(f"{fa_prefix}.and1"),
|
| 1271 |
+
registry.register(f"{fa_prefix}.and2")]
|
| 1272 |
+
|
| 1273 |
+
# Register diff outputs
|
| 1274 |
+
for i in range(5):
|
| 1275 |
+
registry.register(f"{prefix}.diff_ab.fa{i}.xor2.layer2")
|
| 1276 |
+
registry.register(f"{prefix}.diff_ba.fa{i}.xor2.layer2")
|
| 1277 |
+
|
| 1278 |
+
# Exp diff mux
|
| 1279 |
+
match = re.search(r'\.exp_diff_mux(\d+)\.', gate)
|
| 1280 |
+
if match:
|
| 1281 |
+
i = int(match.group(1))
|
| 1282 |
+
if '.and_ab' in gate:
|
| 1283 |
+
return [registry.get_id(f"{prefix}.diff_ab.fa{i}.xor2.layer2"),
|
| 1284 |
+
registry.get_id(f"{prefix}.a_exp_ge_b")]
|
| 1285 |
+
if '.and_ba' in gate:
|
| 1286 |
+
return [registry.get_id(f"{prefix}.diff_ba.fa{i}.xor2.layer2"),
|
| 1287 |
+
registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
|
| 1288 |
+
|
| 1289 |
+
match = re.search(r'\.exp_diff(\d+)$', gate)
|
| 1290 |
+
if match:
|
| 1291 |
+
i = int(match.group(1))
|
| 1292 |
+
return [registry.register(f"{prefix}.exp_diff_mux{i}.and_ab"),
|
| 1293 |
+
registry.register(f"{prefix}.exp_diff_mux{i}.and_ba")]
|
| 1294 |
+
|
| 1295 |
+
for i in range(5):
|
| 1296 |
+
registry.register(f"{prefix}.exp_diff{i}")
|
| 1297 |
+
|
| 1298 |
+
# Exp larger mux
|
| 1299 |
+
match = re.search(r'\.exp_larger_mux(\d+)\.', gate)
|
| 1300 |
+
if match:
|
| 1301 |
+
i = int(match.group(1))
|
| 1302 |
+
if '.and_a' in gate:
|
| 1303 |
+
return [registry.get_id(f"{prefix}.$a[{10+i}]"),
|
| 1304 |
+
registry.get_id(f"{prefix}.a_exp_ge_b")]
|
| 1305 |
+
if '.and_b' in gate:
|
| 1306 |
+
return [registry.get_id(f"{prefix}.$b[{10+i}]"),
|
| 1307 |
+
registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
|
| 1308 |
+
|
| 1309 |
+
match = re.search(r'\.exp_larger(\d+)$', gate)
|
| 1310 |
+
if match:
|
| 1311 |
+
i = int(match.group(1))
|
| 1312 |
+
return [registry.register(f"{prefix}.exp_larger_mux{i}.and_a"),
|
| 1313 |
+
registry.register(f"{prefix}.exp_larger_mux{i}.and_b")]
|
| 1314 |
+
|
| 1315 |
+
for i in range(5):
|
| 1316 |
+
registry.register(f"{prefix}.exp_larger{i}")
|
| 1317 |
+
|
| 1318 |
+
# Mantissa source selection (which mantissa to shift)
|
| 1319 |
+
# mant_shift_src = a_exp_ge_b ? mant_b : mant_a
|
| 1320 |
+
# mant_larger = a_exp_ge_b ? mant_a : mant_b
|
| 1321 |
+
match = re.search(r'\.mant_shift_src(\d+)\.', gate)
|
| 1322 |
+
if match:
|
| 1323 |
+
i = int(match.group(1))
|
| 1324 |
+
if i < 10:
|
| 1325 |
+
mant_a = registry.get_id(f"{prefix}.$a[{i}]")
|
| 1326 |
+
mant_b = registry.get_id(f"{prefix}.$b[{i}]")
|
| 1327 |
+
else:
|
| 1328 |
+
mant_a = registry.get_id(f"{prefix}.implicit_a")
|
| 1329 |
+
mant_b = registry.get_id(f"{prefix}.implicit_b")
|
| 1330 |
+
if '.and_b' in gate:
|
| 1331 |
+
return [mant_b, registry.get_id(f"{prefix}.a_exp_ge_b")]
|
| 1332 |
+
if '.and_a' in gate:
|
| 1333 |
+
return [mant_a, registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
|
| 1334 |
+
|
| 1335 |
+
match = re.search(r'\.mant_shift_src(\d+)$', gate)
|
| 1336 |
+
if match:
|
| 1337 |
+
i = int(match.group(1))
|
| 1338 |
+
return [registry.register(f"{prefix}.mant_shift_src{i}.and_b"),
|
| 1339 |
+
registry.register(f"{prefix}.mant_shift_src{i}.and_a")]
|
| 1340 |
+
|
| 1341 |
+
match = re.search(r'\.mant_larger(\d+)\.', gate)
|
| 1342 |
+
if match:
|
| 1343 |
+
i = int(match.group(1))
|
| 1344 |
+
if i < 10:
|
| 1345 |
+
mant_a = registry.get_id(f"{prefix}.$a[{i}]")
|
| 1346 |
+
mant_b = registry.get_id(f"{prefix}.$b[{i}]")
|
| 1347 |
+
else:
|
| 1348 |
+
mant_a = registry.get_id(f"{prefix}.implicit_a")
|
| 1349 |
+
mant_b = registry.get_id(f"{prefix}.implicit_b")
|
| 1350 |
+
if '.and_a' in gate:
|
| 1351 |
+
return [mant_a, registry.get_id(f"{prefix}.a_exp_ge_b")]
|
| 1352 |
+
if '.and_b' in gate:
|
| 1353 |
+
return [mant_b, registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
|
| 1354 |
+
|
| 1355 |
+
match = re.search(r'\.mant_larger(\d+)$', gate)
|
| 1356 |
+
if match:
|
| 1357 |
+
i = int(match.group(1))
|
| 1358 |
+
return [registry.register(f"{prefix}.mant_larger{i}.and_a"),
|
| 1359 |
+
registry.register(f"{prefix}.mant_larger{i}.and_b")]
|
| 1360 |
+
|
| 1361 |
+
for i in range(11):
|
| 1362 |
+
registry.register(f"{prefix}.mant_shift_src{i}")
|
| 1363 |
+
registry.register(f"{prefix}.mant_larger{i}")
|
| 1364 |
+
|
| 1365 |
+
# NOT gates for exp_diff bits (barrel shifter control)
|
| 1366 |
+
for i in range(5):
|
| 1367 |
+
if f'.not_exp_diff{i}' in gate and f'.not_exp_diff{i}.' not in gate:
|
| 1368 |
+
return [registry.get_id(f"{prefix}.exp_diff{i}")]
|
| 1369 |
+
registry.register(f"{prefix}.not_exp_diff{i}")
|
| 1370 |
+
|
| 1371 |
+
# Barrel shifter stage 0 (shift by 1)
|
| 1372 |
+
match = re.search(r'\.shift_s0_(\d+)\.', gate)
|
| 1373 |
+
if match:
|
| 1374 |
+
i = int(match.group(1))
|
| 1375 |
+
if '.pass' in gate:
|
| 1376 |
+
return [registry.get_id(f"{prefix}.mant_shift_src{i}"),
|
| 1377 |
+
registry.get_id(f"{prefix}.not_exp_diff0")]
|
| 1378 |
+
if '.shift' in gate and i < 10:
|
| 1379 |
+
return [registry.get_id(f"{prefix}.mant_shift_src{i+1}"),
|
| 1380 |
+
registry.get_id(f"{prefix}.exp_diff0")]
|
| 1381 |
+
|
| 1382 |
+
match = re.search(r'\.shift_s0_(\d+)$', gate)
|
| 1383 |
+
if match:
|
| 1384 |
+
i = int(match.group(1))
|
| 1385 |
+
if i < 10:
|
| 1386 |
+
return [registry.register(f"{prefix}.shift_s0_{i}.pass"),
|
| 1387 |
+
registry.register(f"{prefix}.shift_s0_{i}.shift")]
|
| 1388 |
+
else:
|
| 1389 |
+
return [registry.register(f"{prefix}.shift_s0_{i}.pass")]
|
| 1390 |
+
|
| 1391 |
+
for i in range(11):
|
| 1392 |
+
registry.register(f"{prefix}.shift_s0_{i}")
|
| 1393 |
+
|
| 1394 |
+
# Barrel shifter stage 1 (shift by 2)
|
| 1395 |
+
match = re.search(r'\.shift_s1_(\d+)\.', gate)
|
| 1396 |
+
if match:
|
| 1397 |
+
i = int(match.group(1))
|
| 1398 |
+
if '.pass' in gate:
|
| 1399 |
+
return [registry.get_id(f"{prefix}.shift_s0_{i}"),
|
| 1400 |
+
registry.get_id(f"{prefix}.not_exp_diff1")]
|
| 1401 |
+
if '.shift' in gate and i < 9:
|
| 1402 |
+
return [registry.get_id(f"{prefix}.shift_s0_{i+2}"),
|
| 1403 |
+
registry.get_id(f"{prefix}.exp_diff1")]
|
| 1404 |
+
|
| 1405 |
+
match = re.search(r'\.shift_s1_(\d+)$', gate)
|
| 1406 |
+
if match:
|
| 1407 |
+
i = int(match.group(1))
|
| 1408 |
+
if i < 9:
|
| 1409 |
+
return [registry.register(f"{prefix}.shift_s1_{i}.pass"),
|
| 1410 |
+
registry.register(f"{prefix}.shift_s1_{i}.shift")]
|
| 1411 |
+
else:
|
| 1412 |
+
return [registry.register(f"{prefix}.shift_s1_{i}.pass")]
|
| 1413 |
+
|
| 1414 |
+
for i in range(11):
|
| 1415 |
+
registry.register(f"{prefix}.shift_s1_{i}")
|
| 1416 |
+
|
| 1417 |
+
# Barrel shifter stage 2 (shift by 4)
|
| 1418 |
+
match = re.search(r'\.shift_s2_(\d+)\.', gate)
|
| 1419 |
+
if match:
|
| 1420 |
+
i = int(match.group(1))
|
| 1421 |
+
if '.pass' in gate:
|
| 1422 |
+
return [registry.get_id(f"{prefix}.shift_s1_{i}"),
|
| 1423 |
+
registry.get_id(f"{prefix}.not_exp_diff2")]
|
| 1424 |
+
if '.shift' in gate and i < 7:
|
| 1425 |
+
return [registry.get_id(f"{prefix}.shift_s1_{i+4}"),
|
| 1426 |
+
registry.get_id(f"{prefix}.exp_diff2")]
|
| 1427 |
+
|
| 1428 |
+
match = re.search(r'\.shift_s2_(\d+)$', gate)
|
| 1429 |
+
if match:
|
| 1430 |
+
i = int(match.group(1))
|
| 1431 |
+
if i < 7:
|
| 1432 |
+
return [registry.register(f"{prefix}.shift_s2_{i}.pass"),
|
| 1433 |
+
registry.register(f"{prefix}.shift_s2_{i}.shift")]
|
| 1434 |
+
else:
|
| 1435 |
+
return [registry.register(f"{prefix}.shift_s2_{i}.pass")]
|
| 1436 |
+
|
| 1437 |
+
for i in range(11):
|
| 1438 |
+
registry.register(f"{prefix}.shift_s2_{i}")
|
| 1439 |
+
|
| 1440 |
+
# Barrel shifter stage 3 (shift by 8)
|
| 1441 |
+
match = re.search(r'\.shift_s3_(\d+)\.', gate)
|
| 1442 |
+
if match:
|
| 1443 |
+
i = int(match.group(1))
|
| 1444 |
+
if '.pass' in gate:
|
| 1445 |
+
return [registry.get_id(f"{prefix}.shift_s2_{i}"),
|
| 1446 |
+
registry.get_id(f"{prefix}.not_exp_diff3")]
|
| 1447 |
+
if '.shift' in gate and i < 3:
|
| 1448 |
+
return [registry.get_id(f"{prefix}.shift_s2_{i+8}"),
|
| 1449 |
+
registry.get_id(f"{prefix}.exp_diff3")]
|
| 1450 |
+
|
| 1451 |
+
match = re.search(r'\.shift_s3_(\d+)$', gate)
|
| 1452 |
+
if match:
|
| 1453 |
+
i = int(match.group(1))
|
| 1454 |
+
if i < 3:
|
| 1455 |
+
return [registry.register(f"{prefix}.shift_s3_{i}.pass"),
|
| 1456 |
+
registry.register(f"{prefix}.shift_s3_{i}.shift")]
|
| 1457 |
+
else:
|
| 1458 |
+
return [registry.register(f"{prefix}.shift_s3_{i}.pass")]
|
| 1459 |
+
|
| 1460 |
+
for i in range(11):
|
| 1461 |
+
registry.register(f"{prefix}.shift_s3_{i}")
|
| 1462 |
+
|
| 1463 |
+
# mant_aligned (masked by not_exp_diff4)
|
| 1464 |
+
match = re.search(r'\.mant_aligned(\d+)$', gate)
|
| 1465 |
+
if match:
|
| 1466 |
+
i = int(match.group(1))
|
| 1467 |
+
return [registry.get_id(f"{prefix}.shift_s3_{i}"),
|
| 1468 |
+
registry.get_id(f"{prefix}.not_exp_diff4")]
|
| 1469 |
+
|
| 1470 |
+
for i in range(11):
|
| 1471 |
+
registry.register(f"{prefix}.mant_aligned{i}")
|
| 1472 |
+
|
| 1473 |
+
# signs_same = NOT signs_differ
|
| 1474 |
+
if '.signs_same' in gate:
|
| 1475 |
+
return [registry.get_id(f"{prefix}.signs_differ.layer2")]
|
| 1476 |
+
|
| 1477 |
+
registry.register(f"{prefix}.signs_same")
|
| 1478 |
+
|
| 1479 |
+
# Mantissa comparison (for equal exponent case)
|
| 1480 |
+
if '.mant_a_ge_b' in gate:
|
| 1481 |
+
mant_a_full = [registry.get_id(f"{prefix}.$a[{i}]") for i in range(10)] + \
|
| 1482 |
+
[registry.get_id(f"{prefix}.implicit_a")]
|
| 1483 |
+
mant_b_full = [registry.get_id(f"{prefix}.$b[{i}]") for i in range(10)] + \
|
| 1484 |
+
[registry.get_id(f"{prefix}.implicit_b")]
|
| 1485 |
+
return mant_a_full + mant_b_full
|
| 1486 |
+
|
| 1487 |
+
registry.register(f"{prefix}.mant_a_ge_b")
|
| 1488 |
+
|
| 1489 |
+
# NOT gates for mant_aligned (for subtraction)
|
| 1490 |
+
match = re.search(r'\.not_mant_aligned(\d+)$', gate)
|
| 1491 |
+
if match:
|
| 1492 |
+
i = int(match.group(1))
|
| 1493 |
+
return [registry.get_id(f"{prefix}.mant_aligned{i}")]
|
| 1494 |
+
|
| 1495 |
+
for i in range(11):
|
| 1496 |
+
registry.register(f"{prefix}.not_mant_aligned{i}")
|
| 1497 |
+
|
| 1498 |
+
# sub_cin = signs_differ
|
| 1499 |
+
if '.sub_cin' in gate:
|
| 1500 |
+
return [registry.get_id(f"{prefix}.signs_differ.layer2")]
|
| 1501 |
+
|
| 1502 |
+
registry.register(f"{prefix}.sub_cin")
|
| 1503 |
+
|
| 1504 |
+
# addsub_b selection
|
| 1505 |
+
match = re.search(r'\.addsub_b(\d+)\.', gate)
|
| 1506 |
+
if match:
|
| 1507 |
+
i = int(match.group(1))
|
| 1508 |
+
if '.add' in gate:
|
| 1509 |
+
return [registry.get_id(f"{prefix}.mant_aligned{i}"),
|
| 1510 |
+
registry.get_id(f"{prefix}.signs_same")]
|
| 1511 |
+
if '.sub' in gate:
|
| 1512 |
+
return [registry.get_id(f"{prefix}.not_mant_aligned{i}"),
|
| 1513 |
+
registry.get_id(f"{prefix}.signs_differ.layer2")]
|
| 1514 |
+
|
| 1515 |
+
match = re.search(r'\.addsub_b(\d+)$', gate)
|
| 1516 |
+
if match:
|
| 1517 |
+
i = int(match.group(1))
|
| 1518 |
+
return [registry.register(f"{prefix}.addsub_b{i}.add"),
|
| 1519 |
+
registry.register(f"{prefix}.addsub_b{i}.sub")]
|
| 1520 |
+
|
| 1521 |
+
for i in range(11):
|
| 1522 |
+
registry.register(f"{prefix}.addsub_b{i}")
|
| 1523 |
+
|
| 1524 |
+
# 12-bit mantissa adder
|
| 1525 |
+
if '.mant_add.fa' in gate:
|
| 1526 |
+
match = re.search(r'\.mant_add\.fa(\d+)\.', gate)
|
| 1527 |
+
if match:
|
| 1528 |
+
i = int(match.group(1))
|
| 1529 |
+
fa_prefix = f"{prefix}.mant_add.fa{i}"
|
| 1530 |
+
|
| 1531 |
+
if i < 11:
|
| 1532 |
+
a_bit = registry.get_id(f"{prefix}.mant_larger{i}")
|
| 1533 |
+
b_bit = registry.get_id(f"{prefix}.addsub_b{i}")
|
| 1534 |
+
else:
|
| 1535 |
+
a_bit = registry.get_id("#0")
|
| 1536 |
+
b_bit = registry.get_id("#0")
|
| 1537 |
+
|
| 1538 |
+
if i == 0:
|
| 1539 |
+
cin = registry.get_id(f"{prefix}.sub_cin")
|
| 1540 |
+
else:
|
| 1541 |
+
cin = registry.register(f"{prefix}.mant_add.fa{i-1}.cout")
|
| 1542 |
+
|
| 1543 |
+
if '.xor1.layer1' in gate:
|
| 1544 |
+
return [a_bit, b_bit]
|
| 1545 |
+
if '.xor1.layer2' in gate:
|
| 1546 |
+
return [registry.register(f"{fa_prefix}.xor1.layer1.or"),
|
| 1547 |
+
registry.register(f"{fa_prefix}.xor1.layer1.nand")]
|
| 1548 |
+
|
| 1549 |
+
xor1 = registry.register(f"{fa_prefix}.xor1.layer2")
|
| 1550 |
+
|
| 1551 |
+
if '.xor2.layer1' in gate:
|
| 1552 |
+
return [xor1, cin]
|
| 1553 |
+
if '.xor2.layer2' in gate:
|
| 1554 |
+
return [registry.register(f"{fa_prefix}.xor2.layer1.or"),
|
| 1555 |
+
registry.register(f"{fa_prefix}.xor2.layer1.nand")]
|
| 1556 |
+
|
| 1557 |
+
if '.and1' in gate:
|
| 1558 |
+
return [a_bit, b_bit]
|
| 1559 |
+
if '.and2' in gate:
|
| 1560 |
+
return [xor1, cin]
|
| 1561 |
+
if '.cout' in gate:
|
| 1562 |
+
return [registry.register(f"{fa_prefix}.and1"),
|
| 1563 |
+
registry.register(f"{fa_prefix}.and2")]
|
| 1564 |
+
|
| 1565 |
+
for i in range(12):
|
| 1566 |
+
registry.register(f"{prefix}.mant_add.fa{i}.xor2.layer2")
|
| 1567 |
+
registry.register(f"{prefix}.mant_add.fa{i}.cout")
|
| 1568 |
+
|
| 1569 |
+
# Result sign determination
|
| 1570 |
+
if '.not_a_exp_gt_b' in gate:
|
| 1571 |
+
return [registry.get_id(f"{prefix}.a_exp_gt_b")]
|
| 1572 |
+
|
| 1573 |
+
registry.register(f"{prefix}.not_a_exp_gt_b")
|
| 1574 |
+
|
| 1575 |
+
if '.exp_a_eq_b' in gate:
|
| 1576 |
+
return [registry.get_id(f"{prefix}.not_a_exp_gt_b"),
|
| 1577 |
+
registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
|
| 1578 |
+
|
| 1579 |
+
registry.register(f"{prefix}.exp_a_eq_b")
|
| 1580 |
+
|
| 1581 |
+
if '.exp_eq_and_mant_a_ge' in gate:
|
| 1582 |
+
return [registry.get_id(f"{prefix}.exp_a_eq_b"),
|
| 1583 |
+
registry.get_id(f"{prefix}.mant_a_ge_b")]
|
| 1584 |
+
|
| 1585 |
+
registry.register(f"{prefix}.exp_eq_and_mant_a_ge")
|
| 1586 |
+
|
| 1587 |
+
if '.a_magnitude_ge_b' in gate:
|
| 1588 |
+
return [registry.get_id(f"{prefix}.a_exp_gt_b"),
|
| 1589 |
+
registry.get_id(f"{prefix}.exp_eq_and_mant_a_ge")]
|
| 1590 |
+
|
| 1591 |
+
registry.register(f"{prefix}.a_magnitude_ge_b")
|
| 1592 |
+
|
| 1593 |
+
if '.not_a_mag_ge_b' in gate:
|
| 1594 |
+
return [registry.get_id(f"{prefix}.a_magnitude_ge_b")]
|
| 1595 |
+
|
| 1596 |
+
registry.register(f"{prefix}.not_a_mag_ge_b")
|
| 1597 |
+
|
| 1598 |
+
if '.diff_sign_sel_a' in gate:
|
| 1599 |
+
return [registry.get_id(f"{prefix}.sign_a"),
|
| 1600 |
+
registry.get_id(f"{prefix}.a_magnitude_ge_b")]
|
| 1601 |
+
|
| 1602 |
+
if '.diff_sign_sel_b' in gate:
|
| 1603 |
+
return [registry.get_id(f"{prefix}.sign_b"),
|
| 1604 |
+
registry.get_id(f"{prefix}.not_a_mag_ge_b")]
|
| 1605 |
+
|
| 1606 |
+
registry.register(f"{prefix}.diff_sign_sel_a")
|
| 1607 |
+
registry.register(f"{prefix}.diff_sign_sel_b")
|
| 1608 |
+
|
| 1609 |
+
if '.diff_result_sign' in gate:
|
| 1610 |
+
return [registry.get_id(f"{prefix}.diff_sign_sel_a"),
|
| 1611 |
+
registry.get_id(f"{prefix}.diff_sign_sel_b")]
|
| 1612 |
+
|
| 1613 |
+
registry.register(f"{prefix}.diff_result_sign")
|
| 1614 |
+
|
| 1615 |
+
if '.result_sign_same' in gate:
|
| 1616 |
+
return [registry.get_id(f"{prefix}.sign_a"),
|
| 1617 |
+
registry.get_id(f"{prefix}.signs_same")]
|
| 1618 |
+
|
| 1619 |
+
if '.result_sign_diff' in gate:
|
| 1620 |
+
return [registry.get_id(f"{prefix}.diff_result_sign"),
|
| 1621 |
+
registry.get_id(f"{prefix}.signs_differ.layer2")]
|
| 1622 |
+
|
| 1623 |
+
registry.register(f"{prefix}.result_sign_same")
|
| 1624 |
+
registry.register(f"{prefix}.result_sign_diff")
|
| 1625 |
+
|
| 1626 |
+
if gate == f"{prefix}.result_sign":
|
| 1627 |
+
return [registry.get_id(f"{prefix}.result_sign_same"),
|
| 1628 |
+
registry.get_id(f"{prefix}.result_sign_diff")]
|
| 1629 |
+
|
| 1630 |
+
registry.register(f"{prefix}.result_sign")
|
| 1631 |
+
|
| 1632 |
+
# Normalization - sum overflow (bit 11 of sum, not carry out)
|
| 1633 |
+
if '.sum_overflow' in gate:
|
| 1634 |
+
return [registry.get_id(f"{prefix}.mant_add.fa11.xor2.layer2")]
|
| 1635 |
+
|
| 1636 |
+
registry.register(f"{prefix}.sum_overflow")
|
| 1637 |
+
|
| 1638 |
+
# CLZ on bits 10:0 of sum for normalization (11 bits, not 12)
|
| 1639 |
+
sum_bits = [f"{prefix}.mant_add.fa{i}.xor2.layer2" for i in range(11)]
|
| 1640 |
+
|
| 1641 |
+
match = re.search(r'\.sum_pz(\d+)$', gate)
|
| 1642 |
+
if match:
|
| 1643 |
+
k = int(match.group(1))
|
| 1644 |
+
# Check bits 10, 9, 8, ... (from MSB to LSB of 11-bit sum)
|
| 1645 |
+
return [registry.get_id(sum_bits[10-i]) for i in range(k)]
|
| 1646 |
+
|
| 1647 |
+
for k in range(1, 12):
|
| 1648 |
+
registry.register(f"{prefix}.sum_pz{k}")
|
| 1649 |
+
|
| 1650 |
+
pz_ids = [registry.get_id(f"{prefix}.sum_pz{k}") for k in range(1, 12)]
|
| 1651 |
+
|
| 1652 |
+
match = re.search(r'\.sum_ge(\d+)$', gate)
|
| 1653 |
+
if match:
|
| 1654 |
+
return pz_ids
|
| 1655 |
+
|
| 1656 |
+
for k in range(1, 12):
|
| 1657 |
+
registry.register(f"{prefix}.sum_ge{k}")
|
| 1658 |
+
|
| 1659 |
+
match = re.search(r'\.sum_not_ge(\d+)$', gate)
|
| 1660 |
+
if match:
|
| 1661 |
+
k = int(match.group(1))
|
| 1662 |
+
return [registry.get_id(f"{prefix}.sum_ge{k}")]
|
| 1663 |
+
|
| 1664 |
+
for k in [2, 4, 6, 8, 10]:
|
| 1665 |
+
registry.register(f"{prefix}.sum_not_ge{k}")
|
| 1666 |
+
|
| 1667 |
+
if '.norm_shift3' in gate:
|
| 1668 |
+
return [registry.get_id(f"{prefix}.sum_ge8")]
|
| 1669 |
+
|
| 1670 |
+
if '.norm_and_4_7' in gate:
|
| 1671 |
+
return [registry.get_id(f"{prefix}.sum_ge4"),
|
| 1672 |
+
registry.get_id(f"{prefix}.sum_not_ge8")]
|
| 1673 |
+
|
| 1674 |
+
registry.register(f"{prefix}.norm_and_4_7")
|
| 1675 |
+
|
| 1676 |
+
# For 11-bit CLZ (max 11), shift2 = norm_and_4_7 only
|
| 1677 |
+
if '.norm_shift2' in gate:
|
| 1678 |
+
return [registry.get_id(f"{prefix}.norm_and_4_7")]
|
| 1679 |
+
|
| 1680 |
+
if '.norm_and_2_3' in gate:
|
| 1681 |
+
return [registry.get_id(f"{prefix}.sum_ge2"),
|
| 1682 |
+
registry.get_id(f"{prefix}.sum_not_ge4")]
|
| 1683 |
+
if '.norm_and_6_7' in gate:
|
| 1684 |
+
return [registry.get_id(f"{prefix}.sum_ge6"),
|
| 1685 |
+
registry.get_id(f"{prefix}.sum_not_ge8")]
|
| 1686 |
+
# For 11-bit CLZ (max 11), ge10 is sufficient (CLZ 10 or 11)
|
| 1687 |
+
if '.norm_and_10_11' in gate:
|
| 1688 |
+
return [registry.get_id(f"{prefix}.sum_ge10")]
|
| 1689 |
+
|
| 1690 |
+
registry.register(f"{prefix}.norm_and_2_3")
|
| 1691 |
+
registry.register(f"{prefix}.norm_and_6_7")
|
| 1692 |
+
registry.register(f"{prefix}.norm_and_10_11")
|
| 1693 |
+
|
| 1694 |
+
if '.norm_shift1' in gate:
|
| 1695 |
+
return [registry.get_id(f"{prefix}.norm_and_2_3"),
|
| 1696 |
+
registry.get_id(f"{prefix}.norm_and_6_7"),
|
| 1697 |
+
registry.get_id(f"{prefix}.norm_and_10_11")]
|
| 1698 |
+
|
| 1699 |
+
match = re.search(r'\.norm_and_(\d+)$', gate)
|
| 1700 |
+
if match:
|
| 1701 |
+
i = int(match.group(1))
|
| 1702 |
+
if i in [1, 3, 5, 7, 9]:
|
| 1703 |
+
return [registry.get_id(f"{prefix}.sum_ge{i}"),
|
| 1704 |
+
registry.get_id(f"{prefix}.sum_not_ge{i+1}")]
|
| 1705 |
+
|
| 1706 |
+
for i in [1, 3, 5, 7, 9]:
|
| 1707 |
+
registry.register(f"{prefix}.norm_and_{i}")
|
| 1708 |
+
|
| 1709 |
+
if '.norm_shift0' in gate:
|
| 1710 |
+
return [registry.get_id(f"{prefix}.norm_and_{i}") for i in [1, 3, 5, 7, 9]]
|
| 1711 |
+
|
| 1712 |
+
for i in range(4):
|
| 1713 |
+
registry.register(f"{prefix}.norm_shift{i}")
|
| 1714 |
+
|
| 1715 |
+
# Stage 10: Normalization application
|
| 1716 |
+
if '.not_sum_overflow' in gate:
|
| 1717 |
+
return [registry.get_id(f"{prefix}.sum_overflow")]
|
| 1718 |
+
|
| 1719 |
+
registry.register(f"{prefix}.not_sum_overflow")
|
| 1720 |
+
|
| 1721 |
+
# Overflow mantissa (right-shift by 1)
|
| 1722 |
+
match = re.search(r'\.norm_mant_overflow(\d+)$', gate)
|
| 1723 |
+
if match:
|
| 1724 |
+
i = int(match.group(1))
|
| 1725 |
+
return [registry.get_id(f"{prefix}.mant_add.fa{i+1}.xor2.layer2")]
|
| 1726 |
+
|
| 1727 |
+
for i in range(10):
|
| 1728 |
+
registry.register(f"{prefix}.norm_mant_overflow{i}")
|
| 1729 |
+
|
| 1730 |
+
# Left barrel shifter NOT gates
|
| 1731 |
+
for i in range(4):
|
| 1732 |
+
if f'.not_norm_shift{i}' in gate and '.not_norm_shift_sub' not in gate:
|
| 1733 |
+
return [registry.get_id(f"{prefix}.norm_shift{i}")]
|
| 1734 |
+
registry.register(f"{prefix}.not_norm_shift{i}")
|
| 1735 |
+
|
| 1736 |
+
# Left barrel shifter stage 0
|
| 1737 |
+
match = re.search(r'\.lshift_s0_(\d+)\.', gate)
|
| 1738 |
+
if match:
|
| 1739 |
+
i = int(match.group(1))
|
| 1740 |
+
if '.pass' in gate:
|
| 1741 |
+
return [registry.get_id(f"{prefix}.mant_add.fa{i}.xor2.layer2"),
|
| 1742 |
+
registry.get_id(f"{prefix}.not_norm_shift0")]
|
| 1743 |
+
if '.shift' in gate and i > 0:
|
| 1744 |
+
return [registry.get_id(f"{prefix}.mant_add.fa{i-1}.xor2.layer2"),
|
| 1745 |
+
registry.get_id(f"{prefix}.norm_shift0")]
|
| 1746 |
+
|
| 1747 |
+
match = re.search(r'\.lshift_s0_(\d+)$', gate)
|
| 1748 |
+
if match:
|
| 1749 |
+
i = int(match.group(1))
|
| 1750 |
+
if i > 0:
|
| 1751 |
+
return [registry.register(f"{prefix}.lshift_s0_{i}.pass"),
|
| 1752 |
+
registry.register(f"{prefix}.lshift_s0_{i}.shift")]
|
| 1753 |
+
else:
|
| 1754 |
+
return [registry.register(f"{prefix}.lshift_s0_{i}.pass")]
|
| 1755 |
+
|
| 1756 |
+
for i in range(11):
|
| 1757 |
+
registry.register(f"{prefix}.lshift_s0_{i}")
|
| 1758 |
+
|
| 1759 |
+
# Left barrel shifter stage 1
|
| 1760 |
+
match = re.search(r'\.lshift_s1_(\d+)\.', gate)
|
| 1761 |
+
if match:
|
| 1762 |
+
i = int(match.group(1))
|
| 1763 |
+
if '.pass' in gate:
|
| 1764 |
+
return [registry.get_id(f"{prefix}.lshift_s0_{i}"),
|
| 1765 |
+
registry.get_id(f"{prefix}.not_norm_shift1")]
|
| 1766 |
+
if '.shift' in gate and i > 1:
|
| 1767 |
+
return [registry.get_id(f"{prefix}.lshift_s0_{i-2}"),
|
| 1768 |
+
registry.get_id(f"{prefix}.norm_shift1")]
|
| 1769 |
+
|
| 1770 |
+
match = re.search(r'\.lshift_s1_(\d+)$', gate)
|
| 1771 |
+
if match:
|
| 1772 |
+
i = int(match.group(1))
|
| 1773 |
+
if i > 1:
|
| 1774 |
+
return [registry.register(f"{prefix}.lshift_s1_{i}.pass"),
|
| 1775 |
+
registry.register(f"{prefix}.lshift_s1_{i}.shift")]
|
| 1776 |
+
else:
|
| 1777 |
+
return [registry.register(f"{prefix}.lshift_s1_{i}.pass")]
|
| 1778 |
+
|
| 1779 |
+
for i in range(11):
|
| 1780 |
+
registry.register(f"{prefix}.lshift_s1_{i}")
|
| 1781 |
+
|
| 1782 |
+
# Left barrel shifter stage 2
|
| 1783 |
+
match = re.search(r'\.lshift_s2_(\d+)\.', gate)
|
| 1784 |
+
if match:
|
| 1785 |
+
i = int(match.group(1))
|
| 1786 |
+
if '.pass' in gate:
|
| 1787 |
+
return [registry.get_id(f"{prefix}.lshift_s1_{i}"),
|
| 1788 |
+
registry.get_id(f"{prefix}.not_norm_shift2")]
|
| 1789 |
+
if '.shift' in gate and i > 3:
|
| 1790 |
+
return [registry.get_id(f"{prefix}.lshift_s1_{i-4}"),
|
| 1791 |
+
registry.get_id(f"{prefix}.norm_shift2")]
|
| 1792 |
+
|
| 1793 |
+
match = re.search(r'\.lshift_s2_(\d+)$', gate)
|
| 1794 |
+
if match:
|
| 1795 |
+
i = int(match.group(1))
|
| 1796 |
+
if i > 3:
|
| 1797 |
+
return [registry.register(f"{prefix}.lshift_s2_{i}.pass"),
|
| 1798 |
+
registry.register(f"{prefix}.lshift_s2_{i}.shift")]
|
| 1799 |
+
else:
|
| 1800 |
+
return [registry.register(f"{prefix}.lshift_s2_{i}.pass")]
|
| 1801 |
+
|
| 1802 |
+
for i in range(11):
|
| 1803 |
+
registry.register(f"{prefix}.lshift_s2_{i}")
|
| 1804 |
+
|
| 1805 |
+
# Left barrel shifter stage 3
|
| 1806 |
+
match = re.search(r'\.lshift_s3_(\d+)\.', gate)
|
| 1807 |
+
if match:
|
| 1808 |
+
i = int(match.group(1))
|
| 1809 |
+
if '.pass' in gate:
|
| 1810 |
+
return [registry.get_id(f"{prefix}.lshift_s2_{i}"),
|
| 1811 |
+
registry.get_id(f"{prefix}.not_norm_shift3")]
|
| 1812 |
+
if '.shift' in gate and i > 7:
|
| 1813 |
+
return [registry.get_id(f"{prefix}.lshift_s2_{i-8}"),
|
| 1814 |
+
registry.get_id(f"{prefix}.norm_shift3")]
|
| 1815 |
+
|
| 1816 |
+
match = re.search(r'\.lshift_s3_(\d+)$', gate)
|
| 1817 |
+
if match:
|
| 1818 |
+
i = int(match.group(1))
|
| 1819 |
+
if i > 7:
|
| 1820 |
+
return [registry.register(f"{prefix}.lshift_s3_{i}.pass"),
|
| 1821 |
+
registry.register(f"{prefix}.lshift_s3_{i}.shift")]
|
| 1822 |
+
else:
|
| 1823 |
+
return [registry.register(f"{prefix}.lshift_s3_{i}.pass")]
|
| 1824 |
+
|
| 1825 |
+
for i in range(11):
|
| 1826 |
+
registry.register(f"{prefix}.lshift_s3_{i}")
|
| 1827 |
+
|
| 1828 |
+
# Normalized mantissa selection
|
| 1829 |
+
match = re.search(r'\.norm_mant(\d+)\.', gate)
|
| 1830 |
+
if match:
|
| 1831 |
+
i = int(match.group(1))
|
| 1832 |
+
if '.overflow_path' in gate:
|
| 1833 |
+
return [registry.get_id(f"{prefix}.norm_mant_overflow{i}"),
|
| 1834 |
+
registry.get_id(f"{prefix}.sum_overflow")]
|
| 1835 |
+
if '.normal_path' in gate:
|
| 1836 |
+
return [registry.get_id(f"{prefix}.lshift_s3_{i}"),
|
| 1837 |
+
registry.get_id(f"{prefix}.not_sum_overflow")]
|
| 1838 |
+
|
| 1839 |
+
match = re.search(r'\.norm_mant(\d+)$', gate)
|
| 1840 |
+
if match:
|
| 1841 |
+
i = int(match.group(1))
|
| 1842 |
+
return [registry.register(f"{prefix}.norm_mant{i}.overflow_path"),
|
| 1843 |
+
registry.register(f"{prefix}.norm_mant{i}.normal_path")]
|
| 1844 |
+
|
| 1845 |
+
for i in range(10):
|
| 1846 |
+
registry.register(f"{prefix}.norm_mant{i}")
|
| 1847 |
+
|
| 1848 |
+
# Exponent increment (for overflow)
|
| 1849 |
+
if '.exp_inc.ha0.sum' in gate:
|
| 1850 |
+
return [registry.get_id(f"{prefix}.exp_larger0")]
|
| 1851 |
+
if '.exp_inc.ha0.cout' in gate:
|
| 1852 |
+
return [registry.get_id(f"{prefix}.exp_larger0")]
|
| 1853 |
+
|
| 1854 |
+
registry.register(f"{prefix}.exp_inc.ha0.sum")
|
| 1855 |
+
registry.register(f"{prefix}.exp_inc.ha0.cout")
|
| 1856 |
+
|
| 1857 |
+
for i in range(1, 5):
|
| 1858 |
+
if f'.exp_inc.ha{i}.xor.layer1' in gate:
|
| 1859 |
+
return [registry.get_id(f"{prefix}.exp_larger{i}"),
|
| 1860 |
+
registry.get_id(f"{prefix}.exp_inc.ha{i-1}.cout")]
|
| 1861 |
+
if f'.exp_inc.ha{i}.sum' in gate:
|
| 1862 |
+
return [registry.register(f"{prefix}.exp_inc.ha{i}.xor.layer1.or"),
|
| 1863 |
+
registry.register(f"{prefix}.exp_inc.ha{i}.xor.layer1.nand")]
|
| 1864 |
+
if f'.exp_inc.ha{i}.cout' in gate:
|
| 1865 |
+
return [registry.get_id(f"{prefix}.exp_larger{i}"),
|
| 1866 |
+
registry.get_id(f"{prefix}.exp_inc.ha{i-1}.cout")]
|
| 1867 |
+
registry.register(f"{prefix}.exp_inc.ha{i}.sum")
|
| 1868 |
+
registry.register(f"{prefix}.exp_inc.ha{i}.cout")
|
| 1869 |
+
|
| 1870 |
+
# Exponent decrement NOT gates
|
| 1871 |
+
for i in range(4):
|
| 1872 |
+
if f'.not_norm_shift_sub{i}' in gate:
|
| 1873 |
+
return [registry.get_id(f"{prefix}.norm_shift{i}")]
|
| 1874 |
+
registry.register(f"{prefix}.not_norm_shift_sub{i}")
|
| 1875 |
+
|
| 1876 |
+
# Exponent decrement (for no overflow)
|
| 1877 |
+
if '.exp_dec.fa' in gate:
|
| 1878 |
+
match = re.search(r'\.exp_dec\.fa(\d+)\.', gate)
|
| 1879 |
+
if match:
|
| 1880 |
+
i = int(match.group(1))
|
| 1881 |
+
fa_prefix = f"{prefix}.exp_dec.fa{i}"
|
| 1882 |
+
|
| 1883 |
+
exp_bit = registry.get_id(f"{prefix}.exp_larger{i}")
|
| 1884 |
+
if i < 4:
|
| 1885 |
+
not_shift = registry.get_id(f"{prefix}.not_norm_shift_sub{i}")
|
| 1886 |
+
else:
|
| 1887 |
+
not_shift = registry.get_id("#1")
|
| 1888 |
+
|
| 1889 |
+
if i == 0:
|
| 1890 |
+
cin = registry.get_id("#1")
|
| 1891 |
+
else:
|
| 1892 |
+
cin = registry.register(f"{prefix}.exp_dec.fa{i-1}.cout")
|
| 1893 |
+
|
| 1894 |
+
if '.xor1.layer1' in gate:
|
| 1895 |
+
return [exp_bit, not_shift]
|
| 1896 |
+
if '.xor1.layer2' in gate:
|
| 1897 |
+
return [registry.register(f"{fa_prefix}.xor1.layer1.or"),
|
| 1898 |
+
registry.register(f"{fa_prefix}.xor1.layer1.nand")]
|
| 1899 |
+
|
| 1900 |
+
xor1 = registry.register(f"{fa_prefix}.xor1.layer2")
|
| 1901 |
+
|
| 1902 |
+
if '.xor2.layer1' in gate:
|
| 1903 |
+
return [xor1, cin]
|
| 1904 |
+
if '.xor2.layer2' in gate:
|
| 1905 |
+
return [registry.register(f"{fa_prefix}.xor2.layer1.or"),
|
| 1906 |
+
registry.register(f"{fa_prefix}.xor2.layer1.nand")]
|
| 1907 |
+
|
| 1908 |
+
if '.and1' in gate:
|
| 1909 |
+
return [exp_bit, not_shift]
|
| 1910 |
+
if '.and2' in gate:
|
| 1911 |
+
return [xor1, cin]
|
| 1912 |
+
if '.cout' in gate:
|
| 1913 |
+
return [registry.register(f"{fa_prefix}.and1"),
|
| 1914 |
+
registry.register(f"{fa_prefix}.and2")]
|
| 1915 |
+
|
| 1916 |
+
for i in range(5):
|
| 1917 |
+
registry.register(f"{prefix}.exp_dec.fa{i}.xor2.layer2")
|
| 1918 |
+
registry.register(f"{prefix}.exp_dec.fa{i}.cout")
|
| 1919 |
+
|
| 1920 |
+
# Result exponent selection
|
| 1921 |
+
match = re.search(r'\.result_exp(\d+)\.', gate)
|
| 1922 |
+
if match:
|
| 1923 |
+
i = int(match.group(1))
|
| 1924 |
+
if '.overflow_path' in gate:
|
| 1925 |
+
if i == 0:
|
| 1926 |
+
return [registry.get_id(f"{prefix}.exp_inc.ha0.sum"),
|
| 1927 |
+
registry.get_id(f"{prefix}.sum_overflow")]
|
| 1928 |
+
else:
|
| 1929 |
+
return [registry.get_id(f"{prefix}.exp_inc.ha{i}.sum"),
|
| 1930 |
+
registry.get_id(f"{prefix}.sum_overflow")]
|
| 1931 |
+
if '.normal_path' in gate:
|
| 1932 |
+
return [registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2"),
|
| 1933 |
+
registry.get_id(f"{prefix}.not_sum_overflow")]
|
| 1934 |
+
|
| 1935 |
+
match = re.search(r'\.result_exp(\d+)$', gate)
|
| 1936 |
+
if match:
|
| 1937 |
+
i = int(match.group(1))
|
| 1938 |
+
return [registry.register(f"{prefix}.result_exp{i}.overflow_path"),
|
| 1939 |
+
registry.register(f"{prefix}.result_exp{i}.normal_path")]
|
| 1940 |
+
|
| 1941 |
+
for i in range(5):
|
| 1942 |
+
registry.register(f"{prefix}.result_exp{i}")
|
| 1943 |
+
|
| 1944 |
+
# Output assembly
|
| 1945 |
+
if '.not_result_is_inf' in gate:
|
| 1946 |
+
return [registry.get_id(f"{prefix}.result_is_inf")]
|
| 1947 |
+
|
| 1948 |
+
registry.register(f"{prefix}.not_result_is_inf")
|
| 1949 |
+
registry.register(f"{prefix}.result_is_inf")
|
| 1950 |
+
|
| 1951 |
+
if '.is_normal_result' in gate:
|
| 1952 |
+
return [registry.get_id(f"{prefix}.not_result_is_nan"),
|
| 1953 |
+
registry.get_id(f"{prefix}.not_result_is_inf")]
|
| 1954 |
+
|
| 1955 |
+
registry.register(f"{prefix}.is_normal_result")
|
| 1956 |
+
|
| 1957 |
+
# Inf sign selection
|
| 1958 |
+
if '.inf_sign_sel_a' in gate:
|
| 1959 |
+
return [registry.get_id(f"{prefix}.sign_a"),
|
| 1960 |
+
registry.get_id(f"{prefix}.a_is_inf")]
|
| 1961 |
+
if '.inf_sign_sel_b' in gate:
|
| 1962 |
+
return [registry.get_id(f"{prefix}.sign_b"),
|
| 1963 |
+
registry.get_id(f"{prefix}.b_is_inf")]
|
| 1964 |
+
|
| 1965 |
+
registry.register(f"{prefix}.inf_sign_sel_a")
|
| 1966 |
+
registry.register(f"{prefix}.inf_sign_sel_b")
|
| 1967 |
+
|
| 1968 |
+
if '.inf_sign' in gate and '.inf_sign_sel' not in gate:
|
| 1969 |
+
return [registry.get_id(f"{prefix}.inf_sign_sel_a"),
|
| 1970 |
+
registry.get_id(f"{prefix}.inf_sign_sel_b")]
|
| 1971 |
+
|
| 1972 |
+
registry.register(f"{prefix}.inf_sign")
|
| 1973 |
+
|
| 1974 |
+
# NaN bits
|
| 1975 |
+
nan_bits = [0]*9 + [1] + [1]*5 + [0]
|
| 1976 |
+
match = re.search(r'\.out_nan(\d+)$', gate)
|
| 1977 |
+
if match:
|
| 1978 |
+
return [registry.get_id(f"{prefix}.result_is_nan")]
|
| 1979 |
+
|
| 1980 |
+
# Inf bits
|
| 1981 |
+
match = re.search(r'\.out_inf(\d+)$', gate)
|
| 1982 |
+
if match:
|
| 1983 |
+
return [registry.get_id(f"{prefix}.result_is_inf")]
|
| 1984 |
+
|
| 1985 |
+
# Normal output path
|
| 1986 |
+
match = re.search(r'\.out_normal(\d+)$', gate)
|
| 1987 |
+
if match:
|
| 1988 |
+
i = int(match.group(1))
|
| 1989 |
+
if i == 15:
|
| 1990 |
+
return [registry.get_id(f"{prefix}.result_sign")]
|
| 1991 |
+
elif i >= 10:
|
| 1992 |
+
return [registry.get_id(f"{prefix}.result_exp{i-10}")]
|
| 1993 |
+
else:
|
| 1994 |
+
return [registry.get_id(f"{prefix}.norm_mant{i}")]
|
| 1995 |
+
|
| 1996 |
+
for i in range(16):
|
| 1997 |
+
registry.register(f"{prefix}.out_normal{i}")
|
| 1998 |
+
|
| 1999 |
+
# Final output gates
|
| 2000 |
+
match = re.search(r'\.out(\d+)\.(nan_gate|inf_gate|normal_gate)$', gate)
|
| 2001 |
+
if match:
|
| 2002 |
+
i = int(match.group(1))
|
| 2003 |
+
gate_type = match.group(2)
|
| 2004 |
+
if gate_type == 'nan_gate':
|
| 2005 |
+
nan_val = registry.register(f"{prefix}.out_nan{i}") if nan_bits[i] else registry.get_id("#0")
|
| 2006 |
+
return [nan_val, registry.get_id(f"{prefix}.result_is_nan")]
|
| 2007 |
+
elif gate_type == 'inf_gate':
|
| 2008 |
+
if i >= 10 and i < 15:
|
| 2009 |
+
inf_val = registry.register(f"{prefix}.out_inf{i}")
|
| 2010 |
+
elif i == 15:
|
| 2011 |
+
inf_val = registry.get_id(f"{prefix}.inf_sign")
|
| 2012 |
+
else:
|
| 2013 |
+
inf_val = registry.get_id("#0")
|
| 2014 |
+
return [inf_val, registry.get_id(f"{prefix}.result_is_inf")]
|
| 2015 |
+
elif gate_type == 'normal_gate':
|
| 2016 |
+
return [registry.get_id(f"{prefix}.out_normal{i}"),
|
| 2017 |
+
registry.get_id(f"{prefix}.is_normal_result")]
|
| 2018 |
+
|
| 2019 |
+
match = re.search(r'\.out(\d+)$', gate)
|
| 2020 |
+
if match:
|
| 2021 |
+
i = int(match.group(1))
|
| 2022 |
+
return [registry.register(f"{prefix}.out{i}.nan_gate"),
|
| 2023 |
+
registry.register(f"{prefix}.out{i}.inf_gate"),
|
| 2024 |
+
registry.register(f"{prefix}.out{i}.normal_gate")]
|
| 2025 |
+
|
| 2026 |
+
return []
|
| 2027 |
+
|
| 2028 |
+
|
| 2029 |
def infer_float16_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 2030 |
"""Infer inputs for float16.neg circuit."""
|
| 2031 |
prefix = "float16.neg"
|
|
|
|
| 2691 |
return tensors
|
| 2692 |
|
| 2693 |
|
| 2694 |
+
def build_float16_add_tensors() -> Dict[str, torch.Tensor]:
|
| 2695 |
+
"""Build tensors for float16.add circuit.
|
| 2696 |
+
|
| 2697 |
+
IEEE 754 half-precision addition with full special case handling:
|
| 2698 |
+
1. Detect special cases (NaN, infinity, zero, subnormal)
|
| 2699 |
+
2. Extract sign, exponent, mantissa from both operands
|
| 2700 |
+
3. Add implicit bit (1 for normal, 0 for subnormal)
|
| 2701 |
+
4. Compare exponents to find which is larger
|
| 2702 |
+
5. Align mantissas by shifting smaller exponent's mantissa right
|
| 2703 |
+
6. Add or subtract mantissas based on signs
|
| 2704 |
+
7. Normalize result and adjust exponent
|
| 2705 |
+
8. Handle overflow (to infinity) and underflow (to zero/subnormal)
|
| 2706 |
+
9. Pack result with correct special case outputs
|
| 2707 |
+
|
| 2708 |
+
Inputs: $a[0:15], $b[0:15] (two float16 values)
|
| 2709 |
+
Outputs: out[0:15] (float16 result)
|
| 2710 |
+
"""
|
| 2711 |
+
tensors = {}
|
| 2712 |
+
prefix = "float16.add"
|
| 2713 |
+
|
| 2714 |
+
# =========================================================================
|
| 2715 |
+
# STAGE 0: SPECIAL CASE DETECTION
|
| 2716 |
+
# =========================================================================
|
| 2717 |
+
# Detect NaN, infinity, zero, and subnormal inputs.
|
| 2718 |
+
# float16 encoding:
|
| 2719 |
+
# - Zero: exp=0, mant=0
|
| 2720 |
+
# - Subnormal: exp=0, mant≠0
|
| 2721 |
+
# - Normal: 0 < exp < 31
|
| 2722 |
+
# - Infinity: exp=31, mant=0
|
| 2723 |
+
# - NaN: exp=31, mant≠0
|
| 2724 |
+
|
| 2725 |
+
# exp_a_all_ones: all 5 exponent bits are 1 (exp >= 31)
|
| 2726 |
+
# Threshold gate: sum of exp bits >= 5
|
| 2727 |
+
tensors[f"{prefix}.exp_a_all_ones.weight"] = torch.tensor([1.0] * 5)
|
| 2728 |
+
tensors[f"{prefix}.exp_a_all_ones.bias"] = torch.tensor([-5.0])
|
| 2729 |
+
|
| 2730 |
+
tensors[f"{prefix}.exp_b_all_ones.weight"] = torch.tensor([1.0] * 5)
|
| 2731 |
+
tensors[f"{prefix}.exp_b_all_ones.bias"] = torch.tensor([-5.0])
|
| 2732 |
+
|
| 2733 |
+
# exp_a_zero: all 5 exponent bits are 0 (NOR gate)
|
| 2734 |
+
tensors[f"{prefix}.exp_a_zero.weight"] = torch.tensor([-1.0] * 5)
|
| 2735 |
+
tensors[f"{prefix}.exp_a_zero.bias"] = torch.tensor([0.0])
|
| 2736 |
+
|
| 2737 |
+
tensors[f"{prefix}.exp_b_zero.weight"] = torch.tensor([-1.0] * 5)
|
| 2738 |
+
tensors[f"{prefix}.exp_b_zero.bias"] = torch.tensor([0.0])
|
| 2739 |
+
|
| 2740 |
+
# mant_a_nonzero: OR of all 10 mantissa bits
|
| 2741 |
+
tensors[f"{prefix}.mant_a_nonzero.weight"] = torch.tensor([1.0] * 10)
|
| 2742 |
+
tensors[f"{prefix}.mant_a_nonzero.bias"] = torch.tensor([-1.0])
|
| 2743 |
+
|
| 2744 |
+
tensors[f"{prefix}.mant_b_nonzero.weight"] = torch.tensor([1.0] * 10)
|
| 2745 |
+
tensors[f"{prefix}.mant_b_nonzero.bias"] = torch.tensor([-1.0])
|
| 2746 |
+
|
| 2747 |
+
# mant_a_zero: NOR of all mantissa bits
|
| 2748 |
+
tensors[f"{prefix}.mant_a_zero.weight"] = torch.tensor([-1.0] * 10)
|
| 2749 |
+
tensors[f"{prefix}.mant_a_zero.bias"] = torch.tensor([0.0])
|
| 2750 |
+
|
| 2751 |
+
tensors[f"{prefix}.mant_b_zero.weight"] = torch.tensor([-1.0] * 10)
|
| 2752 |
+
tensors[f"{prefix}.mant_b_zero.bias"] = torch.tensor([0.0])
|
| 2753 |
+
|
| 2754 |
+
# a_is_nan: exp_a_all_ones AND mant_a_nonzero
|
| 2755 |
+
tensors[f"{prefix}.a_is_nan.weight"] = torch.tensor([1.0, 1.0])
|
| 2756 |
+
tensors[f"{prefix}.a_is_nan.bias"] = torch.tensor([-2.0])
|
| 2757 |
+
|
| 2758 |
+
tensors[f"{prefix}.b_is_nan.weight"] = torch.tensor([1.0, 1.0])
|
| 2759 |
+
tensors[f"{prefix}.b_is_nan.bias"] = torch.tensor([-2.0])
|
| 2760 |
+
|
| 2761 |
+
# a_is_inf: exp_a_all_ones AND mant_a_zero
|
| 2762 |
+
tensors[f"{prefix}.a_is_inf.weight"] = torch.tensor([1.0, 1.0])
|
| 2763 |
+
tensors[f"{prefix}.a_is_inf.bias"] = torch.tensor([-2.0])
|
| 2764 |
+
|
| 2765 |
+
tensors[f"{prefix}.b_is_inf.weight"] = torch.tensor([1.0, 1.0])
|
| 2766 |
+
tensors[f"{prefix}.b_is_inf.bias"] = torch.tensor([-2.0])
|
| 2767 |
+
|
| 2768 |
+
# a_is_zero: exp_a_zero AND mant_a_zero
|
| 2769 |
+
tensors[f"{prefix}.a_is_zero.weight"] = torch.tensor([1.0, 1.0])
|
| 2770 |
+
tensors[f"{prefix}.a_is_zero.bias"] = torch.tensor([-2.0])
|
| 2771 |
+
|
| 2772 |
+
tensors[f"{prefix}.b_is_zero.weight"] = torch.tensor([1.0, 1.0])
|
| 2773 |
+
tensors[f"{prefix}.b_is_zero.bias"] = torch.tensor([-2.0])
|
| 2774 |
+
|
| 2775 |
+
# a_is_subnormal: exp_a_zero AND mant_a_nonzero
|
| 2776 |
+
tensors[f"{prefix}.a_is_subnormal.weight"] = torch.tensor([1.0, 1.0])
|
| 2777 |
+
tensors[f"{prefix}.a_is_subnormal.bias"] = torch.tensor([-2.0])
|
| 2778 |
+
|
| 2779 |
+
tensors[f"{prefix}.b_is_subnormal.weight"] = torch.tensor([1.0, 1.0])
|
| 2780 |
+
tensors[f"{prefix}.b_is_subnormal.bias"] = torch.tensor([-2.0])
|
| 2781 |
+
|
| 2782 |
+
# either_is_nan: a_is_nan OR b_is_nan
|
| 2783 |
+
tensors[f"{prefix}.either_is_nan.weight"] = torch.tensor([1.0, 1.0])
|
| 2784 |
+
tensors[f"{prefix}.either_is_nan.bias"] = torch.tensor([-1.0])
|
| 2785 |
+
|
| 2786 |
+
# both_are_inf: a_is_inf AND b_is_inf
|
| 2787 |
+
tensors[f"{prefix}.both_are_inf.weight"] = torch.tensor([1.0, 1.0])
|
| 2788 |
+
tensors[f"{prefix}.both_are_inf.bias"] = torch.tensor([-2.0])
|
| 2789 |
+
|
| 2790 |
+
# signs_differ: sign_a XOR sign_b (for inf + (-inf) = NaN case)
|
| 2791 |
+
# XOR layer 1
|
| 2792 |
+
tensors[f"{prefix}.signs_differ.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 2793 |
+
tensors[f"{prefix}.signs_differ.layer1.or.bias"] = torch.tensor([-1.0])
|
| 2794 |
+
tensors[f"{prefix}.signs_differ.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 2795 |
+
tensors[f"{prefix}.signs_differ.layer1.nand.bias"] = torch.tensor([1.0])
|
| 2796 |
+
tensors[f"{prefix}.signs_differ.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 2797 |
+
tensors[f"{prefix}.signs_differ.layer2.bias"] = torch.tensor([-2.0])
|
| 2798 |
+
|
| 2799 |
+
# inf_cancellation: both_are_inf AND signs_differ (produces NaN)
|
| 2800 |
+
tensors[f"{prefix}.inf_cancellation.weight"] = torch.tensor([1.0, 1.0])
|
| 2801 |
+
tensors[f"{prefix}.inf_cancellation.bias"] = torch.tensor([-2.0])
|
| 2802 |
+
|
| 2803 |
+
# result_is_nan: either_is_nan OR inf_cancellation
|
| 2804 |
+
tensors[f"{prefix}.result_is_nan.weight"] = torch.tensor([1.0, 1.0])
|
| 2805 |
+
tensors[f"{prefix}.result_is_nan.bias"] = torch.tensor([-1.0])
|
| 2806 |
+
|
| 2807 |
+
# either_is_inf: a_is_inf OR b_is_inf
|
| 2808 |
+
tensors[f"{prefix}.either_is_inf.weight"] = torch.tensor([1.0, 1.0])
|
| 2809 |
+
tensors[f"{prefix}.either_is_inf.bias"] = torch.tensor([-1.0])
|
| 2810 |
+
|
| 2811 |
+
# NOT result_is_nan (for masking inf result)
|
| 2812 |
+
tensors[f"{prefix}.not_result_is_nan.weight"] = torch.tensor([-1.0])
|
| 2813 |
+
tensors[f"{prefix}.not_result_is_nan.bias"] = torch.tensor([0.0])
|
| 2814 |
+
|
| 2815 |
+
# result_is_inf: either_is_inf AND NOT result_is_nan
|
| 2816 |
+
tensors[f"{prefix}.result_is_inf.weight"] = torch.tensor([1.0, 1.0])
|
| 2817 |
+
tensors[f"{prefix}.result_is_inf.bias"] = torch.tensor([-2.0])
|
| 2818 |
+
|
| 2819 |
+
# =========================================================================
|
| 2820 |
+
# STAGE 1: EXTRACT COMPONENTS
|
| 2821 |
+
# =========================================================================
|
| 2822 |
+
# sign_a = a[15], sign_b = b[15]
|
| 2823 |
+
# exp_a[0:4] = a[10:14], exp_b[0:4] = b[10:14]
|
| 2824 |
+
# mant_a[0:9] = a[0:9], mant_b[0:9] = b[0:9]
|
| 2825 |
+
|
| 2826 |
+
# Pass-through gates for sign extraction
|
| 2827 |
+
tensors[f"{prefix}.sign_a.weight"] = torch.tensor([1.0])
|
| 2828 |
+
tensors[f"{prefix}.sign_a.bias"] = torch.tensor([-0.5])
|
| 2829 |
+
|
| 2830 |
+
tensors[f"{prefix}.sign_b.weight"] = torch.tensor([1.0])
|
| 2831 |
+
tensors[f"{prefix}.sign_b.bias"] = torch.tensor([-0.5])
|
| 2832 |
+
|
| 2833 |
+
# Implicit bit calculation:
|
| 2834 |
+
# For normal numbers, implicit bit = 1
|
| 2835 |
+
# For subnormal numbers, implicit bit = 0
|
| 2836 |
+
# implicit_a = NOT a_is_subnormal AND NOT a_is_zero = NOT exp_a_zero
|
| 2837 |
+
# Actually simpler: implicit_a = NOT exp_a_zero (since exp=0 means no implicit 1)
|
| 2838 |
+
tensors[f"{prefix}.implicit_a.weight"] = torch.tensor([-1.0])
|
| 2839 |
+
tensors[f"{prefix}.implicit_a.bias"] = torch.tensor([0.0])
|
| 2840 |
+
|
| 2841 |
+
tensors[f"{prefix}.implicit_b.weight"] = torch.tensor([-1.0])
|
| 2842 |
+
tensors[f"{prefix}.implicit_b.bias"] = torch.tensor([0.0])
|
| 2843 |
+
|
| 2844 |
+
# =========================================================================
|
| 2845 |
+
# STAGE 2: EXPONENT COMPARISON
|
| 2846 |
+
# =========================================================================
|
| 2847 |
+
# Compare exp_a vs exp_b using weighted comparison
|
| 2848 |
+
# Weights: bit[i] contributes 2^i to the total
|
| 2849 |
+
# exp_a >= exp_b when weighted(exp_a) - weighted(exp_b) >= 0
|
| 2850 |
+
|
| 2851 |
+
weights_exp_a = [float(2**i) for i in range(5)] # +1, +2, +4, +8, +16
|
| 2852 |
+
weights_exp_b = [-float(2**i) for i in range(5)] # -1, -2, -4, -8, -16
|
| 2853 |
+
|
| 2854 |
+
# a_exp_ge_b: exp_a >= exp_b
|
| 2855 |
+
tensors[f"{prefix}.a_exp_ge_b.weight"] = torch.tensor(weights_exp_a + weights_exp_b)
|
| 2856 |
+
tensors[f"{prefix}.a_exp_ge_b.bias"] = torch.tensor([0.0]) # >= (not strict >)
|
| 2857 |
+
|
| 2858 |
+
# a_exp_gt_b: exp_a > exp_b (for strict comparison)
|
| 2859 |
+
tensors[f"{prefix}.a_exp_gt_b.weight"] = torch.tensor(weights_exp_a + weights_exp_b)
|
| 2860 |
+
tensors[f"{prefix}.a_exp_gt_b.bias"] = torch.tensor([-0.5]) # strict >
|
| 2861 |
+
|
| 2862 |
+
# b_exp_gt_a: exp_b > exp_a
|
| 2863 |
+
tensors[f"{prefix}.b_exp_gt_a.weight"] = torch.tensor(weights_exp_b[::-1] + weights_exp_a[::-1])
|
| 2864 |
+
# Actually, simpler: just swap the inputs conceptually
|
| 2865 |
+
# b > a means weights for b positive, weights for a negative
|
| 2866 |
+
tensors[f"{prefix}.b_exp_gt_a.weight"] = torch.tensor(weights_exp_a + weights_exp_b)
|
| 2867 |
+
tensors[f"{prefix}.b_exp_gt_a.bias"] = torch.tensor([-0.5])
|
| 2868 |
+
|
| 2869 |
+
# NOT of a_exp_ge_b (for selecting which path)
|
| 2870 |
+
tensors[f"{prefix}.b_exp_gt_a_sel.weight"] = torch.tensor([-1.0])
|
| 2871 |
+
tensors[f"{prefix}.b_exp_gt_a_sel.bias"] = torch.tensor([0.0])
|
| 2872 |
+
|
| 2873 |
+
# =========================================================================
|
| 2874 |
+
# STAGE 3: COMPUTE EXPONENT DIFFERENCE
|
| 2875 |
+
# =========================================================================
|
| 2876 |
+
# We need |exp_a - exp_b| for the shift amount.
|
| 2877 |
+
# Use 5-bit subtractors: exp_a - exp_b and exp_b - exp_a
|
| 2878 |
+
# Then select based on which exponent is larger.
|
| 2879 |
+
|
| 2880 |
+
# 5-bit subtractor for exp_a - exp_b (using two's complement)
|
| 2881 |
+
# NOT gates for exp_b
|
| 2882 |
+
for i in range(5):
|
| 2883 |
+
tensors[f"{prefix}.not_exp_b{i}.weight"] = torch.tensor([-1.0])
|
| 2884 |
+
tensors[f"{prefix}.not_exp_b{i}.bias"] = torch.tensor([0.0])
|
| 2885 |
+
|
| 2886 |
+
# Full adders for exp_a + NOT(exp_b) + 1 = exp_a - exp_b
|
| 2887 |
+
# FA0: bit 0
|
| 2888 |
+
# XOR1: exp_a[0] XOR not_exp_b[0]
|
| 2889 |
+
tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 2890 |
+
tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.or.bias"] = torch.tensor([-1.0])
|
| 2891 |
+
tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 2892 |
+
tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.nand.bias"] = torch.tensor([1.0])
|
| 2893 |
+
tensors[f"{prefix}.diff_ab.fa0.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 2894 |
+
tensors[f"{prefix}.diff_ab.fa0.xor1.layer2.bias"] = torch.tensor([-2.0])
|
| 2895 |
+
|
| 2896 |
+
# XOR2: xor1 XOR cin (cin=1 for subtraction)
|
| 2897 |
+
tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 2898 |
+
tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.or.bias"] = torch.tensor([-1.0])
|
| 2899 |
+
tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 2900 |
+
tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.nand.bias"] = torch.tensor([1.0])
|
| 2901 |
+
tensors[f"{prefix}.diff_ab.fa0.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 2902 |
+
tensors[f"{prefix}.diff_ab.fa0.xor2.layer2.bias"] = torch.tensor([-2.0])
|
| 2903 |
+
|
| 2904 |
+
# Carry: (a AND b) OR (xor1 AND cin)
|
| 2905 |
+
tensors[f"{prefix}.diff_ab.fa0.and1.weight"] = torch.tensor([1.0, 1.0])
|
| 2906 |
+
tensors[f"{prefix}.diff_ab.fa0.and1.bias"] = torch.tensor([-2.0])
|
| 2907 |
+
tensors[f"{prefix}.diff_ab.fa0.and2.weight"] = torch.tensor([1.0, 1.0])
|
| 2908 |
+
tensors[f"{prefix}.diff_ab.fa0.and2.bias"] = torch.tensor([-2.0])
|
| 2909 |
+
tensors[f"{prefix}.diff_ab.fa0.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 2910 |
+
tensors[f"{prefix}.diff_ab.fa0.cout.bias"] = torch.tensor([-1.0])
|
| 2911 |
+
|
| 2912 |
+
# FA1-FA4: remaining bits (carry chain)
|
| 2913 |
+
for i in range(1, 5):
|
| 2914 |
+
p = f"{prefix}.diff_ab.fa{i}"
|
| 2915 |
+
|
| 2916 |
+
# XOR1: exp_a[i] XOR not_exp_b[i]
|
| 2917 |
+
tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 2918 |
+
tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
|
| 2919 |
+
tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 2920 |
+
tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
|
| 2921 |
+
tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 2922 |
+
tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
|
| 2923 |
+
|
| 2924 |
+
# XOR2: xor1 XOR carry_in
|
| 2925 |
+
tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 2926 |
+
tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
|
| 2927 |
+
tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 2928 |
+
tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
|
| 2929 |
+
tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 2930 |
+
tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
|
| 2931 |
+
|
| 2932 |
+
# Carry
|
| 2933 |
+
tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
|
| 2934 |
+
tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
|
| 2935 |
+
tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
|
| 2936 |
+
tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
|
| 2937 |
+
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 2938 |
+
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 2939 |
+
|
| 2940 |
+
# Similarly for exp_b - exp_a
|
| 2941 |
+
# NOT gates for exp_a
|
| 2942 |
+
for i in range(5):
|
| 2943 |
+
tensors[f"{prefix}.not_exp_a{i}.weight"] = torch.tensor([-1.0])
|
| 2944 |
+
tensors[f"{prefix}.not_exp_a{i}.bias"] = torch.tensor([0.0])
|
| 2945 |
+
|
| 2946 |
+
# Full adders for exp_b + NOT(exp_a) + 1 = exp_b - exp_a
|
| 2947 |
+
for i in range(5):
|
| 2948 |
+
p = f"{prefix}.diff_ba.fa{i}"
|
| 2949 |
+
|
| 2950 |
+
tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 2951 |
+
tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
|
| 2952 |
+
tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 2953 |
+
tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
|
| 2954 |
+
tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 2955 |
+
tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
|
| 2956 |
+
|
| 2957 |
+
tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 2958 |
+
tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
|
| 2959 |
+
tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 2960 |
+
tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
|
| 2961 |
+
tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 2962 |
+
tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
|
| 2963 |
+
|
| 2964 |
+
tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
|
| 2965 |
+
tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
|
| 2966 |
+
tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
|
| 2967 |
+
tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
|
| 2968 |
+
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 2969 |
+
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 2970 |
+
|
| 2971 |
+
# =========================================================================
|
| 2972 |
+
# STAGE 4: SELECT ABSOLUTE DIFFERENCE
|
| 2973 |
+
# =========================================================================
|
| 2974 |
+
# exp_diff = a_exp_ge_b ? (exp_a - exp_b) : (exp_b - exp_a)
|
| 2975 |
+
# Use 2-to-1 mux for each bit
|
| 2976 |
+
|
| 2977 |
+
for i in range(5):
|
| 2978 |
+
# Mux: out = (sel AND b) OR (NOT sel AND a)
|
| 2979 |
+
# sel = b_exp_gt_a_sel (1 if b > a, meaning we want diff_ba)
|
| 2980 |
+
# Actually: sel=0 (a>=b) -> use diff_ab, sel=1 (b>a) -> use diff_ba
|
| 2981 |
+
|
| 2982 |
+
# AND gate for diff_ab path (when a_exp_ge_b = 1)
|
| 2983 |
+
tensors[f"{prefix}.exp_diff_mux{i}.and_ab.weight"] = torch.tensor([1.0, 1.0])
|
| 2984 |
+
tensors[f"{prefix}.exp_diff_mux{i}.and_ab.bias"] = torch.tensor([-2.0])
|
| 2985 |
+
|
| 2986 |
+
# AND gate for diff_ba path (when b_exp_gt_a_sel = 1, i.e., a_exp_ge_b = 0)
|
| 2987 |
+
tensors[f"{prefix}.exp_diff_mux{i}.and_ba.weight"] = torch.tensor([1.0, 1.0])
|
| 2988 |
+
tensors[f"{prefix}.exp_diff_mux{i}.and_ba.bias"] = torch.tensor([-2.0])
|
| 2989 |
+
|
| 2990 |
+
# OR to combine
|
| 2991 |
+
tensors[f"{prefix}.exp_diff{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 2992 |
+
tensors[f"{prefix}.exp_diff{i}.bias"] = torch.tensor([-1.0])
|
| 2993 |
+
|
| 2994 |
+
# =========================================================================
|
| 2995 |
+
# STAGE 5: SELECT LARGER EXPONENT (for result)
|
| 2996 |
+
# =========================================================================
|
| 2997 |
+
# exp_larger = a_exp_ge_b ? exp_a : exp_b
|
| 2998 |
+
|
| 2999 |
+
for i in range(5):
|
| 3000 |
+
# AND gate for exp_a path
|
| 3001 |
+
tensors[f"{prefix}.exp_larger_mux{i}.and_a.weight"] = torch.tensor([1.0, 1.0])
|
| 3002 |
+
tensors[f"{prefix}.exp_larger_mux{i}.and_a.bias"] = torch.tensor([-2.0])
|
| 3003 |
+
|
| 3004 |
+
# AND gate for exp_b path
|
| 3005 |
+
tensors[f"{prefix}.exp_larger_mux{i}.and_b.weight"] = torch.tensor([1.0, 1.0])
|
| 3006 |
+
tensors[f"{prefix}.exp_larger_mux{i}.and_b.bias"] = torch.tensor([-2.0])
|
| 3007 |
+
|
| 3008 |
+
# OR to combine
|
| 3009 |
+
tensors[f"{prefix}.exp_larger{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3010 |
+
tensors[f"{prefix}.exp_larger{i}.bias"] = torch.tensor([-1.0])
|
| 3011 |
+
|
| 3012 |
+
# =========================================================================
|
| 3013 |
+
# STAGE 6: MANTISSA ALIGNMENT (Barrel Shifter)
|
| 3014 |
+
# =========================================================================
|
| 3015 |
+
# The smaller exponent's mantissa needs to be shifted right by exp_diff.
|
| 3016 |
+
# Mantissa is 11 bits: implicit bit + 10 explicit mantissa bits.
|
| 3017 |
+
#
|
| 3018 |
+
# We need to:
|
| 3019 |
+
# 1. Select which mantissa to shift (the one with smaller exponent)
|
| 3020 |
+
# 2. Shift it right by exp_diff positions
|
| 3021 |
+
# 3. The larger mantissa passes through unchanged
|
| 3022 |
+
#
|
| 3023 |
+
# For the barrel shifter, we use cascaded 2-to-1 muxes:
|
| 3024 |
+
# - Stage 0: shift by 0 or 1 (controlled by exp_diff[0])
|
| 3025 |
+
# - Stage 1: shift by 0 or 2 (controlled by exp_diff[1])
|
| 3026 |
+
# - Stage 2: shift by 0 or 4 (controlled by exp_diff[2])
|
| 3027 |
+
# - Stage 3: shift by 0 or 8 (controlled by exp_diff[3])
|
| 3028 |
+
#
|
| 3029 |
+
# If exp_diff >= 11, the shifted mantissa becomes 0 (complete loss).
|
| 3030 |
+
|
| 3031 |
+
# First, select which mantissa gets shifted (the smaller exponent one)
|
| 3032 |
+
# mant_to_shift = a_exp_ge_b ? mant_b : mant_a (shift the smaller exp's mantissa)
|
| 3033 |
+
# mant_larger = a_exp_ge_b ? mant_a : mant_b
|
| 3034 |
+
|
| 3035 |
+
# Full mantissa with implicit bit: 11 bits (bit 10 = implicit, bits 9-0 = explicit)
|
| 3036 |
+
for i in range(11):
|
| 3037 |
+
# mant_shift_src[i] = mux(a_exp_ge_b, mant_b[i], mant_a[i])
|
| 3038 |
+
# When a_exp_ge_b=1, we shift b's mantissa (a has larger exp)
|
| 3039 |
+
# When a_exp_ge_b=0, we shift a's mantissa (b has larger exp)
|
| 3040 |
+
|
| 3041 |
+
tensors[f"{prefix}.mant_shift_src{i}.and_b.weight"] = torch.tensor([1.0, 1.0])
|
| 3042 |
+
tensors[f"{prefix}.mant_shift_src{i}.and_b.bias"] = torch.tensor([-2.0])
|
| 3043 |
+
|
| 3044 |
+
tensors[f"{prefix}.mant_shift_src{i}.and_a.weight"] = torch.tensor([1.0, 1.0])
|
| 3045 |
+
tensors[f"{prefix}.mant_shift_src{i}.and_a.bias"] = torch.tensor([-2.0])
|
| 3046 |
+
|
| 3047 |
+
tensors[f"{prefix}.mant_shift_src{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3048 |
+
tensors[f"{prefix}.mant_shift_src{i}.bias"] = torch.tensor([-1.0])
|
| 3049 |
+
|
| 3050 |
+
# mant_larger[i] = mux(a_exp_ge_b, mant_a[i], mant_b[i])
|
| 3051 |
+
tensors[f"{prefix}.mant_larger{i}.and_a.weight"] = torch.tensor([1.0, 1.0])
|
| 3052 |
+
tensors[f"{prefix}.mant_larger{i}.and_a.bias"] = torch.tensor([-2.0])
|
| 3053 |
+
|
| 3054 |
+
tensors[f"{prefix}.mant_larger{i}.and_b.weight"] = torch.tensor([1.0, 1.0])
|
| 3055 |
+
tensors[f"{prefix}.mant_larger{i}.and_b.bias"] = torch.tensor([-2.0])
|
| 3056 |
+
|
| 3057 |
+
tensors[f"{prefix}.mant_larger{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3058 |
+
tensors[f"{prefix}.mant_larger{i}.bias"] = torch.tensor([-1.0])
|
| 3059 |
+
|
| 3060 |
+
# Barrel shifter stages
|
| 3061 |
+
# Stage 0: shift by 1 if exp_diff[0]=1
|
| 3062 |
+
# NOT exp_diff[0] for pass-through path
|
| 3063 |
+
tensors[f"{prefix}.not_exp_diff0.weight"] = torch.tensor([-1.0])
|
| 3064 |
+
tensors[f"{prefix}.not_exp_diff0.bias"] = torch.tensor([0.0])
|
| 3065 |
+
|
| 3066 |
+
for i in range(11):
|
| 3067 |
+
# Output bit i comes from:
|
| 3068 |
+
# - bit i if not shifting (exp_diff[0]=0)
|
| 3069 |
+
# - bit i+1 if shifting (exp_diff[0]=1), or 0 if i+1 >= 11
|
| 3070 |
+
tensors[f"{prefix}.shift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3071 |
+
tensors[f"{prefix}.shift_s0_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3072 |
+
|
| 3073 |
+
if i < 10:
|
| 3074 |
+
tensors[f"{prefix}.shift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3075 |
+
tensors[f"{prefix}.shift_s0_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3076 |
+
tensors[f"{prefix}.shift_s0_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3077 |
+
else:
|
| 3078 |
+
# MSB: shift path is 0, so just pass-through when not shifting
|
| 3079 |
+
tensors[f"{prefix}.shift_s0_{i}.weight"] = torch.tensor([1.0])
|
| 3080 |
+
tensors[f"{prefix}.shift_s0_{i}.bias"] = torch.tensor([-1.0])
|
| 3081 |
+
|
| 3082 |
+
# Stage 1: shift by 2 if exp_diff[1]=1
|
| 3083 |
+
tensors[f"{prefix}.not_exp_diff1.weight"] = torch.tensor([-1.0])
|
| 3084 |
+
tensors[f"{prefix}.not_exp_diff1.bias"] = torch.tensor([0.0])
|
| 3085 |
+
|
| 3086 |
+
for i in range(11):
|
| 3087 |
+
tensors[f"{prefix}.shift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3088 |
+
tensors[f"{prefix}.shift_s1_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3089 |
+
|
| 3090 |
+
if i < 9:
|
| 3091 |
+
tensors[f"{prefix}.shift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3092 |
+
tensors[f"{prefix}.shift_s1_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3093 |
+
tensors[f"{prefix}.shift_s1_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3094 |
+
else:
|
| 3095 |
+
tensors[f"{prefix}.shift_s1_{i}.weight"] = torch.tensor([1.0])
|
| 3096 |
+
tensors[f"{prefix}.shift_s1_{i}.bias"] = torch.tensor([-1.0])
|
| 3097 |
+
|
| 3098 |
+
# Stage 2: shift by 4 if exp_diff[2]=1
|
| 3099 |
+
tensors[f"{prefix}.not_exp_diff2.weight"] = torch.tensor([-1.0])
|
| 3100 |
+
tensors[f"{prefix}.not_exp_diff2.bias"] = torch.tensor([0.0])
|
| 3101 |
+
|
| 3102 |
+
for i in range(11):
|
| 3103 |
+
tensors[f"{prefix}.shift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3104 |
+
tensors[f"{prefix}.shift_s2_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3105 |
+
|
| 3106 |
+
if i < 7:
|
| 3107 |
+
tensors[f"{prefix}.shift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3108 |
+
tensors[f"{prefix}.shift_s2_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3109 |
+
tensors[f"{prefix}.shift_s2_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3110 |
+
else:
|
| 3111 |
+
tensors[f"{prefix}.shift_s2_{i}.weight"] = torch.tensor([1.0])
|
| 3112 |
+
tensors[f"{prefix}.shift_s2_{i}.bias"] = torch.tensor([-1.0])
|
| 3113 |
+
|
| 3114 |
+
# Stage 3: shift by 8 if exp_diff[3]=1
|
| 3115 |
+
tensors[f"{prefix}.not_exp_diff3.weight"] = torch.tensor([-1.0])
|
| 3116 |
+
tensors[f"{prefix}.not_exp_diff3.bias"] = torch.tensor([0.0])
|
| 3117 |
+
|
| 3118 |
+
for i in range(11):
|
| 3119 |
+
tensors[f"{prefix}.shift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3120 |
+
tensors[f"{prefix}.shift_s3_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3121 |
+
|
| 3122 |
+
if i < 3:
|
| 3123 |
+
tensors[f"{prefix}.shift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3124 |
+
tensors[f"{prefix}.shift_s3_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3125 |
+
tensors[f"{prefix}.shift_s3_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3126 |
+
else:
|
| 3127 |
+
tensors[f"{prefix}.shift_s3_{i}.weight"] = torch.tensor([1.0])
|
| 3128 |
+
tensors[f"{prefix}.shift_s3_{i}.bias"] = torch.tensor([-1.0])
|
| 3129 |
+
|
| 3130 |
+
# If exp_diff[4]=1 (shift by 16 or more), result is 0
|
| 3131 |
+
# mant_aligned = exp_diff[4] ? 0 : shift_s3 result
|
| 3132 |
+
tensors[f"{prefix}.not_exp_diff4.weight"] = torch.tensor([-1.0])
|
| 3133 |
+
tensors[f"{prefix}.not_exp_diff4.bias"] = torch.tensor([0.0])
|
| 3134 |
+
|
| 3135 |
+
for i in range(11):
|
| 3136 |
+
# Only pass through if exp_diff[4]=0
|
| 3137 |
+
tensors[f"{prefix}.mant_aligned{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3138 |
+
tensors[f"{prefix}.mant_aligned{i}.bias"] = torch.tensor([-2.0])
|
| 3139 |
+
|
| 3140 |
+
# =========================================================================
|
| 3141 |
+
# STAGE 7: MANTISSA ADDITION/SUBTRACTION
|
| 3142 |
+
# =========================================================================
|
| 3143 |
+
# If signs are the same: add mantissas
|
| 3144 |
+
# If signs differ: subtract smaller from larger
|
| 3145 |
+
#
|
| 3146 |
+
# We have:
|
| 3147 |
+
# - mant_larger[10:0]: mantissa of the larger exponent operand
|
| 3148 |
+
# - mant_aligned[10:0]: shifted mantissa of the smaller exponent operand
|
| 3149 |
+
#
|
| 3150 |
+
# For subtraction, we need to know which mantissa is larger.
|
| 3151 |
+
# If exp_a > exp_b, then mant_a is the reference (could be smaller mantissa value)
|
| 3152 |
+
# If exp_a == exp_b, we need to compare mantissas directly.
|
| 3153 |
+
#
|
| 3154 |
+
# signs_same: NOT signs_differ
|
| 3155 |
+
tensors[f"{prefix}.signs_same.weight"] = torch.tensor([-1.0])
|
| 3156 |
+
tensors[f"{prefix}.signs_same.bias"] = torch.tensor([0.0])
|
| 3157 |
+
|
| 3158 |
+
# For the result sign when signs differ:
|
| 3159 |
+
# If exp_a > exp_b: result sign = sign_a
|
| 3160 |
+
# If exp_b > exp_a: result sign = sign_b
|
| 3161 |
+
# If exp_a == exp_b: result sign = sign of larger mantissa
|
| 3162 |
+
|
| 3163 |
+
# Mantissa comparison (for equal exponent case)
|
| 3164 |
+
# Compare mant_a vs mant_b when exponents are equal
|
| 3165 |
+
weights_mant = [float(2**i) for i in range(11)]
|
| 3166 |
+
neg_weights_mant = [-float(2**i) for i in range(11)]
|
| 3167 |
+
|
| 3168 |
+
tensors[f"{prefix}.mant_a_ge_b.weight"] = torch.tensor(weights_mant + neg_weights_mant)
|
| 3169 |
+
tensors[f"{prefix}.mant_a_ge_b.bias"] = torch.tensor([0.0])
|
| 3170 |
+
|
| 3171 |
+
# 12-bit adder for mantissa sum (11 mantissa bits + 1 carry out)
|
| 3172 |
+
# We'll compute mant_larger + mant_aligned (for same sign)
|
| 3173 |
+
# or |mant_larger - mant_aligned| (for different signs)
|
| 3174 |
+
|
| 3175 |
+
# For subtraction, we need: larger_mant - smaller_mant
|
| 3176 |
+
# If exponents differ, larger exp means larger value, so:
|
| 3177 |
+
# result = mant_larger - mant_aligned
|
| 3178 |
+
# If exponents equal, compare mantissas:
|
| 3179 |
+
# result = |mant_a - mant_b|
|
| 3180 |
+
|
| 3181 |
+
# NOT gates for mant_aligned (for subtraction)
|
| 3182 |
+
for i in range(11):
|
| 3183 |
+
tensors[f"{prefix}.not_mant_aligned{i}.weight"] = torch.tensor([-1.0])
|
| 3184 |
+
tensors[f"{prefix}.not_mant_aligned{i}.bias"] = torch.tensor([0.0])
|
| 3185 |
+
|
| 3186 |
+
# 12-bit adder/subtractor
|
| 3187 |
+
# When signs_same=1: add (carry_in = 0)
|
| 3188 |
+
# When signs_same=0: subtract (use NOT mant_aligned, carry_in = 1)
|
| 3189 |
+
|
| 3190 |
+
# Carry input selection: signs_same ? 0 : 1
|
| 3191 |
+
# This is just NOT signs_same = signs_differ
|
| 3192 |
+
tensors[f"{prefix}.sub_cin.weight"] = torch.tensor([1.0])
|
| 3193 |
+
tensors[f"{prefix}.sub_cin.bias"] = torch.tensor([-0.5])
|
| 3194 |
+
|
| 3195 |
+
# Operand B selection: signs_same ? mant_aligned : NOT mant_aligned
|
| 3196 |
+
for i in range(11):
|
| 3197 |
+
# When adding (signs_same=1): use mant_aligned
|
| 3198 |
+
tensors[f"{prefix}.addsub_b{i}.add.weight"] = torch.tensor([1.0, 1.0])
|
| 3199 |
+
tensors[f"{prefix}.addsub_b{i}.add.bias"] = torch.tensor([-2.0])
|
| 3200 |
+
|
| 3201 |
+
# When subtracting (signs_same=0 = signs_differ=1): use NOT mant_aligned
|
| 3202 |
+
tensors[f"{prefix}.addsub_b{i}.sub.weight"] = torch.tensor([1.0, 1.0])
|
| 3203 |
+
tensors[f"{prefix}.addsub_b{i}.sub.bias"] = torch.tensor([-2.0])
|
| 3204 |
+
|
| 3205 |
+
tensors[f"{prefix}.addsub_b{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3206 |
+
tensors[f"{prefix}.addsub_b{i}.bias"] = torch.tensor([-1.0])
|
| 3207 |
+
|
| 3208 |
+
# 12-bit ripple carry adder for mant_larger + addsub_b + sub_cin
|
| 3209 |
+
for i in range(12):
|
| 3210 |
+
p = f"{prefix}.mant_add.fa{i}"
|
| 3211 |
+
|
| 3212 |
+
tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 3213 |
+
tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
|
| 3214 |
+
tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 3215 |
+
tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
|
| 3216 |
+
tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 3217 |
+
tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
|
| 3218 |
+
|
| 3219 |
+
tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 3220 |
+
tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
|
| 3221 |
+
tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 3222 |
+
tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
|
| 3223 |
+
tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 3224 |
+
tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
|
| 3225 |
+
|
| 3226 |
+
tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
|
| 3227 |
+
tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
|
| 3228 |
+
tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
|
| 3229 |
+
tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
|
| 3230 |
+
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 3231 |
+
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 3232 |
+
|
| 3233 |
+
# =========================================================================
|
| 3234 |
+
# STAGE 8: RESULT SIGN DETERMINATION
|
| 3235 |
+
# =========================================================================
|
| 3236 |
+
# When signs_same: result_sign = sign_a (= sign_b)
|
| 3237 |
+
# When signs_differ:
|
| 3238 |
+
# If a has larger magnitude: result_sign = sign_a
|
| 3239 |
+
# If b has larger magnitude: result_sign = sign_b
|
| 3240 |
+
#
|
| 3241 |
+
# Magnitude comparison: consider both exponent and mantissa
|
| 3242 |
+
# a_magnitude_ge_b: (exp_a > exp_b) OR (exp_a == exp_b AND mant_a >= mant_b)
|
| 3243 |
+
|
| 3244 |
+
# exp_a_eq_b: NOT a_exp_gt_b AND NOT b_exp_gt_a
|
| 3245 |
+
tensors[f"{prefix}.not_a_exp_gt_b.weight"] = torch.tensor([-1.0])
|
| 3246 |
+
tensors[f"{prefix}.not_a_exp_gt_b.bias"] = torch.tensor([0.0])
|
| 3247 |
+
|
| 3248 |
+
tensors[f"{prefix}.exp_a_eq_b.weight"] = torch.tensor([1.0, 1.0])
|
| 3249 |
+
tensors[f"{prefix}.exp_a_eq_b.bias"] = torch.tensor([-2.0])
|
| 3250 |
+
|
| 3251 |
+
# exp_eq_and_mant_a_ge: exp_a_eq_b AND mant_a_ge_b
|
| 3252 |
+
tensors[f"{prefix}.exp_eq_and_mant_a_ge.weight"] = torch.tensor([1.0, 1.0])
|
| 3253 |
+
tensors[f"{prefix}.exp_eq_and_mant_a_ge.bias"] = torch.tensor([-2.0])
|
| 3254 |
+
|
| 3255 |
+
# a_magnitude_ge_b: a_exp_gt_b OR exp_eq_and_mant_a_ge
|
| 3256 |
+
tensors[f"{prefix}.a_magnitude_ge_b.weight"] = torch.tensor([1.0, 1.0])
|
| 3257 |
+
tensors[f"{prefix}.a_magnitude_ge_b.bias"] = torch.tensor([-1.0])
|
| 3258 |
+
|
| 3259 |
+
# result_sign when signs_differ:
|
| 3260 |
+
# = a_magnitude_ge_b ? sign_a : sign_b
|
| 3261 |
+
tensors[f"{prefix}.not_a_mag_ge_b.weight"] = torch.tensor([-1.0])
|
| 3262 |
+
tensors[f"{prefix}.not_a_mag_ge_b.bias"] = torch.tensor([0.0])
|
| 3263 |
+
|
| 3264 |
+
tensors[f"{prefix}.diff_sign_sel_a.weight"] = torch.tensor([1.0, 1.0])
|
| 3265 |
+
tensors[f"{prefix}.diff_sign_sel_a.bias"] = torch.tensor([-2.0])
|
| 3266 |
+
|
| 3267 |
+
tensors[f"{prefix}.diff_sign_sel_b.weight"] = torch.tensor([1.0, 1.0])
|
| 3268 |
+
tensors[f"{prefix}.diff_sign_sel_b.bias"] = torch.tensor([-2.0])
|
| 3269 |
+
|
| 3270 |
+
tensors[f"{prefix}.diff_result_sign.weight"] = torch.tensor([1.0, 1.0])
|
| 3271 |
+
tensors[f"{prefix}.diff_result_sign.bias"] = torch.tensor([-1.0])
|
| 3272 |
+
|
| 3273 |
+
# Final result sign: signs_same ? sign_a : diff_result_sign
|
| 3274 |
+
tensors[f"{prefix}.result_sign_same.weight"] = torch.tensor([1.0, 1.0])
|
| 3275 |
+
tensors[f"{prefix}.result_sign_same.bias"] = torch.tensor([-2.0])
|
| 3276 |
+
|
| 3277 |
+
tensors[f"{prefix}.result_sign_diff.weight"] = torch.tensor([1.0, 1.0])
|
| 3278 |
+
tensors[f"{prefix}.result_sign_diff.bias"] = torch.tensor([-2.0])
|
| 3279 |
+
|
| 3280 |
+
tensors[f"{prefix}.result_sign.weight"] = torch.tensor([1.0, 1.0])
|
| 3281 |
+
tensors[f"{prefix}.result_sign.bias"] = torch.tensor([-1.0])
|
| 3282 |
+
|
| 3283 |
+
# =========================================================================
|
| 3284 |
+
# STAGE 9: NORMALIZATION
|
| 3285 |
+
# =========================================================================
|
| 3286 |
+
# The mantissa sum may need normalization:
|
| 3287 |
+
# - If bit 12 (carry out) is set: right shift by 1, increment exponent
|
| 3288 |
+
# - If leading bit is 0: left shift until leading 1 found, decrement exponent
|
| 3289 |
+
#
|
| 3290 |
+
# Use CLZ to find shift amount for left shift case.
|
| 3291 |
+
# The sum is 12 bits (mant_add output).
|
| 3292 |
+
|
| 3293 |
+
# Overflow detection: mant_add.fa11 carry out
|
| 3294 |
+
tensors[f"{prefix}.sum_overflow.weight"] = torch.tensor([1.0])
|
| 3295 |
+
tensors[f"{prefix}.sum_overflow.bias"] = torch.tensor([-0.5])
|
| 3296 |
+
|
| 3297 |
+
# CLZ on 11-bit sum (bits 10:0) to find normalization shift
|
| 3298 |
+
# For non-overflow case, count leading zeros starting from bit 10
|
| 3299 |
+
# pz gates: prefix zero detectors on bits 10:0
|
| 3300 |
+
for k in range(1, 12):
|
| 3301 |
+
tensors[f"{prefix}.sum_pz{k}.weight"] = torch.tensor([-1.0] * k)
|
| 3302 |
+
tensors[f"{prefix}.sum_pz{k}.bias"] = torch.tensor([0.0])
|
| 3303 |
+
|
| 3304 |
+
# ge gates: sum of pz >= k (for 11-bit CLZ, max is 11)
|
| 3305 |
+
for k in range(1, 12):
|
| 3306 |
+
tensors[f"{prefix}.sum_ge{k}.weight"] = torch.tensor([1.0] * 11)
|
| 3307 |
+
tensors[f"{prefix}.sum_ge{k}.bias"] = torch.tensor([-float(k)])
|
| 3308 |
+
|
| 3309 |
+
# NOT gates for binary encoding
|
| 3310 |
+
for k in [2, 4, 6, 8, 10]:
|
| 3311 |
+
tensors[f"{prefix}.sum_not_ge{k}.weight"] = torch.tensor([-1.0])
|
| 3312 |
+
tensors[f"{prefix}.sum_not_ge{k}.bias"] = torch.tensor([0.0])
|
| 3313 |
+
|
| 3314 |
+
# Shift amount encoding (4 bits for 0-11)
|
| 3315 |
+
# CLZ of 11 bits can be 0-11
|
| 3316 |
+
tensors[f"{prefix}.norm_shift3.weight"] = torch.tensor([1.0])
|
| 3317 |
+
tensors[f"{prefix}.norm_shift3.bias"] = torch.tensor([-0.5]) # ge8
|
| 3318 |
+
|
| 3319 |
+
tensors[f"{prefix}.norm_and_4_7.weight"] = torch.tensor([1.0, 1.0])
|
| 3320 |
+
tensors[f"{prefix}.norm_and_4_7.bias"] = torch.tensor([-2.0])
|
| 3321 |
+
# For 11-bit CLZ (max 11), shift2 = ge4 AND NOT ge8 (no ge12 needed)
|
| 3322 |
+
tensors[f"{prefix}.norm_shift2.weight"] = torch.tensor([1.0])
|
| 3323 |
+
tensors[f"{prefix}.norm_shift2.bias"] = torch.tensor([-0.5])
|
| 3324 |
+
|
| 3325 |
+
tensors[f"{prefix}.norm_and_2_3.weight"] = torch.tensor([1.0, 1.0])
|
| 3326 |
+
tensors[f"{prefix}.norm_and_2_3.bias"] = torch.tensor([-2.0])
|
| 3327 |
+
tensors[f"{prefix}.norm_and_6_7.weight"] = torch.tensor([1.0, 1.0])
|
| 3328 |
+
tensors[f"{prefix}.norm_and_6_7.bias"] = torch.tensor([-2.0])
|
| 3329 |
+
# For 11-bit CLZ (max 11), ge10 means CLZ is 10 or 11, no need for NOT ge12
|
| 3330 |
+
tensors[f"{prefix}.norm_and_10_11.weight"] = torch.tensor([1.0])
|
| 3331 |
+
tensors[f"{prefix}.norm_and_10_11.bias"] = torch.tensor([-0.5])
|
| 3332 |
+
tensors[f"{prefix}.norm_shift1.weight"] = torch.tensor([1.0, 1.0, 1.0])
|
| 3333 |
+
tensors[f"{prefix}.norm_shift1.bias"] = torch.tensor([-1.0])
|
| 3334 |
+
|
| 3335 |
+
for i in [1, 3, 5, 7, 9]:
|
| 3336 |
+
tensors[f"{prefix}.norm_and_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3337 |
+
tensors[f"{prefix}.norm_and_{i}.bias"] = torch.tensor([-2.0])
|
| 3338 |
+
tensors[f"{prefix}.norm_shift0.weight"] = torch.tensor([1.0] * 5)
|
| 3339 |
+
tensors[f"{prefix}.norm_shift0.bias"] = torch.tensor([-1.0])
|
| 3340 |
+
|
| 3341 |
+
# =========================================================================
|
| 3342 |
+
# STAGE 10: APPLY NORMALIZATION TO MANTISSA
|
| 3343 |
+
# =========================================================================
|
| 3344 |
+
# Two cases:
|
| 3345 |
+
# 1. Overflow (sum bit 11 set): right-shift mantissa by 1, increment exponent
|
| 3346 |
+
# 2. No overflow: left-shift mantissa by norm_shift, decrement exponent
|
| 3347 |
+
|
| 3348 |
+
# NOT sum_overflow for non-overflow path
|
| 3349 |
+
tensors[f"{prefix}.not_sum_overflow.weight"] = torch.tensor([-1.0])
|
| 3350 |
+
tensors[f"{prefix}.not_sum_overflow.bias"] = torch.tensor([0.0])
|
| 3351 |
+
|
| 3352 |
+
# Overflow mantissa: bits 10:1 of adder_sum (right-shifted by 1)
|
| 3353 |
+
# norm_mant_overflow[i] = adder_sum[i+1] for i in 0..9
|
| 3354 |
+
for i in range(10):
|
| 3355 |
+
tensors[f"{prefix}.norm_mant_overflow{i}.weight"] = torch.tensor([1.0])
|
| 3356 |
+
tensors[f"{prefix}.norm_mant_overflow{i}.bias"] = torch.tensor([-0.5])
|
| 3357 |
+
|
| 3358 |
+
# Non-overflow mantissa: left-shift adder_sum[10:0] by norm_shift amount
|
| 3359 |
+
# This requires a left barrel shifter on the 11-bit sum (bits 10:0)
|
| 3360 |
+
|
| 3361 |
+
# Left barrel shifter stage 0: shift left by 1 if norm_shift[0]=1
|
| 3362 |
+
tensors[f"{prefix}.not_norm_shift0.weight"] = torch.tensor([-1.0])
|
| 3363 |
+
tensors[f"{prefix}.not_norm_shift0.bias"] = torch.tensor([0.0])
|
| 3364 |
+
|
| 3365 |
+
for i in range(11):
|
| 3366 |
+
tensors[f"{prefix}.lshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3367 |
+
tensors[f"{prefix}.lshift_s0_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3368 |
+
if i > 0:
|
| 3369 |
+
tensors[f"{prefix}.lshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3370 |
+
tensors[f"{prefix}.lshift_s0_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3371 |
+
tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3372 |
+
else:
|
| 3373 |
+
tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0])
|
| 3374 |
+
tensors[f"{prefix}.lshift_s0_{i}.bias"] = torch.tensor([-1.0])
|
| 3375 |
+
|
| 3376 |
+
# Left barrel shifter stage 1: shift left by 2 if norm_shift[1]=1
|
| 3377 |
+
tensors[f"{prefix}.not_norm_shift1.weight"] = torch.tensor([-1.0])
|
| 3378 |
+
tensors[f"{prefix}.not_norm_shift1.bias"] = torch.tensor([0.0])
|
| 3379 |
+
|
| 3380 |
+
for i in range(11):
|
| 3381 |
+
tensors[f"{prefix}.lshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3382 |
+
tensors[f"{prefix}.lshift_s1_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3383 |
+
if i > 1:
|
| 3384 |
+
tensors[f"{prefix}.lshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3385 |
+
tensors[f"{prefix}.lshift_s1_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3386 |
+
tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3387 |
+
else:
|
| 3388 |
+
tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0])
|
| 3389 |
+
tensors[f"{prefix}.lshift_s1_{i}.bias"] = torch.tensor([-1.0])
|
| 3390 |
+
|
| 3391 |
+
# Left barrel shifter stage 2: shift left by 4 if norm_shift[2]=1
|
| 3392 |
+
tensors[f"{prefix}.not_norm_shift2.weight"] = torch.tensor([-1.0])
|
| 3393 |
+
tensors[f"{prefix}.not_norm_shift2.bias"] = torch.tensor([0.0])
|
| 3394 |
+
|
| 3395 |
+
for i in range(11):
|
| 3396 |
+
tensors[f"{prefix}.lshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3397 |
+
tensors[f"{prefix}.lshift_s2_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3398 |
+
if i > 3:
|
| 3399 |
+
tensors[f"{prefix}.lshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3400 |
+
tensors[f"{prefix}.lshift_s2_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3401 |
+
tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3402 |
+
else:
|
| 3403 |
+
tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0])
|
| 3404 |
+
tensors[f"{prefix}.lshift_s2_{i}.bias"] = torch.tensor([-1.0])
|
| 3405 |
+
|
| 3406 |
+
# Left barrel shifter stage 3: shift left by 8 if norm_shift[3]=1
|
| 3407 |
+
tensors[f"{prefix}.not_norm_shift3.weight"] = torch.tensor([-1.0])
|
| 3408 |
+
tensors[f"{prefix}.not_norm_shift3.bias"] = torch.tensor([0.0])
|
| 3409 |
+
|
| 3410 |
+
for i in range(11):
|
| 3411 |
+
tensors[f"{prefix}.lshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 3412 |
+
tensors[f"{prefix}.lshift_s3_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 3413 |
+
if i > 7:
|
| 3414 |
+
tensors[f"{prefix}.lshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 3415 |
+
tensors[f"{prefix}.lshift_s3_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 3416 |
+
tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3417 |
+
else:
|
| 3418 |
+
tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0])
|
| 3419 |
+
tensors[f"{prefix}.lshift_s3_{i}.bias"] = torch.tensor([-1.0])
|
| 3420 |
+
|
| 3421 |
+
# Select normalized mantissa: overflow ? overflow_mant : lshift result
|
| 3422 |
+
# Take bits 9:0 for the output mantissa (bit 10 is implicit, dropped)
|
| 3423 |
+
for i in range(10):
|
| 3424 |
+
tensors[f"{prefix}.norm_mant{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0])
|
| 3425 |
+
tensors[f"{prefix}.norm_mant{i}.overflow_path.bias"] = torch.tensor([-2.0])
|
| 3426 |
+
tensors[f"{prefix}.norm_mant{i}.normal_path.weight"] = torch.tensor([1.0, 1.0])
|
| 3427 |
+
tensors[f"{prefix}.norm_mant{i}.normal_path.bias"] = torch.tensor([-2.0])
|
| 3428 |
+
tensors[f"{prefix}.norm_mant{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3429 |
+
tensors[f"{prefix}.norm_mant{i}.bias"] = torch.tensor([-1.0])
|
| 3430 |
+
|
| 3431 |
+
# =========================================================================
|
| 3432 |
+
# STAGE 11: ADJUST EXPONENT
|
| 3433 |
+
# =========================================================================
|
| 3434 |
+
# Overflow: exp_result = exp_larger + 1
|
| 3435 |
+
# No overflow: exp_result = exp_larger - norm_shift
|
| 3436 |
+
|
| 3437 |
+
# Increment exponent by 1 (for overflow case)
|
| 3438 |
+
# Half adder chain: exp_larger + 1
|
| 3439 |
+
tensors[f"{prefix}.exp_inc.ha0.sum.weight"] = torch.tensor([-1.0]) # NOT for XOR with 1
|
| 3440 |
+
tensors[f"{prefix}.exp_inc.ha0.sum.bias"] = torch.tensor([0.0])
|
| 3441 |
+
tensors[f"{prefix}.exp_inc.ha0.cout.weight"] = torch.tensor([1.0]) # AND with 1 = passthrough
|
| 3442 |
+
tensors[f"{prefix}.exp_inc.ha0.cout.bias"] = torch.tensor([-0.5])
|
| 3443 |
+
|
| 3444 |
+
for i in range(1, 5):
|
| 3445 |
+
# XOR: exp[i] XOR carry_in
|
| 3446 |
+
tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 3447 |
+
tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.or.bias"] = torch.tensor([-1.0])
|
| 3448 |
+
tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 3449 |
+
tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.nand.bias"] = torch.tensor([1.0])
|
| 3450 |
+
tensors[f"{prefix}.exp_inc.ha{i}.sum.weight"] = torch.tensor([1.0, 1.0])
|
| 3451 |
+
tensors[f"{prefix}.exp_inc.ha{i}.sum.bias"] = torch.tensor([-2.0])
|
| 3452 |
+
# Carry: exp[i] AND carry_in
|
| 3453 |
+
tensors[f"{prefix}.exp_inc.ha{i}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 3454 |
+
tensors[f"{prefix}.exp_inc.ha{i}.cout.bias"] = torch.tensor([-2.0])
|
| 3455 |
+
|
| 3456 |
+
# Decrement exponent by norm_shift (for non-overflow case)
|
| 3457 |
+
# 5-bit subtractor: exp_larger - norm_shift
|
| 3458 |
+
# NOT gates for norm_shift
|
| 3459 |
+
for i in range(4):
|
| 3460 |
+
tensors[f"{prefix}.not_norm_shift_sub{i}.weight"] = torch.tensor([-1.0])
|
| 3461 |
+
tensors[f"{prefix}.not_norm_shift_sub{i}.bias"] = torch.tensor([0.0])
|
| 3462 |
+
|
| 3463 |
+
# Full adders for exp_larger + NOT(norm_shift) + 1 = exp_larger - norm_shift
|
| 3464 |
+
for i in range(5):
|
| 3465 |
+
p = f"{prefix}.exp_dec.fa{i}"
|
| 3466 |
+
tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 3467 |
+
tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
|
| 3468 |
+
tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 3469 |
+
tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
|
| 3470 |
+
tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 3471 |
+
tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
|
| 3472 |
+
|
| 3473 |
+
tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
| 3474 |
+
tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
|
| 3475 |
+
tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
|
| 3476 |
+
tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
|
| 3477 |
+
tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
|
| 3478 |
+
tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
|
| 3479 |
+
|
| 3480 |
+
tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
|
| 3481 |
+
tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
|
| 3482 |
+
tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
|
| 3483 |
+
tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
|
| 3484 |
+
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 3485 |
+
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 3486 |
+
|
| 3487 |
+
# Select result exponent: overflow ? exp_inc : exp_dec
|
| 3488 |
+
for i in range(5):
|
| 3489 |
+
tensors[f"{prefix}.result_exp{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0])
|
| 3490 |
+
tensors[f"{prefix}.result_exp{i}.overflow_path.bias"] = torch.tensor([-2.0])
|
| 3491 |
+
tensors[f"{prefix}.result_exp{i}.normal_path.weight"] = torch.tensor([1.0, 1.0])
|
| 3492 |
+
tensors[f"{prefix}.result_exp{i}.normal_path.bias"] = torch.tensor([-2.0])
|
| 3493 |
+
tensors[f"{prefix}.result_exp{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 3494 |
+
tensors[f"{prefix}.result_exp{i}.bias"] = torch.tensor([-1.0])
|
| 3495 |
+
|
| 3496 |
+
# =========================================================================
|
| 3497 |
+
# STAGE 12: OUTPUT ASSEMBLY
|
| 3498 |
+
# =========================================================================
|
| 3499 |
+
# Final output combines:
|
| 3500 |
+
# - Special cases (NaN, Inf) override normal computation
|
| 3501 |
+
# - For NaN: output canonical NaN (0x7E00)
|
| 3502 |
+
# - For Inf: output Inf with correct sign
|
| 3503 |
+
# - For normal: pack normalized result
|
| 3504 |
+
|
| 3505 |
+
# NaN output: 0x7E00 = 0111111000000000
|
| 3506 |
+
nan_bits = [0]*9 + [1] + [1]*5 + [0] # bits 0-15
|
| 3507 |
+
|
| 3508 |
+
# Final output mux: nan ? nan_val : (inf ? inf_val : normal_val)
|
| 3509 |
+
tensors[f"{prefix}.not_result_is_inf.weight"] = torch.tensor([-1.0])
|
| 3510 |
+
tensors[f"{prefix}.not_result_is_inf.bias"] = torch.tensor([0.0])
|
| 3511 |
+
|
| 3512 |
+
# Normal case selector: NOT nan AND NOT inf
|
| 3513 |
+
tensors[f"{prefix}.is_normal_result.weight"] = torch.tensor([1.0, 1.0])
|
| 3514 |
+
tensors[f"{prefix}.is_normal_result.bias"] = torch.tensor([-2.0])
|
| 3515 |
+
|
| 3516 |
+
# Inf sign selection
|
| 3517 |
+
tensors[f"{prefix}.inf_sign_sel_a.weight"] = torch.tensor([1.0, 1.0])
|
| 3518 |
+
tensors[f"{prefix}.inf_sign_sel_a.bias"] = torch.tensor([-2.0])
|
| 3519 |
+
tensors[f"{prefix}.inf_sign_sel_b.weight"] = torch.tensor([1.0, 1.0])
|
| 3520 |
+
tensors[f"{prefix}.inf_sign_sel_b.bias"] = torch.tensor([-2.0])
|
| 3521 |
+
tensors[f"{prefix}.inf_sign.weight"] = torch.tensor([1.0, 1.0])
|
| 3522 |
+
tensors[f"{prefix}.inf_sign.bias"] = torch.tensor([-1.0])
|
| 3523 |
+
|
| 3524 |
+
for i in range(16):
|
| 3525 |
+
# NaN path: output NaN bits gated by result_is_nan
|
| 3526 |
+
if nan_bits[i]:
|
| 3527 |
+
tensors[f"{prefix}.out_nan{i}.weight"] = torch.tensor([1.0])
|
| 3528 |
+
tensors[f"{prefix}.out_nan{i}.bias"] = torch.tensor([-0.5])
|
| 3529 |
+
|
| 3530 |
+
# Inf path: exponent bits = 1, mantissa = 0, sign from inf operand
|
| 3531 |
+
if i >= 10 and i < 15:
|
| 3532 |
+
tensors[f"{prefix}.out_inf{i}.weight"] = torch.tensor([1.0])
|
| 3533 |
+
tensors[f"{prefix}.out_inf{i}.bias"] = torch.tensor([-0.5])
|
| 3534 |
+
|
| 3535 |
+
# Normal path
|
| 3536 |
+
if i < 10:
|
| 3537 |
+
# Mantissa bits from norm_mant
|
| 3538 |
+
tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0])
|
| 3539 |
+
tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5])
|
| 3540 |
+
elif i < 15:
|
| 3541 |
+
# Exponent bits from result_exp
|
| 3542 |
+
tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0])
|
| 3543 |
+
tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5])
|
| 3544 |
+
else:
|
| 3545 |
+
# Sign bit from result_sign
|
| 3546 |
+
tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0])
|
| 3547 |
+
tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5])
|
| 3548 |
+
|
| 3549 |
+
# Final output: 3-way mux (nan, inf, normal)
|
| 3550 |
+
tensors[f"{prefix}.out{i}.nan_gate.weight"] = torch.tensor([1.0, 1.0])
|
| 3551 |
+
tensors[f"{prefix}.out{i}.nan_gate.bias"] = torch.tensor([-2.0])
|
| 3552 |
+
tensors[f"{prefix}.out{i}.inf_gate.weight"] = torch.tensor([1.0, 1.0])
|
| 3553 |
+
tensors[f"{prefix}.out{i}.inf_gate.bias"] = torch.tensor([-2.0])
|
| 3554 |
+
tensors[f"{prefix}.out{i}.normal_gate.weight"] = torch.tensor([1.0, 1.0])
|
| 3555 |
+
tensors[f"{prefix}.out{i}.normal_gate.bias"] = torch.tensor([-2.0])
|
| 3556 |
+
tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0, 1.0])
|
| 3557 |
+
tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0])
|
| 3558 |
+
|
| 3559 |
+
return tensors
|
| 3560 |
+
|
| 3561 |
+
|
| 3562 |
def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
|
| 3563 |
"""Build tensors for arithmetic.clz8bit circuit.
|
| 3564 |
|
|
|
|
| 3628 |
|
| 3629 |
print(f"Loaded {len(tensors)} tensors")
|
| 3630 |
|
| 3631 |
+
# Remove old float16.add tensors (we're rebuilding from scratch)
|
| 3632 |
+
old_float16_add = [k for k in tensors.keys() if k.startswith('float16.add')]
|
| 3633 |
+
for k in old_float16_add:
|
| 3634 |
+
del tensors[k]
|
| 3635 |
+
print(f"Removed {len(old_float16_add)} old float16.add tensors")
|
| 3636 |
+
|
| 3637 |
# Build new circuits
|
| 3638 |
print("Building new circuits...")
|
| 3639 |
clz_tensors = build_clz8bit_tensors()
|
|
|
|
| 3668 |
tensors.update(abs_tensors)
|
| 3669 |
print(f" float16.abs: {len(abs_tensors)} tensors")
|
| 3670 |
|
| 3671 |
+
add_tensors = build_float16_add_tensors()
|
| 3672 |
+
tensors.update(add_tensors)
|
| 3673 |
+
print(f" float16.add: {len(add_tensors)} tensors")
|
| 3674 |
+
|
| 3675 |
print(f"Total tensors: {len(tensors)}")
|
| 3676 |
|
| 3677 |
# Load routing for complex circuits
|
eval.py
CHANGED
|
@@ -632,6 +632,159 @@ class CircuitEvaluator:
|
|
| 632 |
|
| 633 |
return TestResult('float16.abs', passed, len(test_values), failures)
|
| 634 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
# =========================================================================
|
| 636 |
# ARITHMETIC TESTS (DIRECT EVALUATION)
|
| 637 |
# =========================================================================
|
|
@@ -827,6 +980,11 @@ class Evaluator:
|
|
| 827 |
self.results.append(result)
|
| 828 |
if verbose:
|
| 829 |
self._print_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
|
| 831 |
# Comparators
|
| 832 |
if verbose:
|
|
|
|
| 632 |
|
| 633 |
return TestResult('float16.abs', passed, len(test_values), failures)
|
| 634 |
|
| 635 |
+
def test_float16_add(self) -> TestResult:
|
| 636 |
+
"""Test float16.add (IEEE 754 addition)."""
|
| 637 |
+
prefix = 'float16.add'
|
| 638 |
+
failures = []
|
| 639 |
+
passed = 0
|
| 640 |
+
|
| 641 |
+
import struct
|
| 642 |
+
import math
|
| 643 |
+
|
| 644 |
+
def float16_to_float(bits):
|
| 645 |
+
try:
|
| 646 |
+
return struct.unpack('e', struct.pack('H', bits))[0]
|
| 647 |
+
except:
|
| 648 |
+
return float('nan')
|
| 649 |
+
|
| 650 |
+
def float_to_float16(f):
|
| 651 |
+
try:
|
| 652 |
+
return struct.unpack('H', struct.pack('e', f))[0]
|
| 653 |
+
except:
|
| 654 |
+
return 0x7E00 # NaN
|
| 655 |
+
|
| 656 |
+
# Test cases: pairs of (a, b)
|
| 657 |
+
test_cases = [
|
| 658 |
+
# Zero cases
|
| 659 |
+
(0x0000, 0x0000), # +0 + +0 = +0
|
| 660 |
+
(0x0000, 0x3C00), # +0 + 1.0 = 1.0
|
| 661 |
+
(0x3C00, 0x0000), # 1.0 + +0 = 1.0
|
| 662 |
+
|
| 663 |
+
# Same sign addition
|
| 664 |
+
(0x3C00, 0x3C00), # 1.0 + 1.0 = 2.0
|
| 665 |
+
(0x4000, 0x3C00), # 2.0 + 1.0 = 3.0
|
| 666 |
+
(0x3800, 0x3800), # 0.5 + 0.5 = 1.0
|
| 667 |
+
(0x4200, 0x4000), # 3.0 + 2.0 = 5.0
|
| 668 |
+
|
| 669 |
+
# Different sign (subtraction)
|
| 670 |
+
(0x4000, 0xBC00), # 2.0 + (-1.0) = 1.0
|
| 671 |
+
(0x3C00, 0xBC00), # 1.0 + (-1.0) = 0.0
|
| 672 |
+
(0xBC00, 0x4000), # -1.0 + 2.0 = 1.0
|
| 673 |
+
(0xC000, 0x3C00), # -2.0 + 1.0 = -1.0
|
| 674 |
+
|
| 675 |
+
# Negative + negative
|
| 676 |
+
(0xBC00, 0xBC00), # -1.0 + -1.0 = -2.0
|
| 677 |
+
(0xC000, 0xBC00), # -2.0 + -1.0 = -3.0
|
| 678 |
+
|
| 679 |
+
# Different exponents
|
| 680 |
+
(0x4400, 0x3C00), # 4.0 + 1.0 = 5.0
|
| 681 |
+
(0x4800, 0x3C00), # 8.0 + 1.0 = 9.0
|
| 682 |
+
(0x3C00, 0x3400), # 1.0 + 0.25 = 1.25
|
| 683 |
+
|
| 684 |
+
# Infinity cases
|
| 685 |
+
(0x7C00, 0x3C00), # +inf + 1.0 = +inf
|
| 686 |
+
(0x3C00, 0x7C00), # 1.0 + +inf = +inf
|
| 687 |
+
(0xFC00, 0xBC00), # -inf + -1.0 = -inf
|
| 688 |
+
(0x7C00, 0xFC00), # +inf + -inf = NaN
|
| 689 |
+
|
| 690 |
+
# NaN cases
|
| 691 |
+
(0x7E00, 0x3C00), # NaN + 1.0 = NaN
|
| 692 |
+
(0x3C00, 0x7E00), # 1.0 + NaN = NaN
|
| 693 |
+
]
|
| 694 |
+
|
| 695 |
+
# Add some random test cases
|
| 696 |
+
import random
|
| 697 |
+
random.seed(42)
|
| 698 |
+
for _ in range(50):
|
| 699 |
+
a = random.randint(0, 0x7BFF) # positive normal
|
| 700 |
+
b = random.randint(0, 0x7BFF)
|
| 701 |
+
test_cases.append((a, b))
|
| 702 |
+
# Some negative combinations
|
| 703 |
+
if random.random() > 0.5:
|
| 704 |
+
test_cases.append((a | 0x8000, b))
|
| 705 |
+
if random.random() > 0.5:
|
| 706 |
+
test_cases.append((a, b | 0x8000))
|
| 707 |
+
|
| 708 |
+
for a_bits, b_bits in test_cases:
|
| 709 |
+
a_float = float16_to_float(a_bits)
|
| 710 |
+
b_float = float16_to_float(b_bits)
|
| 711 |
+
|
| 712 |
+
# Expected result
|
| 713 |
+
if math.isnan(a_float) or math.isnan(b_float):
|
| 714 |
+
expected_nan = True
|
| 715 |
+
expected_inf = False
|
| 716 |
+
expected = 0x7E00
|
| 717 |
+
elif math.isinf(a_float) and math.isinf(b_float):
|
| 718 |
+
if (a_float > 0) != (b_float > 0):
|
| 719 |
+
expected_nan = True
|
| 720 |
+
expected_inf = False
|
| 721 |
+
expected = 0x7E00
|
| 722 |
+
else:
|
| 723 |
+
expected_nan = False
|
| 724 |
+
expected_inf = True
|
| 725 |
+
expected = 0x7C00 if a_float > 0 else 0xFC00
|
| 726 |
+
elif math.isinf(a_float):
|
| 727 |
+
expected_nan = False
|
| 728 |
+
expected_inf = True
|
| 729 |
+
expected = 0x7C00 if a_float > 0 else 0xFC00
|
| 730 |
+
elif math.isinf(b_float):
|
| 731 |
+
expected_nan = False
|
| 732 |
+
expected_inf = True
|
| 733 |
+
expected = 0x7C00 if b_float > 0 else 0xFC00
|
| 734 |
+
else:
|
| 735 |
+
expected_nan = False
|
| 736 |
+
expected_inf = False
|
| 737 |
+
result_float = a_float + b_float
|
| 738 |
+
expected = float_to_float16(result_float)
|
| 739 |
+
|
| 740 |
+
# Set up inputs
|
| 741 |
+
ext = {}
|
| 742 |
+
for i in range(16):
|
| 743 |
+
ext[f'{prefix}.$a[{i}]'] = float((a_bits >> i) & 1)
|
| 744 |
+
ext[f'{prefix}.$b[{i}]'] = float((b_bits >> i) & 1)
|
| 745 |
+
|
| 746 |
+
values = self.eval_circuit(prefix, ext)
|
| 747 |
+
|
| 748 |
+
# Extract result
|
| 749 |
+
result = 0
|
| 750 |
+
for i in range(16):
|
| 751 |
+
bit = int(values.get(f'{prefix}.out{i}', 0))
|
| 752 |
+
result |= (bit << i)
|
| 753 |
+
|
| 754 |
+
# Check special cases first
|
| 755 |
+
result_is_nan = int(values.get(f'{prefix}.result_is_nan', 0))
|
| 756 |
+
result_is_inf = int(values.get(f'{prefix}.result_is_inf', 0))
|
| 757 |
+
|
| 758 |
+
# For NaN, check that result_is_nan is set
|
| 759 |
+
if expected_nan:
|
| 760 |
+
if result_is_nan == 1:
|
| 761 |
+
passed += 1
|
| 762 |
+
else:
|
| 763 |
+
if len(failures) < 10:
|
| 764 |
+
failures.append((a_bits, b_bits, 'expected NaN', result, a_float, b_float))
|
| 765 |
+
# For Inf, check result_is_inf and sign
|
| 766 |
+
elif expected_inf:
|
| 767 |
+
expected_sign = (expected >> 15) & 1
|
| 768 |
+
result_sign = (result >> 15) & 1
|
| 769 |
+
if result_is_inf == 1:
|
| 770 |
+
passed += 1
|
| 771 |
+
else:
|
| 772 |
+
if len(failures) < 10:
|
| 773 |
+
failures.append((a_bits, b_bits, expected, result, a_float, b_float))
|
| 774 |
+
else:
|
| 775 |
+
# For normal results, allow small tolerance
|
| 776 |
+
if result == expected:
|
| 777 |
+
passed += 1
|
| 778 |
+
else:
|
| 779 |
+
# Check if within 1 ULP
|
| 780 |
+
if abs(result - expected) <= 1:
|
| 781 |
+
passed += 1
|
| 782 |
+
else:
|
| 783 |
+
if len(failures) < 10:
|
| 784 |
+
failures.append((a_bits, b_bits, expected, result, a_float, b_float))
|
| 785 |
+
|
| 786 |
+
return TestResult('float16.add', passed, len(test_cases), failures)
|
| 787 |
+
|
| 788 |
# =========================================================================
|
| 789 |
# ARITHMETIC TESTS (DIRECT EVALUATION)
|
| 790 |
# =========================================================================
|
|
|
|
| 980 |
self.results.append(result)
|
| 981 |
if verbose:
|
| 982 |
self._print_result(result)
|
| 983 |
+
if 'float16.add.sign_a.weight' in self.eval.tensors:
|
| 984 |
+
result = self.eval.test_float16_add()
|
| 985 |
+
self.results.append(result)
|
| 986 |
+
if verbose:
|
| 987 |
+
self._print_result(result)
|
| 988 |
|
| 989 |
# Comparators
|
| 990 |
if verbose:
|