PortfolioAI commited on
Commit ·
32f7c0d
1
Parent(s): fd6cf07
Add multi-bit carry infrastructure for float16.mul/div
Browse files- Add col_bit2 (floor/4 mod 2) and col_bit3 (floor/8 mod 2) gates
- Add carry accumulator gates for positions receiving multiple carries
- Update TODO.md with detailed remaining work documentation
- Move completed float16 circuits to Completed section
Mul/div still failing due to carry_acc_carry propagation issue.
Proper fix requires Wallace/Dadda tree or secondary carry chain.
- TODO.md +73 -21
- arithmetic.safetensors +2 -2
- convert_to_explicit_inputs.py +356 -34
TODO.md
CHANGED
|
@@ -2,23 +2,60 @@
|
|
| 2 |
|
| 3 |
## High Priority
|
| 4 |
|
| 5 |
-
### Floating Point Circuits
|
| 6 |
-
- [x] `float16.unpack` -- extract sign, exponent, mantissa (16 gates, 63/63 tests)
|
| 7 |
-
- [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
|
| 8 |
-
- [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
|
| 9 |
-
- [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
|
| 10 |
-
- [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
|
| 11 |
-
- [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
|
| 12 |
-
- [ ] `float16.mul` -- IEEE 754 multiplication (766 gates, 13/84 tests, col_sum precision)
|
| 13 |
-
- [ ] `float16.div` -- IEEE 754 division (1854 gates, 5/53 tests, col_sum precision)
|
| 14 |
-
- [x] `float16.toint` -- float16 to int16 (401 gates, 93/93 tests)
|
| 15 |
-
- [x] `float16.fromint` -- int16 to float16 (478 gates, 53/53 tests)
|
| 16 |
-
- [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
|
| 17 |
-
- [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
|
| 18 |
|
| 19 |
-
###
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
## Medium Priority
|
| 24 |
|
|
@@ -30,11 +67,6 @@
|
|
| 30 |
- [ ] `arithmetic.gcd8bit` -- greatest common divisor
|
| 31 |
- [ ] `arithmetic.lcm8bit` -- least common multiple
|
| 32 |
|
| 33 |
-
### Evaluator Improvements
|
| 34 |
-
- [x] Full circuit evaluation using .inputs topology
|
| 35 |
-
- [x] Exhaustive testing for boolean, threshold, CLZ, float16, comparator circuits
|
| 36 |
-
- [x] Automatic topological sort from signal registry
|
| 37 |
-
|
| 38 |
## Low Priority
|
| 39 |
|
| 40 |
### Transcendental Approximations
|
|
@@ -56,6 +88,26 @@
|
|
| 56 |
|
| 57 |
## Completed
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
- [x] Boolean gates (AND, OR, NOT, NAND, NOR, XOR, XNOR, IMPLIES, BIIMPLIES)
|
| 60 |
- [x] Arithmetic adders (half, full, ripple carry 2/4/8 bit)
|
| 61 |
- [x] Arithmetic subtraction (SUB, SBC, NEG)
|
|
|
|
| 2 |
|
| 3 |
## High Priority
|
| 4 |
|
| 5 |
+
### Floating Point Circuits - Remaining Work
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
#### `float16.mul` -- IEEE 754 multiplication (~800 gates, ~55/84 tests)
|
| 8 |
+
|
| 9 |
+
**Problem**: Multi-bit carry propagation in 11x11 mantissa multiplier.
|
| 10 |
+
|
| 11 |
+
**Background**: The mantissa multiplier produces a 22-bit product from two 11-bit mantissas (including implicit leading 1). Each column `i` has `min(i+1, 21-i)` partial products (AND gates). Column 10 has the maximum of 11 partial products.
|
| 12 |
+
|
| 13 |
+
**Current Implementation**:
|
| 14 |
+
- Column sums computed via threshold gates: `col_sum = parity(PP_0, PP_1, ..., PP_n)`
|
| 15 |
+
- Parity computed as `(ge1 AND NOT ge2) OR (ge3 AND NOT ge4) OR ...`
|
| 16 |
+
- `col_bit1` = floor(sum/2) mod 2 (carry to next position)
|
| 17 |
+
- `col_bit2` = floor(sum/4) mod 2 (carry to position i+2)
|
| 18 |
+
- `col_bit3` = floor(sum/8) mod 2 (carry to position i+3)
|
| 19 |
+
- Carry accumulator gates sum incoming carries from multiple columns
|
| 20 |
+
|
| 21 |
+
**Remaining Issue**: The carry accumulator can itself produce a carry (`carry_acc_carry`) when the sum of incoming carry bits is >= 2. This secondary carry needs to propagate to position i+1, creating a cascading effect that requires either:
|
| 22 |
+
1. A proper CSA (Carry Save Adder) tree structure, or
|
| 23 |
+
2. A secondary FA chain for accumulated carries, or
|
| 24 |
+
3. Iterating until carry stabilization
|
| 25 |
+
|
| 26 |
+
**Files**: `convert_to_explicit_inputs.py` lines 5350-5650 (build), lines 2400-2700 (infer)
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
#### `float16.div` -- IEEE 754 division (~1900 gates, ~5/53 tests)
|
| 31 |
+
|
| 32 |
+
**Problem**: Same multi-bit carry issue as multiplication, plus potential issues in the non-restoring division algorithm.
|
| 33 |
+
|
| 34 |
+
**Background**: Division uses non-restoring algorithm with 11-bit dividend and divisor. The quotient mantissa is computed iteratively, and similar column reduction issues arise.
|
| 35 |
+
|
| 36 |
+
**Current Implementation**:
|
| 37 |
+
- NaN output bit 9 fixed (canonical NaN = 0x7E00)
|
| 38 |
+
- Column sum parity gates similar to multiplication
|
| 39 |
+
|
| 40 |
+
**Remaining Issues**:
|
| 41 |
+
1. Same multi-bit carry propagation problem as multiplication
|
| 42 |
+
2. May have additional issues in division-specific logic (partial remainder computation)
|
| 43 |
+
|
| 44 |
+
**Files**: `convert_to_explicit_inputs.py` lines 5700-6200 (build), lines 2700-3100 (infer)
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
### Potential Solutions for Carry Propagation
|
| 49 |
+
|
| 50 |
+
1. **Wallace Tree**: Replace column reduction with proper Wallace tree structure. More gates but handles arbitrary partial product counts correctly.
|
| 51 |
+
|
| 52 |
+
2. **Dadda Tree**: Similar to Wallace but minimizes gate count per level.
|
| 53 |
+
|
| 54 |
+
3. **Iterative Carry Resolution**: After initial FA chain, detect remaining carries and iterate until stable. Simple but slow.
|
| 55 |
+
|
| 56 |
+
4. **Hybrid Approach**: Use threshold gates for small columns (2-3 PPs) and proper tree reduction for larger columns.
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
|
| 60 |
## Medium Priority
|
| 61 |
|
|
|
|
| 67 |
- [ ] `arithmetic.gcd8bit` -- greatest common divisor
|
| 68 |
- [ ] `arithmetic.lcm8bit` -- least common multiple
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
## Low Priority
|
| 71 |
|
| 72 |
### Transcendental Approximations
|
|
|
|
| 88 |
|
| 89 |
## Completed
|
| 90 |
|
| 91 |
+
### Floating Point Circuits
|
| 92 |
+
- [x] `float16.unpack` -- extract sign, exponent, mantissa (16 gates, 63/63 tests)
|
| 93 |
+
- [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
|
| 94 |
+
- [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
|
| 95 |
+
- [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
|
| 96 |
+
- [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
|
| 97 |
+
- [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
|
| 98 |
+
- [x] `float16.toint` -- float16 to int16 (401 gates, 93/93 tests)
|
| 99 |
+
- [x] `float16.fromint` -- int16 to float16 (478 gates, 53/53 tests)
|
| 100 |
+
- [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
|
| 101 |
+
- [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
|
| 102 |
+
|
| 103 |
+
### Supporting Infrastructure
|
| 104 |
+
- [x] `arithmetic.clz8bit` -- 8-bit count leading zeros (30 gates, 256/256 tests)
|
| 105 |
+
- [x] `arithmetic.clz16bit` -- 16-bit count leading zeros (63 gates, 217/217 tests)
|
| 106 |
+
- [x] Full circuit evaluation using .inputs topology
|
| 107 |
+
- [x] Exhaustive testing for boolean, threshold, CLZ, float16, comparator circuits
|
| 108 |
+
- [x] Automatic topological sort from signal registry
|
| 109 |
+
|
| 110 |
+
### Core Circuits
|
| 111 |
- [x] Boolean gates (AND, OR, NOT, NAND, NOR, XOR, XNOR, IMPLIES, BIIMPLIES)
|
| 112 |
- [x] Arithmetic adders (half, full, ripple carry 2/4/8 bit)
|
| 113 |
- [x] Arithmetic subtraction (SUB, SBC, NEG)
|
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce88f97552b5471d1c2adb4a88ea64ddaf8e05537884a12624a455da32531910
|
| 3 |
+
size 2980552
|
convert_to_explicit_inputs.py
CHANGED
|
@@ -2410,28 +2410,192 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 2410 |
if 0 <= j < 11:
|
| 2411 |
pps.append(f"{prefix}.pp{i}_{j}")
|
| 2412 |
|
| 2413 |
-
|
|
|
|
| 2414 |
if f'.col{col}' in gate and f'.col{col}_' not in gate:
|
| 2415 |
return [registry.get_id(pps[0])]
|
| 2416 |
registry.register(f"{prefix}.col{col}")
|
| 2417 |
-
elif
|
| 2418 |
-
|
| 2419 |
-
return [registry.get_id(pp) for pp in pps]
|
| 2420 |
-
registry.register(f"{prefix}.col{col}_sum")
|
| 2421 |
-
|
| 2422 |
match = re.search(rf'\.col{col}_ge(\d+)$', gate)
|
| 2423 |
if match:
|
| 2424 |
return [registry.get_id(pp) for pp in pps]
|
| 2425 |
|
| 2426 |
-
for t in range(1,
|
| 2427 |
registry.register(f"{prefix}.col{col}_ge{t}")
|
| 2428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2429 |
if '.prod_fa' in gate:
|
| 2430 |
match = re.search(r'\.prod_fa(\d+)\.', gate)
|
| 2431 |
if match:
|
| 2432 |
i = int(match.group(1))
|
| 2433 |
fa_prefix = f"{prefix}.prod_fa{i}"
|
| 2434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2435 |
# Count partial products in each column to determine signal names
|
| 2436 |
# col 0 and col 20 have 1 PP each, others have more
|
| 2437 |
def get_col_sum(col):
|
|
@@ -2441,26 +2605,43 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 2441 |
return registry.get_id(f"{prefix}.col{col}_sum")
|
| 2442 |
return registry.get_id("#0")
|
| 2443 |
|
| 2444 |
-
def
|
| 2445 |
-
#
|
| 2446 |
-
|
| 2447 |
-
|
| 2448 |
-
|
| 2449 |
-
if
|
| 2450 |
-
|
| 2451 |
-
|
| 2452 |
-
|
| 2453 |
-
|
| 2454 |
-
|
| 2455 |
-
|
| 2456 |
-
|
| 2457 |
-
|
| 2458 |
-
if
|
| 2459 |
-
return
|
|
|
|
|
|
|
| 2460 |
else:
|
| 2461 |
-
|
| 2462 |
-
|
| 2463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2464 |
|
| 2465 |
if i == 0:
|
| 2466 |
a_bit = get_col_sum(0)
|
|
@@ -2468,7 +2649,7 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 2468 |
cin = registry.get_id("#0")
|
| 2469 |
else:
|
| 2470 |
a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
|
| 2471 |
-
b_bit =
|
| 2472 |
cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
|
| 2473 |
|
| 2474 |
if '.xor1.layer1' in gate:
|
|
@@ -5360,17 +5541,158 @@ def build_float16_mul_tensors() -> Dict[str, torch.Tensor]:
|
|
| 5360 |
tensors[f"{prefix}.col{col}.weight"] = torch.tensor([1.0])
|
| 5361 |
tensors[f"{prefix}.col{col}.bias"] = torch.tensor([-0.5])
|
| 5362 |
else:
|
| 5363 |
-
# Multi-bit column:
|
| 5364 |
-
#
|
| 5365 |
-
#
|
| 5366 |
-
tensors[f"{prefix}.col{col}_sum.weight"] = torch.tensor([1.0] * count)
|
| 5367 |
-
tensors[f"{prefix}.col{col}_sum.bias"] = torch.tensor([-0.5])
|
| 5368 |
|
| 5369 |
-
#
|
| 5370 |
-
for t in range(1, count):
|
| 5371 |
tensors[f"{prefix}.col{col}_ge{t}.weight"] = torch.tensor([1.0] * count)
|
| 5372 |
tensors[f"{prefix}.col{col}_ge{t}.bias"] = torch.tensor([-float(t)])
|
| 5373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5374 |
# Final product assembly using ripple carry
|
| 5375 |
for i in range(22):
|
| 5376 |
p = f"{prefix}.prod_fa{i}"
|
|
|
|
| 2410 |
if 0 <= j < 11:
|
| 2411 |
pps.append(f"{prefix}.pp{i}_{j}")
|
| 2412 |
|
| 2413 |
+
count = len(pps)
|
| 2414 |
+
if count == 1:
|
| 2415 |
if f'.col{col}' in gate and f'.col{col}_' not in gate:
|
| 2416 |
return [registry.get_id(pps[0])]
|
| 2417 |
registry.register(f"{prefix}.col{col}")
|
| 2418 |
+
elif count > 1:
|
| 2419 |
+
# ge{t} gates: threshold >= t
|
|
|
|
|
|
|
|
|
|
| 2420 |
match = re.search(rf'\.col{col}_ge(\d+)$', gate)
|
| 2421 |
if match:
|
| 2422 |
return [registry.get_id(pp) for pp in pps]
|
| 2423 |
|
| 2424 |
+
for t in range(1, count + 1):
|
| 2425 |
registry.register(f"{prefix}.col{col}_ge{t}")
|
| 2426 |
|
| 2427 |
+
# not_ge{t} for even t
|
| 2428 |
+
match = re.search(rf'\.col{col}_not_ge(\d+)$', gate)
|
| 2429 |
+
if match:
|
| 2430 |
+
t = int(match.group(1))
|
| 2431 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
|
| 2432 |
+
|
| 2433 |
+
for t in range(2, count + 1, 2):
|
| 2434 |
+
registry.register(f"{prefix}.col{col}_not_ge{t}")
|
| 2435 |
+
|
| 2436 |
+
# odd{t} gates: ge{t} AND (NOT ge{t+1} or just ge{t} if t+1 > count)
|
| 2437 |
+
match = re.search(rf'\.col{col}_odd(\d+)$', gate)
|
| 2438 |
+
if match:
|
| 2439 |
+
t = int(match.group(1))
|
| 2440 |
+
if t + 1 <= count:
|
| 2441 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
|
| 2442 |
+
registry.get_id(f"{prefix}.col{col}_not_ge{t+1}")]
|
| 2443 |
+
else:
|
| 2444 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
|
| 2445 |
+
|
| 2446 |
+
odd_ranges = []
|
| 2447 |
+
for t in range(1, count + 1, 2):
|
| 2448 |
+
registry.register(f"{prefix}.col{col}_odd{t}")
|
| 2449 |
+
odd_ranges.append(f"{prefix}.col{col}_odd{t}")
|
| 2450 |
+
|
| 2451 |
+
# col_sum = OR of all odd gates (parity)
|
| 2452 |
+
if f'.col{col}_sum' in gate:
|
| 2453 |
+
return [registry.get_id(r) for r in odd_ranges]
|
| 2454 |
+
registry.register(f"{prefix}.col{col}_sum")
|
| 2455 |
+
|
| 2456 |
+
# col_bit1 gates (floor(sum/2) mod 2)
|
| 2457 |
+
if count >= 2:
|
| 2458 |
+
match = re.search(rf'\.col{col}_bit1_(\d+)$', gate)
|
| 2459 |
+
if match:
|
| 2460 |
+
t = int(match.group(1))
|
| 2461 |
+
upper = t + 2
|
| 2462 |
+
if upper <= count:
|
| 2463 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
|
| 2464 |
+
registry.get_id(f"{prefix}.col{col}_not_ge{upper}")]
|
| 2465 |
+
else:
|
| 2466 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
|
| 2467 |
+
|
| 2468 |
+
bit1_ranges = []
|
| 2469 |
+
for t in range(2, count + 1, 4):
|
| 2470 |
+
registry.register(f"{prefix}.col{col}_bit1_{t}")
|
| 2471 |
+
bit1_ranges.append(f"{prefix}.col{col}_bit1_{t}")
|
| 2472 |
+
|
| 2473 |
+
if f'.col{col}_bit1' in gate and f'.col{col}_bit1_' not in gate:
|
| 2474 |
+
return [registry.get_id(r) for r in bit1_ranges]
|
| 2475 |
+
if bit1_ranges:
|
| 2476 |
+
registry.register(f"{prefix}.col{col}_bit1")
|
| 2477 |
+
|
| 2478 |
+
# col_bit2 gates (floor(sum/4) mod 2)
|
| 2479 |
+
if count >= 4:
|
| 2480 |
+
match = re.search(rf'\.col{col}_bit2_(\d+)$', gate)
|
| 2481 |
+
if match:
|
| 2482 |
+
t = int(match.group(1))
|
| 2483 |
+
upper = t + 4
|
| 2484 |
+
if upper <= count:
|
| 2485 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
|
| 2486 |
+
registry.get_id(f"{prefix}.col{col}_not_ge{upper}")]
|
| 2487 |
+
else:
|
| 2488 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
|
| 2489 |
+
|
| 2490 |
+
bit2_ranges = []
|
| 2491 |
+
for t in range(4, count + 1, 8):
|
| 2492 |
+
registry.register(f"{prefix}.col{col}_bit2_{t}")
|
| 2493 |
+
bit2_ranges.append(f"{prefix}.col{col}_bit2_{t}")
|
| 2494 |
+
|
| 2495 |
+
if f'.col{col}_bit2' in gate and f'.col{col}_bit2_' not in gate:
|
| 2496 |
+
return [registry.get_id(r) for r in bit2_ranges]
|
| 2497 |
+
if bit2_ranges:
|
| 2498 |
+
registry.register(f"{prefix}.col{col}_bit2")
|
| 2499 |
+
|
| 2500 |
+
# col_bit3 gates (floor(sum/8) mod 2)
|
| 2501 |
+
if count >= 8:
|
| 2502 |
+
match = re.search(rf'\.col{col}_bit3_(\d+)$', gate)
|
| 2503 |
+
if match:
|
| 2504 |
+
t = int(match.group(1))
|
| 2505 |
+
upper = t + 8
|
| 2506 |
+
if upper <= count:
|
| 2507 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
|
| 2508 |
+
registry.get_id(f"{prefix}.col{col}_not_ge{upper}")]
|
| 2509 |
+
else:
|
| 2510 |
+
return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
|
| 2511 |
+
|
| 2512 |
+
bit3_ranges = []
|
| 2513 |
+
for t in range(8, count + 1, 16):
|
| 2514 |
+
registry.register(f"{prefix}.col{col}_bit3_{t}")
|
| 2515 |
+
bit3_ranges.append(f"{prefix}.col{col}_bit3_{t}")
|
| 2516 |
+
|
| 2517 |
+
if f'.col{col}_bit3' in gate and f'.col{col}_bit3_' not in gate:
|
| 2518 |
+
return [registry.get_id(r) for r in bit3_ranges]
|
| 2519 |
+
if bit3_ranges:
|
| 2520 |
+
registry.register(f"{prefix}.col{col}_bit3")
|
| 2521 |
+
|
| 2522 |
+
# Handle carry accumulator gates
|
| 2523 |
+
if '.carry_acc' in gate:
|
| 2524 |
+
match = re.search(r'\.carry_acc(\d+)_', gate)
|
| 2525 |
+
if match:
|
| 2526 |
+
i = int(match.group(1))
|
| 2527 |
+
|
| 2528 |
+
def get_pp_count(col):
|
| 2529 |
+
if col < 0 or col > 20:
|
| 2530 |
+
return 0
|
| 2531 |
+
return min(col + 1, 21 - col)
|
| 2532 |
+
|
| 2533 |
+
# Determine which carry bits come into position i
|
| 2534 |
+
carry_inputs = []
|
| 2535 |
+
if i >= 1 and get_pp_count(i-1) >= 2:
|
| 2536 |
+
carry_inputs.append(registry.get_id(f"{prefix}.col{i-1}_bit1"))
|
| 2537 |
+
if i >= 2 and get_pp_count(i-2) >= 4:
|
| 2538 |
+
carry_inputs.append(registry.get_id(f"{prefix}.col{i-2}_bit2"))
|
| 2539 |
+
if i >= 3 and get_pp_count(i-3) >= 8:
|
| 2540 |
+
carry_inputs.append(registry.get_id(f"{prefix}.col{i-3}_bit3"))
|
| 2541 |
+
|
| 2542 |
+
n = len(carry_inputs)
|
| 2543 |
+
|
| 2544 |
+
# ge{t} gates
|
| 2545 |
+
match_ge = re.search(rf'\.carry_acc{i}_ge(\d+)$', gate)
|
| 2546 |
+
if match_ge:
|
| 2547 |
+
return carry_inputs
|
| 2548 |
+
|
| 2549 |
+
# not_ge{t} gates
|
| 2550 |
+
match_not = re.search(rf'\.carry_acc{i}_not_ge(\d+)$', gate)
|
| 2551 |
+
if match_not:
|
| 2552 |
+
t = int(match_not.group(1))
|
| 2553 |
+
return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}")]
|
| 2554 |
+
|
| 2555 |
+
# Register ge gates
|
| 2556 |
+
for t in range(1, n + 1):
|
| 2557 |
+
registry.register(f"{prefix}.carry_acc{i}_ge{t}")
|
| 2558 |
+
for t in range(2, n + 1, 2):
|
| 2559 |
+
registry.register(f"{prefix}.carry_acc{i}_not_ge{t}")
|
| 2560 |
+
|
| 2561 |
+
# odd{t} gates
|
| 2562 |
+
match_odd = re.search(rf'\.carry_acc{i}_odd(\d+)$', gate)
|
| 2563 |
+
if match_odd:
|
| 2564 |
+
t = int(match_odd.group(1))
|
| 2565 |
+
if t + 1 <= n:
|
| 2566 |
+
return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}"),
|
| 2567 |
+
registry.get_id(f"{prefix}.carry_acc{i}_not_ge{t+1}")]
|
| 2568 |
+
else:
|
| 2569 |
+
return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}")]
|
| 2570 |
+
|
| 2571 |
+
# Register odd gates
|
| 2572 |
+
odd_ranges = []
|
| 2573 |
+
for t in range(1, n + 1, 2):
|
| 2574 |
+
registry.register(f"{prefix}.carry_acc{i}_odd{t}")
|
| 2575 |
+
odd_ranges.append(f"{prefix}.carry_acc{i}_odd{t}")
|
| 2576 |
+
|
| 2577 |
+
# carry_acc_sum = OR of odd gates
|
| 2578 |
+
if f'.carry_acc{i}_sum' in gate:
|
| 2579 |
+
return [registry.get_id(r) for r in odd_ranges]
|
| 2580 |
+
registry.register(f"{prefix}.carry_acc{i}_sum")
|
| 2581 |
+
|
| 2582 |
+
# carry_acc_carry = ge2
|
| 2583 |
+
if f'.carry_acc{i}_carry' in gate:
|
| 2584 |
+
return carry_inputs
|
| 2585 |
+
if n >= 2:
|
| 2586 |
+
registry.register(f"{prefix}.carry_acc{i}_carry")
|
| 2587 |
+
|
| 2588 |
if '.prod_fa' in gate:
|
| 2589 |
match = re.search(r'\.prod_fa(\d+)\.', gate)
|
| 2590 |
if match:
|
| 2591 |
i = int(match.group(1))
|
| 2592 |
fa_prefix = f"{prefix}.prod_fa{i}"
|
| 2593 |
|
| 2594 |
+
def get_pp_count(col):
|
| 2595 |
+
if col < 0 or col > 20:
|
| 2596 |
+
return 0
|
| 2597 |
+
return min(col + 1, 21 - col)
|
| 2598 |
+
|
| 2599 |
# Count partial products in each column to determine signal names
|
| 2600 |
# col 0 and col 20 have 1 PP each, others have more
|
| 2601 |
def get_col_sum(col):
|
|
|
|
| 2605 |
return registry.get_id(f"{prefix}.col{col}_sum")
|
| 2606 |
return registry.get_id("#0")
|
| 2607 |
|
| 2608 |
+
def get_b_bit(pos):
|
| 2609 |
+
# Determine incoming carries for position pos
|
| 2610 |
+
carry_inputs = []
|
| 2611 |
+
if pos >= 1 and get_pp_count(pos-1) >= 2:
|
| 2612 |
+
carry_inputs.append("bit1")
|
| 2613 |
+
if pos >= 2 and get_pp_count(pos-2) >= 4:
|
| 2614 |
+
carry_inputs.append("bit2")
|
| 2615 |
+
if pos >= 3 and get_pp_count(pos-3) >= 8:
|
| 2616 |
+
carry_inputs.append("bit3")
|
| 2617 |
+
|
| 2618 |
+
if len(carry_inputs) == 0:
|
| 2619 |
+
return registry.get_id("#0")
|
| 2620 |
+
elif len(carry_inputs) == 1:
|
| 2621 |
+
# Single carry, use it directly
|
| 2622 |
+
if carry_inputs[0] == "bit1":
|
| 2623 |
+
return registry.get_id(f"{prefix}.col{pos-1}_bit1")
|
| 2624 |
+
elif carry_inputs[0] == "bit2":
|
| 2625 |
+
return registry.get_id(f"{prefix}.col{pos-2}_bit2")
|
| 2626 |
else:
|
| 2627 |
+
return registry.get_id(f"{prefix}.col{pos-3}_bit3")
|
| 2628 |
+
else:
|
| 2629 |
+
# Multiple carries, use accumulator sum
|
| 2630 |
+
return registry.register(f"{prefix}.carry_acc{pos}_sum")
|
| 2631 |
+
|
| 2632 |
+
def get_extra_cin(pos):
|
| 2633 |
+
# Extra carry from accumulator (when sum of carries >= 2)
|
| 2634 |
+
carry_inputs = []
|
| 2635 |
+
if pos >= 1 and get_pp_count(pos-1) >= 2:
|
| 2636 |
+
carry_inputs.append("bit1")
|
| 2637 |
+
if pos >= 2 and get_pp_count(pos-2) >= 4:
|
| 2638 |
+
carry_inputs.append("bit2")
|
| 2639 |
+
if pos >= 3 and get_pp_count(pos-3) >= 8:
|
| 2640 |
+
carry_inputs.append("bit3")
|
| 2641 |
+
|
| 2642 |
+
if len(carry_inputs) >= 2:
|
| 2643 |
+
return registry.register(f"{prefix}.carry_acc{pos}_carry")
|
| 2644 |
+
return None
|
| 2645 |
|
| 2646 |
if i == 0:
|
| 2647 |
a_bit = get_col_sum(0)
|
|
|
|
| 2649 |
cin = registry.get_id("#0")
|
| 2650 |
else:
|
| 2651 |
a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
|
| 2652 |
+
b_bit = get_b_bit(i)
|
| 2653 |
cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
|
| 2654 |
|
| 2655 |
if '.xor1.layer1' in gate:
|
|
|
|
| 5541 |
tensors[f"{prefix}.col{col}.weight"] = torch.tensor([1.0])
|
| 5542 |
tensors[f"{prefix}.col{col}.bias"] = torch.tensor([-0.5])
|
| 5543 |
else:
|
| 5544 |
+
# Multi-bit column: compute parity (sum mod 2) using threshold gates
|
| 5545 |
+
# parity = (ge1 AND NOT ge2) OR (ge3 AND NOT ge4) OR ...
|
| 5546 |
+
# This captures: sum is odd when in range [1], [3,4), [5,6), etc.
|
|
|
|
|
|
|
| 5547 |
|
| 5548 |
+
# Threshold gates: ge{t} = 1 if sum >= t
|
| 5549 |
+
for t in range(1, count + 1):
|
| 5550 |
tensors[f"{prefix}.col{col}_ge{t}.weight"] = torch.tensor([1.0] * count)
|
| 5551 |
tensors[f"{prefix}.col{col}_ge{t}.bias"] = torch.tensor([-float(t)])
|
| 5552 |
|
| 5553 |
+
# NOT gates for even thresholds
|
| 5554 |
+
for t in range(2, count + 1, 2):
|
| 5555 |
+
tensors[f"{prefix}.col{col}_not_ge{t}.weight"] = torch.tensor([-1.0])
|
| 5556 |
+
tensors[f"{prefix}.col{col}_not_ge{t}.bias"] = torch.tensor([0.0])
|
| 5557 |
+
|
| 5558 |
+
# AND gates for odd ranges: (ge1 AND NOT ge2), (ge3 AND NOT ge4), ...
|
| 5559 |
+
odd_ranges = []
|
| 5560 |
+
for t in range(1, count + 1, 2):
|
| 5561 |
+
if t + 1 <= count:
|
| 5562 |
+
# ge{t} AND NOT ge{t+1}
|
| 5563 |
+
tensors[f"{prefix}.col{col}_odd{t}.weight"] = torch.tensor([1.0, 1.0])
|
| 5564 |
+
tensors[f"{prefix}.col{col}_odd{t}.bias"] = torch.tensor([-2.0])
|
| 5565 |
+
odd_ranges.append(t)
|
| 5566 |
+
else:
|
| 5567 |
+
# ge{t} only (no upper bound needed if t is max)
|
| 5568 |
+
tensors[f"{prefix}.col{col}_odd{t}.weight"] = torch.tensor([1.0])
|
| 5569 |
+
tensors[f"{prefix}.col{col}_odd{t}.bias"] = torch.tensor([-0.5])
|
| 5570 |
+
odd_ranges.append(t)
|
| 5571 |
+
|
| 5572 |
+
# col_sum = OR of all odd ranges (parity = bit 0)
|
| 5573 |
+
num_odd = len(odd_ranges)
|
| 5574 |
+
tensors[f"{prefix}.col{col}_sum.weight"] = torch.tensor([1.0] * num_odd)
|
| 5575 |
+
tensors[f"{prefix}.col{col}_sum.bias"] = torch.tensor([-0.5])
|
| 5576 |
+
|
| 5577 |
+
# col_bit1 = floor(sum/2) mod 2 = parity of [2,3], [6,7], [10,11], ...
|
| 5578 |
+
# This is (ge2 AND NOT ge4) OR (ge6 AND NOT ge8) OR ...
|
| 5579 |
+
if count >= 2:
|
| 5580 |
+
bit1_ranges = []
|
| 5581 |
+
for t in range(2, count + 1, 4):
|
| 5582 |
+
upper = t + 2
|
| 5583 |
+
if upper <= count:
|
| 5584 |
+
tensors[f"{prefix}.col{col}_bit1_{t}.weight"] = torch.tensor([1.0, 1.0])
|
| 5585 |
+
tensors[f"{prefix}.col{col}_bit1_{t}.bias"] = torch.tensor([-2.0])
|
| 5586 |
+
if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors:
|
| 5587 |
+
tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0])
|
| 5588 |
+
tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0])
|
| 5589 |
+
else:
|
| 5590 |
+
tensors[f"{prefix}.col{col}_bit1_{t}.weight"] = torch.tensor([1.0])
|
| 5591 |
+
tensors[f"{prefix}.col{col}_bit1_{t}.bias"] = torch.tensor([-0.5])
|
| 5592 |
+
bit1_ranges.append(t)
|
| 5593 |
+
|
| 5594 |
+
if bit1_ranges:
|
| 5595 |
+
tensors[f"{prefix}.col{col}_bit1.weight"] = torch.tensor([1.0] * len(bit1_ranges))
|
| 5596 |
+
tensors[f"{prefix}.col{col}_bit1.bias"] = torch.tensor([-0.5])
|
| 5597 |
+
|
| 5598 |
+
# col_bit2 = floor(sum/4) mod 2 = parity of [4,7], [12,15], ...
|
| 5599 |
+
# This is (ge4 AND NOT ge8) OR (ge12 AND NOT ge16) OR ...
|
| 5600 |
+
if count >= 4:
|
| 5601 |
+
bit2_ranges = []
|
| 5602 |
+
for t in range(4, count + 1, 8):
|
| 5603 |
+
upper = t + 4
|
| 5604 |
+
if upper <= count:
|
| 5605 |
+
tensors[f"{prefix}.col{col}_bit2_{t}.weight"] = torch.tensor([1.0, 1.0])
|
| 5606 |
+
tensors[f"{prefix}.col{col}_bit2_{t}.bias"] = torch.tensor([-2.0])
|
| 5607 |
+
if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors:
|
| 5608 |
+
tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0])
|
| 5609 |
+
tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0])
|
| 5610 |
+
else:
|
| 5611 |
+
tensors[f"{prefix}.col{col}_bit2_{t}.weight"] = torch.tensor([1.0])
|
| 5612 |
+
tensors[f"{prefix}.col{col}_bit2_{t}.bias"] = torch.tensor([-0.5])
|
| 5613 |
+
bit2_ranges.append(t)
|
| 5614 |
+
|
| 5615 |
+
if bit2_ranges:
|
| 5616 |
+
tensors[f"{prefix}.col{col}_bit2.weight"] = torch.tensor([1.0] * len(bit2_ranges))
|
| 5617 |
+
tensors[f"{prefix}.col{col}_bit2.bias"] = torch.tensor([-0.5])
|
| 5618 |
+
|
| 5619 |
+
# col_bit3 = floor(sum/8) mod 2 (for col10 with 11 PPs)
|
| 5620 |
+
if count >= 8:
|
| 5621 |
+
bit3_ranges = []
|
| 5622 |
+
for t in range(8, count + 1, 16):
|
| 5623 |
+
upper = t + 8
|
| 5624 |
+
if upper <= count:
|
| 5625 |
+
tensors[f"{prefix}.col{col}_bit3_{t}.weight"] = torch.tensor([1.0, 1.0])
|
| 5626 |
+
tensors[f"{prefix}.col{col}_bit3_{t}.bias"] = torch.tensor([-2.0])
|
| 5627 |
+
if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors:
|
| 5628 |
+
tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0])
|
| 5629 |
+
tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0])
|
| 5630 |
+
else:
|
| 5631 |
+
tensors[f"{prefix}.col{col}_bit3_{t}.weight"] = torch.tensor([1.0])
|
| 5632 |
+
tensors[f"{prefix}.col{col}_bit3_{t}.bias"] = torch.tensor([-0.5])
|
| 5633 |
+
bit3_ranges.append(t)
|
| 5634 |
+
|
| 5635 |
+
if bit3_ranges:
|
| 5636 |
+
tensors[f"{prefix}.col{col}_bit3.weight"] = torch.tensor([1.0] * len(bit3_ranges))
|
| 5637 |
+
tensors[f"{prefix}.col{col}_bit3.bias"] = torch.tensor([-0.5])
|
| 5638 |
+
|
| 5639 |
+
# Carry accumulator for multi-bit carries
|
| 5640 |
+
# For position i, incoming carries are: bit1[i-1], bit2[i-2], bit3[i-3]
|
| 5641 |
+
# We need to sum these and produce: carry_acc_sum (parity), carry_acc_carry (sum >= 2)
|
| 5642 |
+
def get_pp_count(col):
|
| 5643 |
+
if col < 0 or col > 20:
|
| 5644 |
+
return 0
|
| 5645 |
+
return min(col + 1, 21 - col)
|
| 5646 |
+
|
| 5647 |
+
for i in range(22):
|
| 5648 |
+
# Determine which carry bits come into position i
|
| 5649 |
+
carry_inputs = []
|
| 5650 |
+
# bit1 from col[i-1]
|
| 5651 |
+
if i >= 1 and get_pp_count(i-1) >= 2:
|
| 5652 |
+
carry_inputs.append(f"bit1_{i-1}")
|
| 5653 |
+
# bit2 from col[i-2]
|
| 5654 |
+
if i >= 2 and get_pp_count(i-2) >= 4:
|
| 5655 |
+
carry_inputs.append(f"bit2_{i-2}")
|
| 5656 |
+
# bit3 from col[i-3]
|
| 5657 |
+
if i >= 3 and get_pp_count(i-3) >= 8:
|
| 5658 |
+
carry_inputs.append(f"bit3_{i-3}")
|
| 5659 |
+
|
| 5660 |
+
if len(carry_inputs) == 0:
|
| 5661 |
+
# No carries, use #0
|
| 5662 |
+
pass
|
| 5663 |
+
elif len(carry_inputs) == 1:
|
| 5664 |
+
# Single carry, no accumulator needed
|
| 5665 |
+
pass
|
| 5666 |
+
else:
|
| 5667 |
+
# Multiple carries, need accumulator
|
| 5668 |
+
n = len(carry_inputs)
|
| 5669 |
+
# Parity (sum mod 2) using threshold gates
|
| 5670 |
+
# ge{t} = sum >= t
|
| 5671 |
+
for t in range(1, n + 1):
|
| 5672 |
+
tensors[f"{prefix}.carry_acc{i}_ge{t}.weight"] = torch.tensor([1.0] * n)
|
| 5673 |
+
tensors[f"{prefix}.carry_acc{i}_ge{t}.bias"] = torch.tensor([-float(t) + 0.5])
|
| 5674 |
+
# NOT gates for even thresholds
|
| 5675 |
+
for t in range(2, n + 1, 2):
|
| 5676 |
+
tensors[f"{prefix}.carry_acc{i}_not_ge{t}.weight"] = torch.tensor([-1.0])
|
| 5677 |
+
tensors[f"{prefix}.carry_acc{i}_not_ge{t}.bias"] = torch.tensor([0.0])
|
| 5678 |
+
# AND gates for odd ranges: (ge1 AND NOT ge2), (ge3 AND NOT ge4), ...
|
| 5679 |
+
odd_ranges = []
|
| 5680 |
+
for t in range(1, n + 1, 2):
|
| 5681 |
+
if t + 1 <= n:
|
| 5682 |
+
tensors[f"{prefix}.carry_acc{i}_odd{t}.weight"] = torch.tensor([1.0, 1.0])
|
| 5683 |
+
tensors[f"{prefix}.carry_acc{i}_odd{t}.bias"] = torch.tensor([-2.0])
|
| 5684 |
+
else:
|
| 5685 |
+
tensors[f"{prefix}.carry_acc{i}_odd{t}.weight"] = torch.tensor([1.0])
|
| 5686 |
+
tensors[f"{prefix}.carry_acc{i}_odd{t}.bias"] = torch.tensor([-0.5])
|
| 5687 |
+
odd_ranges.append(t)
|
| 5688 |
+
# carry_acc_sum = OR of odd ranges
|
| 5689 |
+
tensors[f"{prefix}.carry_acc{i}_sum.weight"] = torch.tensor([1.0] * len(odd_ranges))
|
| 5690 |
+
tensors[f"{prefix}.carry_acc{i}_sum.bias"] = torch.tensor([-0.5])
|
| 5691 |
+
# carry_acc_carry = ge2 (sum >= 2)
|
| 5692 |
+
if n >= 2:
|
| 5693 |
+
tensors[f"{prefix}.carry_acc{i}_carry.weight"] = torch.tensor([1.0] * n)
|
| 5694 |
+
tensors[f"{prefix}.carry_acc{i}_carry.bias"] = torch.tensor([-1.5])
|
| 5695 |
+
|
| 5696 |
# Final product assembly using ripple carry
|
| 5697 |
for i in range(22):
|
| 5698 |
p = f"{prefix}.prod_fa{i}"
|