PortfolioAI commited on
Commit
32f7c0d
·
1 Parent(s): fd6cf07

Add multi-bit carry infrastructure for float16.mul/div

Browse files

- Add col_bit2 (floor/4 mod 2) and col_bit3 (floor/8 mod 2) gates
- Add carry accumulator gates for positions receiving multiple carries
- Update TODO.md with detailed remaining work documentation
- Move completed float16 circuits to Completed section

Mul/div still failing due to carry_acc_carry propagation issue.
Proper fix requires Wallace/Dadda tree or secondary carry chain.

Files changed (3) hide show
  1. TODO.md +73 -21
  2. arithmetic.safetensors +2 -2
  3. convert_to_explicit_inputs.py +356 -34
TODO.md CHANGED
@@ -2,23 +2,60 @@
2
 
3
  ## High Priority
4
 
5
- ### Floating Point Circuits
6
- - [x] `float16.unpack` -- extract sign, exponent, mantissa (16 gates, 63/63 tests)
7
- - [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
8
- - [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
9
- - [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
10
- - [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
11
- - [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
12
- - [ ] `float16.mul` -- IEEE 754 multiplication (766 gates, 13/84 tests, col_sum precision)
13
- - [ ] `float16.div` -- IEEE 754 division (1854 gates, 5/53 tests, col_sum precision)
14
- - [x] `float16.toint` -- float16 to int16 (401 gates, 93/93 tests)
15
- - [x] `float16.fromint` -- int16 to float16 (478 gates, 53/53 tests)
16
- - [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
17
- - [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
18
 
19
- ### Supporting Infrastructure
20
- - [x] `arithmetic.clz8bit` -- 8-bit count leading zeros (30 gates, 256/256 tests)
21
- - [x] `arithmetic.clz16bit` -- 16-bit count leading zeros (63 gates, 217/217 tests)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  ## Medium Priority
24
 
@@ -30,11 +67,6 @@
30
  - [ ] `arithmetic.gcd8bit` -- greatest common divisor
31
  - [ ] `arithmetic.lcm8bit` -- least common multiple
32
 
33
- ### Evaluator Improvements
34
- - [x] Full circuit evaluation using .inputs topology
35
- - [x] Exhaustive testing for boolean, threshold, CLZ, float16, comparator circuits
36
- - [x] Automatic topological sort from signal registry
37
-
38
  ## Low Priority
39
 
40
  ### Transcendental Approximations
@@ -56,6 +88,26 @@
56
 
57
  ## Completed
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  - [x] Boolean gates (AND, OR, NOT, NAND, NOR, XOR, XNOR, IMPLIES, BIIMPLIES)
60
  - [x] Arithmetic adders (half, full, ripple carry 2/4/8 bit)
61
  - [x] Arithmetic subtraction (SUB, SBC, NEG)
 
2
 
3
  ## High Priority
4
 
5
+ ### Floating Point Circuits - Remaining Work
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ #### `float16.mul` -- IEEE 754 multiplication (~800 gates, ~55/84 tests)
8
+
9
+ **Problem**: Multi-bit carry propagation in 11x11 mantissa multiplier.
10
+
11
+ **Background**: The mantissa multiplier produces a 22-bit product from two 11-bit mantissas (including implicit leading 1). Each column `i` has `min(i+1, 21-i)` partial products (AND gates). Column 10 has the maximum of 11 partial products.
12
+
13
+ **Current Implementation**:
14
+ - Column sums computed via threshold gates: `col_sum = parity(PP_0, PP_1, ..., PP_n)`
15
+ - Parity computed as `(ge1 AND NOT ge2) OR (ge3 AND NOT ge4) OR ...`
16
+ - `col_bit1` = floor(sum/2) mod 2 (carry to next position)
17
+ - `col_bit2` = floor(sum/4) mod 2 (carry to position i+2)
18
+ - `col_bit3` = floor(sum/8) mod 2 (carry to position i+3)
19
+ - Carry accumulator gates sum incoming carries from multiple columns
20
+
21
+ **Remaining Issue**: The carry accumulator can itself produce a carry (`carry_acc_carry`) when the sum of incoming carry bits is >= 2. This secondary carry needs to propagate to position i+1, creating a cascading effect that requires either:
22
+ 1. A proper CSA (Carry Save Adder) tree structure, or
23
+ 2. A secondary FA chain for accumulated carries, or
24
+ 3. Iterating until carry stabilization
25
+
26
+ **Files**: `convert_to_explicit_inputs.py` lines 5350-5650 (build), lines 2400-2700 (infer)
27
+
28
+ ---
29
+
30
+ #### `float16.div` -- IEEE 754 division (~1900 gates, ~5/53 tests)
31
+
32
+ **Problem**: Same multi-bit carry issue as multiplication, plus potential issues in the non-restoring division algorithm.
33
+
34
+ **Background**: Division uses non-restoring algorithm with 11-bit dividend and divisor. The quotient mantissa is computed iteratively, and similar column reduction issues arise.
35
+
36
+ **Current Implementation**:
37
+ - NaN output bit 9 fixed (canonical NaN = 0x7E00)
38
+ - Column sum parity gates similar to multiplication
39
+
40
+ **Remaining Issues**:
41
+ 1. Same multi-bit carry propagation problem as multiplication
42
+ 2. May have additional issues in division-specific logic (partial remainder computation)
43
+
44
+ **Files**: `convert_to_explicit_inputs.py` lines 5700-6200 (build), lines 2700-3100 (infer)
45
+
46
+ ---
47
+
48
+ ### Potential Solutions for Carry Propagation
49
+
50
+ 1. **Wallace Tree**: Replace column reduction with proper Wallace tree structure. More gates but handles arbitrary partial product counts correctly.
51
+
52
+ 2. **Dadda Tree**: Similar to Wallace but minimizes gate count per level.
53
+
54
+ 3. **Iterative Carry Resolution**: After initial FA chain, detect remaining carries and iterate until stable. Simple but slow.
55
+
56
+ 4. **Hybrid Approach**: Use threshold gates for small columns (2-3 PPs) and proper tree reduction for larger columns.
57
+
58
+ ---
59
 
60
  ## Medium Priority
61
 
 
67
  - [ ] `arithmetic.gcd8bit` -- greatest common divisor
68
  - [ ] `arithmetic.lcm8bit` -- least common multiple
69
 
 
 
 
 
 
70
  ## Low Priority
71
 
72
  ### Transcendental Approximations
 
88
 
89
  ## Completed
90
 
91
+ ### Floating Point Circuits
92
+ - [x] `float16.unpack` -- extract sign, exponent, mantissa (16 gates, 63/63 tests)
93
+ - [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
94
+ - [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
95
+ - [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
96
+ - [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
97
+ - [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
98
+ - [x] `float16.toint` -- float16 to int16 (401 gates, 93/93 tests)
99
+ - [x] `float16.fromint` -- int16 to float16 (478 gates, 53/53 tests)
100
+ - [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
101
+ - [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
102
+
103
+ ### Supporting Infrastructure
104
+ - [x] `arithmetic.clz8bit` -- 8-bit count leading zeros (30 gates, 256/256 tests)
105
+ - [x] `arithmetic.clz16bit` -- 16-bit count leading zeros (63 gates, 217/217 tests)
106
+ - [x] Full circuit evaluation using .inputs topology
107
+ - [x] Exhaustive testing for boolean, threshold, CLZ, float16, comparator circuits
108
+ - [x] Automatic topological sort from signal registry
109
+
110
+ ### Core Circuits
111
  - [x] Boolean gates (AND, OR, NOT, NAND, NOR, XOR, XNOR, IMPLIES, BIIMPLIES)
112
  - [x] Arithmetic adders (half, full, ripple carry 2/4/8 bit)
113
  - [x] Arithmetic subtraction (SUB, SBC, NEG)
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3bb8fad90726b27a8bd9502c5cb4154242a5f8a6d046c4ba69470be55bb98624
3
- size 2865388
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce88f97552b5471d1c2adb4a88ea64ddaf8e05537884a12624a455da32531910
3
+ size 2980552
convert_to_explicit_inputs.py CHANGED
@@ -2410,28 +2410,192 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2410
  if 0 <= j < 11:
2411
  pps.append(f"{prefix}.pp{i}_{j}")
2412
 
2413
- if len(pps) == 1:
 
2414
  if f'.col{col}' in gate and f'.col{col}_' not in gate:
2415
  return [registry.get_id(pps[0])]
2416
  registry.register(f"{prefix}.col{col}")
2417
- elif len(pps) > 1:
2418
- if f'.col{col}_sum' in gate:
2419
- return [registry.get_id(pp) for pp in pps]
2420
- registry.register(f"{prefix}.col{col}_sum")
2421
-
2422
  match = re.search(rf'\.col{col}_ge(\d+)$', gate)
2423
  if match:
2424
  return [registry.get_id(pp) for pp in pps]
2425
 
2426
- for t in range(1, len(pps)):
2427
  registry.register(f"{prefix}.col{col}_ge{t}")
2428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2429
  if '.prod_fa' in gate:
2430
  match = re.search(r'\.prod_fa(\d+)\.', gate)
2431
  if match:
2432
  i = int(match.group(1))
2433
  fa_prefix = f"{prefix}.prod_fa{i}"
2434
 
 
 
 
 
 
2435
  # Count partial products in each column to determine signal names
2436
  # col 0 and col 20 have 1 PP each, others have more
2437
  def get_col_sum(col):
@@ -2441,26 +2605,43 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2441
  return registry.get_id(f"{prefix}.col{col}_sum")
2442
  return registry.get_id("#0")
2443
 
2444
- def get_col_carry(col):
2445
- # Carry from column = sum >= 2, which is ge2
2446
- # For columns with count >= 3, ge2 exists
2447
- # For columns with count == 2, no ge2 (but carry is rare anyway)
2448
- # For single-bit columns, no carry possible
2449
- if col == 0 or col == 20:
2450
- return registry.get_id("#0") # No carry from single PP columns
2451
- elif col < 21:
2452
- # ge2 exists for columns with 3+ PPs
2453
- # For 2-PP columns (col 1 and col 19), ge2 doesn't exist
2454
- # but those columns can only produce carry if both PPs are 1,
2455
- # which is relatively rare. For now, we use #0 for 2-PP columns.
2456
- ge2_name = f"{prefix}.col{col}_ge2"
2457
- ge2_id = registry.get_id(ge2_name)
2458
- if ge2_id != -1:
2459
- return ge2_id
 
 
2460
  else:
2461
- # 2-PP columns: no ge2, return 0 (imprecise but safe)
2462
- return registry.get_id("#0")
2463
- return registry.get_id("#0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2464
 
2465
  if i == 0:
2466
  a_bit = get_col_sum(0)
@@ -2468,7 +2649,7 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2468
  cin = registry.get_id("#0")
2469
  else:
2470
  a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
2471
- b_bit = get_col_carry(i - 1) if i < 22 else registry.get_id("#0")
2472
  cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
2473
 
2474
  if '.xor1.layer1' in gate:
@@ -5360,17 +5541,158 @@ def build_float16_mul_tensors() -> Dict[str, torch.Tensor]:
5360
  tensors[f"{prefix}.col{col}.weight"] = torch.tensor([1.0])
5361
  tensors[f"{prefix}.col{col}.bias"] = torch.tensor([-0.5])
5362
  else:
5363
- # Multi-bit column: use threshold gate to count
5364
- # For proper addition we need full adder trees
5365
- # Simplified: use weighted sum approach
5366
- tensors[f"{prefix}.col{col}_sum.weight"] = torch.tensor([1.0] * count)
5367
- tensors[f"{prefix}.col{col}_sum.bias"] = torch.tensor([-0.5])
5368
 
5369
- # Carry generation for each threshold level
5370
- for t in range(1, count):
5371
  tensors[f"{prefix}.col{col}_ge{t}.weight"] = torch.tensor([1.0] * count)
5372
  tensors[f"{prefix}.col{col}_ge{t}.bias"] = torch.tensor([-float(t)])
5373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5374
  # Final product assembly using ripple carry
5375
  for i in range(22):
5376
  p = f"{prefix}.prod_fa{i}"
 
2410
  if 0 <= j < 11:
2411
  pps.append(f"{prefix}.pp{i}_{j}")
2412
 
2413
+ count = len(pps)
2414
+ if count == 1:
2415
  if f'.col{col}' in gate and f'.col{col}_' not in gate:
2416
  return [registry.get_id(pps[0])]
2417
  registry.register(f"{prefix}.col{col}")
2418
+ elif count > 1:
2419
+ # ge{t} gates: threshold >= t
 
 
 
2420
  match = re.search(rf'\.col{col}_ge(\d+)$', gate)
2421
  if match:
2422
  return [registry.get_id(pp) for pp in pps]
2423
 
2424
+ for t in range(1, count + 1):
2425
  registry.register(f"{prefix}.col{col}_ge{t}")
2426
 
2427
+ # not_ge{t} for even t
2428
+ match = re.search(rf'\.col{col}_not_ge(\d+)$', gate)
2429
+ if match:
2430
+ t = int(match.group(1))
2431
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
2432
+
2433
+ for t in range(2, count + 1, 2):
2434
+ registry.register(f"{prefix}.col{col}_not_ge{t}")
2435
+
2436
+ # odd{t} gates: ge{t} AND (NOT ge{t+1} or just ge{t} if t+1 > count)
2437
+ match = re.search(rf'\.col{col}_odd(\d+)$', gate)
2438
+ if match:
2439
+ t = int(match.group(1))
2440
+ if t + 1 <= count:
2441
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
2442
+ registry.get_id(f"{prefix}.col{col}_not_ge{t+1}")]
2443
+ else:
2444
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
2445
+
2446
+ odd_ranges = []
2447
+ for t in range(1, count + 1, 2):
2448
+ registry.register(f"{prefix}.col{col}_odd{t}")
2449
+ odd_ranges.append(f"{prefix}.col{col}_odd{t}")
2450
+
2451
+ # col_sum = OR of all odd gates (parity)
2452
+ if f'.col{col}_sum' in gate:
2453
+ return [registry.get_id(r) for r in odd_ranges]
2454
+ registry.register(f"{prefix}.col{col}_sum")
2455
+
2456
+ # col_bit1 gates (floor(sum/2) mod 2)
2457
+ if count >= 2:
2458
+ match = re.search(rf'\.col{col}_bit1_(\d+)$', gate)
2459
+ if match:
2460
+ t = int(match.group(1))
2461
+ upper = t + 2
2462
+ if upper <= count:
2463
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
2464
+ registry.get_id(f"{prefix}.col{col}_not_ge{upper}")]
2465
+ else:
2466
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
2467
+
2468
+ bit1_ranges = []
2469
+ for t in range(2, count + 1, 4):
2470
+ registry.register(f"{prefix}.col{col}_bit1_{t}")
2471
+ bit1_ranges.append(f"{prefix}.col{col}_bit1_{t}")
2472
+
2473
+ if f'.col{col}_bit1' in gate and f'.col{col}_bit1_' not in gate:
2474
+ return [registry.get_id(r) for r in bit1_ranges]
2475
+ if bit1_ranges:
2476
+ registry.register(f"{prefix}.col{col}_bit1")
2477
+
2478
+ # col_bit2 gates (floor(sum/4) mod 2)
2479
+ if count >= 4:
2480
+ match = re.search(rf'\.col{col}_bit2_(\d+)$', gate)
2481
+ if match:
2482
+ t = int(match.group(1))
2483
+ upper = t + 4
2484
+ if upper <= count:
2485
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
2486
+ registry.get_id(f"{prefix}.col{col}_not_ge{upper}")]
2487
+ else:
2488
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
2489
+
2490
+ bit2_ranges = []
2491
+ for t in range(4, count + 1, 8):
2492
+ registry.register(f"{prefix}.col{col}_bit2_{t}")
2493
+ bit2_ranges.append(f"{prefix}.col{col}_bit2_{t}")
2494
+
2495
+ if f'.col{col}_bit2' in gate and f'.col{col}_bit2_' not in gate:
2496
+ return [registry.get_id(r) for r in bit2_ranges]
2497
+ if bit2_ranges:
2498
+ registry.register(f"{prefix}.col{col}_bit2")
2499
+
2500
+ # col_bit3 gates (floor(sum/8) mod 2)
2501
+ if count >= 8:
2502
+ match = re.search(rf'\.col{col}_bit3_(\d+)$', gate)
2503
+ if match:
2504
+ t = int(match.group(1))
2505
+ upper = t + 8
2506
+ if upper <= count:
2507
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}"),
2508
+ registry.get_id(f"{prefix}.col{col}_not_ge{upper}")]
2509
+ else:
2510
+ return [registry.get_id(f"{prefix}.col{col}_ge{t}")]
2511
+
2512
+ bit3_ranges = []
2513
+ for t in range(8, count + 1, 16):
2514
+ registry.register(f"{prefix}.col{col}_bit3_{t}")
2515
+ bit3_ranges.append(f"{prefix}.col{col}_bit3_{t}")
2516
+
2517
+ if f'.col{col}_bit3' in gate and f'.col{col}_bit3_' not in gate:
2518
+ return [registry.get_id(r) for r in bit3_ranges]
2519
+ if bit3_ranges:
2520
+ registry.register(f"{prefix}.col{col}_bit3")
2521
+
2522
+ # Handle carry accumulator gates
2523
+ if '.carry_acc' in gate:
2524
+ match = re.search(r'\.carry_acc(\d+)_', gate)
2525
+ if match:
2526
+ i = int(match.group(1))
2527
+
2528
+ def get_pp_count(col):
2529
+ if col < 0 or col > 20:
2530
+ return 0
2531
+ return min(col + 1, 21 - col)
2532
+
2533
+ # Determine which carry bits come into position i
2534
+ carry_inputs = []
2535
+ if i >= 1 and get_pp_count(i-1) >= 2:
2536
+ carry_inputs.append(registry.get_id(f"{prefix}.col{i-1}_bit1"))
2537
+ if i >= 2 and get_pp_count(i-2) >= 4:
2538
+ carry_inputs.append(registry.get_id(f"{prefix}.col{i-2}_bit2"))
2539
+ if i >= 3 and get_pp_count(i-3) >= 8:
2540
+ carry_inputs.append(registry.get_id(f"{prefix}.col{i-3}_bit3"))
2541
+
2542
+ n = len(carry_inputs)
2543
+
2544
+ # ge{t} gates
2545
+ match_ge = re.search(rf'\.carry_acc{i}_ge(\d+)$', gate)
2546
+ if match_ge:
2547
+ return carry_inputs
2548
+
2549
+ # not_ge{t} gates
2550
+ match_not = re.search(rf'\.carry_acc{i}_not_ge(\d+)$', gate)
2551
+ if match_not:
2552
+ t = int(match_not.group(1))
2553
+ return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}")]
2554
+
2555
+ # Register ge gates
2556
+ for t in range(1, n + 1):
2557
+ registry.register(f"{prefix}.carry_acc{i}_ge{t}")
2558
+ for t in range(2, n + 1, 2):
2559
+ registry.register(f"{prefix}.carry_acc{i}_not_ge{t}")
2560
+
2561
+ # odd{t} gates
2562
+ match_odd = re.search(rf'\.carry_acc{i}_odd(\d+)$', gate)
2563
+ if match_odd:
2564
+ t = int(match_odd.group(1))
2565
+ if t + 1 <= n:
2566
+ return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}"),
2567
+ registry.get_id(f"{prefix}.carry_acc{i}_not_ge{t+1}")]
2568
+ else:
2569
+ return [registry.get_id(f"{prefix}.carry_acc{i}_ge{t}")]
2570
+
2571
+ # Register odd gates
2572
+ odd_ranges = []
2573
+ for t in range(1, n + 1, 2):
2574
+ registry.register(f"{prefix}.carry_acc{i}_odd{t}")
2575
+ odd_ranges.append(f"{prefix}.carry_acc{i}_odd{t}")
2576
+
2577
+ # carry_acc_sum = OR of odd gates
2578
+ if f'.carry_acc{i}_sum' in gate:
2579
+ return [registry.get_id(r) for r in odd_ranges]
2580
+ registry.register(f"{prefix}.carry_acc{i}_sum")
2581
+
2582
+ # carry_acc_carry = ge2
2583
+ if f'.carry_acc{i}_carry' in gate:
2584
+ return carry_inputs
2585
+ if n >= 2:
2586
+ registry.register(f"{prefix}.carry_acc{i}_carry")
2587
+
2588
  if '.prod_fa' in gate:
2589
  match = re.search(r'\.prod_fa(\d+)\.', gate)
2590
  if match:
2591
  i = int(match.group(1))
2592
  fa_prefix = f"{prefix}.prod_fa{i}"
2593
 
2594
+ def get_pp_count(col):
2595
+ if col < 0 or col > 20:
2596
+ return 0
2597
+ return min(col + 1, 21 - col)
2598
+
2599
  # Count partial products in each column to determine signal names
2600
  # col 0 and col 20 have 1 PP each, others have more
2601
  def get_col_sum(col):
 
2605
  return registry.get_id(f"{prefix}.col{col}_sum")
2606
  return registry.get_id("#0")
2607
 
2608
+ def get_b_bit(pos):
2609
+ # Determine incoming carries for position pos
2610
+ carry_inputs = []
2611
+ if pos >= 1 and get_pp_count(pos-1) >= 2:
2612
+ carry_inputs.append("bit1")
2613
+ if pos >= 2 and get_pp_count(pos-2) >= 4:
2614
+ carry_inputs.append("bit2")
2615
+ if pos >= 3 and get_pp_count(pos-3) >= 8:
2616
+ carry_inputs.append("bit3")
2617
+
2618
+ if len(carry_inputs) == 0:
2619
+ return registry.get_id("#0")
2620
+ elif len(carry_inputs) == 1:
2621
+ # Single carry, use it directly
2622
+ if carry_inputs[0] == "bit1":
2623
+ return registry.get_id(f"{prefix}.col{pos-1}_bit1")
2624
+ elif carry_inputs[0] == "bit2":
2625
+ return registry.get_id(f"{prefix}.col{pos-2}_bit2")
2626
  else:
2627
+ return registry.get_id(f"{prefix}.col{pos-3}_bit3")
2628
+ else:
2629
+ # Multiple carries, use accumulator sum
2630
+ return registry.register(f"{prefix}.carry_acc{pos}_sum")
2631
+
2632
+ def get_extra_cin(pos):
2633
+ # Extra carry from accumulator (when sum of carries >= 2)
2634
+ carry_inputs = []
2635
+ if pos >= 1 and get_pp_count(pos-1) >= 2:
2636
+ carry_inputs.append("bit1")
2637
+ if pos >= 2 and get_pp_count(pos-2) >= 4:
2638
+ carry_inputs.append("bit2")
2639
+ if pos >= 3 and get_pp_count(pos-3) >= 8:
2640
+ carry_inputs.append("bit3")
2641
+
2642
+ if len(carry_inputs) >= 2:
2643
+ return registry.register(f"{prefix}.carry_acc{pos}_carry")
2644
+ return None
2645
 
2646
  if i == 0:
2647
  a_bit = get_col_sum(0)
 
2649
  cin = registry.get_id("#0")
2650
  else:
2651
  a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
2652
+ b_bit = get_b_bit(i)
2653
  cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
2654
 
2655
  if '.xor1.layer1' in gate:
 
5541
  tensors[f"{prefix}.col{col}.weight"] = torch.tensor([1.0])
5542
  tensors[f"{prefix}.col{col}.bias"] = torch.tensor([-0.5])
5543
  else:
5544
+ # Multi-bit column: compute parity (sum mod 2) using threshold gates
5545
+ # parity = (ge1 AND NOT ge2) OR (ge3 AND NOT ge4) OR ...
5546
+ # This captures: sum is odd when in range [1], [3,4), [5,6), etc.
 
 
5547
 
5548
+ # Threshold gates: ge{t} = 1 if sum >= t
5549
+ for t in range(1, count + 1):
5550
  tensors[f"{prefix}.col{col}_ge{t}.weight"] = torch.tensor([1.0] * count)
5551
  tensors[f"{prefix}.col{col}_ge{t}.bias"] = torch.tensor([-float(t)])
5552
 
5553
+ # NOT gates for even thresholds
5554
+ for t in range(2, count + 1, 2):
5555
+ tensors[f"{prefix}.col{col}_not_ge{t}.weight"] = torch.tensor([-1.0])
5556
+ tensors[f"{prefix}.col{col}_not_ge{t}.bias"] = torch.tensor([0.0])
5557
+
5558
+ # AND gates for odd ranges: (ge1 AND NOT ge2), (ge3 AND NOT ge4), ...
5559
+ odd_ranges = []
5560
+ for t in range(1, count + 1, 2):
5561
+ if t + 1 <= count:
5562
+ # ge{t} AND NOT ge{t+1}
5563
+ tensors[f"{prefix}.col{col}_odd{t}.weight"] = torch.tensor([1.0, 1.0])
5564
+ tensors[f"{prefix}.col{col}_odd{t}.bias"] = torch.tensor([-2.0])
5565
+ odd_ranges.append(t)
5566
+ else:
5567
+ # ge{t} only (no upper bound needed if t is max)
5568
+ tensors[f"{prefix}.col{col}_odd{t}.weight"] = torch.tensor([1.0])
5569
+ tensors[f"{prefix}.col{col}_odd{t}.bias"] = torch.tensor([-0.5])
5570
+ odd_ranges.append(t)
5571
+
5572
+ # col_sum = OR of all odd ranges (parity = bit 0)
5573
+ num_odd = len(odd_ranges)
5574
+ tensors[f"{prefix}.col{col}_sum.weight"] = torch.tensor([1.0] * num_odd)
5575
+ tensors[f"{prefix}.col{col}_sum.bias"] = torch.tensor([-0.5])
5576
+
5577
+ # col_bit1 = floor(sum/2) mod 2 = parity of [2,3], [6,7], [10,11], ...
5578
+ # This is (ge2 AND NOT ge4) OR (ge6 AND NOT ge8) OR ...
5579
+ if count >= 2:
5580
+ bit1_ranges = []
5581
+ for t in range(2, count + 1, 4):
5582
+ upper = t + 2
5583
+ if upper <= count:
5584
+ tensors[f"{prefix}.col{col}_bit1_{t}.weight"] = torch.tensor([1.0, 1.0])
5585
+ tensors[f"{prefix}.col{col}_bit1_{t}.bias"] = torch.tensor([-2.0])
5586
+ if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors:
5587
+ tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0])
5588
+ tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0])
5589
+ else:
5590
+ tensors[f"{prefix}.col{col}_bit1_{t}.weight"] = torch.tensor([1.0])
5591
+ tensors[f"{prefix}.col{col}_bit1_{t}.bias"] = torch.tensor([-0.5])
5592
+ bit1_ranges.append(t)
5593
+
5594
+ if bit1_ranges:
5595
+ tensors[f"{prefix}.col{col}_bit1.weight"] = torch.tensor([1.0] * len(bit1_ranges))
5596
+ tensors[f"{prefix}.col{col}_bit1.bias"] = torch.tensor([-0.5])
5597
+
5598
+ # col_bit2 = floor(sum/4) mod 2 = parity of [4,7], [12,15], ...
5599
+ # This is (ge4 AND NOT ge8) OR (ge12 AND NOT ge16) OR ...
5600
+ if count >= 4:
5601
+ bit2_ranges = []
5602
+ for t in range(4, count + 1, 8):
5603
+ upper = t + 4
5604
+ if upper <= count:
5605
+ tensors[f"{prefix}.col{col}_bit2_{t}.weight"] = torch.tensor([1.0, 1.0])
5606
+ tensors[f"{prefix}.col{col}_bit2_{t}.bias"] = torch.tensor([-2.0])
5607
+ if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors:
5608
+ tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0])
5609
+ tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0])
5610
+ else:
5611
+ tensors[f"{prefix}.col{col}_bit2_{t}.weight"] = torch.tensor([1.0])
5612
+ tensors[f"{prefix}.col{col}_bit2_{t}.bias"] = torch.tensor([-0.5])
5613
+ bit2_ranges.append(t)
5614
+
5615
+ if bit2_ranges:
5616
+ tensors[f"{prefix}.col{col}_bit2.weight"] = torch.tensor([1.0] * len(bit2_ranges))
5617
+ tensors[f"{prefix}.col{col}_bit2.bias"] = torch.tensor([-0.5])
5618
+
5619
+ # col_bit3 = floor(sum/8) mod 2 (for col10 with 11 PPs)
5620
+ if count >= 8:
5621
+ bit3_ranges = []
5622
+ for t in range(8, count + 1, 16):
5623
+ upper = t + 8
5624
+ if upper <= count:
5625
+ tensors[f"{prefix}.col{col}_bit3_{t}.weight"] = torch.tensor([1.0, 1.0])
5626
+ tensors[f"{prefix}.col{col}_bit3_{t}.bias"] = torch.tensor([-2.0])
5627
+ if f"{prefix}.col{col}_not_ge{upper}.weight" not in tensors:
5628
+ tensors[f"{prefix}.col{col}_not_ge{upper}.weight"] = torch.tensor([-1.0])
5629
+ tensors[f"{prefix}.col{col}_not_ge{upper}.bias"] = torch.tensor([0.0])
5630
+ else:
5631
+ tensors[f"{prefix}.col{col}_bit3_{t}.weight"] = torch.tensor([1.0])
5632
+ tensors[f"{prefix}.col{col}_bit3_{t}.bias"] = torch.tensor([-0.5])
5633
+ bit3_ranges.append(t)
5634
+
5635
+ if bit3_ranges:
5636
+ tensors[f"{prefix}.col{col}_bit3.weight"] = torch.tensor([1.0] * len(bit3_ranges))
5637
+ tensors[f"{prefix}.col{col}_bit3.bias"] = torch.tensor([-0.5])
5638
+
5639
+ # Carry accumulator for multi-bit carries
5640
+ # For position i, incoming carries are: bit1[i-1], bit2[i-2], bit3[i-3]
5641
+ # We need to sum these and produce: carry_acc_sum (parity), carry_acc_carry (sum >= 2)
5642
+ def get_pp_count(col):
5643
+ if col < 0 or col > 20:
5644
+ return 0
5645
+ return min(col + 1, 21 - col)
5646
+
5647
+ for i in range(22):
5648
+ # Determine which carry bits come into position i
5649
+ carry_inputs = []
5650
+ # bit1 from col[i-1]
5651
+ if i >= 1 and get_pp_count(i-1) >= 2:
5652
+ carry_inputs.append(f"bit1_{i-1}")
5653
+ # bit2 from col[i-2]
5654
+ if i >= 2 and get_pp_count(i-2) >= 4:
5655
+ carry_inputs.append(f"bit2_{i-2}")
5656
+ # bit3 from col[i-3]
5657
+ if i >= 3 and get_pp_count(i-3) >= 8:
5658
+ carry_inputs.append(f"bit3_{i-3}")
5659
+
5660
+ if len(carry_inputs) == 0:
5661
+ # No carries, use #0
5662
+ pass
5663
+ elif len(carry_inputs) == 1:
5664
+ # Single carry, no accumulator needed
5665
+ pass
5666
+ else:
5667
+ # Multiple carries, need accumulator
5668
+ n = len(carry_inputs)
5669
+ # Parity (sum mod 2) using threshold gates
5670
+ # ge{t} = sum >= t
5671
+ for t in range(1, n + 1):
5672
+ tensors[f"{prefix}.carry_acc{i}_ge{t}.weight"] = torch.tensor([1.0] * n)
5673
+ tensors[f"{prefix}.carry_acc{i}_ge{t}.bias"] = torch.tensor([-float(t) + 0.5])
5674
+ # NOT gates for even thresholds
5675
+ for t in range(2, n + 1, 2):
5676
+ tensors[f"{prefix}.carry_acc{i}_not_ge{t}.weight"] = torch.tensor([-1.0])
5677
+ tensors[f"{prefix}.carry_acc{i}_not_ge{t}.bias"] = torch.tensor([0.0])
5678
+ # AND gates for odd ranges: (ge1 AND NOT ge2), (ge3 AND NOT ge4), ...
5679
+ odd_ranges = []
5680
+ for t in range(1, n + 1, 2):
5681
+ if t + 1 <= n:
5682
+ tensors[f"{prefix}.carry_acc{i}_odd{t}.weight"] = torch.tensor([1.0, 1.0])
5683
+ tensors[f"{prefix}.carry_acc{i}_odd{t}.bias"] = torch.tensor([-2.0])
5684
+ else:
5685
+ tensors[f"{prefix}.carry_acc{i}_odd{t}.weight"] = torch.tensor([1.0])
5686
+ tensors[f"{prefix}.carry_acc{i}_odd{t}.bias"] = torch.tensor([-0.5])
5687
+ odd_ranges.append(t)
5688
+ # carry_acc_sum = OR of odd ranges
5689
+ tensors[f"{prefix}.carry_acc{i}_sum.weight"] = torch.tensor([1.0] * len(odd_ranges))
5690
+ tensors[f"{prefix}.carry_acc{i}_sum.bias"] = torch.tensor([-0.5])
5691
+ # carry_acc_carry = ge2 (sum >= 2)
5692
+ if n >= 2:
5693
+ tensors[f"{prefix}.carry_acc{i}_carry.weight"] = torch.tensor([1.0] * n)
5694
+ tensors[f"{prefix}.carry_acc{i}_carry.bias"] = torch.tensor([-1.5])
5695
+
5696
  # Final product assembly using ripple carry
5697
  for i in range(22):
5698
  p = f"{prefix}.prod_fa{i}"