PortfolioAI commited on
Commit ·
689739c
1
Parent(s): b8ed073
Add float16.normalize, neg, abs circuits
Browse files- float16.normalize: CLZ-based shift calculator (51 gates)
- float16.neg: sign flip (16 gates)
- float16.abs: clear sign bit (16 gates)
- All 100% pass rate
- TODO.md +10 -10
- __pycache__/eval.cpython-311.pyc +0 -0
- arithmetic.safetensors +2 -2
- arithmetic_legacy.safetensors +3 -0
- convert_to_explicit_inputs.py +99 -3
- eval.py +134 -0
TODO.md
CHANGED
|
@@ -3,22 +3,22 @@
|
|
| 3 |
## High Priority
|
| 4 |
|
| 5 |
### Floating Point Circuits
|
| 6 |
-
- [x] `float16.unpack` -- extract sign, exponent, mantissa
|
| 7 |
-
- [x] `float16.pack` -- assemble from components
|
| 8 |
-
- [
|
| 9 |
-
- [
|
| 10 |
-
- [ ] `float16.
|
|
|
|
| 11 |
- [ ] `float16.mul` -- multiplication
|
| 12 |
- [ ] `float16.div` -- division
|
| 13 |
-
- [x] `float16.
|
| 14 |
-
- [
|
| 15 |
-
- [ ] `float16.abs` -- absolute value
|
| 16 |
- [ ] `float16.toint` -- convert to integer
|
| 17 |
- [ ] `float16.fromint` -- convert from integer
|
| 18 |
|
| 19 |
### Supporting Infrastructure
|
| 20 |
-
- [x] `arithmetic.clz8bit` -- count leading zeros (
|
| 21 |
-
- [x] `arithmetic.clz16bit` -- 16-bit count leading zeros
|
| 22 |
|
| 23 |
## Medium Priority
|
| 24 |
|
|
|
|
| 3 |
## High Priority
|
| 4 |
|
| 5 |
### Floating Point Circuits
|
| 6 |
+
- [x] `float16.unpack` -- extract sign, exponent, mantissa (16 gates, 63/63 tests)
|
| 7 |
+
- [x] `float16.pack` -- assemble from components (16 gates, 63/63 tests)
|
| 8 |
+
- [x] `float16.cmp` -- comparison a > b (14 gates, 113/113 tests)
|
| 9 |
+
- [x] `float16.normalize` -- CLZ-based shift calculator (51 gates, 14/14 tests)
|
| 10 |
+
- [ ] `float16.add` -- IEEE 754 addition (requires normalize + align + add)
|
| 11 |
+
- [ ] `float16.sub` -- subtraction (add with negated operand)
|
| 12 |
- [ ] `float16.mul` -- multiplication
|
| 13 |
- [ ] `float16.div` -- division
|
| 14 |
+
- [x] `float16.neg` -- sign flip (16 gates, 58/58 tests)
|
| 15 |
+
- [x] `float16.abs` -- clear sign bit (16 gates, 58/58 tests)
|
|
|
|
| 16 |
- [ ] `float16.toint` -- convert to integer
|
| 17 |
- [ ] `float16.fromint` -- convert from integer
|
| 18 |
|
| 19 |
### Supporting Infrastructure
|
| 20 |
+
- [x] `arithmetic.clz8bit` -- 8-bit count leading zeros (30 gates, 256/256 tests)
|
| 21 |
+
- [x] `arithmetic.clz16bit` -- 16-bit count leading zeros (63 gates, 217/217 tests)
|
| 22 |
|
| 23 |
## Medium Priority
|
| 24 |
|
__pycache__/eval.cpython-311.pyc
ADDED
|
Binary file (41.8 kB). View file
|
|
|
arithmetic.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b16619fd1cda08ab7c9ccf567ef77f8001ff7b6f76b8ed6852ad262fbc8d139
|
| 3 |
+
size 1140364
|
arithmetic_legacy.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b53234c708c9f134e154f7e8ddbc251ea9a89e087fc34693c69963f3e21a6be0
|
| 3 |
+
size 575300
|
convert_to_explicit_inputs.py
CHANGED
|
@@ -1052,11 +1052,53 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
|
|
| 1052 |
return infer_float16_cmp_inputs(gate, registry)
|
| 1053 |
if 'normalize' in gate:
|
| 1054 |
return infer_float16_normalize_inputs(gate, registry)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
# Default: couldn't infer, return empty (will need manual fix or routing)
|
| 1057 |
return []
|
| 1058 |
|
| 1059 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1061 |
"""Infer inputs for float16.normalize circuit."""
|
| 1062 |
prefix = "float16.normalize"
|
|
@@ -1106,7 +1148,7 @@ def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[
|
|
| 1106 |
k = int(match.group(1))
|
| 1107 |
return [registry.get_id(f"{prefix}.ge{k}")]
|
| 1108 |
|
| 1109 |
-
for k in [2, 4, 8]:
|
| 1110 |
registry.register(f"{prefix}.not_ge{k}")
|
| 1111 |
|
| 1112 |
# AND gates for ranges
|
|
@@ -1117,8 +1159,7 @@ def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[
|
|
| 1117 |
if '.and_6_7' in gate:
|
| 1118 |
return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
|
| 1119 |
if '.and_10_11' in gate:
|
| 1120 |
-
return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.
|
| 1121 |
-
# Note: and_10_11 should be ge10 AND NOT ge12, but we don't have not_ge12
|
| 1122 |
|
| 1123 |
# Odd AND gates
|
| 1124 |
match = re.search(r'\.and_(\d+)$', gate)
|
|
@@ -1330,6 +1371,49 @@ def infer_float16_unpack_inputs(gate: str, registry: SignalRegistry) -> List[int
|
|
| 1330 |
return []
|
| 1331 |
|
| 1332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1333 |
def build_float16_normalize_tensors() -> Dict[str, torch.Tensor]:
|
| 1334 |
"""Build tensors for float16.normalize circuit.
|
| 1335 |
|
|
@@ -1733,6 +1817,18 @@ def main():
|
|
| 1733 |
tensors.update(cmp_tensors)
|
| 1734 |
print(f" float16.cmp: {len(cmp_tensors)} tensors")
|
| 1735 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1736 |
print(f"Total tensors: {len(tensors)}")
|
| 1737 |
|
| 1738 |
# Load routing for complex circuits
|
|
|
|
| 1052 |
return infer_float16_cmp_inputs(gate, registry)
|
| 1053 |
if 'normalize' in gate:
|
| 1054 |
return infer_float16_normalize_inputs(gate, registry)
|
| 1055 |
+
if gate.startswith('float16.neg'):
|
| 1056 |
+
return infer_float16_neg_inputs(gate, registry)
|
| 1057 |
+
if gate.startswith('float16.abs'):
|
| 1058 |
+
return infer_float16_abs_inputs(gate, registry)
|
| 1059 |
|
| 1060 |
# Default: couldn't infer, return empty (will need manual fix or routing)
|
| 1061 |
return []
|
| 1062 |
|
| 1063 |
|
| 1064 |
+
def infer_float16_neg_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1065 |
+
"""Infer inputs for float16.neg circuit."""
|
| 1066 |
+
prefix = "float16.neg"
|
| 1067 |
+
|
| 1068 |
+
# Register 16-bit input
|
| 1069 |
+
for i in range(16):
|
| 1070 |
+
registry.register(f"{prefix}.$x[{i}]")
|
| 1071 |
+
|
| 1072 |
+
# Output gates
|
| 1073 |
+
match = re.search(r'\.out(\d+)', gate)
|
| 1074 |
+
if match:
|
| 1075 |
+
i = int(match.group(1))
|
| 1076 |
+
return [registry.get_id(f"{prefix}.$x[{i}]")]
|
| 1077 |
+
|
| 1078 |
+
return []
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
def infer_float16_abs_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1082 |
+
"""Infer inputs for float16.abs circuit."""
|
| 1083 |
+
prefix = "float16.abs"
|
| 1084 |
+
|
| 1085 |
+
# Register 16-bit input
|
| 1086 |
+
for i in range(16):
|
| 1087 |
+
registry.register(f"{prefix}.$x[{i}]")
|
| 1088 |
+
|
| 1089 |
+
# Output gates
|
| 1090 |
+
match = re.search(r'\.out(\d+)', gate)
|
| 1091 |
+
if match:
|
| 1092 |
+
i = int(match.group(1))
|
| 1093 |
+
if i == 15:
|
| 1094 |
+
# Sign bit output doesn't depend on input (always 0)
|
| 1095 |
+
# But we still need an input for the gate structure
|
| 1096 |
+
return [registry.get_id(f"{prefix}.$x[15]")]
|
| 1097 |
+
return [registry.get_id(f"{prefix}.$x[{i}]")]
|
| 1098 |
+
|
| 1099 |
+
return []
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
def infer_float16_normalize_inputs(gate: str, registry: SignalRegistry) -> List[int]:
|
| 1103 |
"""Infer inputs for float16.normalize circuit."""
|
| 1104 |
prefix = "float16.normalize"
|
|
|
|
| 1148 |
k = int(match.group(1))
|
| 1149 |
return [registry.get_id(f"{prefix}.ge{k}")]
|
| 1150 |
|
| 1151 |
+
for k in [2, 4, 6, 8, 10, 12]:
|
| 1152 |
registry.register(f"{prefix}.not_ge{k}")
|
| 1153 |
|
| 1154 |
# AND gates for ranges
|
|
|
|
| 1159 |
if '.and_6_7' in gate:
|
| 1160 |
return [registry.get_id(f"{prefix}.ge6"), registry.get_id(f"{prefix}.not_ge8")]
|
| 1161 |
if '.and_10_11' in gate:
|
| 1162 |
+
return [registry.get_id(f"{prefix}.ge10"), registry.get_id(f"{prefix}.not_ge12")]
|
|
|
|
| 1163 |
|
| 1164 |
# Odd AND gates
|
| 1165 |
match = re.search(r'\.and_(\d+)$', gate)
|
|
|
|
| 1371 |
return []
|
| 1372 |
|
| 1373 |
|
| 1374 |
+
def build_float16_neg_tensors() -> Dict[str, torch.Tensor]:
|
| 1375 |
+
"""Build tensors for float16.neg circuit.
|
| 1376 |
+
|
| 1377 |
+
Negates a float16 by flipping the sign bit.
|
| 1378 |
+
All other bits pass through unchanged.
|
| 1379 |
+
"""
|
| 1380 |
+
tensors = {}
|
| 1381 |
+
prefix = "float16.neg"
|
| 1382 |
+
|
| 1383 |
+
# Sign bit: NOT of input sign
|
| 1384 |
+
tensors[f"{prefix}.out15.weight"] = torch.tensor([-1.0])
|
| 1385 |
+
tensors[f"{prefix}.out15.bias"] = torch.tensor([0.0])
|
| 1386 |
+
|
| 1387 |
+
# All other bits: pass through
|
| 1388 |
+
for i in range(15):
|
| 1389 |
+
tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0])
|
| 1390 |
+
tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5])
|
| 1391 |
+
|
| 1392 |
+
return tensors
|
| 1393 |
+
|
| 1394 |
+
|
| 1395 |
+
def build_float16_abs_tensors() -> Dict[str, torch.Tensor]:
|
| 1396 |
+
"""Build tensors for float16.abs circuit.
|
| 1397 |
+
|
| 1398 |
+
Absolute value: clear the sign bit, pass all others.
|
| 1399 |
+
"""
|
| 1400 |
+
tensors = {}
|
| 1401 |
+
prefix = "float16.abs"
|
| 1402 |
+
|
| 1403 |
+
# Sign bit: always 0 (use constant #0)
|
| 1404 |
+
# Actually, we can just not output bit 15, or output 0
|
| 1405 |
+
# For consistency, let's output 0 by using bias that never fires
|
| 1406 |
+
tensors[f"{prefix}.out15.weight"] = torch.tensor([1.0])
|
| 1407 |
+
tensors[f"{prefix}.out15.bias"] = torch.tensor([-2.0]) # never fires
|
| 1408 |
+
|
| 1409 |
+
# All other bits: pass through
|
| 1410 |
+
for i in range(15):
|
| 1411 |
+
tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0])
|
| 1412 |
+
tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5])
|
| 1413 |
+
|
| 1414 |
+
return tensors
|
| 1415 |
+
|
| 1416 |
+
|
| 1417 |
def build_float16_normalize_tensors() -> Dict[str, torch.Tensor]:
|
| 1418 |
"""Build tensors for float16.normalize circuit.
|
| 1419 |
|
|
|
|
| 1817 |
tensors.update(cmp_tensors)
|
| 1818 |
print(f" float16.cmp: {len(cmp_tensors)} tensors")
|
| 1819 |
|
| 1820 |
+
norm_tensors = build_float16_normalize_tensors()
|
| 1821 |
+
tensors.update(norm_tensors)
|
| 1822 |
+
print(f" float16.normalize: {len(norm_tensors)} tensors")
|
| 1823 |
+
|
| 1824 |
+
neg_tensors = build_float16_neg_tensors()
|
| 1825 |
+
tensors.update(neg_tensors)
|
| 1826 |
+
print(f" float16.neg: {len(neg_tensors)} tensors")
|
| 1827 |
+
|
| 1828 |
+
abs_tensors = build_float16_abs_tensors()
|
| 1829 |
+
tensors.update(abs_tensors)
|
| 1830 |
+
print(f" float16.abs: {len(abs_tensors)} tensors")
|
| 1831 |
+
|
| 1832 |
print(f"Total tensors: {len(tensors)}")
|
| 1833 |
|
| 1834 |
# Load routing for complex circuits
|
eval.py
CHANGED
|
@@ -513,6 +513,125 @@ class CircuitEvaluator:
|
|
| 513 |
|
| 514 |
return TestResult('float16.cmp', passed, len(test_cases), failures)
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
# =========================================================================
|
| 517 |
# ARITHMETIC TESTS (DIRECT EVALUATION)
|
| 518 |
# =========================================================================
|
|
@@ -693,6 +812,21 @@ class Evaluator:
|
|
| 693 |
self.results.append(result)
|
| 694 |
if verbose:
|
| 695 |
self._print_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
# Comparators
|
| 698 |
if verbose:
|
|
|
|
| 513 |
|
| 514 |
return TestResult('float16.cmp', passed, len(test_cases), failures)
|
| 515 |
|
| 516 |
+
def test_float16_normalize(self) -> TestResult:
|
| 517 |
+
"""Test float16.normalize shift amount calculation."""
|
| 518 |
+
prefix = 'float16.normalize'
|
| 519 |
+
failures = []
|
| 520 |
+
passed = 0
|
| 521 |
+
|
| 522 |
+
# Test cases: 13-bit mantissa values and expected shift amounts
|
| 523 |
+
# Shift amount = CLZ of bits 11:0 (excluding overflow bit 12)
|
| 524 |
+
test_cases = [
|
| 525 |
+
(0b1_000000000000, 0), # Overflow bit set -> shift 0
|
| 526 |
+
(0b0_100000000000, 0), # Bit 11 set -> CLZ=0
|
| 527 |
+
(0b0_010000000000, 1), # Bit 10 set -> CLZ=1
|
| 528 |
+
(0b0_001000000000, 2), # Bit 9 set -> CLZ=2
|
| 529 |
+
(0b0_000100000000, 3), # etc
|
| 530 |
+
(0b0_000010000000, 4),
|
| 531 |
+
(0b0_000001000000, 5),
|
| 532 |
+
(0b0_000000100000, 6),
|
| 533 |
+
(0b0_000000010000, 7),
|
| 534 |
+
(0b0_000000001000, 8),
|
| 535 |
+
(0b0_000000000100, 9),
|
| 536 |
+
(0b0_000000000010, 10),
|
| 537 |
+
(0b0_000000000001, 11),
|
| 538 |
+
(0b0_000000000000, 12), # All zeros -> CLZ=12 (max shift)
|
| 539 |
+
]
|
| 540 |
+
|
| 541 |
+
for mant, expected_shift in test_cases:
|
| 542 |
+
overflow = (mant >> 12) & 1
|
| 543 |
+
|
| 544 |
+
# Set up inputs
|
| 545 |
+
ext = {}
|
| 546 |
+
for i in range(13):
|
| 547 |
+
ext[f'{prefix}.$m[{i}]'] = float((mant >> i) & 1)
|
| 548 |
+
|
| 549 |
+
values = self.eval_circuit(prefix, ext)
|
| 550 |
+
|
| 551 |
+
# Get shift amount (masked by not_overflow)
|
| 552 |
+
shift = 0
|
| 553 |
+
for i in range(4):
|
| 554 |
+
bit = int(values.get(f'{prefix}.out_shift{i}', 0))
|
| 555 |
+
shift |= (bit << i)
|
| 556 |
+
|
| 557 |
+
# Check overflow detection
|
| 558 |
+
got_overflow = int(values.get(f'{prefix}.overflow', 0))
|
| 559 |
+
is_zero = int(values.get(f'{prefix}.is_zero', 0))
|
| 560 |
+
|
| 561 |
+
# Expected: if overflow, shift output should be 0 (masked)
|
| 562 |
+
if overflow:
|
| 563 |
+
expected_out = 0
|
| 564 |
+
else:
|
| 565 |
+
expected_out = expected_shift
|
| 566 |
+
|
| 567 |
+
if shift == expected_out and got_overflow == overflow:
|
| 568 |
+
passed += 1
|
| 569 |
+
else:
|
| 570 |
+
if len(failures) < 10:
|
| 571 |
+
failures.append((mant, expected_shift, shift, overflow, got_overflow))
|
| 572 |
+
|
| 573 |
+
return TestResult('float16.normalize', passed, len(test_cases), failures)
|
| 574 |
+
|
| 575 |
+
def test_float16_neg(self) -> TestResult:
|
| 576 |
+
"""Test float16.neg (sign flip)."""
|
| 577 |
+
prefix = 'float16.neg'
|
| 578 |
+
failures = []
|
| 579 |
+
passed = 0
|
| 580 |
+
|
| 581 |
+
test_values = [0x0000, 0x8000, 0x3C00, 0xBC00, 0x4000, 0x7C00, 0xFC00, 0x7BFF]
|
| 582 |
+
|
| 583 |
+
import random
|
| 584 |
+
random.seed(42)
|
| 585 |
+
for _ in range(50):
|
| 586 |
+
test_values.append(random.randint(0, 0xFFFF))
|
| 587 |
+
|
| 588 |
+
for val in test_values:
|
| 589 |
+
# Expected: flip bit 15
|
| 590 |
+
expected = val ^ 0x8000
|
| 591 |
+
|
| 592 |
+
ext = {f'{prefix}.$x[{i}]': float((val >> i) & 1) for i in range(16)}
|
| 593 |
+
values = self.eval_circuit(prefix, ext)
|
| 594 |
+
|
| 595 |
+
result = sum(int(values.get(f'{prefix}.out{i}', 0)) << i for i in range(16))
|
| 596 |
+
|
| 597 |
+
if result == expected:
|
| 598 |
+
passed += 1
|
| 599 |
+
else:
|
| 600 |
+
if len(failures) < 10:
|
| 601 |
+
failures.append((val, expected, result))
|
| 602 |
+
|
| 603 |
+
return TestResult('float16.neg', passed, len(test_values), failures)
|
| 604 |
+
|
| 605 |
+
def test_float16_abs(self) -> TestResult:
|
| 606 |
+
"""Test float16.abs (clear sign bit)."""
|
| 607 |
+
prefix = 'float16.abs'
|
| 608 |
+
failures = []
|
| 609 |
+
passed = 0
|
| 610 |
+
|
| 611 |
+
test_values = [0x0000, 0x8000, 0x3C00, 0xBC00, 0x4000, 0x7C00, 0xFC00, 0x7BFF]
|
| 612 |
+
|
| 613 |
+
import random
|
| 614 |
+
random.seed(42)
|
| 615 |
+
for _ in range(50):
|
| 616 |
+
test_values.append(random.randint(0, 0xFFFF))
|
| 617 |
+
|
| 618 |
+
for val in test_values:
|
| 619 |
+
# Expected: clear bit 15
|
| 620 |
+
expected = val & 0x7FFF
|
| 621 |
+
|
| 622 |
+
ext = {f'{prefix}.$x[{i}]': float((val >> i) & 1) for i in range(16)}
|
| 623 |
+
values = self.eval_circuit(prefix, ext)
|
| 624 |
+
|
| 625 |
+
result = sum(int(values.get(f'{prefix}.out{i}', 0)) << i for i in range(16))
|
| 626 |
+
|
| 627 |
+
if result == expected:
|
| 628 |
+
passed += 1
|
| 629 |
+
else:
|
| 630 |
+
if len(failures) < 10:
|
| 631 |
+
failures.append((val, expected, result))
|
| 632 |
+
|
| 633 |
+
return TestResult('float16.abs', passed, len(test_values), failures)
|
| 634 |
+
|
| 635 |
# =========================================================================
|
| 636 |
# ARITHMETIC TESTS (DIRECT EVALUATION)
|
| 637 |
# =========================================================================
|
|
|
|
| 812 |
self.results.append(result)
|
| 813 |
if verbose:
|
| 814 |
self._print_result(result)
|
| 815 |
+
if 'float16.normalize.overflow.weight' in self.eval.tensors:
|
| 816 |
+
result = self.eval.test_float16_normalize()
|
| 817 |
+
self.results.append(result)
|
| 818 |
+
if verbose:
|
| 819 |
+
self._print_result(result)
|
| 820 |
+
if 'float16.neg.out0.weight' in self.eval.tensors:
|
| 821 |
+
result = self.eval.test_float16_neg()
|
| 822 |
+
self.results.append(result)
|
| 823 |
+
if verbose:
|
| 824 |
+
self._print_result(result)
|
| 825 |
+
if 'float16.abs.out0.weight' in self.eval.tensors:
|
| 826 |
+
result = self.eval.test_float16_abs()
|
| 827 |
+
self.results.append(result)
|
| 828 |
+
if verbose:
|
| 829 |
+
self._print_result(result)
|
| 830 |
|
| 831 |
# Comparators
|
| 832 |
if verbose:
|