CharlesCNorton commited on
Commit
6c2c63e
·
1 Parent(s): 939c167

Rebuild float16 LUT/pow + 16-bit arithmetic; fix neg16bit

Browse files

- build.py: add float16 LUT match/output generation (sqrt/rsqrt/exp/ln/log2/sin/cos/tan/tanh) and pow via ln*mul->exp; add float16 half conversion helpers and LUT output builders; add 16-bit arithmetic builders (ripplecarry/adc/sbc/sub/cmp/equality/neg/asr/rol/ror/clz) plus comparator/constant vectors; add gate helpers for NOT/AND/OR/XOR/XNOR; extend input inference for new circuits and 16-bit variants; infer multiplier2x2 .andXY inputs; define neg16bit sum0 as NOT(not0).

- eval.py: add float16 LUT and pow tests with direct LUT index evaluation; add 16-bit arithmetic/comparator/shift tests; add topo/alias caching; update negation evaluation to handle sum0/carry0 arity; update orphan/selector tests to 16-bit.

- arithmetic.safetensors: regenerate tensors and .inputs registry with zero missing inputs (626,374 tensors / 208,788 gates).

- README: update circuit list (float16 LUT+pow, 16-bit integer), accuracy notes, counts, and TODO status.

Files changed (4) hide show
  1. README.md +22 -18
  2. arithmetic.safetensors +2 -2
  3. build.py +613 -75
  4. eval.py +601 -39
README.md CHANGED
@@ -22,26 +22,34 @@ Each gate is a threshold logic unit: `output = step(weights · inputs + bias)`.
22
 
23
  | File | Description |
24
  |------|-------------|
25
- | `arithmetic.safetensors` | 33,451 tensors encoding 11,147 gates |
26
- | `eval.py` | Test harness (208,637 tests) |
27
  | `build.py` | Builds tensors and infers gate connectivity |
28
 
29
  ## Circuits
30
 
31
  **Float16 (IEEE 754)**
32
  - `float16.add`, `float16.sub`, `float16.mul`, `float16.div`
 
 
 
33
  - `float16.neg`, `float16.abs`, `float16.cmp`
34
  - `float16.toint`, `float16.fromint`
35
  - `float16.pack`, `float16.unpack`, `float16.normalize`
36
 
37
  Handles NaN, Inf, zero, subnormals. Mantissa alignment via barrel shifter. Normalization via CLZ.
38
 
39
- **8-bit Integer**
40
- - Adders: half, full, ripple carry (2/4/8 bit), add-with-carry
41
- - Subtraction: sub8bit, sbc8bit, neg8bit
42
- - Comparison: cmp8bit, equality8bit
43
- - Shifts: asr8bit, rol8bit, ror8bit
44
- - CLZ: 8-bit and 16-bit
 
 
 
 
 
45
 
46
  **Modular Arithmetic**
47
  - mod2 through mod12 (divisibility testing)
@@ -141,24 +149,20 @@ This began as an attempt to build a complete threshold-logic CPU. The CPU is in
141
  - Float16 core (add/sub/mul/div)
142
  - Float16 utilities (pack/unpack/normalize/conversions)
143
  - Float16 IEEE-754 half compliance for add/sub/mul/div + toint/fromint (including subnormals)
144
- - 8-bit integer arithmetic
 
 
145
  - Boolean, threshold, modular, pattern recognition, combinational
146
 
147
  **Next:**
148
- - Float16 sqrt, rsqrt, pow
149
- - Float16 exp, ln, log2
150
- - Float16 trig (sin, cos, tan via CORDIC)
151
- - Float16 tanh (ML activation)
152
 
153
  **Cleanup:**
154
- - Rip out 8-bit integer circuits, replace with 16-bit
155
- - 8-bit was scaffolding for float16 development, not the product
156
 
157
  ## TODO (Unified)
158
 
159
- 1. Define accuracy/rounding specs and implement float16 sqrt/rsqrt/pow/exp/ln/log2.
160
- 2. Implement float16 trig (sin/cos/tan via CORDIC) and tanh with explicit accuracy targets.
161
- 3. Replace 8-bit integer circuits with 16-bit and remove 8-bit scaffolding.
162
 
163
  ## License
164
 
 
22
 
23
  | File | Description |
24
  |------|-------------|
25
+ | `arithmetic.safetensors` | 626,374 tensors encoding 208,788 gates |
26
+ | `eval.py` | Test harness (211,581 tests) |
27
  | `build.py` | Builds tensors and infers gate connectivity |
28
 
29
  ## Circuits
30
 
31
  **Float16 (IEEE 754)**
32
  - `float16.add`, `float16.sub`, `float16.mul`, `float16.div`
33
+ - `float16.sqrt`, `float16.rsqrt`, `float16.pow`
34
+ - `float16.exp`, `float16.ln`, `float16.log2`
35
+ - `float16.sin`, `float16.cos`, `float16.tan`, `float16.tanh`
36
  - `float16.neg`, `float16.abs`, `float16.cmp`
37
  - `float16.toint`, `float16.fromint`
38
  - `float16.pack`, `float16.unpack`, `float16.normalize`
39
 
40
  Handles NaN, Inf, zero, subnormals. Mantissa alignment via barrel shifter. Normalization via CLZ.
41
 
42
+ Accuracy/rounding:
43
+ - Unary transcendental ops are LUT-backed over all 65,536 float16 inputs.
44
+ - Outputs match torch.float16 results (round-to-nearest-even); NaNs are canonicalized to 0x7E00.
45
+ - `float16.pow` is defined as exp(b * ln(a)) with float16 rounding at each stage.
46
+
47
+ **16-bit Integer**
48
+ - Adders: half, full, ripple carry (2/4/16 bit), add-with-carry (adc16bit)
49
+ - Subtraction: sub16bit, sbc16bit, neg16bit
50
+ - Comparison: cmp16bit, equality16bit
51
+ - Shifts: asr16bit, rol16bit, ror16bit
52
+ - CLZ: 16-bit
53
 
54
  **Modular Arithmetic**
55
  - mod2 through mod12 (divisibility testing)
 
149
  - Float16 core (add/sub/mul/div)
150
  - Float16 utilities (pack/unpack/normalize/conversions)
151
  - Float16 IEEE-754 half compliance for add/sub/mul/div + toint/fromint (including subnormals)
152
+ - Float16 unary LUTs (sqrt/rsqrt/exp/ln/log2/sin/cos/tan/tanh)
153
+ - Float16 pow via exp(b * ln(a))
154
+ - 16-bit integer arithmetic (add/sub/cmp/shifts/CLZ)
155
  - Boolean, threshold, modular, pattern recognition, combinational
156
 
157
  **Next:**
158
+ - TBD
 
 
 
159
 
160
  **Cleanup:**
161
+ - None (8-bit arithmetic scaffolding removed)
 
162
 
163
  ## TODO (Unified)
164
 
165
+ None.
 
 
166
 
167
  ## License
168
 
arithmetic.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:418e84ff00ccfcbeb22920eaa70851d4cd6d6d6945f624fbd9354ab340c11bcb
3
- size 4281848
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7437f06026058f699dad09dd4b657ed6fe81d4845a2b0dc7213b2bbb048273c6
3
+ size 247445516
build.py CHANGED
@@ -17,8 +17,10 @@ from safetensors import safe_open
17
  from safetensors.torch import save_file
18
  import json
19
  import re
 
 
20
  from collections import defaultdict
21
- from typing import Dict, List, Tuple, Set
22
 
23
  class SignalRegistry:
24
  """Manages signal ID assignments."""
@@ -46,6 +48,77 @@ class SignalRegistry:
46
  return json.dumps(self.id_to_name)
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def extract_gate_name(tensor_name: str) -> str:
50
  """Extract gate name from tensor name (remove .weight or .bias suffix)."""
51
  if tensor_name.endswith('.weight'):
@@ -92,6 +165,85 @@ def infer_boolean_inputs(gate: str, registry: SignalRegistry) -> List[int]:
92
  return []
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def infer_halfadder_inputs(gate: str, prefix: str, registry: SignalRegistry) -> List[int]:
96
  """Infer inputs for half adder gates."""
97
  registry.register(f"{prefix}.$a")
@@ -264,27 +416,27 @@ def infer_modular_inputs(gate: str, registry: SignalRegistry) -> List[int]:
264
 
265
  def infer_comparator_inputs(gate: str, registry: SignalRegistry) -> List[int]:
266
  """Infer inputs for comparator gates."""
267
- # 8-bit inputs a and b
268
  prefix = gate.rsplit('.', 1)[0] # Remove .comparator
 
269
 
270
  inputs = []
271
- for i in range(8):
272
  registry.register(f"{prefix}.$a[{i}]")
273
  registry.register(f"{prefix}.$b[{i}]")
274
 
275
  # Comparator takes difference of bit pairs
276
- for i in range(8):
277
  inputs.append(registry.get_id(f"{prefix}.$a[{i}]"))
278
- for i in range(8):
279
  inputs.append(registry.get_id(f"{prefix}.$b[{i}]"))
280
 
281
  return inputs
282
 
283
 
284
- def infer_adc_sbc_inputs(gate: str, prefix: str, registry: SignalRegistry) -> List[int]:
285
  """Infer inputs for ADC/SBC (add/subtract with carry) gates."""
286
  # Register inputs
287
- for i in range(8):
288
  registry.register(f"{prefix}.$a[{i}]")
289
  registry.register(f"{prefix}.$b[{i}]")
290
  registry.register(f"{prefix}.$cin")
@@ -346,11 +498,9 @@ def infer_adc_sbc_inputs(gate: str, prefix: str, registry: SignalRegistry) -> Li
346
  return []
347
 
348
 
349
- def infer_sub8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
350
- """Infer inputs for SUB8BIT (subtraction via complement addition)."""
351
- prefix = "arithmetic.sub8bit"
352
-
353
- for i in range(8):
354
  registry.register(f"{prefix}.$a[{i}]")
355
  registry.register(f"{prefix}.$b[{i}]")
356
 
@@ -404,15 +554,22 @@ def infer_sub8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
404
  return []
405
 
406
 
407
- def infer_cmp8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
408
- """Infer inputs for CMP8BIT (compare via subtraction)."""
409
- prefix = "arithmetic.cmp8bit"
410
 
411
- for i in range(8):
 
 
 
 
 
 
 
 
412
  registry.register(f"{prefix}.$a[{i}]")
413
  registry.register(f"{prefix}.$b[{i}]")
414
 
415
- # Similar to sub8bit
416
  if '.notb' in gate:
417
  match = re.search(r'\.notb(\d+)', gate)
418
  if match:
@@ -454,23 +611,28 @@ def infer_cmp8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
454
  return [registry.register(f"{fa_prefix}.and1"),
455
  registry.register(f"{fa_prefix}.and2")]
456
 
457
- # Flag outputs
458
  if '.flags.' in gate:
459
- # Flags take the result bits
460
- return [registry.register(f"{prefix}.fa{i}.sum") for i in range(8)]
461
 
462
  return []
463
 
464
 
465
- def infer_equality8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
466
- """Infer inputs for equality circuit (XNOR chain + AND)."""
467
- prefix = "arithmetic.equality8bit"
468
 
469
- for i in range(8):
 
 
 
 
 
 
 
 
470
  registry.register(f"{prefix}.$a[{i}]")
471
  registry.register(f"{prefix}.$b[{i}]")
472
 
473
- # XNOR gates
474
  match = re.search(r'\.xnor(\d+)\.', gate)
475
  if match:
476
  idx = int(match.group(1))
@@ -484,29 +646,36 @@ def infer_equality8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
484
  nor_out = registry.register(f"{prefix}.xnor{idx}.layer1.nor")
485
  return [and_out, nor_out]
486
 
487
- # Final AND
488
  if '.and' in gate or '.final_and' in gate:
489
- return [registry.register(f"{prefix}.xnor{i}") for i in range(8)]
490
 
491
  return []
492
 
493
 
494
- def infer_neg8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
495
- """Infer inputs for NEG8BIT (two's complement negation)."""
496
- prefix = "arithmetic.neg8bit"
497
 
498
- for i in range(8):
 
 
 
 
 
 
 
 
499
  registry.register(f"{prefix}.$x[{i}]")
500
 
501
- # NOT gates
502
  if '.not' in gate and 'layer' not in gate:
503
  match = re.search(r'\.not(\d+)', gate)
504
  if match:
505
  idx = int(match.group(1))
506
  return [registry.get_id(f"{prefix}.$x[{idx}]")]
507
 
508
- # Increment by 1 (add chain)
509
- if '.sum0' in gate or '.carry0' in gate:
 
510
  return [registry.register(f"{prefix}.not0"), registry.get_id("#1")]
511
 
512
  match = re.search(r'\.xor(\d+)\.', gate)
@@ -538,19 +707,41 @@ def infer_neg8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
538
  return []
539
 
540
 
 
 
 
 
 
 
 
 
 
 
541
  def infer_shift_rotate_inputs(gate: str, registry: SignalRegistry) -> List[int]:
542
  """Infer inputs for ASR, ROL, ROR."""
543
  # Determine which circuit
544
- if 'asr8bit' in gate:
 
 
 
 
 
 
 
 
 
545
  prefix = "arithmetic.asr8bit"
 
546
  elif 'rol8bit' in gate:
547
  prefix = "arithmetic.rol8bit"
 
548
  elif 'ror8bit' in gate:
549
  prefix = "arithmetic.ror8bit"
 
550
  else:
551
  return []
552
 
553
- for i in range(8):
554
  registry.register(f"{prefix}.$x[{i}]")
555
 
556
  # Bit selectors
@@ -558,12 +749,12 @@ def infer_shift_rotate_inputs(gate: str, registry: SignalRegistry) -> List[int]:
558
  if match:
559
  idx = int(match.group(1))
560
  # Each output bit selects from input bits based on shift
561
- return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(8)]
562
 
563
  # Carry/shift out
564
  if '.cout' in gate or '.shiftout' in gate:
565
  if 'rol' in gate:
566
- return [registry.get_id(f"{prefix}.$x[7]")] # MSB shifts out
567
  elif 'ror' in gate:
568
  return [registry.get_id(f"{prefix}.$x[0]")] # LSB shifts out
569
  elif 'asr' in gate:
@@ -603,6 +794,15 @@ def infer_multiplier_inputs(gate: str, registry: SignalRegistry) -> List[int]:
603
  return [registry.get_id(f"{prefix}.$a[{col}]"),
604
  registry.get_id(f"{prefix}.$b[{row}]")]
605
 
 
 
 
 
 
 
 
 
 
606
  # Stage adders
607
  match = re.search(r'\.stage(\d+)\.bit(\d+)\.', gate)
608
  if match:
@@ -661,40 +861,60 @@ def infer_multiplier_inputs(gate: str, registry: SignalRegistry) -> List[int]:
661
 
662
  def infer_incr_decr_inputs(gate: str, registry: SignalRegistry) -> List[int]:
663
  """Infer inputs for incrementer/decrementer."""
664
- if 'incrementer' in gate:
 
 
 
 
 
 
665
  prefix = "arithmetic.incrementer8bit"
 
666
  elif 'decrementer' in gate:
667
  prefix = "arithmetic.decrementer8bit"
 
668
  else:
669
  return []
670
 
671
- for i in range(8):
672
  registry.register(f"{prefix}.$x[{i}]")
673
 
674
  # These typically just reference adder and constant
675
- return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(8)]
676
 
677
 
678
  def infer_minmax_inputs(gate: str, registry: SignalRegistry) -> List[int]:
679
  """Infer inputs for min/max/absolutedifference."""
680
- if 'max8bit' in gate:
 
 
 
 
 
 
 
 
 
681
  prefix = "arithmetic.max8bit"
 
682
  elif 'min8bit' in gate:
683
  prefix = "arithmetic.min8bit"
 
684
  elif 'absolutedifference' in gate:
685
  prefix = "arithmetic.absolutedifference8bit"
 
686
  else:
687
  return []
688
 
689
- for i in range(8):
690
  registry.register(f"{prefix}.$a[{i}]")
691
  registry.register(f"{prefix}.$b[{i}]")
692
 
693
  # Select/diff weights take comparison + both operands
694
  inputs = []
695
- for i in range(8):
696
  inputs.append(registry.get_id(f"{prefix}.$a[{i}]"))
697
- for i in range(8):
698
  inputs.append(registry.get_id(f"{prefix}.$b[{i}]"))
699
  return inputs
700
 
@@ -993,6 +1213,8 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
993
  # Ripple carry adders
994
  if 'ripplecarry8bit' in gate:
995
  return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry8bit', 8, registry)
 
 
996
  if 'ripplecarry4bit' in gate:
997
  return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry4bit', 4, registry)
998
  if 'ripplecarry2bit' in gate:
@@ -1000,28 +1222,41 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
1000
 
1001
  # ADC/SBC
1002
  if 'adc8bit' in gate:
1003
- return infer_adc_sbc_inputs(gate, 'arithmetic.adc8bit', registry)
 
 
1004
  if 'sbc8bit' in gate:
1005
- return infer_adc_sbc_inputs(gate, 'arithmetic.sbc8bit', registry)
 
 
1006
 
1007
  # SUB
1008
  if 'sub8bit' in gate:
1009
  return infer_sub8bit_inputs(gate, registry)
 
 
1010
 
1011
  # CMP
1012
  if 'cmp8bit' in gate:
1013
  return infer_cmp8bit_inputs(gate, registry)
 
 
1014
 
1015
  # Equality
1016
  if 'equality8bit' in gate:
1017
  return infer_equality8bit_inputs(gate, registry)
 
 
1018
 
1019
  # Negate
1020
  if 'neg8bit' in gate:
1021
  return infer_neg8bit_inputs(gate, registry)
 
 
1022
 
1023
  # Shifts and rotates
1024
- if 'asr8bit' in gate or 'rol8bit' in gate or 'ror8bit' in gate:
 
1025
  return infer_shift_rotate_inputs(gate, registry)
1026
 
1027
  # Multipliers
@@ -1038,7 +1273,9 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
1038
 
1039
  # Comparators
1040
  if 'greaterthan8bit' in gate or 'lessthan8bit' in gate or \
1041
- 'greaterorequal8bit' in gate or 'lessorequal8bit' in gate:
 
 
1042
  return infer_comparator_inputs(gate, registry)
1043
 
1044
  # CLZ (count leading zeros)
@@ -1049,6 +1286,16 @@ def infer_inputs_for_gate(gate: str, registry: SignalRegistry, routing: dict) ->
1049
 
1050
  # Float16 circuits
1051
  if gate.startswith('float16.'):
 
 
 
 
 
 
 
 
 
 
1052
  if 'unpack' in gate:
1053
  return infer_float16_unpack_inputs(gate, registry)
1054
  if 'pack' in gate:
@@ -2786,18 +3033,24 @@ def infer_float16_sub_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2786
  return []
2787
 
2788
 
2789
- def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2790
- """Infer inputs for float16.mul circuit."""
2791
- prefix = "float16.mul"
 
 
 
 
 
2792
 
2793
- for i in range(16):
2794
- registry.register(f"{prefix}.$a[{i}]")
2795
- registry.register(f"{prefix}.$b[{i}]")
 
2796
 
2797
- exp_a_bits = [f"{prefix}.$a[{10+i}]" for i in range(5)]
2798
- exp_b_bits = [f"{prefix}.$b[{10+i}]" for i in range(5)]
2799
- mant_a_bits = [f"{prefix}.$a[{i}]" for i in range(10)]
2800
- mant_b_bits = [f"{prefix}.$b[{i}]" for i in range(10)]
2801
 
2802
  if '.exp_a_all_ones' in gate:
2803
  return [registry.get_id(b) for b in exp_a_bits]
@@ -2842,11 +3095,11 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2842
  match = re.search(r'\.mant_a_norm(\d+)$', gate)
2843
  if match:
2844
  i = int(match.group(1))
2845
- return [registry.get_id(f"{prefix}.$a[{i}]")]
2846
  match = re.search(r'\.mant_b_norm(\d+)$', gate)
2847
  if match:
2848
  i = int(match.group(1))
2849
- return [registry.get_id(f"{prefix}.$b[{i}]")]
2850
 
2851
  for i in range(10):
2852
  registry.register(f"{prefix}.mant_a_norm{i}")
@@ -2919,11 +3172,11 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2919
  registry.register(f"{prefix}.result_is_zero")
2920
 
2921
  if '.result_sign.layer1.or' in gate:
2922
- return [registry.get_id(f"{prefix}.$a[15]"),
2923
- registry.get_id(f"{prefix}.$b[15]")]
2924
  if '.result_sign.layer1.nand' in gate:
2925
- return [registry.get_id(f"{prefix}.$a[15]"),
2926
- registry.get_id(f"{prefix}.$b[15]")]
2927
  if '.result_sign.layer2' in gate:
2928
  return [registry.register(f"{prefix}.result_sign.layer1.or"),
2929
  registry.register(f"{prefix}.result_sign.layer1.nand")]
@@ -2944,11 +3197,11 @@ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry) -> List[int]:
2944
  if i == 10:
2945
  a_bit = registry.get_id(f"{prefix}.implicit_a")
2946
  else:
2947
- a_bit = registry.get_id(f"{prefix}.$a[{i}]")
2948
  if j == 10:
2949
  b_bit = registry.get_id(f"{prefix}.implicit_b")
2950
  else:
2951
- b_bit = registry.get_id(f"{prefix}.$b[{j}]")
2952
  return [a_bit, b_bit]
2953
 
2954
  for i in range(11):
@@ -6810,6 +7063,32 @@ def build_float16_unpack_tensors() -> Dict[str, torch.Tensor]:
6810
  return tensors
6811
 
6812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6813
  def build_clz16bit_tensors() -> Dict[str, torch.Tensor]:
6814
  """Build tensors for arithmetic.clz16bit circuit.
6815
 
@@ -10595,6 +10874,166 @@ def build_clz8bit_tensors() -> Dict[str, torch.Tensor]:
10595
  return tensors
10596
 
10597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10598
  def main():
10599
  print("Loading existing tensors...")
10600
  tensors = {}
@@ -10625,6 +11064,31 @@ def main():
10625
  del tensors[k]
10626
  print(f"Removed {len(old_float16_div)} old float16.div tensors")
10627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10628
  # Remove broken mod2/mod4/mod8 tensors
10629
  old_mod_power2 = [k for k in tensors.keys() if k.startswith('modular.mod2') or
10630
  k.startswith('modular.mod4') or k.startswith('modular.mod8')]
@@ -10647,10 +11111,6 @@ def main():
10647
 
10648
  # Build new circuits
10649
  print("Building new circuits...")
10650
- clz_tensors = build_clz8bit_tensors()
10651
- tensors.update(clz_tensors)
10652
- print(f" CLZ8BIT: {len(clz_tensors)} tensors")
10653
-
10654
  clz16_tensors = build_clz16bit_tensors()
10655
  tensors.update(clz16_tensors)
10656
  print(f" CLZ16BIT: {len(clz16_tensors)} tensors")
@@ -10703,14 +11163,92 @@ def main():
10703
  tensors.update(fromint_tensors)
10704
  print(f" float16.fromint: {len(fromint_tensors)} tensors")
10705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10706
  mod_power2_tensors = build_modular_power2_tensors()
10707
  tensors.update(mod_power2_tensors)
10708
  print(f" modular.mod2/4/8: {len(mod_power2_tensors)} tensors")
10709
 
10710
- bitwise_tensors = build_bitwise_shift_tensors()
10711
- tensors.update(bitwise_tensors)
10712
- print(f" bitwise shifts: {len(bitwise_tensors)} tensors")
10713
-
10714
  symmetry_tensors = build_symmetry8bit_tensors()
10715
  tensors.update(symmetry_tensors)
10716
  print(f" symmetry8bit: {len(symmetry_tensors)} tensors")
 
17
  from safetensors.torch import save_file
18
  import json
19
  import re
20
+ import struct
21
+ import math
22
  from collections import defaultdict
23
+ from typing import Dict, List, Tuple, Set, Callable, Optional
24
 
25
  class SignalRegistry:
26
  """Manages signal ID assignments."""
 
48
  return json.dumps(self.id_to_name)
49
 
50
 
51
+ def float16_bits_to_float(bits: int) -> float:
52
+ """Interpret 16-bit int as IEEE-754 float16."""
53
+ packed = struct.pack('>H', bits & 0xFFFF)
54
+ return struct.unpack('>e', packed)[0]
55
+
56
+
57
+ def float16_float_to_bits(val: float) -> int:
58
+ """Convert float to IEEE-754 float16 bits with canonical NaN."""
59
+ try:
60
+ packed = struct.pack('>e', float(val))
61
+ return struct.unpack('>H', packed)[0]
62
+ except (OverflowError, struct.error):
63
+ if val == float('inf'):
64
+ return 0x7C00
65
+ if val == float('-inf'):
66
+ return 0xFC00
67
+ if val != val:
68
+ return 0x7E00
69
+ return 0x7BFF if val > 0 else 0xFBFF
70
+
71
+
72
+ def compute_float16_unary_lut_outputs(op_fn: Callable[[torch.Tensor], torch.Tensor]) -> List[int]:
73
+ """Compute output bits for all 65536 float16 inputs using a unary op."""
74
+ outputs: List[int] = [0] * 65536
75
+ for bits in range(65536):
76
+ val = float16_bits_to_float(bits)
77
+ out = op_fn(torch.tensor(val, dtype=torch.float16)).item()
78
+ if out != out:
79
+ outputs[bits] = 0x7E00
80
+ else:
81
+ outputs[bits] = float16_float_to_bits(float(out))
82
+ return outputs
83
+
84
+
85
+ def build_float16_lut_match_tensors(prefix: str) -> Dict[str, torch.Tensor]:
86
+ """Build exact-match gates for all 16-bit patterns under prefix.matchXXXX."""
87
+ tensors: Dict[str, torch.Tensor] = {}
88
+ for bits in range(65536):
89
+ ones = bits.bit_count()
90
+ weights = [1.0 if (bits >> i) & 1 else -1.0 for i in range(16)]
91
+ bias = -(ones - 0.5)
92
+ name = f"{prefix}.match{bits:04x}"
93
+ tensors[f"{name}.weight"] = torch.tensor(weights)
94
+ tensors[f"{name}.bias"] = torch.tensor([bias])
95
+ return tensors
96
+
97
+
98
+ def build_float16_lut_output_tensors(prefix: str, outputs: List[int]) -> Dict[str, torch.Tensor]:
99
+ """Build LUT output gates (prefix.out0..out15) using one-hot match inputs."""
100
+ tensors: Dict[str, torch.Tensor] = {}
101
+ for bit in range(16):
102
+ weights = torch.zeros(65536)
103
+ for idx, out_bits in enumerate(outputs):
104
+ if (out_bits >> bit) & 1:
105
+ weights[idx] = 1.0
106
+ tensors[f"{prefix}.out{bit}.weight"] = weights
107
+ tensors[f"{prefix}.out{bit}.bias"] = torch.tensor([-0.5])
108
+ return tensors
109
+
110
+
111
+ def clone_prefix_tensors(src: Dict[str, torch.Tensor], old_prefix: str,
112
+ new_prefix: str) -> Dict[str, torch.Tensor]:
113
+ """Clone tensors and rewrite the prefix in tensor names."""
114
+ out: Dict[str, torch.Tensor] = {}
115
+ for name, tensor in src.items():
116
+ if name.startswith(old_prefix + "."):
117
+ out_name = new_prefix + name[len(old_prefix):]
118
+ out[out_name] = tensor.clone()
119
+ return out
120
+
121
+
122
  def extract_gate_name(tensor_name: str) -> str:
123
  """Extract gate name from tensor name (remove .weight or .bias suffix)."""
124
  if tensor_name.endswith('.weight'):
 
165
  return []
166
 
167
 
168
+ def get_lut_match_ids(registry: SignalRegistry, match_prefix: str) -> List[int]:
169
+ """Get (and cache) match gate IDs for a LUT prefix."""
170
+ cache = getattr(registry, "_lut_match_ids", None)
171
+ if cache is None:
172
+ cache = {}
173
+ setattr(registry, "_lut_match_ids", cache)
174
+ if match_prefix not in cache:
175
+ cache[match_prefix] = [registry.register(f"{match_prefix}.match{idx:04x}") for idx in range(65536)]
176
+ return cache[match_prefix]
177
+
178
+
179
+ def infer_float16_lut_match_inputs(gate: str, registry: SignalRegistry,
180
+ match_prefix: str, input_bits: List[str]) -> List[int]:
181
+ """Infer inputs for LUT match gates (exact pattern match)."""
182
+ if not gate.startswith(f"{match_prefix}.match"):
183
+ return []
184
+ for name in input_bits:
185
+ registry.register(name)
186
+ return [registry.get_id(name) for name in input_bits]
187
+
188
+
189
+ def infer_float16_lut_out_inputs(gate: str, registry: SignalRegistry, match_prefix: str) -> List[int]:
190
+ """Infer inputs for LUT output gates (one-hot match vector)."""
191
+ match = re.search(r'\.out(\d+)$', gate)
192
+ if not match:
193
+ return []
194
+ return get_lut_match_ids(registry, match_prefix)
195
+
196
+
197
+ def infer_float16_lut_inputs(gate: str, registry: SignalRegistry) -> List[int]:
198
+ """Infer inputs for shared float16.lut match gates."""
199
+ prefix = "float16.lut"
200
+ input_bits = [f"{prefix}.$x[{i}]" for i in range(16)]
201
+ return infer_float16_lut_match_inputs(gate, registry, prefix, input_bits)
202
+
203
+
204
+ def infer_float16_pow_inputs(gate: str, registry: SignalRegistry) -> List[int]:
205
+ """Infer inputs for float16.pow circuit (ln -> mul -> exp)."""
206
+ prefix = "float16.pow"
207
+
208
+ # External inputs
209
+ for i in range(16):
210
+ registry.register(f"{prefix}.$a[{i}]")
211
+ registry.register(f"{prefix}.$b[{i}]")
212
+
213
+ # ln subcircuit (match + outputs)
214
+ ln_prefix = f"{prefix}.ln"
215
+ ln_input_bits = [f"{prefix}.$a[{i}]" for i in range(16)]
216
+ inputs = infer_float16_lut_match_inputs(gate, registry, ln_prefix, ln_input_bits)
217
+ if inputs:
218
+ return inputs
219
+ if gate.startswith(f"{ln_prefix}."):
220
+ return infer_float16_lut_out_inputs(gate, registry, ln_prefix)
221
+
222
+ # mul subcircuit (a = ln.out, b = external b)
223
+ if gate.startswith(f"{prefix}.mul."):
224
+ a_bits = [f"{ln_prefix}.out{i}" for i in range(16)]
225
+ b_bits = [f"{prefix}.$b[{i}]" for i in range(16)]
226
+ return infer_float16_mul_inputs(gate, registry, prefix=f"{prefix}.mul",
227
+ a_bits=a_bits, b_bits=b_bits)
228
+
229
+ # exp subcircuit (match + outputs) with input from mul outputs
230
+ exp_prefix = f"{prefix}.exp"
231
+ exp_input_bits = [f"{prefix}.mul.out{i}" for i in range(16)]
232
+ inputs = infer_float16_lut_match_inputs(gate, registry, exp_prefix, exp_input_bits)
233
+ if inputs:
234
+ return inputs
235
+ if gate.startswith(f"{exp_prefix}."):
236
+ return infer_float16_lut_out_inputs(gate, registry, exp_prefix)
237
+
238
+ # pow outputs (pass-through from exp.out)
239
+ match = re.search(r'\.out(\d+)$', gate)
240
+ if match:
241
+ i = int(match.group(1))
242
+ return [registry.get_id(f"{exp_prefix}.out{i}")]
243
+
244
+ return []
245
+
246
+
247
  def infer_halfadder_inputs(gate: str, prefix: str, registry: SignalRegistry) -> List[int]:
248
  """Infer inputs for half adder gates."""
249
  registry.register(f"{prefix}.$a")
 
416
 
417
  def infer_comparator_inputs(gate: str, registry: SignalRegistry) -> List[int]:
418
  """Infer inputs for comparator gates."""
 
419
  prefix = gate.rsplit('.', 1)[0] # Remove .comparator
420
+ bits = 16 if "16bit" in prefix else 8
421
 
422
  inputs = []
423
+ for i in range(bits):
424
  registry.register(f"{prefix}.$a[{i}]")
425
  registry.register(f"{prefix}.$b[{i}]")
426
 
427
  # Comparator takes difference of bit pairs
428
+ for i in range(bits):
429
  inputs.append(registry.get_id(f"{prefix}.$a[{i}]"))
430
+ for i in range(bits):
431
  inputs.append(registry.get_id(f"{prefix}.$b[{i}]"))
432
 
433
  return inputs
434
 
435
 
436
+ def infer_adc_sbc_inputs(gate: str, prefix: str, registry: SignalRegistry, bits: int = 8) -> List[int]:
437
  """Infer inputs for ADC/SBC (add/subtract with carry) gates."""
438
  # Register inputs
439
+ for i in range(bits):
440
  registry.register(f"{prefix}.$a[{i}]")
441
  registry.register(f"{prefix}.$b[{i}]")
442
  registry.register(f"{prefix}.$cin")
 
498
  return []
499
 
500
 
501
+ def infer_sub_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]:
502
+ """Infer inputs for subtractor (complement addition) gates."""
503
+ for i in range(bits):
 
 
504
  registry.register(f"{prefix}.$a[{i}]")
505
  registry.register(f"{prefix}.$b[{i}]")
506
 
 
554
  return []
555
 
556
 
557
+ def infer_sub8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
558
+ """Infer inputs for SUB8BIT (subtraction via complement addition)."""
559
+ return infer_sub_inputs(gate, "arithmetic.sub8bit", 8, registry)
560
 
561
+
562
+ def infer_sub16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
563
+ """Infer inputs for SUB16BIT (subtraction via complement addition)."""
564
+ return infer_sub_inputs(gate, "arithmetic.sub16bit", 16, registry)
565
+
566
+
567
+ def infer_cmp_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]:
568
+ """Infer inputs for comparator via subtraction."""
569
+ for i in range(bits):
570
  registry.register(f"{prefix}.$a[{i}]")
571
  registry.register(f"{prefix}.$b[{i}]")
572
 
 
573
  if '.notb' in gate:
574
  match = re.search(r'\.notb(\d+)', gate)
575
  if match:
 
611
  return [registry.register(f"{fa_prefix}.and1"),
612
  registry.register(f"{fa_prefix}.and2")]
613
 
 
614
  if '.flags.' in gate:
615
+ return [registry.register(f"{prefix}.fa{i}.sum") for i in range(bits)]
 
616
 
617
  return []
618
 
619
 
620
+ def infer_cmp8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
621
+ """Infer inputs for CMP8BIT (compare via subtraction)."""
622
+ return infer_cmp_inputs(gate, "arithmetic.cmp8bit", 8, registry)
623
 
624
+
625
+ def infer_cmp16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
626
+ """Infer inputs for CMP16BIT (compare via subtraction)."""
627
+ return infer_cmp_inputs(gate, "arithmetic.cmp16bit", 16, registry)
628
+
629
+
630
+ def infer_equality_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]:
631
+ """Infer inputs for equality circuit (XNOR chain + AND)."""
632
+ for i in range(bits):
633
  registry.register(f"{prefix}.$a[{i}]")
634
  registry.register(f"{prefix}.$b[{i}]")
635
 
 
636
  match = re.search(r'\.xnor(\d+)\.', gate)
637
  if match:
638
  idx = int(match.group(1))
 
646
  nor_out = registry.register(f"{prefix}.xnor{idx}.layer1.nor")
647
  return [and_out, nor_out]
648
 
 
649
  if '.and' in gate or '.final_and' in gate:
650
+ return [registry.register(f"{prefix}.xnor{i}") for i in range(bits)]
651
 
652
  return []
653
 
654
 
655
+ def infer_equality8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
656
+ """Infer inputs for equality8bit circuit (XNOR chain + AND)."""
657
+ return infer_equality_inputs(gate, "arithmetic.equality8bit", 8, registry)
658
 
659
+
660
+ def infer_equality16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
661
+ """Infer inputs for equality16bit circuit (XNOR chain + AND)."""
662
+ return infer_equality_inputs(gate, "arithmetic.equality16bit", 16, registry)
663
+
664
+
665
+ def infer_neg_inputs(gate: str, prefix: str, bits: int, registry: SignalRegistry) -> List[int]:
666
+ """Infer inputs for negation (two's complement)."""
667
+ for i in range(bits):
668
  registry.register(f"{prefix}.$x[{i}]")
669
 
 
670
  if '.not' in gate and 'layer' not in gate:
671
  match = re.search(r'\.not(\d+)', gate)
672
  if match:
673
  idx = int(match.group(1))
674
  return [registry.get_id(f"{prefix}.$x[{idx}]")]
675
 
676
+ if '.sum0' in gate:
677
+ return [registry.register(f"{prefix}.not0")]
678
+ if '.carry0' in gate:
679
  return [registry.register(f"{prefix}.not0"), registry.get_id("#1")]
680
 
681
  match = re.search(r'\.xor(\d+)\.', gate)
 
707
  return []
708
 
709
 
710
+ def infer_neg8bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
711
+ """Infer inputs for NEG8BIT (two's complement negation)."""
712
+ return infer_neg_inputs(gate, "arithmetic.neg8bit", 8, registry)
713
+
714
+
715
+ def infer_neg16bit_inputs(gate: str, registry: SignalRegistry) -> List[int]:
716
+ """Infer inputs for NEG16BIT (two's complement negation)."""
717
+ return infer_neg_inputs(gate, "arithmetic.neg16bit", 16, registry)
718
+
719
+
720
  def infer_shift_rotate_inputs(gate: str, registry: SignalRegistry) -> List[int]:
721
  """Infer inputs for ASR, ROL, ROR."""
722
  # Determine which circuit
723
+ if 'asr16bit' in gate:
724
+ prefix = "arithmetic.asr16bit"
725
+ bits = 16
726
+ elif 'rol16bit' in gate:
727
+ prefix = "arithmetic.rol16bit"
728
+ bits = 16
729
+ elif 'ror16bit' in gate:
730
+ prefix = "arithmetic.ror16bit"
731
+ bits = 16
732
+ elif 'asr8bit' in gate:
733
  prefix = "arithmetic.asr8bit"
734
+ bits = 8
735
  elif 'rol8bit' in gate:
736
  prefix = "arithmetic.rol8bit"
737
+ bits = 8
738
  elif 'ror8bit' in gate:
739
  prefix = "arithmetic.ror8bit"
740
+ bits = 8
741
  else:
742
  return []
743
 
744
+ for i in range(bits):
745
  registry.register(f"{prefix}.$x[{i}]")
746
 
747
  # Bit selectors
 
749
  if match:
750
  idx = int(match.group(1))
751
  # Each output bit selects from input bits based on shift
752
+ return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(bits)]
753
 
754
  # Carry/shift out
755
  if '.cout' in gate or '.shiftout' in gate:
756
  if 'rol' in gate:
757
+ return [registry.get_id(f"{prefix}.$x[{bits-1}]")] # MSB shifts out
758
  elif 'ror' in gate:
759
  return [registry.get_id(f"{prefix}.$x[0]")] # LSB shifts out
760
  elif 'asr' in gate:
 
794
  return [registry.get_id(f"{prefix}.$a[{col}]"),
795
  registry.get_id(f"{prefix}.$b[{row}]")]
796
 
797
+ # Direct AND gates used by multiplier2x2
798
+ if 'multiplier2x2' in gate:
799
+ match = re.search(r'\.and(\d)(\d)$', gate)
800
+ if match:
801
+ row, col = int(match.group(1)), int(match.group(2))
802
+ if row < size and col < size:
803
+ return [registry.get_id(f"{prefix}.$a[{col}]"),
804
+ registry.get_id(f"{prefix}.$b[{row}]")]
805
+
806
  # Stage adders
807
  match = re.search(r'\.stage(\d+)\.bit(\d+)\.', gate)
808
  if match:
 
861
 
862
  def infer_incr_decr_inputs(gate: str, registry: SignalRegistry) -> List[int]:
863
  """Infer inputs for incrementer/decrementer."""
864
+ if 'incrementer16bit' in gate:
865
+ prefix = "arithmetic.incrementer16bit"
866
+ bits = 16
867
+ elif 'decrementer16bit' in gate:
868
+ prefix = "arithmetic.decrementer16bit"
869
+ bits = 16
870
+ elif 'incrementer' in gate:
871
  prefix = "arithmetic.incrementer8bit"
872
+ bits = 8
873
  elif 'decrementer' in gate:
874
  prefix = "arithmetic.decrementer8bit"
875
+ bits = 8
876
  else:
877
  return []
878
 
879
+ for i in range(bits):
880
  registry.register(f"{prefix}.$x[{i}]")
881
 
882
  # These typically just reference adder and constant
883
+ return [registry.get_id(f"{prefix}.$x[{i}]") for i in range(bits)]
884
 
885
 
886
  def infer_minmax_inputs(gate: str, registry: SignalRegistry) -> List[int]:
887
  """Infer inputs for min/max/absolutedifference."""
888
+ if 'max16bit' in gate:
889
+ prefix = "arithmetic.max16bit"
890
+ bits = 16
891
+ elif 'min16bit' in gate:
892
+ prefix = "arithmetic.min16bit"
893
+ bits = 16
894
+ elif 'absolutedifference16bit' in gate:
895
+ prefix = "arithmetic.absolutedifference16bit"
896
+ bits = 16
897
+ elif 'max8bit' in gate:
898
  prefix = "arithmetic.max8bit"
899
+ bits = 8
900
  elif 'min8bit' in gate:
901
  prefix = "arithmetic.min8bit"
902
+ bits = 8
903
  elif 'absolutedifference' in gate:
904
  prefix = "arithmetic.absolutedifference8bit"
905
+ bits = 8
906
  else:
907
  return []
908
 
909
+ for i in range(bits):
910
  registry.register(f"{prefix}.$a[{i}]")
911
  registry.register(f"{prefix}.$b[{i}]")
912
 
913
  # Select/diff weights take comparison + both operands
914
  inputs = []
915
+ for i in range(bits):
916
  inputs.append(registry.get_id(f"{prefix}.$a[{i}]"))
917
+ for i in range(bits):
918
  inputs.append(registry.get_id(f"{prefix}.$b[{i}]"))
919
  return inputs
920
 
 
1213
  # Ripple carry adders
1214
  if 'ripplecarry8bit' in gate:
1215
  return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry8bit', 8, registry)
1216
+ if 'ripplecarry16bit' in gate:
1217
+ return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry16bit', 16, registry)
1218
  if 'ripplecarry4bit' in gate:
1219
  return infer_ripplecarry_inputs(gate, 'arithmetic.ripplecarry4bit', 4, registry)
1220
  if 'ripplecarry2bit' in gate:
 
1222
 
1223
  # ADC/SBC
1224
  if 'adc8bit' in gate:
1225
+ return infer_adc_sbc_inputs(gate, 'arithmetic.adc8bit', registry, bits=8)
1226
+ if 'adc16bit' in gate:
1227
+ return infer_adc_sbc_inputs(gate, 'arithmetic.adc16bit', registry, bits=16)
1228
  if 'sbc8bit' in gate:
1229
+ return infer_adc_sbc_inputs(gate, 'arithmetic.sbc8bit', registry, bits=8)
1230
+ if 'sbc16bit' in gate:
1231
+ return infer_adc_sbc_inputs(gate, 'arithmetic.sbc16bit', registry, bits=16)
1232
 
1233
  # SUB
1234
  if 'sub8bit' in gate:
1235
  return infer_sub8bit_inputs(gate, registry)
1236
+ if 'sub16bit' in gate:
1237
+ return infer_sub16bit_inputs(gate, registry)
1238
 
1239
  # CMP
1240
  if 'cmp8bit' in gate:
1241
  return infer_cmp8bit_inputs(gate, registry)
1242
+ if 'cmp16bit' in gate:
1243
+ return infer_cmp16bit_inputs(gate, registry)
1244
 
1245
  # Equality
1246
  if 'equality8bit' in gate:
1247
  return infer_equality8bit_inputs(gate, registry)
1248
+ if 'equality16bit' in gate:
1249
+ return infer_equality16bit_inputs(gate, registry)
1250
 
1251
  # Negate
1252
  if 'neg8bit' in gate:
1253
  return infer_neg8bit_inputs(gate, registry)
1254
+ if 'neg16bit' in gate:
1255
+ return infer_neg16bit_inputs(gate, registry)
1256
 
1257
  # Shifts and rotates
1258
+ if ('asr8bit' in gate or 'rol8bit' in gate or 'ror8bit' in gate or
1259
+ 'asr16bit' in gate or 'rol16bit' in gate or 'ror16bit' in gate):
1260
  return infer_shift_rotate_inputs(gate, registry)
1261
 
1262
  # Multipliers
 
1273
 
1274
  # Comparators
1275
  if 'greaterthan8bit' in gate or 'lessthan8bit' in gate or \
1276
+ 'greaterorequal8bit' in gate or 'lessorequal8bit' in gate or \
1277
+ 'greaterthan16bit' in gate or 'lessthan16bit' in gate or \
1278
+ 'greaterorequal16bit' in gate or 'lessorequal16bit' in gate:
1279
  return infer_comparator_inputs(gate, registry)
1280
 
1281
  # CLZ (count leading zeros)
 
1286
 
1287
  # Float16 circuits
1288
  if gate.startswith('float16.'):
1289
+ if gate.startswith('float16.lut'):
1290
+ return infer_float16_lut_inputs(gate, registry)
1291
+ if gate.startswith('float16.pow'):
1292
+ return infer_float16_pow_inputs(gate, registry)
1293
+ if gate.startswith('float16.sqrt') or gate.startswith('float16.rsqrt') or \
1294
+ gate.startswith('float16.exp') or gate.startswith('float16.ln') or \
1295
+ gate.startswith('float16.log2') or gate.startswith('float16.sin') or \
1296
+ gate.startswith('float16.cos') or gate.startswith('float16.tan') or \
1297
+ gate.startswith('float16.tanh'):
1298
+ return infer_float16_lut_out_inputs(gate, registry, "float16.lut")
1299
  if 'unpack' in gate:
1300
  return infer_float16_unpack_inputs(gate, registry)
1301
  if 'pack' in gate:
 
3033
  return []
3034
 
3035
 
3036
+ def infer_float16_mul_inputs(gate: str, registry: SignalRegistry, prefix: str = "float16.mul",
3037
+ a_bits: Optional[List[str]] = None,
3038
+ b_bits: Optional[List[str]] = None) -> List[int]:
3039
+ """Infer inputs for float16.mul circuit (optionally with custom input sources)."""
3040
+ if a_bits is None:
3041
+ a_bits = [f"{prefix}.$a[{i}]" for i in range(16)]
3042
+ if b_bits is None:
3043
+ b_bits = [f"{prefix}.$b[{i}]" for i in range(16)]
3044
 
3045
+ for name in a_bits:
3046
+ registry.register(name)
3047
+ for name in b_bits:
3048
+ registry.register(name)
3049
 
3050
+ exp_a_bits = [a_bits[10 + i] for i in range(5)]
3051
+ exp_b_bits = [b_bits[10 + i] for i in range(5)]
3052
+ mant_a_bits = [a_bits[i] for i in range(10)]
3053
+ mant_b_bits = [b_bits[i] for i in range(10)]
3054
 
3055
  if '.exp_a_all_ones' in gate:
3056
  return [registry.get_id(b) for b in exp_a_bits]
 
3095
  match = re.search(r'\.mant_a_norm(\d+)$', gate)
3096
  if match:
3097
  i = int(match.group(1))
3098
+ return [registry.get_id(a_bits[i])]
3099
  match = re.search(r'\.mant_b_norm(\d+)$', gate)
3100
  if match:
3101
  i = int(match.group(1))
3102
+ return [registry.get_id(b_bits[i])]
3103
 
3104
  for i in range(10):
3105
  registry.register(f"{prefix}.mant_a_norm{i}")
 
3172
  registry.register(f"{prefix}.result_is_zero")
3173
 
3174
  if '.result_sign.layer1.or' in gate:
3175
+ return [registry.get_id(a_bits[15]),
3176
+ registry.get_id(b_bits[15])]
3177
  if '.result_sign.layer1.nand' in gate:
3178
+ return [registry.get_id(a_bits[15]),
3179
+ registry.get_id(b_bits[15])]
3180
  if '.result_sign.layer2' in gate:
3181
  return [registry.register(f"{prefix}.result_sign.layer1.or"),
3182
  registry.register(f"{prefix}.result_sign.layer1.nand")]
 
3197
  if i == 10:
3198
  a_bit = registry.get_id(f"{prefix}.implicit_a")
3199
  else:
3200
+ a_bit = registry.get_id(a_bits[i])
3201
  if j == 10:
3202
  b_bit = registry.get_id(f"{prefix}.implicit_b")
3203
  else:
3204
+ b_bit = registry.get_id(b_bits[j])
3205
  return [a_bit, b_bit]
3206
 
3207
  for i in range(11):
 
7063
  return tensors
7064
 
7065
 
7066
+ def build_float16_pow_tensors(mul_tensors: Dict[str, torch.Tensor],
7067
+ ln_outputs: List[int],
7068
+ exp_outputs: List[int]) -> Dict[str, torch.Tensor]:
7069
+ """Build tensors for float16.pow via ln -> mul -> exp."""
7070
+ tensors: Dict[str, torch.Tensor] = {}
7071
+
7072
+ # ln(a) LUT
7073
+ tensors.update(build_float16_lut_match_tensors("float16.pow.ln"))
7074
+ tensors.update(build_float16_lut_output_tensors("float16.pow.ln", ln_outputs))
7075
+
7076
+ # mul(ln(a), b)
7077
+ tensors.update(clone_prefix_tensors(mul_tensors, "float16.mul", "float16.pow.mul"))
7078
+
7079
+ # exp(mul)
7080
+ tensors.update(build_float16_lut_match_tensors("float16.pow.exp"))
7081
+ tensors.update(build_float16_lut_output_tensors("float16.pow.exp", exp_outputs))
7082
+
7083
+ # Final outputs (pass-through from exp)
7084
+ prefix = "float16.pow"
7085
+ for i in range(16):
7086
+ tensors[f"{prefix}.out{i}.weight"] = torch.tensor([1.0])
7087
+ tensors[f"{prefix}.out{i}.bias"] = torch.tensor([-0.5])
7088
+
7089
+ return tensors
7090
+
7091
+
7092
  def build_clz16bit_tensors() -> Dict[str, torch.Tensor]:
7093
  """Build tensors for arithmetic.clz16bit circuit.
7094
 
 
10874
  return tensors
10875
 
10876
 
10877
+ def add_not_gate(tensors: Dict[str, torch.Tensor], name: str) -> None:
10878
+ tensors[f"{name}.weight"] = torch.tensor([-1.0])
10879
+ tensors[f"{name}.bias"] = torch.tensor([0.0])
10880
+
10881
+
10882
+ def add_and_gate(tensors: Dict[str, torch.Tensor], name: str) -> None:
10883
+ tensors[f"{name}.weight"] = torch.tensor([1.0, 1.0])
10884
+ tensors[f"{name}.bias"] = torch.tensor([-2.0])
10885
+
10886
+
10887
+ def add_or_gate(tensors: Dict[str, torch.Tensor], name: str) -> None:
10888
+ tensors[f"{name}.weight"] = torch.tensor([1.0, 1.0])
10889
+ tensors[f"{name}.bias"] = torch.tensor([-1.0])
10890
+
10891
+
10892
+ def add_xor_gate(tensors: Dict[str, torch.Tensor], name: str) -> None:
10893
+ tensors[f"{name}.layer1.or.weight"] = torch.tensor([1.0, 1.0])
10894
+ tensors[f"{name}.layer1.or.bias"] = torch.tensor([-1.0])
10895
+ tensors[f"{name}.layer1.nand.weight"] = torch.tensor([-1.0, -1.0])
10896
+ tensors[f"{name}.layer1.nand.bias"] = torch.tensor([1.0])
10897
+ tensors[f"{name}.layer2.weight"] = torch.tensor([1.0, 1.0])
10898
+ tensors[f"{name}.layer2.bias"] = torch.tensor([-2.0])
10899
+
10900
+
10901
+ def add_xnor_gate(tensors: Dict[str, torch.Tensor], name: str) -> None:
10902
+ tensors[f"{name}.layer1.and.weight"] = torch.tensor([1.0, 1.0])
10903
+ tensors[f"{name}.layer1.and.bias"] = torch.tensor([-1.5])
10904
+ tensors[f"{name}.layer1.nor.weight"] = torch.tensor([-1.0, -1.0])
10905
+ tensors[f"{name}.layer1.nor.bias"] = torch.tensor([0.0])
10906
+ tensors[f"{name}.layer2.weight"] = torch.tensor([1.0, 1.0])
10907
+ tensors[f"{name}.layer2.bias"] = torch.tensor([-0.5])
10908
+
10909
+
10910
+ def build_ripplecarry_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]:
10911
+ tensors: Dict[str, torch.Tensor] = {}
10912
+ for i in range(bits):
10913
+ fa_prefix = f"{prefix}.fa{i}"
10914
+ add_xor_gate(tensors, f"{fa_prefix}.ha1.sum")
10915
+ add_and_gate(tensors, f"{fa_prefix}.ha1.carry")
10916
+ add_xor_gate(tensors, f"{fa_prefix}.ha2.sum")
10917
+ add_and_gate(tensors, f"{fa_prefix}.ha2.carry")
10918
+ add_or_gate(tensors, f"{fa_prefix}.carry_or")
10919
+ return tensors
10920
+
10921
+
10922
+ def build_adc_sbc_tensors(prefix: str, bits: int, with_notb: bool = False) -> Dict[str, torch.Tensor]:
10923
+ tensors: Dict[str, torch.Tensor] = {}
10924
+ if with_notb:
10925
+ for i in range(bits):
10926
+ add_not_gate(tensors, f"{prefix}.notb{i}")
10927
+ for i in range(bits):
10928
+ fa_prefix = f"{prefix}.fa{i}"
10929
+ add_xor_gate(tensors, f"{fa_prefix}.xor1")
10930
+ add_xor_gate(tensors, f"{fa_prefix}.xor2")
10931
+ add_and_gate(tensors, f"{fa_prefix}.and1")
10932
+ add_and_gate(tensors, f"{fa_prefix}.and2")
10933
+ add_or_gate(tensors, f"{fa_prefix}.or_carry")
10934
+ return tensors
10935
+
10936
+
10937
+ def build_sub_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]:
10938
+ tensors: Dict[str, torch.Tensor] = {}
10939
+ for i in range(bits):
10940
+ add_not_gate(tensors, f"{prefix}.notb{i}")
10941
+ tensors[f"{prefix}.carry_in.weight"] = torch.tensor([1.0])
10942
+ tensors[f"{prefix}.carry_in.bias"] = torch.tensor([-0.5])
10943
+ for i in range(bits):
10944
+ fa_prefix = f"{prefix}.fa{i}"
10945
+ add_xor_gate(tensors, f"{fa_prefix}.xor1")
10946
+ add_xor_gate(tensors, f"{fa_prefix}.xor2")
10947
+ add_and_gate(tensors, f"{fa_prefix}.and1")
10948
+ add_and_gate(tensors, f"{fa_prefix}.and2")
10949
+ add_or_gate(tensors, f"{fa_prefix}.or_carry")
10950
+ return tensors
10951
+
10952
+
10953
+ def build_cmp_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]:
10954
+ tensors: Dict[str, torch.Tensor] = {}
10955
+ for i in range(bits):
10956
+ add_not_gate(tensors, f"{prefix}.notb{i}")
10957
+ for i in range(bits):
10958
+ fa_prefix = f"{prefix}.fa{i}"
10959
+ add_xor_gate(tensors, f"{fa_prefix}.xor1")
10960
+ add_xor_gate(tensors, f"{fa_prefix}.xor2")
10961
+ add_and_gate(tensors, f"{fa_prefix}.and1")
10962
+ add_and_gate(tensors, f"{fa_prefix}.and2")
10963
+ add_or_gate(tensors, f"{fa_prefix}.or_carry")
10964
+ return tensors
10965
+
10966
+
10967
+ def build_equality_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]:
10968
+ tensors: Dict[str, torch.Tensor] = {}
10969
+ for i in range(bits):
10970
+ add_xnor_gate(tensors, f"{prefix}.xnor{i}")
10971
+ tensors[f"{prefix}.final_and.weight"] = torch.tensor([1.0] * bits)
10972
+ tensors[f"{prefix}.final_and.bias"] = torch.tensor([-(bits - 0.5)])
10973
+ return tensors
10974
+
10975
+
10976
+ def build_neg_tensors(prefix: str, bits: int) -> Dict[str, torch.Tensor]:
10977
+ tensors: Dict[str, torch.Tensor] = {}
10978
+ for i in range(bits):
10979
+ add_not_gate(tensors, f"{prefix}.not{i}")
10980
+ # sum0 = NOT(not0) == x0 (since ~x + 1 toggles the LSB)
10981
+ tensors[f"{prefix}.sum0.weight"] = torch.tensor([-1.0])
10982
+ tensors[f"{prefix}.sum0.bias"] = torch.tensor([0.0])
10983
+ tensors[f"{prefix}.carry0.weight"] = torch.tensor([1.0, 1.0])
10984
+ tensors[f"{prefix}.carry0.bias"] = torch.tensor([-2.0])
10985
+ for i in range(1, bits):
10986
+ add_xor_gate(tensors, f"{prefix}.xor{i}")
10987
+ add_and_gate(tensors, f"{prefix}.and{i}")
10988
+ return tensors
10989
+
10990
+
10991
+ def build_shift_rotate_tensors(prefix: str, bits: int, kind: str) -> Dict[str, torch.Tensor]:
10992
+ tensors: Dict[str, torch.Tensor] = {}
10993
+ for i in range(bits):
10994
+ if kind == "asr":
10995
+ src = i + 1 if i < bits - 1 else bits - 1
10996
+ elif kind == "rol":
10997
+ src = (i - 1) % bits
10998
+ elif kind == "ror":
10999
+ src = (i + 1) % bits
11000
+ else:
11001
+ raise ValueError(f"unknown shift kind: {kind}")
11002
+ w = [0.0] * bits
11003
+ w[src] = 1.0
11004
+ tensors[f"{prefix}.bit{i}.weight"] = torch.tensor(w)
11005
+ tensors[f"{prefix}.bit{i}.bias"] = torch.tensor([-0.5])
11006
+ return tensors
11007
+
11008
+
11009
+ def build_comparator_vectors(bits: int) -> Dict[str, torch.Tensor]:
11010
+ tensors: Dict[str, torch.Tensor] = {}
11011
+ weights = [float(2 ** i) for i in range(bits - 1, -1, -1)]
11012
+ names = ["greaterthan", "lessthan", "greaterorequal", "lessorequal"]
11013
+ for name in names:
11014
+ tensors[f"arithmetic.{name}{bits}bit.comparator"] = torch.tensor(weights)
11015
+ return tensors
11016
+
11017
+
11018
+ def build_increment_decrement_constants(bits: int) -> Dict[str, torch.Tensor]:
11019
+ tensors: Dict[str, torch.Tensor] = {}
11020
+ one = [0.0] * (bits - 1) + [1.0]
11021
+ tensors[f"arithmetic.incrementer{bits}bit.one"] = torch.tensor(one)
11022
+ tensors[f"arithmetic.incrementer{bits}bit.adder"] = torch.tensor([1.0] * bits)
11023
+ tensors[f"arithmetic.decrementer{bits}bit.neg_one"] = torch.tensor([1.0] * bits)
11024
+ tensors[f"arithmetic.decrementer{bits}bit.adder"] = torch.tensor([1.0] * bits)
11025
+ return tensors
11026
+
11027
+
11028
+ def build_minmax_diff_constants(bits: int) -> Dict[str, torch.Tensor]:
11029
+ tensors: Dict[str, torch.Tensor] = {}
11030
+ width = bits * 2
11031
+ tensors[f"arithmetic.absolutedifference{bits}bit.diff"] = torch.tensor([1.0] * width)
11032
+ tensors[f"arithmetic.max{bits}bit.select"] = torch.tensor([1.0] * width)
11033
+ tensors[f"arithmetic.min{bits}bit.select"] = torch.tensor([1.0] * width)
11034
+ return tensors
11035
+
11036
+
11037
  def main():
11038
  print("Loading existing tensors...")
11039
  tensors = {}
 
11064
  del tensors[k]
11065
  print(f"Removed {len(old_float16_div)} old float16.div tensors")
11066
 
11067
+ old_float16_lut = [k for k in tensors.keys() if k.startswith('float16.lut') or
11068
+ k.startswith('float16.sqrt') or k.startswith('float16.rsqrt') or
11069
+ k.startswith('float16.exp') or k.startswith('float16.ln') or
11070
+ k.startswith('float16.log2') or k.startswith('float16.sin') or
11071
+ k.startswith('float16.cos') or k.startswith('float16.tan') or
11072
+ k.startswith('float16.tanh') or k.startswith('float16.pow')]
11073
+ for k in old_float16_lut:
11074
+ del tensors[k]
11075
+ print(f"Removed {len(old_float16_lut)} old float16 LUT/pow tensors")
11076
+
11077
+ old_arith_8bit = [k for k in tensors.keys() if k.startswith('arithmetic.') and '8bit' in k]
11078
+ for k in old_arith_8bit:
11079
+ del tensors[k]
11080
+ print(f"Removed {len(old_arith_8bit)} old arithmetic 8-bit tensors")
11081
+
11082
+ old_mult8x8 = [k for k in tensors.keys() if k.startswith('arithmetic.multiplier8x8')]
11083
+ for k in old_mult8x8:
11084
+ del tensors[k]
11085
+ print(f"Removed {len(old_mult8x8)} old multiplier8x8 tensors")
11086
+
11087
+ old_div8bit = [k for k in tensors.keys() if k.startswith('arithmetic.div8bit')]
11088
+ for k in old_div8bit:
11089
+ del tensors[k]
11090
+ print(f"Removed {len(old_div8bit)} old div8bit tensors")
11091
+
11092
  # Remove broken mod2/mod4/mod8 tensors
11093
  old_mod_power2 = [k for k in tensors.keys() if k.startswith('modular.mod2') or
11094
  k.startswith('modular.mod4') or k.startswith('modular.mod8')]
 
11111
 
11112
  # Build new circuits
11113
  print("Building new circuits...")
 
 
 
 
11114
  clz16_tensors = build_clz16bit_tensors()
11115
  tensors.update(clz16_tensors)
11116
  print(f" CLZ16BIT: {len(clz16_tensors)} tensors")
 
11163
  tensors.update(fromint_tensors)
11164
  print(f" float16.fromint: {len(fromint_tensors)} tensors")
11165
 
11166
+ # Shared LUT match gates
11167
+ lut_match_tensors = build_float16_lut_match_tensors("float16.lut")
11168
+ tensors.update(lut_match_tensors)
11169
+ print(f" float16.lut: {len(lut_match_tensors)} tensors")
11170
+
11171
+ # Unary LUT outputs
11172
+ unary_ops = {
11173
+ "sqrt": torch.sqrt,
11174
+ "rsqrt": torch.rsqrt,
11175
+ "exp": torch.exp,
11176
+ "ln": torch.log,
11177
+ "log2": torch.log2,
11178
+ "sin": torch.sin,
11179
+ "cos": torch.cos,
11180
+ "tan": torch.tan,
11181
+ "tanh": torch.tanh,
11182
+ }
11183
+ lut_outputs: Dict[str, List[int]] = {}
11184
+ for name, fn in unary_ops.items():
11185
+ print(f" computing float16.{name} LUT...")
11186
+ outputs = compute_float16_unary_lut_outputs(fn)
11187
+ lut_outputs[name] = outputs
11188
+ op_tensors = build_float16_lut_output_tensors(f"float16.{name}", outputs)
11189
+ tensors.update(op_tensors)
11190
+ print(f" float16.{name}: {len(op_tensors)} tensors")
11191
+
11192
+ # float16.pow (ln -> mul -> exp)
11193
+ pow_tensors = build_float16_pow_tensors(mul_tensors,
11194
+ lut_outputs["ln"],
11195
+ lut_outputs["exp"])
11196
+ tensors.update(pow_tensors)
11197
+ print(f" float16.pow: {len(pow_tensors)} tensors")
11198
+
11199
+ # 16-bit integer arithmetic circuits
11200
+ rc16 = build_ripplecarry_tensors("arithmetic.ripplecarry16bit", 16)
11201
+ tensors.update(rc16)
11202
+ print(f" ripplecarry16bit: {len(rc16)} tensors")
11203
+
11204
+ adc16 = build_adc_sbc_tensors("arithmetic.adc16bit", 16)
11205
+ tensors.update(adc16)
11206
+ print(f" adc16bit: {len(adc16)} tensors")
11207
+
11208
+ sbc16 = build_adc_sbc_tensors("arithmetic.sbc16bit", 16, with_notb=True)
11209
+ tensors.update(sbc16)
11210
+ print(f" sbc16bit: {len(sbc16)} tensors")
11211
+
11212
+ sub16 = build_sub_tensors("arithmetic.sub16bit", 16)
11213
+ tensors.update(sub16)
11214
+ print(f" sub16bit: {len(sub16)} tensors")
11215
+
11216
+ cmp16 = build_cmp_tensors("arithmetic.cmp16bit", 16)
11217
+ tensors.update(cmp16)
11218
+ print(f" cmp16bit: {len(cmp16)} tensors")
11219
+
11220
+ eq16 = build_equality_tensors("arithmetic.equality16bit", 16)
11221
+ tensors.update(eq16)
11222
+ print(f" equality16bit: {len(eq16)} tensors")
11223
+
11224
+ neg16 = build_neg_tensors("arithmetic.neg16bit", 16)
11225
+ tensors.update(neg16)
11226
+ print(f" neg16bit: {len(neg16)} tensors")
11227
+
11228
+ asr16 = build_shift_rotate_tensors("arithmetic.asr16bit", 16, "asr")
11229
+ rol16 = build_shift_rotate_tensors("arithmetic.rol16bit", 16, "rol")
11230
+ ror16 = build_shift_rotate_tensors("arithmetic.ror16bit", 16, "ror")
11231
+ tensors.update(asr16)
11232
+ tensors.update(rol16)
11233
+ tensors.update(ror16)
11234
+ print(f" asr/rol/ror16bit: {len(asr16) + len(rol16) + len(ror16)} tensors")
11235
+
11236
+ comp16 = build_comparator_vectors(16)
11237
+ tensors.update(comp16)
11238
+ print(f" comparator16bit: {len(comp16)} tensors")
11239
+
11240
+ incdec16 = build_increment_decrement_constants(16)
11241
+ tensors.update(incdec16)
11242
+ print(f" increment/decrement16bit: {len(incdec16)} tensors")
11243
+
11244
+ minmax16 = build_minmax_diff_constants(16)
11245
+ tensors.update(minmax16)
11246
+ print(f" min/max/diff16bit: {len(minmax16)} tensors")
11247
+
11248
  mod_power2_tensors = build_modular_power2_tensors()
11249
  tensors.update(mod_power2_tensors)
11250
  print(f" modular.mod2/4/8: {len(mod_power2_tensors)} tensors")
11251
 
 
 
 
 
11252
  symmetry_tensors = build_symmetry8bit_tensors()
11253
  tensors.update(symmetry_tensors)
11254
  print(f" symmetry8bit: {len(symmetry_tensors)} tensors")
eval.py CHANGED
@@ -13,6 +13,7 @@ Usage:
13
 
14
  import argparse
15
  import json
 
16
  import random
17
  import struct
18
  import sys
@@ -54,6 +55,10 @@ class EvalContext:
54
  verbose: bool = False
55
  quick: bool = False
56
  tested_tensors: set = field(default_factory=set)
 
 
 
 
57
 
58
 
59
  def load_model(path: str = "./arithmetic.safetensors") -> Tuple[Dict[str, torch.Tensor], List[str], Dict[str, int], Dict[str, int], Dict[int, str]]:
@@ -227,6 +232,102 @@ def build_alias_maps(ctx: EvalContext) -> Tuple[Dict[int, int], Dict[int, List[i
227
  return alias_to_gate, gate_to_alias
228
 
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  def evaluate_gates_from_inputs(ctx: EvalContext, signals: Dict[int, float],
231
  gate_list: Optional[List[str]] = None) -> Tuple[int, List[str], List[str]]:
232
  """Evaluate gates using explicit .inputs tensors. Returns (evaluated, missing_inputs, unresolved)."""
@@ -235,7 +336,10 @@ def evaluate_gates_from_inputs(ctx: EvalContext, signals: Dict[int, float],
235
  missing_inputs: List[str] = []
236
  unresolved: List[str] = []
237
  evaluated = 0
238
- alias_to_gate, gate_to_alias = build_alias_maps(ctx)
 
 
 
239
 
240
  progress = True
241
  while progress and remaining:
@@ -554,7 +658,9 @@ def eval_prefix_outputs(ctx: EvalContext, prefix: str,
554
  seed_prefix_bits(ctx, prefix, base, bits, signals)
555
 
556
  gates = gate_list if gate_list is not None else [g for g in ctx.gates if g.startswith(prefix + ".")]
557
- evaluated, missing_inputs, unresolved = evaluate_gates_from_inputs(ctx, signals, gate_list=gates)
 
 
558
  if missing_inputs or unresolved:
559
  raise RuntimeError(
560
  f"{prefix}: unresolved inputs (missing={len(missing_inputs)} unresolved={len(unresolved)})"
@@ -577,6 +683,39 @@ def eval_prefix_outputs(ctx: EvalContext, prefix: str,
577
  return outputs
578
 
579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  def build_float16_pairs(rng: random.Random, count: int) -> List[Tuple[int, int]]:
581
  """Build deterministic float16 test pairs using edge cases + random."""
582
  edges = [
@@ -613,6 +752,51 @@ def build_float16_pairs(rng: random.Random, count: int) -> List[Tuple[int, int]]
613
  return pairs
614
 
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  def float16_expected_bits_binary(op: str, a_bits: int, b_bits: int) -> Tuple[int, bool]:
617
  """Compute expected float16 bits for a binary op and whether it's NaN."""
618
  a = float16_int_to_float(a_bits)
@@ -634,6 +818,49 @@ def float16_expected_bits_binary(op: str, a_bits: int, b_bits: int) -> Tuple[int
634
  return float_to_int(float(out)), False
635
 
636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
  # =============================================================================
638
  # BOOLEAN GATE TESTS
639
  # =============================================================================
@@ -1026,31 +1253,47 @@ def eval_subtractor(ctx: EvalContext, prefix: str, a_bits: List[float],
1026
 
1027
 
1028
  def eval_negation(ctx: EvalContext, prefix: str, bits: List[float]) -> List[float]:
1029
- """Evaluate 8-bit negation (two's complement)."""
 
1030
  result = []
1031
 
1032
  # NOT each bit
1033
  not_bits = []
1034
- for i in range(8):
1035
- not_bits.append(eval_gate_direct(ctx, f"{prefix}.not{i}", [bits[i]]))
 
 
 
1036
 
1037
  # Add 1 using carry chain
1038
  carry = 1.0
1039
- for i in range(8):
1040
  if i == 0:
1041
- # First bit: XOR with carry (which is 1)
1042
- result.append(eval_gate_direct(ctx, f"{prefix}.xor0", [not_bits[0], 1.0])
1043
- if f"{prefix}.xor0.weight" in ctx.tensors
1044
- else 1.0 - not_bits[0])
1045
- carry = eval_gate_direct(ctx, f"{prefix}.carry0", [not_bits[0]]) if f"{prefix}.carry0.weight" in ctx.tensors else not_bits[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1046
  else:
1047
- # Subsequent bits use and/xor gates
1048
  if f"{prefix}.xor{i}.weight" in ctx.tensors:
1049
  result.append(eval_gate_direct(ctx, f"{prefix}.xor{i}", [not_bits[i], carry]))
1050
  elif f"{prefix}.out{i}.weight" in ctx.tensors:
1051
  result.append(eval_gate_direct(ctx, f"{prefix}.out{i}", [not_bits[i], carry]))
1052
  else:
1053
- # Manual XOR
1054
  xor_val = 1.0 if (int(not_bits[i]) != int(carry)) else 0.0
1055
  result.append(xor_val)
1056
 
@@ -1104,17 +1347,22 @@ def test_adders(ctx: EvalContext) -> List[TestResult]:
1104
  results.append(TestResult("arithmetic.fulladder", passed, total))
1105
 
1106
  # Ripple carry adders
1107
- for bits in [2, 4, 8]:
1108
  prefix = f"arithmetic.ripplecarry{bits}bit"
1109
  if f"{prefix}.fa0.ha1.sum.layer1.or.weight" not in ctx.tensors:
1110
  continue
1111
 
1112
  passed, total = 0, 0
1113
  max_val = 1 << bits
1114
- test_range = range(max_val) if (not ctx.quick or bits <= 4) else range(0, max_val, max_val // 256)
 
 
 
 
 
1115
 
1116
  for a in test_range:
1117
- for b in (test_range if bits <= 4 else [0, 1, max_val-1]):
1118
  a_bits = [float((a >> i) & 1) for i in range(bits)]
1119
  b_bits = [float((b >> i) & 1) for i in range(bits)]
1120
 
@@ -1148,6 +1396,26 @@ def test_adders(ctx: EvalContext) -> List[TestResult]:
1148
 
1149
  results.append(TestResult("arithmetic.sub8bit", passed, total))
1150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1151
  # 8-bit negation
1152
  if f"arithmetic.neg8bit.not0.weight" in ctx.tensors:
1153
  passed, total = 0, 0
@@ -1165,6 +1433,23 @@ def test_adders(ctx: EvalContext) -> List[TestResult]:
1165
 
1166
  results.append(TestResult("arithmetic.neg8bit", passed, total))
1167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1168
  # 8-bit add with carry (adc8bit)
1169
  if f"arithmetic.adc8bit.fa0.xor1.layer1.or.weight" in ctx.tensors:
1170
  passed, total = 0, 0
@@ -1186,6 +1471,28 @@ def test_adders(ctx: EvalContext) -> List[TestResult]:
1186
 
1187
  results.append(TestResult("arithmetic.adc8bit", passed, total))
1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1189
  # 8-bit subtract with borrow (sbc8bit)
1190
  # sbc computes: a - b - borrow = a + ~b + ~borrow
1191
  # So carry_in = ~borrow (1 when borrow=0, 0 when borrow=1)
@@ -1212,6 +1519,29 @@ def test_adders(ctx: EvalContext) -> List[TestResult]:
1212
 
1213
  results.append(TestResult("arithmetic.sbc8bit", passed, total))
1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1215
  return results
1216
 
1217
 
@@ -1231,23 +1561,28 @@ def test_comparators(ctx: EvalContext) -> List[TestResult]:
1231
 
1232
  # Legacy comparators (if they exist)
1233
  comparators = [
1234
- ("arithmetic.greaterthan8bit", lambda a, b: a > b),
1235
- ("arithmetic.lessthan8bit", lambda a, b: a < b),
1236
- ("arithmetic.greaterorequal8bit", lambda a, b: a >= b),
1237
- ("arithmetic.lessorequal8bit", lambda a, b: a <= b),
 
 
 
 
1238
  ]
1239
 
1240
- for name, op in comparators:
1241
  if f"{name}.weight" not in ctx.tensors:
1242
  continue
1243
 
1244
  passed, total = 0, 0
1245
- test_range = range(256) if not ctx.quick else range(0, 256, 16)
 
1246
 
1247
  for a in test_range:
1248
  for b in test_range:
1249
- a_bits = [float((a >> i) & 1) for i in range(8)]
1250
- b_bits = [float((b >> i) & 1) for i in range(8)]
1251
 
1252
  actual = eval_gate_direct(ctx, name, a_bits + b_bits)
1253
  expected = 1.0 if op(a, b) else 0.0
@@ -1283,6 +1618,26 @@ def test_comparators(ctx: EvalContext) -> List[TestResult]:
1283
 
1284
  results.append(TestResult("arithmetic.cmp8bit", passed, total))
1285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1286
  # arithmetic.equality8bit - checks if a == b
1287
  if f"arithmetic.equality8bit.xnor0.layer1.and.weight" in ctx.tensors:
1288
  passed, total = 0, 0
@@ -1309,6 +1664,29 @@ def test_comparators(ctx: EvalContext) -> List[TestResult]:
1309
 
1310
  results.append(TestResult("arithmetic.equality8bit", passed, total))
1311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1312
  return results
1313
 
1314
 
@@ -1453,6 +1831,28 @@ def test_bitwise(ctx: EvalContext) -> List[TestResult]:
1453
 
1454
  results.append(TestResult("arithmetic.asr8bit", passed, total))
1455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1456
  # Rotate left (rol8bit)
1457
  if f"arithmetic.rol8bit.bit0.weight" in ctx.tensors:
1458
  passed, total = 0, 0
@@ -1478,6 +1878,27 @@ def test_bitwise(ctx: EvalContext) -> List[TestResult]:
1478
 
1479
  results.append(TestResult("arithmetic.rol8bit", passed, total))
1480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1481
  # Rotate right (ror8bit)
1482
  if f"arithmetic.ror8bit.bit0.weight" in ctx.tensors:
1483
  passed, total = 0, 0
@@ -1503,6 +1924,27 @@ def test_bitwise(ctx: EvalContext) -> List[TestResult]:
1503
 
1504
  results.append(TestResult("arithmetic.ror8bit", passed, total))
1505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1506
  return results
1507
 
1508
 
@@ -1674,13 +2116,12 @@ def test_orphan_tensors(ctx: EvalContext) -> List[TestResult]:
1674
 
1675
  # Comparator-like weight vectors (MSB-first weights)
1676
  comp_names = [
1677
- "arithmetic.greaterthan8bit.comparator",
1678
- "arithmetic.lessthan8bit.comparator",
1679
- "arithmetic.greaterorequal8bit.comparator",
1680
- "arithmetic.lessorequal8bit.comparator",
1681
  "combinational.priorityencoder8bit.priority",
1682
  ]
1683
- expected_weights = [128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0, 1.0]
1684
 
1685
  for name in comp_names:
1686
  if name not in ctx.tensors:
@@ -1689,15 +2130,16 @@ def test_orphan_tensors(ctx: EvalContext) -> List[TestResult]:
1689
  ctx.tested_tensors.add(name)
1690
 
1691
  passed, total = 0, 0
1692
- # Validate weight pattern
 
1693
  total += 1
1694
  if weights == expected_weights:
1695
  passed += 1
1696
 
1697
  # Validate numeric interpretation (MSB-first bits -> value)
1698
- test_range = range(256) if not ctx.quick else range(0, 256, 17)
1699
  for val in test_range:
1700
- bits = [float((val >> i) & 1) for i in range(8)][::-1]
1701
  actual = sum(w * b for w, b in zip(weights, bits))
1702
  total += 1
1703
  if int(actual + 0.5) == val:
@@ -1707,8 +2149,8 @@ def test_orphan_tensors(ctx: EvalContext) -> List[TestResult]:
1707
 
1708
  # Constant/selector vectors
1709
  const_specs = {
1710
- "arithmetic.incrementer8bit.one": ([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], 1),
1711
- "arithmetic.decrementer8bit.neg_one": ([1.0] * 8, 255),
1712
  }
1713
  for name, (expected_bits, expected_val) in const_specs.items():
1714
  if name not in ctx.tensors:
@@ -1724,11 +2166,11 @@ def test_orphan_tensors(ctx: EvalContext) -> List[TestResult]:
1724
 
1725
  # All-ones selector/mask tensors
1726
  ones_specs = {
1727
- "arithmetic.absolutedifference8bit.diff": 16,
1728
- "arithmetic.incrementer8bit.adder": 8,
1729
- "arithmetic.decrementer8bit.adder": 8,
1730
- "arithmetic.max8bit.select": 16,
1731
- "arithmetic.min8bit.select": 16,
1732
  "combinational.barrelshifter8bit.shift": 11,
1733
  "combinational.demultiplexer1to4.decode": 3,
1734
  "combinational.demultiplexer1to8.decode": 4,
@@ -2228,6 +2670,124 @@ def test_float16_conversion(ctx: EvalContext) -> List[TestResult]:
2228
  return results
2229
 
2230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2231
  # =============================================================================
2232
  # TEST RUNNER
2233
  # =============================================================================
@@ -2248,6 +2808,8 @@ CATEGORIES = {
2248
  "float16_basic": ("Float16 - Basic", test_float16_basic),
2249
  "float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
2250
  "float16_conv": ("Float16 - Conversion", test_float16_conversion),
 
 
2251
  }
2252
 
2253
 
 
13
 
14
  import argparse
15
  import json
16
+ import math
17
  import random
18
  import struct
19
  import sys
 
55
  verbose: bool = False
56
  quick: bool = False
57
  tested_tensors: set = field(default_factory=set)
58
+ alias_to_gate: Dict[int, int] = field(default_factory=dict)
59
+ gate_to_alias: Dict[int, List[int]] = field(default_factory=dict)
60
+ alias_ready: bool = False
61
+ topo_cache: Dict[str, List[str]] = field(default_factory=dict)
62
 
63
 
64
  def load_model(path: str = "./arithmetic.safetensors") -> Tuple[Dict[str, torch.Tensor], List[str], Dict[str, int], Dict[str, int], Dict[int, str]]:
 
232
  return alias_to_gate, gate_to_alias
233
 
234
 
235
+ def topo_sort_gates(ctx: EvalContext, gate_list: List[str]) -> List[str]:
236
+ """Topologically sort gates based on .inputs dependencies."""
237
+ gate_set = set(gate_list)
238
+ deps: Dict[str, set] = {g: set() for g in gate_list}
239
+ rev: Dict[str, List[str]] = {g: [] for g in gate_list}
240
+
241
+ for gate in gate_list:
242
+ inputs_key = f"{gate}.inputs"
243
+ if inputs_key not in ctx.tensors:
244
+ continue
245
+ input_ids = [int(x) for x in ctx.tensors[inputs_key].tolist()]
246
+ for sid in input_ids:
247
+ name = ctx.id_to_name.get(sid)
248
+ if name and name in gate_set:
249
+ deps[gate].add(name)
250
+ rev[name].append(gate)
251
+
252
+ queue = [g for g in gate_list if not deps[g]]
253
+ order: List[str] = []
254
+ # Deterministic order
255
+ queue.sort()
256
+
257
+ while queue:
258
+ g = queue.pop(0)
259
+ order.append(g)
260
+ for child in rev[g]:
261
+ deps[child].remove(g)
262
+ if not deps[child]:
263
+ queue.append(child)
264
+ queue.sort()
265
+
266
+ # Fallback to original order if cycle/unresolved
267
+ if len(order) != len(gate_list):
268
+ return gate_list
269
+ return order
270
+
271
+
272
+ def evaluate_gates_in_order(ctx: EvalContext, signals: Dict[int, float],
273
+ gate_order: List[str]) -> Tuple[int, List[str], List[str]]:
274
+ """Evaluate gates in a fixed topological order."""
275
+ missing_inputs: List[str] = []
276
+ unresolved: List[str] = []
277
+ evaluated = 0
278
+
279
+ if not ctx.alias_ready:
280
+ ctx.alias_to_gate, ctx.gate_to_alias = build_alias_maps(ctx)
281
+ ctx.alias_ready = True
282
+ alias_to_gate, gate_to_alias = ctx.alias_to_gate, ctx.gate_to_alias
283
+
284
+ for gate in gate_order:
285
+ inputs_key = f"{gate}.inputs"
286
+ weight_key = f"{gate}.weight"
287
+ bias_key = f"{gate}.bias"
288
+
289
+ if inputs_key not in ctx.tensors:
290
+ missing_inputs.append(gate)
291
+ continue
292
+
293
+ input_ids = [int(x) for x in ctx.tensors[inputs_key].tolist()]
294
+ ready = True
295
+ for sid in input_ids:
296
+ if sid in signals:
297
+ continue
298
+ alias_gate = alias_to_gate.get(sid)
299
+ if alias_gate is not None and alias_gate in signals:
300
+ signals[sid] = signals[alias_gate]
301
+ continue
302
+ ready = False
303
+ break
304
+ if not ready:
305
+ unresolved.append(gate)
306
+ continue
307
+
308
+ weight = ctx.tensors[weight_key].tolist()
309
+ bias = ctx.tensors.get(bias_key, torch.tensor([0.0])).item()
310
+ total = bias + sum(w * signals[sid] for w, sid in zip(weight, input_ids))
311
+ out = 1.0 if total >= 0 else 0.0
312
+
313
+ gate_id = ctx.name_to_id.get(gate)
314
+ if gate_id is not None:
315
+ signals[gate_id] = out
316
+ for alias_id in gate_to_alias.get(gate_id, []):
317
+ signals[alias_id] = out
318
+
319
+ if inputs_key in ctx.tensors:
320
+ ctx.tested_tensors.add(inputs_key)
321
+ if weight_key in ctx.tensors:
322
+ ctx.tested_tensors.add(weight_key)
323
+ if bias_key in ctx.tensors:
324
+ ctx.tested_tensors.add(bias_key)
325
+
326
+ evaluated += 1
327
+
328
+ return evaluated, missing_inputs, unresolved
329
+
330
+
331
  def evaluate_gates_from_inputs(ctx: EvalContext, signals: Dict[int, float],
332
  gate_list: Optional[List[str]] = None) -> Tuple[int, List[str], List[str]]:
333
  """Evaluate gates using explicit .inputs tensors. Returns (evaluated, missing_inputs, unresolved)."""
 
336
  missing_inputs: List[str] = []
337
  unresolved: List[str] = []
338
  evaluated = 0
339
+ if not ctx.alias_ready:
340
+ ctx.alias_to_gate, ctx.gate_to_alias = build_alias_maps(ctx)
341
+ ctx.alias_ready = True
342
+ alias_to_gate, gate_to_alias = ctx.alias_to_gate, ctx.gate_to_alias
343
 
344
  progress = True
345
  while progress and remaining:
 
658
  seed_prefix_bits(ctx, prefix, base, bits, signals)
659
 
660
  gates = gate_list if gate_list is not None else [g for g in ctx.gates if g.startswith(prefix + ".")]
661
+ if prefix not in ctx.topo_cache or len(ctx.topo_cache[prefix]) != len(gates):
662
+ ctx.topo_cache[prefix] = topo_sort_gates(ctx, gates)
663
+ evaluated, missing_inputs, unresolved = evaluate_gates_in_order(ctx, signals, ctx.topo_cache[prefix])
664
  if missing_inputs or unresolved:
665
  raise RuntimeError(
666
  f"{prefix}: unresolved inputs (missing={len(missing_inputs)} unresolved={len(unresolved)})"
 
683
  return outputs
684
 
685
 
686
+ def eval_float16_lut_outputs(ctx: EvalContext, op_prefix: str,
687
+ bits: List[float],
688
+ match_prefix: str = "float16.lut") -> List[float]:
689
+ """Evaluate LUT-backed float16 unary ops using direct LUT indexing."""
690
+ idx = bits_to_int(bits)
691
+
692
+ # Mark the matching LUT gate tensors as tested for coverage.
693
+ match_gate = f"{match_prefix}.match{idx:04x}"
694
+ for suffix in (".weight", ".bias", ".inputs"):
695
+ key = match_gate + suffix
696
+ if key in ctx.tensors:
697
+ ctx.tested_tensors.add(key)
698
+
699
+ outputs: List[float] = []
700
+ for i in range(16):
701
+ gate = f"{op_prefix}.out{i}"
702
+ weight_key = f"{gate}.weight"
703
+ bias_key = f"{gate}.bias"
704
+ inputs_key = f"{gate}.inputs"
705
+
706
+ ctx.tested_tensors.add(weight_key)
707
+ if bias_key in ctx.tensors:
708
+ ctx.tested_tensors.add(bias_key)
709
+ if inputs_key in ctx.tensors:
710
+ ctx.tested_tensors.add(inputs_key)
711
+
712
+ weight = ctx.tensors[weight_key][idx].item()
713
+ bias = ctx.tensors.get(bias_key, torch.tensor([0.0])).item()
714
+ outputs.append(1.0 if (weight + bias) >= 0 else 0.0)
715
+
716
+ return outputs
717
+
718
+
719
  def build_float16_pairs(rng: random.Random, count: int) -> List[Tuple[int, int]]:
720
  """Build deterministic float16 test pairs using edge cases + random."""
721
  edges = [
 
752
  return pairs
753
 
754
 
755
+ def build_float16_values(rng: random.Random, count: int) -> List[int]:
756
+ """Build deterministic float16 test values using edge cases + random."""
757
+ edges = [
758
+ 0x0000, # +0
759
+ 0x8000, # -0
760
+ 0x3C00, # 1.0
761
+ 0xBC00, # -1.0
762
+ 0x4000, # 2.0
763
+ 0xC000, # -2.0
764
+ 0x3E00, # 1.5
765
+ 0x3555, # ~0.333
766
+ 0x7BFF, # max finite
767
+ 0xFBFF, # min finite
768
+ 0x0400, # min normal
769
+ 0x0001, # min subnormal
770
+ 0x03FF, # max subnormal
771
+ 0x7C00, # +inf
772
+ 0xFC00, # -inf
773
+ 0x7E00, # NaN
774
+ ]
775
+ # Extra edges for trig/exp/log
776
+ for val in [0.5, -0.5, math.pi, -math.pi, math.pi / 2, -math.pi / 2, math.e, -math.e]:
777
+ edges.append(float_to_int(float(val)))
778
+
779
+ # Deduplicate while preserving order
780
+ seen = set()
781
+ values = []
782
+ for v in edges:
783
+ if v not in seen:
784
+ seen.add(v)
785
+ values.append(v)
786
+
787
+ rng.shuffle(values)
788
+ values = values[:min(len(values), count)]
789
+
790
+ while len(values) < count:
791
+ v = rng.getrandbits(16)
792
+ if v in seen:
793
+ continue
794
+ seen.add(v)
795
+ values.append(v)
796
+
797
+ return values
798
+
799
+
800
  def float16_expected_bits_binary(op: str, a_bits: int, b_bits: int) -> Tuple[int, bool]:
801
  """Compute expected float16 bits for a binary op and whether it's NaN."""
802
  a = float16_int_to_float(a_bits)
 
818
  return float_to_int(float(out)), False
819
 
820
 
821
+ def float16_expected_bits_unary(op: str, a_bits: int) -> Tuple[int, bool]:
822
+ """Compute expected float16 bits for a unary op and whether it's NaN."""
823
+ a = float16_int_to_float(a_bits)
824
+ a16 = torch.tensor(a, dtype=torch.float16)
825
+ if op == "sqrt":
826
+ out = torch.sqrt(a16).item()
827
+ elif op == "rsqrt":
828
+ out = torch.rsqrt(a16).item()
829
+ elif op == "exp":
830
+ out = torch.exp(a16).item()
831
+ elif op == "ln":
832
+ out = torch.log(a16).item()
833
+ elif op == "log2":
834
+ out = torch.log2(a16).item()
835
+ elif op == "sin":
836
+ out = torch.sin(a16).item()
837
+ elif op == "cos":
838
+ out = torch.cos(a16).item()
839
+ elif op == "tan":
840
+ out = torch.tan(a16).item()
841
+ elif op == "tanh":
842
+ out = torch.tanh(a16).item()
843
+ else:
844
+ raise ValueError(f"unknown op: {op}")
845
+ if out != out:
846
+ return 0x7E00, True
847
+ return float_to_int(float(out)), False
848
+
849
+
850
+ def float16_expected_bits_pow(a_bits: int, b_bits: int) -> Tuple[int, bool]:
851
+ """Compute expected float16 bits for pow via exp(b * ln(a))."""
852
+ a = float16_int_to_float(a_bits)
853
+ b = float16_int_to_float(b_bits)
854
+ a16 = torch.tensor(a, dtype=torch.float16)
855
+ b16 = torch.tensor(b, dtype=torch.float16)
856
+ ln_a = torch.log(a16)
857
+ prod = ln_a * b16
858
+ out = torch.exp(prod).item()
859
+ if out != out:
860
+ return 0x7E00, True
861
+ return float_to_int(float(out)), False
862
+
863
+
864
  # =============================================================================
865
  # BOOLEAN GATE TESTS
866
  # =============================================================================
 
1253
 
1254
 
1255
  def eval_negation(ctx: EvalContext, prefix: str, bits: List[float]) -> List[float]:
1256
+ """Evaluate negation (two's complement) for variable width."""
1257
+ n = len(bits)
1258
  result = []
1259
 
1260
  # NOT each bit
1261
  not_bits = []
1262
+ for i in range(n):
1263
+ if f"{prefix}.not{i}.weight" in ctx.tensors:
1264
+ not_bits.append(eval_gate_direct(ctx, f"{prefix}.not{i}", [bits[i]]))
1265
+ else:
1266
+ not_bits.append(1.0 - bits[i])
1267
 
1268
  # Add 1 using carry chain
1269
  carry = 1.0
1270
+ for i in range(n):
1271
  if i == 0:
1272
+ if f"{prefix}.sum0.weight" in ctx.tensors:
1273
+ sum_w = ctx.tensors[f"{prefix}.sum0.weight"]
1274
+ if sum_w.numel() == 1:
1275
+ result.append(eval_gate_direct(ctx, f"{prefix}.sum0", [not_bits[0]]))
1276
+ else:
1277
+ result.append(eval_gate_direct(ctx, f"{prefix}.sum0", [not_bits[0], 1.0]))
1278
+ elif f"{prefix}.xor0.weight" in ctx.tensors:
1279
+ result.append(eval_gate_direct(ctx, f"{prefix}.xor0", [not_bits[0], 1.0]))
1280
+ else:
1281
+ result.append(1.0 - not_bits[0])
1282
+
1283
+ if f"{prefix}.carry0.weight" in ctx.tensors:
1284
+ carry_w = ctx.tensors[f"{prefix}.carry0.weight"]
1285
+ if carry_w.numel() == 1:
1286
+ carry = eval_gate_direct(ctx, f"{prefix}.carry0", [not_bits[0]])
1287
+ else:
1288
+ carry = eval_gate_direct(ctx, f"{prefix}.carry0", [not_bits[0], 1.0])
1289
+ else:
1290
+ carry = not_bits[0]
1291
  else:
 
1292
  if f"{prefix}.xor{i}.weight" in ctx.tensors:
1293
  result.append(eval_gate_direct(ctx, f"{prefix}.xor{i}", [not_bits[i], carry]))
1294
  elif f"{prefix}.out{i}.weight" in ctx.tensors:
1295
  result.append(eval_gate_direct(ctx, f"{prefix}.out{i}", [not_bits[i], carry]))
1296
  else:
 
1297
  xor_val = 1.0 if (int(not_bits[i]) != int(carry)) else 0.0
1298
  result.append(xor_val)
1299
 
 
1347
  results.append(TestResult("arithmetic.fulladder", passed, total))
1348
 
1349
  # Ripple carry adders
1350
+ for bits in [2, 4, 8, 16]:
1351
  prefix = f"arithmetic.ripplecarry{bits}bit"
1352
  if f"{prefix}.fa0.ha1.sum.layer1.or.weight" not in ctx.tensors:
1353
  continue
1354
 
1355
  passed, total = 0, 0
1356
  max_val = 1 << bits
1357
+ if bits >= 16:
1358
+ test_range = range(0, max_val, max_val // 256)
1359
+ b_vals = [0, 1, max_val - 1]
1360
+ else:
1361
+ test_range = range(max_val) if (not ctx.quick or bits <= 4) else range(0, max_val, max_val // 256)
1362
+ b_vals = test_range if bits <= 4 else [0, 1, max_val - 1]
1363
 
1364
  for a in test_range:
1365
+ for b in b_vals:
1366
  a_bits = [float((a >> i) & 1) for i in range(bits)]
1367
  b_bits = [float((b >> i) & 1) for i in range(bits)]
1368
 
 
1396
 
1397
  results.append(TestResult("arithmetic.sub8bit", passed, total))
1398
 
1399
+ # 16-bit subtractor
1400
+ if f"arithmetic.sub16bit.fa0.xor1.layer1.or.weight" in ctx.tensors:
1401
+ passed, total = 0, 0
1402
+ test_range = range(0, 1 << 16, 257)
1403
+
1404
+ for a in test_range:
1405
+ for b in test_range:
1406
+ a_bits = [float((a >> i) & 1) for i in range(16)]
1407
+ b_bits = [float((b >> i) & 1) for i in range(16)]
1408
+
1409
+ result_bits, _ = eval_subtractor(ctx, "arithmetic.sub16bit", a_bits, b_bits)
1410
+ result = sum(int(bit) << i for i, bit in enumerate(result_bits))
1411
+ expected = (a - b) % (1 << 16)
1412
+
1413
+ total += 1
1414
+ if result == expected:
1415
+ passed += 1
1416
+
1417
+ results.append(TestResult("arithmetic.sub16bit", passed, total))
1418
+
1419
  # 8-bit negation
1420
  if f"arithmetic.neg8bit.not0.weight" in ctx.tensors:
1421
  passed, total = 0, 0
 
1433
 
1434
  results.append(TestResult("arithmetic.neg8bit", passed, total))
1435
 
1436
+ # 16-bit negation
1437
+ if f"arithmetic.neg16bit.not0.weight" in ctx.tensors:
1438
+ passed, total = 0, 0
1439
+ test_range = range(0, 1 << 16, 257)
1440
+
1441
+ for val in test_range:
1442
+ bits = [float((val >> i) & 1) for i in range(16)]
1443
+ result_bits = eval_negation(ctx, "arithmetic.neg16bit", bits)
1444
+ result = sum(int(bit) << i for i, bit in enumerate(result_bits))
1445
+ expected = (-val) % (1 << 16)
1446
+
1447
+ total += 1
1448
+ if result == expected:
1449
+ passed += 1
1450
+
1451
+ results.append(TestResult("arithmetic.neg16bit", passed, total))
1452
+
1453
  # 8-bit add with carry (adc8bit)
1454
  if f"arithmetic.adc8bit.fa0.xor1.layer1.or.weight" in ctx.tensors:
1455
  passed, total = 0, 0
 
1471
 
1472
  results.append(TestResult("arithmetic.adc8bit", passed, total))
1473
 
1474
+ # 16-bit add with carry (adc16bit)
1475
+ if f"arithmetic.adc16bit.fa0.xor1.layer1.or.weight" in ctx.tensors:
1476
+ passed, total = 0, 0
1477
+ test_cases = [(0, 0, 0), (0, 0, 1), (65535, 1, 0), (65535, 1, 1),
1478
+ (32767, 32768, 0), (32767, 32768, 1)]
1479
+ test_cases.extend((a, b, c) for a in range(0, 65536, 4096)
1480
+ for b in range(0, 65536, 4096) for c in [0, 1])
1481
+
1482
+ for a, b, cin in test_cases:
1483
+ a_bits = [float((a >> i) & 1) for i in range(16)]
1484
+ b_bits = [float((b >> i) & 1) for i in range(16)]
1485
+
1486
+ result_bits = eval_ripple_carry_adder(ctx, "arithmetic.adc16bit", a_bits, b_bits, float(cin))
1487
+ result = sum(int(bit) << i for i, bit in enumerate(result_bits))
1488
+ expected = (a + b + cin) % (1 << 16)
1489
+
1490
+ total += 1
1491
+ if result == expected:
1492
+ passed += 1
1493
+
1494
+ results.append(TestResult("arithmetic.adc16bit", passed, total))
1495
+
1496
  # 8-bit subtract with borrow (sbc8bit)
1497
  # sbc computes: a - b - borrow = a + ~b + ~borrow
1498
  # So carry_in = ~borrow (1 when borrow=0, 0 when borrow=1)
 
1519
 
1520
  results.append(TestResult("arithmetic.sbc8bit", passed, total))
1521
 
1522
+ # 16-bit subtract with borrow (sbc16bit)
1523
+ if f"arithmetic.sbc16bit.fa0.xor1.layer1.or.weight" in ctx.tensors:
1524
+ passed, total = 0, 0
1525
+ test_cases = [(0, 0, 0), (0, 0, 1), (65535, 1, 0), (65535, 1, 1),
1526
+ (50000, 1234, 0), (50000, 1234, 1)]
1527
+ test_cases.extend((a, b, c) for a in range(0, 65536, 4096)
1528
+ for b in range(0, 65536, 4096) for c in [0, 1])
1529
+
1530
+ for a, b, borrow in test_cases:
1531
+ a_bits = [float((a >> i) & 1) for i in range(16)]
1532
+ b_bits = [float((b >> i) & 1) for i in range(16)]
1533
+
1534
+ initial_carry = 1.0 - float(borrow)
1535
+ result_bits, _ = eval_subtractor(ctx, "arithmetic.sbc16bit", a_bits, b_bits, initial_carry)
1536
+ result = sum(int(bit) << i for i, bit in enumerate(result_bits))
1537
+ expected = (a - b - borrow) % (1 << 16)
1538
+
1539
+ total += 1
1540
+ if result == expected:
1541
+ passed += 1
1542
+
1543
+ results.append(TestResult("arithmetic.sbc16bit", passed, total))
1544
+
1545
  return results
1546
 
1547
 
 
1561
 
1562
  # Legacy comparators (if they exist)
1563
  comparators = [
1564
+ ("arithmetic.greaterthan8bit", lambda a, b: a > b, 8, range(256)),
1565
+ ("arithmetic.lessthan8bit", lambda a, b: a < b, 8, range(256)),
1566
+ ("arithmetic.greaterorequal8bit", lambda a, b: a >= b, 8, range(256)),
1567
+ ("arithmetic.lessorequal8bit", lambda a, b: a <= b, 8, range(256)),
1568
+ ("arithmetic.greaterthan16bit", lambda a, b: a > b, 16, range(0, 1 << 16, 257)),
1569
+ ("arithmetic.lessthan16bit", lambda a, b: a < b, 16, range(0, 1 << 16, 257)),
1570
+ ("arithmetic.greaterorequal16bit", lambda a, b: a >= b, 16, range(0, 1 << 16, 257)),
1571
+ ("arithmetic.lessorequal16bit", lambda a, b: a <= b, 16, range(0, 1 << 16, 257)),
1572
  ]
1573
 
1574
+ for name, op, bits, test_range in comparators:
1575
  if f"{name}.weight" not in ctx.tensors:
1576
  continue
1577
 
1578
  passed, total = 0, 0
1579
+ if ctx.quick:
1580
+ test_range = range(0, (1 << bits), max(1, (1 << bits) // 256))
1581
 
1582
  for a in test_range:
1583
  for b in test_range:
1584
+ a_bits = [float((a >> i) & 1) for i in range(bits)]
1585
+ b_bits = [float((b >> i) & 1) for i in range(bits)]
1586
 
1587
  actual = eval_gate_direct(ctx, name, a_bits + b_bits)
1588
  expected = 1.0 if op(a, b) else 0.0
 
1618
 
1619
  results.append(TestResult("arithmetic.cmp8bit", passed, total))
1620
 
1621
+ # arithmetic.cmp16bit - compares a and b, outputs sign of (a - b)
1622
+ if f"arithmetic.cmp16bit.fa0.xor1.layer1.or.weight" in ctx.tensors:
1623
+ passed, total = 0, 0
1624
+ test_range = range(0, 1 << 16, 257)
1625
+
1626
+ for a in test_range:
1627
+ for b in test_range:
1628
+ a_bits = [float((a >> i) & 1) for i in range(16)]
1629
+ b_bits = [float((b >> i) & 1) for i in range(16)]
1630
+
1631
+ result_bits, borrow = eval_subtractor(ctx, "arithmetic.cmp16bit", a_bits, b_bits)
1632
+ expected_lt = 1.0 if a < b else 0.0
1633
+ actual_lt = 1.0 - borrow
1634
+
1635
+ total += 1
1636
+ if actual_lt == expected_lt:
1637
+ passed += 1
1638
+
1639
+ results.append(TestResult("arithmetic.cmp16bit", passed, total))
1640
+
1641
  # arithmetic.equality8bit - checks if a == b
1642
  if f"arithmetic.equality8bit.xnor0.layer1.and.weight" in ctx.tensors:
1643
  passed, total = 0, 0
 
1664
 
1665
  results.append(TestResult("arithmetic.equality8bit", passed, total))
1666
 
1667
+ if f"arithmetic.equality16bit.xnor0.layer1.and.weight" in ctx.tensors:
1668
+ passed, total = 0, 0
1669
+ test_range = range(0, 1 << 16, 257)
1670
+
1671
+ for a in test_range:
1672
+ for b in test_range:
1673
+ a_bits = [float((a >> i) & 1) for i in range(16)]
1674
+ b_bits = [float((b >> i) & 1) for i in range(16)]
1675
+
1676
+ xnor_results = []
1677
+ for i in range(16):
1678
+ xnor_val = eval_xnor_gate(ctx, f"arithmetic.equality16bit.xnor{i}", a_bits[i], b_bits[i])
1679
+ xnor_results.append(xnor_val)
1680
+
1681
+ actual = eval_gate_direct(ctx, "arithmetic.equality16bit.final_and", xnor_results)
1682
+ expected = 1.0 if a == b else 0.0
1683
+
1684
+ total += 1
1685
+ if actual == expected:
1686
+ passed += 1
1687
+
1688
+ results.append(TestResult("arithmetic.equality16bit", passed, total))
1689
+
1690
  return results
1691
 
1692
 
 
1831
 
1832
  results.append(TestResult("arithmetic.asr8bit", passed, total))
1833
 
1834
+ # Arithmetic shift right (asr16bit)
1835
+ if f"arithmetic.asr16bit.bit0.weight" in ctx.tensors:
1836
+ passed, total = 0, 0
1837
+ test_range = range(0, 1 << 16, 257)
1838
+
1839
+ for val in test_range:
1840
+ bits = [float((val >> i) & 1) for i in range(16)]
1841
+ result_bits = []
1842
+ for i in range(16):
1843
+ out_bit = eval_gate_direct(ctx, f"arithmetic.asr16bit.bit{i}", bits)
1844
+ result_bits.append(out_bit)
1845
+
1846
+ result = sum(int(b) << i for i, b in enumerate(result_bits))
1847
+ sign_bit = (val >> 15) & 1
1848
+ expected = (val >> 1) | (sign_bit << 15)
1849
+
1850
+ total += 1
1851
+ if result == expected:
1852
+ passed += 1
1853
+
1854
+ results.append(TestResult("arithmetic.asr16bit", passed, total))
1855
+
1856
  # Rotate left (rol8bit)
1857
  if f"arithmetic.rol8bit.bit0.weight" in ctx.tensors:
1858
  passed, total = 0, 0
 
1878
 
1879
  results.append(TestResult("arithmetic.rol8bit", passed, total))
1880
 
1881
+ # Rotate left (rol16bit)
1882
+ if f"arithmetic.rol16bit.bit0.weight" in ctx.tensors:
1883
+ passed, total = 0, 0
1884
+ test_range = range(0, 1 << 16, 257)
1885
+
1886
+ for val in test_range:
1887
+ bits = [float((val >> i) & 1) for i in range(16)]
1888
+ result_bits = []
1889
+ for i in range(16):
1890
+ out_bit = eval_gate_direct(ctx, f"arithmetic.rol16bit.bit{i}", bits)
1891
+ result_bits.append(out_bit)
1892
+
1893
+ result = sum(int(b) << i for i, b in enumerate(result_bits))
1894
+ expected = ((val << 1) | (val >> 15)) & 0xFFFF
1895
+
1896
+ total += 1
1897
+ if result == expected:
1898
+ passed += 1
1899
+
1900
+ results.append(TestResult("arithmetic.rol16bit", passed, total))
1901
+
1902
  # Rotate right (ror8bit)
1903
  if f"arithmetic.ror8bit.bit0.weight" in ctx.tensors:
1904
  passed, total = 0, 0
 
1924
 
1925
  results.append(TestResult("arithmetic.ror8bit", passed, total))
1926
 
1927
+ # Rotate right (ror16bit)
1928
+ if f"arithmetic.ror16bit.bit0.weight" in ctx.tensors:
1929
+ passed, total = 0, 0
1930
+ test_range = range(0, 1 << 16, 257)
1931
+
1932
+ for val in test_range:
1933
+ bits = [float((val >> i) & 1) for i in range(16)]
1934
+ result_bits = []
1935
+ for i in range(16):
1936
+ out_bit = eval_gate_direct(ctx, f"arithmetic.ror16bit.bit{i}", bits)
1937
+ result_bits.append(out_bit)
1938
+
1939
+ result = sum(int(b) << i for i, b in enumerate(result_bits))
1940
+ expected = ((val >> 1) | ((val & 1) << 15)) & 0xFFFF
1941
+
1942
+ total += 1
1943
+ if result == expected:
1944
+ passed += 1
1945
+
1946
+ results.append(TestResult("arithmetic.ror16bit", passed, total))
1947
+
1948
  return results
1949
 
1950
 
 
2116
 
2117
  # Comparator-like weight vectors (MSB-first weights)
2118
  comp_names = [
2119
+ "arithmetic.greaterthan16bit.comparator",
2120
+ "arithmetic.lessthan16bit.comparator",
2121
+ "arithmetic.greaterorequal16bit.comparator",
2122
+ "arithmetic.lessorequal16bit.comparator",
2123
  "combinational.priorityencoder8bit.priority",
2124
  ]
 
2125
 
2126
  for name in comp_names:
2127
  if name not in ctx.tensors:
 
2130
  ctx.tested_tensors.add(name)
2131
 
2132
  passed, total = 0, 0
2133
+ # Validate weight pattern (MSB-first powers of two)
2134
+ expected_weights = [float(2 ** i) for i in range(len(weights) - 1, -1, -1)]
2135
  total += 1
2136
  if weights == expected_weights:
2137
  passed += 1
2138
 
2139
  # Validate numeric interpretation (MSB-first bits -> value)
2140
+ test_range = range(256) if len(weights) == 8 else range(0, 1 << 16, 257)
2141
  for val in test_range:
2142
+ bits = [float((val >> i) & 1) for i in range(len(weights))][::-1]
2143
  actual = sum(w * b for w, b in zip(weights, bits))
2144
  total += 1
2145
  if int(actual + 0.5) == val:
 
2149
 
2150
  # Constant/selector vectors
2151
  const_specs = {
2152
+ "arithmetic.incrementer16bit.one": ([0.0] * 15 + [1.0], 1),
2153
+ "arithmetic.decrementer16bit.neg_one": ([1.0] * 16, 0xFFFF),
2154
  }
2155
  for name, (expected_bits, expected_val) in const_specs.items():
2156
  if name not in ctx.tensors:
 
2166
 
2167
  # All-ones selector/mask tensors
2168
  ones_specs = {
2169
+ "arithmetic.absolutedifference16bit.diff": 32,
2170
+ "arithmetic.incrementer16bit.adder": 16,
2171
+ "arithmetic.decrementer16bit.adder": 16,
2172
+ "arithmetic.max16bit.select": 32,
2173
+ "arithmetic.min16bit.select": 32,
2174
  "combinational.barrelshifter8bit.shift": 11,
2175
  "combinational.demultiplexer1to4.decode": 3,
2176
  "combinational.demultiplexer1to8.decode": 4,
 
2670
  return results
2671
 
2672
 
2673
+ def test_float16_unary(ctx: EvalContext) -> List[TestResult]:
2674
+ """Test LUT-backed float16 unary operations."""
2675
+ results: List[TestResult] = []
2676
+
2677
+ rng = random.Random(1337)
2678
+ values = build_float16_values(rng, 256)
2679
+
2680
+ ops = [
2681
+ ("float16.sqrt", "sqrt"),
2682
+ ("float16.rsqrt", "rsqrt"),
2683
+ ("float16.exp", "exp"),
2684
+ ("float16.ln", "ln"),
2685
+ ("float16.log2", "log2"),
2686
+ ("float16.sin", "sin"),
2687
+ ("float16.cos", "cos"),
2688
+ ("float16.tan", "tan"),
2689
+ ("float16.tanh", "tanh"),
2690
+ ]
2691
+
2692
+ for prefix, op in ops:
2693
+ if f"{prefix}.out0.weight" not in ctx.tensors:
2694
+ continue
2695
+ passed, total = 0, 0
2696
+ failures: List[Dict[str, Any]] = []
2697
+ for a_bits in values:
2698
+ bits_list = [float((a_bits >> i) & 1) for i in range(16)]
2699
+ actual_bits = eval_float16_lut_outputs(ctx, prefix, bits_list)
2700
+ actual_int = bits_to_int(actual_bits)
2701
+ expected_int, expected_nan = float16_expected_bits_unary(op, a_bits)
2702
+ ok = float16_is_nan_bits(actual_int) if expected_nan else actual_int == expected_int
2703
+ total += 1
2704
+ if ok:
2705
+ passed += 1
2706
+ elif len(failures) < 8:
2707
+ failures.append({
2708
+ "input": hex(a_bits),
2709
+ "actual": hex(actual_int),
2710
+ "expected": hex(expected_int),
2711
+ })
2712
+ results.append(TestResult(prefix, passed, total, failures))
2713
+
2714
+ return results
2715
+
2716
+
2717
+ def test_float16_pow(ctx: EvalContext) -> List[TestResult]:
2718
+ """Test float16.pow (defined as exp(b * ln(a)))."""
2719
+ results: List[TestResult] = []
2720
+ if f"float16.pow.out0.weight" not in ctx.tensors:
2721
+ return results
2722
+
2723
+ rng = random.Random(1337)
2724
+ pairs = build_float16_pairs(rng, 128)
2725
+ mul_prefix = "float16.pow.mul"
2726
+ mul_gates = sorted([g for g in ctx.gates if g.startswith(mul_prefix + ".")])
2727
+
2728
+ passed, total = 0, 0
2729
+ failures: List[Dict[str, Any]] = []
2730
+ for a_bits, b_bits in pairs:
2731
+ a_list = [float((a_bits >> i) & 1) for i in range(16)]
2732
+ b_list = [float((b_bits >> i) & 1) for i in range(16)]
2733
+ # ln(a) via LUT, then mul, then exp via LUT (fast path)
2734
+ ln_bits = eval_float16_lut_outputs(ctx, "float16.pow.ln", a_list, match_prefix="float16.pow.ln")
2735
+
2736
+ # Evaluate pow.mul with ln outputs as internal inputs
2737
+ signals: Dict[int, float] = {}
2738
+ if "#0" in ctx.name_to_id:
2739
+ signals[ctx.name_to_id["#0"]] = 0.0
2740
+ if "#1" in ctx.name_to_id:
2741
+ signals[ctx.name_to_id["#1"]] = 1.0
2742
+ for i in range(16):
2743
+ sid = ctx.name_to_id.get(f"float16.pow.$b[{i}]")
2744
+ if sid is not None:
2745
+ signals[sid] = float(b_list[i])
2746
+ for i in range(16):
2747
+ sid = ctx.name_to_id.get(f"float16.pow.ln.out{i}")
2748
+ if sid is not None:
2749
+ signals[sid] = float(ln_bits[i])
2750
+
2751
+ if mul_prefix not in ctx.topo_cache or len(ctx.topo_cache[mul_prefix]) != len(mul_gates):
2752
+ ctx.topo_cache[mul_prefix] = topo_sort_gates(ctx, mul_gates)
2753
+ evaluate_gates_in_order(ctx, signals, ctx.topo_cache[mul_prefix])
2754
+
2755
+ mul_bits = []
2756
+ for i in range(16):
2757
+ gate = f"{mul_prefix}.out{i}"
2758
+ sid = ctx.name_to_id.get(gate)
2759
+ if sid is None or sid not in signals:
2760
+ raise RuntimeError(f"{mul_prefix}: missing output {gate}")
2761
+ mul_bits.append(float(signals[sid]))
2762
+
2763
+ exp_bits = eval_float16_lut_outputs(ctx, "float16.pow.exp", mul_bits, match_prefix="float16.pow.exp")
2764
+
2765
+ # Mark pow output pass-through gates as tested
2766
+ for i in range(16):
2767
+ gate = f"float16.pow.out{i}"
2768
+ for suffix in (".weight", ".bias", ".inputs"):
2769
+ key = gate + suffix
2770
+ if key in ctx.tensors:
2771
+ ctx.tested_tensors.add(key)
2772
+
2773
+ actual_int = bits_to_int(exp_bits)
2774
+ expected_int, expected_nan = float16_expected_bits_pow(a_bits, b_bits)
2775
+ ok = float16_is_nan_bits(actual_int) if expected_nan else actual_int == expected_int
2776
+ total += 1
2777
+ if ok:
2778
+ passed += 1
2779
+ elif len(failures) < 8:
2780
+ failures.append({
2781
+ "a": hex(a_bits),
2782
+ "b": hex(b_bits),
2783
+ "actual": hex(actual_int),
2784
+ "expected": hex(expected_int),
2785
+ })
2786
+
2787
+ results.append(TestResult("float16.pow", passed, total, failures))
2788
+ return results
2789
+
2790
+
2791
  # =============================================================================
2792
  # TEST RUNNER
2793
  # =============================================================================
 
2808
  "float16_basic": ("Float16 - Basic", test_float16_basic),
2809
  "float16_arith": ("Float16 - Arithmetic", test_float16_arithmetic),
2810
  "float16_conv": ("Float16 - Conversion", test_float16_conversion),
2811
+ "float16_unary": ("Float16 - Unary LUT", test_float16_unary),
2812
+ "float16_pow": ("Float16 - Pow", test_float16_pow),
2813
  }
2814
 
2815