CharlesCNorton commited on
Commit
313da7e
·
1 Parent(s): ad82d36

Add float16 constant circuits

Browse files
Files changed (4) hide show
  1. arithmetic.safetensors +2 -2
  2. build.py +28 -0
  3. calculator.py +24 -4
  4. eval.py +27 -0
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a94399cf37d56ba839e1dbae35510cfa6eccebc6e413e5200c8df54d707552f0
3
- size 443825664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d97b4a0b0f1f8d8c858c46e3077140092e54fb7648021aa757d08284e920e2a0
3
+ size 443848032
build.py CHANGED
@@ -69,6 +69,21 @@ def float16_float_to_bits(val: float) -> int:
69
  return 0x7BFF if val > 0 else 0xFBFF
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def compute_float16_unary_lut_outputs(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> List[int]:
73
  """Compute output bits for all 65536 float16 inputs using a unary op."""
74
  outputs: List[int] = [0] * 65536
@@ -1353,6 +1368,8 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
1353
 
1354
  # Float16 circuits
1355
  if gate.startswith('float16.'):
 
 
1356
  if gate.endswith('.domain_not'):
1357
  prefix = gate[:-len('.domain_not')]
1358
  registry.register(f"{prefix}.domain")
@@ -11260,6 +11277,17 @@ def main():
11260
  tensors.update(fromint_tensors)
11261
  print(f" float16.fromint: {len(fromint_tensors)} tensors")
11262
 
 
 
 
 
 
 
 
 
 
 
 
11263
  # Shared LUT match gates
11264
  lut_match_tensors = build_float16_lut_match_tensors("float16.lut")
11265
  tensors.update(lut_match_tensors)
 
69
  return 0x7BFF if val > 0 else 0xFBFF
70
 
71
 
72
+ def build_float16_const_tensors(prefix: str, value: float) -> Dict[str, torch.Tensor]:
73
+ """Build constant float16 outputs (prefix.out0..out15) using #1 input."""
74
+ tensors: Dict[str, torch.Tensor] = {}
75
+ bits = float16_float_to_bits(value)
76
+ for i in range(16):
77
+ bit = (bits >> i) & 1
78
+ if bit:
79
+ tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0])
80
+ tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5])
81
+ else:
82
+ tensors[f"{prefix}.out{i}.weight"] = torch.tensor([-1.0])
83
+ tensors[f"{prefix}.out{i}.bias"] = torch.tensor([0.0])
84
+ return tensors
85
+
86
+
87
  def compute_float16_unary_lut_outputs(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> List[int]:
88
  """Compute output bits for all 65536 float16 inputs using a unary op."""
89
  outputs: List[int] = [0] * 65536
 
1368
 
1369
  # Float16 circuits
1370
  if gate.startswith('float16.'):
1371
+ if gate.startswith('float16.const_'):
1372
+ return [registry.get_id("#1")]
1373
  if gate.endswith('.domain_not'):
1374
  prefix = gate[:-len('.domain_not')]
1375
  registry.register(f"{prefix}.domain")
 
11277
  tensors.update(fromint_tensors)
11278
  print(f" float16.fromint: {len(fromint_tensors)} tensors")
11279
 
11280
+ const_map = {
11281
+ "pi": math.pi,
11282
+ "e": math.e,
11283
+ "deg2rad": math.pi / 180.0,
11284
+ "rad2deg": 180.0 / math.pi,
11285
+ }
11286
+ for name, value in const_map.items():
11287
+ const_tensors = build_float16_const_tensors(f"float16.const_{name}", value)
11288
+ tensors.update(const_tensors)
11289
+ print(f" float16.const_{name}: {len(const_tensors)} tensors")
11290
+
11291
  # Shared LUT match gates
11292
  lut_match_tensors = build_float16_lut_match_tensors("float16.lut")
11293
  tensors.update(lut_match_tensors)
calculator.py CHANGED
@@ -159,6 +159,7 @@ class ThresholdCalculator:
159
  self._id_to_gate: Dict[int, str] = {}
160
  self._topo_cache: Dict[Tuple[str, Tuple[str, ...]], List[str]] = {}
161
  self._compiled: Dict[Tuple[str, Tuple[str, ...]], CompiledCircuit] = {}
 
162
  self._load()
163
 
164
  def _load(self) -> None:
@@ -493,6 +494,17 @@ class ThresholdCalculator:
493
  out_int = bits_to_int(result.bits)
494
  return float16_bits_to_float(out_int), result
495
 
 
 
 
 
 
 
 
 
 
 
 
496
  def evaluate_rpn(
497
  self,
498
  tokens: Sequence[str],
@@ -525,9 +537,13 @@ class ThresholdCalculator:
525
 
526
  def const_to_bits(tok: str) -> int:
527
  if tok == "pi":
528
- return float_to_float16_bits(math.pi)
529
  if tok == "e":
530
- return float_to_float16_bits(math.e)
 
 
 
 
531
  if tok == "inf":
532
  return float_to_float16_bits(float("inf"))
533
  if tok == "nan":
@@ -671,9 +687,13 @@ class ThresholdCalculator:
671
  if isinstance(node, ast.Name):
672
  name = node.id
673
  if name == "pi":
674
- return float_to_float16_bits(math.pi)
675
  if name == "e":
676
- return float_to_float16_bits(math.e)
 
 
 
 
677
  if name == "inf":
678
  return float_to_float16_bits(float("inf"))
679
  if name == "nan":
 
159
  self._id_to_gate: Dict[int, str] = {}
160
  self._topo_cache: Dict[Tuple[str, Tuple[str, ...]], List[str]] = {}
161
  self._compiled: Dict[Tuple[str, Tuple[str, ...]], CompiledCircuit] = {}
162
+ self._const_cache: Dict[str, int] = {}
163
  self._load()
164
 
165
  def _load(self) -> None:
 
494
  out_int = bits_to_int(result.bits)
495
  return float16_bits_to_float(out_int), result
496
 
497
+ def _const_bits(self, name: str, fallback: float) -> int:
498
+ if name in self._const_cache:
499
+ return self._const_cache[name]
500
+ prefix = f"float16.const_{name}"
501
+ if f"{prefix}.out0.weight" in self.tensors:
502
+ res = self.evaluate_prefix(prefix, {}, out_bits=16)
503
+ self._const_cache[name] = bits_to_int(res.bits)
504
+ else:
505
+ self._const_cache[name] = float_to_float16_bits(fallback)
506
+ return self._const_cache[name]
507
+
508
  def evaluate_rpn(
509
  self,
510
  tokens: Sequence[str],
 
537
 
538
  def const_to_bits(tok: str) -> int:
539
  if tok == "pi":
540
+ return self._const_bits("pi", math.pi)
541
  if tok == "e":
542
+ return self._const_bits("e", math.e)
543
+ if tok == "deg2rad":
544
+ return self._const_bits("deg2rad", math.pi / 180.0)
545
+ if tok == "rad2deg":
546
+ return self._const_bits("rad2deg", 180.0 / math.pi)
547
  if tok == "inf":
548
  return float_to_float16_bits(float("inf"))
549
  if tok == "nan":
 
687
  if isinstance(node, ast.Name):
688
  name = node.id
689
  if name == "pi":
690
+ return self._const_bits("pi", math.pi)
691
  if name == "e":
692
+ return self._const_bits("e", math.e)
693
+ if name == "deg2rad":
694
+ return self._const_bits("deg2rad", math.pi / 180.0)
695
+ if name == "rad2deg":
696
+ return self._const_bits("rad2deg", 180.0 / math.pi)
697
  if name == "inf":
698
  return float_to_float16_bits(float("inf"))
699
  if name == "nan":
eval.py CHANGED
@@ -936,6 +936,32 @@ def float16_expected_domain(op: str, a_bits: int) -> int:
936
  return 0
937
 
938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939
  # =============================================================================
940
  # BOOLEAN GATE TESTS
941
  # =============================================================================
@@ -2994,6 +3020,7 @@ CATEGORIES = {
2994
  "float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
2995
  "float16_conv": ("Float16 - Conversion", test_float16_conversion),
2996
  "float16_unary": ("Float16 - Unary LUT", test_float16_unary),
 
2997
  "float16_domain": ("Float16 - Domain Flags", test_float16_domain_flags),
2998
  "float16_checked": ("Float16 - Checked Outputs", test_float16_checked_outputs),
2999
  "float16_pow": ("Float16 - Pow", test_float16_pow),
 
936
  return 0
937
 
938
 
939
+ def test_float16_constants(ctx: EvalContext) -> List[TestResult]:
940
+ """Test float16 constant-output circuits."""
941
+ results: List[TestResult] = []
942
+ consts = {
943
+ "float16.const_pi": math.pi,
944
+ "float16.const_e": math.e,
945
+ "float16.const_deg2rad": math.pi / 180.0,
946
+ "float16.const_rad2deg": 180.0 / math.pi,
947
+ }
948
+ for prefix, value in consts.items():
949
+ if f"{prefix}.out0.weight" not in ctx.tensors:
950
+ continue
951
+ expected = float_to_int(value)
952
+ actual_bits = eval_prefix_outputs(ctx, prefix, {})
953
+ actual = bits_to_int(actual_bits)
954
+ passed = 1 if actual == expected else 0
955
+ failures = []
956
+ if not passed:
957
+ failures.append({
958
+ "expected": hex(expected),
959
+ "actual": hex(actual),
960
+ })
961
+ results.append(TestResult(prefix, passed, 1, failures))
962
+ return results
963
+
964
+
965
  # =============================================================================
966
  # BOOLEAN GATE TESTS
967
  # =============================================================================
 
3020
  "float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
3021
  "float16_conv": ("Float16 - Conversion", test_float16_conversion),
3022
  "float16_unary": ("Float16 - Unary LUT", test_float16_unary),
3023
+ "float16_constants": ("Float16 - Constants", test_float16_constants),
3024
  "float16_domain": ("Float16 - Domain Flags", test_float16_domain_flags),
3025
  "float16_checked": ("Float16 - Checked Outputs", test_float16_checked_outputs),
3026
  "float16_pow": ("Float16 - Pow", test_float16_pow),