PortfolioAI commited on
Commit ·
fd6cf07
1
Parent(s): 4479be9
Fix toint, fromint; improve mul/div inference
Browse files- float16.toint: 93/93 (fixed not_mag regex, added output gating)
- float16.fromint: 53/53 (fixed clz_and/not_in regex, added not_is_zero)
- float16.mul: 3/84 -> 13/84 (fixed not_15_bits, carry logic, NaN bit 9)
- float16.div: 2/53 -> 5/53 (fixed NaN bit 9)
Remaining: mul/div col_sum precision needs full adder trees
- TODO.md +4 -4
- arithmetic.safetensors +2 -2
- convert_to_explicit_inputs.py +63 -21
TODO.md
CHANGED
|
@@ -9,10 +9,10 @@
|
|
| 9 |
- [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
|
| 10 |
- [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
|
| 11 |
- [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
|
| 12 |
-
- [ ] `float16.mul` -- IEEE 754 multiplication (766 gates,
|
| 13 |
-
- [ ] `float16.div` -- IEEE 754 division (1854 gates,
|
| 14 |
-
- [
|
| 15 |
-
- [
|
| 16 |
- [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
|
| 17 |
- [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
|
| 18 |
|
|
|
|
| 9 |
- [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
|
| 10 |
- [x] `float16.add` -- IEEE 754 addition (~998 gates, 125/125 tests)
|
| 11 |
- [x] `float16.sub` -- IEEE 754 subtraction (via add with -b, 115/115 tests)
|
| 12 |
+
- [ ] `float16.mul` -- IEEE 754 multiplication (766 gates, 13/84 tests, col_sum precision)
|
| 13 |
+
- [ ] `float16.div` -- IEEE 754 division (1854 gates, 5/53 tests, col_sum precision)
|
| 14 |
+
- [x] `float16.toint` -- float16 to int16 (401 gates, 93/93 tests)
|
| 15 |
+
- [x] `float16.fromint` -- int16 to float16 (478 gates, 53/53 tests)
|
| 16 |
- [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
|
| 17 |
- [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
|
| 18 |
|
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3bb8fad90726b27a8bd9502c5cb4154242a5f8a6d046c4ba69470be55bb98624
|
| 3 |
+
size 2865388
|
convert_to_explicit_inputs.py
CHANGED
|
@@ -2441,12 +2441,25 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 2441 |
return registry.get_id(f"{prefix}.col{col}_sum")
|
| 2442 |
return registry.get_id("#0")
|
| 2443 |
|
| 2444 |
-
def
|
| 2445 |
-
#
|
|
|
|
|
|
|
|
|
|
| 2446 |
if col == 0 or col == 20:
|
| 2447 |
return registry.get_id("#0") # No carry from single PP columns
|
| 2448 |
elif col < 21:
|
| 2449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2450 |
return registry.get_id("#0")
|
| 2451 |
|
| 2452 |
if i == 0:
|
|
@@ -2455,7 +2468,7 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 2455 |
cin = registry.get_id("#0")
|
| 2456 |
else:
|
| 2457 |
a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
|
| 2458 |
-
b_bit =
|
| 2459 |
cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
|
| 2460 |
|
| 2461 |
if '.xor1.layer1' in gate:
|
|
@@ -2514,7 +2527,8 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 2514 |
registry.register(f"{prefix}.exp_add.fa{i}.xor2.layer2")
|
| 2515 |
registry.register(f"{prefix}.exp_add.fa{i}.cout")
|
| 2516 |
|
| 2517 |
-
|
|
|
|
| 2518 |
if '.exp_sub.fa' in gate:
|
| 2519 |
match = re.search(r'\.exp_sub\.fa(\d+)\.', gate)
|
| 2520 |
if match:
|
|
@@ -2649,9 +2663,11 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 2649 |
if match:
|
| 2650 |
i = int(match.group(1))
|
| 2651 |
if '.nan_gate' in gate:
|
| 2652 |
-
|
|
|
|
| 2653 |
return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
|
| 2654 |
if '.inf_gate' in gate:
|
|
|
|
| 2655 |
inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
|
| 2656 |
return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
|
| 2657 |
if '.zero_gate' in gate:
|
|
@@ -3018,9 +3034,11 @@ def infer_float16_div_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 3018 |
if match:
|
| 3019 |
i = int(match.group(1))
|
| 3020 |
if '.nan_gate' in gate:
|
| 3021 |
-
|
|
|
|
| 3022 |
return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
|
| 3023 |
if '.inf_gate' in gate:
|
|
|
|
| 3024 |
inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
|
| 3025 |
return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
|
| 3026 |
if '.zero_gate' in gate:
|
|
@@ -3229,9 +3247,12 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3229 |
registry.register(f"{prefix}.rshift_s{stage}_{i}")
|
| 3230 |
|
| 3231 |
# === NEGATION ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3232 |
for i in range(16):
|
| 3233 |
-
if f'.not_mag{i}' in gate:
|
| 3234 |
-
return [registry.get_id(f"{prefix}.rshift_s3_{i}")]
|
| 3235 |
registry.register(f"{prefix}.not_mag{i}")
|
| 3236 |
|
| 3237 |
if '.neg.fa' in gate:
|
|
@@ -3264,12 +3285,15 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3264 |
i = int(match.group(1))
|
| 3265 |
sign = registry.get_id(f"{prefix}.$x[15]")
|
| 3266 |
not_sign = registry.register(f"{prefix}.not_sign")
|
|
|
|
| 3267 |
if '.pos_path' in gate:
|
| 3268 |
return [registry.get_id(f"{prefix}.rshift_s3_{i}"),
|
| 3269 |
-
not_sign
|
|
|
|
| 3270 |
if '.neg_path' in gate:
|
| 3271 |
return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
|
| 3272 |
-
sign
|
|
|
|
| 3273 |
|
| 3274 |
# not_sign gate
|
| 3275 |
if '.not_sign' in gate:
|
|
@@ -3294,6 +3318,8 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
|
|
| 3294 |
|
| 3295 |
in_bits = [f"{prefix}.$x[{i}]" for i in range(16)]
|
| 3296 |
|
|
|
|
|
|
|
| 3297 |
if '.is_zero' in gate:
|
| 3298 |
return [registry.get_id(b) for b in in_bits]
|
| 3299 |
if '.is_negative' in gate:
|
|
@@ -3302,12 +3328,16 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
|
|
| 3302 |
return [registry.get_id(f"{prefix}.is_negative")]
|
| 3303 |
|
| 3304 |
registry.register(f"{prefix}.is_zero")
|
|
|
|
| 3305 |
registry.register(f"{prefix}.is_negative")
|
| 3306 |
registry.register(f"{prefix}.not_negative")
|
| 3307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3308 |
for i in range(16):
|
| 3309 |
-
if f'.not_in{i}' in gate:
|
| 3310 |
-
return [registry.get_id(in_bits[i])]
|
| 3311 |
registry.register(f"{prefix}.not_in{i}")
|
| 3312 |
|
| 3313 |
if '.abs.fa' in gate:
|
|
@@ -3421,10 +3451,14 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
|
|
| 3421 |
registry.register(f"{prefix}.clz_and_14_15")
|
| 3422 |
registry.register(f"{prefix}.clz1")
|
| 3423 |
|
| 3424 |
-
|
| 3425 |
-
|
|
|
|
|
|
|
| 3426 |
return [registry.get_id(f"{prefix}.ge{i}"),
|
| 3427 |
registry.get_id(f"{prefix}.not_ge{i+1}")]
|
|
|
|
|
|
|
| 3428 |
registry.register(f"{prefix}.clz_and_{i}")
|
| 3429 |
|
| 3430 |
if '.clz0' in gate:
|
|
@@ -3519,7 +3553,7 @@ def infer_float16_fromint_inputs(gate: str, registry: SignalRegistry) -> List[in
|
|
| 3519 |
val = registry.get_id(f"{prefix}.exp_calc.fa{i-10}.xor2.layer2")
|
| 3520 |
else:
|
| 3521 |
val = registry.get_id(f"{prefix}.is_negative")
|
| 3522 |
-
not_zero = registry.get_id(f"{prefix}.
|
| 3523 |
return [val, not_zero]
|
| 3524 |
|
| 3525 |
match = re.search(r'\.out(\d+)$', gate)
|
|
@@ -5971,16 +6005,20 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
|
|
| 5971 |
|
| 5972 |
# === OUTPUT SELECTION ===
|
| 5973 |
# Select between positive path, negative path, and zero
|
|
|
|
| 5974 |
|
| 5975 |
# NOT of sign bit for muxing positive path
|
| 5976 |
tensors[f"{prefix}.not_sign.weight"] = torch.tensor([-1.0])
|
| 5977 |
tensors[f"{prefix}.not_sign.bias"] = torch.tensor([0.0])
|
| 5978 |
|
| 5979 |
for i in range(16):
|
| 5980 |
-
|
| 5981 |
-
tensors[f"{prefix}.out{i}.pos_path.
|
| 5982 |
-
tensors[f"{prefix}.out{i}.
|
| 5983 |
-
|
|
|
|
|
|
|
|
|
|
| 5984 |
tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 5985 |
tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0])
|
| 5986 |
|
|
@@ -6005,6 +6043,10 @@ def build_float16_fromint_tensors() -> Dict[str, torch.Tensor]:
|
|
| 6005 |
tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 16)
|
| 6006 |
tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0])
|
| 6007 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6008 |
# Check if negative (sign bit)
|
| 6009 |
tensors[f"{prefix}.is_negative.weight"] = torch.tensor([1.0])
|
| 6010 |
tensors[f"{prefix}.is_negative.bias"] = torch.tensor([-0.5])
|
|
@@ -6050,7 +6092,7 @@ def build_float16_fromint_tensors() -> Dict[str, torch.Tensor]:
|
|
| 6050 |
tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
|
| 6051 |
|
| 6052 |
# CLZ binary encoding
|
| 6053 |
-
for k in [2, 4, 8, 16]:
|
| 6054 |
tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
|
| 6055 |
tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
|
| 6056 |
|
|
|
|
| 2441 |
return registry.get_id(f"{prefix}.col{col}_sum")
|
| 2442 |
return registry.get_id("#0")
|
| 2443 |
|
| 2444 |
+
def get_col_carry(col):
|
| 2445 |
+
# Carry from column = sum >= 2, which is ge2
|
| 2446 |
+
# For columns with count >= 3, ge2 exists
|
| 2447 |
+
# For columns with count == 2, no ge2 (but carry is rare anyway)
|
| 2448 |
+
# For single-bit columns, no carry possible
|
| 2449 |
if col == 0 or col == 20:
|
| 2450 |
return registry.get_id("#0") # No carry from single PP columns
|
| 2451 |
elif col < 21:
|
| 2452 |
+
# ge2 exists for columns with 3+ PPs
|
| 2453 |
+
# For 2-PP columns (col 1 and col 19), ge2 doesn't exist
|
| 2454 |
+
# but those columns can only produce carry if both PPs are 1,
|
| 2455 |
+
# which is relatively rare. For now, we use #0 for 2-PP columns.
|
| 2456 |
+
ge2_name = f"{prefix}.col{col}_ge2"
|
| 2457 |
+
ge2_id = registry.get_id(ge2_name)
|
| 2458 |
+
if ge2_id != -1:
|
| 2459 |
+
return ge2_id
|
| 2460 |
+
else:
|
| 2461 |
+
# 2-PP columns: no ge2, return 0 (imprecise but safe)
|
| 2462 |
+
return registry.get_id("#0")
|
| 2463 |
return registry.get_id("#0")
|
| 2464 |
|
| 2465 |
if i == 0:
|
|
|
|
| 2468 |
cin = registry.get_id("#0")
|
| 2469 |
else:
|
| 2470 |
a_bit = get_col_sum(i) if i < 21 else registry.get_id("#0")
|
| 2471 |
+
b_bit = get_col_carry(i - 1) if i < 22 else registry.get_id("#0")
|
| 2472 |
cin = registry.register(f"{prefix}.prod_fa{i-1}.cout")
|
| 2473 |
|
| 2474 |
if '.xor1.layer1' in gate:
|
|
|
|
| 2527 |
registry.register(f"{prefix}.exp_add.fa{i}.xor2.layer2")
|
| 2528 |
registry.register(f"{prefix}.exp_add.fa{i}.cout")
|
| 2529 |
|
| 2530 |
+
# NOT(15) = NOT(001111) = 110000 in 6-bit, little-endian: [0, 0, 0, 0, 1, 1]
|
| 2531 |
+
not_15_bits = [0, 0, 0, 0, 1, 1]
|
| 2532 |
if '.exp_sub.fa' in gate:
|
| 2533 |
match = re.search(r'\.exp_sub\.fa(\d+)\.', gate)
|
| 2534 |
if match:
|
|
|
|
| 2663 |
if match:
|
| 2664 |
i = int(match.group(1))
|
| 2665 |
if '.nan_gate' in gate:
|
| 2666 |
+
# Canonical NaN = 0x7E00 = 0_11111_1000000000, bits 9-14 are 1
|
| 2667 |
+
nan_bit = registry.get_id("#1") if (i >= 9 and i < 15) else registry.get_id("#0")
|
| 2668 |
return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
|
| 2669 |
if '.inf_gate' in gate:
|
| 2670 |
+
# Inf = 0x7C00 = 0_11111_0000000000, bits 10-14 are 1
|
| 2671 |
inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
|
| 2672 |
return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
|
| 2673 |
if '.zero_gate' in gate:
|
|
|
|
| 3034 |
if match:
|
| 3035 |
i = int(match.group(1))
|
| 3036 |
if '.nan_gate' in gate:
|
| 3037 |
+
# Canonical NaN = 0x7E00 = 0_11111_1000000000, bits 9-14 are 1
|
| 3038 |
+
nan_bit = registry.get_id("#1") if (i >= 9 and i < 15) else registry.get_id("#0")
|
| 3039 |
return [nan_bit, registry.get_id(f"{prefix}.result_is_nan")]
|
| 3040 |
if '.inf_gate' in gate:
|
| 3041 |
+
# Inf = 0x7C00 = 0_11111_0000000000, bits 10-14 are 1
|
| 3042 |
inf_bit = registry.get_id("#1") if i >= 10 and i < 15 else registry.get_id("#0")
|
| 3043 |
return [inf_bit, registry.get_id(f"{prefix}.result_is_inf")]
|
| 3044 |
if '.zero_gate' in gate:
|
|
|
|
| 3247 |
registry.register(f"{prefix}.rshift_s{stage}_{i}")
|
| 3248 |
|
| 3249 |
# === NEGATION ===
|
| 3250 |
+
match = re.search(r'\.not_mag(\d+)$', gate)
|
| 3251 |
+
if match:
|
| 3252 |
+
i = int(match.group(1))
|
| 3253 |
+
return [registry.get_id(f"{prefix}.rshift_s3_{i}")]
|
| 3254 |
+
|
| 3255 |
for i in range(16):
|
|
|
|
|
|
|
| 3256 |
registry.register(f"{prefix}.not_mag{i}")
|
| 3257 |
|
| 3258 |
if '.neg.fa' in gate:
|
|
|
|
| 3285 |
i = int(match.group(1))
|
| 3286 |
sign = registry.get_id(f"{prefix}.$x[15]")
|
| 3287 |
not_sign = registry.register(f"{prefix}.not_sign")
|
| 3288 |
+
not_result_zero = registry.get_id(f"{prefix}.not_result_is_zero")
|
| 3289 |
if '.pos_path' in gate:
|
| 3290 |
return [registry.get_id(f"{prefix}.rshift_s3_{i}"),
|
| 3291 |
+
not_sign,
|
| 3292 |
+
not_result_zero]
|
| 3293 |
if '.neg_path' in gate:
|
| 3294 |
return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
|
| 3295 |
+
sign,
|
| 3296 |
+
not_result_zero]
|
| 3297 |
|
| 3298 |
# not_sign gate
|
| 3299 |
if '.not_sign' in gate:
|
|
|
|
| 3318 |
|
| 3319 |
in_bits = [f"{prefix}.$x[{i}]" for i in range(16)]
|
| 3320 |
|
| 3321 |
+
if '.not_is_zero' in gate:
|
| 3322 |
+
return [registry.get_id(f"{prefix}.is_zero")]
|
| 3323 |
if '.is_zero' in gate:
|
| 3324 |
return [registry.get_id(b) for b in in_bits]
|
| 3325 |
if '.is_negative' in gate:
|
|
|
|
| 3328 |
return [registry.get_id(f"{prefix}.is_negative")]
|
| 3329 |
|
| 3330 |
registry.register(f"{prefix}.is_zero")
|
| 3331 |
+
registry.register(f"{prefix}.not_is_zero")
|
| 3332 |
registry.register(f"{prefix}.is_negative")
|
| 3333 |
registry.register(f"{prefix}.not_negative")
|
| 3334 |
|
| 3335 |
+
match = re.search(r'\.not_in(\d+)$', gate)
|
| 3336 |
+
if match:
|
| 3337 |
+
i = int(match.group(1))
|
| 3338 |
+
return [registry.get_id(in_bits[i])]
|
| 3339 |
+
|
| 3340 |
for i in range(16):
|
|
|
|
|
|
|
| 3341 |
registry.register(f"{prefix}.not_in{i}")
|
| 3342 |
|
| 3343 |
if '.abs.fa' in gate:
|
|
|
|
| 3451 |
registry.register(f"{prefix}.clz_and_14_15")
|
| 3452 |
registry.register(f"{prefix}.clz1")
|
| 3453 |
|
| 3454 |
+
match = re.search(r'\.clz_and_(\d+)$', gate)
|
| 3455 |
+
if match:
|
| 3456 |
+
i = int(match.group(1))
|
| 3457 |
+
if i in [1, 3, 5, 7, 9, 11, 13, 15]:
|
| 3458 |
return [registry.get_id(f"{prefix}.ge{i}"),
|
| 3459 |
registry.get_id(f"{prefix}.not_ge{i+1}")]
|
| 3460 |
+
|
| 3461 |
+
for i in [1, 3, 5, 7, 9, 11, 13, 15]:
|
| 3462 |
registry.register(f"{prefix}.clz_and_{i}")
|
| 3463 |
|
| 3464 |
if '.clz0' in gate:
|
|
|
|
| 3553 |
val = registry.get_id(f"{prefix}.exp_calc.fa{i-10}.xor2.layer2")
|
| 3554 |
else:
|
| 3555 |
val = registry.get_id(f"{prefix}.is_negative")
|
| 3556 |
+
not_zero = registry.get_id(f"{prefix}.not_is_zero")
|
| 3557 |
return [val, not_zero]
|
| 3558 |
|
| 3559 |
match = re.search(r'\.out(\d+)$', gate)
|
|
|
|
| 6005 |
|
| 6006 |
# === OUTPUT SELECTION ===
|
| 6007 |
# Select between positive path, negative path, and zero
|
| 6008 |
+
# Gate by not_result_is_zero to force output to 0 for |value| < 1
|
| 6009 |
|
| 6010 |
# NOT of sign bit for muxing positive path
|
| 6011 |
tensors[f"{prefix}.not_sign.weight"] = torch.tensor([-1.0])
|
| 6012 |
tensors[f"{prefix}.not_sign.bias"] = torch.tensor([0.0])
|
| 6013 |
|
| 6014 |
for i in range(16):
|
| 6015 |
+
# pos_path = shifted_value AND not_sign AND not_result_is_zero
|
| 6016 |
+
tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0, 1.0])
|
| 6017 |
+
tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-3.0])
|
| 6018 |
+
# neg_path = negated_value AND sign AND not_result_is_zero
|
| 6019 |
+
tensors[f"{prefix}.out{i}.neg_path.weight"] = torch.tensor([1.0, 1.0, 1.0])
|
| 6020 |
+
tensors[f"{prefix}.out{i}.neg_path.bias"] = torch.tensor([-3.0])
|
| 6021 |
+
# out = pos_path OR neg_path
|
| 6022 |
tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 6023 |
tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-1.0])
|
| 6024 |
|
|
|
|
| 6043 |
tensors[f"{prefix}.is_zero.weight"] = torch.tensor([-1.0] * 16)
|
| 6044 |
tensors[f"{prefix}.is_zero.bias"] = torch.tensor([0.0])
|
| 6045 |
|
| 6046 |
+
# NOT is_zero for gating normal output
|
| 6047 |
+
tensors[f"{prefix}.not_is_zero.weight"] = torch.tensor([-1.0])
|
| 6048 |
+
tensors[f"{prefix}.not_is_zero.bias"] = torch.tensor([0.0])
|
| 6049 |
+
|
| 6050 |
# Check if negative (sign bit)
|
| 6051 |
tensors[f"{prefix}.is_negative.weight"] = torch.tensor([1.0])
|
| 6052 |
tensors[f"{prefix}.is_negative.bias"] = torch.tensor([-0.5])
|
|
|
|
| 6092 |
tensors[f"{prefix}.ge{k}.bias"] = torch.tensor([-float(k)])
|
| 6093 |
|
| 6094 |
# CLZ binary encoding
|
| 6095 |
+
for k in [2, 4, 6, 8, 10, 12, 14, 16]:
|
| 6096 |
tensors[f"{prefix}.not_ge{k}.weight"] = torch.tensor([-1.0])
|
| 6097 |
tensors[f"{prefix}.not_ge{k}.bias"] = torch.tensor([0.0])
|
| 6098 |
|