CharlesCNorton commited on
Commit ·
3c7f544
1
Parent(s): b8c48ca
Add degree-mode trig LUTs
Browse files- arithmetic.safetensors +2 -2
- build.py +35 -0
- calculator.py +39 -7
- eval.py +18 -0
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bc68cbb34ffc8d5a3dcdda3c63a2a806175e6e90ee83c7fe7ca37ed287e087d
|
| 3 |
+
size 436688020
|
build.py
CHANGED
|
@@ -89,6 +89,20 @@ def unary_float32(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[to
|
|
| 89 |
return _fn
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def build_float16_lut_match_tensors(prefix: str) -> Dict[str, torch.Tensor]:
|
| 93 |
"""Build exact-match gates for all 16-bit patterns under prefix.matchXXXX."""
|
| 94 |
tensors: Dict[str, torch.Tensor] = {}
|
|
@@ -1302,6 +1316,9 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
|
|
| 1302 |
gate.startswith('float16.log2') or gate.startswith('float16.log10') or \
|
| 1303 |
gate.startswith('float16.sin') or gate.startswith('float16.cos') or \
|
| 1304 |
gate.startswith('float16.tan') or gate.startswith('float16.tanh') or \
|
|
|
|
|
|
|
|
|
|
| 1305 |
gate.startswith('float16.asin') or gate.startswith('float16.acos') or \
|
| 1306 |
gate.startswith('float16.atan') or gate.startswith('float16.sinh') or \
|
| 1307 |
gate.startswith('float16.cosh') or gate.startswith('float16.floor') or \
|
|
@@ -11081,6 +11098,9 @@ def main():
|
|
| 11081 |
k.startswith('float16.log2') or k.startswith('float16.log10') or
|
| 11082 |
k.startswith('float16.sin') or k.startswith('float16.cos') or
|
| 11083 |
k.startswith('float16.tan') or k.startswith('float16.tanh') or
|
|
|
|
|
|
|
|
|
|
| 11084 |
k.startswith('float16.asin') or k.startswith('float16.acos') or
|
| 11085 |
k.startswith('float16.atan') or k.startswith('float16.sinh') or
|
| 11086 |
k.startswith('float16.cosh') or k.startswith('float16.floor') or
|
|
@@ -11205,6 +11225,14 @@ def main():
|
|
| 11205 |
"ceil": unary_float32(torch.ceil),
|
| 11206 |
"round": unary_float32(torch.round),
|
| 11207 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11208 |
lut_outputs: Dict[str, List[int]] = {}
|
| 11209 |
for name, fn in unary_ops.items():
|
| 11210 |
print(f" computing float16.{name} LUT...")
|
|
@@ -11213,6 +11241,13 @@ def main():
|
|
| 11213 |
op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
|
| 11214 |
tensors.update(op_tensors)
|
| 11215 |
print(f" float16.{name}: {len(op_tensors)} tensors")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11216 |
|
| 11217 |
# float16.pow (ln -> mul -> exp)
|
| 11218 |
pow_tensors = build_float16_pow_tensors(mul_tensors,
|
|
|
|
| 89 |
return _fn
|
| 90 |
|
| 91 |
|
| 92 |
+
def wrap_deg_trig(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 93 |
+
"""Wrap trig op to interpret input as degrees."""
|
| 94 |
+
def _fn(x: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
return op_fn(x.float() * (math.pi / 180.0))
|
| 96 |
+
return _fn
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def wrap_inv_trig_deg(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 100 |
+
"""Wrap inverse trig op to return degrees."""
|
| 101 |
+
def _fn(x: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
return op_fn(x.float()) * (180.0 / math.pi)
|
| 103 |
+
return _fn
|
| 104 |
+
|
| 105 |
+
|
| 106 |
def build_float16_lut_match_tensors(prefix: str) -> Dict[str, torch.Tensor]:
|
| 107 |
"""Build exact-match gates for all 16-bit patterns under prefix.matchXXXX."""
|
| 108 |
tensors: Dict[str, torch.Tensor] = {}
|
|
|
|
| 1316 |
gate.startswith('float16.log2') or gate.startswith('float16.log10') or \
|
| 1317 |
gate.startswith('float16.sin') or gate.startswith('float16.cos') or \
|
| 1318 |
gate.startswith('float16.tan') or gate.startswith('float16.tanh') or \
|
| 1319 |
+
gate.startswith('float16.sin_deg') or gate.startswith('float16.cos_deg') or \
|
| 1320 |
+
gate.startswith('float16.tan_deg') or gate.startswith('float16.asin_deg') or \
|
| 1321 |
+
gate.startswith('float16.acos_deg') or gate.startswith('float16.atan_deg') or \
|
| 1322 |
gate.startswith('float16.asin') or gate.startswith('float16.acos') or \
|
| 1323 |
gate.startswith('float16.atan') or gate.startswith('float16.sinh') or \
|
| 1324 |
gate.startswith('float16.cosh') or gate.startswith('float16.floor') or \
|
|
|
|
| 11098 |
k.startswith('float16.log2') or k.startswith('float16.log10') or
|
| 11099 |
k.startswith('float16.sin') or k.startswith('float16.cos') or
|
| 11100 |
k.startswith('float16.tan') or k.startswith('float16.tanh') or
|
| 11101 |
+
k.startswith('float16.sin_deg') or k.startswith('float16.cos_deg') or
|
| 11102 |
+
k.startswith('float16.tan_deg') or k.startswith('float16.asin_deg') or
|
| 11103 |
+
k.startswith('float16.acos_deg') or k.startswith('float16.atan_deg') or
|
| 11104 |
k.startswith('float16.asin') or k.startswith('float16.acos') or
|
| 11105 |
k.startswith('float16.atan') or k.startswith('float16.sinh') or
|
| 11106 |
k.startswith('float16.cosh') or k.startswith('float16.floor') or
|
|
|
|
| 11225 |
"ceil": unary_float32(torch.ceil),
|
| 11226 |
"round": unary_float32(torch.round),
|
| 11227 |
}
|
| 11228 |
+
deg_ops = {
|
| 11229 |
+
"sin_deg": wrap_deg_trig(torch.sin),
|
| 11230 |
+
"cos_deg": wrap_deg_trig(torch.cos),
|
| 11231 |
+
"tan_deg": wrap_deg_trig(torch.tan),
|
| 11232 |
+
"asin_deg": wrap_inv_trig_deg(torch.asin),
|
| 11233 |
+
"acos_deg": wrap_inv_trig_deg(torch.acos),
|
| 11234 |
+
"atan_deg": wrap_inv_trig_deg(torch.atan),
|
| 11235 |
+
}
|
| 11236 |
lut_outputs: Dict[str, List[int]] = {}
|
| 11237 |
for name, fn in unary_ops.items():
|
| 11238 |
print(f" computing float16.{name} LUT...")
|
|
|
|
| 11241 |
op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
|
| 11242 |
tensors.update(op_tensors)
|
| 11243 |
print(f" float16.{name}: {len(op_tensors)} tensors")
|
| 11244 |
+
for name, fn in deg_ops.items():
|
| 11245 |
+
print(f" computing float16.{name} LUT...")
|
| 11246 |
+
outputs = compute_float16_unary_lut_outputs(fn)
|
| 11247 |
+
lut_outputs[name] = outputs
|
| 11248 |
+
op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
|
| 11249 |
+
tensors.update(op_tensors)
|
| 11250 |
+
print(f" float16.{name}: {len(op_tensors)} tensors")
|
| 11251 |
|
| 11252 |
# float16.pow (ln -> mul -> exp)
|
| 11253 |
pow_tensors = build_float16_pow_tensors(mul_tensors,
|
calculator.py
CHANGED
|
@@ -493,11 +493,18 @@ class ThresholdCalculator:
|
|
| 493 |
out_int = bits_to_int(result.bits)
|
| 494 |
return float16_bits_to_float(out_int), result
|
| 495 |
|
| 496 |
-
def evaluate_rpn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
"""Evaluate an expression from RPN tokens using float16 circuits."""
|
| 498 |
total_elapsed = 0.0
|
| 499 |
total_gates = 0
|
| 500 |
non_gate_events: List[str] = []
|
|
|
|
|
|
|
| 501 |
|
| 502 |
def run_prefix(prefix: str, inputs: Dict[str, object]) -> int:
|
| 503 |
nonlocal total_elapsed, total_gates
|
|
@@ -550,7 +557,12 @@ class ThresholdCalculator:
|
|
| 550 |
if not stack:
|
| 551 |
raise RuntimeError("stack underflow")
|
| 552 |
x = stack.pop()
|
| 553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
stack.append(out)
|
| 555 |
continue
|
| 556 |
if tok in {"+", "-", "*", "/", "^"}:
|
|
@@ -590,9 +602,16 @@ class ThresholdCalculator:
|
|
| 590 |
non_gate_events=non_gate_events,
|
| 591 |
)
|
| 592 |
|
| 593 |
-
def evaluate_expr(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
"""Evaluate a calculator expression using float16 circuits."""
|
| 595 |
expr = normalize_expr(expr)
|
|
|
|
|
|
|
| 596 |
tree = ast.parse(expr, mode="eval")
|
| 597 |
|
| 598 |
total_elapsed = 0.0
|
|
@@ -666,13 +685,25 @@ class ThresholdCalculator:
|
|
| 666 |
if fname == "log10":
|
| 667 |
return run_prefix("float16.log10", {"x": x})
|
| 668 |
if fname == "sin":
|
| 669 |
-
|
|
|
|
| 670 |
if fname == "cos":
|
| 671 |
-
|
|
|
|
| 672 |
if fname == "tan":
|
| 673 |
-
|
|
|
|
| 674 |
if fname == "tanh":
|
| 675 |
return run_prefix("float16.tanh", {"x": x})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
if fname == "asin":
|
| 677 |
return run_prefix("float16.asin", {"x": x})
|
| 678 |
if fname == "acos":
|
|
@@ -720,6 +751,7 @@ def main() -> int:
|
|
| 720 |
parser.add_argument("--inputs", nargs="*", help="Explicit inputs as name=value (e.g., a=0x3c00)")
|
| 721 |
parser.add_argument("--hex", action="store_true", help="Parse numeric inputs as hex")
|
| 722 |
parser.add_argument("--expr", help="Evaluate expression using float16 circuits")
|
|
|
|
| 723 |
parser.add_argument("--json", action="store_true", help="Output JSON result")
|
| 724 |
parser.add_argument("--strict", action="store_true", help="Warn if any non-gate path is used")
|
| 725 |
args = parser.parse_args()
|
|
@@ -750,7 +782,7 @@ def main() -> int:
|
|
| 750 |
|
| 751 |
if args.expr or (args.prefix and not args.values and not args.inputs and looks_like_expression(args.prefix)):
|
| 752 |
expr = args.expr if args.expr else args.prefix
|
| 753 |
-
result = calc.evaluate_expr(expr)
|
| 754 |
out_int = bits_to_int(result.bits)
|
| 755 |
return emit_result("expr", out_int, result, expr=expr)
|
| 756 |
|
|
|
|
| 493 |
out_int = bits_to_int(result.bits)
|
| 494 |
return float16_bits_to_float(out_int), result
|
| 495 |
|
| 496 |
+
def evaluate_rpn(
|
| 497 |
+
self,
|
| 498 |
+
tokens: Sequence[str],
|
| 499 |
+
force_gate_eval: bool = True,
|
| 500 |
+
angle_mode: str = "rad",
|
| 501 |
+
) -> EvalResult:
|
| 502 |
"""Evaluate an expression from RPN tokens using float16 circuits."""
|
| 503 |
total_elapsed = 0.0
|
| 504 |
total_gates = 0
|
| 505 |
non_gate_events: List[str] = []
|
| 506 |
+
angle_mode = (angle_mode or "rad").lower()
|
| 507 |
+
use_degrees = angle_mode.startswith("deg")
|
| 508 |
|
| 509 |
def run_prefix(prefix: str, inputs: Dict[str, object]) -> int:
|
| 510 |
nonlocal total_elapsed, total_gates
|
|
|
|
| 557 |
if not stack:
|
| 558 |
raise RuntimeError("stack underflow")
|
| 559 |
x = stack.pop()
|
| 560 |
+
prefix = unary_ops[tok]
|
| 561 |
+
if use_degrees and tok in ("sin", "cos", "tan"):
|
| 562 |
+
prefix = f"float16.{tok}_deg"
|
| 563 |
+
elif use_degrees and tok in ("asin", "acos", "atan"):
|
| 564 |
+
prefix = f"float16.{tok}_deg"
|
| 565 |
+
out = run_prefix(prefix, {"x": x})
|
| 566 |
stack.append(out)
|
| 567 |
continue
|
| 568 |
if tok in {"+", "-", "*", "/", "^"}:
|
|
|
|
| 602 |
non_gate_events=non_gate_events,
|
| 603 |
)
|
| 604 |
|
| 605 |
+
def evaluate_expr(
|
| 606 |
+
self,
|
| 607 |
+
expr: str,
|
| 608 |
+
force_gate_eval: bool = True,
|
| 609 |
+
angle_mode: str = "rad",
|
| 610 |
+
) -> EvalResult:
|
| 611 |
"""Evaluate a calculator expression using float16 circuits."""
|
| 612 |
expr = normalize_expr(expr)
|
| 613 |
+
angle_mode = (angle_mode or "rad").lower()
|
| 614 |
+
use_degrees = angle_mode.startswith("deg")
|
| 615 |
tree = ast.parse(expr, mode="eval")
|
| 616 |
|
| 617 |
total_elapsed = 0.0
|
|
|
|
| 685 |
if fname == "log10":
|
| 686 |
return run_prefix("float16.log10", {"x": x})
|
| 687 |
if fname == "sin":
|
| 688 |
+
prefix = "float16.sin_deg" if use_degrees else "float16.sin"
|
| 689 |
+
return run_prefix(prefix, {"x": x})
|
| 690 |
if fname == "cos":
|
| 691 |
+
prefix = "float16.cos_deg" if use_degrees else "float16.cos"
|
| 692 |
+
return run_prefix(prefix, {"x": x})
|
| 693 |
if fname == "tan":
|
| 694 |
+
prefix = "float16.tan_deg" if use_degrees else "float16.tan"
|
| 695 |
+
return run_prefix(prefix, {"x": x})
|
| 696 |
if fname == "tanh":
|
| 697 |
return run_prefix("float16.tanh", {"x": x})
|
| 698 |
+
if fname == "asin":
|
| 699 |
+
prefix = "float16.asin_deg" if use_degrees else "float16.asin"
|
| 700 |
+
return run_prefix(prefix, {"x": x})
|
| 701 |
+
if fname == "acos":
|
| 702 |
+
prefix = "float16.acos_deg" if use_degrees else "float16.acos"
|
| 703 |
+
return run_prefix(prefix, {"x": x})
|
| 704 |
+
if fname == "atan":
|
| 705 |
+
prefix = "float16.atan_deg" if use_degrees else "float16.atan"
|
| 706 |
+
return run_prefix(prefix, {"x": x})
|
| 707 |
if fname == "asin":
|
| 708 |
return run_prefix("float16.asin", {"x": x})
|
| 709 |
if fname == "acos":
|
|
|
|
| 751 |
parser.add_argument("--inputs", nargs="*", help="Explicit inputs as name=value (e.g., a=0x3c00)")
|
| 752 |
parser.add_argument("--hex", action="store_true", help="Parse numeric inputs as hex")
|
| 753 |
parser.add_argument("--expr", help="Evaluate expression using float16 circuits")
|
| 754 |
+
parser.add_argument("--angle", default="rad", choices=["rad", "deg"], help="Angle mode for trig functions")
|
| 755 |
parser.add_argument("--json", action="store_true", help="Output JSON result")
|
| 756 |
parser.add_argument("--strict", action="store_true", help="Warn if any non-gate path is used")
|
| 757 |
args = parser.parse_args()
|
|
|
|
| 782 |
|
| 783 |
if args.expr or (args.prefix and not args.values and not args.inputs and looks_like_expression(args.prefix)):
|
| 784 |
expr = args.expr if args.expr else args.prefix
|
| 785 |
+
result = calc.evaluate_expr(expr, angle_mode=args.angle)
|
| 786 |
out_int = bits_to_int(result.bits)
|
| 787 |
return emit_result("expr", out_int, result, expr=expr)
|
| 788 |
|
eval.py
CHANGED
|
@@ -859,6 +859,18 @@ def float16_expected_bits_unary(op: str, a_bits: int) -> Tuple[int, bool]:
|
|
| 859 |
out = torch.ceil(a32).item()
|
| 860 |
elif op == "round":
|
| 861 |
out = torch.round(a32).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
else:
|
| 863 |
raise ValueError(f"unknown op: {op}")
|
| 864 |
if out != out:
|
|
@@ -2707,6 +2719,12 @@ def test_float16_unary(ctx: EvalContext) -> List[TestResult]:
|
|
| 2707 |
("float16.cos", "cos"),
|
| 2708 |
("float16.tan", "tan"),
|
| 2709 |
("float16.tanh", "tanh"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2710 |
("float16.asin", "asin"),
|
| 2711 |
("float16.acos", "acos"),
|
| 2712 |
("float16.atan", "atan"),
|
|
|
|
| 859 |
out = torch.ceil(a32).item()
|
| 860 |
elif op == "round":
|
| 861 |
out = torch.round(a32).item()
|
| 862 |
+
elif op == "sin_deg":
|
| 863 |
+
out = torch.sin(a32 * (math.pi / 180.0)).item()
|
| 864 |
+
elif op == "cos_deg":
|
| 865 |
+
out = torch.cos(a32 * (math.pi / 180.0)).item()
|
| 866 |
+
elif op == "tan_deg":
|
| 867 |
+
out = torch.tan(a32 * (math.pi / 180.0)).item()
|
| 868 |
+
elif op == "asin_deg":
|
| 869 |
+
out = (torch.asin(a32) * (180.0 / math.pi)).item()
|
| 870 |
+
elif op == "acos_deg":
|
| 871 |
+
out = (torch.acos(a32) * (180.0 / math.pi)).item()
|
| 872 |
+
elif op == "atan_deg":
|
| 873 |
+
out = (torch.atan(a32) * (180.0 / math.pi)).item()
|
| 874 |
else:
|
| 875 |
raise ValueError(f"unknown op: {op}")
|
| 876 |
if out != out:
|
|
|
|
| 2719 |
("float16.cos", "cos"),
|
| 2720 |
("float16.tan", "tan"),
|
| 2721 |
("float16.tanh", "tanh"),
|
| 2722 |
+
("float16.sin_deg", "sin_deg"),
|
| 2723 |
+
("float16.cos_deg", "cos_deg"),
|
| 2724 |
+
("float16.tan_deg", "tan_deg"),
|
| 2725 |
+
("float16.asin_deg", "asin_deg"),
|
| 2726 |
+
("float16.acos_deg", "acos_deg"),
|
| 2727 |
+
("float16.atan_deg", "atan_deg"),
|
| 2728 |
("float16.asin", "asin"),
|
| 2729 |
("float16.acos", "acos"),
|
| 2730 |
("float16.atan", "atan"),
|