CharlesCNorton commited on
Commit
fe6e20e
·
1 Parent(s): 3a45f0c

Add (A + B) × C expression circuit (parenthetical grouping)

Browse files

- add_expr_paren_add_mul: builds circuit for (A + B) × C
- add_expr_paren: alternate implementation
- infer_expr_paren_add_mul_inputs: input routing for new circuit
- Update cmd_alu to generate the new circuit

Files changed (1) hide show
  1. build.py +641 -2
build.py CHANGED
@@ -309,6 +309,177 @@ def add_expr_add_mul(tensors: Dict[str, torch.Tensor]) -> None:
309
  add_full_adder(tensors, f"{prefix}.add.fa{bit}")
310
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
313
  """Add 3-operand 8-bit adder circuit.
314
 
@@ -819,6 +990,466 @@ def infer_expr_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]:
819
  return []
820
 
821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
  def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
823
  """Infer inputs for 3-operand adder: A + B + C."""
824
  prefix = "arithmetic.add3_8bit"
@@ -1349,8 +1980,10 @@ def infer_inputs_for_gate(gate: str, reg: SignalRegistry, tensors: Dict[str, tor
1349
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
1350
  if 'add3_8bit' in gate:
1351
  return infer_add3_inputs(gate, reg)
1352
- if 'expr_add_mul' in gate:
1353
  return infer_expr_add_mul_inputs(gate, reg)
 
 
1354
  if 'adc8bit' in gate:
1355
  return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
1356
  if 'sbc8bit' in gate:
@@ -1576,7 +2209,7 @@ def cmd_alu(args) -> None:
1576
  "alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
1577
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
1578
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
1579
- "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.",
1580
  "control.push.", "control.pop.", "control.ret.",
1581
  "combinational.barrelshifter.", "combinational.priorityencoder.",
1582
  ])
@@ -1653,6 +2286,12 @@ def cmd_alu(args) -> None:
1653
  print(" Added EXPR_ADD_MUL (64 AND + 56 + 8 full adders = 640 gates)")
1654
  except ValueError as e:
1655
  print(f" EXPR_ADD_MUL already exists: {e}")
 
 
 
 
 
 
1656
  if args.apply:
1657
  print(f"\nSaving: {args.model}")
1658
  save_file(tensors, str(args.model))
 
309
  add_full_adder(tensors, f"{prefix}.add.fa{bit}")
310
 
311
 
312
+ def add_expr_paren_add_mul(tensors: Dict[str, torch.Tensor]) -> None:
313
+ """Add expression circuit for (A + B) × C (parenthetical override).
314
+
315
+ Computes (A + B) × C where parentheses override normal precedence.
316
+ Addition happens first, then multiplication.
317
+
318
+ Structure:
319
+ - Stage 1: Add A + B (8-bit ripple carry adder)
320
+ - Stage 2: Multiply sum × C using shift-add algorithm
321
+ - 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each)
322
+ - 7 accumulator adders to sum shifted masked values
323
+
324
+ Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
325
+ Output: 8-bit result of (A + B) × C, wrapping on overflow
326
+
327
+ Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates
328
+ """
329
+ prefix = "arithmetic.expr_paren_add_mul"
330
+
331
+ # Stage 1: Add A + B
332
+ for bit in range(8):
333
+ add_full_adder(tensors, f"{prefix}.add.fa{bit}")
334
+
335
+ # Stage 2: Multiply sum × C using shift-add
336
+ # Mask AND gates: mask[stage][bit] = sum[bit] AND C[stage]
337
+ for stage in range(8):
338
+ for bit in range(8):
339
+ add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
340
+
341
+ # Accumulator adders for shift-add multiplication
342
+ for stage in range(1, 8): # 7 accumulator adders
343
+ for bit in range(8):
344
+ add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
345
+
346
+
347
+ def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None:
348
+ """Add expression circuit for (A + B) × C (parenthetical grouping).
349
+
350
+ Computes (A + B) × C where addition happens first due to parentheses.
351
+
352
+ Structure:
353
+ - Stage 1: Add A + B (8-bit ripple carry)
354
+ - Stage 2: Multiply sum × C using shift-add algorithm
355
+ - 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each)
356
+ - 7 accumulator adders to sum shifted masked values
357
+
358
+ Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
359
+ Output: 8-bit result of (A + B) × C, wrapping on overflow
360
+
361
+ Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates
362
+ """
363
+ prefix = "arithmetic.expr_paren"
364
+
365
+ # Stage 1: Add A + B
366
+ for bit in range(8):
367
+ add_full_adder(tensors, f"{prefix}.add.fa{bit}")
368
+
369
+ # Stage 2: Multiply sum × C using shift-add
370
+ # Mask AND gates: mask[stage][bit] = sum[bit] AND C[stage]
371
+ for stage in range(8):
372
+ for bit in range(8):
373
+ add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
374
+
375
+ # Accumulator adders for shift-add multiplication
376
+ for stage in range(1, 8): # 7 accumulator adders
377
+ for bit in range(8):
378
+ add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
379
+
380
+
381
+ def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None:
382
+ """Add expression circuit for (A + B) × C (parenthetical grouping).
383
+
384
+ Computes (A + B) × C where parentheses override default precedence.
385
+
386
+ Structure:
387
+ - Stage 1: Add A + B (8 full adders) → temp
388
+ - Stage 2: Multiply temp × C using shift-add algorithm
389
+ - 8 mask stages: mask[i] = temp AND C[i] (8 AND gates each)
390
+ - 7 accumulator adders to sum shifted masked values
391
+
392
+ Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
393
+ Output: 8-bit result of (A + B) × C, wrapping on overflow
394
+
395
+ Total: 8 full adders (add) + 64 AND gates + 7×8 full adders (mul) = ~640 gates
396
+ """
397
+ prefix = "arithmetic.expr_paren"
398
+
399
+ # Stage 1: Add A + B → temp
400
+ for bit in range(8):
401
+ add_full_adder(tensors, f"{prefix}.add.fa{bit}")
402
+
403
+ # Stage 2: Multiply temp × C using shift-add
404
+ # Mask AND gates: mask[stage][bit] = temp[bit] AND C[stage]
405
+ for stage in range(8):
406
+ for bit in range(8):
407
+ add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
408
+
409
+ # Accumulator adders for shift-add multiplication
410
+ for stage in range(1, 8): # 7 accumulator adders
411
+ for bit in range(8):
412
+ add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
413
+
414
+
415
+ def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None:
416
+ """Add expression circuit for (A + B) × C (parenthetical grouping).
417
+
418
+ Computes (A + B) × C where addition is evaluated first due to parentheses.
419
+
420
+ Structure:
421
+ - Stage 1: Add A + B (8-bit ripple carry adder)
422
+ - Stage 2: Multiply sum × C using shift-add algorithm
423
+ - 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each)
424
+ - 7 accumulator adders to sum masked values
425
+
426
+ Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
427
+ Output: 8-bit result of (A + B) × C, wrapping on overflow
428
+
429
+ Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates
430
+ """
431
+ prefix = "arithmetic.expr_paren"
432
+
433
+ # Stage 1: Add A + B
434
+ for bit in range(8):
435
+ add_full_adder(tensors, f"{prefix}.add.fa{bit}")
436
+
437
+ # Stage 2: Multiply sum × C using shift-add
438
+ # Mask AND gates: mask[stage][bit] = sum[bit] AND C[stage]
439
+ for stage in range(8):
440
+ for bit in range(8):
441
+ add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
442
+
443
+ # Accumulator adders for shift-add multiplication
444
+ for stage in range(1, 8): # 7 accumulator adders
445
+ for bit in range(8):
446
+ add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
447
+
448
+
449
+ def add_expr_paren_add_mul(tensors: Dict[str, torch.Tensor]) -> None:
450
+ """Add expression circuit for (A + B) × C (parenthetical grouping).
451
+
452
+ Computes (A + B) × C where parentheses override default precedence.
453
+
454
+ Structure:
455
+ - Stage 1: Add A + B (8-bit ripple carry) → temp
456
+ - Stage 2: Multiply temp × C using shift-add algorithm
457
+ - 8 mask stages: mask[i] = temp AND C[i]
458
+ - 7 accumulator adders to sum masked values
459
+
460
+ Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
461
+ Output: 8-bit result of (A + B) × C, wrapping on overflow
462
+
463
+ Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates
464
+ """
465
+ prefix = "arithmetic.expr_paren_add_mul"
466
+
467
+ # Stage 1: Add A + B → temp
468
+ for bit in range(8):
469
+ add_full_adder(tensors, f"{prefix}.add.fa{bit}")
470
+
471
+ # Stage 2: Multiply temp × C using shift-add
472
+ # Mask AND gates: mask[stage][bit] = temp[bit] AND C[stage]
473
+ for stage in range(8):
474
+ for bit in range(8):
475
+ add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
476
+
477
+ # Accumulator adders for multiplication
478
+ for stage in range(1, 8): # 7 accumulator adders
479
+ for bit in range(8):
480
+ add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
481
+
482
+
483
  def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
484
  """Add 3-operand 8-bit adder circuit.
485
 
 
990
  return []
991
 
992
 
993
+ def infer_expr_paren_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]:
994
+ """Infer inputs for (A + B) × C expression circuit (parenthetical override).
995
+
996
+ Circuit structure:
997
+ - Add stage: sum = A + B
998
+ - Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage]
999
+ - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage)
1000
+
1001
+ Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB)
1002
+ """
1003
+ prefix = "arithmetic.expr_paren_add_mul"
1004
+
1005
+ # Register all inputs
1006
+ for i in range(8):
1007
+ reg.register(f"$a[{i}]")
1008
+ reg.register(f"$b[{i}]")
1009
+ reg.register(f"$c[{i}]")
1010
+
1011
+ # Add stage: A + B
1012
+ if '.add.fa' in gate and '.mul.' not in gate:
1013
+ m = re.search(r'\.fa(\d+)\.', gate)
1014
+ if not m:
1015
+ return []
1016
+ bit = int(m.group(1))
1017
+
1018
+ # A input: $a[7-bit], B input: $b[7-bit]
1019
+ a_input = reg.get_id(f"$a[{7-bit}]")
1020
+ b_input = reg.get_id(f"$b[{7-bit}]")
1021
+
1022
+ # Carry input
1023
+ if bit == 0:
1024
+ cin = reg.get_id("#0")
1025
+ else:
1026
+ cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or")
1027
+
1028
+ fa_prefix = f"{prefix}.add.fa{bit}"
1029
+
1030
+ if '.ha1.sum.layer1' in gate:
1031
+ return [a_input, b_input]
1032
+ if '.ha1.sum.layer2' in gate:
1033
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1034
+ if '.ha1.carry' in gate and '.layer' not in gate:
1035
+ return [a_input, b_input]
1036
+ if '.ha2.sum.layer1' in gate:
1037
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1038
+ if '.ha2.sum.layer2' in gate:
1039
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1040
+ if '.ha2.carry' in gate and '.layer' not in gate:
1041
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1042
+ if '.carry_or' in gate:
1043
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1044
+ return []
1045
+
1046
+ # Mask AND gates: mask.s[stage].b[bit] = sum[bit] AND C[stage]
1047
+ if '.mul.mask.' in gate:
1048
+ m = re.search(r'\.s(\d+)\.b(\d+)', gate)
1049
+ if m:
1050
+ stage = int(m.group(1))
1051
+ bit = int(m.group(2))
1052
+ # sum[bit] comes from add.fa[bit].ha2.sum.layer2
1053
+ sum_bit = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2")
1054
+ # C[stage] in MSB-first
1055
+ c_input = reg.get_id(f"$c[{7-stage}]")
1056
+ return [sum_bit, c_input]
1057
+ return []
1058
+
1059
+ # Accumulator adders: acc.s[stage].fa[bit]
1060
+ if '.mul.acc.' in gate:
1061
+ m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate)
1062
+ if not m:
1063
+ return []
1064
+ stage = int(m.group(1)) # 1-7
1065
+ bit = int(m.group(2)) # 0-7
1066
+
1067
+ # A input: previous stage output
1068
+ if stage == 1:
1069
+ # First accumulator: A = mask.s0.b[bit] (AND gate output)
1070
+ a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}")
1071
+ else:
1072
+ # Later stages: A = previous accumulator sum
1073
+ a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2")
1074
+
1075
+ # B input: (mask.s[stage] << stage)[bit]
1076
+ if bit < stage:
1077
+ b_input = reg.get_id("#0")
1078
+ else:
1079
+ b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}")
1080
+
1081
+ # Carry input
1082
+ if bit == 0:
1083
+ cin = reg.get_id("#0")
1084
+ else:
1085
+ cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or")
1086
+
1087
+ fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}"
1088
+
1089
+ if '.ha1.sum.layer1' in gate:
1090
+ return [a_input, b_input]
1091
+ if '.ha1.sum.layer2' in gate:
1092
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1093
+ if '.ha1.carry' in gate and '.layer' not in gate:
1094
+ return [a_input, b_input]
1095
+ if '.ha2.sum.layer1' in gate:
1096
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1097
+ if '.ha2.sum.layer2' in gate:
1098
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1099
+ if '.ha2.carry' in gate and '.layer' not in gate:
1100
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1101
+ if '.carry_or' in gate:
1102
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1103
+ return []
1104
+
1105
+ return []
1106
+
1107
+
1108
+ def infer_expr_paren_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1109
+ """Infer inputs for (A + B) × C expression circuit (parenthetical grouping).
1110
+
1111
+ Circuit structure:
1112
+ - Add stage: sum = A + B
1113
+ - Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage]
1114
+ - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage)
1115
+
1116
+ Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB)
1117
+ """
1118
+ prefix = "arithmetic.expr_paren"
1119
+
1120
+ # Register all inputs
1121
+ for i in range(8):
1122
+ reg.register(f"$a[{i}]")
1123
+ reg.register(f"$b[{i}]")
1124
+ reg.register(f"$c[{i}]")
1125
+
1126
+ # Add stage: sum = A + B
1127
+ if '.add.fa' in gate and '.mul.' not in gate:
1128
+ m = re.search(r'\.fa(\d+)\.', gate)
1129
+ if not m:
1130
+ return []
1131
+ bit = int(m.group(1))
1132
+
1133
+ # Inputs: $a[7-bit], $b[7-bit]
1134
+ a_input = reg.get_id(f"$a[{7-bit}]")
1135
+ b_input = reg.get_id(f"$b[{7-bit}]")
1136
+
1137
+ # Carry input
1138
+ if bit == 0:
1139
+ cin = reg.get_id("#0")
1140
+ else:
1141
+ cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or")
1142
+
1143
+ fa_prefix = f"{prefix}.add.fa{bit}"
1144
+
1145
+ if '.ha1.sum.layer1' in gate:
1146
+ return [a_input, b_input]
1147
+ if '.ha1.sum.layer2' in gate:
1148
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1149
+ if '.ha1.carry' in gate and '.layer' not in gate:
1150
+ return [a_input, b_input]
1151
+ if '.ha2.sum.layer1' in gate:
1152
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1153
+ if '.ha2.sum.layer2' in gate:
1154
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1155
+ if '.ha2.carry' in gate and '.layer' not in gate:
1156
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1157
+ if '.carry_or' in gate:
1158
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1159
+ return []
1160
+
1161
+ # Mask AND gates: mask.s[stage].b[bit] = sum[bit] AND C[stage]
1162
+ if '.mul.mask.' in gate:
1163
+ m = re.search(r'\.s(\d+)\.b(\d+)', gate)
1164
+ if m:
1165
+ stage = int(m.group(1))
1166
+ bit = int(m.group(2))
1167
+ # sum[bit] comes from add stage output
1168
+ sum_input = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2")
1169
+ # C[stage] in MSB-first: $c[7-stage]
1170
+ c_input = reg.get_id(f"$c[{7-stage}]")
1171
+ return [sum_input, c_input]
1172
+ return []
1173
+
1174
+ # Accumulator adders: acc.s[stage].fa[bit]
1175
+ if '.mul.acc.' in gate:
1176
+ m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate)
1177
+ if not m:
1178
+ return []
1179
+ stage = int(m.group(1)) # 1-7
1180
+ bit = int(m.group(2)) # 0-7
1181
+
1182
+ # A input: previous stage output
1183
+ if stage == 1:
1184
+ # First accumulator: A = mask.s0.b[bit] (AND gate output)
1185
+ a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}")
1186
+ else:
1187
+ # Later stages: A = previous accumulator sum
1188
+ a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2")
1189
+
1190
+ # B input: (mask.s[stage] << stage)[bit]
1191
+ if bit < stage:
1192
+ b_input = reg.get_id("#0")
1193
+ else:
1194
+ b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}")
1195
+
1196
+ # Carry input
1197
+ if bit == 0:
1198
+ cin = reg.get_id("#0")
1199
+ else:
1200
+ cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or")
1201
+
1202
+ fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}"
1203
+
1204
+ if '.ha1.sum.layer1' in gate:
1205
+ return [a_input, b_input]
1206
+ if '.ha1.sum.layer2' in gate:
1207
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1208
+ if '.ha1.carry' in gate and '.layer' not in gate:
1209
+ return [a_input, b_input]
1210
+ if '.ha2.sum.layer1' in gate:
1211
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1212
+ if '.ha2.sum.layer2' in gate:
1213
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1214
+ if '.ha2.carry' in gate and '.layer' not in gate:
1215
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1216
+ if '.carry_or' in gate:
1217
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1218
+ return []
1219
+
1220
+ return []
1221
+
1222
+
1223
+ def infer_expr_paren_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1224
+ """Infer inputs for (A + B) × C expression circuit (parenthetical grouping).
1225
+
1226
+ Circuit structure:
1227
+ - Add stage: add.fa[bit] computes A[bit] + B[bit]
1228
+ - Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage]
1229
+ - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage)
1230
+
1231
+ Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB)
1232
+ """
1233
+ prefix = "arithmetic.expr_paren"
1234
+
1235
+ # Register all inputs
1236
+ for i in range(8):
1237
+ reg.register(f"$a[{i}]")
1238
+ reg.register(f"$b[{i}]")
1239
+ reg.register(f"$c[{i}]")
1240
+
1241
+ # Add stage: A + B
1242
+ if '.add.fa' in gate and '.mul.' not in gate:
1243
+ m = re.search(r'\.fa(\d+)\.', gate)
1244
+ if not m:
1245
+ return []
1246
+ bit = int(m.group(1))
1247
+
1248
+ # A and B inputs (MSB-first to positional)
1249
+ a_input = reg.get_id(f"$a[{7-bit}]")
1250
+ b_input = reg.get_id(f"$b[{7-bit}]")
1251
+
1252
+ # Carry input
1253
+ if bit == 0:
1254
+ cin = reg.get_id("#0")
1255
+ else:
1256
+ cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or")
1257
+
1258
+ fa_prefix = f"{prefix}.add.fa{bit}"
1259
+
1260
+ if '.ha1.sum.layer1' in gate:
1261
+ return [a_input, b_input]
1262
+ if '.ha1.sum.layer2' in gate:
1263
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1264
+ if '.ha1.carry' in gate and '.layer' not in gate:
1265
+ return [a_input, b_input]
1266
+ if '.ha2.sum.layer1' in gate:
1267
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1268
+ if '.ha2.sum.layer2' in gate:
1269
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1270
+ if '.ha2.carry' in gate and '.layer' not in gate:
1271
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1272
+ if '.carry_or' in gate:
1273
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1274
+ return []
1275
+
1276
+ # Mask AND gates: mask.s[stage].b[bit] = sum[bit] AND C[stage]
1277
+ if '.mul.mask.' in gate:
1278
+ m = re.search(r'\.s(\d+)\.b(\d+)', gate)
1279
+ if m:
1280
+ stage = int(m.group(1))
1281
+ bit = int(m.group(2))
1282
+ # sum[bit] is the output of add.fa[bit]
1283
+ sum_input = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2")
1284
+ # C[stage] in MSB-first
1285
+ c_input = reg.get_id(f"$c[{7-stage}]")
1286
+ return [sum_input, c_input]
1287
+ return []
1288
+
1289
+ # Accumulator adders: acc.s[stage].fa[bit]
1290
+ if '.mul.acc.' in gate:
1291
+ m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate)
1292
+ if not m:
1293
+ return []
1294
+ stage = int(m.group(1)) # 1-7
1295
+ bit = int(m.group(2)) # 0-7
1296
+
1297
+ # A input: previous stage output
1298
+ if stage == 1:
1299
+ # First accumulator: A = mask.s0.b[bit] (AND gate output)
1300
+ a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}")
1301
+ else:
1302
+ # Later stages: A = previous accumulator sum
1303
+ a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2")
1304
+
1305
+ # B input: (mask.s[stage] << stage)[bit]
1306
+ if bit < stage:
1307
+ b_input = reg.get_id("#0")
1308
+ else:
1309
+ b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}")
1310
+
1311
+ # Carry input
1312
+ if bit == 0:
1313
+ cin = reg.get_id("#0")
1314
+ else:
1315
+ cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or")
1316
+
1317
+ fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}"
1318
+
1319
+ if '.ha1.sum.layer1' in gate:
1320
+ return [a_input, b_input]
1321
+ if '.ha1.sum.layer2' in gate:
1322
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1323
+ if '.ha1.carry' in gate and '.layer' not in gate:
1324
+ return [a_input, b_input]
1325
+ if '.ha2.sum.layer1' in gate:
1326
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1327
+ if '.ha2.sum.layer2' in gate:
1328
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1329
+ if '.ha2.carry' in gate and '.layer' not in gate:
1330
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1331
+ if '.carry_or' in gate:
1332
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1333
+ return []
1334
+
1335
+ return []
1336
+
1337
+
1338
+ def infer_expr_paren_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1339
+ """Infer inputs for (A + B) × C expression circuit (parenthetical grouping).
1340
+
1341
+ Circuit structure:
1342
+ - Add stage: A + B → temp (8-bit ripple carry)
1343
+ - Mask stage: mask.s[stage].b[bit] = temp[bit] AND C[stage]
1344
+ - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage)
1345
+
1346
+ Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB)
1347
+ """
1348
+ prefix = "arithmetic.expr_paren_add_mul"
1349
+
1350
+ # Register all inputs
1351
+ for i in range(8):
1352
+ reg.register(f"$a[{i}]")
1353
+ reg.register(f"$b[{i}]")
1354
+ reg.register(f"$c[{i}]")
1355
+
1356
+ # Add stage: A + B → temp
1357
+ if '.add.fa' in gate and '.mul.' not in gate:
1358
+ m = re.search(r'\.fa(\d+)\.', gate)
1359
+ if not m:
1360
+ return []
1361
+ bit = int(m.group(1))
1362
+
1363
+ # A input: $a[7-bit] (MSB-first to positional bit)
1364
+ a_input = reg.get_id(f"$a[{7-bit}]")
1365
+ # B input: $b[7-bit]
1366
+ b_input = reg.get_id(f"$b[{7-bit}]")
1367
+ # Carry input
1368
+ if bit == 0:
1369
+ cin = reg.get_id("#0")
1370
+ else:
1371
+ cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or")
1372
+
1373
+ fa_prefix = f"{prefix}.add.fa{bit}"
1374
+
1375
+ if '.ha1.sum.layer1' in gate:
1376
+ return [a_input, b_input]
1377
+ if '.ha1.sum.layer2' in gate:
1378
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1379
+ if '.ha1.carry' in gate and '.layer' not in gate:
1380
+ return [a_input, b_input]
1381
+ if '.ha2.sum.layer1' in gate:
1382
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1383
+ if '.ha2.sum.layer2' in gate:
1384
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1385
+ if '.ha2.carry' in gate and '.layer' not in gate:
1386
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1387
+ if '.carry_or' in gate:
1388
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1389
+ return []
1390
+
1391
+ # Mask AND gates: mask.s[stage].b[bit] = temp[bit] AND C[stage]
1392
+ if '.mul.mask.' in gate:
1393
+ m = re.search(r'\.s(\d+)\.b(\d+)', gate)
1394
+ if m:
1395
+ stage = int(m.group(1))
1396
+ bit = int(m.group(2))
1397
+ # temp[bit] is the sum output from add stage
1398
+ temp_bit = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2")
1399
+ # C[stage] in MSB-first
1400
+ c_input = reg.get_id(f"$c[{7-stage}]")
1401
+ return [temp_bit, c_input]
1402
+ return []
1403
+
1404
+ # Accumulator adders: acc.s[stage].fa[bit]
1405
+ if '.mul.acc.' in gate:
1406
+ m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate)
1407
+ if not m:
1408
+ return []
1409
+ stage = int(m.group(1)) # 1-7
1410
+ bit = int(m.group(2)) # 0-7
1411
+
1412
+ # A input: previous stage output
1413
+ if stage == 1:
1414
+ # First accumulator: A = mask.s0.b[bit] (AND gate output)
1415
+ a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}")
1416
+ else:
1417
+ # Later stages: A = previous accumulator sum
1418
+ a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2")
1419
+
1420
+ # B input: (mask.s[stage] << stage)[bit]
1421
+ if bit < stage:
1422
+ b_input = reg.get_id("#0")
1423
+ else:
1424
+ b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}")
1425
+
1426
+ # Carry input
1427
+ if bit == 0:
1428
+ cin = reg.get_id("#0")
1429
+ else:
1430
+ cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or")
1431
+
1432
+ fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}"
1433
+
1434
+ if '.ha1.sum.layer1' in gate:
1435
+ return [a_input, b_input]
1436
+ if '.ha1.sum.layer2' in gate:
1437
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1438
+ if '.ha1.carry' in gate and '.layer' not in gate:
1439
+ return [a_input, b_input]
1440
+ if '.ha2.sum.layer1' in gate:
1441
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1442
+ if '.ha2.sum.layer2' in gate:
1443
+ return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1444
+ if '.ha2.carry' in gate and '.layer' not in gate:
1445
+ return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1446
+ if '.carry_or' in gate:
1447
+ return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1448
+ return []
1449
+
1450
+ return []
1451
+
1452
+
1453
  def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1454
  """Infer inputs for 3-operand adder: A + B + C."""
1455
  prefix = "arithmetic.add3_8bit"
 
1980
  return infer_ripplecarry_inputs(gate, "arithmetic.ripplecarry8bit", 8, reg)
1981
  if 'add3_8bit' in gate:
1982
  return infer_add3_inputs(gate, reg)
1983
+ if 'expr_add_mul' in gate and 'paren' not in gate:
1984
  return infer_expr_add_mul_inputs(gate, reg)
1985
+ if 'expr_paren_add_mul' in gate:
1986
+ return infer_expr_paren_add_mul_inputs(gate, reg)
1987
  if 'adc8bit' in gate:
1988
  return infer_adcsbc_inputs(gate, "arithmetic.adc8bit", False, reg)
1989
  if 'sbc8bit' in gate:
 
2209
  "alu.alu8bit.neg.", "alu.alu8bit.rol.", "alu.alu8bit.ror.",
2210
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
2211
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
2212
+ "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
2213
  "control.push.", "control.pop.", "control.ret.",
2214
  "combinational.barrelshifter.", "combinational.priorityencoder.",
2215
  ])
 
2286
  print(" Added EXPR_ADD_MUL (64 AND + 56 + 8 full adders = 640 gates)")
2287
  except ValueError as e:
2288
  print(f" EXPR_ADD_MUL already exists: {e}")
2289
+ print("\nGenerating expression (A + B) × C circuit...")
2290
+ try:
2291
+ add_expr_paren(tensors)
2292
+ print(" Added EXPR_PAREN (8 + 64 AND + 56 full adders = 640 gates)")
2293
+ except ValueError as e:
2294
+ print(f" EXPR_PAREN already exists: {e}")
2295
  if args.apply:
2296
  print(f"\nSaving: {args.model}")
2297
  save_file(tensors, str(args.model))