CharlesCNorton commited on
Commit
3c7f544
·
1 Parent(s): b8c48ca

Add degree-mode trig LUTs

Browse files
Files changed (4) hide show
  1. arithmetic.safetensors +2 -2
  2. build.py +35 -0
  3. calculator.py +39 -7
  4. eval.py +18 -0
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a24be4920c915a295d057e727c11e68459e5271dcfa6a40831c047f00da7061
3
- size 361034908
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bc68cbb34ffc8d5a3dcdda3c63a2a806175e6e90ee83c7fe7ca37ed287e087d
3
+ size 436688020
build.py CHANGED
@@ -89,6 +89,20 @@ def unary_float32(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[to
89
  return _fn
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def build_float16_lut_match_tensors(prefix: str) -> Dict[str, torch.Tensor]:
93
  """Build exact-match gates for all 16-bit patterns under prefix.matchXXXX."""
94
  tensors: Dict[str, torch.Tensor] = {}
@@ -1302,6 +1316,9 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
1302
  gate.startswith('float16.log2') or gate.startswith('float16.log10') or \
1303
  gate.startswith('float16.sin') or gate.startswith('float16.cos') or \
1304
  gate.startswith('float16.tan') or gate.startswith('float16.tanh') or \
 
 
 
1305
  gate.startswith('float16.asin') or gate.startswith('float16.acos') or \
1306
  gate.startswith('float16.atan') or gate.startswith('float16.sinh') or \
1307
  gate.startswith('float16.cosh') or gate.startswith('float16.floor') or \
@@ -11081,6 +11098,9 @@ def main():
11081
  k.startswith('float16.log2') or k.startswith('float16.log10') or
11082
  k.startswith('float16.sin') or k.startswith('float16.cos') or
11083
  k.startswith('float16.tan') or k.startswith('float16.tanh') or
 
 
 
11084
  k.startswith('float16.asin') or k.startswith('float16.acos') or
11085
  k.startswith('float16.atan') or k.startswith('float16.sinh') or
11086
  k.startswith('float16.cosh') or k.startswith('float16.floor') or
@@ -11205,6 +11225,14 @@ def main():
11205
  "ceil": unary_float32(torch.ceil),
11206
  "round": unary_float32(torch.round),
11207
  }
 
 
 
 
 
 
 
 
11208
  lut_outputs: Dict[str, List[int]] = {}
11209
  for name, fn in unary_ops.items():
11210
  print(f" computing float16.{name} LUT...")
@@ -11213,6 +11241,13 @@ def main():
11213
  op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
11214
  tensors.update(op_tensors)
11215
  print(f" float16.{name}: {len(op_tensors)} tensors")
 
 
 
 
 
 
 
11216
 
11217
  # float16.pow (ln -> mul -> exp)
11218
  pow_tensors = build_float16_pow_tensors(mul_tensors,
 
89
  return _fn
90
 
91
 
92
+ def wrap_deg_trig(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
93
+ """Wrap trig op to interpret input as degrees."""
94
+ def _fn(x: torch.Tensor) -> torch.Tensor:
95
+ return op_fn(x.float() * (math.pi / 180.0))
96
+ return _fn
97
+
98
+
99
+ def wrap_inv_trig_deg(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
100
+ """Wrap inverse trig op to return degrees."""
101
+ def _fn(x: torch.Tensor) -> torch.Tensor:
102
+ return op_fn(x.float()) * (180.0 / math.pi)
103
+ return _fn
104
+
105
+
106
  def build_float16_lut_match_tensors(prefix: str) -> Dict[str, torch.Tensor]:
107
  """Build exact-match gates for all 16-bit patterns under prefix.matchXXXX."""
108
  tensors: Dict[str, torch.Tensor] = {}
 
1316
  gate.startswith('float16.log2') or gate.startswith('float16.log10') or \
1317
  gate.startswith('float16.sin') or gate.startswith('float16.cos') or \
1318
  gate.startswith('float16.tan') or gate.startswith('float16.tanh') or \
1319
+ gate.startswith('float16.sin_deg') or gate.startswith('float16.cos_deg') or \
1320
+ gate.startswith('float16.tan_deg') or gate.startswith('float16.asin_deg') or \
1321
+ gate.startswith('float16.acos_deg') or gate.startswith('float16.atan_deg') or \
1322
  gate.startswith('float16.asin') or gate.startswith('float16.acos') or \
1323
  gate.startswith('float16.atan') or gate.startswith('float16.sinh') or \
1324
  gate.startswith('float16.cosh') or gate.startswith('float16.floor') or \
 
11098
  k.startswith('float16.log2') or k.startswith('float16.log10') or
11099
  k.startswith('float16.sin') or k.startswith('float16.cos') or
11100
  k.startswith('float16.tan') or k.startswith('float16.tanh') or
11101
+ k.startswith('float16.sin_deg') or k.startswith('float16.cos_deg') or
11102
+ k.startswith('float16.tan_deg') or k.startswith('float16.asin_deg') or
11103
+ k.startswith('float16.acos_deg') or k.startswith('float16.atan_deg') or
11104
  k.startswith('float16.asin') or k.startswith('float16.acos') or
11105
  k.startswith('float16.atan') or k.startswith('float16.sinh') or
11106
  k.startswith('float16.cosh') or k.startswith('float16.floor') or
 
11225
  "ceil": unary_float32(torch.ceil),
11226
  "round": unary_float32(torch.round),
11227
  }
11228
+ deg_ops = {
11229
+ "sin_deg": wrap_deg_trig(torch.sin),
11230
+ "cos_deg": wrap_deg_trig(torch.cos),
11231
+ "tan_deg": wrap_deg_trig(torch.tan),
11232
+ "asin_deg": wrap_inv_trig_deg(torch.asin),
11233
+ "acos_deg": wrap_inv_trig_deg(torch.acos),
11234
+ "atan_deg": wrap_inv_trig_deg(torch.atan),
11235
+ }
11236
  lut_outputs: Dict[str, List[int]] = {}
11237
  for name, fn in unary_ops.items():
11238
  print(f" computing float16.{name} LUT...")
 
11241
  op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
11242
  tensors.update(op_tensors)
11243
  print(f" float16.{name}: {len(op_tensors)} tensors")
11244
+ for name, fn in deg_ops.items():
11245
+ print(f" computing float16.{name} LUT...")
11246
+ outputs = compute_float16_unary_lut_outputs(fn)
11247
+ lut_outputs[name] = outputs
11248
+ op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
11249
+ tensors.update(op_tensors)
11250
+ print(f" float16.{name}: {len(op_tensors)} tensors")
11251
 
11252
  # float16.pow (ln -> mul -> exp)
11253
  pow_tensors = build_float16_pow_tensors(mul_tensors,
calculator.py CHANGED
@@ -493,11 +493,18 @@ class ThresholdCalculator:
493
  out_int = bits_to_int(result.bits)
494
  return float16_bits_to_float(out_int), result
495
 
496
- def evaluate_rpn(self, tokens: Sequence[str], force_gate_eval: bool = True) -> EvalResult:
 
 
 
 
 
497
  """Evaluate an expression from RPN tokens using float16 circuits."""
498
  total_elapsed = 0.0
499
  total_gates = 0
500
  non_gate_events: List[str] = []
 
 
501
 
502
  def run_prefix(prefix: str, inputs: Dict[str, object]) -> int:
503
  nonlocal total_elapsed, total_gates
@@ -550,7 +557,12 @@ class ThresholdCalculator:
550
  if not stack:
551
  raise RuntimeError("stack underflow")
552
  x = stack.pop()
553
- out = run_prefix(unary_ops[tok], {"x": x})
 
 
 
 
 
554
  stack.append(out)
555
  continue
556
  if tok in {"+", "-", "*", "/", "^"}:
@@ -590,9 +602,16 @@ class ThresholdCalculator:
590
  non_gate_events=non_gate_events,
591
  )
592
 
593
- def evaluate_expr(self, expr: str, force_gate_eval: bool = True) -> EvalResult:
 
 
 
 
 
594
  """Evaluate a calculator expression using float16 circuits."""
595
  expr = normalize_expr(expr)
 
 
596
  tree = ast.parse(expr, mode="eval")
597
 
598
  total_elapsed = 0.0
@@ -666,13 +685,25 @@ class ThresholdCalculator:
666
  if fname == "log10":
667
  return run_prefix("float16.log10", {"x": x})
668
  if fname == "sin":
669
- return run_prefix("float16.sin", {"x": x})
 
670
  if fname == "cos":
671
- return run_prefix("float16.cos", {"x": x})
 
672
  if fname == "tan":
673
- return run_prefix("float16.tan", {"x": x})
 
674
  if fname == "tanh":
675
  return run_prefix("float16.tanh", {"x": x})
 
 
 
 
 
 
 
 
 
676
  if fname == "asin":
677
  return run_prefix("float16.asin", {"x": x})
678
  if fname == "acos":
@@ -720,6 +751,7 @@ def main() -> int:
720
  parser.add_argument("--inputs", nargs="*", help="Explicit inputs as name=value (e.g., a=0x3c00)")
721
  parser.add_argument("--hex", action="store_true", help="Parse numeric inputs as hex")
722
  parser.add_argument("--expr", help="Evaluate expression using float16 circuits")
 
723
  parser.add_argument("--json", action="store_true", help="Output JSON result")
724
  parser.add_argument("--strict", action="store_true", help="Warn if any non-gate path is used")
725
  args = parser.parse_args()
@@ -750,7 +782,7 @@ def main() -> int:
750
 
751
  if args.expr or (args.prefix and not args.values and not args.inputs and looks_like_expression(args.prefix)):
752
  expr = args.expr if args.expr else args.prefix
753
- result = calc.evaluate_expr(expr)
754
  out_int = bits_to_int(result.bits)
755
  return emit_result("expr", out_int, result, expr=expr)
756
 
 
493
  out_int = bits_to_int(result.bits)
494
  return float16_bits_to_float(out_int), result
495
 
496
+ def evaluate_rpn(
497
+ self,
498
+ tokens: Sequence[str],
499
+ force_gate_eval: bool = True,
500
+ angle_mode: str = "rad",
501
+ ) -> EvalResult:
502
  """Evaluate an expression from RPN tokens using float16 circuits."""
503
  total_elapsed = 0.0
504
  total_gates = 0
505
  non_gate_events: List[str] = []
506
+ angle_mode = (angle_mode or "rad").lower()
507
+ use_degrees = angle_mode.startswith("deg")
508
 
509
  def run_prefix(prefix: str, inputs: Dict[str, object]) -> int:
510
  nonlocal total_elapsed, total_gates
 
557
  if not stack:
558
  raise RuntimeError("stack underflow")
559
  x = stack.pop()
560
+ prefix = unary_ops[tok]
561
+ if use_degrees and tok in ("sin", "cos", "tan"):
562
+ prefix = f"float16.{tok}_deg"
563
+ elif use_degrees and tok in ("asin", "acos", "atan"):
564
+ prefix = f"float16.{tok}_deg"
565
+ out = run_prefix(prefix, {"x": x})
566
  stack.append(out)
567
  continue
568
  if tok in {"+", "-", "*", "/", "^"}:
 
602
  non_gate_events=non_gate_events,
603
  )
604
 
605
+ def evaluate_expr(
606
+ self,
607
+ expr: str,
608
+ force_gate_eval: bool = True,
609
+ angle_mode: str = "rad",
610
+ ) -> EvalResult:
611
  """Evaluate a calculator expression using float16 circuits."""
612
  expr = normalize_expr(expr)
613
+ angle_mode = (angle_mode or "rad").lower()
614
+ use_degrees = angle_mode.startswith("deg")
615
  tree = ast.parse(expr, mode="eval")
616
 
617
  total_elapsed = 0.0
 
685
  if fname == "log10":
686
  return run_prefix("float16.log10", {"x": x})
687
  if fname == "sin":
688
+ prefix = "float16.sin_deg" if use_degrees else "float16.sin"
689
+ return run_prefix(prefix, {"x": x})
690
  if fname == "cos":
691
+ prefix = "float16.cos_deg" if use_degrees else "float16.cos"
692
+ return run_prefix(prefix, {"x": x})
693
  if fname == "tan":
694
+ prefix = "float16.tan_deg" if use_degrees else "float16.tan"
695
+ return run_prefix(prefix, {"x": x})
696
  if fname == "tanh":
697
  return run_prefix("float16.tanh", {"x": x})
698
+ if fname == "asin":
699
+ prefix = "float16.asin_deg" if use_degrees else "float16.asin"
700
+ return run_prefix(prefix, {"x": x})
701
+ if fname == "acos":
702
+ prefix = "float16.acos_deg" if use_degrees else "float16.acos"
703
+ return run_prefix(prefix, {"x": x})
704
+ if fname == "atan":
705
+ prefix = "float16.atan_deg" if use_degrees else "float16.atan"
706
+ return run_prefix(prefix, {"x": x})
707
  if fname == "asin":
708
  return run_prefix("float16.asin", {"x": x})
709
  if fname == "acos":
 
751
  parser.add_argument("--inputs", nargs="*", help="Explicit inputs as name=value (e.g., a=0x3c00)")
752
  parser.add_argument("--hex", action="store_true", help="Parse numeric inputs as hex")
753
  parser.add_argument("--expr", help="Evaluate expression using float16 circuits")
754
+ parser.add_argument("--angle", default="rad", choices=["rad", "deg"], help="Angle mode for trig functions")
755
  parser.add_argument("--json", action="store_true", help="Output JSON result")
756
  parser.add_argument("--strict", action="store_true", help="Warn if any non-gate path is used")
757
  args = parser.parse_args()
 
782
 
783
  if args.expr or (args.prefix and not args.values and not args.inputs and looks_like_expression(args.prefix)):
784
  expr = args.expr if args.expr else args.prefix
785
+ result = calc.evaluate_expr(expr, angle_mode=args.angle)
786
  out_int = bits_to_int(result.bits)
787
  return emit_result("expr", out_int, result, expr=expr)
788
 
eval.py CHANGED
@@ -859,6 +859,18 @@ def float16_expected_bits_unary(op: str, a_bits: int) -> Tuple[int, bool]:
859
  out = torch.ceil(a32).item()
860
  elif op == "round":
861
  out = torch.round(a32).item()
 
 
 
 
 
 
 
 
 
 
 
 
862
  else:
863
  raise ValueError(f"unknown op: {op}")
864
  if out != out:
@@ -2707,6 +2719,12 @@ def test_float16_unary(ctx: EvalContext) -> List[TestResult]:
2707
  ("float16.cos", "cos"),
2708
  ("float16.tan", "tan"),
2709
  ("float16.tanh", "tanh"),
 
 
 
 
 
 
2710
  ("float16.asin", "asin"),
2711
  ("float16.acos", "acos"),
2712
  ("float16.atan", "atan"),
 
859
  out = torch.ceil(a32).item()
860
  elif op == "round":
861
  out = torch.round(a32).item()
862
+ elif op == "sin_deg":
863
+ out = torch.sin(a32 * (math.pi / 180.0)).item()
864
+ elif op == "cos_deg":
865
+ out = torch.cos(a32 * (math.pi / 180.0)).item()
866
+ elif op == "tan_deg":
867
+ out = torch.tan(a32 * (math.pi / 180.0)).item()
868
+ elif op == "asin_deg":
869
+ out = (torch.asin(a32) * (180.0 / math.pi)).item()
870
+ elif op == "acos_deg":
871
+ out = (torch.acos(a32) * (180.0 / math.pi)).item()
872
+ elif op == "atan_deg":
873
+ out = (torch.atan(a32) * (180.0 / math.pi)).item()
874
  else:
875
  raise ValueError(f"unknown op: {op}")
876
  if out != out:
 
2719
  ("float16.cos", "cos"),
2720
  ("float16.tan", "tan"),
2721
  ("float16.tanh", "tanh"),
2722
+ ("float16.sin_deg", "sin_deg"),
2723
+ ("float16.cos_deg", "cos_deg"),
2724
+ ("float16.tan_deg", "tan_deg"),
2725
+ ("float16.asin_deg", "asin_deg"),
2726
+ ("float16.acos_deg", "acos_deg"),
2727
+ ("float16.atan_deg", "atan_deg"),
2728
  ("float16.asin", "asin"),
2729
  ("float16.acos", "acos"),
2730
  ("float16.atan", "atan"),