CharlesCNorton commited on
Commit ·
7fc245f
1
Parent(s): 313da7e
Add float16 deg2rad/rad2deg circuits
Browse files- Add deg2rad/rad2deg to unary op LUT build and gate input inference\n- Extend eval expected-value logic and float16 unary tests for deg2rad/rad2deg\n- Wire calculator function dispatch for deg2rad/rad2deg\n- Rebuild arithmetic.safetensors with the new LUT tensors
- arithmetic.safetensors +2 -2
- build.py +4 -0
- calculator.py +6 -0
- eval.py +6 -0
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a64904f581a9aaad1a8908933bd2338a3b73737d05a296986d87f6ccf73b4bf8
|
| 3 |
+
size 469024560
|
build.py
CHANGED
|
@@ -1392,6 +1392,7 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
|
|
| 1392 |
if gate.startswith('float16.sqrt') or gate.startswith('float16.rsqrt') or \
|
| 1393 |
gate.startswith('float16.exp') or gate.startswith('float16.ln') or \
|
| 1394 |
gate.startswith('float16.log2') or gate.startswith('float16.log10') or \
|
|
|
|
| 1395 |
gate.startswith('float16.sin') or gate.startswith('float16.cos') or \
|
| 1396 |
gate.startswith('float16.tan') or gate.startswith('float16.tanh') or \
|
| 1397 |
gate.startswith('float16.sin_deg') or gate.startswith('float16.cos_deg') or \
|
|
@@ -11174,6 +11175,7 @@ def main():
|
|
| 11174 |
k.startswith('float16.sqrt') or k.startswith('float16.rsqrt') or
|
| 11175 |
k.startswith('float16.exp') or k.startswith('float16.ln') or
|
| 11176 |
k.startswith('float16.log2') or k.startswith('float16.log10') or
|
|
|
|
| 11177 |
k.startswith('float16.sin') or k.startswith('float16.cos') or
|
| 11178 |
k.startswith('float16.tan') or k.startswith('float16.tanh') or
|
| 11179 |
k.startswith('float16.sin_deg') or k.startswith('float16.cos_deg') or
|
|
@@ -11301,6 +11303,8 @@ def main():
|
|
| 11301 |
"ln": torch.log,
|
| 11302 |
"log2": torch.log2,
|
| 11303 |
"log10": unary_float32(torch.log10),
|
|
|
|
|
|
|
| 11304 |
"sin": torch.sin,
|
| 11305 |
"cos": torch.cos,
|
| 11306 |
"tan": torch.tan,
|
|
|
|
| 1392 |
if gate.startswith('float16.sqrt') or gate.startswith('float16.rsqrt') or \
|
| 1393 |
gate.startswith('float16.exp') or gate.startswith('float16.ln') or \
|
| 1394 |
gate.startswith('float16.log2') or gate.startswith('float16.log10') or \
|
| 1395 |
+
gate.startswith('float16.deg2rad') or gate.startswith('float16.rad2deg') or \
|
| 1396 |
gate.startswith('float16.sin') or gate.startswith('float16.cos') or \
|
| 1397 |
gate.startswith('float16.tan') or gate.startswith('float16.tanh') or \
|
| 1398 |
gate.startswith('float16.sin_deg') or gate.startswith('float16.cos_deg') or \
|
|
|
|
| 11175 |
k.startswith('float16.sqrt') or k.startswith('float16.rsqrt') or
|
| 11176 |
k.startswith('float16.exp') or k.startswith('float16.ln') or
|
| 11177 |
k.startswith('float16.log2') or k.startswith('float16.log10') or
|
| 11178 |
+
k.startswith('float16.deg2rad') or k.startswith('float16.rad2deg') or
|
| 11179 |
k.startswith('float16.sin') or k.startswith('float16.cos') or
|
| 11180 |
k.startswith('float16.tan') or k.startswith('float16.tanh') or
|
| 11181 |
k.startswith('float16.sin_deg') or k.startswith('float16.cos_deg') or
|
|
|
|
| 11303 |
"ln": torch.log,
|
| 11304 |
"log2": torch.log2,
|
| 11305 |
"log10": unary_float32(torch.log10),
|
| 11306 |
+
"deg2rad": unary_float32(lambda x: x * (math.pi / 180.0)),
|
| 11307 |
+
"rad2deg": unary_float32(lambda x: x * (180.0 / math.pi)),
|
| 11308 |
"sin": torch.sin,
|
| 11309 |
"cos": torch.cos,
|
| 11310 |
"tan": torch.tan,
|
calculator.py
CHANGED
|
@@ -562,6 +562,8 @@ class ThresholdCalculator:
|
|
| 562 |
"log": "float16.ln",
|
| 563 |
"log2": "float16.log2",
|
| 564 |
"log10": "float16.log10",
|
|
|
|
|
|
|
| 565 |
"sin": "float16.sin",
|
| 566 |
"cos": "float16.cos",
|
| 567 |
"tan": "float16.tan",
|
|
@@ -740,6 +742,10 @@ class ThresholdCalculator:
|
|
| 740 |
return run_unary("float16.log2", x, fname)
|
| 741 |
if fname == "log10":
|
| 742 |
return run_unary("float16.log10", x, fname)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
if fname == "sin":
|
| 744 |
prefix = "float16.sin_deg" if use_degrees else "float16.sin"
|
| 745 |
return run_unary(prefix, x, fname)
|
|
|
|
| 562 |
"log": "float16.ln",
|
| 563 |
"log2": "float16.log2",
|
| 564 |
"log10": "float16.log10",
|
| 565 |
+
"deg2rad": "float16.deg2rad",
|
| 566 |
+
"rad2deg": "float16.rad2deg",
|
| 567 |
"sin": "float16.sin",
|
| 568 |
"cos": "float16.cos",
|
| 569 |
"tan": "float16.tan",
|
|
|
|
| 742 |
return run_unary("float16.log2", x, fname)
|
| 743 |
if fname == "log10":
|
| 744 |
return run_unary("float16.log10", x, fname)
|
| 745 |
+
if fname == "deg2rad":
|
| 746 |
+
return run_unary("float16.deg2rad", x, fname)
|
| 747 |
+
if fname == "rad2deg":
|
| 748 |
+
return run_unary("float16.rad2deg", x, fname)
|
| 749 |
if fname == "sin":
|
| 750 |
prefix = "float16.sin_deg" if use_degrees else "float16.sin"
|
| 751 |
return run_unary(prefix, x, fname)
|
eval.py
CHANGED
|
@@ -865,6 +865,10 @@ def float16_expected_bits_unary(op: str, a_bits: int) -> Tuple[int, bool]:
|
|
| 865 |
out = torch.log2(a16).item()
|
| 866 |
elif op == "log10":
|
| 867 |
out = torch.log10(a32).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
elif op == "sin":
|
| 869 |
out = torch.sin(a16).item()
|
| 870 |
elif op == "cos":
|
|
@@ -2785,6 +2789,8 @@ def test_float16_unary(ctx: EvalContext) -> List[TestResult]:
|
|
| 2785 |
("float16.ln", "ln"),
|
| 2786 |
("float16.log2", "log2"),
|
| 2787 |
("float16.log10", "log10"),
|
|
|
|
|
|
|
| 2788 |
("float16.sin", "sin"),
|
| 2789 |
("float16.cos", "cos"),
|
| 2790 |
("float16.tan", "tan"),
|
|
|
|
| 865 |
out = torch.log2(a16).item()
|
| 866 |
elif op == "log10":
|
| 867 |
out = torch.log10(a32).item()
|
| 868 |
+
elif op == "deg2rad":
|
| 869 |
+
out = (a32 * (math.pi / 180.0)).item()
|
| 870 |
+
elif op == "rad2deg":
|
| 871 |
+
out = (a32 * (180.0 / math.pi)).item()
|
| 872 |
elif op == "sin":
|
| 873 |
out = torch.sin(a16).item()
|
| 874 |
elif op == "cos":
|
|
|
|
| 2789 |
("float16.ln", "ln"),
|
| 2790 |
("float16.log2", "log2"),
|
| 2791 |
("float16.log10", "log10"),
|
| 2792 |
+
("float16.deg2rad", "deg2rad"),
|
| 2793 |
+
("float16.rad2deg", "rad2deg"),
|
| 2794 |
("float16.sin", "sin"),
|
| 2795 |
("float16.cos", "cos"),
|
| 2796 |
("float16.tan", "tan"),
|