PortfolioAI commited on
Commit
fd6cf07
·
1 Parent(s): 4479be9

Fix toint, fromint; improve mul/div inference

Browse files

- float16.toint: 93/93 (fixed not_mag regex, added output gating)
- float16.fromint: 53/53 (fixed clz_and/not_in regex, added not_is_zero)
- float16.mul: 3/84 -> 13/84 (fixed not_15_bits, carry logic, NaN bit 9)
- float16.div: 2/53 -> 5/53 (fixed NaN bit 9)

Remaining: mul/div col_sum precision needs full adder trees

Files changed (3) hide show
  1. TODO.md +4 -4
  2. arithmetic.safetensors +2 -2
  3. convert_to_explicit_inputs.py +63 -21
TODO.md CHANGED
@@ -9,10 +9,10 @@
9
  - [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
10
  - [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
11
  - [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
12
- - [ ] `float16.mul` -- IEEE 754 multiplication (766 gates, 3/84 tests, algorithm bugs)
13
- - [ ] `float16.div` -- IEEE 754 division (1854 gates, 2/53 tests, algorithm bugs)
14
- - [ ] `float16.toint` -- float16 to int16 (401 gates, 54/93 tests, debugging shift logic)
15
- - [ ] `float16.fromint` -- int16 to float16 (478 gates, 1/53 tests, algorithm bugs)
16
  - [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
17
  - [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
18
 
 
9
  - [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
10
  - [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
11
  - [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
12
+ - [ ] `float16.mul` -- IEEE 754 multiplication (766 gates, 13/84 tests, col_sum precision)
13
+ - [ ] `float16.div` -- IEEE 754 division (1854 gates, 5/53 tests, col_sum precision)
14
+ - [x] `float16.toint` -- float16 to int16 (401 gates, 93/93 tests)
15
+ - [x] `float16.fromint` -- int16 to float16 (478 gates, 53/53 tests)
16
  - [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
17
  - [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
18
 
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c7bc258f49a9a4c85321d0980f843bb989e8fb84c4d8bc65883f24c6e306334
3
- size 2863492
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bb8fad90726b27a8bd9502c5cb4154242a5f8a6d046c4ba69470be55bb98624
3
+ size 2865388
convert_to_explicit_inputs.py CHANGED
@@ -2441,12 +2441,25 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2441
  return registry.get_id(f"{prefix}.col{col}_sum")
2442
  return registry.get_id("#0")
2443
 
2444
- def get_col_ge1(col):
2445
- # ge1 only exists for columns with >= 2 PPs
 
 
 
2446
  if col == 0 or col == 20:
2447
  return registry.get_id("#0") # No carry from single PP columns
2448
  elif col < 21:
2449
- return registry.get_id(f"{prefix}.col{col}_ge1")
 
 
 
 
 
 
 
 
 
 
2450
  return registry.get_id("#0")
2451
 
2452
  if i == 0:
@@ -2455,7 +2468,7 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2455
  cin = registry.get_id("#0")
2456
  else:
2457
  a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
2458
- b_bit = get_col_ge1(i - 1) if i < 22 else registry.get_id("#0")
2459
  cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
2460
 
2461
  if '.xor1.layer1' in gate:
@@ -2514,7 +2527,8 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2514
  registry.register(f"{prefix}.exp_add.fa{i}.xor2.layer2")
2515
  registry.register(f"{prefix}.exp_add.fa{i}.cout")
2516
 
2517
- not_15_bits = [1, 1, 1, 1, 0, 0]
 
2518
  if '.exp_sub.fa' in gate:
2519
  match = re.search(r'\.exp_sub\.fa(\d+)\.', gate)
2520
  if match:
@@ -2649,9 +2663,11 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2649
  if match:
2650
  i = int(match.group(1))
2651
  if '.nan_gate' in gate:
2652
- nan_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
 
2653
  return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
2654
  if '.inf_gate' in gate:
 
2655
  inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
2656
  return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
2657
  if '.zero_gate' in gate:
@@ -3018,9 +3034,11 @@ def infer_float16_div_inputs(gate: str, registry: SignalRegistry) -> List[int]:
3018
  if match:
3019
  i = int(match.group(1))
3020
  if '.nan_gate' in gate:
3021
- nan_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
 
3022
  return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
3023
  if '.inf_gate' in gate:
 
3024
  inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
3025
  return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
3026
  if '.zero_gate' in gate:
@@ -3229,9 +3247,12 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3229
  registry.register(f"{prefix}.rshift_s{stage}_{i}")
3230
 
3231
  # === NEGATION ===
 
 
 
 
 
3232
  for i in range(16):
3233
- if f'.not_mag{i}' in gate:
3234
- return [registry.get_id(f"{prefix}.rshift_s3_{i}")]
3235
  registry.register(f"{prefix}.not_mag{i}")
3236
 
3237
  if '.neg.fa' in gate:
@@ -3264,12 +3285,15 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
3264
  i = int(match.group(1))
3265
  sign = registry.get_id(f"{prefix}.$x[15]")
3266
  not_sign = registry.register(f"{prefix}.not_sign")
 
3267
  if '.pos_path' in gate:
3268
  return [registry.get_id(f"{prefix}.rshift_s3_{i}"),
3269
- not_sign]
 
3270
  if '.neg_path' in gate:
3271
  return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
3272
- sign]
 
3273
 
3274
  # not_sign gate
3275
  if '.not_sign' in gate:
@@ -3294,6 +3318,8 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
3294
 
3295
  in_bits = [f"{prefix}.$x[{i}]" for i in range(16)]
3296
 
 
 
3297
  if '.is_zero' in gate:
3298
  return [registry.get_id(b) for b in in_bits]
3299
  if '.is_negative' in gate:
@@ -3302,12 +3328,16 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
3302
  return [registry.get_id(f"{prefix}.is_negative")]
3303
 
3304
  registry.register(f"{prefix}.is_zero")
 
3305
  registry.register(f"{prefix}.is_negative")
3306
  registry.register(f"{prefix}.not_negative")
3307
 
 
 
 
 
 
3308
  for i in range(16):
3309
- if f'.not_in{i}' in gate:
3310
- return [registry.get_id(in_bits[i])]
3311
  registry.register(f"{prefix}.not_in{i}")
3312
 
3313
  if '.abs.fa' in gate:
@@ -3421,10 +3451,14 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
3421
  registry.register(f"{prefix}.clz_and_14_15")
3422
  registry.register(f"{prefix}.clz1")
3423
 
3424
- for i in [1, 3, 5, 7, 9, 11, 13, 15]:
3425
- if f'.clz_and_{i}' in gate:
 
 
3426
  return [registry.get_id(f"{prefix}.ge{i}"),
3427
  registry.get_id(f"{prefix}.not_ge{i+1}")]
 
 
3428
  registry.register(f"{prefix}.clz_and_{i}")
3429
 
3430
  if '.clz0' in gate:
@@ -3519,7 +3553,7 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
3519
  val = registry.get_id(f"{prefix}.exp_calc.fa{i-10}.xor2.layer2")
3520
  else:
3521
  val = registry.get_id(f"{prefix}.is_negative")
3522
- not_zero = registry.get_id(f"{prefix}.is_zero")
3523
  return [val, not_zero]
3524
 
3525
  match = re.search(r'\.out(\d+)$', gate)
@@ -5971,16 +6005,20 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
5971
 
5972
  # === OUTPUT SELECTION ===
5973
  # Select between positive path, negative path, and zero
 
5974
 
5975
  # NOT of sign bit for muxing positive path
5976
  tensors[f"{prefix}.not_sign.weight"] = torch.tensor([-1.0])
5977
  tensors[f"{prefix}.not_sign.bias"] = torch.tensor([0.0])
5978
 
5979
  for i in range(16):
5980
- tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0])
5981
- tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-2.0])
5982
- tensors[f"{prefix}.out{i}.neg_path.weight"] = torch.tensor([1.0, 1.0])
5983
- tensors[f"{prefix}.out{i}.neg_path.bias"] = torch.tensor([-2.0])
 
 
 
5984
  tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0])
5985
  tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0])
5986
 
@@ -6005,6 +6043,10 @@ def build_float16_fromint_tensors() -> Dict[str, torch.Tensor]:
6005
  tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 16)
6006
  tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0])
6007
 
 
 
 
 
6008
  # Check if negative (sign bit)
6009
  tensors[f"{prefix}.is_negative.weight"] = torch.tensor([1.0])
6010
  tensors[f"{prefix}.is_negative.bias"] = torch.tensor([-0.5])
@@ -6050,7 +6092,7 @@ def build_float16_fromint_tensors() -> Dict[str, torch.Tensor]:
6050
  tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
6051
 
6052
  # CLZ binary encoding
6053
- for k in [2, 4, 8, 16]:
6054
  tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
6055
  tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
6056
 
 
2441
  return registry.get_id(f"{prefix}.col{col}_sum")
2442
  return registry.get_id("#0")
2443
 
2444
+ def get_col_carry(col):
2445
+ # Carry from column = sum >= 2, which is ge2
2446
+ # For columns with count >= 3, ge2 exists
2447
+ # For columns with count == 2, no ge2 (but carry is rare anyway)
2448
+ # For single-bit columns, no carry possible
2449
  if col == 0 or col == 20:
2450
  return registry.get_id("#0") # No carry from single PP columns
2451
  elif col < 21:
2452
+ # ge2 exists for columns with 3+ PPs
2453
+ # For 2-PP columns (col 1 and col 19), ge2 doesn't exist
2454
+ # but those columns can only produce carry if both PPs are 1,
2455
+ # which is relatively rare. For now, we use #0 for 2-PP columns.
2456
+ ge2_name = f"{prefix}.col{col}_ge2"
2457
+ ge2_id = registry.get_id(ge2_name)
2458
+ if ge2_id != -1:
2459
+ return ge2_id
2460
+ else:
2461
+ # 2-PP columns: no ge2, return 0 (imprecise but safe)
2462
+ return registry.get_id("#0")
2463
  return registry.get_id("#0")
2464
 
2465
  if i == 0:
 
2468
  cin = registry.get_id("#0")
2469
  else:
2470
  a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
2471
+ b_bit = get_col_carry(i - 1) if i < 22 else registry.get_id("#0")
2472
  cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
2473
 
2474
  if '.xor1.layer1' in gate:
 
2527
  registry.register(f"{prefix}.exp_add.fa{i}.xor2.layer2")
2528
  registry.register(f"{prefix}.exp_add.fa{i}.cout")
2529
 
2530
+ # NOT(15) = NOT(001111) = 110000 in 6-bit, little-endian: [0, 0, 0, 0, 1, 1]
2531
+ not_15_bits = [0, 0, 0, 0, 1, 1]
2532
  if '.exp_sub.fa' in gate:
2533
  match = re.search(r'\.exp_sub\.fa(\d+)\.', gate)
2534
  if match:
 
2663
  if match:
2664
  i = int(match.group(1))
2665
  if '.nan_gate' in gate:
2666
+ # Canonical NaN = 0x7E00 = 0_11111_1000000000, bits 9-14 are 1
2667
+ nan_bit = registry.get_id("#1") if (i >= 9 and i < 15) else registry.get_id("#0")
2668
  return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
2669
  if '.inf_gate' in gate:
2670
+ # Inf = 0x7C00 = 0_11111_0000000000, bits 10-14 are 1
2671
  inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
2672
  return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
2673
  if '.zero_gate' in gate:
 
3034
  if match:
3035
  i = int(match.group(1))
3036
  if '.nan_gate' in gate:
3037
+ # Canonical NaN = 0x7E00 = 0_11111_1000000000, bits 9-14 are 1
3038
+ nan_bit = registry.get_id("#1") if (i >= 9 and i < 15) else registry.get_id("#0")
3039
  return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
3040
  if '.inf_gate' in gate:
3041
+ # Inf = 0x7C00 = 0_11111_0000000000, bits 10-14 are 1
3042
  inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
3043
  return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
3044
  if '.zero_gate' in gate:
 
3247
  registry.register(f"{prefix}.rshift_s{stage}_{i}")
3248
 
3249
  # === NEGATION ===
3250
+ match = re.search(r'\.not_mag(\d+)$', gate)
3251
+ if match:
3252
+ i = int(match.group(1))
3253
+ return [registry.get_id(f"{prefix}.rshift_s3_{i}")]
3254
+
3255
  for i in range(16):
 
 
3256
  registry.register(f"{prefix}.not_mag{i}")
3257
 
3258
  if '.neg.fa' in gate:
 
3285
  i = int(match.group(1))
3286
  sign = registry.get_id(f"{prefix}.$x[15]")
3287
  not_sign = registry.register(f"{prefix}.not_sign")
3288
+ not_result_zero = registry.get_id(f"{prefix}.not_result_is_zero")
3289
  if '.pos_path' in gate:
3290
  return [registry.get_id(f"{prefix}.rshift_s3_{i}"),
3291
+ not_sign,
3292
+ not_result_zero]
3293
  if '.neg_path' in gate:
3294
  return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
3295
+ sign,
3296
+ not_result_zero]
3297
 
3298
  # not_sign gate
3299
  if '.not_sign' in gate:
 
3318
 
3319
  in_bits = [f"{prefix}.$x[{i}]" for i in range(16)]
3320
 
3321
+ if '.not_is_zero' in gate:
3322
+ return [registry.get_id(f"{prefix}.is_zero")]
3323
  if '.is_zero' in gate:
3324
  return [registry.get_id(b) for b in in_bits]
3325
  if '.is_negative' in gate:
 
3328
  return [registry.get_id(f"{prefix}.is_negative")]
3329
 
3330
  registry.register(f"{prefix}.is_zero")
3331
+ registry.register(f"{prefix}.not_is_zero")
3332
  registry.register(f"{prefix}.is_negative")
3333
  registry.register(f"{prefix}.not_negative")
3334
 
3335
+ match = re.search(r'\.not_in(\d+)$', gate)
3336
+ if match:
3337
+ i = int(match.group(1))
3338
+ return [registry.get_id(in_bits[i])]
3339
+
3340
  for i in range(16):
 
 
3341
  registry.register(f"{prefix}.not_in{i}")
3342
 
3343
  if '.abs.fa' in gate:
 
3451
  registry.register(f"{prefix}.clz_and_14_15")
3452
  registry.register(f"{prefix}.clz1")
3453
 
3454
+ match = re.search(r'\.clz_and_(\d+)$', gate)
3455
+ if match:
3456
+ i = int(match.group(1))
3457
+ if i in [1, 3, 5, 7, 9, 11, 13, 15]:
3458
  return [registry.get_id(f"{prefix}.ge{i}"),
3459
  registry.get_id(f"{prefix}.not_ge{i+1}")]
3460
+
3461
+ for i in [1, 3, 5, 7, 9, 11, 13, 15]:
3462
  registry.register(f"{prefix}.clz_and_{i}")
3463
 
3464
  if '.clz0' in gate:
 
3553
  val = registry.get_id(f"{prefix}.exp_calc.fa{i-10}.xor2.layer2")
3554
  else:
3555
  val = registry.get_id(f"{prefix}.is_negative")
3556
+ not_zero = registry.get_id(f"{prefix}.not_is_zero")
3557
  return [val, not_zero]
3558
 
3559
  match = re.search(r'\.out(\d+)$', gate)
 
6005
 
6006
  # === OUTPUT SELECTION ===
6007
  # Select between positive path, negative path, and zero
6008
+ # Gate by not_result_is_zero to force output to 0 for |value| < 1
6009
 
6010
  # NOT of sign bit for muxing positive path
6011
  tensors[f"{prefix}.not_sign.weight"] = torch.tensor([-1.0])
6012
  tensors[f"{prefix}.not_sign.bias"] = torch.tensor([0.0])
6013
 
6014
  for i in range(16):
6015
+ # pos_path = shifted_value AND not_sign AND not_result_is_zero
6016
+ tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0, 1.0])
6017
+ tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-3.0])
6018
+ # neg_path = negated_value AND sign AND not_result_is_zero
6019
+ tensors[f"{prefix}.out{i}.neg_path.weight"] = torch.tensor([1.0, 1.0, 1.0])
6020
+ tensors[f"{prefix}.out{i}.neg_path.bias"] = torch.tensor([-3.0])
6021
+ # out = pos_path OR neg_path
6022
  tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0])
6023
  tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0])
6024
 
 
6043
  tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 16)
6044
  tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0])
6045
 
6046
+ # NOT is_zero for gating normal output
6047
+ tensors[f"{prefix}.not_is_zero.weight"] = torch.tensor([-1.0])
6048
+ tensors[f"{prefix}.not_is_zero.bias"] = torch.tensor([0.0])
6049
+
6050
  # Check if negative (sign bit)
6051
  tensors[f"{prefix}.is_negative.weight"] = torch.tensor([1.0])
6052
  tensors[f"{prefix}.is_negative.bias"] = torch.tensor([-0.5])
 
6092
  tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
6093
 
6094
  # CLZ binary encoding
6095
+ for k in [2, 4, 6, 8, 10, 12, 14, 16]:
6096
  tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
6097
  tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
6098