CharlesCNorton commited on
Commit
55eb692
·
1 Parent(s): 3c7f544

Add float16 domain flags

Browse files
Files changed (4) hide show
  1. arithmetic.safetensors +2 -2
  2. build.py +49 -0
  3. calculator.py +59 -46
  4. eval.py +80 -0
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6bc68cbb34ffc8d5a3dcdda3c63a2a806175e6e90ee83c7fe7ca37ed287e087d
3
- size 436688020
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a54f3c41068308469d61313565f8f8457799b84214eac14ff14588bd14309c92
3
+ size 443768896
build.py CHANGED
@@ -82,6 +82,24 @@ def compute_float16_unary_lut_outputs(op_fn: Callable[[torch.Tensor], torch.Tens
82
  return outputs
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def unary_float32(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
86
  """Wrap unary op to run in float32 for wider CPU support, returning float32."""
87
  def _fn(x: torch.Tensor) -> torch.Tensor:
@@ -129,6 +147,18 @@ def build_float16_lut_output_tensors(prefix: str, outputs: List[int]) -> Dict[st
129
  return tensors
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def clone_prefix_tensors(src: Dict[str, torch.Tensor], old_prefix: str,
133
  new_prefix: str) -> Dict[str, torch.Tensor]:
134
  """Clone tensors and rewrite the prefix in tensor names."""
@@ -209,6 +239,8 @@ def infer_float16_lut_match_inputs(gate: str, registry: SignalRegistry,
209
 
210
  def infer_float16_lut_out_inputs(gate: str, registry: SignalRegistry, match_prefix: str) -> List[int]:
211
  """Infer inputs for LUT output gates (one-hot match vector)."""
 
 
212
  match = re.search(r'\.out(\d+)$', gate)
213
  if not match:
214
  return []
@@ -11233,6 +11265,17 @@ def main():
11233
  "acos_deg": wrap_inv_trig_deg(torch.acos),
11234
  "atan_deg": wrap_inv_trig_deg(torch.atan),
11235
  }
 
 
 
 
 
 
 
 
 
 
 
11236
  lut_outputs: Dict[str, List[int]] = {}
11237
  for name, fn in unary_ops.items():
11238
  print(f" computing float16.{name} LUT...")
@@ -11248,6 +11291,12 @@ def main():
11248
  op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
11249
  tensors.update(op_tensors)
11250
  print(f" float16.{name}: {len(op_tensors)} tensors")
 
 
 
 
 
 
11251
 
11252
  # float16.pow (ln -> mul -> exp)
11253
  pow_tensors = build_float16_pow_tensors(mul_tensors,
 
82
  return outputs
83
 
84
 
85
+ def compute_float16_domain_flags(op: str) -> List[int]:
86
+ """Compute domain error flags (1=invalid) for all 65536 float16 inputs."""
87
+ flags: List[int] = [0] * 65536
88
+ for bits in range(65536):
89
+ val = float16_bits_to_float(bits)
90
+ invalid = False
91
+ if val != val:
92
+ invalid = True
93
+ elif op in ("sqrt", "rsqrt") and val < 0:
94
+ invalid = True
95
+ elif op in ("ln", "log2", "log10") and val <= 0:
96
+ invalid = True
97
+ elif op in ("asin", "acos", "asin_deg", "acos_deg") and abs(val) > 1.0:
98
+ invalid = True
99
+ flags[bits] = 1 if invalid else 0
100
+ return flags
101
+
102
+
103
  def unary_float32(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
104
  """Wrap unary op to run in float32 for wider CPU support, returning float32."""
105
  def _fn(x: torch.Tensor) -> torch.Tensor:
 
147
  return tensors
148
 
149
 
150
+ def build_float16_lut_flag_tensors(prefix: str, flags: List[int], flag_name: str = "domain") -> Dict[str, torch.Tensor]:
151
+ """Build a 1-bit LUT flag gate (prefix.{flag_name}) using one-hot match inputs."""
152
+ weights = torch.zeros(65536)
153
+ for idx, flag in enumerate(flags):
154
+ if flag:
155
+ weights[idx] = 1.0
156
+ tensors: Dict[str, torch.Tensor] = {}
157
+ tensors[f"{prefix}.{flag_name}.weight"] = weights
158
+ tensors[f"{prefix}.{flag_name}.bias"] = torch.tensor([-0.5])
159
+ return tensors
160
+
161
+
162
  def clone_prefix_tensors(src: Dict[str, torch.Tensor], old_prefix: str,
163
  new_prefix: str) -> Dict[str, torch.Tensor]:
164
  """Clone tensors and rewrite the prefix in tensor names."""
 
239
 
240
  def infer_float16_lut_out_inputs(gate: str, registry: SignalRegistry, match_prefix: str) -> List[int]:
241
  """Infer inputs for LUT output gates (one-hot match vector)."""
242
+ if gate.endswith(".domain"):
243
+ return get_lut_match_ids(registry, match_prefix)
244
  match = re.search(r'\.out(\d+)$', gate)
245
  if not match:
246
  return []
 
11265
  "acos_deg": wrap_inv_trig_deg(torch.acos),
11266
  "atan_deg": wrap_inv_trig_deg(torch.atan),
11267
  }
11268
+ domain_ops = [
11269
+ "sqrt",
11270
+ "rsqrt",
11271
+ "ln",
11272
+ "log2",
11273
+ "log10",
11274
+ "asin",
11275
+ "acos",
11276
+ "asin_deg",
11277
+ "acos_deg",
11278
+ ]
11279
  lut_outputs: Dict[str, List[int]] = {}
11280
  for name, fn in unary_ops.items():
11281
  print(f" computing float16.{name} LUT...")
 
11291
  op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
11292
  tensors.update(op_tensors)
11293
  print(f" float16.{name}: {len(op_tensors)} tensors")
11294
+ for name in domain_ops:
11295
+ print(f" computing float16.{name} domain flags...")
11296
+ flags = compute_float16_domain_flags(name)
11297
+ flag_tensors = build_float16_lut_flag_tensors(f"float16.{name}", flags, flag_name="domain")
11298
+ tensors.update(flag_tensors)
11299
+ print(f" float16.{name}.domain: {len(flag_tensors)} tensors")
11300
 
11301
  # float16.pow (ln -> mul -> exp)
11302
  pow_tensors = build_float16_pow_tensors(mul_tensors,
calculator.py CHANGED
@@ -506,12 +506,12 @@ class ThresholdCalculator:
506
  angle_mode = (angle_mode or "rad").lower()
507
  use_degrees = angle_mode.startswith("deg")
508
 
509
- def run_prefix(prefix: str, inputs: Dict[str, object]) -> int:
510
  nonlocal total_elapsed, total_gates
511
- res = self.evaluate_prefix(prefix, inputs, out_bits=16)
512
  total_elapsed += res.elapsed_s
513
  total_gates += res.gates_evaluated
514
- return bits_to_int(res.bits)
515
 
516
  def const_to_bits(tok: str) -> int:
517
  if tok == "pi":
@@ -562,7 +562,16 @@ class ThresholdCalculator:
562
  prefix = f"float16.{tok}_deg"
563
  elif use_degrees and tok in ("asin", "acos", "atan"):
564
  prefix = f"float16.{tok}_deg"
565
- out = run_prefix(prefix, {"x": x})
 
 
 
 
 
 
 
 
 
566
  stack.append(out)
567
  continue
568
  if tok in {"+", "-", "*", "/", "^"}:
@@ -571,16 +580,16 @@ class ThresholdCalculator:
571
  b = stack.pop()
572
  a = stack.pop()
573
  if tok == "+":
574
- out = run_prefix("float16.add", {"a": a, "b": b})
575
  elif tok == "-":
576
  b_flip = b ^ 0x8000
577
- out = run_prefix("float16.sub", {"a": a, "b": b_flip})
578
  elif tok == "*":
579
- out = run_prefix("float16.mul", {"a": a, "b": b})
580
  elif tok == "/":
581
- out = run_prefix("float16.div", {"a": a, "b": b})
582
  else:
583
- out = run_prefix("float16.pow", {"a": a, "b": b})
584
  stack.append(out)
585
  continue
586
  stack.append(const_to_bits(tok))
@@ -591,7 +600,7 @@ class ThresholdCalculator:
591
  out_bits = stack.pop()
592
  if total_gates == 0:
593
  if force_gate_eval:
594
- out_bits = run_prefix("float16.add", {"a": out_bits, "b": 0})
595
  else:
596
  non_gate_events.append("constant_expression_no_gates")
597
 
@@ -618,12 +627,22 @@ class ThresholdCalculator:
618
  total_gates = 0
619
  non_gate_events: List[str] = []
620
 
621
- def run_prefix(prefix: str, inputs: Dict[str, object]) -> int:
622
  nonlocal total_elapsed, total_gates
623
- res = self.evaluate_prefix(prefix, inputs, out_bits=16)
624
  total_elapsed += res.elapsed_s
625
  total_gates += res.gates_evaluated
626
- return bits_to_int(res.bits)
 
 
 
 
 
 
 
 
 
 
627
 
628
  def eval_node(node: ast.AST) -> int:
629
  if isinstance(node, ast.Expression):
@@ -648,22 +667,22 @@ class ThresholdCalculator:
648
  return eval_node(node.operand)
649
  if isinstance(node.op, ast.USub):
650
  x = eval_node(node.operand)
651
- return run_prefix("float16.neg", {"x": x})
652
  raise RuntimeError("unsupported unary operator")
653
  if isinstance(node, ast.BinOp):
654
  a = eval_node(node.left)
655
  b = eval_node(node.right)
656
  if isinstance(node.op, ast.Add):
657
- return run_prefix("float16.add", {"a": a, "b": b})
658
  if isinstance(node.op, ast.Sub):
659
  b_flip = b ^ 0x8000
660
- return run_prefix("float16.sub", {"a": a, "b": b_flip})
661
  if isinstance(node.op, ast.Mult):
662
- return run_prefix("float16.mul", {"a": a, "b": b})
663
  if isinstance(node.op, ast.Div):
664
- return run_prefix("float16.div", {"a": a, "b": b})
665
  if isinstance(node.op, ast.Pow):
666
- return run_prefix("float16.pow", {"a": a, "b": b})
667
  raise RuntimeError("unsupported binary operator")
668
  if isinstance(node, ast.Call):
669
  if not isinstance(node.func, ast.Name):
@@ -673,57 +692,51 @@ class ThresholdCalculator:
673
  raise RuntimeError(f"{fname} expects one argument")
674
  x = eval_node(node.args[0])
675
  if fname == "sqrt":
676
- return run_prefix("float16.sqrt", {"x": x})
677
  if fname == "rsqrt":
678
- return run_prefix("float16.rsqrt", {"x": x})
679
  if fname == "exp":
680
- return run_prefix("float16.exp", {"x": x})
681
  if fname in ("ln", "log"):
682
- return run_prefix("float16.ln", {"x": x})
683
  if fname == "log2":
684
- return run_prefix("float16.log2", {"x": x})
685
  if fname == "log10":
686
- return run_prefix("float16.log10", {"x": x})
687
  if fname == "sin":
688
  prefix = "float16.sin_deg" if use_degrees else "float16.sin"
689
- return run_prefix(prefix, {"x": x})
690
  if fname == "cos":
691
  prefix = "float16.cos_deg" if use_degrees else "float16.cos"
692
- return run_prefix(prefix, {"x": x})
693
  if fname == "tan":
694
  prefix = "float16.tan_deg" if use_degrees else "float16.tan"
695
- return run_prefix(prefix, {"x": x})
696
  if fname == "tanh":
697
- return run_prefix("float16.tanh", {"x": x})
698
  if fname == "asin":
699
  prefix = "float16.asin_deg" if use_degrees else "float16.asin"
700
- return run_prefix(prefix, {"x": x})
701
  if fname == "acos":
702
  prefix = "float16.acos_deg" if use_degrees else "float16.acos"
703
- return run_prefix(prefix, {"x": x})
704
  if fname == "atan":
705
  prefix = "float16.atan_deg" if use_degrees else "float16.atan"
706
- return run_prefix(prefix, {"x": x})
707
- if fname == "asin":
708
- return run_prefix("float16.asin", {"x": x})
709
- if fname == "acos":
710
- return run_prefix("float16.acos", {"x": x})
711
- if fname == "atan":
712
- return run_prefix("float16.atan", {"x": x})
713
  if fname == "sinh":
714
- return run_prefix("float16.sinh", {"x": x})
715
  if fname == "cosh":
716
- return run_prefix("float16.cosh", {"x": x})
717
  if fname == "floor":
718
- return run_prefix("float16.floor", {"x": x})
719
  if fname == "ceil":
720
- return run_prefix("float16.ceil", {"x": x})
721
  if fname == "round":
722
- return run_prefix("float16.round", {"x": x})
723
  if fname == "abs":
724
- return run_prefix("float16.abs", {"x": x})
725
  if fname == "neg":
726
- return run_prefix("float16.neg", {"x": x})
727
  raise RuntimeError(f"unsupported function: {fname}")
728
  raise RuntimeError("unsupported expression")
729
 
@@ -731,7 +744,7 @@ class ThresholdCalculator:
731
  if total_gates == 0:
732
  if force_gate_eval:
733
  # Route constants through float16.add with +0 to ensure gate-level evaluation.
734
- out_bits = run_prefix("float16.add", {"a": out_bits, "b": 0})
735
  else:
736
  non_gate_events.append("constant_expression_no_gates")
737
  return EvalResult(
 
506
  angle_mode = (angle_mode or "rad").lower()
507
  use_degrees = angle_mode.startswith("deg")
508
 
509
+ def run_prefix(prefix: str, inputs: Dict[str, object], outputs: Optional[List[str]] = None) -> EvalResult:
510
  nonlocal total_elapsed, total_gates
511
+ res = self.evaluate_prefix(prefix, inputs, out_bits=16, outputs=outputs)
512
  total_elapsed += res.elapsed_s
513
  total_gates += res.gates_evaluated
514
+ return res
515
 
516
  def const_to_bits(tok: str) -> int:
517
  if tok == "pi":
 
562
  prefix = f"float16.{tok}_deg"
563
  elif use_degrees and tok in ("asin", "acos", "atan"):
564
  prefix = f"float16.{tok}_deg"
565
+ if f"{prefix}.domain.weight" in self.tensors:
566
+ outs = [f"{prefix}.out{i}" for i in range(16)] + [f"{prefix}.domain"]
567
+ res = run_prefix(prefix, {"x": x}, outputs=outs)
568
+ if res.bits[16] >= 0.5:
569
+ x_val = float16_bits_to_float(x)
570
+ raise RuntimeError(f"domain error: {tok}({x_val})")
571
+ out = bits_to_int(res.bits[:16])
572
+ else:
573
+ res = run_prefix(prefix, {"x": x})
574
+ out = bits_to_int(res.bits)
575
  stack.append(out)
576
  continue
577
  if tok in {"+", "-", "*", "/", "^"}:
 
580
  b = stack.pop()
581
  a = stack.pop()
582
  if tok == "+":
583
+ out = bits_to_int(run_prefix("float16.add", {"a": a, "b": b}).bits)
584
  elif tok == "-":
585
  b_flip = b ^ 0x8000
586
+ out = bits_to_int(run_prefix("float16.sub", {"a": a, "b": b_flip}).bits)
587
  elif tok == "*":
588
+ out = bits_to_int(run_prefix("float16.mul", {"a": a, "b": b}).bits)
589
  elif tok == "/":
590
+ out = bits_to_int(run_prefix("float16.div", {"a": a, "b": b}).bits)
591
  else:
592
+ out = bits_to_int(run_prefix("float16.pow", {"a": a, "b": b}).bits)
593
  stack.append(out)
594
  continue
595
  stack.append(const_to_bits(tok))
 
600
  out_bits = stack.pop()
601
  if total_gates == 0:
602
  if force_gate_eval:
603
+ out_bits = bits_to_int(run_prefix("float16.add", {"a": out_bits, "b": 0}).bits)
604
  else:
605
  non_gate_events.append("constant_expression_no_gates")
606
 
 
627
  total_gates = 0
628
  non_gate_events: List[str] = []
629
 
630
+ def run_prefix(prefix: str, inputs: Dict[str, object], outputs: Optional[List[str]] = None) -> EvalResult:
631
  nonlocal total_elapsed, total_gates
632
+ res = self.evaluate_prefix(prefix, inputs, out_bits=16, outputs=outputs)
633
  total_elapsed += res.elapsed_s
634
  total_gates += res.gates_evaluated
635
+ return res
636
+
637
+ def run_unary(prefix: str, x_bits: int, fname: str) -> int:
638
+ if f"{prefix}.domain.weight" in self.tensors:
639
+ outs = [f"{prefix}.out{i}" for i in range(16)] + [f"{prefix}.domain"]
640
+ res = run_prefix(prefix, {"x": x_bits}, outputs=outs)
641
+ if res.bits[16] >= 0.5:
642
+ x_val = float16_bits_to_float(x_bits)
643
+ raise RuntimeError(f"domain error: {fname}({x_val})")
644
+ return bits_to_int(res.bits[:16])
645
+ return bits_to_int(run_prefix(prefix, {"x": x_bits}).bits)
646
 
647
  def eval_node(node: ast.AST) -> int:
648
  if isinstance(node, ast.Expression):
 
667
  return eval_node(node.operand)
668
  if isinstance(node.op, ast.USub):
669
  x = eval_node(node.operand)
670
+ return bits_to_int(run_prefix("float16.neg", {"x": x}).bits)
671
  raise RuntimeError("unsupported unary operator")
672
  if isinstance(node, ast.BinOp):
673
  a = eval_node(node.left)
674
  b = eval_node(node.right)
675
  if isinstance(node.op, ast.Add):
676
+ return bits_to_int(run_prefix("float16.add", {"a": a, "b": b}).bits)
677
  if isinstance(node.op, ast.Sub):
678
  b_flip = b ^ 0x8000
679
+ return bits_to_int(run_prefix("float16.sub", {"a": a, "b": b_flip}).bits)
680
  if isinstance(node.op, ast.Mult):
681
+ return bits_to_int(run_prefix("float16.mul", {"a": a, "b": b}).bits)
682
  if isinstance(node.op, ast.Div):
683
+ return bits_to_int(run_prefix("float16.div", {"a": a, "b": b}).bits)
684
  if isinstance(node.op, ast.Pow):
685
+ return bits_to_int(run_prefix("float16.pow", {"a": a, "b": b}).bits)
686
  raise RuntimeError("unsupported binary operator")
687
  if isinstance(node, ast.Call):
688
  if not isinstance(node.func, ast.Name):
 
692
  raise RuntimeError(f"{fname} expects one argument")
693
  x = eval_node(node.args[0])
694
  if fname == "sqrt":
695
+ return run_unary("float16.sqrt", x, fname)
696
  if fname == "rsqrt":
697
+ return run_unary("float16.rsqrt", x, fname)
698
  if fname == "exp":
699
+ return run_unary("float16.exp", x, fname)
700
  if fname in ("ln", "log"):
701
+ return run_unary("float16.ln", x, fname)
702
  if fname == "log2":
703
+ return run_unary("float16.log2", x, fname)
704
  if fname == "log10":
705
+ return run_unary("float16.log10", x, fname)
706
  if fname == "sin":
707
  prefix = "float16.sin_deg" if use_degrees else "float16.sin"
708
+ return run_unary(prefix, x, fname)
709
  if fname == "cos":
710
  prefix = "float16.cos_deg" if use_degrees else "float16.cos"
711
+ return run_unary(prefix, x, fname)
712
  if fname == "tan":
713
  prefix = "float16.tan_deg" if use_degrees else "float16.tan"
714
+ return run_unary(prefix, x, fname)
715
  if fname == "tanh":
716
+ return run_unary("float16.tanh", x, fname)
717
  if fname == "asin":
718
  prefix = "float16.asin_deg" if use_degrees else "float16.asin"
719
+ return run_unary(prefix, x, fname)
720
  if fname == "acos":
721
  prefix = "float16.acos_deg" if use_degrees else "float16.acos"
722
+ return run_unary(prefix, x, fname)
723
  if fname == "atan":
724
  prefix = "float16.atan_deg" if use_degrees else "float16.atan"
725
+ return run_unary(prefix, x, fname)
 
 
 
 
 
 
726
  if fname == "sinh":
727
+ return run_unary("float16.sinh", x, fname)
728
  if fname == "cosh":
729
+ return run_unary("float16.cosh", x, fname)
730
  if fname == "floor":
731
+ return run_unary("float16.floor", x, fname)
732
  if fname == "ceil":
733
+ return run_unary("float16.ceil", x, fname)
734
  if fname == "round":
735
+ return run_unary("float16.round", x, fname)
736
  if fname == "abs":
737
+ return run_unary("float16.abs", x, fname)
738
  if fname == "neg":
739
+ return run_unary("float16.neg", x, fname)
740
  raise RuntimeError(f"unsupported function: {fname}")
741
  raise RuntimeError("unsupported expression")
742
 
 
744
  if total_gates == 0:
745
  if force_gate_eval:
746
  # Route constants through float16.add with +0 to ensure gate-level evaluation.
747
+ out_bits = bits_to_int(run_prefix("float16.add", {"a": out_bits, "b": 0}).bits)
748
  else:
749
  non_gate_events.append("constant_expression_no_gates")
750
  return EvalResult(
eval.py CHANGED
@@ -716,6 +716,33 @@ def eval_float16_lut_outputs(ctx: EvalContext, op_prefix: str,
716
  return outputs
717
 
718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
  def build_float16_pairs(rng: random.Random, count: int) -> List[Tuple[int, int]]:
720
  """Build deterministic float16 test pairs using edge cases + random."""
721
  edges = [
@@ -892,6 +919,20 @@ def float16_expected_bits_pow(a_bits: int, b_bits: int) -> Tuple[int, bool]:
892
  return float_to_int(float(out)), False
893
 
894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895
  # =============================================================================
896
  # BOOLEAN GATE TESTS
897
  # =============================================================================
@@ -2760,6 +2801,44 @@ def test_float16_unary(ctx: EvalContext) -> List[TestResult]:
2760
  return results
2761
 
2762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2763
  def test_float16_pow(ctx: EvalContext) -> List[TestResult]:
2764
  """Test float16.pow (defined as exp(b * ln(a)))."""
2765
  results: List[TestResult] = []
@@ -2855,6 +2934,7 @@ CATEGORIES = {
2855
  "float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
2856
  "float16_conv": ("Float16 - Conversion", test_float16_conversion),
2857
  "float16_unary": ("Float16 - Unary LUT", test_float16_unary),
 
2858
  "float16_pow": ("Float16 - Pow", test_float16_pow),
2859
  }
2860
 
 
716
  return outputs
717
 
718
 
719
+ def eval_float16_lut_flag(ctx: EvalContext, op_prefix: str,
720
+ bits: List[float],
721
+ flag: str = "domain",
722
+ match_prefix: str = "float16.lut") -> float:
723
+ """Evaluate a LUT-backed 1-bit flag using direct LUT indexing."""
724
+ idx = bits_to_int(bits)
725
+ match_gate = f"{match_prefix}.match{idx:04x}"
726
+ for suffix in (".weight", ".bias", ".inputs"):
727
+ key = match_gate + suffix
728
+ if key in ctx.tensors:
729
+ ctx.tested_tensors.add(key)
730
+
731
+ gate = f"{op_prefix}.{flag}"
732
+ weight_key = f"{gate}.weight"
733
+ bias_key = f"{gate}.bias"
734
+ inputs_key = f"{gate}.inputs"
735
+ ctx.tested_tensors.add(weight_key)
736
+ if bias_key in ctx.tensors:
737
+ ctx.tested_tensors.add(bias_key)
738
+ if inputs_key in ctx.tensors:
739
+ ctx.tested_tensors.add(inputs_key)
740
+
741
+ weight = ctx.tensors[weight_key][idx].item()
742
+ bias = ctx.tensors.get(bias_key, torch.tensor([0.0])).item()
743
+ return 1.0 if (weight + bias) >= 0 else 0.0
744
+
745
+
746
  def build_float16_pairs(rng: random.Random, count: int) -> List[Tuple[int, int]]:
747
  """Build deterministic float16 test pairs using edge cases + random."""
748
  edges = [
 
919
  return float_to_int(float(out)), False
920
 
921
 
922
+ def float16_expected_domain(op: str, a_bits: int) -> int:
923
+ """Compute expected domain flag (1=invalid) for unary ops."""
924
+ a = float16_int_to_float(a_bits)
925
+ if a != a:
926
+ return 1
927
+ if op in ("sqrt", "rsqrt") and a < 0:
928
+ return 1
929
+ if op in ("ln", "log2", "log10") and a <= 0:
930
+ return 1
931
+ if op in ("asin", "acos", "asin_deg", "acos_deg") and abs(a) > 1.0:
932
+ return 1
933
+ return 0
934
+
935
+
936
  # =============================================================================
937
  # BOOLEAN GATE TESTS
938
  # =============================================================================
 
2801
  return results
2802
 
2803
 
2804
+ def test_float16_domain_flags(ctx: EvalContext) -> List[TestResult]:
2805
+ """Test float16 domain flag outputs."""
2806
+ results: List[TestResult] = []
2807
+ rng = random.Random(1337)
2808
+ values = build_float16_values(rng, 256)
2809
+ ops = [
2810
+ ("float16.sqrt", "sqrt"),
2811
+ ("float16.rsqrt", "rsqrt"),
2812
+ ("float16.ln", "ln"),
2813
+ ("float16.log2", "log2"),
2814
+ ("float16.log10", "log10"),
2815
+ ("float16.asin", "asin"),
2816
+ ("float16.acos", "acos"),
2817
+ ("float16.asin_deg", "asin_deg"),
2818
+ ("float16.acos_deg", "acos_deg"),
2819
+ ]
2820
+ for prefix, op in ops:
2821
+ if f"{prefix}.domain.weight" not in ctx.tensors:
2822
+ continue
2823
+ passed, total = 0, 0
2824
+ failures: List[Dict[str, Any]] = []
2825
+ for a_bits in values:
2826
+ bits_list = [float((a_bits >> i) & 1) for i in range(16)]
2827
+ actual = eval_float16_lut_flag(ctx, prefix, bits_list)
2828
+ expected = float16_expected_domain(op, a_bits)
2829
+ total += 1
2830
+ if int(actual) == expected:
2831
+ passed += 1
2832
+ elif len(failures) < 8:
2833
+ failures.append({
2834
+ "input": hex(a_bits),
2835
+ "actual": int(actual),
2836
+ "expected": expected,
2837
+ })
2838
+ results.append(TestResult(f"{prefix}.domain", passed, total, failures))
2839
+ return results
2840
+
2841
+
2842
  def test_float16_pow(ctx: EvalContext) -> List[TestResult]:
2843
  """Test float16.pow (defined as exp(b * ln(a)))."""
2844
  results: List[TestResult] = []
 
2934
  "float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
2935
  "float16_conv": ("Float16 - Conversion", test_float16_conversion),
2936
  "float16_unary": ("Float16 - Unary LUT", test_float16_unary),
2937
+ "float16_domain": ("Float16 - Domain Flags", test_float16_domain_flags),
2938
  "float16_pow": ("Float16 - Pow", test_float16_pow),
2939
  }
2940