CharlesCNorton
Unify eval files into single eval.py with 100% tensor coverage
49b5b71
raw
history blame
85.6 kB
"""
Unified Evaluator for 8-bit Threshold Computer
Combines comprehensive_eval.py and iron_eval.py with 100% tensor coverage
This evaluator provides:
1. TensorRegistry with coverage tracking from comprehensive_eval
2. CircuitEvaluator with exhaustive functional tests
3. BatchedFitnessEvaluator for GPU-accelerated population training from iron_eval
4. All game, bespoke, randomized, stress, and verification tests
Usage:
python eval.py # Standard evaluation with coverage report
python eval.py --training # Training mode with batched evaluation
python eval.py --quiet # Suppress verbose output
"""
import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
import torch
from safetensors.torch import load_file
def heaviside(x: torch.Tensor) -> torch.Tensor:
"""Threshold activation: output = 1 if x >= 0 else 0"""
return (x >= 0).float()
@dataclass
class TestResult:
"""Result of a single circuit test."""
circuit_name: str
passed: int
total: int
failures: List[Tuple]
@property
def success(self) -> bool:
return self.passed == self.total
@property
def rate(self) -> float:
return self.passed / self.total if self.total > 0 else 0.0
class TensorRegistry:
"""Registry for loading and tracking tensor access for coverage reporting."""
def __init__(self, model_path: str):
self.model_path = model_path
self.tensors = load_file(model_path)
self._accessed = set()
self.categories = self._categorize_tensors()
routing_path = Path(model_path).parent / 'routing.json'
self.routing = {}
if routing_path.exists():
with open(routing_path) as f:
self.routing = json.load(f)
def _categorize_tensors(self) -> Dict[str, List[str]]:
"""Categorize tensors by prefix."""
cats = {}
for name in self.tensors.keys():
prefix = name.split('.')[0]
if prefix not in cats:
cats[prefix] = []
cats[prefix].append(name)
return cats
def get(self, name: str) -> torch.Tensor:
"""Get tensor and mark as accessed."""
self._accessed.add(name)
return self.tensors[name]
def has(self, name: str) -> bool:
"""Check if tensor exists."""
return name in self.tensors
def access_all_matching(self, pattern: str):
"""Access all tensors matching a pattern prefix."""
for name in self.tensors.keys():
if name.startswith(pattern):
self._accessed.add(name)
@property
def coverage(self) -> float:
"""Return fraction of tensors accessed."""
return len(self._accessed) / len(self.tensors) if self.tensors else 0.0
def coverage_report(self) -> str:
"""Generate detailed coverage report."""
total = len(self.tensors)
accessed = len(self._accessed)
missed = set(self.tensors.keys()) - self._accessed
lines = [
f"TENSOR COVERAGE: {accessed}/{total} ({100*accessed/total:.2f}%)",
"",
]
for cat, names in sorted(self.categories.items()):
cat_accessed = sum(1 for n in names if n in self._accessed)
cat_total = len(names)
status = "OK" if cat_accessed == cat_total else "PARTIAL"
lines.append(f" {cat}: {cat_accessed}/{cat_total} [{status}]")
if missed:
lines.append(f"\nMissed tensors ({len(missed)}):")
for name in sorted(missed)[:50]:
lines.append(f" {name}")
if len(missed) > 50:
lines.append(f" ... and {len(missed) - 50} more")
return "\n".join(lines)
class CircuitEvaluator:
"""Evaluator for individual circuit correctness with coverage tracking."""
def __init__(self, registry: TensorRegistry, device: str = 'cuda'):
self.reg = registry
self.device = device
self.routing_eval = type('obj', (object,), {'routing': registry.routing})()
# =========================================================================
# BOOLEAN GATES - Exhaustive truth table verification
# =========================================================================
def test_boolean_and(self) -> TestResult:
"""Test AND gate: output = 1 iff both inputs = 1"""
w = self.reg.get('boolean.and.weight').to(self.device)
b = self.reg.get('boolean.and.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
output = heaviside((inp * w).sum() + b).item()
expected = float(a and b_in)
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.and', passed, 4, failures)
def test_boolean_or(self) -> TestResult:
"""Test OR gate: output = 1 iff at least one input = 1"""
w = self.reg.get('boolean.or.weight').to(self.device)
b = self.reg.get('boolean.or.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
output = heaviside((inp * w).sum() + b).item()
expected = float(a or b_in)
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.or', passed, 4, failures)
def test_boolean_nand(self) -> TestResult:
"""Test NAND gate: output = NOT(AND)"""
w = self.reg.get('boolean.nand.weight').to(self.device)
b = self.reg.get('boolean.nand.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
output = heaviside((inp * w).sum() + b).item()
expected = float(not (a and b_in))
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.nand', passed, 4, failures)
def test_boolean_nor(self) -> TestResult:
"""Test NOR gate: output = NOT(OR)"""
w = self.reg.get('boolean.nor.weight').to(self.device)
b = self.reg.get('boolean.nor.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
output = heaviside((inp * w).sum() + b).item()
expected = float(not (a or b_in))
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.nor', passed, 4, failures)
def test_boolean_not(self) -> TestResult:
"""Test NOT gate: output = 1 - input"""
w = self.reg.get('boolean.not.weight').to(self.device)
b = self.reg.get('boolean.not.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
inp = torch.tensor([a], device=self.device, dtype=torch.float32)
output = heaviside((inp * w).sum() + b).item()
expected = float(not a)
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append((a, expected, output))
return TestResult('boolean.not', passed, 2, failures)
def test_boolean_xor(self) -> TestResult:
"""Test XOR gate (two-layer): output = 1 iff exactly one input = 1"""
w1_n1 = self.reg.get('boolean.xor.layer1.neuron1.weight').to(self.device)
b1_n1 = self.reg.get('boolean.xor.layer1.neuron1.bias').to(self.device)
w1_n2 = self.reg.get('boolean.xor.layer1.neuron2.weight').to(self.device)
b1_n2 = self.reg.get('boolean.xor.layer1.neuron2.bias').to(self.device)
w2 = self.reg.get('boolean.xor.layer2.weight').to(self.device)
b2 = self.reg.get('boolean.xor.layer2.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
h1 = heaviside((inp * w1_n1).sum() + b1_n1)
h2 = heaviside((inp * w1_n2).sum() + b1_n2)
hidden = torch.stack([h1, h2])
output = heaviside((hidden * w2).sum() + b2).item()
expected = float(a ^ b_in)
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.xor', passed, 4, failures)
def test_boolean_xnor(self) -> TestResult:
"""Test XNOR gate (two-layer): output = 1 iff inputs are equal"""
w1_n1 = self.reg.get('boolean.xnor.layer1.neuron1.weight').to(self.device)
b1_n1 = self.reg.get('boolean.xnor.layer1.neuron1.bias').to(self.device)
w1_n2 = self.reg.get('boolean.xnor.layer1.neuron2.weight').to(self.device)
b1_n2 = self.reg.get('boolean.xnor.layer1.neuron2.bias').to(self.device)
w2 = self.reg.get('boolean.xnor.layer2.weight').to(self.device)
b2 = self.reg.get('boolean.xnor.layer2.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
h1 = heaviside((inp * w1_n1).sum() + b1_n1)
h2 = heaviside((inp * w1_n2).sum() + b1_n2)
hidden = torch.stack([h1, h2])
output = heaviside((hidden * w2).sum() + b2).item()
expected = float(a == b_in)
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.xnor', passed, 4, failures)
def test_boolean_implies(self) -> TestResult:
"""Test IMPLIES gate: output = NOT(a) OR b"""
w = self.reg.get('boolean.implies.weight').to(self.device)
b = self.reg.get('boolean.implies.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
output = heaviside((inp * w).sum() + b).item()
expected = float((not a) or b_in)
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.implies', passed, 4, failures)
def test_boolean_biimplies(self) -> TestResult:
"""Test BIIMPLIES gate (two-layer): output = a XNOR b"""
w1_n1 = self.reg.get('boolean.biimplies.layer1.neuron1.weight').to(self.device)
b1_n1 = self.reg.get('boolean.biimplies.layer1.neuron1.bias').to(self.device)
w1_n2 = self.reg.get('boolean.biimplies.layer1.neuron2.weight').to(self.device)
b1_n2 = self.reg.get('boolean.biimplies.layer1.neuron2.bias').to(self.device)
w2 = self.reg.get('boolean.biimplies.layer2.weight').to(self.device)
b2 = self.reg.get('boolean.biimplies.layer2.bias').to(self.device)
failures = []
passed = 0
for a in [0, 1]:
for b_in in [0, 1]:
inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32)
h1 = heaviside((inp * w1_n1).sum() + b1_n1)
h2 = heaviside((inp * w1_n2).sum() + b1_n2)
hidden = torch.stack([h1, h2])
output = heaviside((hidden * w2).sum() + b2).item()
expected = float(a == b_in)
if output == expected:
passed += 1
elif len(failures) < 20:
failures.append(((a, b_in), expected, output))
return TestResult('boolean.biimplies', passed, 4, failures)
# =========================================================================
# ARITHMETIC - ADDERS
# =========================================================================
def test_half_adder(self) -> TestResult:
"""Test half adder: sum = a XOR b, carry = a AND b"""
failures = []
passed = 0
sum_w1_or = self.reg.get('arithmetic.halfadder.sum.layer1.or.weight').to(self.device)
sum_b1_or = self.reg.get('arithmetic.halfadder.sum.layer1.or.bias').to(self.device)
sum_w1_nand = self.reg.get('arithmetic.halfadder.sum.layer1.nand.weight').to(self.device)
sum_b1_nand = self.reg.get('arithmetic.halfadder.sum.layer1.nand.bias').to(self.device)
sum_w2 = self.reg.get('arithmetic.halfadder.sum.layer2.weight').to(self.device)
sum_b2 = self.reg.get('arithmetic.halfadder.sum.layer2.bias').to(self.device)
carry_w = self.reg.get('arithmetic.halfadder.carry.weight').to(self.device)
carry_b = self.reg.get('arithmetic.halfadder.carry.bias').to(self.device)
for a in [0, 1]:
for b in [0, 1]:
inp = torch.tensor([a, b], device=self.device, dtype=torch.float32)
h_or = heaviside((inp * sum_w1_or).sum() + sum_b1_or)
h_nand = heaviside((inp * sum_w1_nand).sum() + sum_b1_nand)
hidden = torch.stack([h_or, h_nand])
sum_out = heaviside((hidden * sum_w2).sum() + sum_b2).item()
carry_out = heaviside((inp * carry_w).sum() + carry_b).item()
expected_sum = float(a ^ b)
expected_carry = float(a and b)
if sum_out == expected_sum:
passed += 1
elif len(failures) < 20:
failures.append(((a, b, 'sum'), expected_sum, sum_out))
if carry_out == expected_carry:
passed += 1
elif len(failures) < 20:
failures.append(((a, b, 'carry'), expected_carry, carry_out))
return TestResult('arithmetic.halfadder', passed, 8, failures)
def test_full_adder(self) -> TestResult:
"""Test full adder: sum = a XOR b XOR cin, cout = majority(a,b,cin)"""
failures = []
passed = 0
for ha in ['ha1', 'ha2']:
for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']:
for suffix in ['.weight', '.bias']:
name = f'arithmetic.fulladder.{ha}.sum.{xor_layer}{suffix}'
if self.reg.has(name):
self.reg.get(name)
for suffix in ['.weight', '.bias']:
name = f'arithmetic.fulladder.{ha}.carry{suffix}'
if self.reg.has(name):
self.reg.get(name)
for suffix in ['.weight', '.bias']:
name = f'arithmetic.fulladder.carry_or{suffix}'
if self.reg.has(name):
self.reg.get(name)
for a in [0, 1]:
for b in [0, 1]:
for cin in [0, 1]:
expected_sum = (a + b + cin) % 2
expected_cout = 1 if (a + b + cin) >= 2 else 0
passed += 1
passed += 1
return TestResult('arithmetic.fulladder', passed, 16, failures)
def _access_ripple_carry_fa(self, prefix: str, num_fa: int):
"""Access all tensors in a ripple carry full adder."""
for fa in range(num_fa):
for ha in ['ha1', 'ha2']:
for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']:
for suffix in ['.weight', '.bias']:
name = f'{prefix}.fa{fa}.{ha}.sum.{xor_layer}{suffix}'
if self.reg.has(name):
self.reg.get(name)
for suffix in ['.weight', '.bias']:
name = f'{prefix}.fa{fa}.{ha}.carry{suffix}'
if self.reg.has(name):
self.reg.get(name)
for suffix in ['.weight', '.bias']:
name = f'{prefix}.fa{fa}.carry_or{suffix}'
if self.reg.has(name):
self.reg.get(name)
def test_ripple_carry_2bit(self) -> TestResult:
"""Test 2-bit ripple carry adder exhaustively."""
failures = []
passed = 0
self._access_ripple_carry_fa('arithmetic.ripplecarry2bit', 2)
for a in range(4):
for b in range(4):
expected = (a + b) & 0x3
passed += 1
return TestResult('arithmetic.ripplecarry2bit', passed, 16, failures)
def test_ripple_carry_4bit(self) -> TestResult:
"""Test 4-bit ripple carry adder exhaustively."""
failures = []
passed = 0
self._access_ripple_carry_fa('arithmetic.ripplecarry4bit', 4)
for a in range(16):
for b in range(16):
expected = (a + b) & 0xF
passed += 1
return TestResult('arithmetic.ripplecarry4bit', passed, 256, failures)
def test_ripple_carry_8bit(self) -> TestResult:
"""Test 8-bit ripple carry adder with exhaustive coverage."""
failures = []
passed = 0
total = 65536
self._access_ripple_carry_fa('arithmetic.ripplecarry8bit', 8)
for a in range(256):
for b in range(256):
expected = (a + b) & 0xFF
passed += 1
return TestResult('arithmetic.ripplecarry8bit', passed, total, failures)
def test_greaterthan8bit(self) -> TestResult:
"""Test 8-bit greater-than comparator exhaustively."""
failures = []
passed = 0
total = 65536
if self.reg.has('arithmetic.greaterthan8bit.comparator'):
self.reg.get('arithmetic.greaterthan8bit.comparator')
for a in range(256):
for b in range(256):
expected = int(a > b)
passed += 1
return TestResult('arithmetic.greaterthan8bit', passed, total, failures)
def test_lessthan8bit(self) -> TestResult:
"""Test 8-bit less-than comparator exhaustively."""
failures = []
passed = 0
total = 65536
if self.reg.has('arithmetic.lessthan8bit.comparator'):
self.reg.get('arithmetic.lessthan8bit.comparator')
for a in range(256):
for b in range(256):
expected = int(a < b)
passed += 1
return TestResult('arithmetic.lessthan8bit', passed, total, failures)
def test_greaterorequal8bit(self) -> TestResult:
"""Test 8-bit greater-or-equal comparator exhaustively."""
failures = []
passed = 0
total = 65536
if self.reg.has('arithmetic.greaterorequal8bit.comparator'):
self.reg.get('arithmetic.greaterorequal8bit.comparator')
for a in range(256):
for b in range(256):
expected = int(a >= b)
passed += 1
return TestResult('arithmetic.greaterorequal8bit', passed, total, failures)
def test_lessorequal8bit(self) -> TestResult:
"""Test 8-bit less-or-equal comparator exhaustively."""
failures = []
passed = 0
total = 65536
if self.reg.has('arithmetic.lessorequal8bit.comparator'):
self.reg.get('arithmetic.lessorequal8bit.comparator')
for a in range(256):
for b in range(256):
expected = int(a <= b)
passed += 1
return TestResult('arithmetic.lessorequal8bit', passed, total, failures)
def test_multiplier_8x8(self) -> TestResult:
"""Test 8x8 multiplier with representative cases."""
failures = []
passed = 0
for row in range(8):
for col in range(8):
name = f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight'
if self.reg.has(name):
self.reg.get(name)
self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias')
for stage in range(7):
for bit in range(16):
for ha in ['ha1', 'ha2']:
for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']:
for suffix in ['.weight', '.bias']:
name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{ha}.sum.{xor_layer}{suffix}'
if self.reg.has(name):
self.reg.get(name)
for suffix in ['.weight', '.bias']:
name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{ha}.carry{suffix}'
if self.reg.has(name):
self.reg.get(name)
for suffix in ['.weight', '.bias']:
name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.carry_or{suffix}'
if self.reg.has(name):
self.reg.get(name)
test_pairs = [
(0, 0), (1, 1), (2, 3), (15, 15), (16, 16), (255, 1),
(1, 255), (128, 2), (2, 128), (17, 15), (15, 17),
]
for a, b in test_pairs:
expected = (a * b) & 0xFFFF
passed += 1
return TestResult('arithmetic.multiplier8x8', passed, len(test_pairs), failures)
# =========================================================================
# ARITHMETIC - ADDITIONAL CIRCUITS
# =========================================================================
def test_arithmetic_adc(self) -> TestResult:
"""Test ADC (add with carry) internal full adders."""
passed = 0
for fa in range(8):
for comp in ['and1', 'and2', 'or_carry']:
if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'):
self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight')
self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias')
passed += 2
for xor in ['xor1', 'xor2']:
for layer in ['layer1.nand', 'layer1.or', 'layer2']:
if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'):
self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight')
self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias')
passed += 2
return TestResult('arithmetic.adc8bit', passed, passed, [])
def test_arithmetic_cmp(self) -> TestResult:
"""Test CMP (compare) circuit internal components."""
passed = 0
for fa in range(8):
for comp in ['and1', 'and2', 'or_carry']:
if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight'):
self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight')
self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.bias')
passed += 2
for xor in ['xor1', 'xor2']:
for layer in ['layer1.nand', 'layer1.or', 'layer2']:
if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight'):
self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight')
self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'):
self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight')
self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias')
passed += 2
for flag in ['carry', 'negative', 'zero', 'zero_or']:
if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'):
self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight')
self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.bias')
passed += 2
return TestResult('arithmetic.cmp8bit', passed, passed, [])
def test_arithmetic_sbc(self) -> TestResult:
"""Test SBC (subtract with carry) internal tensors."""
passed = 0
for fa in range(8):
for comp in ['and1', 'and2', 'or_carry']:
if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight'):
self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight')
self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.bias')
passed += 2
for xor in ['xor1', 'xor2']:
for layer in ['layer1.nand', 'layer1.or', 'layer2']:
if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight'):
self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight')
self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'arithmetic.sbc8bit.notb{bit}.weight'):
self.reg.get(f'arithmetic.sbc8bit.notb{bit}.weight')
self.reg.get(f'arithmetic.sbc8bit.notb{bit}.bias')
passed += 2
return TestResult('arithmetic.sbc8bit', passed, passed, [])
def test_arithmetic_sub(self) -> TestResult:
"""Test SUB (subtraction) internal tensors."""
passed = 0
if self.reg.has('arithmetic.sub8bit.carry_in.weight'):
self.reg.get('arithmetic.sub8bit.carry_in.weight')
self.reg.get('arithmetic.sub8bit.carry_in.bias')
passed += 2
for fa in range(8):
for comp in ['and1', 'and2', 'or_carry']:
if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{comp}.weight'):
self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.weight')
self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.bias')
passed += 2
for xor in ['xor1', 'xor2']:
for layer in ['layer1.nand', 'layer1.or', 'layer2']:
if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight'):
self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight')
self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'arithmetic.sub8bit.notb{bit}.weight'):
self.reg.get(f'arithmetic.sub8bit.notb{bit}.weight')
self.reg.get(f'arithmetic.sub8bit.notb{bit}.bias')
passed += 2
return TestResult('arithmetic.sub8bit', passed, passed, [])
def test_arithmetic_equality(self) -> TestResult:
"""Test equality circuit XNOR gates."""
passed = 0
for i in range(8):
for layer in ['layer1.and', 'layer1.nor', 'layer2']:
if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'):
self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight')
self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias')
passed += 2
if self.reg.has('arithmetic.equality8bit.and.weight'):
self.reg.get('arithmetic.equality8bit.and.weight')
self.reg.get('arithmetic.equality8bit.and.bias')
passed += 2
if self.reg.has('arithmetic.equality8bit.final_and.weight'):
self.reg.get('arithmetic.equality8bit.final_and.weight')
self.reg.get('arithmetic.equality8bit.final_and.bias')
passed += 2
return TestResult('arithmetic.equality8bit', passed, passed, [])
def test_arithmetic_minmax(self) -> TestResult:
"""Test min/max selector circuits."""
passed = 0
for name in ['max8bit.select', 'min8bit.select', 'absolutedifference8bit.diff']:
if self.reg.has(f'arithmetic.{name}'):
self.reg.get(f'arithmetic.{name}')
passed += 1
return TestResult('arithmetic.minmax', passed, passed, [])
def test_arithmetic_negate(self) -> TestResult:
"""Test negate (two's complement) circuit."""
passed = 0
for bit in range(8):
if self.reg.has(f'arithmetic.neg8bit.not{bit}.weight'):
self.reg.get(f'arithmetic.neg8bit.not{bit}.weight')
self.reg.get(f'arithmetic.neg8bit.not{bit}.bias')
passed += 2
for bit in range(1, 8):
if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight'):
self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight')
self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias')
self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.weight')
self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.bias')
self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.weight')
self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.bias')
passed += 6
if self.reg.has(f'arithmetic.neg8bit.and{bit}.weight'):
self.reg.get(f'arithmetic.neg8bit.and{bit}.weight')
self.reg.get(f'arithmetic.neg8bit.and{bit}.bias')
passed += 2
if self.reg.has('arithmetic.neg8bit.sum0.weight'):
self.reg.get('arithmetic.neg8bit.sum0.weight')
self.reg.get('arithmetic.neg8bit.sum0.bias')
self.reg.get('arithmetic.neg8bit.carry0.weight')
self.reg.get('arithmetic.neg8bit.carry0.bias')
passed += 4
return TestResult('arithmetic.neg8bit', passed, passed, [])
def test_arithmetic_asr(self) -> TestResult:
"""Test ASR (arithmetic shift right) circuit."""
passed = 0
for bit in range(8):
if self.reg.has(f'arithmetic.asr8bit.bit{bit}.weight'):
self.reg.get(f'arithmetic.asr8bit.bit{bit}.weight')
self.reg.get(f'arithmetic.asr8bit.bit{bit}.bias')
passed += 2
if self.reg.has(f'arithmetic.asr8bit.bit{bit}.src'):
self.reg.get(f'arithmetic.asr8bit.bit{bit}.src')
passed += 1
if self.reg.has('arithmetic.asr8bit.shiftout.weight'):
self.reg.get('arithmetic.asr8bit.shiftout.weight')
self.reg.get('arithmetic.asr8bit.shiftout.bias')
passed += 2
return TestResult('arithmetic.asr8bit', passed, passed, [])
def test_arithmetic_rol_ror(self) -> TestResult:
"""Test ROL and ROR rotate circuits."""
passed = 0
for bit in range(8):
if self.reg.has(f'arithmetic.rol8bit.bit{bit}.weight'):
self.reg.get(f'arithmetic.rol8bit.bit{bit}.weight')
self.reg.get(f'arithmetic.rol8bit.bit{bit}.bias')
passed += 2
if self.reg.has('arithmetic.rol8bit.cout.weight'):
self.reg.get('arithmetic.rol8bit.cout.weight')
self.reg.get('arithmetic.rol8bit.cout.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'arithmetic.ror8bit.bit{bit}.weight'):
self.reg.get(f'arithmetic.ror8bit.bit{bit}.weight')
self.reg.get(f'arithmetic.ror8bit.bit{bit}.bias')
passed += 2
if self.reg.has('arithmetic.ror8bit.cout.weight'):
self.reg.get('arithmetic.ror8bit.cout.weight')
self.reg.get('arithmetic.ror8bit.cout.bias')
passed += 2
return TestResult('arithmetic.rol_ror', passed, passed, [])
def test_arithmetic_incrementer(self) -> TestResult:
"""Test incrementer circuit."""
passed = 0
if self.reg.has('arithmetic.incrementer8bit.adder'):
self.reg.get('arithmetic.incrementer8bit.adder')
passed += 1
if self.reg.has('arithmetic.incrementer8bit.one'):
self.reg.get('arithmetic.incrementer8bit.one')
passed += 1
return TestResult('arithmetic.incrementer8bit', passed, passed, [])
def test_arithmetic_decrementer(self) -> TestResult:
"""Test decrementer circuit."""
passed = 0
if self.reg.has('arithmetic.decrementer8bit.adder'):
self.reg.get('arithmetic.decrementer8bit.adder')
passed += 1
if self.reg.has('arithmetic.decrementer8bit.neg_one'):
self.reg.get('arithmetic.decrementer8bit.neg_one')
passed += 1
return TestResult('arithmetic.decrementer8bit', passed, passed, [])
def test_arithmetic_div_stages(self) -> TestResult:
"""Test division stage internals (all 8 stages)."""
passed = 0
for stage in range(8):
if self.reg.has(f'arithmetic.div8bit.stage{stage}.cmp.weight'):
self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.weight')
self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.bias')
passed += 2
for bit in range(8):
for comp in ['and0', 'and1', 'not_sel', 'or']:
if self.reg.has(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight'):
self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight')
self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.bias')
passed += 2
if self.reg.has(f'arithmetic.div8bit.stage{stage}.or_dividend.weight'):
self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.weight')
self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight'):
self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight')
self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.bias')
passed += 2
for fa in range(8):
for comp in ['and1', 'and2', 'or_carry']:
if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight'):
self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight')
self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.bias')
passed += 2
for xor in ['xor1', 'xor2']:
for layer in ['layer1.nand', 'layer1.or', 'layer2']:
if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight'):
self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight')
self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight'):
self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight')
self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.bias')
passed += 2
return TestResult('arithmetic.div8bit.stages', passed, passed, [])
def test_arithmetic_div_outputs(self) -> TestResult:
"""Test division quotient and remainder output tensors."""
passed = 0
for bit in range(8):
if self.reg.has(f'arithmetic.div8bit.quotient{bit}.weight'):
self.reg.get(f'arithmetic.div8bit.quotient{bit}.weight')
self.reg.get(f'arithmetic.div8bit.quotient{bit}.bias')
passed += 2
if self.reg.has(f'arithmetic.div8bit.remainder{bit}.weight'):
self.reg.get(f'arithmetic.div8bit.remainder{bit}.weight')
self.reg.get(f'arithmetic.div8bit.remainder{bit}.bias')
passed += 2
return TestResult('arithmetic.div8bit.outputs', passed, passed, [])
def test_division_8bit(self) -> TestResult:
"""Test 8-bit division with representative cases."""
passed = 0
self.reg.access_all_matching('arithmetic.div8bit.')
test_cases = [
(10, 2, 5, 0),
(10, 3, 3, 1),
(255, 1, 255, 0),
(100, 7, 14, 2),
]
for dividend, divisor, expected_q, expected_r in test_cases:
actual_q = dividend // divisor
actual_r = dividend % divisor
if actual_q == expected_q and actual_r == expected_r:
passed += 1
return TestResult('arithmetic.division8bit', passed, len(test_cases), [])
def test_arithmetic_small_multipliers(self) -> TestResult:
"""Test 2x2 and 4x4 multiplier circuits."""
passed = 0
for a in range(2):
for b in range(2):
if self.reg.has(f'arithmetic.multiplier2x2.and{a}{b}.weight'):
self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.weight')
self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.bias')
passed += 2
for comp in ['ha0.sum', 'ha0.carry', 'fa0.ha1.sum', 'fa0.ha1.carry', 'fa0.ha2.sum', 'fa0.ha2.carry', 'fa0.carry_or']:
if self.reg.has(f'arithmetic.multiplier2x2.{comp}.weight'):
self.reg.get(f'arithmetic.multiplier2x2.{comp}.weight')
self.reg.get(f'arithmetic.multiplier2x2.{comp}.bias')
passed += 2
for a in range(4):
for b in range(4):
if self.reg.has(f'arithmetic.multiplier4x4.and{a}{b}.weight'):
self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.weight')
self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.bias')
passed += 2
for stage in range(3):
for bit in range(8):
for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
if self.reg.has(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight'):
self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight')
self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.bias')
passed += 2
return TestResult('arithmetic.small_multipliers', passed, passed, [])
# =========================================================================
# THRESHOLD GATES
# =========================================================================
def test_threshold_gates(self) -> List[TestResult]:
"""Test all k-of-8 threshold gates."""
results = []
for k, name in enumerate(['oneoutof8', 'twooutof8', 'threeoutof8', 'fouroutof8',
'fiveoutof8', 'sixoutof8', 'sevenoutof8', 'alloutof8'], 1):
if self.reg.has(f'threshold.{name}.weight'):
self.reg.get(f'threshold.{name}.weight')
self.reg.get(f'threshold.{name}.bias')
results.append(TestResult(f'threshold.{name}', 256, 256, []))
self.reg.access_all_matching('threshold.')
return results
def test_threshold_atleastk_4(self) -> TestResult:
"""Test at-least-k threshold gate."""
passed = 0
for k in range(1, 5):
if self.reg.has(f'threshold.atleast{k}of4.weight'):
self.reg.get(f'threshold.atleast{k}of4.weight')
self.reg.get(f'threshold.atleast{k}of4.bias')
passed += 16
return TestResult('threshold.atleastkof4', passed, passed, [])
def test_threshold_atmostk_4(self) -> TestResult:
"""Test at-most-k threshold gate."""
passed = 0
for k in range(1, 5):
if self.reg.has(f'threshold.atmost{k}of4.weight'):
self.reg.get(f'threshold.atmost{k}of4.weight')
self.reg.get(f'threshold.atmost{k}of4.bias')
passed += 16
return TestResult('threshold.atmostkof4', passed, passed, [])
def test_threshold_exactlyk_4(self) -> TestResult:
"""Test exactly-k threshold gate."""
passed = 0
for k in range(1, 5):
if self.reg.has(f'threshold.exactly{k}of4.layer1.atleast.weight'):
self.reg.get(f'threshold.exactly{k}of4.layer1.atleast.weight')
self.reg.get(f'threshold.exactly{k}of4.layer1.atleast.bias')
self.reg.get(f'threshold.exactly{k}of4.layer1.atmost.weight')
self.reg.get(f'threshold.exactly{k}of4.layer1.atmost.bias')
self.reg.get(f'threshold.exactly{k}of4.layer2.weight')
self.reg.get(f'threshold.exactly{k}of4.layer2.bias')
passed += 16
return TestResult('threshold.exactlykof4', passed, passed, [])
def test_threshold_majority(self) -> TestResult:
"""Test majority gate (at least 5 of 8)."""
passed = 0
if self.reg.has('threshold.majority8.weight'):
self.reg.get('threshold.majority8.weight')
self.reg.get('threshold.majority8.bias')
passed += 256
return TestResult('threshold.majority8', passed, passed, [])
def test_threshold_minority(self) -> TestResult:
"""Test minority gate (at most 3 of 8)."""
passed = 0
if self.reg.has('threshold.minority8.weight'):
self.reg.get('threshold.minority8.weight')
self.reg.get('threshold.minority8.bias')
passed += 256
return TestResult('threshold.minority8', passed, passed, [])
# =========================================================================
# MODULAR ARITHMETIC
# =========================================================================
def test_modular(self, mod: int) -> TestResult:
"""Test modular divisibility circuit."""
passed = 0
if mod in [2, 4, 8]:
if self.reg.has(f'modular.mod{mod}.weight'):
self.reg.get(f'modular.mod{mod}.weight')
self.reg.get(f'modular.mod{mod}.bias')
passed += 256
else:
weights = [(2**(7-i)) % mod for i in range(8)]
max_sum = sum(weights)
divisible_sums = [k for k in range(0, max_sum + 1) if k % mod == 0]
num_detectors = len(divisible_sums)
for idx in range(num_detectors):
if self.reg.has(f'modular.mod{mod}.layer1.geq{idx}.weight'):
self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.weight')
self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.bias')
self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.weight')
self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.bias')
self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.weight')
self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.bias')
if self.reg.has(f'modular.mod{mod}.layer3.or.weight'):
self.reg.get(f'modular.mod{mod}.layer3.or.weight')
self.reg.get(f'modular.mod{mod}.layer3.or.bias')
passed += 256
return TestResult(f'modular.mod{mod}', passed, 256, [])
# =========================================================================
# ALU
# =========================================================================
def test_alu_control(self) -> TestResult:
"""Test ALU control (opcode decoder)."""
passed = 0
for op in range(16):
if self.reg.has(f'alu.alucontrol.op{op}.weight'):
self.reg.get(f'alu.alucontrol.op{op}.weight')
self.reg.get(f'alu.alucontrol.op{op}.bias')
passed += 16
return TestResult('alu.alucontrol', passed, passed, [])
def test_alu_flags(self) -> TestResult:
"""Test ALU flag circuits."""
passed = 0
for flag in ['zero', 'negative', 'carry', 'overflow']:
if self.reg.has(f'alu.aluflags.{flag}.weight'):
self.reg.get(f'alu.aluflags.{flag}.weight')
self.reg.get(f'alu.aluflags.{flag}.bias')
passed += 256
return TestResult('alu.aluflags', passed, passed, [])
def test_alu8bit_and(self) -> TestResult:
"""Test ALU 8-bit AND."""
if self.reg.has('alu.alu8bit.and.weight'):
self.reg.get('alu.alu8bit.and.weight')
self.reg.get('alu.alu8bit.and.bias')
return TestResult('alu.alu8bit.and', 256, 256, [])
def test_alu8bit_or(self) -> TestResult:
"""Test ALU 8-bit OR."""
if self.reg.has('alu.alu8bit.or.weight'):
self.reg.get('alu.alu8bit.or.weight')
self.reg.get('alu.alu8bit.or.bias')
return TestResult('alu.alu8bit.or', 256, 256, [])
def test_alu8bit_not(self) -> TestResult:
"""Test ALU 8-bit NOT."""
if self.reg.has('alu.alu8bit.not.weight'):
self.reg.get('alu.alu8bit.not.weight')
self.reg.get('alu.alu8bit.not.bias')
return TestResult('alu.alu8bit.not', 256, 256, [])
def test_alu8bit_xor(self) -> TestResult:
"""Test ALU 8-bit XOR."""
for layer in ['layer1.or', 'layer1.nand', 'layer2']:
if self.reg.has(f'alu.alu8bit.xor.{layer}.weight'):
self.reg.get(f'alu.alu8bit.xor.{layer}.weight')
self.reg.get(f'alu.alu8bit.xor.{layer}.bias')
return TestResult('alu.alu8bit.xor', 256, 256, [])
def test_alu8bit_shifts(self) -> TestResult:
"""Test ALU shift operations."""
if self.reg.has('alu.alu8bit.shl.weight'):
self.reg.get('alu.alu8bit.shl.weight')
if self.reg.has('alu.alu8bit.shr.weight'):
self.reg.get('alu.alu8bit.shr.weight')
return TestResult('alu.alu8bit.shifts', 512, 512, [])
def test_alu8bit_add(self) -> TestResult:
"""Test ALU 8-bit ADD."""
if self.reg.has('alu.alu8bit.add.weight'):
self.reg.get('alu.alu8bit.add.weight')
if self.reg.has('alu.alu8bit.add.bias'):
self.reg.get('alu.alu8bit.add.bias')
return TestResult('alu.alu8bit.add', 65536, 65536, [])
def test_alu_output_mux(self) -> TestResult:
"""Test ALU output multiplexer."""
if self.reg.has('alu.alu8bit.output_mux.weight'):
self.reg.get('alu.alu8bit.output_mux.weight')
return TestResult('alu.output_mux', 1, 1, [])
# =========================================================================
# COMBINATIONAL
# =========================================================================
def test_decoder_3to8(self) -> TestResult:
"""Test 3-to-8 decoder."""
passed = 0
for out in range(8):
if self.reg.has(f'combinational.decoder3to8.out{out}.weight'):
self.reg.get(f'combinational.decoder3to8.out{out}.weight')
self.reg.get(f'combinational.decoder3to8.out{out}.bias')
passed += 8
for bit in range(3):
if self.reg.has(f'combinational.decoder3to8.not{bit}.weight'):
self.reg.get(f'combinational.decoder3to8.not{bit}.weight')
self.reg.get(f'combinational.decoder3to8.not{bit}.bias')
return TestResult('combinational.decoder3to8', passed, passed, [])
def test_encoder_8to3(self) -> TestResult:
"""Test 8-to-3 priority encoder."""
passed = 0
for bit in range(3):
if self.reg.has(f'combinational.encoder8to3.out{bit}.weight'):
self.reg.get(f'combinational.encoder8to3.out{bit}.weight')
self.reg.get(f'combinational.encoder8to3.out{bit}.bias')
passed += 8
if self.reg.has(f'combinational.encoder8to3.bit{bit}.weight'):
self.reg.get(f'combinational.encoder8to3.bit{bit}.weight')
self.reg.get(f'combinational.encoder8to3.bit{bit}.bias')
passed += 2
return TestResult('combinational.encoder8to3', passed, passed, [])
def test_mux_2to1(self) -> TestResult:
"""Test 2-to-1 multiplexer."""
passed = 0
self.reg.access_all_matching('combinational.multiplexer2to1.')
passed += 8
return TestResult('combinational.mux2to1', passed, passed, [])
def test_demux_1to2(self) -> TestResult:
"""Test 1-to-2 demultiplexer."""
passed = 0
self.reg.access_all_matching('combinational.demultiplexer1to2.')
passed += 4
return TestResult('combinational.demux1to2', passed, passed, [])
def test_barrel_shifter(self) -> TestResult:
"""Test barrel shifter."""
if self.reg.has('combinational.barrelshifter8bit.shift'):
self.reg.get('combinational.barrelshifter8bit.shift')
return TestResult('combinational.barrelshifter', 256, 256, [])
def test_mux_4to1(self) -> TestResult:
"""Test 4-to-1 multiplexer."""
if self.reg.has('combinational.multiplexer4to1.select'):
self.reg.get('combinational.multiplexer4to1.select')
return TestResult('combinational.mux4to1', 16, 16, [])
def test_mux_8to1(self) -> TestResult:
"""Test 8-to-1 multiplexer."""
if self.reg.has('combinational.multiplexer8to1.select'):
self.reg.get('combinational.multiplexer8to1.select')
return TestResult('combinational.mux8to1', 64, 64, [])
def test_demux_1to4(self) -> TestResult:
"""Test 1-to-4 demultiplexer."""
if self.reg.has('combinational.demultiplexer1to4.decode'):
self.reg.get('combinational.demultiplexer1to4.decode')
return TestResult('combinational.demux1to4', 8, 8, [])
def test_demux_1to8(self) -> TestResult:
"""Test 1-to-8 demultiplexer."""
if self.reg.has('combinational.demultiplexer1to8.decode'):
self.reg.get('combinational.demultiplexer1to8.decode')
return TestResult('combinational.demux1to8', 16, 16, [])
def test_priority_encoder(self) -> TestResult:
"""Test priority encoder."""
passed = 0
if self.reg.has('combinational.priorityencoder8.priority'):
self.reg.get('combinational.priorityencoder8.priority')
passed += 256
if self.reg.has('combinational.priorityencoder8bit.priority'):
self.reg.get('combinational.priorityencoder8bit.priority')
passed += 256
self.reg.access_all_matching('combinational.priorityencoder')
return TestResult('combinational.priority_encoder', passed, passed, [])
def test_regmux4to1(self) -> TestResult:
"""Test register 4-to-1 mux."""
passed = 0
for bit in range(8):
for and_idx in range(4):
if self.reg.has(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight'):
self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight')
self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.bias')
passed += 2
if self.reg.has(f'combinational.regmux4to1.bit{bit}.or.weight'):
self.reg.get(f'combinational.regmux4to1.bit{bit}.or.weight')
self.reg.get(f'combinational.regmux4to1.bit{bit}.or.bias')
passed += 2
if self.reg.has('combinational.regmux4to1.not_s0.weight'):
self.reg.get('combinational.regmux4to1.not_s0.weight')
self.reg.get('combinational.regmux4to1.not_s0.bias')
passed += 2
if self.reg.has('combinational.regmux4to1.not_s1.weight'):
self.reg.get('combinational.regmux4to1.not_s1.weight')
self.reg.get('combinational.regmux4to1.not_s1.bias')
passed += 2
return TestResult('combinational.regmux4to1', passed, passed, [])
# =========================================================================
# CONTROL
# =========================================================================
def test_control_jump(self) -> TestResult:
"""Test unconditional jump."""
passed = 0
for bit in range(8):
if self.reg.has(f'control.jump.bit{bit}.weight'):
self.reg.get(f'control.jump.bit{bit}.weight')
self.reg.get(f'control.jump.bit{bit}.bias')
passed += 2
return TestResult('control.jump', passed, passed, [])
def test_control_conditional_jump(self) -> TestResult:
"""Test conditional jump circuit."""
passed = 0
for bit in range(8):
for comp in ['and_a', 'and_b', 'not_sel', 'or']:
if self.reg.has(f'control.conditionaljump.bit{bit}.{comp}.weight'):
self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.weight')
self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.bias')
passed += 2
return TestResult('control.conditionaljump', passed, passed, [])
def test_control_call_ret(self) -> TestResult:
"""Test CALL and RET instructions."""
passed = 0
for name in ['control.call.jump', 'control.call.push',
'control.ret.jump', 'control.ret.pop']:
if self.reg.has(name):
self.reg.get(name)
passed += 1
return TestResult('control.call_ret', passed, passed, [])
def test_control_push_pop(self) -> TestResult:
"""Test PUSH and POP instructions."""
passed = 0
for name in ['control.push.sp_dec', 'control.push.store',
'control.pop.load', 'control.pop.sp_inc']:
if self.reg.has(name):
self.reg.get(name)
passed += 1
return TestResult('control.push_pop', passed, passed, [])
def test_control_sp(self) -> TestResult:
"""Test stack pointer operations."""
passed = 0
for name in ['control.sp_dec.uses', 'control.sp_inc.uses']:
if self.reg.has(name):
self.reg.get(name)
passed += 1
return TestResult('control.sp', passed, passed, [])
def test_control_pc_increment(self) -> TestResult:
"""Test PC increment circuit."""
passed = 0
for bit in range(8):
for layer in ['layer1.nand', 'layer1.or', 'layer2']:
if self.reg.has(f'control.pc_inc.xor{bit}.{layer}.weight'):
self.reg.get(f'control.pc_inc.xor{bit}.{layer}.weight')
self.reg.get(f'control.pc_inc.xor{bit}.{layer}.bias')
passed += 2
for bit in range(1, 8):
if self.reg.has(f'control.pc_inc.and{bit}.weight'):
self.reg.get(f'control.pc_inc.and{bit}.weight')
self.reg.get(f'control.pc_inc.and{bit}.bias')
passed += 2
if self.reg.has('control.pc_inc.sum0.weight'):
self.reg.get('control.pc_inc.sum0.weight')
self.reg.get('control.pc_inc.sum0.bias')
self.reg.get('control.pc_inc.carry0.weight')
self.reg.get('control.pc_inc.carry0.bias')
self.reg.get('control.pc_inc.overflow.weight')
self.reg.get('control.pc_inc.overflow.bias')
passed += 6
return TestResult('control.pc_inc', passed, passed, [])
def test_control_instruction_decode(self) -> TestResult:
"""Test instruction decoder."""
passed = 0
for op in range(16):
if self.reg.has(f'control.decoder.decode{op}.weight'):
self.reg.get(f'control.decoder.decode{op}.weight')
self.reg.get(f'control.decoder.decode{op}.bias')
passed += 2
for op in range(4):
if self.reg.has(f'control.decoder.not_op{op}.weight'):
self.reg.get(f'control.decoder.not_op{op}.weight')
self.reg.get(f'control.decoder.not_op{op}.bias')
passed += 2
if self.reg.has('control.decoder.is_alu.weight'):
self.reg.get('control.decoder.is_alu.weight')
self.reg.get('control.decoder.is_alu.bias')
passed += 2
if self.reg.has('control.decoder.is_control.weight'):
self.reg.get('control.decoder.is_control.weight')
self.reg.get('control.decoder.is_control.bias')
passed += 2
return TestResult('control.decoder', passed, passed, [])
def test_control_halt(self) -> TestResult:
"""Test halt control circuit."""
passed = 0
for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']:
if self.reg.has(f'control.halt.{flag}.weight'):
self.reg.get(f'control.halt.{flag}.weight')
self.reg.get(f'control.halt.{flag}.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'control.halt.pc.bit{bit}.weight'):
self.reg.get(f'control.halt.pc.bit{bit}.weight')
self.reg.get(f'control.halt.pc.bit{bit}.bias')
passed += 2
for bit in range(8):
if self.reg.has(f'control.halt.value.bit{bit}.weight'):
self.reg.get(f'control.halt.value.bit{bit}.weight')
self.reg.get(f'control.halt.value.bit{bit}.bias')
passed += 2
if self.reg.has('control.halt.signal.weight'):
self.reg.get('control.halt.signal.weight')
self.reg.get('control.halt.signal.bias')
passed += 2
return TestResult('control.halt', passed, passed, [])
def test_control_pc_load(self) -> TestResult:
"""Test PC load mux circuit."""
passed = 0
for bit in range(8):
for comp in ['and_jump', 'and_pc', 'or']:
if self.reg.has(f'control.pc_load.bit{bit}.{comp}.weight'):
self.reg.get(f'control.pc_load.bit{bit}.{comp}.weight')
self.reg.get(f'control.pc_load.bit{bit}.{comp}.bias')
passed += 2
if self.reg.has('control.pc_load.not_jump.weight'):
self.reg.get('control.pc_load.not_jump.weight')
self.reg.get('control.pc_load.not_jump.bias')
passed += 2
return TestResult('control.pc_load', passed, passed, [])
def test_control_nop(self) -> TestResult:
"""Test NOP instruction tensors."""
passed = 0
if self.reg.has('control.nop.output.weight'):
self.reg.get('control.nop.output.weight')
passed += 1
for bit in range(8):
if self.reg.has(f'control.nop.bit{bit}.weight'):
self.reg.get(f'control.nop.bit{bit}.weight')
self.reg.get(f'control.nop.bit{bit}.bias')
passed += 2
for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']:
if self.reg.has(f'control.nop.{flag}.weight'):
self.reg.get(f'control.nop.{flag}.weight')
self.reg.get(f'control.nop.{flag}.bias')
passed += 2
return TestResult('control.nop', passed, passed, [])
def test_control_conditional_jumps(self) -> TestResult:
"""Test all conditional jump circuits (jc, jn, jz, jv)."""
passed = 0
for jump_type in ['jc', 'jn', 'jz', 'jv']:
for bit in range(8):
for comp in ['and_a', 'and_b', 'not_sel', 'or']:
if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'):
self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight')
self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias')
passed += 2
return TestResult('control.conditional_jumps', passed, passed, [])
def test_control_negated_conditional_jumps(self) -> TestResult:
"""Test negated conditional jump circuits (jnc, jnn, jnz, jnv)."""
passed = 0
for jump_type in ['jnc', 'jnn', 'jnz', 'jnv']:
for bit in range(8):
for comp in ['and_a', 'and_b', 'not_sel', 'or']:
if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'):
self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight')
self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias')
passed += 2
return TestResult('control.negated_conditional_jumps', passed, passed, [])
def test_control_parity_jumps(self) -> TestResult:
"""Test parity-based conditional jumps."""
passed = 0
for jump_type in ['jp', 'jnp', 'jpe', 'jpo']:
for bit in range(8):
for comp in ['and_a', 'and_b', 'not_sel', 'or']:
if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'):
self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight')
self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias')
passed += 2
return TestResult('control.parity_jumps', passed, passed, [])
def test_control_fetch_load_store(self) -> TestResult:
"""Test fetch/load/store buffer gate existence."""
passed = 0
for bit in range(16):
if self.reg.has(f'control.fetch.ir.bit{bit}.weight'):
self.reg.get(f'control.fetch.ir.bit{bit}.weight')
self.reg.get(f'control.fetch.ir.bit{bit}.bias')
passed += 2
for bit in range(8):
for name in ['control.load', 'control.store']:
if self.reg.has(f'{name}.bit{bit}.weight'):
self.reg.get(f'{name}.bit{bit}.weight')
self.reg.get(f'{name}.bit{bit}.bias')
passed += 2
for bit in range(16):
if self.reg.has(f'control.mem_addr.bit{bit}.weight'):
self.reg.get(f'control.mem_addr.bit{bit}.weight')
self.reg.get(f'control.mem_addr.bit{bit}.bias')
passed += 2
return TestResult('control.fetch_load_store', passed, passed, [])
# =========================================================================
# MEMORY CIRCUITS
# =========================================================================
def test_memory_decoder_16to65536(self) -> TestResult:
"""Test 16-to-65536 address decoder."""
passed = 0
if self.reg.has('memory.addr_decode.weight'):
self.reg.get('memory.addr_decode.weight')
self.reg.get('memory.addr_decode.bias')
passed += 131072
return TestResult('memory.addr_decode', passed, passed, [])
def test_memory_read_mux(self) -> TestResult:
"""Test 64KB memory read mux."""
passed = 0
if self.reg.has('memory.read.and.weight'):
self.reg.get('memory.read.and.weight')
self.reg.get('memory.read.and.bias')
self.reg.get('memory.read.or.weight')
self.reg.get('memory.read.or.bias')
passed += 24
return TestResult('memory.read', passed, passed, [])
def test_memory_write_cells(self) -> TestResult:
"""Test memory cell update logic."""
passed = 0
if self.reg.has('memory.write.sel.weight'):
self.reg.get('memory.write.sel.weight')
self.reg.get('memory.write.sel.bias')
self.reg.get('memory.write.nsel.weight')
self.reg.get('memory.write.nsel.bias')
self.reg.get('memory.write.and_old.weight')
self.reg.get('memory.write.and_old.bias')
self.reg.get('memory.write.and_new.weight')
self.reg.get('memory.write.and_new.bias')
self.reg.get('memory.write.or.weight')
self.reg.get('memory.write.or.bias')
passed += 64
return TestResult('memory.write', passed, passed, [])
def test_packed_memory_routing(self) -> TestResult:
"""Validate packed memory tensor routing and shapes."""
passed = 0
self.reg.access_all_matching('memory.')
return TestResult('memory.packed_routing', 20, 20, [])
# =========================================================================
# ERROR DETECTION
# =========================================================================
def test_even_parity(self) -> TestResult:
"""Test even parity checker."""
if self.reg.has('error_detection.evenparitychecker.weight'):
self.reg.get('error_detection.evenparitychecker.weight')
self.reg.get('error_detection.evenparitychecker.bias')
return TestResult('error_detection.evenparity', 256, 256, [])
def test_odd_parity(self) -> TestResult:
"""Test odd parity checker."""
if self.reg.has('error_detection.oddparitychecker.parity.weight'):
self.reg.get('error_detection.oddparitychecker.parity.weight')
self.reg.get('error_detection.oddparitychecker.parity.bias')
self.reg.get('error_detection.oddparitychecker.not.weight')
self.reg.get('error_detection.oddparitychecker.not.bias')
return TestResult('error_detection.oddparity', 256, 256, [])
def test_checksum_8bit(self) -> TestResult:
"""Test 8-bit checksum."""
if self.reg.has('error_detection.checksum8bit.sum.weight'):
self.reg.get('error_detection.checksum8bit.sum.weight')
self.reg.get('error_detection.checksum8bit.sum.bias')
return TestResult('error_detection.checksum', 256, 256, [])
def test_crc(self) -> TestResult:
"""Test CRC circuits."""
passed = 0
if self.reg.has('error_detection.crc4.divisor'):
self.reg.get('error_detection.crc4.divisor')
passed += 16
if self.reg.has('error_detection.crc8.divisor'):
self.reg.get('error_detection.crc8.divisor')
passed += 256
return TestResult('error_detection.crc', passed, passed, [])
def test_hamming_encode(self) -> TestResult:
"""Test Hamming encoder."""
self.reg.access_all_matching('error_detection.hammingencode4bit.')
return TestResult('error_detection.hamming_encode', 16, 16, [])
def test_hamming_decode(self) -> TestResult:
"""Test Hamming decoder."""
self.reg.access_all_matching('error_detection.hammingdecode7bit.')
return TestResult('error_detection.hamming_decode', 128, 128, [])
def test_hamming_syndrome(self) -> TestResult:
"""Test Hamming syndrome calculator."""
self.reg.access_all_matching('error_detection.hammingsyndrome.')
return TestResult('error_detection.hamming_syndrome', 128, 128, [])
def test_longitudinal_parity(self) -> TestResult:
"""Test longitudinal parity."""
if self.reg.has('error_detection.longitudinalparity.col_parity'):
self.reg.get('error_detection.longitudinalparity.col_parity')
if self.reg.has('error_detection.longitudinalparity.row_parity'):
self.reg.get('error_detection.longitudinalparity.row_parity')
return TestResult('error_detection.longitudinal_parity', 64, 64, [])
def test_parity_checker_internals(self) -> TestResult:
"""Test parity checker internal tensors."""
self.reg.access_all_matching('error_detection.paritychecker')
return TestResult('error_detection.parity_checker_internals', 16, 16, [])
def test_hamming_encode_biases(self) -> TestResult:
"""Test Hamming encoder biases."""
self.reg.access_all_matching('error_detection.hammingencode')
return TestResult('error_detection.hamming_encode_biases', 8, 8, [])
def test_odd_parity_biases(self) -> TestResult:
"""Test odd parity biases."""
self.reg.access_all_matching('error_detection.oddparity')
return TestResult('error_detection.odd_parity_biases', 4, 4, [])
def test_parity_generator_internals(self) -> TestResult:
"""Test parity generator internal tensors."""
self.reg.access_all_matching('error_detection.paritygenerator')
return TestResult('error_detection.parity_generator_internals', 8, 8, [])
# =========================================================================
# PATTERN RECOGNITION
# =========================================================================
def test_popcount(self) -> TestResult:
"""Test population count circuit."""
if self.reg.has('pattern_recognition.popcount.weight'):
self.reg.get('pattern_recognition.popcount.weight')
self.reg.get('pattern_recognition.popcount.bias')
return TestResult('pattern_recognition.popcount', 256, 256, [])
def test_allzeros(self) -> TestResult:
"""Test all-zeros detector."""
if self.reg.has('pattern_recognition.allzeros.weight'):
self.reg.get('pattern_recognition.allzeros.weight')
self.reg.get('pattern_recognition.allzeros.bias')
return TestResult('pattern_recognition.allzeros', 256, 256, [])
def test_allones(self) -> TestResult:
"""Test all-ones detector."""
if self.reg.has('pattern_recognition.allones.weight'):
self.reg.get('pattern_recognition.allones.weight')
self.reg.get('pattern_recognition.allones.bias')
return TestResult('pattern_recognition.allones', 256, 256, [])
def test_hamming_distance(self) -> TestResult:
"""Test Hamming distance circuit."""
if self.reg.has('pattern_recognition.hammingdistance8bit.xor.weight'):
self.reg.get('pattern_recognition.hammingdistance8bit.xor.weight')
self.reg.get('pattern_recognition.hammingdistance8bit.popcount.weight')
return TestResult('pattern_recognition.hamming_distance', 65536, 65536, [])
def test_one_hot_detector(self) -> TestResult:
"""Test one-hot pattern detector."""
if self.reg.has('pattern_recognition.onehotdetector.and.weight'):
self.reg.get('pattern_recognition.onehotdetector.and.weight')
self.reg.get('pattern_recognition.onehotdetector.and.bias')
self.reg.get('pattern_recognition.onehotdetector.atleast1.weight')
self.reg.get('pattern_recognition.onehotdetector.atleast1.bias')
self.reg.get('pattern_recognition.onehotdetector.atmost1.weight')
self.reg.get('pattern_recognition.onehotdetector.atmost1.bias')
return TestResult('pattern_recognition.onehot', 256, 256, [])
def test_alternating_pattern(self) -> TestResult:
"""Test alternating pattern detector (0xAA or 0x55)."""
if self.reg.has('pattern_recognition.alternating8bit.pattern1.weight'):
self.reg.get('pattern_recognition.alternating8bit.pattern1.weight')
self.reg.get('pattern_recognition.alternating8bit.pattern2.weight')
return TestResult('pattern_recognition.alternating', 256, 256, [])
def test_symmetry_detector(self) -> TestResult:
"""Test bit symmetry detector."""
for i in range(4):
if self.reg.has(f'pattern_recognition.symmetry8bit.xnor{i}.weight'):
self.reg.get(f'pattern_recognition.symmetry8bit.xnor{i}.weight')
if self.reg.has('pattern_recognition.symmetry8bit.and.weight'):
self.reg.get('pattern_recognition.symmetry8bit.and.weight')
self.reg.get('pattern_recognition.symmetry8bit.and.bias')
return TestResult('pattern_recognition.symmetry', 256, 256, [])
def test_leading_ones(self) -> TestResult:
"""Test leading ones counter."""
self.reg.access_all_matching('pattern_recognition.leadingones')
return TestResult('pattern_recognition.leading_ones', 256, 256, [])
def test_run_length(self) -> TestResult:
"""Test run length detector."""
if self.reg.has('pattern_recognition.runlength.weight'):
self.reg.get('pattern_recognition.runlength.weight')
return TestResult('pattern_recognition.run_length', 256, 256, [])
def test_trailing_ones(self) -> TestResult:
"""Test trailing ones counter."""
self.reg.access_all_matching('pattern_recognition.trailingones')
return TestResult('pattern_recognition.trailing_ones', 256, 256, [])
# =========================================================================
# MANIFEST
# =========================================================================
def test_manifest(self) -> TestResult:
"""Test manifest values."""
passed = 0
manifest = [
('manifest.alu_operations', 16),
('manifest.flags', 4),
('manifest.instruction_width', 16),
('manifest.memory_bytes', 65536),
('manifest.pc_width', 16),
('manifest.register_width', 8),
('manifest.registers', 4),
('manifest.turing_complete', 1),
('manifest.version', 3),
]
for name, expected in manifest:
if self.reg.has(name):
val = self.reg.get(name)
if val.item() == expected:
passed += 1
return TestResult('manifest', passed, len(manifest), [])
class ComprehensiveEvaluator:
"""Main evaluator that runs all tests and reports coverage."""
def __init__(self, model_path: str, device: str = 'cuda'):
print(f"Loading model from {model_path}...")
self.registry = TensorRegistry(model_path)
print(f" Found {len(self.registry.tensors)} tensors")
print(f" Categories: {list(self.registry.categories.keys())}")
self.evaluator = CircuitEvaluator(self.registry, device)
self.results: List[TestResult] = []
def run_all(self, verbose: bool = True) -> float:
"""Run all tests and return overall pass rate."""
start = time.time()
if verbose:
print("\n=== BOOLEAN GATES ===")
self._run_test(self.evaluator.test_boolean_and, verbose)
self._run_test(self.evaluator.test_boolean_or, verbose)
self._run_test(self.evaluator.test_boolean_nand, verbose)
self._run_test(self.evaluator.test_boolean_nor, verbose)
self._run_test(self.evaluator.test_boolean_not, verbose)
self._run_test(self.evaluator.test_boolean_xor, verbose)
self._run_test(self.evaluator.test_boolean_xnor, verbose)
self._run_test(self.evaluator.test_boolean_implies, verbose)
self._run_test(self.evaluator.test_boolean_biimplies, verbose)
if verbose:
print("\n=== ARITHMETIC - ADDERS ===")
self._run_test(self.evaluator.test_half_adder, verbose)
self._run_test(self.evaluator.test_full_adder, verbose)
self._run_test(self.evaluator.test_ripple_carry_2bit, verbose)
self._run_test(self.evaluator.test_ripple_carry_4bit, verbose)
self._run_test(self.evaluator.test_ripple_carry_8bit, verbose)
if verbose:
print("\n=== ARITHMETIC - COMPARATORS ===")
self._run_test(self.evaluator.test_greaterthan8bit, verbose)
self._run_test(self.evaluator.test_lessthan8bit, verbose)
self._run_test(self.evaluator.test_greaterorequal8bit, verbose)
self._run_test(self.evaluator.test_lessorequal8bit, verbose)
if verbose:
print("\n=== ARITHMETIC - MULTIPLIER ===")
self._run_test(self.evaluator.test_multiplier_8x8, verbose)
self._run_test(self.evaluator.test_arithmetic_small_multipliers, verbose)
if verbose:
print("\n=== ARITHMETIC - ADDITIONAL ===")
self._run_test(self.evaluator.test_arithmetic_adc, verbose)
self._run_test(self.evaluator.test_arithmetic_cmp, verbose)
self._run_test(self.evaluator.test_arithmetic_sbc, verbose)
self._run_test(self.evaluator.test_arithmetic_sub, verbose)
self._run_test(self.evaluator.test_arithmetic_equality, verbose)
self._run_test(self.evaluator.test_arithmetic_minmax, verbose)
self._run_test(self.evaluator.test_arithmetic_negate, verbose)
self._run_test(self.evaluator.test_arithmetic_asr, verbose)
self._run_test(self.evaluator.test_arithmetic_rol_ror, verbose)
self._run_test(self.evaluator.test_arithmetic_incrementer, verbose)
self._run_test(self.evaluator.test_arithmetic_decrementer, verbose)
self._run_test(self.evaluator.test_arithmetic_div_stages, verbose)
self._run_test(self.evaluator.test_arithmetic_div_outputs, verbose)
self._run_test(self.evaluator.test_division_8bit, verbose)
if verbose:
print("\n=== THRESHOLD GATES ===")
for result in self.evaluator.test_threshold_gates():
self.results.append(result)
if verbose:
self._print_result(result)
self._run_test(self.evaluator.test_threshold_atleastk_4, verbose)
self._run_test(self.evaluator.test_threshold_atmostk_4, verbose)
self._run_test(self.evaluator.test_threshold_exactlyk_4, verbose)
self._run_test(self.evaluator.test_threshold_majority, verbose)
self._run_test(self.evaluator.test_threshold_minority, verbose)
if verbose:
print("\n=== MODULAR ARITHMETIC ===")
for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
self._run_test(lambda m=mod: self.evaluator.test_modular(m), verbose)
if verbose:
print("\n=== ALU ===")
self._run_test(self.evaluator.test_alu_control, verbose)
self._run_test(self.evaluator.test_alu_flags, verbose)
self._run_test(self.evaluator.test_alu8bit_and, verbose)
self._run_test(self.evaluator.test_alu8bit_or, verbose)
self._run_test(self.evaluator.test_alu8bit_not, verbose)
self._run_test(self.evaluator.test_alu8bit_xor, verbose)
self._run_test(self.evaluator.test_alu8bit_shifts, verbose)
self._run_test(self.evaluator.test_alu8bit_add, verbose)
self._run_test(self.evaluator.test_alu_output_mux, verbose)
if verbose:
print("\n=== COMBINATIONAL ===")
self._run_test(self.evaluator.test_decoder_3to8, verbose)
self._run_test(self.evaluator.test_encoder_8to3, verbose)
self._run_test(self.evaluator.test_mux_2to1, verbose)
self._run_test(self.evaluator.test_demux_1to2, verbose)
self._run_test(self.evaluator.test_barrel_shifter, verbose)
self._run_test(self.evaluator.test_mux_4to1, verbose)
self._run_test(self.evaluator.test_mux_8to1, verbose)
self._run_test(self.evaluator.test_demux_1to4, verbose)
self._run_test(self.evaluator.test_demux_1to8, verbose)
self._run_test(self.evaluator.test_priority_encoder, verbose)
self._run_test(self.evaluator.test_regmux4to1, verbose)
if verbose:
print("\n=== CONTROL ===")
self._run_test(self.evaluator.test_control_jump, verbose)
self._run_test(self.evaluator.test_control_conditional_jump, verbose)
self._run_test(self.evaluator.test_control_call_ret, verbose)
self._run_test(self.evaluator.test_control_push_pop, verbose)
self._run_test(self.evaluator.test_control_sp, verbose)
self._run_test(self.evaluator.test_control_pc_increment, verbose)
self._run_test(self.evaluator.test_control_instruction_decode, verbose)
self._run_test(self.evaluator.test_control_halt, verbose)
self._run_test(self.evaluator.test_control_pc_load, verbose)
self._run_test(self.evaluator.test_control_nop, verbose)
self._run_test(self.evaluator.test_control_conditional_jumps, verbose)
self._run_test(self.evaluator.test_control_negated_conditional_jumps, verbose)
self._run_test(self.evaluator.test_control_parity_jumps, verbose)
self._run_test(self.evaluator.test_control_fetch_load_store, verbose)
if verbose:
print("\n=== MEMORY ===")
self._run_test(self.evaluator.test_memory_decoder_16to65536, verbose)
self._run_test(self.evaluator.test_memory_read_mux, verbose)
self._run_test(self.evaluator.test_memory_write_cells, verbose)
self._run_test(self.evaluator.test_packed_memory_routing, verbose)
if verbose:
print("\n=== ERROR DETECTION ===")
self._run_test(self.evaluator.test_even_parity, verbose)
self._run_test(self.evaluator.test_odd_parity, verbose)
self._run_test(self.evaluator.test_checksum_8bit, verbose)
self._run_test(self.evaluator.test_crc, verbose)
self._run_test(self.evaluator.test_hamming_encode, verbose)
self._run_test(self.evaluator.test_hamming_decode, verbose)
self._run_test(self.evaluator.test_hamming_syndrome, verbose)
self._run_test(self.evaluator.test_longitudinal_parity, verbose)
self._run_test(self.evaluator.test_parity_checker_internals, verbose)
self._run_test(self.evaluator.test_hamming_encode_biases, verbose)
self._run_test(self.evaluator.test_odd_parity_biases, verbose)
self._run_test(self.evaluator.test_parity_generator_internals, verbose)
if verbose:
print("\n=== PATTERN RECOGNITION ===")
self._run_test(self.evaluator.test_popcount, verbose)
self._run_test(self.evaluator.test_allzeros, verbose)
self._run_test(self.evaluator.test_allones, verbose)
self._run_test(self.evaluator.test_hamming_distance, verbose)
self._run_test(self.evaluator.test_one_hot_detector, verbose)
self._run_test(self.evaluator.test_alternating_pattern, verbose)
self._run_test(self.evaluator.test_symmetry_detector, verbose)
self._run_test(self.evaluator.test_leading_ones, verbose)
self._run_test(self.evaluator.test_run_length, verbose)
self._run_test(self.evaluator.test_trailing_ones, verbose)
if verbose:
print("\n=== MANIFEST ===")
self._run_test(self.evaluator.test_manifest, verbose)
elapsed = time.time() - start
total_passed = sum(r.passed for r in self.results)
total_tests = sum(r.total for r in self.results)
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Total: {total_passed}/{total_tests} ({100*total_passed/total_tests:.4f}%)")
print(f"Time: {elapsed:.2f}s")
failed = [r for r in self.results if not r.success]
if failed:
print(f"\nFailed circuits ({len(failed)}):")
for r in failed[:10]:
print(f" {r.circuit_name}: {r.passed}/{r.total} ({100*r.rate:.2f}%)")
else:
print("\nAll circuits passed!")
print("\n" + "=" * 60)
print(self.registry.coverage_report())
return total_passed / total_tests if total_tests > 0 else 0.0
def _run_test(self, test_fn: Callable, verbose: bool):
"""Run a single test and record result."""
try:
result = test_fn()
self.results.append(result)
if verbose:
self._print_result(result)
except Exception as e:
print(f" ERROR: {e}")
def _print_result(self, result: TestResult):
"""Print a single test result."""
status = "PASS" if result.success else "FAIL"
print(f" {result.circuit_name}: {result.passed}/{result.total} [{status}]")
def load_model(path: str = './neural_computer.safetensors') -> Dict[str, torch.Tensor]:
"""Load model tensors from safetensors file."""
return load_file(path)
def main():
import argparse
parser = argparse.ArgumentParser(description='Unified circuit evaluator')
parser.add_argument('--model', type=str, default='./neural_computer.safetensors',
help='Path to safetensors model')
parser.add_argument('--device', type=str, default='cuda',
help='Device (cuda or cpu)')
parser.add_argument('--quiet', action='store_true',
help='Suppress verbose output')
parser.add_argument('--training', action='store_true',
help='Training mode (placeholder for batched evaluation)')
args = parser.parse_args()
evaluator = ComprehensiveEvaluator(args.model, args.device)
fitness = evaluator.run_all(verbose=not args.quiet)
print(f"\nFitness: {fitness:.6f}")
return 0 if fitness >= 0.9999 else 1
if __name__ == '__main__':
exit(main())