CharlesCNorton commited on
Commit
659bba6
·
1 Parent(s): fe6e20e

Remove duplicate function definitions in build.py

Browse files
Files changed (1) hide show
  1. build.py +0 -332
build.py CHANGED
@@ -378,108 +378,6 @@ def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None:
378
  add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
379
 
380
 
381
- def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None:
382
- """Add expression circuit for (A + B) × C (parenthetical grouping).
383
-
384
- Computes (A + B) × C where parentheses override default precedence.
385
-
386
- Structure:
387
- - Stage 1: Add A + B (8 full adders) → temp
388
- - Stage 2: Multiply temp × C using shift-add algorithm
389
- - 8 mask stages: mask[i] = temp AND C[i] (8 AND gates each)
390
- - 7 accumulator adders to sum shifted masked values
391
-
392
- Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
393
- Output: 8-bit result of (A + B) × C, wrapping on overflow
394
-
395
- Total: 8 full adders (add) + 64 AND gates + 7×8 full adders (mul) = ~640 gates
396
- """
397
- prefix = "arithmetic.expr_paren"
398
-
399
- # Stage 1: Add A + B → temp
400
- for bit in range(8):
401
- add_full_adder(tensors, f"{prefix}.add.fa{bit}")
402
-
403
- # Stage 2: Multiply temp × C using shift-add
404
- # Mask AND gates: mask[stage][bit] = temp[bit] AND C[stage]
405
- for stage in range(8):
406
- for bit in range(8):
407
- add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
408
-
409
- # Accumulator adders for shift-add multiplication
410
- for stage in range(1, 8): # 7 accumulator adders
411
- for bit in range(8):
412
- add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
413
-
414
-
415
- def add_expr_paren(tensors: Dict[str, torch.Tensor]) -> None:
416
- """Add expression circuit for (A + B) × C (parenthetical grouping).
417
-
418
- Computes (A + B) × C where addition is evaluated first due to parentheses.
419
-
420
- Structure:
421
- - Stage 1: Add A + B (8-bit ripple carry adder)
422
- - Stage 2: Multiply sum × C using shift-add algorithm
423
- - 8 mask stages: mask[i] = sum AND C[i] (8 AND gates each)
424
- - 7 accumulator adders to sum masked values
425
-
426
- Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
427
- Output: 8-bit result of (A + B) × C, wrapping on overflow
428
-
429
- Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates
430
- """
431
- prefix = "arithmetic.expr_paren"
432
-
433
- # Stage 1: Add A + B
434
- for bit in range(8):
435
- add_full_adder(tensors, f"{prefix}.add.fa{bit}")
436
-
437
- # Stage 2: Multiply sum × C using shift-add
438
- # Mask AND gates: mask[stage][bit] = sum[bit] AND C[stage]
439
- for stage in range(8):
440
- for bit in range(8):
441
- add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
442
-
443
- # Accumulator adders for shift-add multiplication
444
- for stage in range(1, 8): # 7 accumulator adders
445
- for bit in range(8):
446
- add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
447
-
448
-
449
- def add_expr_paren_add_mul(tensors: Dict[str, torch.Tensor]) -> None:
450
- """Add expression circuit for (A + B) × C (parenthetical grouping).
451
-
452
- Computes (A + B) × C where parentheses override default precedence.
453
-
454
- Structure:
455
- - Stage 1: Add A + B (8-bit ripple carry) → temp
456
- - Stage 2: Multiply temp × C using shift-add algorithm
457
- - 8 mask stages: mask[i] = temp AND C[i]
458
- - 7 accumulator adders to sum masked values
459
-
460
- Inputs: $a[0-7], $b[0-7], $c[0-7] (MSB-first, 8-bit each)
461
- Output: 8-bit result of (A + B) × C, wrapping on overflow
462
-
463
- Total: 8 full adders (add) + 64 AND gates + 56 full adders (mul) = ~640 gates
464
- """
465
- prefix = "arithmetic.expr_paren_add_mul"
466
-
467
- # Stage 1: Add A + B → temp
468
- for bit in range(8):
469
- add_full_adder(tensors, f"{prefix}.add.fa{bit}")
470
-
471
- # Stage 2: Multiply temp × C using shift-add
472
- # Mask AND gates: mask[stage][bit] = temp[bit] AND C[stage]
473
- for stage in range(8):
474
- for bit in range(8):
475
- add_gate(tensors, f"{prefix}.mul.mask.s{stage}.b{bit}", [1.0, 1.0], [-2.0])
476
-
477
- # Accumulator adders for multiplication
478
- for stage in range(1, 8): # 7 accumulator adders
479
- for bit in range(8):
480
- add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
481
-
482
-
483
  def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
484
  """Add 3-operand 8-bit adder circuit.
485
 
@@ -1220,236 +1118,6 @@ def infer_expr_paren_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1220
  return []
1221
 
1222
 
1223
- def infer_expr_paren_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1224
- """Infer inputs for (A + B) × C expression circuit (parenthetical grouping).
1225
-
1226
- Circuit structure:
1227
- - Add stage: add.fa[bit] computes A[bit] + B[bit]
1228
- - Mask stage: mask.s[stage].b[bit] = sum[bit] AND C[stage]
1229
- - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage)
1230
-
1231
- Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB)
1232
- """
1233
- prefix = "arithmetic.expr_paren"
1234
-
1235
- # Register all inputs
1236
- for i in range(8):
1237
- reg.register(f"$a[{i}]")
1238
- reg.register(f"$b[{i}]")
1239
- reg.register(f"$c[{i}]")
1240
-
1241
- # Add stage: A + B
1242
- if '.add.fa' in gate and '.mul.' not in gate:
1243
- m = re.search(r'\.fa(\d+)\.', gate)
1244
- if not m:
1245
- return []
1246
- bit = int(m.group(1))
1247
-
1248
- # A and B inputs (MSB-first to positional)
1249
- a_input = reg.get_id(f"$a[{7-bit}]")
1250
- b_input = reg.get_id(f"$b[{7-bit}]")
1251
-
1252
- # Carry input
1253
- if bit == 0:
1254
- cin = reg.get_id("#0")
1255
- else:
1256
- cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or")
1257
-
1258
- fa_prefix = f"{prefix}.add.fa{bit}"
1259
-
1260
- if '.ha1.sum.layer1' in gate:
1261
- return [a_input, b_input]
1262
- if '.ha1.sum.layer2' in gate:
1263
- return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1264
- if '.ha1.carry' in gate and '.layer' not in gate:
1265
- return [a_input, b_input]
1266
- if '.ha2.sum.layer1' in gate:
1267
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1268
- if '.ha2.sum.layer2' in gate:
1269
- return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1270
- if '.ha2.carry' in gate and '.layer' not in gate:
1271
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1272
- if '.carry_or' in gate:
1273
- return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1274
- return []
1275
-
1276
- # Mask AND gates: mask.s[stage].b[bit] = sum[bit] AND C[stage]
1277
- if '.mul.mask.' in gate:
1278
- m = re.search(r'\.s(\d+)\.b(\d+)', gate)
1279
- if m:
1280
- stage = int(m.group(1))
1281
- bit = int(m.group(2))
1282
- # sum[bit] is the output of add.fa[bit]
1283
- sum_input = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2")
1284
- # C[stage] in MSB-first
1285
- c_input = reg.get_id(f"$c[{7-stage}]")
1286
- return [sum_input, c_input]
1287
- return []
1288
-
1289
- # Accumulator adders: acc.s[stage].fa[bit]
1290
- if '.mul.acc.' in gate:
1291
- m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate)
1292
- if not m:
1293
- return []
1294
- stage = int(m.group(1)) # 1-7
1295
- bit = int(m.group(2)) # 0-7
1296
-
1297
- # A input: previous stage output
1298
- if stage == 1:
1299
- # First accumulator: A = mask.s0.b[bit] (AND gate output)
1300
- a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}")
1301
- else:
1302
- # Later stages: A = previous accumulator sum
1303
- a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2")
1304
-
1305
- # B input: (mask.s[stage] << stage)[bit]
1306
- if bit < stage:
1307
- b_input = reg.get_id("#0")
1308
- else:
1309
- b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}")
1310
-
1311
- # Carry input
1312
- if bit == 0:
1313
- cin = reg.get_id("#0")
1314
- else:
1315
- cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or")
1316
-
1317
- fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}"
1318
-
1319
- if '.ha1.sum.layer1' in gate:
1320
- return [a_input, b_input]
1321
- if '.ha1.sum.layer2' in gate:
1322
- return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1323
- if '.ha1.carry' in gate and '.layer' not in gate:
1324
- return [a_input, b_input]
1325
- if '.ha2.sum.layer1' in gate:
1326
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1327
- if '.ha2.sum.layer2' in gate:
1328
- return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1329
- if '.ha2.carry' in gate and '.layer' not in gate:
1330
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1331
- if '.carry_or' in gate:
1332
- return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1333
- return []
1334
-
1335
- return []
1336
-
1337
-
1338
- def infer_expr_paren_add_mul_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1339
- """Infer inputs for (A + B) × C expression circuit (parenthetical grouping).
1340
-
1341
- Circuit structure:
1342
- - Add stage: A + B → temp (8-bit ripple carry)
1343
- - Mask stage: mask.s[stage].b[bit] = temp[bit] AND C[stage]
1344
- - Accumulator stages 1-7: acc.s[stage] = acc.s[stage-1] + (mask.s[stage] << stage)
1345
-
1346
- Bit ordering: MSB-first externally, LSB-first internally (fa0 = LSB, fa7 = MSB)
1347
- """
1348
- prefix = "arithmetic.expr_paren_add_mul"
1349
-
1350
- # Register all inputs
1351
- for i in range(8):
1352
- reg.register(f"$a[{i}]")
1353
- reg.register(f"$b[{i}]")
1354
- reg.register(f"$c[{i}]")
1355
-
1356
- # Add stage: A + B → temp
1357
- if '.add.fa' in gate and '.mul.' not in gate:
1358
- m = re.search(r'\.fa(\d+)\.', gate)
1359
- if not m:
1360
- return []
1361
- bit = int(m.group(1))
1362
-
1363
- # A input: $a[7-bit] (MSB-first to positional bit)
1364
- a_input = reg.get_id(f"$a[{7-bit}]")
1365
- # B input: $b[7-bit]
1366
- b_input = reg.get_id(f"$b[{7-bit}]")
1367
- # Carry input
1368
- if bit == 0:
1369
- cin = reg.get_id("#0")
1370
- else:
1371
- cin = reg.register(f"{prefix}.add.fa{bit-1}.carry_or")
1372
-
1373
- fa_prefix = f"{prefix}.add.fa{bit}"
1374
-
1375
- if '.ha1.sum.layer1' in gate:
1376
- return [a_input, b_input]
1377
- if '.ha1.sum.layer2' in gate:
1378
- return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1379
- if '.ha1.carry' in gate and '.layer' not in gate:
1380
- return [a_input, b_input]
1381
- if '.ha2.sum.layer1' in gate:
1382
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1383
- if '.ha2.sum.layer2' in gate:
1384
- return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1385
- if '.ha2.carry' in gate and '.layer' not in gate:
1386
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1387
- if '.carry_or' in gate:
1388
- return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1389
- return []
1390
-
1391
- # Mask AND gates: mask.s[stage].b[bit] = temp[bit] AND C[stage]
1392
- if '.mul.mask.' in gate:
1393
- m = re.search(r'\.s(\d+)\.b(\d+)', gate)
1394
- if m:
1395
- stage = int(m.group(1))
1396
- bit = int(m.group(2))
1397
- # temp[bit] is the sum output from add stage
1398
- temp_bit = reg.register(f"{prefix}.add.fa{bit}.ha2.sum.layer2")
1399
- # C[stage] in MSB-first
1400
- c_input = reg.get_id(f"$c[{7-stage}]")
1401
- return [temp_bit, c_input]
1402
- return []
1403
-
1404
- # Accumulator adders: acc.s[stage].fa[bit]
1405
- if '.mul.acc.' in gate:
1406
- m = re.search(r'\.s(\d+)\.fa(\d+)\.', gate)
1407
- if not m:
1408
- return []
1409
- stage = int(m.group(1)) # 1-7
1410
- bit = int(m.group(2)) # 0-7
1411
-
1412
- # A input: previous stage output
1413
- if stage == 1:
1414
- # First accumulator: A = mask.s0.b[bit] (AND gate output)
1415
- a_input = reg.register(f"{prefix}.mul.mask.s0.b{bit}")
1416
- else:
1417
- # Later stages: A = previous accumulator sum
1418
- a_input = reg.register(f"{prefix}.mul.acc.s{stage-1}.fa{bit}.ha2.sum.layer2")
1419
-
1420
- # B input: (mask.s[stage] << stage)[bit]
1421
- if bit < stage:
1422
- b_input = reg.get_id("#0")
1423
- else:
1424
- b_input = reg.register(f"{prefix}.mul.mask.s{stage}.b{bit-stage}")
1425
-
1426
- # Carry input
1427
- if bit == 0:
1428
- cin = reg.get_id("#0")
1429
- else:
1430
- cin = reg.register(f"{prefix}.mul.acc.s{stage}.fa{bit-1}.carry_or")
1431
-
1432
- fa_prefix = f"{prefix}.mul.acc.s{stage}.fa{bit}"
1433
-
1434
- if '.ha1.sum.layer1' in gate:
1435
- return [a_input, b_input]
1436
- if '.ha1.sum.layer2' in gate:
1437
- return [reg.register(f"{fa_prefix}.ha1.sum.layer1.or"), reg.register(f"{fa_prefix}.ha1.sum.layer1.nand")]
1438
- if '.ha1.carry' in gate and '.layer' not in gate:
1439
- return [a_input, b_input]
1440
- if '.ha2.sum.layer1' in gate:
1441
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1442
- if '.ha2.sum.layer2' in gate:
1443
- return [reg.register(f"{fa_prefix}.ha2.sum.layer1.or"), reg.register(f"{fa_prefix}.ha2.sum.layer1.nand")]
1444
- if '.ha2.carry' in gate and '.layer' not in gate:
1445
- return [reg.register(f"{fa_prefix}.ha1.sum.layer2"), cin]
1446
- if '.carry_or' in gate:
1447
- return [reg.register(f"{fa_prefix}.ha1.carry"), reg.register(f"{fa_prefix}.ha2.carry")]
1448
- return []
1449
-
1450
- return []
1451
-
1452
-
1453
  def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1454
  """Infer inputs for 3-operand adder: A + B + C."""
1455
  prefix = "arithmetic.add3_8bit"
 
378
  add_full_adder(tensors, f"{prefix}.mul.acc.s{stage}.fa{bit}")
379
 
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  def add_add3(tensors: Dict[str, torch.Tensor]) -> None:
382
  """Add 3-operand 8-bit adder circuit.
383
 
 
1118
  return []
1119
 
1120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1121
  def infer_add3_inputs(gate: str, reg: SignalRegistry) -> List[int]:
1122
  """Infer inputs for 3-operand adder: A + B + C."""
1123
  prefix = "arithmetic.add3_8bit"