CharlesCNorton commited on
Commit ·
55eb692
1
Parent(s): 3c7f544
Add float16 domain flags
Browse files- arithmetic.safetensors +2 -2
- build.py +49 -0
- calculator.py +59 -46
- eval.py +80 -0
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a54f3c41068308469d61313565f8f8457799b84214eac14ff14588bd14309c92
|
| 3 |
+
size 443768896
|
build.py
CHANGED
|
@@ -82,6 +82,24 @@ def compute_float16_unary_lut_outputs(op_fn: Callable[[torch.Tensor], torch.Tens
|
|
| 82 |
return outputs
|
| 83 |
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def unary_float32(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 86 |
"""Wrap unary op to run in float32 for wider CPU support, returning float32."""
|
| 87 |
def _fn(x: torch.Tensor) -> torch.Tensor:
|
|
@@ -129,6 +147,18 @@ def build_float16_lut_output_tensors(prefix: str, outputs: List[int]) -> Dict[st
|
|
| 129 |
return tensors
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
def clone_prefix_tensors(src: Dict[str, torch.Tensor], old_prefix: str,
|
| 133 |
new_prefix: str) -> Dict[str, torch.Tensor]:
|
| 134 |
"""Clone tensors and rewrite the prefix in tensor names."""
|
|
@@ -209,6 +239,8 @@ def infer_float16_lut_match_inputs(gate: str, registry: SignalRegistry,
|
|
| 209 |
|
| 210 |
def infer_float16_lut_out_inputs(gate: str, registry: SignalRegistry, match_prefix: str) -> List[int]:
|
| 211 |
"""Infer inputs for LUT output gates (one-hot match vector)."""
|
|
|
|
|
|
|
| 212 |
match = re.search(r'\.out(\d+)$', gate)
|
| 213 |
if not match:
|
| 214 |
return []
|
|
@@ -11233,6 +11265,17 @@ def main():
|
|
| 11233 |
"acos_deg": wrap_inv_trig_deg(torch.acos),
|
| 11234 |
"atan_deg": wrap_inv_trig_deg(torch.atan),
|
| 11235 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11236 |
lut_outputs: Dict[str, List[int]] = {}
|
| 11237 |
for name, fn in unary_ops.items():
|
| 11238 |
print(f" computing float16.{name} LUT...")
|
|
@@ -11248,6 +11291,12 @@ def main():
|
|
| 11248 |
op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
|
| 11249 |
tensors.update(op_tensors)
|
| 11250 |
print(f" float16.{name}: {len(op_tensors)} tensors")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11251 |
|
| 11252 |
# float16.pow (ln -> mul -> exp)
|
| 11253 |
pow_tensors = build_float16_pow_tensors(mul_tensors,
|
|
|
|
| 82 |
return outputs
|
| 83 |
|
| 84 |
|
| 85 |
+
def compute_float16_domain_flags(op: str) -> List[int]:
|
| 86 |
+
"""Compute domain error flags (1=invalid) for all 65536 float16 inputs."""
|
| 87 |
+
flags: List[int] = [0] * 65536
|
| 88 |
+
for bits in range(65536):
|
| 89 |
+
val = float16_bits_to_float(bits)
|
| 90 |
+
invalid = False
|
| 91 |
+
if val != val:
|
| 92 |
+
invalid = True
|
| 93 |
+
elif op in ("sqrt", "rsqrt") and val < 0:
|
| 94 |
+
invalid = True
|
| 95 |
+
elif op in ("ln", "log2", "log10") and val <= 0:
|
| 96 |
+
invalid = True
|
| 97 |
+
elif op in ("asin", "acos", "asin_deg", "acos_deg") and abs(val) > 1.0:
|
| 98 |
+
invalid = True
|
| 99 |
+
flags[bits] = 1 if invalid else 0
|
| 100 |
+
return flags
|
| 101 |
+
|
| 102 |
+
|
| 103 |
def unary_float32(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 104 |
"""Wrap unary op to run in float32 for wider CPU support, returning float32."""
|
| 105 |
def _fn(x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 147 |
return tensors
|
| 148 |
|
| 149 |
|
| 150 |
+
def build_float16_lut_flag_tensors(prefix: str, flags: List[int], flag_name: str = "domain") -> Dict[str, torch.Tensor]:
|
| 151 |
+
"""Build a 1-bit LUT flag gate (prefix.{flag_name}) using one-hot match inputs."""
|
| 152 |
+
weights = torch.zeros(65536)
|
| 153 |
+
for idx, flag in enumerate(flags):
|
| 154 |
+
if flag:
|
| 155 |
+
weights[idx] = 1.0
|
| 156 |
+
tensors: Dict[str, torch.Tensor] = {}
|
| 157 |
+
tensors[f"{prefix}.{flag_name}.weight"] = weights
|
| 158 |
+
tensors[f"{prefix}.{flag_name}.bias"] = torch.tensor([-0.5])
|
| 159 |
+
return tensors
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def clone_prefix_tensors(src: Dict[str, torch.Tensor], old_prefix: str,
|
| 163 |
new_prefix: str) -> Dict[str, torch.Tensor]:
|
| 164 |
"""Clone tensors and rewrite the prefix in tensor names."""
|
|
|
|
| 239 |
|
| 240 |
def infer_float16_lut_out_inputs(gate: str, registry: SignalRegistry, match_prefix: str) -> List[int]:
|
| 241 |
"""Infer inputs for LUT output gates (one-hot match vector)."""
|
| 242 |
+
if gate.endswith(".domain"):
|
| 243 |
+
return get_lut_match_ids(registry, match_prefix)
|
| 244 |
match = re.search(r'\.out(\d+)$', gate)
|
| 245 |
if not match:
|
| 246 |
return []
|
|
|
|
| 11265 |
"acos_deg": wrap_inv_trig_deg(torch.acos),
|
| 11266 |
"atan_deg": wrap_inv_trig_deg(torch.atan),
|
| 11267 |
}
|
| 11268 |
+
domain_ops = [
|
| 11269 |
+
"sqrt",
|
| 11270 |
+
"rsqrt",
|
| 11271 |
+
"ln",
|
| 11272 |
+
"log2",
|
| 11273 |
+
"log10",
|
| 11274 |
+
"asin",
|
| 11275 |
+
"acos",
|
| 11276 |
+
"asin_deg",
|
| 11277 |
+
"acos_deg",
|
| 11278 |
+
]
|
| 11279 |
lut_outputs: Dict[str, List[int]] = {}
|
| 11280 |
for name, fn in unary_ops.items():
|
| 11281 |
print(f" computing float16.{name} LUT...")
|
|
|
|
| 11291 |
op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
|
| 11292 |
tensors.update(op_tensors)
|
| 11293 |
print(f" float16.{name}: {len(op_tensors)} tensors")
|
| 11294 |
+
for name in domain_ops:
|
| 11295 |
+
print(f" computing float16.{name} domain flags...")
|
| 11296 |
+
flags = compute_float16_domain_flags(name)
|
| 11297 |
+
flag_tensors = build_float16_lut_flag_tensors(f"float16.{name}", flags, flag_name="domain")
|
| 11298 |
+
tensors.update(flag_tensors)
|
| 11299 |
+
print(f" float16.{name}.domain: {len(flag_tensors)} tensors")
|
| 11300 |
|
| 11301 |
# float16.pow (ln -> mul -> exp)
|
| 11302 |
pow_tensors = build_float16_pow_tensors(mul_tensors,
|
calculator.py
CHANGED
|
@@ -506,12 +506,12 @@ class ThresholdCalculator:
|
|
| 506 |
angle_mode = (angle_mode or "rad").lower()
|
| 507 |
use_degrees = angle_mode.startswith("deg")
|
| 508 |
|
| 509 |
-
def run_prefix(prefix: str, inputs: Dict[str, object]) ->
|
| 510 |
nonlocal total_elapsed, total_gates
|
| 511 |
-
res = self.evaluate_prefix(prefix, inputs, out_bits=16)
|
| 512 |
total_elapsed += res.elapsed_s
|
| 513 |
total_gates += res.gates_evaluated
|
| 514 |
-
return
|
| 515 |
|
| 516 |
def const_to_bits(tok: str) -> int:
|
| 517 |
if tok == "pi":
|
|
@@ -562,7 +562,16 @@ class ThresholdCalculator:
|
|
| 562 |
prefix = f"float16.{tok}_deg"
|
| 563 |
elif use_degrees and tok in ("asin", "acos", "atan"):
|
| 564 |
prefix = f"float16.{tok}_deg"
|
| 565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
stack.append(out)
|
| 567 |
continue
|
| 568 |
if tok in {"+", "-", "*", "/", "^"}:
|
|
@@ -571,16 +580,16 @@ class ThresholdCalculator:
|
|
| 571 |
b = stack.pop()
|
| 572 |
a = stack.pop()
|
| 573 |
if tok == "+":
|
| 574 |
-
out = run_prefix("float16.add", {"a": a, "b": b})
|
| 575 |
elif tok == "-":
|
| 576 |
b_flip = b ^ 0x8000
|
| 577 |
-
out = run_prefix("float16.sub", {"a": a, "b": b_flip})
|
| 578 |
elif tok == "*":
|
| 579 |
-
out = run_prefix("float16.mul", {"a": a, "b": b})
|
| 580 |
elif tok == "/":
|
| 581 |
-
out = run_prefix("float16.div", {"a": a, "b": b})
|
| 582 |
else:
|
| 583 |
-
out = run_prefix("float16.pow", {"a": a, "b": b})
|
| 584 |
stack.append(out)
|
| 585 |
continue
|
| 586 |
stack.append(const_to_bits(tok))
|
|
@@ -591,7 +600,7 @@ class ThresholdCalculator:
|
|
| 591 |
out_bits = stack.pop()
|
| 592 |
if total_gates == 0:
|
| 593 |
if force_gate_eval:
|
| 594 |
-
out_bits = run_prefix("float16.add", {"a": out_bits, "b": 0})
|
| 595 |
else:
|
| 596 |
non_gate_events.append("constant_expression_no_gates")
|
| 597 |
|
|
@@ -618,12 +627,22 @@ class ThresholdCalculator:
|
|
| 618 |
total_gates = 0
|
| 619 |
non_gate_events: List[str] = []
|
| 620 |
|
| 621 |
-
def run_prefix(prefix: str, inputs: Dict[str, object]) ->
|
| 622 |
nonlocal total_elapsed, total_gates
|
| 623 |
-
res = self.evaluate_prefix(prefix, inputs, out_bits=16)
|
| 624 |
total_elapsed += res.elapsed_s
|
| 625 |
total_gates += res.gates_evaluated
|
| 626 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
|
| 628 |
def eval_node(node: ast.AST) -> int:
|
| 629 |
if isinstance(node, ast.Expression):
|
|
@@ -648,22 +667,22 @@ class ThresholdCalculator:
|
|
| 648 |
return eval_node(node.operand)
|
| 649 |
if isinstance(node.op, ast.USub):
|
| 650 |
x = eval_node(node.operand)
|
| 651 |
-
return run_prefix("float16.neg", {"x": x})
|
| 652 |
raise RuntimeError("unsupported unary operator")
|
| 653 |
if isinstance(node, ast.BinOp):
|
| 654 |
a = eval_node(node.left)
|
| 655 |
b = eval_node(node.right)
|
| 656 |
if isinstance(node.op, ast.Add):
|
| 657 |
-
return run_prefix("float16.add", {"a": a, "b": b})
|
| 658 |
if isinstance(node.op, ast.Sub):
|
| 659 |
b_flip = b ^ 0x8000
|
| 660 |
-
return run_prefix("float16.sub", {"a": a, "b": b_flip})
|
| 661 |
if isinstance(node.op, ast.Mult):
|
| 662 |
-
return run_prefix("float16.mul", {"a": a, "b": b})
|
| 663 |
if isinstance(node.op, ast.Div):
|
| 664 |
-
return run_prefix("float16.div", {"a": a, "b": b})
|
| 665 |
if isinstance(node.op, ast.Pow):
|
| 666 |
-
return run_prefix("float16.pow", {"a": a, "b": b})
|
| 667 |
raise RuntimeError("unsupported binary operator")
|
| 668 |
if isinstance(node, ast.Call):
|
| 669 |
if not isinstance(node.func, ast.Name):
|
|
@@ -673,57 +692,51 @@ class ThresholdCalculator:
|
|
| 673 |
raise RuntimeError(f"{fname} expects one argument")
|
| 674 |
x = eval_node(node.args[0])
|
| 675 |
if fname == "sqrt":
|
| 676 |
-
return
|
| 677 |
if fname == "rsqrt":
|
| 678 |
-
return
|
| 679 |
if fname == "exp":
|
| 680 |
-
return
|
| 681 |
if fname in ("ln", "log"):
|
| 682 |
-
return
|
| 683 |
if fname == "log2":
|
| 684 |
-
return
|
| 685 |
if fname == "log10":
|
| 686 |
-
return
|
| 687 |
if fname == "sin":
|
| 688 |
prefix = "float16.sin_deg" if use_degrees else "float16.sin"
|
| 689 |
-
return
|
| 690 |
if fname == "cos":
|
| 691 |
prefix = "float16.cos_deg" if use_degrees else "float16.cos"
|
| 692 |
-
return
|
| 693 |
if fname == "tan":
|
| 694 |
prefix = "float16.tan_deg" if use_degrees else "float16.tan"
|
| 695 |
-
return
|
| 696 |
if fname == "tanh":
|
| 697 |
-
return
|
| 698 |
if fname == "asin":
|
| 699 |
prefix = "float16.asin_deg" if use_degrees else "float16.asin"
|
| 700 |
-
return
|
| 701 |
if fname == "acos":
|
| 702 |
prefix = "float16.acos_deg" if use_degrees else "float16.acos"
|
| 703 |
-
return
|
| 704 |
if fname == "atan":
|
| 705 |
prefix = "float16.atan_deg" if use_degrees else "float16.atan"
|
| 706 |
-
return
|
| 707 |
-
if fname == "asin":
|
| 708 |
-
return run_prefix("float16.asin", {"x": x})
|
| 709 |
-
if fname == "acos":
|
| 710 |
-
return run_prefix("float16.acos", {"x": x})
|
| 711 |
-
if fname == "atan":
|
| 712 |
-
return run_prefix("float16.atan", {"x": x})
|
| 713 |
if fname == "sinh":
|
| 714 |
-
return
|
| 715 |
if fname == "cosh":
|
| 716 |
-
return
|
| 717 |
if fname == "floor":
|
| 718 |
-
return
|
| 719 |
if fname == "ceil":
|
| 720 |
-
return
|
| 721 |
if fname == "round":
|
| 722 |
-
return
|
| 723 |
if fname == "abs":
|
| 724 |
-
return
|
| 725 |
if fname == "neg":
|
| 726 |
-
return
|
| 727 |
raise RuntimeError(f"unsupported function: {fname}")
|
| 728 |
raise RuntimeError("unsupported expression")
|
| 729 |
|
|
@@ -731,7 +744,7 @@ class ThresholdCalculator:
|
|
| 731 |
if total_gates == 0:
|
| 732 |
if force_gate_eval:
|
| 733 |
# Route constants through float16.add with +0 to ensure gate-level evaluation.
|
| 734 |
-
out_bits = run_prefix("float16.add", {"a": out_bits, "b": 0})
|
| 735 |
else:
|
| 736 |
non_gate_events.append("constant_expression_no_gates")
|
| 737 |
return EvalResult(
|
|
|
|
| 506 |
angle_mode = (angle_mode or "rad").lower()
|
| 507 |
use_degrees = angle_mode.startswith("deg")
|
| 508 |
|
| 509 |
+
def run_prefix(prefix: str, inputs: Dict[str, object], outputs: Optional[List[str]] = None) -> EvalResult:
|
| 510 |
nonlocal total_elapsed, total_gates
|
| 511 |
+
res = self.evaluate_prefix(prefix, inputs, out_bits=16, outputs=outputs)
|
| 512 |
total_elapsed += res.elapsed_s
|
| 513 |
total_gates += res.gates_evaluated
|
| 514 |
+
return res
|
| 515 |
|
| 516 |
def const_to_bits(tok: str) -> int:
|
| 517 |
if tok == "pi":
|
|
|
|
| 562 |
prefix = f"float16.{tok}_deg"
|
| 563 |
elif use_degrees and tok in ("asin", "acos", "atan"):
|
| 564 |
prefix = f"float16.{tok}_deg"
|
| 565 |
+
if f"{prefix}.domain.weight" in self.tensors:
|
| 566 |
+
outs = [f"{prefix}.out{i}" for i in range(16)] + [f"{prefix}.domain"]
|
| 567 |
+
res = run_prefix(prefix, {"x": x}, outputs=outs)
|
| 568 |
+
if res.bits[16] >= 0.5:
|
| 569 |
+
x_val = float16_bits_to_float(x)
|
| 570 |
+
raise RuntimeError(f"domain error: {tok}({x_val})")
|
| 571 |
+
out = bits_to_int(res.bits[:16])
|
| 572 |
+
else:
|
| 573 |
+
res = run_prefix(prefix, {"x": x})
|
| 574 |
+
out = bits_to_int(res.bits)
|
| 575 |
stack.append(out)
|
| 576 |
continue
|
| 577 |
if tok in {"+", "-", "*", "/", "^"}:
|
|
|
|
| 580 |
b = stack.pop()
|
| 581 |
a = stack.pop()
|
| 582 |
if tok == "+":
|
| 583 |
+
out = bits_to_int(run_prefix("float16.add", {"a": a, "b": b}).bits)
|
| 584 |
elif tok == "-":
|
| 585 |
b_flip = b ^ 0x8000
|
| 586 |
+
out = bits_to_int(run_prefix("float16.sub", {"a": a, "b": b_flip}).bits)
|
| 587 |
elif tok == "*":
|
| 588 |
+
out = bits_to_int(run_prefix("float16.mul", {"a": a, "b": b}).bits)
|
| 589 |
elif tok == "/":
|
| 590 |
+
out = bits_to_int(run_prefix("float16.div", {"a": a, "b": b}).bits)
|
| 591 |
else:
|
| 592 |
+
out = bits_to_int(run_prefix("float16.pow", {"a": a, "b": b}).bits)
|
| 593 |
stack.append(out)
|
| 594 |
continue
|
| 595 |
stack.append(const_to_bits(tok))
|
|
|
|
| 600 |
out_bits = stack.pop()
|
| 601 |
if total_gates == 0:
|
| 602 |
if force_gate_eval:
|
| 603 |
+
out_bits = bits_to_int(run_prefix("float16.add", {"a": out_bits, "b": 0}).bits)
|
| 604 |
else:
|
| 605 |
non_gate_events.append("constant_expression_no_gates")
|
| 606 |
|
|
|
|
| 627 |
total_gates = 0
|
| 628 |
non_gate_events: List[str] = []
|
| 629 |
|
| 630 |
+
def run_prefix(prefix: str, inputs: Dict[str, object], outputs: Optional[List[str]] = None) -> EvalResult:
|
| 631 |
nonlocal total_elapsed, total_gates
|
| 632 |
+
res = self.evaluate_prefix(prefix, inputs, out_bits=16, outputs=outputs)
|
| 633 |
total_elapsed += res.elapsed_s
|
| 634 |
total_gates += res.gates_evaluated
|
| 635 |
+
return res
|
| 636 |
+
|
| 637 |
+
def run_unary(prefix: str, x_bits: int, fname: str) -> int:
|
| 638 |
+
if f"{prefix}.domain.weight" in self.tensors:
|
| 639 |
+
outs = [f"{prefix}.out{i}" for i in range(16)] + [f"{prefix}.domain"]
|
| 640 |
+
res = run_prefix(prefix, {"x": x_bits}, outputs=outs)
|
| 641 |
+
if res.bits[16] >= 0.5:
|
| 642 |
+
x_val = float16_bits_to_float(x_bits)
|
| 643 |
+
raise RuntimeError(f"domain error: {fname}({x_val})")
|
| 644 |
+
return bits_to_int(res.bits[:16])
|
| 645 |
+
return bits_to_int(run_prefix(prefix, {"x": x_bits}).bits)
|
| 646 |
|
| 647 |
def eval_node(node: ast.AST) -> int:
|
| 648 |
if isinstance(node, ast.Expression):
|
|
|
|
| 667 |
return eval_node(node.operand)
|
| 668 |
if isinstance(node.op, ast.USub):
|
| 669 |
x = eval_node(node.operand)
|
| 670 |
+
return bits_to_int(run_prefix("float16.neg", {"x": x}).bits)
|
| 671 |
raise RuntimeError("unsupported unary operator")
|
| 672 |
if isinstance(node, ast.BinOp):
|
| 673 |
a = eval_node(node.left)
|
| 674 |
b = eval_node(node.right)
|
| 675 |
if isinstance(node.op, ast.Add):
|
| 676 |
+
return bits_to_int(run_prefix("float16.add", {"a": a, "b": b}).bits)
|
| 677 |
if isinstance(node.op, ast.Sub):
|
| 678 |
b_flip = b ^ 0x8000
|
| 679 |
+
return bits_to_int(run_prefix("float16.sub", {"a": a, "b": b_flip}).bits)
|
| 680 |
if isinstance(node.op, ast.Mult):
|
| 681 |
+
return bits_to_int(run_prefix("float16.mul", {"a": a, "b": b}).bits)
|
| 682 |
if isinstance(node.op, ast.Div):
|
| 683 |
+
return bits_to_int(run_prefix("float16.div", {"a": a, "b": b}).bits)
|
| 684 |
if isinstance(node.op, ast.Pow):
|
| 685 |
+
return bits_to_int(run_prefix("float16.pow", {"a": a, "b": b}).bits)
|
| 686 |
raise RuntimeError("unsupported binary operator")
|
| 687 |
if isinstance(node, ast.Call):
|
| 688 |
if not isinstance(node.func, ast.Name):
|
|
|
|
| 692 |
raise RuntimeError(f"{fname} expects one argument")
|
| 693 |
x = eval_node(node.args[0])
|
| 694 |
if fname == "sqrt":
|
| 695 |
+
return run_unary("float16.sqrt", x, fname)
|
| 696 |
if fname == "rsqrt":
|
| 697 |
+
return run_unary("float16.rsqrt", x, fname)
|
| 698 |
if fname == "exp":
|
| 699 |
+
return run_unary("float16.exp", x, fname)
|
| 700 |
if fname in ("ln", "log"):
|
| 701 |
+
return run_unary("float16.ln", x, fname)
|
| 702 |
if fname == "log2":
|
| 703 |
+
return run_unary("float16.log2", x, fname)
|
| 704 |
if fname == "log10":
|
| 705 |
+
return run_unary("float16.log10", x, fname)
|
| 706 |
if fname == "sin":
|
| 707 |
prefix = "float16.sin_deg" if use_degrees else "float16.sin"
|
| 708 |
+
return run_unary(prefix, x, fname)
|
| 709 |
if fname == "cos":
|
| 710 |
prefix = "float16.cos_deg" if use_degrees else "float16.cos"
|
| 711 |
+
return run_unary(prefix, x, fname)
|
| 712 |
if fname == "tan":
|
| 713 |
prefix = "float16.tan_deg" if use_degrees else "float16.tan"
|
| 714 |
+
return run_unary(prefix, x, fname)
|
| 715 |
if fname == "tanh":
|
| 716 |
+
return run_unary("float16.tanh", x, fname)
|
| 717 |
if fname == "asin":
|
| 718 |
prefix = "float16.asin_deg" if use_degrees else "float16.asin"
|
| 719 |
+
return run_unary(prefix, x, fname)
|
| 720 |
if fname == "acos":
|
| 721 |
prefix = "float16.acos_deg" if use_degrees else "float16.acos"
|
| 722 |
+
return run_unary(prefix, x, fname)
|
| 723 |
if fname == "atan":
|
| 724 |
prefix = "float16.atan_deg" if use_degrees else "float16.atan"
|
| 725 |
+
return run_unary(prefix, x, fname)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
if fname == "sinh":
|
| 727 |
+
return run_unary("float16.sinh", x, fname)
|
| 728 |
if fname == "cosh":
|
| 729 |
+
return run_unary("float16.cosh", x, fname)
|
| 730 |
if fname == "floor":
|
| 731 |
+
return run_unary("float16.floor", x, fname)
|
| 732 |
if fname == "ceil":
|
| 733 |
+
return run_unary("float16.ceil", x, fname)
|
| 734 |
if fname == "round":
|
| 735 |
+
return run_unary("float16.round", x, fname)
|
| 736 |
if fname == "abs":
|
| 737 |
+
return run_unary("float16.abs", x, fname)
|
| 738 |
if fname == "neg":
|
| 739 |
+
return run_unary("float16.neg", x, fname)
|
| 740 |
raise RuntimeError(f"unsupported function: {fname}")
|
| 741 |
raise RuntimeError("unsupported expression")
|
| 742 |
|
|
|
|
| 744 |
if total_gates == 0:
|
| 745 |
if force_gate_eval:
|
| 746 |
# Route constants through float16.add with +0 to ensure gate-level evaluation.
|
| 747 |
+
out_bits = bits_to_int(run_prefix("float16.add", {"a": out_bits, "b": 0}).bits)
|
| 748 |
else:
|
| 749 |
non_gate_events.append("constant_expression_no_gates")
|
| 750 |
return EvalResult(
|
eval.py
CHANGED
|
@@ -716,6 +716,33 @@ def eval_float16_lut_outputs(ctx: EvalContext, op_prefix: str,
|
|
| 716 |
return outputs
|
| 717 |
|
| 718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
def build_float16_pairs(rng: random.Random, count: int) -> List[Tuple[int, int]]:
|
| 720 |
"""Build deterministic float16 test pairs using edge cases + random."""
|
| 721 |
edges = [
|
|
@@ -892,6 +919,20 @@ def float16_expected_bits_pow(a_bits: int, b_bits: int) -> Tuple[int, bool]:
|
|
| 892 |
return float_to_int(float(out)), False
|
| 893 |
|
| 894 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 895 |
# =============================================================================
|
| 896 |
# BOOLEAN GATE TESTS
|
| 897 |
# =============================================================================
|
|
@@ -2760,6 +2801,44 @@ def test_float16_unary(ctx: EvalContext) -> List[TestResult]:
|
|
| 2760 |
return results
|
| 2761 |
|
| 2762 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2763 |
def test_float16_pow(ctx: EvalContext) -> List[TestResult]:
|
| 2764 |
"""Test float16.pow (defined as exp(b * ln(a)))."""
|
| 2765 |
results: List[TestResult] = []
|
|
@@ -2855,6 +2934,7 @@ CATEGORIES = {
|
|
| 2855 |
"float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
|
| 2856 |
"float16_conv": ("Float16 - Conversion", test_float16_conversion),
|
| 2857 |
"float16_unary": ("Float16 - Unary LUT", test_float16_unary),
|
|
|
|
| 2858 |
"float16_pow": ("Float16 - Pow", test_float16_pow),
|
| 2859 |
}
|
| 2860 |
|
|
|
|
| 716 |
return outputs
|
| 717 |
|
| 718 |
|
| 719 |
+
def eval_float16_lut_flag(ctx: EvalContext, op_prefix: str,
|
| 720 |
+
bits: List[float],
|
| 721 |
+
flag: str = "domain",
|
| 722 |
+
match_prefix: str = "float16.lut") -> float:
|
| 723 |
+
"""Evaluate a LUT-backed 1-bit flag using direct LUT indexing."""
|
| 724 |
+
idx = bits_to_int(bits)
|
| 725 |
+
match_gate = f"{match_prefix}.match{idx:04x}"
|
| 726 |
+
for suffix in (".weight", ".bias", ".inputs"):
|
| 727 |
+
key = match_gate + suffix
|
| 728 |
+
if key in ctx.tensors:
|
| 729 |
+
ctx.tested_tensors.add(key)
|
| 730 |
+
|
| 731 |
+
gate = f"{op_prefix}.{flag}"
|
| 732 |
+
weight_key = f"{gate}.weight"
|
| 733 |
+
bias_key = f"{gate}.bias"
|
| 734 |
+
inputs_key = f"{gate}.inputs"
|
| 735 |
+
ctx.tested_tensors.add(weight_key)
|
| 736 |
+
if bias_key in ctx.tensors:
|
| 737 |
+
ctx.tested_tensors.add(bias_key)
|
| 738 |
+
if inputs_key in ctx.tensors:
|
| 739 |
+
ctx.tested_tensors.add(inputs_key)
|
| 740 |
+
|
| 741 |
+
weight = ctx.tensors[weight_key][idx].item()
|
| 742 |
+
bias = ctx.tensors.get(bias_key, torch.tensor([0.0])).item()
|
| 743 |
+
return 1.0 if (weight + bias) >= 0 else 0.0
|
| 744 |
+
|
| 745 |
+
|
| 746 |
def build_float16_pairs(rng: random.Random, count: int) -> List[Tuple[int, int]]:
|
| 747 |
"""Build deterministic float16 test pairs using edge cases + random."""
|
| 748 |
edges = [
|
|
|
|
| 919 |
return float_to_int(float(out)), False
|
| 920 |
|
| 921 |
|
| 922 |
+
def float16_expected_domain(op: str, a_bits: int) -> int:
|
| 923 |
+
"""Compute expected domain flag (1=invalid) for unary ops."""
|
| 924 |
+
a = float16_int_to_float(a_bits)
|
| 925 |
+
if a != a:
|
| 926 |
+
return 1
|
| 927 |
+
if op in ("sqrt", "rsqrt") and a < 0:
|
| 928 |
+
return 1
|
| 929 |
+
if op in ("ln", "log2", "log10") and a <= 0:
|
| 930 |
+
return 1
|
| 931 |
+
if op in ("asin", "acos", "asin_deg", "acos_deg") and abs(a) > 1.0:
|
| 932 |
+
return 1
|
| 933 |
+
return 0
|
| 934 |
+
|
| 935 |
+
|
| 936 |
# =============================================================================
|
| 937 |
# BOOLEAN GATE TESTS
|
| 938 |
# =============================================================================
|
|
|
|
| 2801 |
return results
|
| 2802 |
|
| 2803 |
|
| 2804 |
+
def test_float16_domain_flags(ctx: EvalContext) -> List[TestResult]:
|
| 2805 |
+
"""Test float16 domain flag outputs."""
|
| 2806 |
+
results: List[TestResult] = []
|
| 2807 |
+
rng = random.Random(1337)
|
| 2808 |
+
values = build_float16_values(rng, 256)
|
| 2809 |
+
ops = [
|
| 2810 |
+
("float16.sqrt", "sqrt"),
|
| 2811 |
+
("float16.rsqrt", "rsqrt"),
|
| 2812 |
+
("float16.ln", "ln"),
|
| 2813 |
+
("float16.log2", "log2"),
|
| 2814 |
+
("float16.log10", "log10"),
|
| 2815 |
+
("float16.asin", "asin"),
|
| 2816 |
+
("float16.acos", "acos"),
|
| 2817 |
+
("float16.asin_deg", "asin_deg"),
|
| 2818 |
+
("float16.acos_deg", "acos_deg"),
|
| 2819 |
+
]
|
| 2820 |
+
for prefix, op in ops:
|
| 2821 |
+
if f"{prefix}.domain.weight" not in ctx.tensors:
|
| 2822 |
+
continue
|
| 2823 |
+
passed, total = 0, 0
|
| 2824 |
+
failures: List[Dict[str, Any]] = []
|
| 2825 |
+
for a_bits in values:
|
| 2826 |
+
bits_list = [float((a_bits >> i) & 1) for i in range(16)]
|
| 2827 |
+
actual = eval_float16_lut_flag(ctx, prefix, bits_list)
|
| 2828 |
+
expected = float16_expected_domain(op, a_bits)
|
| 2829 |
+
total += 1
|
| 2830 |
+
if int(actual) == expected:
|
| 2831 |
+
passed += 1
|
| 2832 |
+
elif len(failures) < 8:
|
| 2833 |
+
failures.append({
|
| 2834 |
+
"input": hex(a_bits),
|
| 2835 |
+
"actual": int(actual),
|
| 2836 |
+
"expected": expected,
|
| 2837 |
+
})
|
| 2838 |
+
results.append(TestResult(f"{prefix}.domain", passed, total, failures))
|
| 2839 |
+
return results
|
| 2840 |
+
|
| 2841 |
+
|
| 2842 |
def test_float16_pow(ctx: EvalContext) -> List[TestResult]:
|
| 2843 |
"""Test float16.pow (defined as exp(b * ln(a)))."""
|
| 2844 |
results: List[TestResult] = []
|
|
|
|
| 2934 |
"float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
|
| 2935 |
"float16_conv": ("Float16 - Conversion", test_float16_conversion),
|
| 2936 |
"float16_unary": ("Float16 - Unary LUT", test_float16_unary),
|
| 2937 |
+
"float16_domain": ("Float16 - Domain Flags", test_float16_domain_flags),
|
| 2938 |
"float16_pow": ("Float16 - Pow", test_float16_pow),
|
| 2939 |
}
|
| 2940 |
|