PortfolioAI commited on
Commit
689739c
·
1 Parent(s): b8ed073

Add float16.normalize, neg, abs circuits

Browse files

- float16.normalize: CLZ-based shift calculator (51 gates)
- float16.neg: sign flip (16 gates)
- float16.abs: clear sign bit (16 gates)
- All 100% pass rate

TODO.md CHANGED
@@ -3,22 +3,22 @@
3
  ## High Priority
4
 
5
  ### Floating Point Circuits
6
- - [x] `float16.unpack` -- extract sign, exponent, mantissa from IEEE 754 half-precision
7
- - [x] `float16.pack` -- assemble from components
8
- - [ ] `float16.normalize` -- normalize after arithmetic
9
- - [ ] `float16.add` -- 16-bit IEEE 754 addition
10
- - [ ] `float16.sub` -- subtraction
 
11
  - [ ] `float16.mul` -- multiplication
12
  - [ ] `float16.div` -- division
13
- - [x] `float16.cmp` -- comparison (>)
14
- - [ ] `float16.neg` -- negation
15
- - [ ] `float16.abs` -- absolute value
16
  - [ ] `float16.toint` -- convert to integer
17
  - [ ] `float16.fromint` -- convert from integer
18
 
19
  ### Supporting Infrastructure
20
- - [x] `arithmetic.clz8bit` -- count leading zeros (needed for float normalization)
21
- - [x] `arithmetic.clz16bit` -- 16-bit count leading zeros
22
 
23
  ## Medium Priority
24
 
 
3
  ## High Priority
4
 
5
  ### Floating Point Circuits
6
+ - [x] `float16.unpack` -- extract sign, exponent, mantissa (16 gates, 63/63 tests)
7
+ - [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
8
+ - [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
9
+ - [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
10
+ - [ ] `float16.add` -- IEEE 754 addition (requires normalize + align + add)
11
+ - [ ] `float16.sub` -- subtraction (add with negated operand)
12
  - [ ] `float16.mul` -- multiplication
13
  - [ ] `float16.div` -- division
14
+ - [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
15
+ - [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
 
16
  - [ ] `float16.toint` -- convert to integer
17
  - [ ] `float16.fromint` -- convert from integer
18
 
19
  ### Supporting Infrastructure
20
+ - [x] `arithmetic.clz8bit` -- 8-bit count leading zeros (30 gates, 256/256 tests)
21
+ - [x] `arithmetic.clz16bit` -- 16-bit count leading zeros (63 gates, 217/217 tests)
22
 
23
  ## Medium Priority
24
 
__pycache__/eval.cpython-311.pyc ADDED
Binary file (41.8 kB). View file
 
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ebe8e155f964f27d26a8a35750f6af361556a65c1178a1c96e4dd5eea95a66c4
3
- size 1111188
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b16619fd1cda08ab7c9ccf567ef77f8001ff7b6f76b8ed6852ad262fbc8d139
3
+ size 1140364
arithmetic_legacy.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b53234c708c9f134e154f7e8ddbc251ea9a89e087fc34693c69963f3e21a6be0
3
+ size 575300
convert_to_explicit_inputs.py CHANGED
@@ -1052,11 +1052,53 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
1052
  return infer_float16_cmp_inputs(gate, registry)
1053
  if 'normalize' in gate:
1054
  return infer_float16_normalize_inputs(gate, registry)
 
 
 
 
1055
 
1056
  # Default: couldn't infer, return empty (will need manual fix or routing)
1057
  return []
1058
 
1059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1060
  def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1061
  """Infer inputs for float16.normalize circuit."""
1062
  prefix = "float16.normalize"
@@ -1106,7 +1148,7 @@ def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[
1106
  k = int(match.group(1))
1107
  return [registry.get_id(f"{prefix}.ge{k}")]
1108
 
1109
- for k in [2, 4, 8]:
1110
  registry.register(f"{prefix}.not_ge{k}")
1111
 
1112
  # AND gates for ranges
@@ -1117,8 +1159,7 @@ def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[
1117
  if '.and_6_7' in gate:
1118
  return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
1119
  if '.and_10_11' in gate:
1120
- return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.ge12")]
1121
- # Note: and_10_11 should be ge10 AND NOT ge12, but we don't have not_ge12
1122
 
1123
  # Odd AND gates
1124
  match = re.search(r'\.and_(\d+)$', gate)
@@ -1330,6 +1371,49 @@ def infer_float16_unpack_inputs(gate: str, registry: SignalRegistry) -> List[int
1330
  return []
1331
 
1332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1333
  def build_float16_normalize_tensors() -> Dict[str, torch.Tensor]:
1334
  """Build tensors for float16.normalize circuit.
1335
 
@@ -1733,6 +1817,18 @@ def main():
1733
  tensors.update(cmp_tensors)
1734
  print(f" float16.cmp: {len(cmp_tensors)} tensors")
1735
 
 
 
 
 
 
 
 
 
 
 
 
 
1736
  print(f"Total tensors: {len(tensors)}")
1737
 
1738
  # Load routing for complex circuits
 
1052
  return infer_float16_cmp_inputs(gate, registry)
1053
  if 'normalize' in gate:
1054
  return infer_float16_normalize_inputs(gate, registry)
1055
+ if gate.startswith('float16.neg'):
1056
+ return infer_float16_neg_inputs(gate, registry)
1057
+ if gate.startswith('float16.abs'):
1058
+ return infer_float16_abs_inputs(gate, registry)
1059
 
1060
  # Default: couldn't infer, return empty (will need manual fix or routing)
1061
  return []
1062
 
1063
 
1064
+ def infer_float16_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1065
+ """Infer inputs for float16.neg circuit."""
1066
+ prefix = "float16.neg"
1067
+
1068
+ # Register 16-bit input
1069
+ for i in range(16):
1070
+ registry.register(f"{prefix}.$x[{i}]")
1071
+
1072
+ # Output gates
1073
+ match = re.search(r'\.out(\d+)', gate)
1074
+ if match:
1075
+ i = int(match.group(1))
1076
+ return [registry.get_id(f"{prefix}.$x[{i}]")]
1077
+
1078
+ return []
1079
+
1080
+
1081
+ def infer_float16_abs_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1082
+ """Infer inputs for float16.abs circuit."""
1083
+ prefix = "float16.abs"
1084
+
1085
+ # Register 16-bit input
1086
+ for i in range(16):
1087
+ registry.register(f"{prefix}.$x[{i}]")
1088
+
1089
+ # Output gates
1090
+ match = re.search(r'\.out(\d+)', gate)
1091
+ if match:
1092
+ i = int(match.group(1))
1093
+ if i == 15:
1094
+ # Sign bit output doesn't depend on input (always 0)
1095
+ # But we still need an input for the gate structure
1096
+ return [registry.get_id(f"{prefix}.$x[15]")]
1097
+ return [registry.get_id(f"{prefix}.$x[{i}]")]
1098
+
1099
+ return []
1100
+
1101
+
1102
  def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[int]:
1103
  """Infer inputs for float16.normalize circuit."""
1104
  prefix = "float16.normalize"
 
1148
  k = int(match.group(1))
1149
  return [registry.get_id(f"{prefix}.ge{k}")]
1150
 
1151
+ for k in [2, 4, 6, 8, 10, 12]:
1152
  registry.register(f"{prefix}.not_ge{k}")
1153
 
1154
  # AND gates for ranges
 
1159
  if '.and_6_7' in gate:
1160
  return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
1161
  if '.and_10_11' in gate:
1162
+ return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.not_ge12")]
 
1163
 
1164
  # Odd AND gates
1165
  match = re.search(r'\.and_(\d+)$', gate)
 
1371
  return []
1372
 
1373
 
1374
+ def build_float16_neg_tensors() -> Dict[str, torch.Tensor]:
1375
+ """Build tensors for float16.neg circuit.
1376
+
1377
+ Negates a float16 by flipping the sign bit.
1378
+ All other bits pass through unchanged.
1379
+ """
1380
+ tensors = {}
1381
+ prefix = "float16.neg"
1382
+
1383
+ # Sign bit: NOT of input sign
1384
+ tensors[f"{prefix}.out15.weight"] = torch.tensor([-1.0])
1385
+ tensors[f"{prefix}.out15.bias"] = torch.tensor([0.0])
1386
+
1387
+ # All other bits: pass through
1388
+ for i in range(15):
1389
+ tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0])
1390
+ tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5])
1391
+
1392
+ return tensors
1393
+
1394
+
1395
+ def build_float16_abs_tensors() -> Dict[str, torch.Tensor]:
1396
+ """Build tensors for float16.abs circuit.
1397
+
1398
+ Absolute value: clear the sign bit, pass all others.
1399
+ """
1400
+ tensors = {}
1401
+ prefix = "float16.abs"
1402
+
1403
+ # Sign bit: always 0 (use constant #0)
1404
+ # Actually, we can just not output bit 15, or output 0
1405
+ # For consistency, let's output 0 by using bias that never fires
1406
+ tensors[f"{prefix}.out15.weight"] = torch.tensor([1.0])
1407
+ tensors[f"{prefix}.out15.bias"] = torch.tensor([-2.0]) # never fires
1408
+
1409
+ # All other bits: pass through
1410
+ for i in range(15):
1411
+ tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0])
1412
+ tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5])
1413
+
1414
+ return tensors
1415
+
1416
+
1417
  def build_float16_normalize_tensors() -> Dict[str, torch.Tensor]:
1418
  """Build tensors for float16.normalize circuit.
1419
 
 
1817
  tensors.update(cmp_tensors)
1818
  print(f" float16.cmp: {len(cmp_tensors)} tensors")
1819
 
1820
+ norm_tensors = build_float16_normalize_tensors()
1821
+ tensors.update(norm_tensors)
1822
+ print(f" float16.normalize: {len(norm_tensors)} tensors")
1823
+
1824
+ neg_tensors = build_float16_neg_tensors()
1825
+ tensors.update(neg_tensors)
1826
+ print(f" float16.neg: {len(neg_tensors)} tensors")
1827
+
1828
+ abs_tensors = build_float16_abs_tensors()
1829
+ tensors.update(abs_tensors)
1830
+ print(f" float16.abs: {len(abs_tensors)} tensors")
1831
+
1832
  print(f"Total tensors: {len(tensors)}")
1833
 
1834
  # Load routing for complex circuits
eval.py CHANGED
@@ -513,6 +513,125 @@ class CircuitEvaluator:
513
 
514
  return TestResult('float16.cmp', passed, len(test_cases), failures)
515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  # =========================================================================
517
  # ARITHMETIC TESTS (DIRECT EVALUATION)
518
  # =========================================================================
@@ -693,6 +812,21 @@ class Evaluator:
693
  self.results.append(result)
694
  if verbose:
695
  self._print_result(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
697
  # Comparators
698
  if verbose:
 
513
 
514
  return TestResult('float16.cmp', passed, len(test_cases), failures)
515
 
516
+ def test_float16_normalize(self) -> TestResult:
517
+ """Test float16.normalize shift amount calculation."""
518
+ prefix = 'float16.normalize'
519
+ failures = []
520
+ passed = 0
521
+
522
+ # Test cases: 13-bit mantissa values and expected shift amounts
523
+ # Shift amount = CLZ of bits 11:0 (excluding overflow bit 12)
524
+ test_cases = [
525
+ (0b1_000000000000, 0), # Overflow bit set -> shift 0
526
+ (0b0_100000000000, 0), # Bit 11 set -> CLZ=0
527
+ (0b0_010000000000, 1), # Bit 10 set -> CLZ=1
528
+ (0b0_001000000000, 2), # Bit 9 set -> CLZ=2
529
+ (0b0_000100000000, 3), # etc
530
+ (0b0_000010000000, 4),
531
+ (0b0_000001000000, 5),
532
+ (0b0_000000100000, 6),
533
+ (0b0_000000010000, 7),
534
+ (0b0_000000001000, 8),
535
+ (0b0_000000000100, 9),
536
+ (0b0_000000000010, 10),
537
+ (0b0_000000000001, 11),
538
+ (0b0_000000000000, 12), # All zeros -> CLZ=12 (max shift)
539
+ ]
540
+
541
+ for mant, expected_shift in test_cases:
542
+ overflow = (mant >> 12) & 1
543
+
544
+ # Set up inputs
545
+ ext = {}
546
+ for i in range(13):
547
+ ext[f'{prefix}.$m[{i}]'] = float((mant >> i) & 1)
548
+
549
+ values = self.eval_circuit(prefix, ext)
550
+
551
+ # Get shift amount (masked by not_overflow)
552
+ shift = 0
553
+ for i in range(4):
554
+ bit = int(values.get(f'{prefix}.out_shift{i}', 0))
555
+ shift |= (bit << i)
556
+
557
+ # Check overflow detection
558
+ got_overflow = int(values.get(f'{prefix}.overflow', 0))
559
+ is_zero = int(values.get(f'{prefix}.is_zero', 0))
560
+
561
+ # Expected: if overflow, shift output should be 0 (masked)
562
+ if overflow:
563
+ expected_out = 0
564
+ else:
565
+ expected_out = expected_shift
566
+
567
+ if shift == expected_out and got_overflow == overflow:
568
+ passed += 1
569
+ else:
570
+ if len(failures) < 10:
571
+ failures.append((mant, expected_shift, shift, overflow, got_overflow))
572
+
573
+ return TestResult('float16.normalize', passed, len(test_cases), failures)
574
+
575
+ def test_float16_neg(self) -> TestResult:
576
+ """Test float16.neg (sign flip)."""
577
+ prefix = 'float16.neg'
578
+ failures = []
579
+ passed = 0
580
+
581
+ test_values = [0x0000, 0x8000, 0x3C00, 0xBC00, 0x4000, 0x7C00, 0xFC00, 0x7BFF]
582
+
583
+ import random
584
+ random.seed(42)
585
+ for _ in range(50):
586
+ test_values.append(random.randint(0, 0xFFFF))
587
+
588
+ for val in test_values:
589
+ # Expected: flip bit 15
590
+ expected = val ^ 0x8000
591
+
592
+ ext = {f'{prefix}.$x[{i}]': float((val >> i) & 1) for i in range(16)}
593
+ values = self.eval_circuit(prefix, ext)
594
+
595
+ result = sum(int(values.get(f'{prefix}.out{i}', 0)) << i for i in range(16))
596
+
597
+ if result == expected:
598
+ passed += 1
599
+ else:
600
+ if len(failures) < 10:
601
+ failures.append((val, expected, result))
602
+
603
+ return TestResult('float16.neg', passed, len(test_values), failures)
604
+
605
+ def test_float16_abs(self) -> TestResult:
606
+ """Test float16.abs (clear sign bit)."""
607
+ prefix = 'float16.abs'
608
+ failures = []
609
+ passed = 0
610
+
611
+ test_values = [0x0000, 0x8000, 0x3C00, 0xBC00, 0x4000, 0x7C00, 0xFC00, 0x7BFF]
612
+
613
+ import random
614
+ random.seed(42)
615
+ for _ in range(50):
616
+ test_values.append(random.randint(0, 0xFFFF))
617
+
618
+ for val in test_values:
619
+ # Expected: clear bit 15
620
+ expected = val & 0x7FFF
621
+
622
+ ext = {f'{prefix}.$x[{i}]': float((val >> i) & 1) for i in range(16)}
623
+ values = self.eval_circuit(prefix, ext)
624
+
625
+ result = sum(int(values.get(f'{prefix}.out{i}', 0)) << i for i in range(16))
626
+
627
+ if result == expected:
628
+ passed += 1
629
+ else:
630
+ if len(failures) < 10:
631
+ failures.append((val, expected, result))
632
+
633
+ return TestResult('float16.abs', passed, len(test_values), failures)
634
+
635
  # =========================================================================
636
  # ARITHMETIC TESTS (DIRECT EVALUATION)
637
  # =========================================================================
 
812
  self.results.append(result)
813
  if verbose:
814
  self._print_result(result)
815
+ if 'float16.normalize.overflow.weight' in self.eval.tensors:
816
+ result = self.eval.test_float16_normalize()
817
+ self.results.append(result)
818
+ if verbose:
819
+ self._print_result(result)
820
+ if 'float16.neg.out0.weight' in self.eval.tensors:
821
+ result = self.eval.test_float16_neg()
822
+ self.results.append(result)
823
+ if verbose:
824
+ self._print_result(result)
825
+ if 'float16.abs.out0.weight' in self.eval.tensors:
826
+ result = self.eval.test_float16_abs()
827
+ self.results.append(result)
828
+ if verbose:
829
+ self._print_result(result)
830
 
831
  # Comparators
832
  if verbose: