PortfolioAI commited on
Commit ·
48421e0
1
Parent(s): 1bcc911
Rewrite float16.toint with right-shift barrel shifter
Browse files- Changed from left-shift to right-shift (25-exp positions)
- Improved from 31/93 to 54/93 tests passing
- Still needs debugging for remaining edge cases
- arithmetic.safetensors +2 -2
- convert_to_explicit_inputs.py +200 -57
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c7bc258f49a9a4c85321d0980f843bb989e8fb84c4d8bc65883f24c6e306334
|
| 3 |
+
size 2863492
|
convert_to_explicit_inputs.py
CHANGED
|
@@ -3046,7 +3046,7 @@ def infer_float16_div_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
|
| 3046 |
|
| 3047 |
|
| 3048 |
def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 3049 |
-
"""Infer inputs for float16.toint circuit."""
|
| 3050 |
prefix = "float16.toint"
|
| 3051 |
|
| 3052 |
for i in range(16):
|
|
@@ -3055,6 +3055,7 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3055 |
exp_bits = [f"{prefix}.$x[{10+i}]" for i in range(5)]
|
| 3056 |
mant_bits = [f"{prefix}.$x[{i}]" for i in range(10)]
|
| 3057 |
|
|
|
|
| 3058 |
if '.exp_all_ones' in gate:
|
| 3059 |
return [registry.get_id(b) for b in exp_bits]
|
| 3060 |
if '.exp_zero' in gate:
|
|
@@ -3080,7 +3081,7 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3080 |
if '.is_inf' in gate:
|
| 3081 |
return [registry.get_id(f"{prefix}.exp_all_ones"),
|
| 3082 |
registry.get_id(f"{prefix}.mant_zero")]
|
| 3083 |
-
if '.is_zero' in gate:
|
| 3084 |
return [registry.get_id(f"{prefix}.exp_zero"),
|
| 3085 |
registry.get_id(f"{prefix}.mant_zero")]
|
| 3086 |
|
|
@@ -3092,27 +3093,53 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3092 |
|
| 3093 |
registry.register(f"{prefix}.exp_lt_15")
|
| 3094 |
|
| 3095 |
-
if '.result_is_zero' in gate:
|
| 3096 |
return [registry.get_id(f"{prefix}.is_nan"),
|
| 3097 |
registry.get_id(f"{prefix}.is_zero"),
|
| 3098 |
registry.get_id(f"{prefix}.exp_lt_15")]
|
| 3099 |
|
| 3100 |
registry.register(f"{prefix}.result_is_zero")
|
| 3101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3102 |
if '.implicit_bit' in gate:
|
| 3103 |
return [registry.get_id(f"{prefix}.exp_zero")]
|
| 3104 |
|
| 3105 |
registry.register(f"{prefix}.implicit_bit")
|
| 3106 |
|
| 3107 |
-
|
| 3108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3109 |
|
| 3110 |
-
|
| 3111 |
-
|
| 3112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3114 |
for i in range(5):
|
| 3115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3116 |
|
| 3117 |
if '.shift_calc.fa' in gate:
|
| 3118 |
match = re.search(r'\.shift_calc\.fa(\d+)\.', gate)
|
|
@@ -3120,9 +3147,11 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3120 |
i = int(match.group(1))
|
| 3121 |
fa_prefix = f"{prefix}.shift_calc.fa{i}"
|
| 3122 |
|
| 3123 |
-
|
| 3124 |
-
|
| 3125 |
-
|
|
|
|
|
|
|
| 3126 |
|
| 3127 |
if '.xor1.layer1' in gate:
|
| 3128 |
return [a_bit, b_bit]
|
|
@@ -3147,6 +3176,7 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3147 |
registry.register(f"{prefix}.shift_calc.fa{i}.xor2.layer2")
|
| 3148 |
registry.register(f"{prefix}.shift_calc.fa{i}.cout")
|
| 3149 |
|
|
|
|
| 3150 |
for stage in range(4):
|
| 3151 |
shift_amt = 1 << stage
|
| 3152 |
|
|
@@ -3154,10 +3184,13 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3154 |
return [registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
|
| 3155 |
registry.register(f"{prefix}.not_shift{stage}")
|
| 3156 |
|
| 3157 |
-
match = re.search(rf'\.
|
| 3158 |
if match:
|
| 3159 |
i = int(match.group(1))
|
|
|
|
|
|
|
| 3160 |
if '.pass' in gate:
|
|
|
|
| 3161 |
if stage == 0:
|
| 3162 |
if i < 10:
|
| 3163 |
val = registry.get_id(mant_bits[i])
|
|
@@ -3166,36 +3199,39 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3166 |
else:
|
| 3167 |
val = registry.get_id("#0")
|
| 3168 |
else:
|
| 3169 |
-
val = registry.get_id(f"{prefix}.
|
| 3170 |
return [val, registry.get_id(f"{prefix}.not_shift{stage}")]
|
| 3171 |
-
|
|
|
|
|
|
|
| 3172 |
if stage == 0:
|
| 3173 |
-
|
| 3174 |
-
|
| 3175 |
-
|
| 3176 |
-
elif prev_i == 10:
|
| 3177 |
val = registry.get_id(f"{prefix}.implicit_bit")
|
| 3178 |
else:
|
| 3179 |
val = registry.get_id("#0")
|
| 3180 |
else:
|
| 3181 |
-
val = registry.get_id(f"{prefix}.
|
| 3182 |
return [val, registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
|
| 3183 |
|
| 3184 |
-
match = re.search(rf'\.
|
| 3185 |
if match:
|
| 3186 |
i = int(match.group(1))
|
| 3187 |
-
|
| 3188 |
-
|
| 3189 |
-
|
|
|
|
| 3190 |
else:
|
| 3191 |
-
return [registry.register(f"{prefix}.
|
| 3192 |
|
| 3193 |
for i in range(16):
|
| 3194 |
-
registry.register(f"{prefix}.
|
| 3195 |
|
|
|
|
| 3196 |
for i in range(16):
|
| 3197 |
if f'.not_mag{i}' in gate:
|
| 3198 |
-
return [registry.get_id(f"{prefix}.
|
| 3199 |
registry.register(f"{prefix}.not_mag{i}")
|
| 3200 |
|
| 3201 |
if '.neg.fa' in gate:
|
|
@@ -3222,17 +3258,24 @@ def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]
|
|
| 3222 |
registry.register(f"{prefix}.neg.fa{i}.xor.layer2")
|
| 3223 |
registry.register(f"{prefix}.neg.fa{i}.cout")
|
| 3224 |
|
|
|
|
| 3225 |
match = re.search(r'\.out(\d+)\.', gate)
|
| 3226 |
if match:
|
| 3227 |
i = int(match.group(1))
|
| 3228 |
sign = registry.get_id(f"{prefix}.$x[15]")
|
|
|
|
| 3229 |
if '.pos_path' in gate:
|
| 3230 |
-
return [registry.get_id(f"{prefix}.
|
| 3231 |
-
|
| 3232 |
if '.neg_path' in gate:
|
| 3233 |
return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
|
| 3234 |
sign]
|
| 3235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3236 |
match = re.search(r'\.out(\d+)$', gate)
|
| 3237 |
if match:
|
| 3238 |
i = int(match.group(1))
|
|
@@ -5726,54 +5769,135 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
|
|
| 5726 |
Convert float16 to signed 16-bit integer (truncate toward zero).
|
| 5727 |
|
| 5728 |
Algorithm:
|
| 5729 |
-
1.
|
| 5730 |
-
2.
|
| 5731 |
-
3.
|
| 5732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5733 |
"""
|
| 5734 |
tensors = {}
|
| 5735 |
prefix = "float16.toint"
|
| 5736 |
|
| 5737 |
-
#
|
|
|
|
| 5738 |
tensors[f"{prefix}.exp_all_ones.weight"] = torch.tensor([1.0] * 5)
|
| 5739 |
tensors[f"{prefix}.exp_all_ones.bias"] = torch.tensor([-5.0])
|
| 5740 |
|
|
|
|
| 5741 |
tensors[f"{prefix}.exp_zero.weight"] = torch.tensor([-1.0] * 5)
|
| 5742 |
tensors[f"{prefix}.exp_zero.bias"] = torch.tensor([0.0])
|
| 5743 |
|
|
|
|
| 5744 |
tensors[f"{prefix}.mant_nonzero.weight"] = torch.tensor([1.0] * 10)
|
| 5745 |
tensors[f"{prefix}.mant_nonzero.bias"] = torch.tensor([-1.0])
|
| 5746 |
|
|
|
|
| 5747 |
tensors[f"{prefix}.is_nan.weight"] = torch.tensor([1.0, 1.0])
|
| 5748 |
tensors[f"{prefix}.is_nan.bias"] = torch.tensor([-2.0])
|
| 5749 |
|
|
|
|
| 5750 |
tensors[f"{prefix}.mant_zero.weight"] = torch.tensor([-1.0])
|
| 5751 |
tensors[f"{prefix}.mant_zero.bias"] = torch.tensor([0.0])
|
| 5752 |
|
|
|
|
| 5753 |
tensors[f"{prefix}.is_inf.weight"] = torch.tensor([1.0, 1.0])
|
| 5754 |
tensors[f"{prefix}.is_inf.bias"] = torch.tensor([-2.0])
|
| 5755 |
|
|
|
|
| 5756 |
tensors[f"{prefix}.is_zero.weight"] = torch.tensor([1.0, 1.0])
|
| 5757 |
tensors[f"{prefix}.is_zero.bias"] = torch.tensor([-2.0])
|
| 5758 |
|
| 5759 |
-
#
|
| 5760 |
-
# exp < 15 means unbiased
|
|
|
|
| 5761 |
weights = [-float(2**i) for i in range(5)]
|
| 5762 |
tensors[f"{prefix}.exp_lt_15.weight"] = torch.tensor(weights)
|
| 5763 |
tensors[f"{prefix}.exp_lt_15.bias"] = torch.tensor([14.0])
|
| 5764 |
|
|
|
|
| 5765 |
tensors[f"{prefix}.result_is_zero.weight"] = torch.tensor([1.0, 1.0, 1.0])
|
| 5766 |
tensors[f"{prefix}.result_is_zero.bias"] = torch.tensor([-1.0])
|
| 5767 |
|
| 5768 |
-
#
|
| 5769 |
-
|
| 5770 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5771 |
|
| 5772 |
-
# exp - 25 subtractor
|
| 5773 |
for i in range(5):
|
| 5774 |
-
tensors[f"{prefix}.
|
| 5775 |
-
tensors[f"{prefix}.
|
| 5776 |
|
|
|
|
|
|
|
|
|
|
| 5777 |
for i in range(6):
|
| 5778 |
p = f"{prefix}.shift_calc.fa{i}"
|
| 5779 |
tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
|
@@ -5797,32 +5921,40 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
|
|
| 5797 |
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 5798 |
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 5799 |
|
| 5800 |
-
#
|
| 5801 |
-
|
| 5802 |
-
|
|
|
|
| 5803 |
|
| 5804 |
for stage in range(4):
|
| 5805 |
shift_amt = 1 << stage
|
|
|
|
| 5806 |
tensors[f"{prefix}.not_shift{stage}.weight"] = torch.tensor([-1.0])
|
| 5807 |
tensors[f"{prefix}.not_shift{stage}.bias"] = torch.tensor([0.0])
|
| 5808 |
|
| 5809 |
for i in range(16):
|
| 5810 |
-
|
| 5811 |
-
tensors[f"{prefix}.
|
| 5812 |
-
|
| 5813 |
-
|
| 5814 |
-
|
| 5815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5816 |
else:
|
| 5817 |
-
|
| 5818 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5819 |
|
| 5820 |
-
# Apply sign (negate if negative)
|
| 5821 |
for i in range(16):
|
| 5822 |
tensors[f"{prefix}.not_mag{i}.weight"] = torch.tensor([-1.0])
|
| 5823 |
tensors[f"{prefix}.not_mag{i}.bias"] = torch.tensor([0.0])
|
| 5824 |
|
| 5825 |
-
# Two's complement negation
|
| 5826 |
for i in range(16):
|
| 5827 |
p = f"{prefix}.neg.fa{i}"
|
| 5828 |
tensors[f"{p}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
|
@@ -5837,7 +5969,13 @@ def build_float16_toint_tensors() -> Dict[str, torch.Tensor]:
|
|
| 5837 |
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 5838 |
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 5839 |
|
| 5840 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5841 |
for i in range(16):
|
| 5842 |
tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0])
|
| 5843 |
tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-2.0])
|
|
@@ -6081,12 +6219,17 @@ def main():
|
|
| 6081 |
|
| 6082 |
print(f"Loaded {len(tensors)} tensors")
|
| 6083 |
|
| 6084 |
-
# Remove old
|
| 6085 |
old_float16_add = [k for k in tensors.keys() if k.startswith('float16.add')]
|
| 6086 |
for k in old_float16_add:
|
| 6087 |
del tensors[k]
|
| 6088 |
print(f"Removed {len(old_float16_add)} old float16.add tensors")
|
| 6089 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6090 |
# Build new circuits
|
| 6091 |
print("Building new circuits...")
|
| 6092 |
clz_tensors = build_clz8bit_tensors()
|
|
|
|
| 3046 |
|
| 3047 |
|
| 3048 |
def infer_float16_toint_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 3049 |
+
"""Infer inputs for float16.toint circuit (with right-shift barrel shifter)."""
|
| 3050 |
prefix = "float16.toint"
|
| 3051 |
|
| 3052 |
for i in range(16):
|
|
|
|
| 3055 |
exp_bits = [f"{prefix}.$x[{10+i}]" for i in range(5)]
|
| 3056 |
mant_bits = [f"{prefix}.$x[{i}]" for i in range(10)]
|
| 3057 |
|
| 3058 |
+
# === SPECIAL CASE DETECTION ===
|
| 3059 |
if '.exp_all_ones' in gate:
|
| 3060 |
return [registry.get_id(b) for b in exp_bits]
|
| 3061 |
if '.exp_zero' in gate:
|
|
|
|
| 3081 |
if '.is_inf' in gate:
|
| 3082 |
return [registry.get_id(f"{prefix}.exp_all_ones"),
|
| 3083 |
registry.get_id(f"{prefix}.mant_zero")]
|
| 3084 |
+
if '.is_zero' in gate and '.not_' not in gate and '.result_is_zero' not in gate:
|
| 3085 |
return [registry.get_id(f"{prefix}.exp_zero"),
|
| 3086 |
registry.get_id(f"{prefix}.mant_zero")]
|
| 3087 |
|
|
|
|
| 3093 |
|
| 3094 |
registry.register(f"{prefix}.exp_lt_15")
|
| 3095 |
|
| 3096 |
+
if '.result_is_zero' in gate and '.not_' not in gate:
|
| 3097 |
return [registry.get_id(f"{prefix}.is_nan"),
|
| 3098 |
registry.get_id(f"{prefix}.is_zero"),
|
| 3099 |
registry.get_id(f"{prefix}.exp_lt_15")]
|
| 3100 |
|
| 3101 |
registry.register(f"{prefix}.result_is_zero")
|
| 3102 |
|
| 3103 |
+
if '.not_result_is_zero' in gate:
|
| 3104 |
+
return [registry.get_id(f"{prefix}.result_is_zero")]
|
| 3105 |
+
|
| 3106 |
+
registry.register(f"{prefix}.not_result_is_zero")
|
| 3107 |
+
|
| 3108 |
if '.implicit_bit' in gate:
|
| 3109 |
return [registry.get_id(f"{prefix}.exp_zero")]
|
| 3110 |
|
| 3111 |
registry.register(f"{prefix}.implicit_bit")
|
| 3112 |
|
| 3113 |
+
# === THRESHOLD GATES FOR SHIFT CONTROL ===
|
| 3114 |
+
if '.exp_ge_15' in gate:
|
| 3115 |
+
return [registry.get_id(b) for b in exp_bits]
|
| 3116 |
+
if '.exp_ge_18' in gate:
|
| 3117 |
+
return [registry.get_id(b) for b in exp_bits]
|
| 3118 |
+
if '.exp_le_21' in gate:
|
| 3119 |
+
return [registry.get_id(b) for b in exp_bits]
|
| 3120 |
|
| 3121 |
+
registry.register(f"{prefix}.exp_ge_15")
|
| 3122 |
+
registry.register(f"{prefix}.exp_ge_18")
|
| 3123 |
+
registry.register(f"{prefix}.exp_le_21")
|
| 3124 |
+
|
| 3125 |
+
if '.shift_bit3' in gate:
|
| 3126 |
+
return [registry.get_id(b) for b in exp_bits]
|
| 3127 |
+
if '.shift_bit2' in gate:
|
| 3128 |
+
return [registry.get_id(f"{prefix}.exp_ge_18"),
|
| 3129 |
+
registry.get_id(f"{prefix}.exp_le_21")]
|
| 3130 |
|
| 3131 |
+
registry.register(f"{prefix}.shift_bit3")
|
| 3132 |
+
registry.register(f"{prefix}.shift_bit2")
|
| 3133 |
+
|
| 3134 |
+
# === NOT OF EXPONENT BITS ===
|
| 3135 |
for i in range(5):
|
| 3136 |
+
if f'.not_exp{i}' in gate:
|
| 3137 |
+
return [registry.get_id(exp_bits[i])]
|
| 3138 |
+
registry.register(f"{prefix}.not_exp{i}")
|
| 3139 |
+
|
| 3140 |
+
# === SHIFT CALCULATION: 25 - exp = ~exp + 26 ===
|
| 3141 |
+
# 26 = 0b011010
|
| 3142 |
+
const_26 = [0, 1, 0, 1, 1, 0]
|
| 3143 |
|
| 3144 |
if '.shift_calc.fa' in gate:
|
| 3145 |
match = re.search(r'\.shift_calc\.fa(\d+)\.', gate)
|
|
|
|
| 3147 |
i = int(match.group(1))
|
| 3148 |
fa_prefix = f"{prefix}.shift_calc.fa{i}"
|
| 3149 |
|
| 3150 |
+
# a = ~exp[i] (or 1 for i >= 5)
|
| 3151 |
+
a_bit = registry.get_id(f"{prefix}.not_exp{i}") if i < 5 else registry.get_id("#1")
|
| 3152 |
+
# b = const_26[i]
|
| 3153 |
+
b_bit = registry.get_id(f"#{const_26[i]}")
|
| 3154 |
+
cin = registry.get_id("#0") if i == 0 else registry.register(f"{prefix}.shift_calc.fa{i-1}.cout")
|
| 3155 |
|
| 3156 |
if '.xor1.layer1' in gate:
|
| 3157 |
return [a_bit, b_bit]
|
|
|
|
| 3176 |
registry.register(f"{prefix}.shift_calc.fa{i}.xor2.layer2")
|
| 3177 |
registry.register(f"{prefix}.shift_calc.fa{i}.cout")
|
| 3178 |
|
| 3179 |
+
# === RIGHT-SHIFT BARREL SHIFTER ===
|
| 3180 |
for stage in range(4):
|
| 3181 |
shift_amt = 1 << stage
|
| 3182 |
|
|
|
|
| 3184 |
return [registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
|
| 3185 |
registry.register(f"{prefix}.not_shift{stage}")
|
| 3186 |
|
| 3187 |
+
match = re.search(rf'\.rshift_s{stage}_(\d+)\.', gate)
|
| 3188 |
if match:
|
| 3189 |
i = int(match.group(1))
|
| 3190 |
+
src_pos = i + shift_amt
|
| 3191 |
+
|
| 3192 |
if '.pass' in gate:
|
| 3193 |
+
# Current value (from previous stage or input)
|
| 3194 |
if stage == 0:
|
| 3195 |
if i < 10:
|
| 3196 |
val = registry.get_id(mant_bits[i])
|
|
|
|
| 3199 |
else:
|
| 3200 |
val = registry.get_id("#0")
|
| 3201 |
else:
|
| 3202 |
+
val = registry.get_id(f"{prefix}.rshift_s{stage-1}_{i}")
|
| 3203 |
return [val, registry.get_id(f"{prefix}.not_shift{stage}")]
|
| 3204 |
+
|
| 3205 |
+
if '.shift' in gate and src_pos < 16:
|
| 3206 |
+
# Value from higher position
|
| 3207 |
if stage == 0:
|
| 3208 |
+
if src_pos < 10:
|
| 3209 |
+
val = registry.get_id(mant_bits[src_pos])
|
| 3210 |
+
elif src_pos == 10:
|
|
|
|
| 3211 |
val = registry.get_id(f"{prefix}.implicit_bit")
|
| 3212 |
else:
|
| 3213 |
val = registry.get_id("#0")
|
| 3214 |
else:
|
| 3215 |
+
val = registry.get_id(f"{prefix}.rshift_s{stage-1}_{src_pos}")
|
| 3216 |
return [val, registry.get_id(f"{prefix}.shift_calc.fa{stage}.xor2.layer2")]
|
| 3217 |
|
| 3218 |
+
match = re.search(rf'\.rshift_s{stage}_(\d+)$', gate)
|
| 3219 |
if match:
|
| 3220 |
i = int(match.group(1))
|
| 3221 |
+
src_pos = i + shift_amt
|
| 3222 |
+
if src_pos < 16:
|
| 3223 |
+
return [registry.register(f"{prefix}.rshift_s{stage}_{i}.pass"),
|
| 3224 |
+
registry.register(f"{prefix}.rshift_s{stage}_{i}.shift")]
|
| 3225 |
else:
|
| 3226 |
+
return [registry.register(f"{prefix}.rshift_s{stage}_{i}.pass")]
|
| 3227 |
|
| 3228 |
for i in range(16):
|
| 3229 |
+
registry.register(f"{prefix}.rshift_s{stage}_{i}")
|
| 3230 |
|
| 3231 |
+
# === NEGATION ===
|
| 3232 |
for i in range(16):
|
| 3233 |
if f'.not_mag{i}' in gate:
|
| 3234 |
+
return [registry.get_id(f"{prefix}.rshift_s3_{i}")]
|
| 3235 |
registry.register(f"{prefix}.not_mag{i}")
|
| 3236 |
|
| 3237 |
if '.neg.fa' in gate:
|
|
|
|
| 3258 |
registry.register(f"{prefix}.neg.fa{i}.xor.layer2")
|
| 3259 |
registry.register(f"{prefix}.neg.fa{i}.cout")
|
| 3260 |
|
| 3261 |
+
# === OUTPUT SELECTION ===
|
| 3262 |
match = re.search(r'\.out(\d+)\.', gate)
|
| 3263 |
if match:
|
| 3264 |
i = int(match.group(1))
|
| 3265 |
sign = registry.get_id(f"{prefix}.$x[15]")
|
| 3266 |
+
not_sign = registry.register(f"{prefix}.not_sign")
|
| 3267 |
if '.pos_path' in gate:
|
| 3268 |
+
return [registry.get_id(f"{prefix}.rshift_s3_{i}"),
|
| 3269 |
+
not_sign]
|
| 3270 |
if '.neg_path' in gate:
|
| 3271 |
return [registry.get_id(f"{prefix}.neg.fa{i}.xor.layer2"),
|
| 3272 |
sign]
|
| 3273 |
|
| 3274 |
+
# not_sign gate
|
| 3275 |
+
if '.not_sign' in gate:
|
| 3276 |
+
return [registry.get_id(f"{prefix}.$x[15]")]
|
| 3277 |
+
registry.register(f"{prefix}.not_sign")
|
| 3278 |
+
|
| 3279 |
match = re.search(r'\.out(\d+)$', gate)
|
| 3280 |
if match:
|
| 3281 |
i = int(match.group(1))
|
|
|
|
| 5769 |
Convert float16 to signed 16-bit integer (truncate toward zero).
|
| 5770 |
|
| 5771 |
Algorithm:
|
| 5772 |
+
1. Extract mantissa M with implicit bit (11 bits, bit 10 = implicit 1)
|
| 5773 |
+
2. For exp < 15: result = 0 (|value| < 1)
|
| 5774 |
+
3. For exp >= 15: right-shift M by (25 - exp) positions
|
| 5775 |
+
- exp = 15: shift by 10, result = 1 for normalized
|
| 5776 |
+
- exp = 25: shift by 0, result = M (up to 2047)
|
| 5777 |
+
- exp > 25: would need left shift, but limited range
|
| 5778 |
+
4. Apply sign via two's complement negation
|
| 5779 |
+
5. Handle special cases: NaN, Inf, overflow
|
| 5780 |
"""
|
| 5781 |
tensors = {}
|
| 5782 |
prefix = "float16.toint"
|
| 5783 |
|
| 5784 |
+
# === SPECIAL CASE DETECTION ===
|
| 5785 |
+
# exp_all_ones: exponent = 31 (NaN or Inf)
|
| 5786 |
tensors[f"{prefix}.exp_all_ones.weight"] = torch.tensor([1.0] * 5)
|
| 5787 |
tensors[f"{prefix}.exp_all_ones.bias"] = torch.tensor([-5.0])
|
| 5788 |
|
| 5789 |
+
# exp_zero: exponent = 0 (zero or subnormal)
|
| 5790 |
tensors[f"{prefix}.exp_zero.weight"] = torch.tensor([-1.0] * 5)
|
| 5791 |
tensors[f"{prefix}.exp_zero.bias"] = torch.tensor([0.0])
|
| 5792 |
|
| 5793 |
+
# mant_nonzero: any mantissa bit set
|
| 5794 |
tensors[f"{prefix}.mant_nonzero.weight"] = torch.tensor([1.0] * 10)
|
| 5795 |
tensors[f"{prefix}.mant_nonzero.bias"] = torch.tensor([-1.0])
|
| 5796 |
|
| 5797 |
+
# is_nan: exp=31 AND mant!=0
|
| 5798 |
tensors[f"{prefix}.is_nan.weight"] = torch.tensor([1.0, 1.0])
|
| 5799 |
tensors[f"{prefix}.is_nan.bias"] = torch.tensor([-2.0])
|
| 5800 |
|
| 5801 |
+
# mant_zero: NOT mant_nonzero
|
| 5802 |
tensors[f"{prefix}.mant_zero.weight"] = torch.tensor([-1.0])
|
| 5803 |
tensors[f"{prefix}.mant_zero.bias"] = torch.tensor([0.0])
|
| 5804 |
|
| 5805 |
+
# is_inf: exp=31 AND mant=0
|
| 5806 |
tensors[f"{prefix}.is_inf.weight"] = torch.tensor([1.0, 1.0])
|
| 5807 |
tensors[f"{prefix}.is_inf.bias"] = torch.tensor([-2.0])
|
| 5808 |
|
| 5809 |
+
# is_zero: exp=0 AND mant=0
|
| 5810 |
tensors[f"{prefix}.is_zero.weight"] = torch.tensor([1.0, 1.0])
|
| 5811 |
tensors[f"{prefix}.is_zero.bias"] = torch.tensor([-2.0])
|
| 5812 |
|
| 5813 |
+
# === CHECK IF |VALUE| < 1 ===
|
| 5814 |
+
# exp < 15 means unbiased exponent < 0, so |value| < 1
|
| 5815 |
+
# Use threshold: sum(exp[i] * 2^i) < 15
|
| 5816 |
weights = [-float(2**i) for i in range(5)]
|
| 5817 |
tensors[f"{prefix}.exp_lt_15.weight"] = torch.tensor(weights)
|
| 5818 |
tensors[f"{prefix}.exp_lt_15.bias"] = torch.tensor([14.0])
|
| 5819 |
|
| 5820 |
+
# result_is_zero: exp_zero OR exp_lt_15 OR is_nan
|
| 5821 |
tensors[f"{prefix}.result_is_zero.weight"] = torch.tensor([1.0, 1.0, 1.0])
|
| 5822 |
tensors[f"{prefix}.result_is_zero.bias"] = torch.tensor([-1.0])
|
| 5823 |
|
| 5824 |
+
# not_result_is_zero for muxing
|
| 5825 |
+
tensors[f"{prefix}.not_result_is_zero.weight"] = torch.tensor([-1.0])
|
| 5826 |
+
tensors[f"{prefix}.not_result_is_zero.bias"] = torch.tensor([0.0])
|
| 5827 |
+
|
| 5828 |
+
# === COMPUTE SHIFT AMOUNT: 25 - exp ===
|
| 5829 |
+
# For right shift: shift_amt = 25 - exp (need 0 to 10 for normal range)
|
| 5830 |
+
# 25 = 0b11001, so we compute NOT(exp) + 25 + 1 = NOT(exp) + 26
|
| 5831 |
+
# Actually simpler: use 25 - exp directly with threshold gates
|
| 5832 |
+
|
| 5833 |
+
# We'll use a different approach: compute exp directly and use threshold
|
| 5834 |
+
# gates to determine shift amount bits
|
| 5835 |
+
|
| 5836 |
+
# Implicit bit (always 1 for normalized numbers, 0 for subnormals)
|
| 5837 |
+
# implicit = NOT exp_zero
|
| 5838 |
+
tensors[f"{prefix}.implicit_bit.weight"] = torch.tensor([-1.0])
|
| 5839 |
+
tensors[f"{prefix}.implicit_bit.bias"] = torch.tensor([0.0])
|
| 5840 |
+
|
| 5841 |
+
# === DIRECT SHIFT USING EXPONENT VALUE ===
|
| 5842 |
+
# For exp in range 15-25, shift right by (25-exp)
|
| 5843 |
+
# For exp >= 25, no shift or left shift (overflow territory)
|
| 5844 |
+
#
|
| 5845 |
+
# Shift amounts needed: 0-10 for exp 25-15
|
| 5846 |
+
# shift[0] = 1 if (25-exp) is odd = exp is even when exp in {15,17,19,21,23,25}
|
| 5847 |
+
# This is complex. Let's use threshold gates on exp value.
|
| 5848 |
+
|
| 5849 |
+
# exp_ge_15: exp >= 15 (value >= 1)
|
| 5850 |
+
tensors[f"{prefix}.exp_ge_15.weight"] = torch.tensor([float(2**i) for i in range(5)])
|
| 5851 |
+
tensors[f"{prefix}.exp_ge_15.bias"] = torch.tensor([-15.0])
|
| 5852 |
+
|
| 5853 |
+
# For each shift stage, determine if we should shift
|
| 5854 |
+
# Right shift by 2^k if bit k of (25-exp) is set
|
| 5855 |
+
# 25 - exp for exp in [15, 25]: shift in [10, 0]
|
| 5856 |
+
# Binary of shift amounts:
|
| 5857 |
+
# exp=15: shift=10 = 0b1010
|
| 5858 |
+
# exp=16: shift=9 = 0b1001
|
| 5859 |
+
# exp=17: shift=8 = 0b1000
|
| 5860 |
+
# exp=18: shift=7 = 0b0111
|
| 5861 |
+
# exp=19: shift=6 = 0b0110
|
| 5862 |
+
# exp=20: shift=5 = 0b0101
|
| 5863 |
+
# exp=21: shift=4 = 0b0100
|
| 5864 |
+
# exp=22: shift=3 = 0b0011
|
| 5865 |
+
# exp=23: shift=2 = 0b0010
|
| 5866 |
+
# exp=24: shift=1 = 0b0001
|
| 5867 |
+
# exp=25: shift=0 = 0b0000
|
| 5868 |
+
|
| 5869 |
+
# Use threshold on exp to determine shift control bits
|
| 5870 |
+
# shift_bit3 (shift by 8): exp <= 17 (shift >= 8)
|
| 5871 |
+
tensors[f"{prefix}.shift_bit3.weight"] = torch.tensor([-float(2**i) for i in range(5)])
|
| 5872 |
+
tensors[f"{prefix}.shift_bit3.bias"] = torch.tensor([17.0])
|
| 5873 |
+
|
| 5874 |
+
# shift_bit2 (shift by 4): (exp <= 17) OR (18 <= exp <= 21)
|
| 5875 |
+
# = exp <= 21 AND NOT (18 <= exp <= 21 AND exp > 17)... complex
|
| 5876 |
+
# Simpler: shift_bit2 = 1 when shift in {4,5,6,7,12,13,14,15} ∩ [0,10] = {4,5,6,7}
|
| 5877 |
+
# = exp in {18,19,20,21}
|
| 5878 |
+
# Use: exp >= 18 AND exp <= 21
|
| 5879 |
+
tensors[f"{prefix}.exp_ge_18.weight"] = torch.tensor([float(2**i) for i in range(5)])
|
| 5880 |
+
tensors[f"{prefix}.exp_ge_18.bias"] = torch.tensor([-18.0])
|
| 5881 |
+
tensors[f"{prefix}.exp_le_21.weight"] = torch.tensor([-float(2**i) for i in range(5)])
|
| 5882 |
+
tensors[f"{prefix}.exp_le_21.bias"] = torch.tensor([21.0])
|
| 5883 |
+
tensors[f"{prefix}.shift_bit2.weight"] = torch.tensor([1.0, 1.0])
|
| 5884 |
+
tensors[f"{prefix}.shift_bit2.bias"] = torch.tensor([-2.0])
|
| 5885 |
+
|
| 5886 |
+
# shift_bit1 (shift by 2): shift in {2,3,6,7,10,11,...} ∩ [0,10] = {2,3,6,7,10}
|
| 5887 |
+
# = exp in {15,19,22,23} -- this is getting complex
|
| 5888 |
+
# Let's use a simpler direct threshold approach
|
| 5889 |
+
|
| 5890 |
+
# Actually, let's compute 25-exp using subtraction, then use those bits
|
| 5891 |
+
# 25 = 0b011001 (6 bits), exp is 5 bits
|
| 5892 |
+
# 25 - exp in two's complement
|
| 5893 |
|
|
|
|
| 5894 |
for i in range(5):
|
| 5895 |
+
tensors[f"{prefix}.not_exp{i}.weight"] = torch.tensor([-1.0])
|
| 5896 |
+
tensors[f"{prefix}.not_exp{i}.bias"] = torch.tensor([0.0])
|
| 5897 |
|
| 5898 |
+
# 25 - exp = 25 + (~exp) + 1 = 26 + ~exp (in binary)
|
| 5899 |
+
# 26 = 0b011010
|
| 5900 |
+
const_26 = [0, 1, 0, 1, 1, 0] # bits of 26
|
| 5901 |
for i in range(6):
|
| 5902 |
p = f"{prefix}.shift_calc.fa{i}"
|
| 5903 |
tensors[f"{p}.xor1.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
|
|
|
| 5921 |
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 5922 |
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 5923 |
|
| 5924 |
+
# === RIGHT-SHIFT BARREL SHIFTER ===
|
| 5925 |
+
# 4 stages for shifts of 1, 2, 4, 8
|
| 5926 |
+
# Input: mantissa (10 bits) + implicit bit at position 10 = 11 bits
|
| 5927 |
+
# We'll work with 16 bits to have room
|
| 5928 |
|
| 5929 |
for stage in range(4):
|
| 5930 |
shift_amt = 1 << stage
|
| 5931 |
+
# NOT of shift control bit for mux
|
| 5932 |
tensors[f"{prefix}.not_shift{stage}.weight"] = torch.tensor([-1.0])
|
| 5933 |
tensors[f"{prefix}.not_shift{stage}.bias"] = torch.tensor([0.0])
|
| 5934 |
|
| 5935 |
for i in range(16):
|
| 5936 |
+
# pass: keep current position (AND with NOT shift_bit)
|
| 5937 |
+
tensors[f"{prefix}.rshift_s{stage}_{i}.pass.weight"] = torch.tensor([1.0, 1.0])
|
| 5938 |
+
tensors[f"{prefix}.rshift_s{stage}_{i}.pass.bias"] = torch.tensor([-2.0])
|
| 5939 |
+
|
| 5940 |
+
# shift: take from higher position (AND with shift_bit)
|
| 5941 |
+
src_pos = i + shift_amt
|
| 5942 |
+
if src_pos < 16:
|
| 5943 |
+
tensors[f"{prefix}.rshift_s{stage}_{i}.shift.weight"] = torch.tensor([1.0, 1.0])
|
| 5944 |
+
tensors[f"{prefix}.rshift_s{stage}_{i}.shift.bias"] = torch.tensor([-2.0])
|
| 5945 |
+
tensors[f"{prefix}.rshift_s{stage}_{i}.weight"] = torch.tensor([1.0, 1.0])
|
| 5946 |
else:
|
| 5947 |
+
# Shift in 0 from above
|
| 5948 |
+
tensors[f"{prefix}.rshift_s{stage}_{i}.weight"] = torch.tensor([1.0])
|
| 5949 |
+
tensors[f"{prefix}.rshift_s{stage}_{i}.bias"] = torch.tensor([-1.0])
|
| 5950 |
+
|
| 5951 |
+
# === TWO'S COMPLEMENT NEGATION FOR NEGATIVE FLOATS ===
|
| 5952 |
+
# If sign bit is 1, negate the result
|
| 5953 |
|
|
|
|
| 5954 |
for i in range(16):
|
| 5955 |
tensors[f"{prefix}.not_mag{i}.weight"] = torch.tensor([-1.0])
|
| 5956 |
tensors[f"{prefix}.not_mag{i}.bias"] = torch.tensor([0.0])
|
| 5957 |
|
|
|
|
| 5958 |
for i in range(16):
|
| 5959 |
p = f"{prefix}.neg.fa{i}"
|
| 5960 |
tensors[f"{p}.xor.layer1.or.weight"] = torch.tensor([1.0, 1.0])
|
|
|
|
| 5969 |
tensors[f"{p}.cout.weight"] = torch.tensor([1.0, 1.0])
|
| 5970 |
tensors[f"{p}.cout.bias"] = torch.tensor([-1.0])
|
| 5971 |
|
| 5972 |
+
# === OUTPUT SELECTION ===
|
| 5973 |
+
# Select between positive path, negative path, and zero
|
| 5974 |
+
|
| 5975 |
+
# NOT of sign bit for muxing positive path
|
| 5976 |
+
tensors[f"{prefix}.not_sign.weight"] = torch.tensor([-1.0])
|
| 5977 |
+
tensors[f"{prefix}.not_sign.bias"] = torch.tensor([0.0])
|
| 5978 |
+
|
| 5979 |
for i in range(16):
|
| 5980 |
tensors[f"{prefix}.out{i}.pos_path.weight"] = torch.tensor([1.0, 1.0])
|
| 5981 |
tensors[f"{prefix}.out{i}.pos_path.bias"] = torch.tensor([-2.0])
|
|
|
|
| 6219 |
|
| 6220 |
print(f"Loaded {len(tensors)} tensors")
|
| 6221 |
|
| 6222 |
+
# Remove old tensors for circuits we're rebuilding
|
| 6223 |
old_float16_add = [k for k in tensors.keys() if k.startswith('float16.add')]
|
| 6224 |
for k in old_float16_add:
|
| 6225 |
del tensors[k]
|
| 6226 |
print(f"Removed {len(old_float16_add)} old float16.add tensors")
|
| 6227 |
|
| 6228 |
+
old_float16_toint = [k for k in tensors.keys() if k.startswith('float16.toint')]
|
| 6229 |
+
for k in old_float16_toint:
|
| 6230 |
+
del tensors[k]
|
| 6231 |
+
print(f"Removed {len(old_float16_toint)} old float16.toint tensors")
|
| 6232 |
+
|
| 6233 |
# Build new circuits
|
| 6234 |
print("Building new circuits...")
|
| 6235 |
clz_tensors = build_clz8bit_tensors()
|