PortfolioAI commited on
Commit
48421e0
·
1 Parent(s): 1bcc911

Rewrite float16.toint with right-shift barrel shifter

Browse files

- Changed from left-shift to right-shift (25-exp positions)
- Improved from 31/93 to 54/93 tests passing
- Still needs debugging for remaining edge cases

arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9bdd850c9d33e5e667744caf0ce5dee8afde5aa5fafbf20eb812fa647c556626
3
- size 2860992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c7bc258f49a9a4c85321d0980f843bb989e8fb84c4d8bc65883f24c6e306334
3
+ size 2863492
convert_to_explicit_inputs.py CHANGED
@@ -3046,7 +3046,7 @@ def infer_float16_div_inputs(gate: str, registry: SignalRegistry) -> List[int]:
3046
 
3047
 
3048
  def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]:
3049
- """Infer inputs for float16.toint circuit."""
3050
  prefix = "float16.toint"
3051
 
3052
  for i in range(16):
@@ -3055,6 +3055,7 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3055
  exp_bits = [f"{prefix}.$x[{10+i}]" for i in range(5)]
3056
  mant_bits = [f"{prefix}.$x[{i}]" for i in range(10)]
3057
 
 
3058
  if '.exp_all_ones' in gate:
3059
  return [registry.get_id(b) for b in exp_bits]
3060
  if '.exp_zero' in gate:
@@ -3080,7 +3081,7 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3080
  if '.is_inf' in gate:
3081
  return [registry.get_id(f"{prefix}.exp_all_ones"),
3082
  registry.get_id(f"{prefix}.mant_zero")]
3083
- if '.is_zero' in gate:
3084
  return [registry.get_id(f"{prefix}.exp_zero"),
3085
  registry.get_id(f"{prefix}.mant_zero")]
3086
 
@@ -3092,27 +3093,53 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3092
 
3093
  registry.register(f"{prefix}.exp_lt_15")
3094
 
3095
- if '.result_is_zero' in gate:
3096
  return [registry.get_id(f"{prefix}.is_nan"),
3097
  registry.get_id(f"{prefix}.is_zero"),
3098
  registry.get_id(f"{prefix}.exp_lt_15")]
3099
 
3100
  registry.register(f"{prefix}.result_is_zero")
3101
 
 
 
 
 
 
3102
  if '.implicit_bit' in gate:
3103
  return [registry.get_id(f"{prefix}.exp_zero")]
3104
 
3105
  registry.register(f"{prefix}.implicit_bit")
3106
 
3107
- bits_25 = [1, 0, 0, 1, 1, 0]
3108
- not_25 = [0, 1, 1, 0, 0, 1]
 
 
 
 
 
3109
 
3110
- match = re.search(r'\.not_25_(\d+)$', gate)
3111
- if match:
3112
- return [registry.get_id(f"#{bits_25[int(match.group(1))]}")]
 
 
 
 
 
 
3113
 
 
 
 
 
3114
  for i in range(5):
3115
- registry.register(f"{prefix}.not_25_{i}")
 
 
 
 
 
 
3116
 
3117
  if '.shift_calc.fa' in gate:
3118
  match = re.search(r'\.shift_calc\.fa(\d+)\.', gate)
@@ -3120,9 +3147,11 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3120
  i = int(match.group(1))
3121
  fa_prefix = f"{prefix}.shift_calc.fa{i}"
3122
 
3123
- a_bit = registry.get_id(exp_bits[i]) if i < 5 else registry.get_id("#0")
3124
- b_bit = registry.get_id(f"#{not_25[i]}")
3125
- cin = registry.get_id("#1") if i == 0 else registry.register(f"{prefix}.shift_calc.fa{i-1}.cout")
 
 
3126
 
3127
  if '.xor1.layer1' in gate:
3128
  return [a_bit, b_bit]
@@ -3147,6 +3176,7 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3147
  registry.register(f"{prefix}.shift_calc.fa{i}.xor2.layer2")
3148
  registry.register(f"{prefix}.shift_calc.fa{i}.cout")
3149
 
 
3150
  for stage in range(4):
3151
  shift_amt = 1 << stage
3152
 
@@ -3154,10 +3184,13 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3154
  return [registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
3155
  registry.register(f"{prefix}.not_shift{stage}")
3156
 
3157
- match = re.search(rf'\.lshift_s{stage}_(\d+)\.', gate)
3158
  if match:
3159
  i = int(match.group(1))
 
 
3160
  if '.pass' in gate:
 
3161
  if stage == 0:
3162
  if i < 10:
3163
  val = registry.get_id(mant_bits[i])
@@ -3166,36 +3199,39 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3166
  else:
3167
  val = registry.get_id("#0")
3168
  else:
3169
- val = registry.get_id(f"{prefix}.lshift_s{stage-1}_{i}")
3170
  return [val, registry.get_id(f"{prefix}.not_shift{stage}")]
3171
- if '.shift' in gate and i >= shift_amt:
 
 
3172
  if stage == 0:
3173
- prev_i = i - shift_amt
3174
- if prev_i < 10:
3175
- val = registry.get_id(mant_bits[prev_i])
3176
- elif prev_i == 10:
3177
  val = registry.get_id(f"{prefix}.implicit_bit")
3178
  else:
3179
  val = registry.get_id("#0")
3180
  else:
3181
- val = registry.get_id(f"{prefix}.lshift_s{stage-1}_{i-shift_amt}")
3182
  return [val, registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
3183
 
3184
- match = re.search(rf'\.lshift_s{stage}_(\d+)$', gate)
3185
  if match:
3186
  i = int(match.group(1))
3187
- if i >= shift_amt:
3188
- return [registry.register(f"{prefix}.lshift_s{stage}_{i}.pass"),
3189
- registry.register(f"{prefix}.lshift_s{stage}_{i}.shift")]
 
3190
  else:
3191
- return [registry.register(f"{prefix}.lshift_s{stage}_{i}.pass")]
3192
 
3193
  for i in range(16):
3194
- registry.register(f"{prefix}.lshift_s{stage}_{i}")
3195
 
 
3196
  for i in range(16):
3197
  if f'.not_mag{i}' in gate:
3198
- return [registry.get_id(f"{prefix}.lshift_s3_{i}")]
3199
  registry.register(f"{prefix}.not_mag{i}")
3200
 
3201
  if '.neg.fa' in gate:
@@ -3222,17 +3258,24 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3222
  registry.register(f"{prefix}.neg.fa{i}.xor.layer2")
3223
  registry.register(f"{prefix}.neg.fa{i}.cout")
3224
 
 
3225
  match = re.search(r'\.out(\d+)\.', gate)
3226
  if match:
3227
  i = int(match.group(1))
3228
  sign = registry.get_id(f"{prefix}.$x[15]")
 
3229
  if '.pos_path' in gate:
3230
- return [registry.get_id(f"{prefix}.lshift_s3_{i}"),
3231
- registry.get_id(f"{prefix}.$x[15]")]
3232
  if '.neg_path' in gate:
3233
  return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
3234
  sign]
3235
 
 
 
 
 
 
3236
  match = re.search(r'\.out(\d+)$', gate)
3237
  if match:
3238
  i = int(match.group(1))
@@ -5726,54 +5769,135 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
5726
  Convert float16 to signed 16-bit integer (truncate toward zero).
5727
 
5728
  Algorithm:
5729
- 1. If NaN or Inf, return 0 or max/min int
5730
- 2. If |value| < 1, return 0
5731
- 3. Shift mantissa by (exponent - 15 - 10) positions
5732
- 4. Apply sign
 
 
 
 
5733
  """
5734
  tensors = {}
5735
  prefix = "float16.toint"
5736
 
5737
- # Special case detection
 
5738
  tensors[f"{prefix}.exp_all_ones.weight"] = torch.tensor([1.0] * 5)
5739
  tensors[f"{prefix}.exp_all_ones.bias"] = torch.tensor([-5.0])
5740
 
 
5741
  tensors[f"{prefix}.exp_zero.weight"] = torch.tensor([-1.0] * 5)
5742
  tensors[f"{prefix}.exp_zero.bias"] = torch.tensor([0.0])
5743
 
 
5744
  tensors[f"{prefix}.mant_nonzero.weight"] = torch.tensor([1.0] * 10)
5745
  tensors[f"{prefix}.mant_nonzero.bias"] = torch.tensor([-1.0])
5746
 
 
5747
  tensors[f"{prefix}.is_nan.weight"] = torch.tensor([1.0, 1.0])
5748
  tensors[f"{prefix}.is_nan.bias"] = torch.tensor([-2.0])
5749
 
 
5750
  tensors[f"{prefix}.mant_zero.weight"] = torch.tensor([-1.0])
5751
  tensors[f"{prefix}.mant_zero.bias"] = torch.tensor([0.0])
5752
 
 
5753
  tensors[f"{prefix}.is_inf.weight"] = torch.tensor([1.0, 1.0])
5754
  tensors[f"{prefix}.is_inf.bias"] = torch.tensor([-2.0])
5755
 
 
5756
  tensors[f"{prefix}.is_zero.weight"] = torch.tensor([1.0, 1.0])
5757
  tensors[f"{prefix}.is_zero.bias"] = torch.tensor([-2.0])
5758
 
5759
- # Check if exponent < 15 (|value| < 1)
5760
- # exp < 15 means unbiased exp < 0, so result is 0
 
5761
  weights = [-float(2**i) for i in range(5)]
5762
  tensors[f"{prefix}.exp_lt_15.weight"] = torch.tensor(weights)
5763
  tensors[f"{prefix}.exp_lt_15.bias"] = torch.tensor([14.0])
5764
 
 
5765
  tensors[f"{prefix}.result_is_zero.weight"] = torch.tensor([1.0, 1.0, 1.0])
5766
  tensors[f"{prefix}.result_is_zero.bias"] = torch.tensor([-1.0])
5767
 
5768
- # Compute shift amount: exp - 15 - 10 = exp - 25
5769
- # If positive, left shift mantissa; if negative, right shift
5770
- # For int16, max shift is 15 (for values up to 32767)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5771
 
5772
- # exp - 25 subtractor
5773
  for i in range(5):
5774
- tensors[f"{prefix}.not_25_{i}.weight"] = torch.tensor([1.0])
5775
- tensors[f"{prefix}.not_25_{i}.bias"] = torch.tensor([-0.5])
5776
 
 
 
 
5777
  for i in range(6):
5778
  p = f"{prefix}.shift_calc.fa{i}"
5779
  tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
@@ -5797,32 +5921,40 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
5797
  tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
5798
  tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
5799
 
5800
- # Barrel shifter (left shift mantissa)
5801
- tensors[f"{prefix}.implicit_bit.weight"] = torch.tensor([-1.0])
5802
- tensors[f"{prefix}.implicit_bit.bias"] = torch.tensor([0.0])
 
5803
 
5804
  for stage in range(4):
5805
  shift_amt = 1 << stage
 
5806
  tensors[f"{prefix}.not_shift{stage}.weight"] = torch.tensor([-1.0])
5807
  tensors[f"{prefix}.not_shift{stage}.bias"] = torch.tensor([0.0])
5808
 
5809
  for i in range(16):
5810
- tensors[f"{prefix}.lshift_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
5811
- tensors[f"{prefix}.lshift_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0])
5812
- if i >= shift_amt:
5813
- tensors[f"{prefix}.lshift_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
5814
- tensors[f"{prefix}.lshift_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0])
5815
- tensors[f"{prefix}.lshift_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0])
 
 
 
 
5816
  else:
5817
- tensors[f"{prefix}.lshift_s{stage}_{i}.weight"] = torch.tensor([1.0])
5818
- tensors[f"{prefix}.lshift_s{stage}_{i}.bias"] = torch.tensor([-1.0])
 
 
 
 
5819
 
5820
- # Apply sign (negate if negative)
5821
  for i in range(16):
5822
  tensors[f"{prefix}.not_mag{i}.weight"] = torch.tensor([-1.0])
5823
  tensors[f"{prefix}.not_mag{i}.bias"] = torch.tensor([0.0])
5824
 
5825
- # Two's complement negation
5826
  for i in range(16):
5827
  p = f"{prefix}.neg.fa{i}"
5828
  tensors[f"{p}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0])
@@ -5837,7 +5969,13 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
5837
  tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
5838
  tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
5839
 
5840
- # Output selection
 
 
 
 
 
 
5841
  for i in range(16):
5842
  tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0])
5843
  tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-2.0])
@@ -6081,12 +6219,17 @@ def main():
6081
 
6082
  print(f"Loaded {len(tensors)} tensors")
6083
 
6084
- # Remove old float16.add tensors (we're rebuilding from scratch)
6085
  old_float16_add = [k for k in tensors.keys() if k.startswith('float16.add')]
6086
  for k in old_float16_add:
6087
  del tensors[k]
6088
  print(f"Removed {len(old_float16_add)} old float16.add tensors")
6089
 
 
 
 
 
 
6090
  # Build new circuits
6091
  print("Building new circuits...")
6092
  clz_tensors = build_clz8bit_tensors()
 
3046
 
3047
 
3048
  def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]:
3049
+ """Infer inputs for float16.toint circuit (with right-shift barrel shifter)."""
3050
  prefix = "float16.toint"
3051
 
3052
  for i in range(16):
 
3055
  exp_bits = [f"{prefix}.$x[{10+i}]" for i in range(5)]
3056
  mant_bits = [f"{prefix}.$x[{i}]" for i in range(10)]
3057
 
3058
+ # === SPECIAL CASE DETECTION ===
3059
  if '.exp_all_ones' in gate:
3060
  return [registry.get_id(b) for b in exp_bits]
3061
  if '.exp_zero' in gate:
 
3081
  if '.is_inf' in gate:
3082
  return [registry.get_id(f"{prefix}.exp_all_ones"),
3083
  registry.get_id(f"{prefix}.mant_zero")]
3084
+ if '.is_zero' in gate and '.not_' not in gate and '.result_is_zero' not in gate:
3085
  return [registry.get_id(f"{prefix}.exp_zero"),
3086
  registry.get_id(f"{prefix}.mant_zero")]
3087
 
 
3093
 
3094
  registry.register(f"{prefix}.exp_lt_15")
3095
 
3096
+ if '.result_is_zero' in gate and '.not_' not in gate:
3097
  return [registry.get_id(f"{prefix}.is_nan"),
3098
  registry.get_id(f"{prefix}.is_zero"),
3099
  registry.get_id(f"{prefix}.exp_lt_15")]
3100
 
3101
  registry.register(f"{prefix}.result_is_zero")
3102
 
3103
+ if '.not_result_is_zero' in gate:
3104
+ return [registry.get_id(f"{prefix}.result_is_zero")]
3105
+
3106
+ registry.register(f"{prefix}.not_result_is_zero")
3107
+
3108
  if '.implicit_bit' in gate:
3109
  return [registry.get_id(f"{prefix}.exp_zero")]
3110
 
3111
  registry.register(f"{prefix}.implicit_bit")
3112
 
3113
+ # === THRESHOLD GATES FOR SHIFT CONTROL ===
3114
+ if '.exp_ge_15' in gate:
3115
+ return [registry.get_id(b) for b in exp_bits]
3116
+ if '.exp_ge_18' in gate:
3117
+ return [registry.get_id(b) for b in exp_bits]
3118
+ if '.exp_le_21' in gate:
3119
+ return [registry.get_id(b) for b in exp_bits]
3120
 
3121
+ registry.register(f"{prefix}.exp_ge_15")
3122
+ registry.register(f"{prefix}.exp_ge_18")
3123
+ registry.register(f"{prefix}.exp_le_21")
3124
+
3125
+ if '.shift_bit3' in gate:
3126
+ return [registry.get_id(b) for b in exp_bits]
3127
+ if '.shift_bit2' in gate:
3128
+ return [registry.get_id(f"{prefix}.exp_ge_18"),
3129
+ registry.get_id(f"{prefix}.exp_le_21")]
3130
 
3131
+ registry.register(f"{prefix}.shift_bit3")
3132
+ registry.register(f"{prefix}.shift_bit2")
3133
+
3134
+ # === NOT OF EXPONENT BITS ===
3135
  for i in range(5):
3136
+ if f'.not_exp{i}' in gate:
3137
+ return [registry.get_id(exp_bits[i])]
3138
+ registry.register(f"{prefix}.not_exp{i}")
3139
+
3140
+ # === SHIFT CALCULATION: 25 - exp = ~exp + 26 ===
3141
+ # 26 = 0b011010
3142
+ const_26 = [0, 1, 0, 1, 1, 0]
3143
 
3144
  if '.shift_calc.fa' in gate:
3145
  match = re.search(r'\.shift_calc\.fa(\d+)\.', gate)
 
3147
  i = int(match.group(1))
3148
  fa_prefix = f"{prefix}.shift_calc.fa{i}"
3149
 
3150
+ # a = ~exp[i] (or 1 for i >= 5)
3151
+ a_bit = registry.get_id(f"{prefix}.not_exp{i}") if i < 5 else registry.get_id("#1")
3152
+ # b = const_26[i]
3153
+ b_bit = registry.get_id(f"#{const_26[i]}")
3154
+ cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.shift_calc.fa{i-1}.cout")
3155
 
3156
  if '.xor1.layer1' in gate:
3157
  return [a_bit, b_bit]
 
3176
  registry.register(f"{prefix}.shift_calc.fa{i}.xor2.layer2")
3177
  registry.register(f"{prefix}.shift_calc.fa{i}.cout")
3178
 
3179
+ # === RIGHT-SHIFT BARREL SHIFTER ===
3180
  for stage in range(4):
3181
  shift_amt = 1 << stage
3182
 
 
3184
  return [registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
3185
  registry.register(f"{prefix}.not_shift{stage}")
3186
 
3187
+ match = re.search(rf'\.rshift_s{stage}_(\d+)\.', gate)
3188
  if match:
3189
  i = int(match.group(1))
3190
+ src_pos = i + shift_amt
3191
+
3192
  if '.pass' in gate:
3193
+ # Current value (from previous stage or input)
3194
  if stage == 0:
3195
  if i < 10:
3196
  val = registry.get_id(mant_bits[i])
 
3199
  else:
3200
  val = registry.get_id("#0")
3201
  else:
3202
+ val = registry.get_id(f"{prefix}.rshift_s{stage-1}_{i}")
3203
  return [val, registry.get_id(f"{prefix}.not_shift{stage}")]
3204
+
3205
+ if '.shift' in gate and src_pos < 16:
3206
+ # Value from higher position
3207
  if stage == 0:
3208
+ if src_pos < 10:
3209
+ val = registry.get_id(mant_bits[src_pos])
3210
+ elif src_pos == 10:
 
3211
  val = registry.get_id(f"{prefix}.implicit_bit")
3212
  else:
3213
  val = registry.get_id("#0")
3214
  else:
3215
+ val = registry.get_id(f"{prefix}.rshift_s{stage-1}_{src_pos}")
3216
  return [val, registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
3217
 
3218
+ match = re.search(rf'\.rshift_s{stage}_(\d+)$', gate)
3219
  if match:
3220
  i = int(match.group(1))
3221
+ src_pos = i + shift_amt
3222
+ if src_pos < 16:
3223
+ return [registry.register(f"{prefix}.rshift_s{stage}_{i}.pass"),
3224
+ registry.register(f"{prefix}.rshift_s{stage}_{i}.shift")]
3225
  else:
3226
+ return [registry.register(f"{prefix}.rshift_s{stage}_{i}.pass")]
3227
 
3228
  for i in range(16):
3229
+ registry.register(f"{prefix}.rshift_s{stage}_{i}")
3230
 
3231
+ # === NEGATION ===
3232
  for i in range(16):
3233
  if f'.not_mag{i}' in gate:
3234
+ return [registry.get_id(f"{prefix}.rshift_s3_{i}")]
3235
  registry.register(f"{prefix}.not_mag{i}")
3236
 
3237
  if '.neg.fa' in gate:
 
3258
  registry.register(f"{prefix}.neg.fa{i}.xor.layer2")
3259
  registry.register(f"{prefix}.neg.fa{i}.cout")
3260
 
3261
+ # === OUTPUT SELECTION ===
3262
  match = re.search(r'\.out(\d+)\.', gate)
3263
  if match:
3264
  i = int(match.group(1))
3265
  sign = registry.get_id(f"{prefix}.$x[15]")
3266
+ not_sign = registry.register(f"{prefix}.not_sign")
3267
  if '.pos_path' in gate:
3268
+ return [registry.get_id(f"{prefix}.rshift_s3_{i}"),
3269
+ not_sign]
3270
  if '.neg_path' in gate:
3271
  return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
3272
  sign]
3273
 
3274
+ # not_sign gate
3275
+ if '.not_sign' in gate:
3276
+ return [registry.get_id(f"{prefix}.$x[15]")]
3277
+ registry.register(f"{prefix}.not_sign")
3278
+
3279
  match = re.search(r'\.out(\d+)$', gate)
3280
  if match:
3281
  i = int(match.group(1))
 
5769
  Convert float16 to signed 16-bit integer (truncate toward zero).
5770
 
5771
  Algorithm:
5772
+ 1. Extract mantissa M with implicit bit (11 bits, bit 10 = implicit 1)
5773
+ 2. For exp < 15: result = 0 (|value| < 1)
5774
+ 3. For exp >= 15: right-shift M by (25 - exp) positions
5775
+ - exp = 15: shift by 10, result = 1 for normalized
5776
+ - exp = 25: shift by 0, result = M (up to 2047)
5777
+ - exp > 25: would need left shift, but limited range
5778
+ 4. Apply sign via two's complement negation
5779
+ 5. Handle special cases: NaN, Inf, overflow
5780
  """
5781
  tensors = {}
5782
  prefix = "float16.toint"
5783
 
5784
+ # === SPECIAL CASE DETECTION ===
5785
+ # exp_all_ones: exponent = 31 (NaN or Inf)
5786
  tensors[f"{prefix}.exp_all_ones.weight"] = torch.tensor([1.0] * 5)
5787
  tensors[f"{prefix}.exp_all_ones.bias"] = torch.tensor([-5.0])
5788
 
5789
+ # exp_zero: exponent = 0 (zero or subnormal)
5790
  tensors[f"{prefix}.exp_zero.weight"] = torch.tensor([-1.0] * 5)
5791
  tensors[f"{prefix}.exp_zero.bias"] = torch.tensor([0.0])
5792
 
5793
+ # mant_nonzero: any mantissa bit set
5794
  tensors[f"{prefix}.mant_nonzero.weight"] = torch.tensor([1.0] * 10)
5795
  tensors[f"{prefix}.mant_nonzero.bias"] = torch.tensor([-1.0])
5796
 
5797
+ # is_nan: exp=31 AND mant!=0
5798
  tensors[f"{prefix}.is_nan.weight"] = torch.tensor([1.0, 1.0])
5799
  tensors[f"{prefix}.is_nan.bias"] = torch.tensor([-2.0])
5800
 
5801
+ # mant_zero: NOT mant_nonzero
5802
  tensors[f"{prefix}.mant_zero.weight"] = torch.tensor([-1.0])
5803
  tensors[f"{prefix}.mant_zero.bias"] = torch.tensor([0.0])
5804
 
5805
+ # is_inf: exp=31 AND mant=0
5806
  tensors[f"{prefix}.is_inf.weight"] = torch.tensor([1.0, 1.0])
5807
  tensors[f"{prefix}.is_inf.bias"] = torch.tensor([-2.0])
5808
 
5809
+ # is_zero: exp=0 AND mant=0
5810
  tensors[f"{prefix}.is_zero.weight"] = torch.tensor([1.0, 1.0])
5811
  tensors[f"{prefix}.is_zero.bias"] = torch.tensor([-2.0])
5812
 
5813
+ # === CHECK IF |VALUE| < 1 ===
5814
+ # exp < 15 means unbiased exponent < 0, so |value| < 1
5815
+ # Use threshold: sum(exp[i] * 2^i) < 15
5816
  weights = [-float(2**i) for i in range(5)]
5817
  tensors[f"{prefix}.exp_lt_15.weight"] = torch.tensor(weights)
5818
  tensors[f"{prefix}.exp_lt_15.bias"] = torch.tensor([14.0])
5819
 
5820
+ # result_is_zero: exp_zero OR exp_lt_15 OR is_nan
5821
  tensors[f"{prefix}.result_is_zero.weight"] = torch.tensor([1.0, 1.0, 1.0])
5822
  tensors[f"{prefix}.result_is_zero.bias"] = torch.tensor([-1.0])
5823
 
5824
+ # not_result_is_zero for muxing
5825
+ tensors[f"{prefix}.not_result_is_zero.weight"] = torch.tensor([-1.0])
5826
+ tensors[f"{prefix}.not_result_is_zero.bias"] = torch.tensor([0.0])
5827
+
5828
+ # === COMPUTE SHIFT AMOUNT: 25 - exp ===
5829
+ # For right shift: shift_amt = 25 - exp (need 0 to 10 for normal range)
5830
+ # 25 = 0b11001, so we compute NOT(exp) + 25 + 1 = NOT(exp) + 26
5831
+ # Actually simpler: use 25 - exp directly with threshold gates
5832
+
5833
+ # We'll use a different approach: compute exp directly and use threshold
5834
+ # gates to determine shift amount bits
5835
+
5836
+ # Implicit bit (always 1 for normalized numbers, 0 for subnormals)
5837
+ # implicit = NOT exp_zero
5838
+ tensors[f"{prefix}.implicit_bit.weight"] = torch.tensor([-1.0])
5839
+ tensors[f"{prefix}.implicit_bit.bias"] = torch.tensor([0.0])
5840
+
5841
+ # === DIRECT SHIFT USING EXPONENT VALUE ===
5842
+ # For exp in range 15-25, shift right by (25-exp)
5843
+ # For exp >= 25, no shift or left shift (overflow territory)
5844
+ #
5845
+ # Shift amounts needed: 0-10 for exp 25-15
5846
+ # shift[0] = 1 if (25-exp) is odd = exp is even when exp in {15,17,19,21,23,25}
5847
+ # This is complex. Let's use threshold gates on exp value.
5848
+
5849
+ # exp_ge_15: exp >= 15 (value >= 1)
5850
+ tensors[f"{prefix}.exp_ge_15.weight"] = torch.tensor([float(2**i) for i in range(5)])
5851
+ tensors[f"{prefix}.exp_ge_15.bias"] = torch.tensor([-15.0])
5852
+
5853
+ # For each shift stage, determine if we should shift
5854
+ # Right shift by 2^k if bit k of (25-exp) is set
5855
+ # 25 - exp for exp in [15, 25]: shift in [10, 0]
5856
+ # Binary of shift amounts:
5857
+ # exp=15: shift=10 = 0b1010
5858
+ # exp=16: shift=9 = 0b1001
5859
+ # exp=17: shift=8 = 0b1000
5860
+ # exp=18: shift=7 = 0b0111
5861
+ # exp=19: shift=6 = 0b0110
5862
+ # exp=20: shift=5 = 0b0101
5863
+ # exp=21: shift=4 = 0b0100
5864
+ # exp=22: shift=3 = 0b0011
5865
+ # exp=23: shift=2 = 0b0010
5866
+ # exp=24: shift=1 = 0b0001
5867
+ # exp=25: shift=0 = 0b0000
5868
+
5869
+ # Use threshold on exp to determine shift control bits
5870
+ # shift_bit3 (shift by 8): exp <= 17 (shift >= 8)
5871
+ tensors[f"{prefix}.shift_bit3.weight"] = torch.tensor([-float(2**i) for i in range(5)])
5872
+ tensors[f"{prefix}.shift_bit3.bias"] = torch.tensor([17.0])
5873
+
5874
+ # shift_bit2 (shift by 4): (exp <= 17) OR (18 <= exp <= 21)
5875
+ # = exp <= 21 AND NOT (18 <= exp <= 21 AND exp > 17)... complex
5876
+ # Simpler: shift_bit2 = 1 when shift in {4,5,6,7,12,13,14,15} ∩ [0,10] = {4,5,6,7}
5877
+ # = exp in {18,19,20,21}
5878
+ # Use: exp >= 18 AND exp <= 21
5879
+ tensors[f"{prefix}.exp_ge_18.weight"] = torch.tensor([float(2**i) for i in range(5)])
5880
+ tensors[f"{prefix}.exp_ge_18.bias"] = torch.tensor([-18.0])
5881
+ tensors[f"{prefix}.exp_le_21.weight"] = torch.tensor([-float(2**i) for i in range(5)])
5882
+ tensors[f"{prefix}.exp_le_21.bias"] = torch.tensor([21.0])
5883
+ tensors[f"{prefix}.shift_bit2.weight"] = torch.tensor([1.0, 1.0])
5884
+ tensors[f"{prefix}.shift_bit2.bias"] = torch.tensor([-2.0])
5885
+
5886
+ # shift_bit1 (shift by 2): shift in {2,3,6,7,10,11,...} ∩ [0,10] = {2,3,6,7,10}
5887
+ # = exp in {15,19,22,23} -- this is getting complex
5888
+ # Let's use a simpler direct threshold approach
5889
+
5890
+ # Actually, let's compute 25-exp using subtraction, then use those bits
5891
+ # 25 = 0b011001 (6 bits), exp is 5 bits
5892
+ # 25 - exp in two's complement
5893
 
 
5894
  for i in range(5):
5895
+ tensors[f"{prefix}.not_exp{i}.weight"] = torch.tensor([-1.0])
5896
+ tensors[f"{prefix}.not_exp{i}.bias"] = torch.tensor([0.0])
5897
 
5898
+ # 25 - exp = 25 + (~exp) + 1 = 26 + ~exp (in binary)
5899
+ # 26 = 0b011010
5900
+ const_26 = [0, 1, 0, 1, 1, 0] # bits of 26
5901
  for i in range(6):
5902
  p = f"{prefix}.shift_calc.fa{i}"
5903
  tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
 
5921
  tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
5922
  tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
5923
 
5924
+ # === RIGHT-SHIFT BARREL SHIFTER ===
5925
+ # 4 stages for shifts of 1, 2, 4, 8
5926
+ # Input: mantissa (10 bits) + implicit bit at position 10 = 11 bits
5927
+ # We'll work with 16 bits to have room
5928
 
5929
  for stage in range(4):
5930
  shift_amt = 1 << stage
5931
+ # NOT of shift control bit for mux
5932
  tensors[f"{prefix}.not_shift{stage}.weight"] = torch.tensor([-1.0])
5933
  tensors[f"{prefix}.not_shift{stage}.bias"] = torch.tensor([0.0])
5934
 
5935
  for i in range(16):
5936
+ # pass: keep current position (AND with NOT shift_bit)
5937
+ tensors[f"{prefix}.rshift_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
5938
+ tensors[f"{prefix}.rshift_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0])
5939
+
5940
+ # shift: take from higher position (AND with shift_bit)
5941
+ src_pos = i + shift_amt
5942
+ if src_pos < 16:
5943
+ tensors[f"{prefix}.rshift_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
5944
+ tensors[f"{prefix}.rshift_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0])
5945
+ tensors[f"{prefix}.rshift_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0])
5946
  else:
5947
+ # Shift in 0 from above
5948
+ tensors[f"{prefix}.rshift_s{stage}_{i}.weight"] = torch.tensor([1.0])
5949
+ tensors[f"{prefix}.rshift_s{stage}_{i}.bias"] = torch.tensor([-1.0])
5950
+
5951
+ # === TWO'S COMPLEMENT NEGATION FOR NEGATIVE FLOATS ===
5952
+ # If sign bit is 1, negate the result
5953
 
 
5954
  for i in range(16):
5955
  tensors[f"{prefix}.not_mag{i}.weight"] = torch.tensor([-1.0])
5956
  tensors[f"{prefix}.not_mag{i}.bias"] = torch.tensor([0.0])
5957
 
 
5958
  for i in range(16):
5959
  p = f"{prefix}.neg.fa{i}"
5960
  tensors[f"{p}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0])
 
5969
  tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
5970
  tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
5971
 
5972
+ # === OUTPUT SELECTION ===
5973
+ # Select between positive path, negative path, and zero
5974
+
5975
+ # NOT of sign bit for muxing positive path
5976
+ tensors[f"{prefix}.not_sign.weight"] = torch.tensor([-1.0])
5977
+ tensors[f"{prefix}.not_sign.bias"] = torch.tensor([0.0])
5978
+
5979
  for i in range(16):
5980
  tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0])
5981
  tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-2.0])
 
6219
 
6220
  print(f"Loaded {len(tensors)} tensors")
6221
 
6222
+ # Remove old tensors for circuits we're rebuilding
6223
  old_float16_add = [k for k in tensors.keys() if k.startswith('float16.add')]
6224
  for k in old_float16_add:
6225
  del tensors[k]
6226
  print(f"Removed {len(old_float16_add)} old float16.add tensors")
6227
 
6228
+ old_float16_toint = [k for k in tensors.keys() if k.startswith('float16.toint')]
6229
+ for k in old_float16_toint:
6230
+ del tensors[k]
6231
+ print(f"Removed {len(old_float16_toint)} old float16.toint tensors")
6232
+
6233
  # Build new circuits
6234
  print("Building new circuits...")
6235
  clz_tensors = build_clz8bit_tensors()