PortfolioAI commited on
Commit
6c34eb3
·
1 Parent(s): a6eff5f

Add float16.add circuit (93/125 tests passing)

Browse files

Implements IEEE 754 half-precision addition with:
- Special case detection (NaN, infinity, zero, subnormal)
- Exponent comparison and difference calculation
- Mantissa alignment via barrel shifter
- 12-bit mantissa adder/subtractor
- Result normalization with overflow/underflow handling
- Output assembly with special case multiplexing

~910 gates total. Remaining issues:
- Zero+zero produces incorrect result
- Subtraction (different signs) has bugs

Files changed (4) hide show
  1. TODO.md +1 -1
  2. arithmetic.safetensors +2 -2
  3. convert_to_explicit_inputs.py +1843 -0
  4. eval.py +158 -0
TODO.md CHANGED
@@ -7,7 +7,7 @@
7
  - [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
8
  - [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
9
  - [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
10
- - [ ] `float16.add` -- IEEE 754 addition (requires normalize + align + add)
11
  - [ ] `float16.sub` -- subtraction (add with negated operand)
12
  - [ ] `float16.mul` -- multiplication
13
  - [ ] `float16.div` -- division
 
7
  - [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
8
  - [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
9
  - [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
10
+ - [~] `float16.add` -- IEEE 754 addition (~910 gates, 93/125 tests, needs zero+zero and subtraction fixes)
11
  - [ ] `float16.sub` -- subtraction (add with negated operand)
12
  - [ ] `float16.mul` -- multiplication
13
  - [ ] `float16.div` -- division
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b16619fd1cda08ab7c9ccf567ef77f8001ff7b6f76b8ed6852ad262fbc8d139
3
- size 1140364
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:098369950361600a735b8200b51642a6c11bed441619adc1c1dd609ce298af53
3
+ size 1471280
convert_to_explicit_inputs.py CHANGED
@@ -1056,11 +1056,976 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
1056
  return infer_float16_neg_inputs(gate, registry)
1057
  if gate.startswith('float16.abs'):
1058
  return infer_float16_abs_inputs(gate, registry)
 
 
1059
 
1060
  # Default: couldn't infer, return empty (will need manual fix or routing)
1061
  return []
1062
 
1063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1064
  def infer_float16_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1065
  """Infer inputs for float16.neg circuit."""
1066
  prefix = "float16.neg"
@@ -1726,6 +2691,874 @@ def build_clz16bit_tensors() -> Dict[str, torch.Tensor]:
1726
  return tensors
1727
 
1728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1729
  def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
1730
  """Build tensors for arithmetic.clz8bit circuit.
1731
 
@@ -1795,6 +3628,12 @@ def main():
1795
 
1796
  print(f"Loaded {len(tensors)} tensors")
1797
 
 
 
 
 
 
 
1798
  # Build new circuits
1799
  print("Building new circuits...")
1800
  clz_tensors = build_clz8bit_tensors()
@@ -1829,6 +3668,10 @@ def main():
1829
  tensors.update(abs_tensors)
1830
  print(f" float16.abs: {len(abs_tensors)} tensors")
1831
 
 
 
 
 
1832
  print(f"Total tensors: {len(tensors)}")
1833
 
1834
  # Load routing for complex circuits
 
1056
  return infer_float16_neg_inputs(gate, registry)
1057
  if gate.startswith('float16.abs'):
1058
  return infer_float16_abs_inputs(gate, registry)
1059
+ if gate.startswith('float16.add'):
1060
+ return infer_float16_add_inputs(gate, registry)
1061
 
1062
  # Default: couldn't infer, return empty (will need manual fix or routing)
1063
  return []
1064
 
1065
 
1066
+ def infer_float16_add_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1067
+ """Infer inputs for float16.add circuit."""
1068
+ prefix = "float16.add"
1069
+
1070
+ # Register 32 input bits (two 16-bit operands)
1071
+ for i in range(16):
1072
+ registry.register(f"{prefix}.$a[{i}]")
1073
+ registry.register(f"{prefix}.$b[{i}]")
1074
+
1075
+ # Extract exponent bits (10-14)
1076
+ exp_a_bits = [f"{prefix}.$a[{10+i}]" for i in range(5)]
1077
+ exp_b_bits = [f"{prefix}.$b[{10+i}]" for i in range(5)]
1078
+ mant_a_bits = [f"{prefix}.$a[{i}]" for i in range(10)]
1079
+ mant_b_bits = [f"{prefix}.$b[{i}]" for i in range(10)]
1080
+
1081
+ # Stage 0: Special case detection
1082
+ if '.exp_a_all_ones' in gate:
1083
+ return [registry.get_id(b) for b in exp_a_bits]
1084
+ if '.exp_b_all_ones' in gate:
1085
+ return [registry.get_id(b) for b in exp_b_bits]
1086
+ if '.exp_a_zero' in gate:
1087
+ return [registry.get_id(b) for b in exp_a_bits]
1088
+ if '.exp_b_zero' in gate:
1089
+ return [registry.get_id(b) for b in exp_b_bits]
1090
+ if '.mant_a_nonzero' in gate:
1091
+ return [registry.get_id(b) for b in mant_a_bits]
1092
+ if '.mant_b_nonzero' in gate:
1093
+ return [registry.get_id(b) for b in mant_b_bits]
1094
+ if '.mant_a_zero' in gate:
1095
+ return [registry.get_id(b) for b in mant_a_bits]
1096
+ if '.mant_b_zero' in gate:
1097
+ return [registry.get_id(b) for b in mant_b_bits]
1098
+
1099
+ registry.register(f"{prefix}.exp_a_all_ones")
1100
+ registry.register(f"{prefix}.exp_b_all_ones")
1101
+ registry.register(f"{prefix}.exp_a_zero")
1102
+ registry.register(f"{prefix}.exp_b_zero")
1103
+ registry.register(f"{prefix}.mant_a_nonzero")
1104
+ registry.register(f"{prefix}.mant_b_nonzero")
1105
+ registry.register(f"{prefix}.mant_a_zero")
1106
+ registry.register(f"{prefix}.mant_b_zero")
1107
+
1108
+ if '.a_is_nan' in gate:
1109
+ return [registry.get_id(f"{prefix}.exp_a_all_ones"),
1110
+ registry.get_id(f"{prefix}.mant_a_nonzero")]
1111
+ if '.b_is_nan' in gate:
1112
+ return [registry.get_id(f"{prefix}.exp_b_all_ones"),
1113
+ registry.get_id(f"{prefix}.mant_b_nonzero")]
1114
+ if '.a_is_inf' in gate:
1115
+ return [registry.get_id(f"{prefix}.exp_a_all_ones"),
1116
+ registry.get_id(f"{prefix}.mant_a_zero")]
1117
+ if '.b_is_inf' in gate:
1118
+ return [registry.get_id(f"{prefix}.exp_b_all_ones"),
1119
+ registry.get_id(f"{prefix}.mant_b_zero")]
1120
+ if '.a_is_zero' in gate:
1121
+ return [registry.get_id(f"{prefix}.exp_a_zero"),
1122
+ registry.get_id(f"{prefix}.mant_a_zero")]
1123
+ if '.b_is_zero' in gate:
1124
+ return [registry.get_id(f"{prefix}.exp_b_zero"),
1125
+ registry.get_id(f"{prefix}.mant_b_zero")]
1126
+ if '.a_is_subnormal' in gate:
1127
+ return [registry.get_id(f"{prefix}.exp_a_zero"),
1128
+ registry.get_id(f"{prefix}.mant_a_nonzero")]
1129
+ if '.b_is_subnormal' in gate:
1130
+ return [registry.get_id(f"{prefix}.exp_b_zero"),
1131
+ registry.get_id(f"{prefix}.mant_b_nonzero")]
1132
+
1133
+ registry.register(f"{prefix}.a_is_nan")
1134
+ registry.register(f"{prefix}.b_is_nan")
1135
+ registry.register(f"{prefix}.a_is_inf")
1136
+ registry.register(f"{prefix}.b_is_inf")
1137
+
1138
+ if '.either_is_nan' in gate:
1139
+ return [registry.get_id(f"{prefix}.a_is_nan"),
1140
+ registry.get_id(f"{prefix}.b_is_nan")]
1141
+ if '.both_are_inf' in gate:
1142
+ return [registry.get_id(f"{prefix}.a_is_inf"),
1143
+ registry.get_id(f"{prefix}.b_is_inf")]
1144
+
1145
+ # Sign extraction
1146
+ if gate == f"{prefix}.sign_a":
1147
+ return [registry.get_id(f"{prefix}.$a[15]")]
1148
+ if gate == f"{prefix}.sign_b":
1149
+ return [registry.get_id(f"{prefix}.$b[15]")]
1150
+
1151
+ registry.register(f"{prefix}.sign_a")
1152
+ registry.register(f"{prefix}.sign_b")
1153
+
1154
+ if '.signs_differ.layer1' in gate:
1155
+ return [registry.get_id(f"{prefix}.sign_a"),
1156
+ registry.get_id(f"{prefix}.sign_b")]
1157
+ if '.signs_differ.layer2' in gate:
1158
+ return [registry.register(f"{prefix}.signs_differ.layer1.or"),
1159
+ registry.register(f"{prefix}.signs_differ.layer1.nand")]
1160
+
1161
+ registry.register(f"{prefix}.signs_differ.layer2")
1162
+ registry.register(f"{prefix}.either_is_nan")
1163
+ registry.register(f"{prefix}.both_are_inf")
1164
+
1165
+ if '.inf_cancellation' in gate:
1166
+ return [registry.get_id(f"{prefix}.both_are_inf"),
1167
+ registry.get_id(f"{prefix}.signs_differ.layer2")]
1168
+
1169
+ registry.register(f"{prefix}.inf_cancellation")
1170
+
1171
+ if '.result_is_nan' in gate:
1172
+ return [registry.get_id(f"{prefix}.either_is_nan"),
1173
+ registry.get_id(f"{prefix}.inf_cancellation")]
1174
+ if '.either_is_inf' in gate:
1175
+ return [registry.get_id(f"{prefix}.a_is_inf"),
1176
+ registry.get_id(f"{prefix}.b_is_inf")]
1177
+
1178
+ registry.register(f"{prefix}.result_is_nan")
1179
+ registry.register(f"{prefix}.either_is_inf")
1180
+
1181
+ if '.not_result_is_nan' in gate:
1182
+ return [registry.get_id(f"{prefix}.result_is_nan")]
1183
+
1184
+ registry.register(f"{prefix}.not_result_is_nan")
1185
+
1186
+ if '.result_is_inf' in gate:
1187
+ return [registry.get_id(f"{prefix}.either_is_inf"),
1188
+ registry.get_id(f"{prefix}.not_result_is_nan")]
1189
+
1190
+ # Implicit bit
1191
+ if '.implicit_a' in gate:
1192
+ return [registry.get_id(f"{prefix}.exp_a_zero")]
1193
+ if '.implicit_b' in gate:
1194
+ return [registry.get_id(f"{prefix}.exp_b_zero")]
1195
+
1196
+ registry.register(f"{prefix}.implicit_a")
1197
+ registry.register(f"{prefix}.implicit_b")
1198
+
1199
+ # Exponent comparison
1200
+ if '.a_exp_ge_b' in gate or '.a_exp_gt_b' in gate:
1201
+ return [registry.get_id(b) for b in exp_a_bits] + \
1202
+ [registry.get_id(b) for b in exp_b_bits]
1203
+ if '.b_exp_gt_a' in gate and 'sel' not in gate:
1204
+ return [registry.get_id(b) for b in exp_b_bits] + \
1205
+ [registry.get_id(b) for b in exp_a_bits]
1206
+
1207
+ registry.register(f"{prefix}.a_exp_ge_b")
1208
+ registry.register(f"{prefix}.a_exp_gt_b")
1209
+ registry.register(f"{prefix}.b_exp_gt_a")
1210
+
1211
+ if '.b_exp_gt_a_sel' in gate:
1212
+ return [registry.get_id(f"{prefix}.a_exp_ge_b")]
1213
+
1214
+ registry.register(f"{prefix}.b_exp_gt_a_sel")
1215
+
1216
+ # NOT gates for exponent bits
1217
+ match = re.search(r'\.not_exp_b(\d+)', gate)
1218
+ if match:
1219
+ i = int(match.group(1))
1220
+ return [registry.get_id(f"{prefix}.$b[{10+i}]")]
1221
+
1222
+ match = re.search(r'\.not_exp_a(\d+)', gate)
1223
+ if match:
1224
+ i = int(match.group(1))
1225
+ return [registry.get_id(f"{prefix}.$a[{10+i}]")]
1226
+
1227
+ for i in range(5):
1228
+ registry.register(f"{prefix}.not_exp_b{i}")
1229
+ registry.register(f"{prefix}.not_exp_a{i}")
1230
+
1231
+ # Exp diff subtractors (diff_ab and diff_ba)
1232
+ if '.diff_ab.fa' in gate or '.diff_ba.fa' in gate:
1233
+ is_ab = '.diff_ab' in gate
1234
+ match = re.search(r'\.fa(\d+)\.', gate)
1235
+ if match:
1236
+ i = int(match.group(1))
1237
+ fa_prefix = f"{prefix}.diff_{'ab' if is_ab else 'ba'}.fa{i}"
1238
+
1239
+ if is_ab:
1240
+ a_bit = registry.get_id(f"{prefix}.$a[{10+i}]")
1241
+ not_b = registry.get_id(f"{prefix}.not_exp_b{i}")
1242
+ else:
1243
+ a_bit = registry.get_id(f"{prefix}.$b[{10+i}]")
1244
+ not_b = registry.get_id(f"{prefix}.not_exp_a{i}")
1245
+
1246
+ if i == 0:
1247
+ cin = registry.get_id("#1")
1248
+ else:
1249
+ cin = registry.register(f"{prefix}.diff_{'ab' if is_ab else 'ba'}.fa{i-1}.cout")
1250
+
1251
+ if '.xor1.layer1' in gate:
1252
+ return [a_bit, not_b]
1253
+ if '.xor1.layer2' in gate:
1254
+ return [registry.register(f"{fa_prefix}.xor1.layer1.or"),
1255
+ registry.register(f"{fa_prefix}.xor1.layer1.nand")]
1256
+
1257
+ xor1 = registry.register(f"{fa_prefix}.xor1.layer2")
1258
+
1259
+ if '.xor2.layer1' in gate:
1260
+ return [xor1, cin]
1261
+ if '.xor2.layer2' in gate:
1262
+ return [registry.register(f"{fa_prefix}.xor2.layer1.or"),
1263
+ registry.register(f"{fa_prefix}.xor2.layer1.nand")]
1264
+
1265
+ if '.and1' in gate:
1266
+ return [a_bit, not_b]
1267
+ if '.and2' in gate:
1268
+ return [xor1, cin]
1269
+ if '.cout' in gate:
1270
+ return [registry.register(f"{fa_prefix}.and1"),
1271
+ registry.register(f"{fa_prefix}.and2")]
1272
+
1273
+ # Register diff outputs
1274
+ for i in range(5):
1275
+ registry.register(f"{prefix}.diff_ab.fa{i}.xor2.layer2")
1276
+ registry.register(f"{prefix}.diff_ba.fa{i}.xor2.layer2")
1277
+
1278
+ # Exp diff mux
1279
+ match = re.search(r'\.exp_diff_mux(\d+)\.', gate)
1280
+ if match:
1281
+ i = int(match.group(1))
1282
+ if '.and_ab' in gate:
1283
+ return [registry.get_id(f"{prefix}.diff_ab.fa{i}.xor2.layer2"),
1284
+ registry.get_id(f"{prefix}.a_exp_ge_b")]
1285
+ if '.and_ba' in gate:
1286
+ return [registry.get_id(f"{prefix}.diff_ba.fa{i}.xor2.layer2"),
1287
+ registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
1288
+
1289
+ match = re.search(r'\.exp_diff(\d+)$', gate)
1290
+ if match:
1291
+ i = int(match.group(1))
1292
+ return [registry.register(f"{prefix}.exp_diff_mux{i}.and_ab"),
1293
+ registry.register(f"{prefix}.exp_diff_mux{i}.and_ba")]
1294
+
1295
+ for i in range(5):
1296
+ registry.register(f"{prefix}.exp_diff{i}")
1297
+
1298
+ # Exp larger mux
1299
+ match = re.search(r'\.exp_larger_mux(\d+)\.', gate)
1300
+ if match:
1301
+ i = int(match.group(1))
1302
+ if '.and_a' in gate:
1303
+ return [registry.get_id(f"{prefix}.$a[{10+i}]"),
1304
+ registry.get_id(f"{prefix}.a_exp_ge_b")]
1305
+ if '.and_b' in gate:
1306
+ return [registry.get_id(f"{prefix}.$b[{10+i}]"),
1307
+ registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
1308
+
1309
+ match = re.search(r'\.exp_larger(\d+)$', gate)
1310
+ if match:
1311
+ i = int(match.group(1))
1312
+ return [registry.register(f"{prefix}.exp_larger_mux{i}.and_a"),
1313
+ registry.register(f"{prefix}.exp_larger_mux{i}.and_b")]
1314
+
1315
+ for i in range(5):
1316
+ registry.register(f"{prefix}.exp_larger{i}")
1317
+
1318
+ # Mantissa source selection (which mantissa to shift)
1319
+ # mant_shift_src = a_exp_ge_b ? mant_b : mant_a
1320
+ # mant_larger = a_exp_ge_b ? mant_a : mant_b
1321
+ match = re.search(r'\.mant_shift_src(\d+)\.', gate)
1322
+ if match:
1323
+ i = int(match.group(1))
1324
+ if i < 10:
1325
+ mant_a = registry.get_id(f"{prefix}.$a[{i}]")
1326
+ mant_b = registry.get_id(f"{prefix}.$b[{i}]")
1327
+ else:
1328
+ mant_a = registry.get_id(f"{prefix}.implicit_a")
1329
+ mant_b = registry.get_id(f"{prefix}.implicit_b")
1330
+ if '.and_b' in gate:
1331
+ return [mant_b, registry.get_id(f"{prefix}.a_exp_ge_b")]
1332
+ if '.and_a' in gate:
1333
+ return [mant_a, registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
1334
+
1335
+ match = re.search(r'\.mant_shift_src(\d+)$', gate)
1336
+ if match:
1337
+ i = int(match.group(1))
1338
+ return [registry.register(f"{prefix}.mant_shift_src{i}.and_b"),
1339
+ registry.register(f"{prefix}.mant_shift_src{i}.and_a")]
1340
+
1341
+ match = re.search(r'\.mant_larger(\d+)\.', gate)
1342
+ if match:
1343
+ i = int(match.group(1))
1344
+ if i < 10:
1345
+ mant_a = registry.get_id(f"{prefix}.$a[{i}]")
1346
+ mant_b = registry.get_id(f"{prefix}.$b[{i}]")
1347
+ else:
1348
+ mant_a = registry.get_id(f"{prefix}.implicit_a")
1349
+ mant_b = registry.get_id(f"{prefix}.implicit_b")
1350
+ if '.and_a' in gate:
1351
+ return [mant_a, registry.get_id(f"{prefix}.a_exp_ge_b")]
1352
+ if '.and_b' in gate:
1353
+ return [mant_b, registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
1354
+
1355
+ match = re.search(r'\.mant_larger(\d+)$', gate)
1356
+ if match:
1357
+ i = int(match.group(1))
1358
+ return [registry.register(f"{prefix}.mant_larger{i}.and_a"),
1359
+ registry.register(f"{prefix}.mant_larger{i}.and_b")]
1360
+
1361
+ for i in range(11):
1362
+ registry.register(f"{prefix}.mant_shift_src{i}")
1363
+ registry.register(f"{prefix}.mant_larger{i}")
1364
+
1365
+ # NOT gates for exp_diff bits (barrel shifter control)
1366
+ for i in range(5):
1367
+ if f'.not_exp_diff{i}' in gate and f'.not_exp_diff{i}.' not in gate:
1368
+ return [registry.get_id(f"{prefix}.exp_diff{i}")]
1369
+ registry.register(f"{prefix}.not_exp_diff{i}")
1370
+
1371
+ # Barrel shifter stage 0 (shift by 1)
1372
+ match = re.search(r'\.shift_s0_(\d+)\.', gate)
1373
+ if match:
1374
+ i = int(match.group(1))
1375
+ if '.pass' in gate:
1376
+ return [registry.get_id(f"{prefix}.mant_shift_src{i}"),
1377
+ registry.get_id(f"{prefix}.not_exp_diff0")]
1378
+ if '.shift' in gate and i < 10:
1379
+ return [registry.get_id(f"{prefix}.mant_shift_src{i+1}"),
1380
+ registry.get_id(f"{prefix}.exp_diff0")]
1381
+
1382
+ match = re.search(r'\.shift_s0_(\d+)$', gate)
1383
+ if match:
1384
+ i = int(match.group(1))
1385
+ if i < 10:
1386
+ return [registry.register(f"{prefix}.shift_s0_{i}.pass"),
1387
+ registry.register(f"{prefix}.shift_s0_{i}.shift")]
1388
+ else:
1389
+ return [registry.register(f"{prefix}.shift_s0_{i}.pass")]
1390
+
1391
+ for i in range(11):
1392
+ registry.register(f"{prefix}.shift_s0_{i}")
1393
+
1394
+ # Barrel shifter stage 1 (shift by 2)
1395
+ match = re.search(r'\.shift_s1_(\d+)\.', gate)
1396
+ if match:
1397
+ i = int(match.group(1))
1398
+ if '.pass' in gate:
1399
+ return [registry.get_id(f"{prefix}.shift_s0_{i}"),
1400
+ registry.get_id(f"{prefix}.not_exp_diff1")]
1401
+ if '.shift' in gate and i < 9:
1402
+ return [registry.get_id(f"{prefix}.shift_s0_{i+2}"),
1403
+ registry.get_id(f"{prefix}.exp_diff1")]
1404
+
1405
+ match = re.search(r'\.shift_s1_(\d+)$', gate)
1406
+ if match:
1407
+ i = int(match.group(1))
1408
+ if i < 9:
1409
+ return [registry.register(f"{prefix}.shift_s1_{i}.pass"),
1410
+ registry.register(f"{prefix}.shift_s1_{i}.shift")]
1411
+ else:
1412
+ return [registry.register(f"{prefix}.shift_s1_{i}.pass")]
1413
+
1414
+ for i in range(11):
1415
+ registry.register(f"{prefix}.shift_s1_{i}")
1416
+
1417
+ # Barrel shifter stage 2 (shift by 4)
1418
+ match = re.search(r'\.shift_s2_(\d+)\.', gate)
1419
+ if match:
1420
+ i = int(match.group(1))
1421
+ if '.pass' in gate:
1422
+ return [registry.get_id(f"{prefix}.shift_s1_{i}"),
1423
+ registry.get_id(f"{prefix}.not_exp_diff2")]
1424
+ if '.shift' in gate and i < 7:
1425
+ return [registry.get_id(f"{prefix}.shift_s1_{i+4}"),
1426
+ registry.get_id(f"{prefix}.exp_diff2")]
1427
+
1428
+ match = re.search(r'\.shift_s2_(\d+)$', gate)
1429
+ if match:
1430
+ i = int(match.group(1))
1431
+ if i < 7:
1432
+ return [registry.register(f"{prefix}.shift_s2_{i}.pass"),
1433
+ registry.register(f"{prefix}.shift_s2_{i}.shift")]
1434
+ else:
1435
+ return [registry.register(f"{prefix}.shift_s2_{i}.pass")]
1436
+
1437
+ for i in range(11):
1438
+ registry.register(f"{prefix}.shift_s2_{i}")
1439
+
1440
+ # Barrel shifter stage 3 (shift by 8)
1441
+ match = re.search(r'\.shift_s3_(\d+)\.', gate)
1442
+ if match:
1443
+ i = int(match.group(1))
1444
+ if '.pass' in gate:
1445
+ return [registry.get_id(f"{prefix}.shift_s2_{i}"),
1446
+ registry.get_id(f"{prefix}.not_exp_diff3")]
1447
+ if '.shift' in gate and i < 3:
1448
+ return [registry.get_id(f"{prefix}.shift_s2_{i+8}"),
1449
+ registry.get_id(f"{prefix}.exp_diff3")]
1450
+
1451
+ match = re.search(r'\.shift_s3_(\d+)$', gate)
1452
+ if match:
1453
+ i = int(match.group(1))
1454
+ if i < 3:
1455
+ return [registry.register(f"{prefix}.shift_s3_{i}.pass"),
1456
+ registry.register(f"{prefix}.shift_s3_{i}.shift")]
1457
+ else:
1458
+ return [registry.register(f"{prefix}.shift_s3_{i}.pass")]
1459
+
1460
+ for i in range(11):
1461
+ registry.register(f"{prefix}.shift_s3_{i}")
1462
+
1463
+ # mant_aligned (masked by not_exp_diff4)
1464
+ match = re.search(r'\.mant_aligned(\d+)$', gate)
1465
+ if match:
1466
+ i = int(match.group(1))
1467
+ return [registry.get_id(f"{prefix}.shift_s3_{i}"),
1468
+ registry.get_id(f"{prefix}.not_exp_diff4")]
1469
+
1470
+ for i in range(11):
1471
+ registry.register(f"{prefix}.mant_aligned{i}")
1472
+
1473
+ # signs_same = NOT signs_differ
1474
+ if '.signs_same' in gate:
1475
+ return [registry.get_id(f"{prefix}.signs_differ.layer2")]
1476
+
1477
+ registry.register(f"{prefix}.signs_same")
1478
+
1479
+ # Mantissa comparison (for equal exponent case)
1480
+ if '.mant_a_ge_b' in gate:
1481
+ mant_a_full = [registry.get_id(f"{prefix}.$a[{i}]") for i in range(10)] + \
1482
+ [registry.get_id(f"{prefix}.implicit_a")]
1483
+ mant_b_full = [registry.get_id(f"{prefix}.$b[{i}]") for i in range(10)] + \
1484
+ [registry.get_id(f"{prefix}.implicit_b")]
1485
+ return mant_a_full + mant_b_full
1486
+
1487
+ registry.register(f"{prefix}.mant_a_ge_b")
1488
+
1489
+ # NOT gates for mant_aligned (for subtraction)
1490
+ match = re.search(r'\.not_mant_aligned(\d+)$', gate)
1491
+ if match:
1492
+ i = int(match.group(1))
1493
+ return [registry.get_id(f"{prefix}.mant_aligned{i}")]
1494
+
1495
+ for i in range(11):
1496
+ registry.register(f"{prefix}.not_mant_aligned{i}")
1497
+
1498
+ # sub_cin = signs_differ
1499
+ if '.sub_cin' in gate:
1500
+ return [registry.get_id(f"{prefix}.signs_differ.layer2")]
1501
+
1502
+ registry.register(f"{prefix}.sub_cin")
1503
+
1504
+ # addsub_b selection
1505
+ match = re.search(r'\.addsub_b(\d+)\.', gate)
1506
+ if match:
1507
+ i = int(match.group(1))
1508
+ if '.add' in gate:
1509
+ return [registry.get_id(f"{prefix}.mant_aligned{i}"),
1510
+ registry.get_id(f"{prefix}.signs_same")]
1511
+ if '.sub' in gate:
1512
+ return [registry.get_id(f"{prefix}.not_mant_aligned{i}"),
1513
+ registry.get_id(f"{prefix}.signs_differ.layer2")]
1514
+
1515
+ match = re.search(r'\.addsub_b(\d+)$', gate)
1516
+ if match:
1517
+ i = int(match.group(1))
1518
+ return [registry.register(f"{prefix}.addsub_b{i}.add"),
1519
+ registry.register(f"{prefix}.addsub_b{i}.sub")]
1520
+
1521
+ for i in range(11):
1522
+ registry.register(f"{prefix}.addsub_b{i}")
1523
+
1524
+ # 12-bit mantissa adder
1525
+ if '.mant_add.fa' in gate:
1526
+ match = re.search(r'\.mant_add\.fa(\d+)\.', gate)
1527
+ if match:
1528
+ i = int(match.group(1))
1529
+ fa_prefix = f"{prefix}.mant_add.fa{i}"
1530
+
1531
+ if i < 11:
1532
+ a_bit = registry.get_id(f"{prefix}.mant_larger{i}")
1533
+ b_bit = registry.get_id(f"{prefix}.addsub_b{i}")
1534
+ else:
1535
+ a_bit = registry.get_id("#0")
1536
+ b_bit = registry.get_id("#0")
1537
+
1538
+ if i == 0:
1539
+ cin = registry.get_id(f"{prefix}.sub_cin")
1540
+ else:
1541
+ cin = registry.register(f"{prefix}.mant_add.fa{i-1}.cout")
1542
+
1543
+ if '.xor1.layer1' in gate:
1544
+ return [a_bit, b_bit]
1545
+ if '.xor1.layer2' in gate:
1546
+ return [registry.register(f"{fa_prefix}.xor1.layer1.or"),
1547
+ registry.register(f"{fa_prefix}.xor1.layer1.nand")]
1548
+
1549
+ xor1 = registry.register(f"{fa_prefix}.xor1.layer2")
1550
+
1551
+ if '.xor2.layer1' in gate:
1552
+ return [xor1, cin]
1553
+ if '.xor2.layer2' in gate:
1554
+ return [registry.register(f"{fa_prefix}.xor2.layer1.or"),
1555
+ registry.register(f"{fa_prefix}.xor2.layer1.nand")]
1556
+
1557
+ if '.and1' in gate:
1558
+ return [a_bit, b_bit]
1559
+ if '.and2' in gate:
1560
+ return [xor1, cin]
1561
+ if '.cout' in gate:
1562
+ return [registry.register(f"{fa_prefix}.and1"),
1563
+ registry.register(f"{fa_prefix}.and2")]
1564
+
1565
+ for i in range(12):
1566
+ registry.register(f"{prefix}.mant_add.fa{i}.xor2.layer2")
1567
+ registry.register(f"{prefix}.mant_add.fa{i}.cout")
1568
+
1569
+ # Result sign determination
1570
+ if '.not_a_exp_gt_b' in gate:
1571
+ return [registry.get_id(f"{prefix}.a_exp_gt_b")]
1572
+
1573
+ registry.register(f"{prefix}.not_a_exp_gt_b")
1574
+
1575
+ if '.exp_a_eq_b' in gate:
1576
+ return [registry.get_id(f"{prefix}.not_a_exp_gt_b"),
1577
+ registry.get_id(f"{prefix}.b_exp_gt_a_sel")]
1578
+
1579
+ registry.register(f"{prefix}.exp_a_eq_b")
1580
+
1581
+ if '.exp_eq_and_mant_a_ge' in gate:
1582
+ return [registry.get_id(f"{prefix}.exp_a_eq_b"),
1583
+ registry.get_id(f"{prefix}.mant_a_ge_b")]
1584
+
1585
+ registry.register(f"{prefix}.exp_eq_and_mant_a_ge")
1586
+
1587
+ if '.a_magnitude_ge_b' in gate:
1588
+ return [registry.get_id(f"{prefix}.a_exp_gt_b"),
1589
+ registry.get_id(f"{prefix}.exp_eq_and_mant_a_ge")]
1590
+
1591
+ registry.register(f"{prefix}.a_magnitude_ge_b")
1592
+
1593
+ if '.not_a_mag_ge_b' in gate:
1594
+ return [registry.get_id(f"{prefix}.a_magnitude_ge_b")]
1595
+
1596
+ registry.register(f"{prefix}.not_a_mag_ge_b")
1597
+
1598
+ if '.diff_sign_sel_a' in gate:
1599
+ return [registry.get_id(f"{prefix}.sign_a"),
1600
+ registry.get_id(f"{prefix}.a_magnitude_ge_b")]
1601
+
1602
+ if '.diff_sign_sel_b' in gate:
1603
+ return [registry.get_id(f"{prefix}.sign_b"),
1604
+ registry.get_id(f"{prefix}.not_a_mag_ge_b")]
1605
+
1606
+ registry.register(f"{prefix}.diff_sign_sel_a")
1607
+ registry.register(f"{prefix}.diff_sign_sel_b")
1608
+
1609
+ if '.diff_result_sign' in gate:
1610
+ return [registry.get_id(f"{prefix}.diff_sign_sel_a"),
1611
+ registry.get_id(f"{prefix}.diff_sign_sel_b")]
1612
+
1613
+ registry.register(f"{prefix}.diff_result_sign")
1614
+
1615
+ if '.result_sign_same' in gate:
1616
+ return [registry.get_id(f"{prefix}.sign_a"),
1617
+ registry.get_id(f"{prefix}.signs_same")]
1618
+
1619
+ if '.result_sign_diff' in gate:
1620
+ return [registry.get_id(f"{prefix}.diff_result_sign"),
1621
+ registry.get_id(f"{prefix}.signs_differ.layer2")]
1622
+
1623
+ registry.register(f"{prefix}.result_sign_same")
1624
+ registry.register(f"{prefix}.result_sign_diff")
1625
+
1626
+ if gate == f"{prefix}.result_sign":
1627
+ return [registry.get_id(f"{prefix}.result_sign_same"),
1628
+ registry.get_id(f"{prefix}.result_sign_diff")]
1629
+
1630
+ registry.register(f"{prefix}.result_sign")
1631
+
1632
+ # Normalization - sum overflow (bit 11 of sum, not carry out)
1633
+ if '.sum_overflow' in gate:
1634
+ return [registry.get_id(f"{prefix}.mant_add.fa11.xor2.layer2")]
1635
+
1636
+ registry.register(f"{prefix}.sum_overflow")
1637
+
1638
+ # CLZ on bits 10:0 of sum for normalization (11 bits, not 12)
1639
+ sum_bits = [f"{prefix}.mant_add.fa{i}.xor2.layer2" for i in range(11)]
1640
+
1641
+ match = re.search(r'\.sum_pz(\d+)$', gate)
1642
+ if match:
1643
+ k = int(match.group(1))
1644
+ # Check bits 10, 9, 8, ... (from MSB to LSB of 11-bit sum)
1645
+ return [registry.get_id(sum_bits[10-i]) for i in range(k)]
1646
+
1647
+ for k in range(1, 12):
1648
+ registry.register(f"{prefix}.sum_pz{k}")
1649
+
1650
+ pz_ids = [registry.get_id(f"{prefix}.sum_pz{k}") for k in range(1, 12)]
1651
+
1652
+ match = re.search(r'\.sum_ge(\d+)$', gate)
1653
+ if match:
1654
+ return pz_ids
1655
+
1656
+ for k in range(1, 12):
1657
+ registry.register(f"{prefix}.sum_ge{k}")
1658
+
1659
+ match = re.search(r'\.sum_not_ge(\d+)$', gate)
1660
+ if match:
1661
+ k = int(match.group(1))
1662
+ return [registry.get_id(f"{prefix}.sum_ge{k}")]
1663
+
1664
+ for k in [2, 4, 6, 8, 10]:
1665
+ registry.register(f"{prefix}.sum_not_ge{k}")
1666
+
1667
+ if '.norm_shift3' in gate:
1668
+ return [registry.get_id(f"{prefix}.sum_ge8")]
1669
+
1670
+ if '.norm_and_4_7' in gate:
1671
+ return [registry.get_id(f"{prefix}.sum_ge4"),
1672
+ registry.get_id(f"{prefix}.sum_not_ge8")]
1673
+
1674
+ registry.register(f"{prefix}.norm_and_4_7")
1675
+
1676
+ # For 11-bit CLZ (max 11), shift2 = norm_and_4_7 only
1677
+ if '.norm_shift2' in gate:
1678
+ return [registry.get_id(f"{prefix}.norm_and_4_7")]
1679
+
1680
+ if '.norm_and_2_3' in gate:
1681
+ return [registry.get_id(f"{prefix}.sum_ge2"),
1682
+ registry.get_id(f"{prefix}.sum_not_ge4")]
1683
+ if '.norm_and_6_7' in gate:
1684
+ return [registry.get_id(f"{prefix}.sum_ge6"),
1685
+ registry.get_id(f"{prefix}.sum_not_ge8")]
1686
+ # For 11-bit CLZ (max 11), ge10 is sufficient (CLZ 10 or 11)
1687
+ if '.norm_and_10_11' in gate:
1688
+ return [registry.get_id(f"{prefix}.sum_ge10")]
1689
+
1690
+ registry.register(f"{prefix}.norm_and_2_3")
1691
+ registry.register(f"{prefix}.norm_and_6_7")
1692
+ registry.register(f"{prefix}.norm_and_10_11")
1693
+
1694
+ if '.norm_shift1' in gate:
1695
+ return [registry.get_id(f"{prefix}.norm_and_2_3"),
1696
+ registry.get_id(f"{prefix}.norm_and_6_7"),
1697
+ registry.get_id(f"{prefix}.norm_and_10_11")]
1698
+
1699
+ match = re.search(r'\.norm_and_(\d+)$', gate)
1700
+ if match:
1701
+ i = int(match.group(1))
1702
+ if i in [1, 3, 5, 7, 9]:
1703
+ return [registry.get_id(f"{prefix}.sum_ge{i}"),
1704
+ registry.get_id(f"{prefix}.sum_not_ge{i+1}")]
1705
+
1706
+ for i in [1, 3, 5, 7, 9]:
1707
+ registry.register(f"{prefix}.norm_and_{i}")
1708
+
1709
+ if '.norm_shift0' in gate:
1710
+ return [registry.get_id(f"{prefix}.norm_and_{i}") for i in [1, 3, 5, 7, 9]]
1711
+
1712
+ for i in range(4):
1713
+ registry.register(f"{prefix}.norm_shift{i}")
1714
+
1715
+ # Stage 10: Normalization application
1716
+ if '.not_sum_overflow' in gate:
1717
+ return [registry.get_id(f"{prefix}.sum_overflow")]
1718
+
1719
+ registry.register(f"{prefix}.not_sum_overflow")
1720
+
1721
+ # Overflow mantissa (right-shift by 1)
1722
+ match = re.search(r'\.norm_mant_overflow(\d+)$', gate)
1723
+ if match:
1724
+ i = int(match.group(1))
1725
+ return [registry.get_id(f"{prefix}.mant_add.fa{i+1}.xor2.layer2")]
1726
+
1727
+ for i in range(10):
1728
+ registry.register(f"{prefix}.norm_mant_overflow{i}")
1729
+
1730
+ # Left barrel shifter NOT gates
1731
+ for i in range(4):
1732
+ if f'.not_norm_shift{i}' in gate and '.not_norm_shift_sub' not in gate:
1733
+ return [registry.get_id(f"{prefix}.norm_shift{i}")]
1734
+ registry.register(f"{prefix}.not_norm_shift{i}")
1735
+
1736
+ # Left barrel shifter stage 0
1737
+ match = re.search(r'\.lshift_s0_(\d+)\.', gate)
1738
+ if match:
1739
+ i = int(match.group(1))
1740
+ if '.pass' in gate:
1741
+ return [registry.get_id(f"{prefix}.mant_add.fa{i}.xor2.layer2"),
1742
+ registry.get_id(f"{prefix}.not_norm_shift0")]
1743
+ if '.shift' in gate and i > 0:
1744
+ return [registry.get_id(f"{prefix}.mant_add.fa{i-1}.xor2.layer2"),
1745
+ registry.get_id(f"{prefix}.norm_shift0")]
1746
+
1747
+ match = re.search(r'\.lshift_s0_(\d+)$', gate)
1748
+ if match:
1749
+ i = int(match.group(1))
1750
+ if i > 0:
1751
+ return [registry.register(f"{prefix}.lshift_s0_{i}.pass"),
1752
+ registry.register(f"{prefix}.lshift_s0_{i}.shift")]
1753
+ else:
1754
+ return [registry.register(f"{prefix}.lshift_s0_{i}.pass")]
1755
+
1756
+ for i in range(11):
1757
+ registry.register(f"{prefix}.lshift_s0_{i}")
1758
+
1759
+ # Left barrel shifter stage 1
1760
+ match = re.search(r'\.lshift_s1_(\d+)\.', gate)
1761
+ if match:
1762
+ i = int(match.group(1))
1763
+ if '.pass' in gate:
1764
+ return [registry.get_id(f"{prefix}.lshift_s0_{i}"),
1765
+ registry.get_id(f"{prefix}.not_norm_shift1")]
1766
+ if '.shift' in gate and i > 1:
1767
+ return [registry.get_id(f"{prefix}.lshift_s0_{i-2}"),
1768
+ registry.get_id(f"{prefix}.norm_shift1")]
1769
+
1770
+ match = re.search(r'\.lshift_s1_(\d+)$', gate)
1771
+ if match:
1772
+ i = int(match.group(1))
1773
+ if i > 1:
1774
+ return [registry.register(f"{prefix}.lshift_s1_{i}.pass"),
1775
+ registry.register(f"{prefix}.lshift_s1_{i}.shift")]
1776
+ else:
1777
+ return [registry.register(f"{prefix}.lshift_s1_{i}.pass")]
1778
+
1779
+ for i in range(11):
1780
+ registry.register(f"{prefix}.lshift_s1_{i}")
1781
+
1782
+ # Left barrel shifter stage 2
1783
+ match = re.search(r'\.lshift_s2_(\d+)\.', gate)
1784
+ if match:
1785
+ i = int(match.group(1))
1786
+ if '.pass' in gate:
1787
+ return [registry.get_id(f"{prefix}.lshift_s1_{i}"),
1788
+ registry.get_id(f"{prefix}.not_norm_shift2")]
1789
+ if '.shift' in gate and i > 3:
1790
+ return [registry.get_id(f"{prefix}.lshift_s1_{i-4}"),
1791
+ registry.get_id(f"{prefix}.norm_shift2")]
1792
+
1793
+ match = re.search(r'\.lshift_s2_(\d+)$', gate)
1794
+ if match:
1795
+ i = int(match.group(1))
1796
+ if i > 3:
1797
+ return [registry.register(f"{prefix}.lshift_s2_{i}.pass"),
1798
+ registry.register(f"{prefix}.lshift_s2_{i}.shift")]
1799
+ else:
1800
+ return [registry.register(f"{prefix}.lshift_s2_{i}.pass")]
1801
+
1802
+ for i in range(11):
1803
+ registry.register(f"{prefix}.lshift_s2_{i}")
1804
+
1805
+ # Left barrel shifter stage 3
1806
+ match = re.search(r'\.lshift_s3_(\d+)\.', gate)
1807
+ if match:
1808
+ i = int(match.group(1))
1809
+ if '.pass' in gate:
1810
+ return [registry.get_id(f"{prefix}.lshift_s2_{i}"),
1811
+ registry.get_id(f"{prefix}.not_norm_shift3")]
1812
+ if '.shift' in gate and i > 7:
1813
+ return [registry.get_id(f"{prefix}.lshift_s2_{i-8}"),
1814
+ registry.get_id(f"{prefix}.norm_shift3")]
1815
+
1816
+ match = re.search(r'\.lshift_s3_(\d+)$', gate)
1817
+ if match:
1818
+ i = int(match.group(1))
1819
+ if i > 7:
1820
+ return [registry.register(f"{prefix}.lshift_s3_{i}.pass"),
1821
+ registry.register(f"{prefix}.lshift_s3_{i}.shift")]
1822
+ else:
1823
+ return [registry.register(f"{prefix}.lshift_s3_{i}.pass")]
1824
+
1825
+ for i in range(11):
1826
+ registry.register(f"{prefix}.lshift_s3_{i}")
1827
+
1828
+ # Normalized mantissa selection
1829
+ match = re.search(r'\.norm_mant(\d+)\.', gate)
1830
+ if match:
1831
+ i = int(match.group(1))
1832
+ if '.overflow_path' in gate:
1833
+ return [registry.get_id(f"{prefix}.norm_mant_overflow{i}"),
1834
+ registry.get_id(f"{prefix}.sum_overflow")]
1835
+ if '.normal_path' in gate:
1836
+ return [registry.get_id(f"{prefix}.lshift_s3_{i}"),
1837
+ registry.get_id(f"{prefix}.not_sum_overflow")]
1838
+
1839
+ match = re.search(r'\.norm_mant(\d+)$', gate)
1840
+ if match:
1841
+ i = int(match.group(1))
1842
+ return [registry.register(f"{prefix}.norm_mant{i}.overflow_path"),
1843
+ registry.register(f"{prefix}.norm_mant{i}.normal_path")]
1844
+
1845
+ for i in range(10):
1846
+ registry.register(f"{prefix}.norm_mant{i}")
1847
+
1848
+ # Exponent increment (for overflow)
1849
+ if '.exp_inc.ha0.sum' in gate:
1850
+ return [registry.get_id(f"{prefix}.exp_larger0")]
1851
+ if '.exp_inc.ha0.cout' in gate:
1852
+ return [registry.get_id(f"{prefix}.exp_larger0")]
1853
+
1854
+ registry.register(f"{prefix}.exp_inc.ha0.sum")
1855
+ registry.register(f"{prefix}.exp_inc.ha0.cout")
1856
+
1857
+ for i in range(1, 5):
1858
+ if f'.exp_inc.ha{i}.xor.layer1' in gate:
1859
+ return [registry.get_id(f"{prefix}.exp_larger{i}"),
1860
+ registry.get_id(f"{prefix}.exp_inc.ha{i-1}.cout")]
1861
+ if f'.exp_inc.ha{i}.sum' in gate:
1862
+ return [registry.register(f"{prefix}.exp_inc.ha{i}.xor.layer1.or"),
1863
+ registry.register(f"{prefix}.exp_inc.ha{i}.xor.layer1.nand")]
1864
+ if f'.exp_inc.ha{i}.cout' in gate:
1865
+ return [registry.get_id(f"{prefix}.exp_larger{i}"),
1866
+ registry.get_id(f"{prefix}.exp_inc.ha{i-1}.cout")]
1867
+ registry.register(f"{prefix}.exp_inc.ha{i}.sum")
1868
+ registry.register(f"{prefix}.exp_inc.ha{i}.cout")
1869
+
1870
+ # Exponent decrement NOT gates
1871
+ for i in range(4):
1872
+ if f'.not_norm_shift_sub{i}' in gate:
1873
+ return [registry.get_id(f"{prefix}.norm_shift{i}")]
1874
+ registry.register(f"{prefix}.not_norm_shift_sub{i}")
1875
+
1876
+ # Exponent decrement (for no overflow)
1877
+ if '.exp_dec.fa' in gate:
1878
+ match = re.search(r'\.exp_dec\.fa(\d+)\.', gate)
1879
+ if match:
1880
+ i = int(match.group(1))
1881
+ fa_prefix = f"{prefix}.exp_dec.fa{i}"
1882
+
1883
+ exp_bit = registry.get_id(f"{prefix}.exp_larger{i}")
1884
+ if i < 4:
1885
+ not_shift = registry.get_id(f"{prefix}.not_norm_shift_sub{i}")
1886
+ else:
1887
+ not_shift = registry.get_id("#1")
1888
+
1889
+ if i == 0:
1890
+ cin = registry.get_id("#1")
1891
+ else:
1892
+ cin = registry.register(f"{prefix}.exp_dec.fa{i-1}.cout")
1893
+
1894
+ if '.xor1.layer1' in gate:
1895
+ return [exp_bit, not_shift]
1896
+ if '.xor1.layer2' in gate:
1897
+ return [registry.register(f"{fa_prefix}.xor1.layer1.or"),
1898
+ registry.register(f"{fa_prefix}.xor1.layer1.nand")]
1899
+
1900
+ xor1 = registry.register(f"{fa_prefix}.xor1.layer2")
1901
+
1902
+ if '.xor2.layer1' in gate:
1903
+ return [xor1, cin]
1904
+ if '.xor2.layer2' in gate:
1905
+ return [registry.register(f"{fa_prefix}.xor2.layer1.or"),
1906
+ registry.register(f"{fa_prefix}.xor2.layer1.nand")]
1907
+
1908
+ if '.and1' in gate:
1909
+ return [exp_bit, not_shift]
1910
+ if '.and2' in gate:
1911
+ return [xor1, cin]
1912
+ if '.cout' in gate:
1913
+ return [registry.register(f"{fa_prefix}.and1"),
1914
+ registry.register(f"{fa_prefix}.and2")]
1915
+
1916
+ for i in range(5):
1917
+ registry.register(f"{prefix}.exp_dec.fa{i}.xor2.layer2")
1918
+ registry.register(f"{prefix}.exp_dec.fa{i}.cout")
1919
+
1920
+ # Result exponent selection
1921
+ match = re.search(r'\.result_exp(\d+)\.', gate)
1922
+ if match:
1923
+ i = int(match.group(1))
1924
+ if '.overflow_path' in gate:
1925
+ if i == 0:
1926
+ return [registry.get_id(f"{prefix}.exp_inc.ha0.sum"),
1927
+ registry.get_id(f"{prefix}.sum_overflow")]
1928
+ else:
1929
+ return [registry.get_id(f"{prefix}.exp_inc.ha{i}.sum"),
1930
+ registry.get_id(f"{prefix}.sum_overflow")]
1931
+ if '.normal_path' in gate:
1932
+ return [registry.get_id(f"{prefix}.exp_dec.fa{i}.xor2.layer2"),
1933
+ registry.get_id(f"{prefix}.not_sum_overflow")]
1934
+
1935
+ match = re.search(r'\.result_exp(\d+)$', gate)
1936
+ if match:
1937
+ i = int(match.group(1))
1938
+ return [registry.register(f"{prefix}.result_exp{i}.overflow_path"),
1939
+ registry.register(f"{prefix}.result_exp{i}.normal_path")]
1940
+
1941
+ for i in range(5):
1942
+ registry.register(f"{prefix}.result_exp{i}")
1943
+
1944
+ # Output assembly
1945
+ if '.not_result_is_inf' in gate:
1946
+ return [registry.get_id(f"{prefix}.result_is_inf")]
1947
+
1948
+ registry.register(f"{prefix}.not_result_is_inf")
1949
+ registry.register(f"{prefix}.result_is_inf")
1950
+
1951
+ if '.is_normal_result' in gate:
1952
+ return [registry.get_id(f"{prefix}.not_result_is_nan"),
1953
+ registry.get_id(f"{prefix}.not_result_is_inf")]
1954
+
1955
+ registry.register(f"{prefix}.is_normal_result")
1956
+
1957
+ # Inf sign selection
1958
+ if '.inf_sign_sel_a' in gate:
1959
+ return [registry.get_id(f"{prefix}.sign_a"),
1960
+ registry.get_id(f"{prefix}.a_is_inf")]
1961
+ if '.inf_sign_sel_b' in gate:
1962
+ return [registry.get_id(f"{prefix}.sign_b"),
1963
+ registry.get_id(f"{prefix}.b_is_inf")]
1964
+
1965
+ registry.register(f"{prefix}.inf_sign_sel_a")
1966
+ registry.register(f"{prefix}.inf_sign_sel_b")
1967
+
1968
+ if '.inf_sign' in gate and '.inf_sign_sel' not in gate:
1969
+ return [registry.get_id(f"{prefix}.inf_sign_sel_a"),
1970
+ registry.get_id(f"{prefix}.inf_sign_sel_b")]
1971
+
1972
+ registry.register(f"{prefix}.inf_sign")
1973
+
1974
+ # NaN bits
1975
+ nan_bits = [0]*9 + [1] + [1]*5 + [0]
1976
+ match = re.search(r'\.out_nan(\d+)$', gate)
1977
+ if match:
1978
+ return [registry.get_id(f"{prefix}.result_is_nan")]
1979
+
1980
+ # Inf bits
1981
+ match = re.search(r'\.out_inf(\d+)$', gate)
1982
+ if match:
1983
+ return [registry.get_id(f"{prefix}.result_is_inf")]
1984
+
1985
+ # Normal output path
1986
+ match = re.search(r'\.out_normal(\d+)$', gate)
1987
+ if match:
1988
+ i = int(match.group(1))
1989
+ if i == 15:
1990
+ return [registry.get_id(f"{prefix}.result_sign")]
1991
+ elif i >= 10:
1992
+ return [registry.get_id(f"{prefix}.result_exp{i-10}")]
1993
+ else:
1994
+ return [registry.get_id(f"{prefix}.norm_mant{i}")]
1995
+
1996
+ for i in range(16):
1997
+ registry.register(f"{prefix}.out_normal{i}")
1998
+
1999
+ # Final output gates
2000
+ match = re.search(r'\.out(\d+)\.(nan_gate|inf_gate|normal_gate)$', gate)
2001
+ if match:
2002
+ i = int(match.group(1))
2003
+ gate_type = match.group(2)
2004
+ if gate_type == 'nan_gate':
2005
+ nan_val = registry.register(f"{prefix}.out_nan{i}") if nan_bits[i] else registry.get_id("#0")
2006
+ return [nan_val, registry.get_id(f"{prefix}.result_is_nan")]
2007
+ elif gate_type == 'inf_gate':
2008
+ if i >= 10 and i < 15:
2009
+ inf_val = registry.register(f"{prefix}.out_inf{i}")
2010
+ elif i == 15:
2011
+ inf_val = registry.get_id(f"{prefix}.inf_sign")
2012
+ else:
2013
+ inf_val = registry.get_id("#0")
2014
+ return [inf_val, registry.get_id(f"{prefix}.result_is_inf")]
2015
+ elif gate_type == 'normal_gate':
2016
+ return [registry.get_id(f"{prefix}.out_normal{i}"),
2017
+ registry.get_id(f"{prefix}.is_normal_result")]
2018
+
2019
+ match = re.search(r'\.out(\d+)$', gate)
2020
+ if match:
2021
+ i = int(match.group(1))
2022
+ return [registry.register(f"{prefix}.out{i}.nan_gate"),
2023
+ registry.register(f"{prefix}.out{i}.inf_gate"),
2024
+ registry.register(f"{prefix}.out{i}.normal_gate")]
2025
+
2026
+ return []
2027
+
2028
+
2029
  def infer_float16_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2030
  """Infer inputs for float16.neg circuit."""
2031
  prefix = "float16.neg"
 
2691
  return tensors
2692
 
2693
 
2694
+ def build_float16_add_tensors() -> Dict[str, torch.Tensor]:
2695
+ """Build tensors for float16.add circuit.
2696
+
2697
+ IEEE 754 half-precision addition with full special case handling:
2698
+ 1. Detect special cases (NaN, infinity, zero, subnormal)
2699
+ 2. Extract sign, exponent, mantissa from both operands
2700
+ 3. Add implicit bit (1 for normal, 0 for subnormal)
2701
+ 4. Compare exponents to find which is larger
2702
+ 5. Align mantissas by shifting smaller exponent's mantissa right
2703
+ 6. Add or subtract mantissas based on signs
2704
+ 7. Normalize result and adjust exponent
2705
+ 8. Handle overflow (to infinity) and underflow (to zero/subnormal)
2706
+ 9. Pack result with correct special case outputs
2707
+
2708
+ Inputs: $a[0:15], $b[0:15] (two float16 values)
2709
+ Outputs: out[0:15] (float16 result)
2710
+ """
2711
+ tensors = {}
2712
+ prefix = "float16.add"
2713
+
2714
+ # =========================================================================
2715
+ # STAGE 0: SPECIAL CASE DETECTION
2716
+ # =========================================================================
2717
+ # Detect NaN, infinity, zero, and subnormal inputs.
2718
+ # float16 encoding:
2719
+ # - Zero: exp=0, mant=0
2720
+ # - Subnormal: exp=0, mant≠0
2721
+ # - Normal: 0 < exp < 31
2722
+ # - Infinity: exp=31, mant=0
2723
+ # - NaN: exp=31, mant≠0
2724
+
2725
+ # exp_a_all_ones: all 5 exponent bits are 1 (exp >= 31)
2726
+ # Threshold gate: sum of exp bits >= 5
2727
+ tensors[f"{prefix}.exp_a_all_ones.weight"] = torch.tensor([1.0] * 5)
2728
+ tensors[f"{prefix}.exp_a_all_ones.bias"] = torch.tensor([-5.0])
2729
+
2730
+ tensors[f"{prefix}.exp_b_all_ones.weight"] = torch.tensor([1.0] * 5)
2731
+ tensors[f"{prefix}.exp_b_all_ones.bias"] = torch.tensor([-5.0])
2732
+
2733
+ # exp_a_zero: all 5 exponent bits are 0 (NOR gate)
2734
+ tensors[f"{prefix}.exp_a_zero.weight"] = torch.tensor([-1.0] * 5)
2735
+ tensors[f"{prefix}.exp_a_zero.bias"] = torch.tensor([0.0])
2736
+
2737
+ tensors[f"{prefix}.exp_b_zero.weight"] = torch.tensor([-1.0] * 5)
2738
+ tensors[f"{prefix}.exp_b_zero.bias"] = torch.tensor([0.0])
2739
+
2740
+ # mant_a_nonzero: OR of all 10 mantissa bits
2741
+ tensors[f"{prefix}.mant_a_nonzero.weight"] = torch.tensor([1.0] * 10)
2742
+ tensors[f"{prefix}.mant_a_nonzero.bias"] = torch.tensor([-1.0])
2743
+
2744
+ tensors[f"{prefix}.mant_b_nonzero.weight"] = torch.tensor([1.0] * 10)
2745
+ tensors[f"{prefix}.mant_b_nonzero.bias"] = torch.tensor([-1.0])
2746
+
2747
+ # mant_a_zero: NOR of all mantissa bits
2748
+ tensors[f"{prefix}.mant_a_zero.weight"] = torch.tensor([-1.0] * 10)
2749
+ tensors[f"{prefix}.mant_a_zero.bias"] = torch.tensor([0.0])
2750
+
2751
+ tensors[f"{prefix}.mant_b_zero.weight"] = torch.tensor([-1.0] * 10)
2752
+ tensors[f"{prefix}.mant_b_zero.bias"] = torch.tensor([0.0])
2753
+
2754
+ # a_is_nan: exp_a_all_ones AND mant_a_nonzero
2755
+ tensors[f"{prefix}.a_is_nan.weight"] = torch.tensor([1.0, 1.0])
2756
+ tensors[f"{prefix}.a_is_nan.bias"] = torch.tensor([-2.0])
2757
+
2758
+ tensors[f"{prefix}.b_is_nan.weight"] = torch.tensor([1.0, 1.0])
2759
+ tensors[f"{prefix}.b_is_nan.bias"] = torch.tensor([-2.0])
2760
+
2761
+ # a_is_inf: exp_a_all_ones AND mant_a_zero
2762
+ tensors[f"{prefix}.a_is_inf.weight"] = torch.tensor([1.0, 1.0])
2763
+ tensors[f"{prefix}.a_is_inf.bias"] = torch.tensor([-2.0])
2764
+
2765
+ tensors[f"{prefix}.b_is_inf.weight"] = torch.tensor([1.0, 1.0])
2766
+ tensors[f"{prefix}.b_is_inf.bias"] = torch.tensor([-2.0])
2767
+
2768
+ # a_is_zero: exp_a_zero AND mant_a_zero
2769
+ tensors[f"{prefix}.a_is_zero.weight"] = torch.tensor([1.0, 1.0])
2770
+ tensors[f"{prefix}.a_is_zero.bias"] = torch.tensor([-2.0])
2771
+
2772
+ tensors[f"{prefix}.b_is_zero.weight"] = torch.tensor([1.0, 1.0])
2773
+ tensors[f"{prefix}.b_is_zero.bias"] = torch.tensor([-2.0])
2774
+
2775
+ # a_is_subnormal: exp_a_zero AND mant_a_nonzero
2776
+ tensors[f"{prefix}.a_is_subnormal.weight"] = torch.tensor([1.0, 1.0])
2777
+ tensors[f"{prefix}.a_is_subnormal.bias"] = torch.tensor([-2.0])
2778
+
2779
+ tensors[f"{prefix}.b_is_subnormal.weight"] = torch.tensor([1.0, 1.0])
2780
+ tensors[f"{prefix}.b_is_subnormal.bias"] = torch.tensor([-2.0])
2781
+
2782
+ # either_is_nan: a_is_nan OR b_is_nan
2783
+ tensors[f"{prefix}.either_is_nan.weight"] = torch.tensor([1.0, 1.0])
2784
+ tensors[f"{prefix}.either_is_nan.bias"] = torch.tensor([-1.0])
2785
+
2786
+ # both_are_inf: a_is_inf AND b_is_inf
2787
+ tensors[f"{prefix}.both_are_inf.weight"] = torch.tensor([1.0, 1.0])
2788
+ tensors[f"{prefix}.both_are_inf.bias"] = torch.tensor([-2.0])
2789
+
2790
+ # signs_differ: sign_a XOR sign_b (for inf + (-inf) = NaN case)
2791
+ # XOR layer 1
2792
+ tensors[f"{prefix}.signs_differ.layer1.or.weight"] = torch.tensor([1.0, 1.0])
2793
+ tensors[f"{prefix}.signs_differ.layer1.or.bias"] = torch.tensor([-1.0])
2794
+ tensors[f"{prefix}.signs_differ.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
2795
+ tensors[f"{prefix}.signs_differ.layer1.nand.bias"] = torch.tensor([1.0])
2796
+ tensors[f"{prefix}.signs_differ.layer2.weight"] = torch.tensor([1.0, 1.0])
2797
+ tensors[f"{prefix}.signs_differ.layer2.bias"] = torch.tensor([-2.0])
2798
+
2799
+ # inf_cancellation: both_are_inf AND signs_differ (produces NaN)
2800
+ tensors[f"{prefix}.inf_cancellation.weight"] = torch.tensor([1.0, 1.0])
2801
+ tensors[f"{prefix}.inf_cancellation.bias"] = torch.tensor([-2.0])
2802
+
2803
+ # result_is_nan: either_is_nan OR inf_cancellation
2804
+ tensors[f"{prefix}.result_is_nan.weight"] = torch.tensor([1.0, 1.0])
2805
+ tensors[f"{prefix}.result_is_nan.bias"] = torch.tensor([-1.0])
2806
+
2807
+ # either_is_inf: a_is_inf OR b_is_inf
2808
+ tensors[f"{prefix}.either_is_inf.weight"] = torch.tensor([1.0, 1.0])
2809
+ tensors[f"{prefix}.either_is_inf.bias"] = torch.tensor([-1.0])
2810
+
2811
+ # NOT result_is_nan (for masking inf result)
2812
+ tensors[f"{prefix}.not_result_is_nan.weight"] = torch.tensor([-1.0])
2813
+ tensors[f"{prefix}.not_result_is_nan.bias"] = torch.tensor([0.0])
2814
+
2815
+ # result_is_inf: either_is_inf AND NOT result_is_nan
2816
+ tensors[f"{prefix}.result_is_inf.weight"] = torch.tensor([1.0, 1.0])
2817
+ tensors[f"{prefix}.result_is_inf.bias"] = torch.tensor([-2.0])
2818
+
2819
+ # =========================================================================
2820
+ # STAGE 1: EXTRACT COMPONENTS
2821
+ # =========================================================================
2822
+ # sign_a = a[15], sign_b = b[15]
2823
+ # exp_a[0:4] = a[10:14], exp_b[0:4] = b[10:14]
2824
+ # mant_a[0:9] = a[0:9], mant_b[0:9] = b[0:9]
2825
+
2826
+ # Pass-through gates for sign extraction
2827
+ tensors[f"{prefix}.sign_a.weight"] = torch.tensor([1.0])
2828
+ tensors[f"{prefix}.sign_a.bias"] = torch.tensor([-0.5])
2829
+
2830
+ tensors[f"{prefix}.sign_b.weight"] = torch.tensor([1.0])
2831
+ tensors[f"{prefix}.sign_b.bias"] = torch.tensor([-0.5])
2832
+
2833
+ # Implicit bit calculation:
2834
+ # For normal numbers, implicit bit = 1
2835
+ # For subnormal numbers, implicit bit = 0
2836
+ # implicit_a = NOT a_is_subnormal AND NOT a_is_zero = NOT exp_a_zero
2837
+ # Actually simpler: implicit_a = NOT exp_a_zero (since exp=0 means no implicit 1)
2838
+ tensors[f"{prefix}.implicit_a.weight"] = torch.tensor([-1.0])
2839
+ tensors[f"{prefix}.implicit_a.bias"] = torch.tensor([0.0])
2840
+
2841
+ tensors[f"{prefix}.implicit_b.weight"] = torch.tensor([-1.0])
2842
+ tensors[f"{prefix}.implicit_b.bias"] = torch.tensor([0.0])
2843
+
2844
+ # =========================================================================
2845
+ # STAGE 2: EXPONENT COMPARISON
2846
+ # =========================================================================
2847
+ # Compare exp_a vs exp_b using weighted comparison
2848
+ # Weights: bit[i] contributes 2^i to the total
2849
+ # exp_a >= exp_b when weighted(exp_a) - weighted(exp_b) >= 0
2850
+
2851
+ weights_exp_a = [float(2**i) for i in range(5)] # +1, +2, +4, +8, +16
2852
+ weights_exp_b = [-float(2**i) for i in range(5)] # -1, -2, -4, -8, -16
2853
+
2854
+ # a_exp_ge_b: exp_a >= exp_b
2855
+ tensors[f"{prefix}.a_exp_ge_b.weight"] = torch.tensor(weights_exp_a + weights_exp_b)
2856
+ tensors[f"{prefix}.a_exp_ge_b.bias"] = torch.tensor([0.0]) # >= (not strict >)
2857
+
2858
+ # a_exp_gt_b: exp_a > exp_b (for strict comparison)
2859
+ tensors[f"{prefix}.a_exp_gt_b.weight"] = torch.tensor(weights_exp_a + weights_exp_b)
2860
+ tensors[f"{prefix}.a_exp_gt_b.bias"] = torch.tensor([-0.5]) # strict >
2861
+
2862
+ # b_exp_gt_a: exp_b > exp_a
2863
+ tensors[f"{prefix}.b_exp_gt_a.weight"] = torch.tensor(weights_exp_b[::-1] + weights_exp_a[::-1])
2864
+ # Actually, simpler: just swap the inputs conceptually
2865
+ # b > a means weights for b positive, weights for a negative
2866
+ tensors[f"{prefix}.b_exp_gt_a.weight"] = torch.tensor(weights_exp_a + weights_exp_b)
2867
+ tensors[f"{prefix}.b_exp_gt_a.bias"] = torch.tensor([-0.5])
2868
+
2869
+ # NOT of a_exp_ge_b (for selecting which path)
2870
+ tensors[f"{prefix}.b_exp_gt_a_sel.weight"] = torch.tensor([-1.0])
2871
+ tensors[f"{prefix}.b_exp_gt_a_sel.bias"] = torch.tensor([0.0])
2872
+
2873
+ # =========================================================================
2874
+ # STAGE 3: COMPUTE EXPONENT DIFFERENCE
2875
+ # =========================================================================
2876
+ # We need |exp_a - exp_b| for the shift amount.
2877
+ # Use 5-bit subtractors: exp_a - exp_b and exp_b - exp_a
2878
+ # Then select based on which exponent is larger.
2879
+
2880
+ # 5-bit subtractor for exp_a - exp_b (using two's complement)
2881
+ # NOT gates for exp_b
2882
+ for i in range(5):
2883
+ tensors[f"{prefix}.not_exp_b{i}.weight"] = torch.tensor([-1.0])
2884
+ tensors[f"{prefix}.not_exp_b{i}.bias"] = torch.tensor([0.0])
2885
+
2886
+ # Full adders for exp_a + NOT(exp_b) + 1 = exp_a - exp_b
2887
+ # FA0: bit 0
2888
+ # XOR1: exp_a[0] XOR not_exp_b[0]
2889
+ tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
2890
+ tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.or.bias"] = torch.tensor([-1.0])
2891
+ tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
2892
+ tensors[f"{prefix}.diff_ab.fa0.xor1.layer1.nand.bias"] = torch.tensor([1.0])
2893
+ tensors[f"{prefix}.diff_ab.fa0.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
2894
+ tensors[f"{prefix}.diff_ab.fa0.xor1.layer2.bias"] = torch.tensor([-2.0])
2895
+
2896
+ # XOR2: xor1 XOR cin (cin=1 for subtraction)
2897
+ tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
2898
+ tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.or.bias"] = torch.tensor([-1.0])
2899
+ tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
2900
+ tensors[f"{prefix}.diff_ab.fa0.xor2.layer1.nand.bias"] = torch.tensor([1.0])
2901
+ tensors[f"{prefix}.diff_ab.fa0.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
2902
+ tensors[f"{prefix}.diff_ab.fa0.xor2.layer2.bias"] = torch.tensor([-2.0])
2903
+
2904
+ # Carry: (a AND b) OR (xor1 AND cin)
2905
+ tensors[f"{prefix}.diff_ab.fa0.and1.weight"] = torch.tensor([1.0, 1.0])
2906
+ tensors[f"{prefix}.diff_ab.fa0.and1.bias"] = torch.tensor([-2.0])
2907
+ tensors[f"{prefix}.diff_ab.fa0.and2.weight"] = torch.tensor([1.0, 1.0])
2908
+ tensors[f"{prefix}.diff_ab.fa0.and2.bias"] = torch.tensor([-2.0])
2909
+ tensors[f"{prefix}.diff_ab.fa0.cout.weight"] = torch.tensor([1.0, 1.0])
2910
+ tensors[f"{prefix}.diff_ab.fa0.cout.bias"] = torch.tensor([-1.0])
2911
+
2912
+ # FA1-FA4: remaining bits (carry chain)
2913
+ for i in range(1, 5):
2914
+ p = f"{prefix}.diff_ab.fa{i}"
2915
+
2916
+ # XOR1: exp_a[i] XOR not_exp_b[i]
2917
+ tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
2918
+ tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
2919
+ tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
2920
+ tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
2921
+ tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
2922
+ tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
2923
+
2924
+ # XOR2: xor1 XOR carry_in
2925
+ tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
2926
+ tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
2927
+ tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
2928
+ tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
2929
+ tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
2930
+ tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
2931
+
2932
+ # Carry
2933
+ tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
2934
+ tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
2935
+ tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
2936
+ tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
2937
+ tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
2938
+ tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
2939
+
2940
+ # Similarly for exp_b - exp_a
2941
+ # NOT gates for exp_a
2942
+ for i in range(5):
2943
+ tensors[f"{prefix}.not_exp_a{i}.weight"] = torch.tensor([-1.0])
2944
+ tensors[f"{prefix}.not_exp_a{i}.bias"] = torch.tensor([0.0])
2945
+
2946
+ # Full adders for exp_b + NOT(exp_a) + 1 = exp_b - exp_a
2947
+ for i in range(5):
2948
+ p = f"{prefix}.diff_ba.fa{i}"
2949
+
2950
+ tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
2951
+ tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
2952
+ tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
2953
+ tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
2954
+ tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
2955
+ tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
2956
+
2957
+ tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
2958
+ tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
2959
+ tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
2960
+ tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
2961
+ tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
2962
+ tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
2963
+
2964
+ tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
2965
+ tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
2966
+ tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
2967
+ tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
2968
+ tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
2969
+ tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
2970
+
2971
+ # =========================================================================
2972
+ # STAGE 4: SELECT ABSOLUTE DIFFERENCE
2973
+ # =========================================================================
2974
+ # exp_diff = a_exp_ge_b ? (exp_a - exp_b) : (exp_b - exp_a)
2975
+ # Use 2-to-1 mux for each bit
2976
+
2977
+ for i in range(5):
2978
+ # Mux: out = (sel AND b) OR (NOT sel AND a)
2979
+ # sel = b_exp_gt_a_sel (1 if b > a, meaning we want diff_ba)
2980
+ # Actually: sel=0 (a>=b) -> use diff_ab, sel=1 (b>a) -> use diff_ba
2981
+
2982
+ # AND gate for diff_ab path (when a_exp_ge_b = 1)
2983
+ tensors[f"{prefix}.exp_diff_mux{i}.and_ab.weight"] = torch.tensor([1.0, 1.0])
2984
+ tensors[f"{prefix}.exp_diff_mux{i}.and_ab.bias"] = torch.tensor([-2.0])
2985
+
2986
+ # AND gate for diff_ba path (when b_exp_gt_a_sel = 1, i.e., a_exp_ge_b = 0)
2987
+ tensors[f"{prefix}.exp_diff_mux{i}.and_ba.weight"] = torch.tensor([1.0, 1.0])
2988
+ tensors[f"{prefix}.exp_diff_mux{i}.and_ba.bias"] = torch.tensor([-2.0])
2989
+
2990
+ # OR to combine
2991
+ tensors[f"{prefix}.exp_diff{i}.weight"] = torch.tensor([1.0, 1.0])
2992
+ tensors[f"{prefix}.exp_diff{i}.bias"] = torch.tensor([-1.0])
2993
+
2994
+ # =========================================================================
2995
+ # STAGE 5: SELECT LARGER EXPONENT (for result)
2996
+ # =========================================================================
2997
+ # exp_larger = a_exp_ge_b ? exp_a : exp_b
2998
+
2999
+ for i in range(5):
3000
+ # AND gate for exp_a path
3001
+ tensors[f"{prefix}.exp_larger_mux{i}.and_a.weight"] = torch.tensor([1.0, 1.0])
3002
+ tensors[f"{prefix}.exp_larger_mux{i}.and_a.bias"] = torch.tensor([-2.0])
3003
+
3004
+ # AND gate for exp_b path
3005
+ tensors[f"{prefix}.exp_larger_mux{i}.and_b.weight"] = torch.tensor([1.0, 1.0])
3006
+ tensors[f"{prefix}.exp_larger_mux{i}.and_b.bias"] = torch.tensor([-2.0])
3007
+
3008
+ # OR to combine
3009
+ tensors[f"{prefix}.exp_larger{i}.weight"] = torch.tensor([1.0, 1.0])
3010
+ tensors[f"{prefix}.exp_larger{i}.bias"] = torch.tensor([-1.0])
3011
+
3012
+ # =========================================================================
3013
+ # STAGE 6: MANTISSA ALIGNMENT (Barrel Shifter)
3014
+ # =========================================================================
3015
+ # The smaller exponent's mantissa needs to be shifted right by exp_diff.
3016
+ # Mantissa is 11 bits: implicit bit + 10 explicit mantissa bits.
3017
+ #
3018
+ # We need to:
3019
+ # 1. Select which mantissa to shift (the one with smaller exponent)
3020
+ # 2. Shift it right by exp_diff positions
3021
+ # 3. The larger mantissa passes through unchanged
3022
+ #
3023
+ # For the barrel shifter, we use cascaded 2-to-1 muxes:
3024
+ # - Stage 0: shift by 0 or 1 (controlled by exp_diff[0])
3025
+ # - Stage 1: shift by 0 or 2 (controlled by exp_diff[1])
3026
+ # - Stage 2: shift by 0 or 4 (controlled by exp_diff[2])
3027
+ # - Stage 3: shift by 0 or 8 (controlled by exp_diff[3])
3028
+ #
3029
+ # If exp_diff >= 11, the shifted mantissa becomes 0 (complete loss).
3030
+
3031
+ # First, select which mantissa gets shifted (the smaller exponent one)
3032
+ # mant_to_shift = a_exp_ge_b ? mant_b : mant_a (shift the smaller exp's mantissa)
3033
+ # mant_larger = a_exp_ge_b ? mant_a : mant_b
3034
+
3035
+ # Full mantissa with implicit bit: 11 bits (bit 10 = implicit, bits 9-0 = explicit)
3036
+ for i in range(11):
3037
+ # mant_shift_src[i] = mux(a_exp_ge_b, mant_b[i], mant_a[i])
3038
+ # When a_exp_ge_b=1, we shift b's mantissa (a has larger exp)
3039
+ # When a_exp_ge_b=0, we shift a's mantissa (b has larger exp)
3040
+
3041
+ tensors[f"{prefix}.mant_shift_src{i}.and_b.weight"] = torch.tensor([1.0, 1.0])
3042
+ tensors[f"{prefix}.mant_shift_src{i}.and_b.bias"] = torch.tensor([-2.0])
3043
+
3044
+ tensors[f"{prefix}.mant_shift_src{i}.and_a.weight"] = torch.tensor([1.0, 1.0])
3045
+ tensors[f"{prefix}.mant_shift_src{i}.and_a.bias"] = torch.tensor([-2.0])
3046
+
3047
+ tensors[f"{prefix}.mant_shift_src{i}.weight"] = torch.tensor([1.0, 1.0])
3048
+ tensors[f"{prefix}.mant_shift_src{i}.bias"] = torch.tensor([-1.0])
3049
+
3050
+ # mant_larger[i] = mux(a_exp_ge_b, mant_a[i], mant_b[i])
3051
+ tensors[f"{prefix}.mant_larger{i}.and_a.weight"] = torch.tensor([1.0, 1.0])
3052
+ tensors[f"{prefix}.mant_larger{i}.and_a.bias"] = torch.tensor([-2.0])
3053
+
3054
+ tensors[f"{prefix}.mant_larger{i}.and_b.weight"] = torch.tensor([1.0, 1.0])
3055
+ tensors[f"{prefix}.mant_larger{i}.and_b.bias"] = torch.tensor([-2.0])
3056
+
3057
+ tensors[f"{prefix}.mant_larger{i}.weight"] = torch.tensor([1.0, 1.0])
3058
+ tensors[f"{prefix}.mant_larger{i}.bias"] = torch.tensor([-1.0])
3059
+
3060
+ # Barrel shifter stages
3061
+ # Stage 0: shift by 1 if exp_diff[0]=1
3062
+ # NOT exp_diff[0] for pass-through path
3063
+ tensors[f"{prefix}.not_exp_diff0.weight"] = torch.tensor([-1.0])
3064
+ tensors[f"{prefix}.not_exp_diff0.bias"] = torch.tensor([0.0])
3065
+
3066
+ for i in range(11):
3067
+ # Output bit i comes from:
3068
+ # - bit i if not shifting (exp_diff[0]=0)
3069
+ # - bit i+1 if shifting (exp_diff[0]=1), or 0 if i+1 >= 11
3070
+ tensors[f"{prefix}.shift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3071
+ tensors[f"{prefix}.shift_s0_{i}.pass.bias"] = torch.tensor([-2.0])
3072
+
3073
+ if i < 10:
3074
+ tensors[f"{prefix}.shift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3075
+ tensors[f"{prefix}.shift_s0_{i}.shift.bias"] = torch.tensor([-2.0])
3076
+ tensors[f"{prefix}.shift_s0_{i}.weight"] = torch.tensor([1.0, 1.0])
3077
+ else:
3078
+ # MSB: shift path is 0, so just pass-through when not shifting
3079
+ tensors[f"{prefix}.shift_s0_{i}.weight"] = torch.tensor([1.0])
3080
+ tensors[f"{prefix}.shift_s0_{i}.bias"] = torch.tensor([-1.0])
3081
+
3082
+ # Stage 1: shift by 2 if exp_diff[1]=1
3083
+ tensors[f"{prefix}.not_exp_diff1.weight"] = torch.tensor([-1.0])
3084
+ tensors[f"{prefix}.not_exp_diff1.bias"] = torch.tensor([0.0])
3085
+
3086
+ for i in range(11):
3087
+ tensors[f"{prefix}.shift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3088
+ tensors[f"{prefix}.shift_s1_{i}.pass.bias"] = torch.tensor([-2.0])
3089
+
3090
+ if i < 9:
3091
+ tensors[f"{prefix}.shift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3092
+ tensors[f"{prefix}.shift_s1_{i}.shift.bias"] = torch.tensor([-2.0])
3093
+ tensors[f"{prefix}.shift_s1_{i}.weight"] = torch.tensor([1.0, 1.0])
3094
+ else:
3095
+ tensors[f"{prefix}.shift_s1_{i}.weight"] = torch.tensor([1.0])
3096
+ tensors[f"{prefix}.shift_s1_{i}.bias"] = torch.tensor([-1.0])
3097
+
3098
+ # Stage 2: shift by 4 if exp_diff[2]=1
3099
+ tensors[f"{prefix}.not_exp_diff2.weight"] = torch.tensor([-1.0])
3100
+ tensors[f"{prefix}.not_exp_diff2.bias"] = torch.tensor([0.0])
3101
+
3102
+ for i in range(11):
3103
+ tensors[f"{prefix}.shift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3104
+ tensors[f"{prefix}.shift_s2_{i}.pass.bias"] = torch.tensor([-2.0])
3105
+
3106
+ if i < 7:
3107
+ tensors[f"{prefix}.shift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3108
+ tensors[f"{prefix}.shift_s2_{i}.shift.bias"] = torch.tensor([-2.0])
3109
+ tensors[f"{prefix}.shift_s2_{i}.weight"] = torch.tensor([1.0, 1.0])
3110
+ else:
3111
+ tensors[f"{prefix}.shift_s2_{i}.weight"] = torch.tensor([1.0])
3112
+ tensors[f"{prefix}.shift_s2_{i}.bias"] = torch.tensor([-1.0])
3113
+
3114
+ # Stage 3: shift by 8 if exp_diff[3]=1
3115
+ tensors[f"{prefix}.not_exp_diff3.weight"] = torch.tensor([-1.0])
3116
+ tensors[f"{prefix}.not_exp_diff3.bias"] = torch.tensor([0.0])
3117
+
3118
+ for i in range(11):
3119
+ tensors[f"{prefix}.shift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3120
+ tensors[f"{prefix}.shift_s3_{i}.pass.bias"] = torch.tensor([-2.0])
3121
+
3122
+ if i < 3:
3123
+ tensors[f"{prefix}.shift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3124
+ tensors[f"{prefix}.shift_s3_{i}.shift.bias"] = torch.tensor([-2.0])
3125
+ tensors[f"{prefix}.shift_s3_{i}.weight"] = torch.tensor([1.0, 1.0])
3126
+ else:
3127
+ tensors[f"{prefix}.shift_s3_{i}.weight"] = torch.tensor([1.0])
3128
+ tensors[f"{prefix}.shift_s3_{i}.bias"] = torch.tensor([-1.0])
3129
+
3130
+ # If exp_diff[4]=1 (shift by 16 or more), result is 0
3131
+ # mant_aligned = exp_diff[4] ? 0 : shift_s3 result
3132
+ tensors[f"{prefix}.not_exp_diff4.weight"] = torch.tensor([-1.0])
3133
+ tensors[f"{prefix}.not_exp_diff4.bias"] = torch.tensor([0.0])
3134
+
3135
+ for i in range(11):
3136
+ # Only pass through if exp_diff[4]=0
3137
+ tensors[f"{prefix}.mant_aligned{i}.weight"] = torch.tensor([1.0, 1.0])
3138
+ tensors[f"{prefix}.mant_aligned{i}.bias"] = torch.tensor([-2.0])
3139
+
3140
+ # =========================================================================
3141
+ # STAGE 7: MANTISSA ADDITION/SUBTRACTION
3142
+ # =========================================================================
3143
+ # If signs are the same: add mantissas
3144
+ # If signs differ: subtract smaller from larger
3145
+ #
3146
+ # We have:
3147
+ # - mant_larger[10:0]: mantissa of the larger exponent operand
3148
+ # - mant_aligned[10:0]: shifted mantissa of the smaller exponent operand
3149
+ #
3150
+ # For subtraction, we need to know which mantissa is larger.
3151
+ # If exp_a > exp_b, then mant_a is the reference (could be smaller mantissa value)
3152
+ # If exp_a == exp_b, we need to compare mantissas directly.
3153
+ #
3154
+ # signs_same: NOT signs_differ
3155
+ tensors[f"{prefix}.signs_same.weight"] = torch.tensor([-1.0])
3156
+ tensors[f"{prefix}.signs_same.bias"] = torch.tensor([0.0])
3157
+
3158
+ # For the result sign when signs differ:
3159
+ # If exp_a > exp_b: result sign = sign_a
3160
+ # If exp_b > exp_a: result sign = sign_b
3161
+ # If exp_a == exp_b: result sign = sign of larger mantissa
3162
+
3163
+ # Mantissa comparison (for equal exponent case)
3164
+ # Compare mant_a vs mant_b when exponents are equal
3165
+ weights_mant = [float(2**i) for i in range(11)]
3166
+ neg_weights_mant = [-float(2**i) for i in range(11)]
3167
+
3168
+ tensors[f"{prefix}.mant_a_ge_b.weight"] = torch.tensor(weights_mant + neg_weights_mant)
3169
+ tensors[f"{prefix}.mant_a_ge_b.bias"] = torch.tensor([0.0])
3170
+
3171
+ # 12-bit adder for mantissa sum (11 mantissa bits + 1 carry out)
3172
+ # We'll compute mant_larger + mant_aligned (for same sign)
3173
+ # or |mant_larger - mant_aligned| (for different signs)
3174
+
3175
+ # For subtraction, we need: larger_mant - smaller_mant
3176
+ # If exponents differ, larger exp means larger value, so:
3177
+ # result = mant_larger - mant_aligned
3178
+ # If exponents equal, compare mantissas:
3179
+ # result = |mant_a - mant_b|
3180
+
3181
+ # NOT gates for mant_aligned (for subtraction)
3182
+ for i in range(11):
3183
+ tensors[f"{prefix}.not_mant_aligned{i}.weight"] = torch.tensor([-1.0])
3184
+ tensors[f"{prefix}.not_mant_aligned{i}.bias"] = torch.tensor([0.0])
3185
+
3186
+ # 12-bit adder/subtractor
3187
+ # When signs_same=1: add (carry_in = 0)
3188
+ # When signs_same=0: subtract (use NOT mant_aligned, carry_in = 1)
3189
+
3190
+ # Carry input selection: signs_same ? 0 : 1
3191
+ # This is just NOT signs_same = signs_differ
3192
+ tensors[f"{prefix}.sub_cin.weight"] = torch.tensor([1.0])
3193
+ tensors[f"{prefix}.sub_cin.bias"] = torch.tensor([-0.5])
3194
+
3195
+ # Operand B selection: signs_same ? mant_aligned : NOT mant_aligned
3196
+ for i in range(11):
3197
+ # When adding (signs_same=1): use mant_aligned
3198
+ tensors[f"{prefix}.addsub_b{i}.add.weight"] = torch.tensor([1.0, 1.0])
3199
+ tensors[f"{prefix}.addsub_b{i}.add.bias"] = torch.tensor([-2.0])
3200
+
3201
+ # When subtracting (signs_same=0 = signs_differ=1): use NOT mant_aligned
3202
+ tensors[f"{prefix}.addsub_b{i}.sub.weight"] = torch.tensor([1.0, 1.0])
3203
+ tensors[f"{prefix}.addsub_b{i}.sub.bias"] = torch.tensor([-2.0])
3204
+
3205
+ tensors[f"{prefix}.addsub_b{i}.weight"] = torch.tensor([1.0, 1.0])
3206
+ tensors[f"{prefix}.addsub_b{i}.bias"] = torch.tensor([-1.0])
3207
+
3208
+ # 12-bit ripple carry adder for mant_larger + addsub_b + sub_cin
3209
+ for i in range(12):
3210
+ p = f"{prefix}.mant_add.fa{i}"
3211
+
3212
+ tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
3213
+ tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
3214
+ tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
3215
+ tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
3216
+ tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
3217
+ tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
3218
+
3219
+ tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
3220
+ tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
3221
+ tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
3222
+ tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
3223
+ tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
3224
+ tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
3225
+
3226
+ tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
3227
+ tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
3228
+ tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
3229
+ tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
3230
+ tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
3231
+ tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
3232
+
3233
+ # =========================================================================
3234
+ # STAGE 8: RESULT SIGN DETERMINATION
3235
+ # =========================================================================
3236
+ # When signs_same: result_sign = sign_a (= sign_b)
3237
+ # When signs_differ:
3238
+ # If a has larger magnitude: result_sign = sign_a
3239
+ # If b has larger magnitude: result_sign = sign_b
3240
+ #
3241
+ # Magnitude comparison: consider both exponent and mantissa
3242
+ # a_magnitude_ge_b: (exp_a > exp_b) OR (exp_a == exp_b AND mant_a >= mant_b)
3243
+
3244
+ # exp_a_eq_b: NOT a_exp_gt_b AND NOT b_exp_gt_a
3245
+ tensors[f"{prefix}.not_a_exp_gt_b.weight"] = torch.tensor([-1.0])
3246
+ tensors[f"{prefix}.not_a_exp_gt_b.bias"] = torch.tensor([0.0])
3247
+
3248
+ tensors[f"{prefix}.exp_a_eq_b.weight"] = torch.tensor([1.0, 1.0])
3249
+ tensors[f"{prefix}.exp_a_eq_b.bias"] = torch.tensor([-2.0])
3250
+
3251
+ # exp_eq_and_mant_a_ge: exp_a_eq_b AND mant_a_ge_b
3252
+ tensors[f"{prefix}.exp_eq_and_mant_a_ge.weight"] = torch.tensor([1.0, 1.0])
3253
+ tensors[f"{prefix}.exp_eq_and_mant_a_ge.bias"] = torch.tensor([-2.0])
3254
+
3255
+ # a_magnitude_ge_b: a_exp_gt_b OR exp_eq_and_mant_a_ge
3256
+ tensors[f"{prefix}.a_magnitude_ge_b.weight"] = torch.tensor([1.0, 1.0])
3257
+ tensors[f"{prefix}.a_magnitude_ge_b.bias"] = torch.tensor([-1.0])
3258
+
3259
+ # result_sign when signs_differ:
3260
+ # = a_magnitude_ge_b ? sign_a : sign_b
3261
+ tensors[f"{prefix}.not_a_mag_ge_b.weight"] = torch.tensor([-1.0])
3262
+ tensors[f"{prefix}.not_a_mag_ge_b.bias"] = torch.tensor([0.0])
3263
+
3264
+ tensors[f"{prefix}.diff_sign_sel_a.weight"] = torch.tensor([1.0, 1.0])
3265
+ tensors[f"{prefix}.diff_sign_sel_a.bias"] = torch.tensor([-2.0])
3266
+
3267
+ tensors[f"{prefix}.diff_sign_sel_b.weight"] = torch.tensor([1.0, 1.0])
3268
+ tensors[f"{prefix}.diff_sign_sel_b.bias"] = torch.tensor([-2.0])
3269
+
3270
+ tensors[f"{prefix}.diff_result_sign.weight"] = torch.tensor([1.0, 1.0])
3271
+ tensors[f"{prefix}.diff_result_sign.bias"] = torch.tensor([-1.0])
3272
+
3273
+ # Final result sign: signs_same ? sign_a : diff_result_sign
3274
+ tensors[f"{prefix}.result_sign_same.weight"] = torch.tensor([1.0, 1.0])
3275
+ tensors[f"{prefix}.result_sign_same.bias"] = torch.tensor([-2.0])
3276
+
3277
+ tensors[f"{prefix}.result_sign_diff.weight"] = torch.tensor([1.0, 1.0])
3278
+ tensors[f"{prefix}.result_sign_diff.bias"] = torch.tensor([-2.0])
3279
+
3280
+ tensors[f"{prefix}.result_sign.weight"] = torch.tensor([1.0, 1.0])
3281
+ tensors[f"{prefix}.result_sign.bias"] = torch.tensor([-1.0])
3282
+
3283
+ # =========================================================================
3284
+ # STAGE 9: NORMALIZATION
3285
+ # =========================================================================
3286
+ # The mantissa sum may need normalization:
3287
+ # - If bit 12 (carry out) is set: right shift by 1, increment exponent
3288
+ # - If leading bit is 0: left shift until leading 1 found, decrement exponent
3289
+ #
3290
+ # Use CLZ to find shift amount for left shift case.
3291
+ # The sum is 12 bits (mant_add output).
3292
+
3293
+ # Overflow detection: mant_add.fa11 carry out
3294
+ tensors[f"{prefix}.sum_overflow.weight"] = torch.tensor([1.0])
3295
+ tensors[f"{prefix}.sum_overflow.bias"] = torch.tensor([-0.5])
3296
+
3297
+ # CLZ on 11-bit sum (bits 10:0) to find normalization shift
3298
+ # For non-overflow case, count leading zeros starting from bit 10
3299
+ # pz gates: prefix zero detectors on bits 10:0
3300
+ for k in range(1, 12):
3301
+ tensors[f"{prefix}.sum_pz{k}.weight"] = torch.tensor([-1.0] * k)
3302
+ tensors[f"{prefix}.sum_pz{k}.bias"] = torch.tensor([0.0])
3303
+
3304
+ # ge gates: sum of pz >= k (for 11-bit CLZ, max is 11)
3305
+ for k in range(1, 12):
3306
+ tensors[f"{prefix}.sum_ge{k}.weight"] = torch.tensor([1.0] * 11)
3307
+ tensors[f"{prefix}.sum_ge{k}.bias"] = torch.tensor([-float(k)])
3308
+
3309
+ # NOT gates for binary encoding
3310
+ for k in [2, 4, 6, 8, 10]:
3311
+ tensors[f"{prefix}.sum_not_ge{k}.weight"] = torch.tensor([-1.0])
3312
+ tensors[f"{prefix}.sum_not_ge{k}.bias"] = torch.tensor([0.0])
3313
+
3314
+ # Shift amount encoding (4 bits for 0-11)
3315
+ # CLZ of 11 bits can be 0-11
3316
+ tensors[f"{prefix}.norm_shift3.weight"] = torch.tensor([1.0])
3317
+ tensors[f"{prefix}.norm_shift3.bias"] = torch.tensor([-0.5]) # ge8
3318
+
3319
+ tensors[f"{prefix}.norm_and_4_7.weight"] = torch.tensor([1.0, 1.0])
3320
+ tensors[f"{prefix}.norm_and_4_7.bias"] = torch.tensor([-2.0])
3321
+ # For 11-bit CLZ (max 11), shift2 = ge4 AND NOT ge8 (no ge12 needed)
3322
+ tensors[f"{prefix}.norm_shift2.weight"] = torch.tensor([1.0])
3323
+ tensors[f"{prefix}.norm_shift2.bias"] = torch.tensor([-0.5])
3324
+
3325
+ tensors[f"{prefix}.norm_and_2_3.weight"] = torch.tensor([1.0, 1.0])
3326
+ tensors[f"{prefix}.norm_and_2_3.bias"] = torch.tensor([-2.0])
3327
+ tensors[f"{prefix}.norm_and_6_7.weight"] = torch.tensor([1.0, 1.0])
3328
+ tensors[f"{prefix}.norm_and_6_7.bias"] = torch.tensor([-2.0])
3329
+ # For 11-bit CLZ (max 11), ge10 means CLZ is 10 or 11, no need for NOT ge12
3330
+ tensors[f"{prefix}.norm_and_10_11.weight"] = torch.tensor([1.0])
3331
+ tensors[f"{prefix}.norm_and_10_11.bias"] = torch.tensor([-0.5])
3332
+ tensors[f"{prefix}.norm_shift1.weight"] = torch.tensor([1.0, 1.0, 1.0])
3333
+ tensors[f"{prefix}.norm_shift1.bias"] = torch.tensor([-1.0])
3334
+
3335
+ for i in [1, 3, 5, 7, 9]:
3336
+ tensors[f"{prefix}.norm_and_{i}.weight"] = torch.tensor([1.0, 1.0])
3337
+ tensors[f"{prefix}.norm_and_{i}.bias"] = torch.tensor([-2.0])
3338
+ tensors[f"{prefix}.norm_shift0.weight"] = torch.tensor([1.0] * 5)
3339
+ tensors[f"{prefix}.norm_shift0.bias"] = torch.tensor([-1.0])
3340
+
3341
+ # =========================================================================
3342
+ # STAGE 10: APPLY NORMALIZATION TO MANTISSA
3343
+ # =========================================================================
3344
+ # Two cases:
3345
+ # 1. Overflow (sum bit 11 set): right-shift mantissa by 1, increment exponent
3346
+ # 2. No overflow: left-shift mantissa by norm_shift, decrement exponent
3347
+
3348
+ # NOT sum_overflow for non-overflow path
3349
+ tensors[f"{prefix}.not_sum_overflow.weight"] = torch.tensor([-1.0])
3350
+ tensors[f"{prefix}.not_sum_overflow.bias"] = torch.tensor([0.0])
3351
+
3352
+ # Overflow mantissa: bits 10:1 of adder_sum (right-shifted by 1)
3353
+ # norm_mant_overflow[i] = adder_sum[i+1] for i in 0..9
3354
+ for i in range(10):
3355
+ tensors[f"{prefix}.norm_mant_overflow{i}.weight"] = torch.tensor([1.0])
3356
+ tensors[f"{prefix}.norm_mant_overflow{i}.bias"] = torch.tensor([-0.5])
3357
+
3358
+ # Non-overflow mantissa: left-shift adder_sum[10:0] by norm_shift amount
3359
+ # This requires a left barrel shifter on the 11-bit sum (bits 10:0)
3360
+
3361
+ # Left barrel shifter stage 0: shift left by 1 if norm_shift[0]=1
3362
+ tensors[f"{prefix}.not_norm_shift0.weight"] = torch.tensor([-1.0])
3363
+ tensors[f"{prefix}.not_norm_shift0.bias"] = torch.tensor([0.0])
3364
+
3365
+ for i in range(11):
3366
+ tensors[f"{prefix}.lshift_s0_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3367
+ tensors[f"{prefix}.lshift_s0_{i}.pass.bias"] = torch.tensor([-2.0])
3368
+ if i > 0:
3369
+ tensors[f"{prefix}.lshift_s0_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3370
+ tensors[f"{prefix}.lshift_s0_{i}.shift.bias"] = torch.tensor([-2.0])
3371
+ tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0, 1.0])
3372
+ else:
3373
+ tensors[f"{prefix}.lshift_s0_{i}.weight"] = torch.tensor([1.0])
3374
+ tensors[f"{prefix}.lshift_s0_{i}.bias"] = torch.tensor([-1.0])
3375
+
3376
+ # Left barrel shifter stage 1: shift left by 2 if norm_shift[1]=1
3377
+ tensors[f"{prefix}.not_norm_shift1.weight"] = torch.tensor([-1.0])
3378
+ tensors[f"{prefix}.not_norm_shift1.bias"] = torch.tensor([0.0])
3379
+
3380
+ for i in range(11):
3381
+ tensors[f"{prefix}.lshift_s1_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3382
+ tensors[f"{prefix}.lshift_s1_{i}.pass.bias"] = torch.tensor([-2.0])
3383
+ if i > 1:
3384
+ tensors[f"{prefix}.lshift_s1_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3385
+ tensors[f"{prefix}.lshift_s1_{i}.shift.bias"] = torch.tensor([-2.0])
3386
+ tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0, 1.0])
3387
+ else:
3388
+ tensors[f"{prefix}.lshift_s1_{i}.weight"] = torch.tensor([1.0])
3389
+ tensors[f"{prefix}.lshift_s1_{i}.bias"] = torch.tensor([-1.0])
3390
+
3391
+ # Left barrel shifter stage 2: shift left by 4 if norm_shift[2]=1
3392
+ tensors[f"{prefix}.not_norm_shift2.weight"] = torch.tensor([-1.0])
3393
+ tensors[f"{prefix}.not_norm_shift2.bias"] = torch.tensor([0.0])
3394
+
3395
+ for i in range(11):
3396
+ tensors[f"{prefix}.lshift_s2_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3397
+ tensors[f"{prefix}.lshift_s2_{i}.pass.bias"] = torch.tensor([-2.0])
3398
+ if i > 3:
3399
+ tensors[f"{prefix}.lshift_s2_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3400
+ tensors[f"{prefix}.lshift_s2_{i}.shift.bias"] = torch.tensor([-2.0])
3401
+ tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0, 1.0])
3402
+ else:
3403
+ tensors[f"{prefix}.lshift_s2_{i}.weight"] = torch.tensor([1.0])
3404
+ tensors[f"{prefix}.lshift_s2_{i}.bias"] = torch.tensor([-1.0])
3405
+
3406
+ # Left barrel shifter stage 3: shift left by 8 if norm_shift[3]=1
3407
+ tensors[f"{prefix}.not_norm_shift3.weight"] = torch.tensor([-1.0])
3408
+ tensors[f"{prefix}.not_norm_shift3.bias"] = torch.tensor([0.0])
3409
+
3410
+ for i in range(11):
3411
+ tensors[f"{prefix}.lshift_s3_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
3412
+ tensors[f"{prefix}.lshift_s3_{i}.pass.bias"] = torch.tensor([-2.0])
3413
+ if i > 7:
3414
+ tensors[f"{prefix}.lshift_s3_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
3415
+ tensors[f"{prefix}.lshift_s3_{i}.shift.bias"] = torch.tensor([-2.0])
3416
+ tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0, 1.0])
3417
+ else:
3418
+ tensors[f"{prefix}.lshift_s3_{i}.weight"] = torch.tensor([1.0])
3419
+ tensors[f"{prefix}.lshift_s3_{i}.bias"] = torch.tensor([-1.0])
3420
+
3421
+ # Select normalized mantissa: overflow ? overflow_mant : lshift result
3422
+ # Take bits 9:0 for the output mantissa (bit 10 is implicit, dropped)
3423
+ for i in range(10):
3424
+ tensors[f"{prefix}.norm_mant{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0])
3425
+ tensors[f"{prefix}.norm_mant{i}.overflow_path.bias"] = torch.tensor([-2.0])
3426
+ tensors[f"{prefix}.norm_mant{i}.normal_path.weight"] = torch.tensor([1.0, 1.0])
3427
+ tensors[f"{prefix}.norm_mant{i}.normal_path.bias"] = torch.tensor([-2.0])
3428
+ tensors[f"{prefix}.norm_mant{i}.weight"] = torch.tensor([1.0, 1.0])
3429
+ tensors[f"{prefix}.norm_mant{i}.bias"] = torch.tensor([-1.0])
3430
+
3431
+ # =========================================================================
3432
+ # STAGE 11: ADJUST EXPONENT
3433
+ # =========================================================================
3434
+ # Overflow: exp_result = exp_larger + 1
3435
+ # No overflow: exp_result = exp_larger - norm_shift
3436
+
3437
+ # Increment exponent by 1 (for overflow case)
3438
+ # Half adder chain: exp_larger + 1
3439
+ tensors[f"{prefix}.exp_inc.ha0.sum.weight"] = torch.tensor([-1.0]) # NOT for XOR with 1
3440
+ tensors[f"{prefix}.exp_inc.ha0.sum.bias"] = torch.tensor([0.0])
3441
+ tensors[f"{prefix}.exp_inc.ha0.cout.weight"] = torch.tensor([1.0]) # AND with 1 = passthrough
3442
+ tensors[f"{prefix}.exp_inc.ha0.cout.bias"] = torch.tensor([-0.5])
3443
+
3444
+ for i in range(1, 5):
3445
+ # XOR: exp[i] XOR carry_in
3446
+ tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0])
3447
+ tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.or.bias"] = torch.tensor([-1.0])
3448
+ tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
3449
+ tensors[f"{prefix}.exp_inc.ha{i}.xor.layer1.nand.bias"] = torch.tensor([1.0])
3450
+ tensors[f"{prefix}.exp_inc.ha{i}.sum.weight"] = torch.tensor([1.0, 1.0])
3451
+ tensors[f"{prefix}.exp_inc.ha{i}.sum.bias"] = torch.tensor([-2.0])
3452
+ # Carry: exp[i] AND carry_in
3453
+ tensors[f"{prefix}.exp_inc.ha{i}.cout.weight"] = torch.tensor([1.0, 1.0])
3454
+ tensors[f"{prefix}.exp_inc.ha{i}.cout.bias"] = torch.tensor([-2.0])
3455
+
3456
+ # Decrement exponent by norm_shift (for non-overflow case)
3457
+ # 5-bit subtractor: exp_larger - norm_shift
3458
+ # NOT gates for norm_shift
3459
+ for i in range(4):
3460
+ tensors[f"{prefix}.not_norm_shift_sub{i}.weight"] = torch.tensor([-1.0])
3461
+ tensors[f"{prefix}.not_norm_shift_sub{i}.bias"] = torch.tensor([0.0])
3462
+
3463
+ # Full adders for exp_larger + NOT(norm_shift) + 1 = exp_larger - norm_shift
3464
+ for i in range(5):
3465
+ p = f"{prefix}.exp_dec.fa{i}"
3466
+ tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
3467
+ tensors[f"{p}.xor1.layer1.or.bias"] = torch.tensor([-1.0])
3468
+ tensors[f"{p}.xor1.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
3469
+ tensors[f"{p}.xor1.layer1.nand.bias"] = torch.tensor([1.0])
3470
+ tensors[f"{p}.xor1.layer2.weight"] = torch.tensor([1.0, 1.0])
3471
+ tensors[f"{p}.xor1.layer2.bias"] = torch.tensor([-2.0])
3472
+
3473
+ tensors[f"{p}.xor2.layer1.or.weight"] = torch.tensor([1.0, 1.0])
3474
+ tensors[f"{p}.xor2.layer1.or.bias"] = torch.tensor([-1.0])
3475
+ tensors[f"{p}.xor2.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
3476
+ tensors[f"{p}.xor2.layer1.nand.bias"] = torch.tensor([1.0])
3477
+ tensors[f"{p}.xor2.layer2.weight"] = torch.tensor([1.0, 1.0])
3478
+ tensors[f"{p}.xor2.layer2.bias"] = torch.tensor([-2.0])
3479
+
3480
+ tensors[f"{p}.and1.weight"] = torch.tensor([1.0, 1.0])
3481
+ tensors[f"{p}.and1.bias"] = torch.tensor([-2.0])
3482
+ tensors[f"{p}.and2.weight"] = torch.tensor([1.0, 1.0])
3483
+ tensors[f"{p}.and2.bias"] = torch.tensor([-2.0])
3484
+ tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
3485
+ tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
3486
+
3487
+ # Select result exponent: overflow ? exp_inc : exp_dec
3488
+ for i in range(5):
3489
+ tensors[f"{prefix}.result_exp{i}.overflow_path.weight"] = torch.tensor([1.0, 1.0])
3490
+ tensors[f"{prefix}.result_exp{i}.overflow_path.bias"] = torch.tensor([-2.0])
3491
+ tensors[f"{prefix}.result_exp{i}.normal_path.weight"] = torch.tensor([1.0, 1.0])
3492
+ tensors[f"{prefix}.result_exp{i}.normal_path.bias"] = torch.tensor([-2.0])
3493
+ tensors[f"{prefix}.result_exp{i}.weight"] = torch.tensor([1.0, 1.0])
3494
+ tensors[f"{prefix}.result_exp{i}.bias"] = torch.tensor([-1.0])
3495
+
3496
+ # =========================================================================
3497
+ # STAGE 12: OUTPUT ASSEMBLY
3498
+ # =========================================================================
3499
+ # Final output combines:
3500
+ # - Special cases (NaN, Inf) override normal computation
3501
+ # - For NaN: output canonical NaN (0x7E00)
3502
+ # - For Inf: output Inf with correct sign
3503
+ # - For normal: pack normalized result
3504
+
3505
+ # NaN output: 0x7E00 = 0111111000000000
3506
+ nan_bits = [0]*9 + [1] + [1]*5 + [0] # bits 0-15
3507
+
3508
+ # Final output mux: nan ? nan_val : (inf ? inf_val : normal_val)
3509
+ tensors[f"{prefix}.not_result_is_inf.weight"] = torch.tensor([-1.0])
3510
+ tensors[f"{prefix}.not_result_is_inf.bias"] = torch.tensor([0.0])
3511
+
3512
+ # Normal case selector: NOT nan AND NOT inf
3513
+ tensors[f"{prefix}.is_normal_result.weight"] = torch.tensor([1.0, 1.0])
3514
+ tensors[f"{prefix}.is_normal_result.bias"] = torch.tensor([-2.0])
3515
+
3516
+ # Inf sign selection
3517
+ tensors[f"{prefix}.inf_sign_sel_a.weight"] = torch.tensor([1.0, 1.0])
3518
+ tensors[f"{prefix}.inf_sign_sel_a.bias"] = torch.tensor([-2.0])
3519
+ tensors[f"{prefix}.inf_sign_sel_b.weight"] = torch.tensor([1.0, 1.0])
3520
+ tensors[f"{prefix}.inf_sign_sel_b.bias"] = torch.tensor([-2.0])
3521
+ tensors[f"{prefix}.inf_sign.weight"] = torch.tensor([1.0, 1.0])
3522
+ tensors[f"{prefix}.inf_sign.bias"] = torch.tensor([-1.0])
3523
+
3524
+ for i in range(16):
3525
+ # NaN path: output NaN bits gated by result_is_nan
3526
+ if nan_bits[i]:
3527
+ tensors[f"{prefix}.out_nan{i}.weight"] = torch.tensor([1.0])
3528
+ tensors[f"{prefix}.out_nan{i}.bias"] = torch.tensor([-0.5])
3529
+
3530
+ # Inf path: exponent bits = 1, mantissa = 0, sign from inf operand
3531
+ if i >= 10 and i < 15:
3532
+ tensors[f"{prefix}.out_inf{i}.weight"] = torch.tensor([1.0])
3533
+ tensors[f"{prefix}.out_inf{i}.bias"] = torch.tensor([-0.5])
3534
+
3535
+ # Normal path
3536
+ if i < 10:
3537
+ # Mantissa bits from norm_mant
3538
+ tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0])
3539
+ tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5])
3540
+ elif i < 15:
3541
+ # Exponent bits from result_exp
3542
+ tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0])
3543
+ tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5])
3544
+ else:
3545
+ # Sign bit from result_sign
3546
+ tensors[f"{prefix}.out_normal{i}.weight"] = torch.tensor([1.0])
3547
+ tensors[f"{prefix}.out_normal{i}.bias"] = torch.tensor([-0.5])
3548
+
3549
+ # Final output: 3-way mux (nan, inf, normal)
3550
+ tensors[f"{prefix}.out{i}.nan_gate.weight"] = torch.tensor([1.0, 1.0])
3551
+ tensors[f"{prefix}.out{i}.nan_gate.bias"] = torch.tensor([-2.0])
3552
+ tensors[f"{prefix}.out{i}.inf_gate.weight"] = torch.tensor([1.0, 1.0])
3553
+ tensors[f"{prefix}.out{i}.inf_gate.bias"] = torch.tensor([-2.0])
3554
+ tensors[f"{prefix}.out{i}.normal_gate.weight"] = torch.tensor([1.0, 1.0])
3555
+ tensors[f"{prefix}.out{i}.normal_gate.bias"] = torch.tensor([-2.0])
3556
+ tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0, 1.0])
3557
+ tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0])
3558
+
3559
+ return tensors
3560
+
3561
+
3562
  def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
3563
  """Build tensors for arithmetic.clz8bit circuit.
3564
 
 
3628
 
3629
  print(f"Loaded {len(tensors)} tensors")
3630
 
3631
+ # Remove old float16.add tensors (we're rebuilding from scratch)
3632
+ old_float16_add = [k for k in tensors.keys() if k.startswith('float16.add')]
3633
+ for k in old_float16_add:
3634
+ del tensors[k]
3635
+ print(f"Removed {len(old_float16_add)} old float16.add tensors")
3636
+
3637
  # Build new circuits
3638
  print("Building new circuits...")
3639
  clz_tensors = build_clz8bit_tensors()
 
3668
  tensors.update(abs_tensors)
3669
  print(f" float16.abs: {len(abs_tensors)} tensors")
3670
 
3671
+ add_tensors = build_float16_add_tensors()
3672
+ tensors.update(add_tensors)
3673
+ print(f" float16.add: {len(add_tensors)} tensors")
3674
+
3675
  print(f"Total tensors: {len(tensors)}")
3676
 
3677
  # Load routing for complex circuits
eval.py CHANGED
@@ -632,6 +632,159 @@ class CircuitEvaluator:
632
 
633
  return TestResult('float16.abs', passed, len(test_values), failures)
634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  # =========================================================================
636
  # ARITHMETIC TESTS (DIRECT EVALUATION)
637
  # =========================================================================
@@ -827,6 +980,11 @@ class Evaluator:
827
  self.results.append(result)
828
  if verbose:
829
  self._print_result(result)
 
 
 
 
 
830
 
831
  # Comparators
832
  if verbose:
 
632
 
633
  return TestResult('float16.abs', passed, len(test_values), failures)
634
 
635
+ def test_float16_add(self) -> TestResult:
636
+ """Test float16.add (IEEE 754 addition)."""
637
+ prefix = 'float16.add'
638
+ failures = []
639
+ passed = 0
640
+
641
+ import struct
642
+ import math
643
+
644
+ def float16_to_float(bits):
645
+ try:
646
+ return struct.unpack('e', struct.pack('H', bits))[0]
647
+ except:
648
+ return float('nan')
649
+
650
+ def float_to_float16(f):
651
+ try:
652
+ return struct.unpack('H', struct.pack('e', f))[0]
653
+ except:
654
+ return 0x7E00 # NaN
655
+
656
+ # Test cases: pairs of (a, b)
657
+ test_cases = [
658
+ # Zero cases
659
+ (0x0000, 0x0000), # +0 + +0 = +0
660
+ (0x0000, 0x3C00), # +0 + 1.0 = 1.0
661
+ (0x3C00, 0x0000), # 1.0 + +0 = 1.0
662
+
663
+ # Same sign addition
664
+ (0x3C00, 0x3C00), # 1.0 + 1.0 = 2.0
665
+ (0x4000, 0x3C00), # 2.0 + 1.0 = 3.0
666
+ (0x3800, 0x3800), # 0.5 + 0.5 = 1.0
667
+ (0x4200, 0x4000), # 3.0 + 2.0 = 5.0
668
+
669
+ # Different sign (subtraction)
670
+ (0x4000, 0xBC00), # 2.0 + (-1.0) = 1.0
671
+ (0x3C00, 0xBC00), # 1.0 + (-1.0) = 0.0
672
+ (0xBC00, 0x4000), # -1.0 + 2.0 = 1.0
673
+ (0xC000, 0x3C00), # -2.0 + 1.0 = -1.0
674
+
675
+ # Negative + negative
676
+ (0xBC00, 0xBC00), # -1.0 + -1.0 = -2.0
677
+ (0xC000, 0xBC00), # -2.0 + -1.0 = -3.0
678
+
679
+ # Different exponents
680
+ (0x4400, 0x3C00), # 4.0 + 1.0 = 5.0
681
+ (0x4800, 0x3C00), # 8.0 + 1.0 = 9.0
682
+ (0x3C00, 0x3400), # 1.0 + 0.25 = 1.25
683
+
684
+ # Infinity cases
685
+ (0x7C00, 0x3C00), # +inf + 1.0 = +inf
686
+ (0x3C00, 0x7C00), # 1.0 + +inf = +inf
687
+ (0xFC00, 0xBC00), # -inf + -1.0 = -inf
688
+ (0x7C00, 0xFC00), # +inf + -inf = NaN
689
+
690
+ # NaN cases
691
+ (0x7E00, 0x3C00), # NaN + 1.0 = NaN
692
+ (0x3C00, 0x7E00), # 1.0 + NaN = NaN
693
+ ]
694
+
695
+ # Add some random test cases
696
+ import random
697
+ random.seed(42)
698
+ for _ in range(50):
699
+ a = random.randint(0, 0x7BFF) # positive normal
700
+ b = random.randint(0, 0x7BFF)
701
+ test_cases.append((a, b))
702
+ # Some negative combinations
703
+ if random.random() > 0.5:
704
+ test_cases.append((a | 0x8000, b))
705
+ if random.random() > 0.5:
706
+ test_cases.append((a, b | 0x8000))
707
+
708
+ for a_bits, b_bits in test_cases:
709
+ a_float = float16_to_float(a_bits)
710
+ b_float = float16_to_float(b_bits)
711
+
712
+ # Expected result
713
+ if math.isnan(a_float) or math.isnan(b_float):
714
+ expected_nan = True
715
+ expected_inf = False
716
+ expected = 0x7E00
717
+ elif math.isinf(a_float) and math.isinf(b_float):
718
+ if (a_float > 0) != (b_float > 0):
719
+ expected_nan = True
720
+ expected_inf = False
721
+ expected = 0x7E00
722
+ else:
723
+ expected_nan = False
724
+ expected_inf = True
725
+ expected = 0x7C00 if a_float > 0 else 0xFC00
726
+ elif math.isinf(a_float):
727
+ expected_nan = False
728
+ expected_inf = True
729
+ expected = 0x7C00 if a_float > 0 else 0xFC00
730
+ elif math.isinf(b_float):
731
+ expected_nan = False
732
+ expected_inf = True
733
+ expected = 0x7C00 if b_float > 0 else 0xFC00
734
+ else:
735
+ expected_nan = False
736
+ expected_inf = False
737
+ result_float = a_float + b_float
738
+ expected = float_to_float16(result_float)
739
+
740
+ # Set up inputs
741
+ ext = {}
742
+ for i in range(16):
743
+ ext[f'{prefix}.$a[{i}]'] = float((a_bits >> i) & 1)
744
+ ext[f'{prefix}.$b[{i}]'] = float((b_bits >> i) & 1)
745
+
746
+ values = self.eval_circuit(prefix, ext)
747
+
748
+ # Extract result
749
+ result = 0
750
+ for i in range(16):
751
+ bit = int(values.get(f'{prefix}.out{i}', 0))
752
+ result |= (bit << i)
753
+
754
+ # Check special cases first
755
+ result_is_nan = int(values.get(f'{prefix}.result_is_nan', 0))
756
+ result_is_inf = int(values.get(f'{prefix}.result_is_inf', 0))
757
+
758
+ # For NaN, check that result_is_nan is set
759
+ if expected_nan:
760
+ if result_is_nan == 1:
761
+ passed += 1
762
+ else:
763
+ if len(failures) < 10:
764
+ failures.append((a_bits, b_bits, 'expected NaN', result, a_float, b_float))
765
+ # For Inf, check result_is_inf and sign
766
+ elif expected_inf:
767
+ expected_sign = (expected >> 15) & 1
768
+ result_sign = (result >> 15) & 1
769
+ if result_is_inf == 1:
770
+ passed += 1
771
+ else:
772
+ if len(failures) < 10:
773
+ failures.append((a_bits, b_bits, expected, result, a_float, b_float))
774
+ else:
775
+ # For normal results, allow small tolerance
776
+ if result == expected:
777
+ passed += 1
778
+ else:
779
+ # Check if within 1 ULP
780
+ if abs(result - expected) <= 1:
781
+ passed += 1
782
+ else:
783
+ if len(failures) < 10:
784
+ failures.append((a_bits, b_bits, expected, result, a_float, b_float))
785
+
786
+ return TestResult('float16.add', passed, len(test_cases), failures)
787
+
788
  # =========================================================================
789
  # ARITHMETIC TESTS (DIRECT EVALUATION)
790
  # =========================================================================
 
980
  self.results.append(result)
981
  if verbose:
982
  self._print_result(result)
983
+ if 'float16.add.sign_a.weight' in self.eval.tensors:
984
+ result = self.eval.test_float16_add()
985
+ self.results.append(result)
986
+ if verbose:
987
+ self._print_result(result)
988
 
989
  # Comparators
990
  if verbose: