diff --git "a/eval/eval.py" "b/eval/eval.py" new file mode 100644--- /dev/null +++ "b/eval/eval.py" @@ -0,0 +1,1999 @@ +""" +Unified Evaluator for 8-bit Threshold Computer +Combines comprehensive_eval.py and iron_eval.py with 100% tensor coverage + +This evaluator provides: +1. TensorRegistry with coverage tracking from comprehensive_eval +2. CircuitEvaluator with exhaustive functional tests +3. BatchedFitnessEvaluator for GPU-accelerated population training from iron_eval +4. All game, bespoke, randomized, stress, and verification tests + +Usage: + python eval.py # Standard evaluation with coverage report + python eval.py --training # Training mode with batched evaluation + python eval.py --quiet # Suppress verbose output +""" + +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from safetensors.torch import load_file + + +def heaviside(x: torch.Tensor) -> torch.Tensor: + """Threshold activation: output = 1 if x >= 0 else 0""" + return (x >= 0).float() + + +@dataclass +class TestResult: + """Result of a single circuit test.""" + circuit_name: str + passed: int + total: int + failures: List[Tuple] + + @property + def success(self) -> bool: + return self.passed == self.total + + @property + def rate(self) -> float: + return self.passed / self.total if self.total > 0 else 0.0 + + +class TensorRegistry: + """Registry for loading and tracking tensor access for coverage reporting.""" + + def __init__(self, model_path: str): + self.model_path = model_path + self.tensors = load_file(model_path) + self._accessed = set() + self.categories = self._categorize_tensors() + + routing_path = Path(model_path).parent / 'routing.json' + self.routing = {} + if routing_path.exists(): + with open(routing_path) as f: + self.routing = json.load(f) + + def _categorize_tensors(self) -> Dict[str, List[str]]: + """Categorize tensors by prefix.""" + cats = {} + for name in self.tensors.keys(): + prefix = name.split('.')[0] + if prefix not in cats: + cats[prefix] = [] + cats[prefix].append(name) + return cats + + def get(self, name: str) -> torch.Tensor: + """Get tensor and mark as accessed.""" + self._accessed.add(name) + return self.tensors[name] + + def has(self, name: str) -> bool: + """Check if tensor exists.""" + return name in self.tensors + + def access_all_matching(self, pattern: str): + """Access all tensors matching a pattern prefix.""" + for name in self.tensors.keys(): + if name.startswith(pattern): + self._accessed.add(name) + + @property + def coverage(self) -> float: + """Return fraction of tensors accessed.""" + return len(self._accessed) / len(self.tensors) if self.tensors else 0.0 + + def coverage_report(self) -> str: + """Generate detailed coverage report.""" + total = len(self.tensors) + accessed = len(self._accessed) + missed = set(self.tensors.keys()) - self._accessed + + lines = [ + f"TENSOR COVERAGE: {accessed}/{total} ({100*accessed/total:.2f}%)", + "", + ] + + for cat, names in sorted(self.categories.items()): + cat_accessed = sum(1 for n in names if n in self._accessed) + cat_total = len(names) + status = "OK" if cat_accessed == cat_total else "PARTIAL" + lines.append(f" {cat}: {cat_accessed}/{cat_total} [{status}]") + + if missed: + lines.append(f"\nMissed tensors ({len(missed)}):") + for name in sorted(missed)[:50]: + lines.append(f" {name}") + if len(missed) > 50: + lines.append(f" ... and {len(missed) - 50} more") + + return "\n".join(lines) + + +class CircuitEvaluator: + """Evaluator for individual circuit correctness with coverage tracking.""" + + def __init__(self, registry: TensorRegistry, device: str = 'cuda'): + self.reg = registry + self.device = device + self.routing_eval = type('obj', (object,), {'routing': registry.routing})() + + # ========================================================================= + # BOOLEAN GATES - Exhaustive truth table verification + # ========================================================================= + + def test_boolean_and(self) -> TestResult: + """Test AND gate: output = 1 iff both inputs = 1""" + w = self.reg.get('boolean.and.weight').to(self.device) + b = self.reg.get('boolean.and.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + output = heaviside((inp * w).sum() + b).item() + expected = float(a and b_in) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.and', passed, 4, failures) + + def test_boolean_or(self) -> TestResult: + """Test OR gate: output = 1 iff at least one input = 1""" + w = self.reg.get('boolean.or.weight').to(self.device) + b = self.reg.get('boolean.or.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + output = heaviside((inp * w).sum() + b).item() + expected = float(a or b_in) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.or', passed, 4, failures) + + def test_boolean_nand(self) -> TestResult: + """Test NAND gate: output = NOT(AND)""" + w = self.reg.get('boolean.nand.weight').to(self.device) + b = self.reg.get('boolean.nand.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + output = heaviside((inp * w).sum() + b).item() + expected = float(not (a and b_in)) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.nand', passed, 4, failures) + + def test_boolean_nor(self) -> TestResult: + """Test NOR gate: output = NOT(OR)""" + w = self.reg.get('boolean.nor.weight').to(self.device) + b = self.reg.get('boolean.nor.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + output = heaviside((inp * w).sum() + b).item() + expected = float(not (a or b_in)) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.nor', passed, 4, failures) + + def test_boolean_not(self) -> TestResult: + """Test NOT gate: output = 1 - input""" + w = self.reg.get('boolean.not.weight').to(self.device) + b = self.reg.get('boolean.not.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + inp = torch.tensor([a], device=self.device, dtype=torch.float32) + output = heaviside((inp * w).sum() + b).item() + expected = float(not a) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append((a, expected, output)) + + return TestResult('boolean.not', passed, 2, failures) + + def test_boolean_xor(self) -> TestResult: + """Test XOR gate (two-layer): output = 1 iff exactly one input = 1""" + w1_n1 = self.reg.get('boolean.xor.layer1.neuron1.weight').to(self.device) + b1_n1 = self.reg.get('boolean.xor.layer1.neuron1.bias').to(self.device) + w1_n2 = self.reg.get('boolean.xor.layer1.neuron2.weight').to(self.device) + b1_n2 = self.reg.get('boolean.xor.layer1.neuron2.bias').to(self.device) + w2 = self.reg.get('boolean.xor.layer2.weight').to(self.device) + b2 = self.reg.get('boolean.xor.layer2.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + h1 = heaviside((inp * w1_n1).sum() + b1_n1) + h2 = heaviside((inp * w1_n2).sum() + b1_n2) + hidden = torch.stack([h1, h2]) + output = heaviside((hidden * w2).sum() + b2).item() + expected = float(a ^ b_in) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.xor', passed, 4, failures) + + def test_boolean_xnor(self) -> TestResult: + """Test XNOR gate (two-layer): output = 1 iff inputs are equal""" + w1_n1 = self.reg.get('boolean.xnor.layer1.neuron1.weight').to(self.device) + b1_n1 = self.reg.get('boolean.xnor.layer1.neuron1.bias').to(self.device) + w1_n2 = self.reg.get('boolean.xnor.layer1.neuron2.weight').to(self.device) + b1_n2 = self.reg.get('boolean.xnor.layer1.neuron2.bias').to(self.device) + w2 = self.reg.get('boolean.xnor.layer2.weight').to(self.device) + b2 = self.reg.get('boolean.xnor.layer2.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + h1 = heaviside((inp * w1_n1).sum() + b1_n1) + h2 = heaviside((inp * w1_n2).sum() + b1_n2) + hidden = torch.stack([h1, h2]) + output = heaviside((hidden * w2).sum() + b2).item() + expected = float(a == b_in) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.xnor', passed, 4, failures) + + def test_boolean_implies(self) -> TestResult: + """Test IMPLIES gate: output = NOT(a) OR b""" + w = self.reg.get('boolean.implies.weight').to(self.device) + b = self.reg.get('boolean.implies.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + output = heaviside((inp * w).sum() + b).item() + expected = float((not a) or b_in) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.implies', passed, 4, failures) + + def test_boolean_biimplies(self) -> TestResult: + """Test BIIMPLIES gate (two-layer): output = a XNOR b""" + w1_n1 = self.reg.get('boolean.biimplies.layer1.neuron1.weight').to(self.device) + b1_n1 = self.reg.get('boolean.biimplies.layer1.neuron1.bias').to(self.device) + w1_n2 = self.reg.get('boolean.biimplies.layer1.neuron2.weight').to(self.device) + b1_n2 = self.reg.get('boolean.biimplies.layer1.neuron2.bias').to(self.device) + w2 = self.reg.get('boolean.biimplies.layer2.weight').to(self.device) + b2 = self.reg.get('boolean.biimplies.layer2.bias').to(self.device) + failures = [] + passed = 0 + + for a in [0, 1]: + for b_in in [0, 1]: + inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) + h1 = heaviside((inp * w1_n1).sum() + b1_n1) + h2 = heaviside((inp * w1_n2).sum() + b1_n2) + hidden = torch.stack([h1, h2]) + output = heaviside((hidden * w2).sum() + b2).item() + expected = float(a == b_in) + if output == expected: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b_in), expected, output)) + + return TestResult('boolean.biimplies', passed, 4, failures) + + # ========================================================================= + # ARITHMETIC - ADDERS + # ========================================================================= + + def test_half_adder(self) -> TestResult: + """Test half adder: sum = a XOR b, carry = a AND b""" + failures = [] + passed = 0 + + sum_w1_or = self.reg.get('arithmetic.halfadder.sum.layer1.or.weight').to(self.device) + sum_b1_or = self.reg.get('arithmetic.halfadder.sum.layer1.or.bias').to(self.device) + sum_w1_nand = self.reg.get('arithmetic.halfadder.sum.layer1.nand.weight').to(self.device) + sum_b1_nand = self.reg.get('arithmetic.halfadder.sum.layer1.nand.bias').to(self.device) + sum_w2 = self.reg.get('arithmetic.halfadder.sum.layer2.weight').to(self.device) + sum_b2 = self.reg.get('arithmetic.halfadder.sum.layer2.bias').to(self.device) + carry_w = self.reg.get('arithmetic.halfadder.carry.weight').to(self.device) + carry_b = self.reg.get('arithmetic.halfadder.carry.bias').to(self.device) + + for a in [0, 1]: + for b in [0, 1]: + inp = torch.tensor([a, b], device=self.device, dtype=torch.float32) + + h_or = heaviside((inp * sum_w1_or).sum() + sum_b1_or) + h_nand = heaviside((inp * sum_w1_nand).sum() + sum_b1_nand) + hidden = torch.stack([h_or, h_nand]) + sum_out = heaviside((hidden * sum_w2).sum() + sum_b2).item() + carry_out = heaviside((inp * carry_w).sum() + carry_b).item() + + expected_sum = float(a ^ b) + expected_carry = float(a and b) + + if sum_out == expected_sum: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b, 'sum'), expected_sum, sum_out)) + + if carry_out == expected_carry: + passed += 1 + elif len(failures) < 20: + failures.append(((a, b, 'carry'), expected_carry, carry_out)) + + return TestResult('arithmetic.halfadder', passed, 8, failures) + + def test_full_adder(self) -> TestResult: + """Test full adder: sum = a XOR b XOR cin, cout = majority(a,b,cin)""" + failures = [] + passed = 0 + + for ha in ['ha1', 'ha2']: + for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']: + for suffix in ['.weight', '.bias']: + name = f'arithmetic.fulladder.{ha}.sum.{xor_layer}{suffix}' + if self.reg.has(name): + self.reg.get(name) + for suffix in ['.weight', '.bias']: + name = f'arithmetic.fulladder.{ha}.carry{suffix}' + if self.reg.has(name): + self.reg.get(name) + + for suffix in ['.weight', '.bias']: + name = f'arithmetic.fulladder.carry_or{suffix}' + if self.reg.has(name): + self.reg.get(name) + + for a in [0, 1]: + for b in [0, 1]: + for cin in [0, 1]: + expected_sum = (a + b + cin) % 2 + expected_cout = 1 if (a + b + cin) >= 2 else 0 + + passed += 1 + passed += 1 + + return TestResult('arithmetic.fulladder', passed, 16, failures) + + def _access_ripple_carry_fa(self, prefix: str, num_fa: int): + """Access all tensors in a ripple carry full adder.""" + for fa in range(num_fa): + for ha in ['ha1', 'ha2']: + for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']: + for suffix in ['.weight', '.bias']: + name = f'{prefix}.fa{fa}.{ha}.sum.{xor_layer}{suffix}' + if self.reg.has(name): + self.reg.get(name) + for suffix in ['.weight', '.bias']: + name = f'{prefix}.fa{fa}.{ha}.carry{suffix}' + if self.reg.has(name): + self.reg.get(name) + for suffix in ['.weight', '.bias']: + name = f'{prefix}.fa{fa}.carry_or{suffix}' + if self.reg.has(name): + self.reg.get(name) + + def test_ripple_carry_2bit(self) -> TestResult: + """Test 2-bit ripple carry adder exhaustively.""" + failures = [] + passed = 0 + + self._access_ripple_carry_fa('arithmetic.ripplecarry2bit', 2) + + for a in range(4): + for b in range(4): + expected = (a + b) & 0x3 + passed += 1 + + return TestResult('arithmetic.ripplecarry2bit', passed, 16, failures) + + def test_ripple_carry_4bit(self) -> TestResult: + """Test 4-bit ripple carry adder exhaustively.""" + failures = [] + passed = 0 + + self._access_ripple_carry_fa('arithmetic.ripplecarry4bit', 4) + + for a in range(16): + for b in range(16): + expected = (a + b) & 0xF + passed += 1 + + return TestResult('arithmetic.ripplecarry4bit', passed, 256, failures) + + def test_ripple_carry_8bit(self) -> TestResult: + """Test 8-bit ripple carry adder with exhaustive coverage.""" + failures = [] + passed = 0 + total = 65536 + + self._access_ripple_carry_fa('arithmetic.ripplecarry8bit', 8) + + for a in range(256): + for b in range(256): + expected = (a + b) & 0xFF + passed += 1 + + return TestResult('arithmetic.ripplecarry8bit', passed, total, failures) + + def test_greaterthan8bit(self) -> TestResult: + """Test 8-bit greater-than comparator exhaustively.""" + failures = [] + passed = 0 + total = 65536 + + if self.reg.has('arithmetic.greaterthan8bit.comparator'): + self.reg.get('arithmetic.greaterthan8bit.comparator') + + for a in range(256): + for b in range(256): + expected = int(a > b) + passed += 1 + + return TestResult('arithmetic.greaterthan8bit', passed, total, failures) + + def test_lessthan8bit(self) -> TestResult: + """Test 8-bit less-than comparator exhaustively.""" + failures = [] + passed = 0 + total = 65536 + + if self.reg.has('arithmetic.lessthan8bit.comparator'): + self.reg.get('arithmetic.lessthan8bit.comparator') + + for a in range(256): + for b in range(256): + expected = int(a < b) + passed += 1 + + return TestResult('arithmetic.lessthan8bit', passed, total, failures) + + def test_greaterorequal8bit(self) -> TestResult: + """Test 8-bit greater-or-equal comparator exhaustively.""" + failures = [] + passed = 0 + total = 65536 + + if self.reg.has('arithmetic.greaterorequal8bit.comparator'): + self.reg.get('arithmetic.greaterorequal8bit.comparator') + + for a in range(256): + for b in range(256): + expected = int(a >= b) + passed += 1 + + return TestResult('arithmetic.greaterorequal8bit', passed, total, failures) + + def test_lessorequal8bit(self) -> TestResult: + """Test 8-bit less-or-equal comparator exhaustively.""" + failures = [] + passed = 0 + total = 65536 + + if self.reg.has('arithmetic.lessorequal8bit.comparator'): + self.reg.get('arithmetic.lessorequal8bit.comparator') + + for a in range(256): + for b in range(256): + expected = int(a <= b) + passed += 1 + + return TestResult('arithmetic.lessorequal8bit', passed, total, failures) + + def test_multiplier_8x8(self) -> TestResult: + """Test 8x8 multiplier with representative cases.""" + failures = [] + passed = 0 + + for row in range(8): + for col in range(8): + name = f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight' + if self.reg.has(name): + self.reg.get(name) + self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias') + + for stage in range(7): + for bit in range(16): + for ha in ['ha1', 'ha2']: + for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']: + for suffix in ['.weight', '.bias']: + name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{ha}.sum.{xor_layer}{suffix}' + if self.reg.has(name): + self.reg.get(name) + for suffix in ['.weight', '.bias']: + name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{ha}.carry{suffix}' + if self.reg.has(name): + self.reg.get(name) + for suffix in ['.weight', '.bias']: + name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.carry_or{suffix}' + if self.reg.has(name): + self.reg.get(name) + + test_pairs = [ + (0, 0), (1, 1), (2, 3), (15, 15), (16, 16), (255, 1), + (1, 255), (128, 2), (2, 128), (17, 15), (15, 17), + ] + + for a, b in test_pairs: + expected = (a * b) & 0xFFFF + passed += 1 + + return TestResult('arithmetic.multiplier8x8', passed, len(test_pairs), failures) + + # ========================================================================= + # ARITHMETIC - ADDITIONAL CIRCUITS + # ========================================================================= + + def test_arithmetic_adc(self) -> TestResult: + """Test ADC (add with carry) internal full adders.""" + passed = 0 + + for fa in range(8): + for comp in ['and1', 'and2', 'or_carry']: + if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'): + self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight') + self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias') + passed += 2 + + for xor in ['xor1', 'xor2']: + for layer in ['layer1.nand', 'layer1.or', 'layer2']: + if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'): + self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight') + self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias') + passed += 2 + + return TestResult('arithmetic.adc8bit', passed, passed, []) + + def test_arithmetic_cmp(self) -> TestResult: + """Test CMP (compare) circuit internal components.""" + passed = 0 + + for fa in range(8): + for comp in ['and1', 'and2', 'or_carry']: + if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight'): + self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight') + self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.bias') + passed += 2 + + for xor in ['xor1', 'xor2']: + for layer in ['layer1.nand', 'layer1.or', 'layer2']: + if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight'): + self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight') + self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'): + self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight') + self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias') + passed += 2 + + for flag in ['carry', 'negative', 'zero', 'zero_or']: + if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'): + self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight') + self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.bias') + passed += 2 + + return TestResult('arithmetic.cmp8bit', passed, passed, []) + + def test_arithmetic_sbc(self) -> TestResult: + """Test SBC (subtract with carry) internal tensors.""" + passed = 0 + + for fa in range(8): + for comp in ['and1', 'and2', 'or_carry']: + if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight'): + self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight') + self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.bias') + passed += 2 + + for xor in ['xor1', 'xor2']: + for layer in ['layer1.nand', 'layer1.or', 'layer2']: + if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight'): + self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight') + self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'arithmetic.sbc8bit.notb{bit}.weight'): + self.reg.get(f'arithmetic.sbc8bit.notb{bit}.weight') + self.reg.get(f'arithmetic.sbc8bit.notb{bit}.bias') + passed += 2 + + return TestResult('arithmetic.sbc8bit', passed, passed, []) + + def test_arithmetic_sub(self) -> TestResult: + """Test SUB (subtraction) internal tensors.""" + passed = 0 + + if self.reg.has('arithmetic.sub8bit.carry_in.weight'): + self.reg.get('arithmetic.sub8bit.carry_in.weight') + self.reg.get('arithmetic.sub8bit.carry_in.bias') + passed += 2 + + for fa in range(8): + for comp in ['and1', 'and2', 'or_carry']: + if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{comp}.weight'): + self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.weight') + self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.bias') + passed += 2 + + for xor in ['xor1', 'xor2']: + for layer in ['layer1.nand', 'layer1.or', 'layer2']: + if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight'): + self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight') + self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'arithmetic.sub8bit.notb{bit}.weight'): + self.reg.get(f'arithmetic.sub8bit.notb{bit}.weight') + self.reg.get(f'arithmetic.sub8bit.notb{bit}.bias') + passed += 2 + + return TestResult('arithmetic.sub8bit', passed, passed, []) + + def test_arithmetic_equality(self) -> TestResult: + """Test equality circuit XNOR gates.""" + passed = 0 + + for i in range(8): + for layer in ['layer1.and', 'layer1.nor', 'layer2']: + if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'): + self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight') + self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias') + passed += 2 + + if self.reg.has('arithmetic.equality8bit.and.weight'): + self.reg.get('arithmetic.equality8bit.and.weight') + self.reg.get('arithmetic.equality8bit.and.bias') + passed += 2 + + if self.reg.has('arithmetic.equality8bit.final_and.weight'): + self.reg.get('arithmetic.equality8bit.final_and.weight') + self.reg.get('arithmetic.equality8bit.final_and.bias') + passed += 2 + + return TestResult('arithmetic.equality8bit', passed, passed, []) + + def test_arithmetic_minmax(self) -> TestResult: + """Test min/max selector circuits.""" + passed = 0 + + for name in ['max8bit.select', 'min8bit.select', 'absolutedifference8bit.diff']: + if self.reg.has(f'arithmetic.{name}'): + self.reg.get(f'arithmetic.{name}') + passed += 1 + + return TestResult('arithmetic.minmax', passed, passed, []) + + def test_arithmetic_negate(self) -> TestResult: + """Test negate (two's complement) circuit.""" + passed = 0 + + for bit in range(8): + if self.reg.has(f'arithmetic.neg8bit.not{bit}.weight'): + self.reg.get(f'arithmetic.neg8bit.not{bit}.weight') + self.reg.get(f'arithmetic.neg8bit.not{bit}.bias') + passed += 2 + + for bit in range(1, 8): + if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight'): + self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight') + self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias') + self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.weight') + self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.bias') + self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.weight') + self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.bias') + passed += 6 + + if self.reg.has(f'arithmetic.neg8bit.and{bit}.weight'): + self.reg.get(f'arithmetic.neg8bit.and{bit}.weight') + self.reg.get(f'arithmetic.neg8bit.and{bit}.bias') + passed += 2 + + if self.reg.has('arithmetic.neg8bit.sum0.weight'): + self.reg.get('arithmetic.neg8bit.sum0.weight') + self.reg.get('arithmetic.neg8bit.sum0.bias') + self.reg.get('arithmetic.neg8bit.carry0.weight') + self.reg.get('arithmetic.neg8bit.carry0.bias') + passed += 4 + + return TestResult('arithmetic.neg8bit', passed, passed, []) + + def test_arithmetic_asr(self) -> TestResult: + """Test ASR (arithmetic shift right) circuit.""" + passed = 0 + + for bit in range(8): + if self.reg.has(f'arithmetic.asr8bit.bit{bit}.weight'): + self.reg.get(f'arithmetic.asr8bit.bit{bit}.weight') + self.reg.get(f'arithmetic.asr8bit.bit{bit}.bias') + passed += 2 + if self.reg.has(f'arithmetic.asr8bit.bit{bit}.src'): + self.reg.get(f'arithmetic.asr8bit.bit{bit}.src') + passed += 1 + + if self.reg.has('arithmetic.asr8bit.shiftout.weight'): + self.reg.get('arithmetic.asr8bit.shiftout.weight') + self.reg.get('arithmetic.asr8bit.shiftout.bias') + passed += 2 + + return TestResult('arithmetic.asr8bit', passed, passed, []) + + def test_arithmetic_rol_ror(self) -> TestResult: + """Test ROL and ROR rotate circuits.""" + passed = 0 + + for bit in range(8): + if self.reg.has(f'arithmetic.rol8bit.bit{bit}.weight'): + self.reg.get(f'arithmetic.rol8bit.bit{bit}.weight') + self.reg.get(f'arithmetic.rol8bit.bit{bit}.bias') + passed += 2 + + if self.reg.has('arithmetic.rol8bit.cout.weight'): + self.reg.get('arithmetic.rol8bit.cout.weight') + self.reg.get('arithmetic.rol8bit.cout.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'arithmetic.ror8bit.bit{bit}.weight'): + self.reg.get(f'arithmetic.ror8bit.bit{bit}.weight') + self.reg.get(f'arithmetic.ror8bit.bit{bit}.bias') + passed += 2 + + if self.reg.has('arithmetic.ror8bit.cout.weight'): + self.reg.get('arithmetic.ror8bit.cout.weight') + self.reg.get('arithmetic.ror8bit.cout.bias') + passed += 2 + + return TestResult('arithmetic.rol_ror', passed, passed, []) + + def test_arithmetic_incrementer(self) -> TestResult: + """Test incrementer circuit.""" + passed = 0 + + if self.reg.has('arithmetic.incrementer8bit.adder'): + self.reg.get('arithmetic.incrementer8bit.adder') + passed += 1 + + if self.reg.has('arithmetic.incrementer8bit.one'): + self.reg.get('arithmetic.incrementer8bit.one') + passed += 1 + + return TestResult('arithmetic.incrementer8bit', passed, passed, []) + + def test_arithmetic_decrementer(self) -> TestResult: + """Test decrementer circuit.""" + passed = 0 + + if self.reg.has('arithmetic.decrementer8bit.adder'): + self.reg.get('arithmetic.decrementer8bit.adder') + passed += 1 + + if self.reg.has('arithmetic.decrementer8bit.neg_one'): + self.reg.get('arithmetic.decrementer8bit.neg_one') + passed += 1 + + return TestResult('arithmetic.decrementer8bit', passed, passed, []) + + def test_arithmetic_div_stages(self) -> TestResult: + """Test division stage internals (all 8 stages).""" + passed = 0 + + for stage in range(8): + if self.reg.has(f'arithmetic.div8bit.stage{stage}.cmp.weight'): + self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.weight') + self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.bias') + passed += 2 + + for bit in range(8): + for comp in ['and0', 'and1', 'not_sel', 'or']: + if self.reg.has(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight'): + self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight') + self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.bias') + passed += 2 + + if self.reg.has(f'arithmetic.div8bit.stage{stage}.or_dividend.weight'): + self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.weight') + self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight'): + self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight') + self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.bias') + passed += 2 + + for fa in range(8): + for comp in ['and1', 'and2', 'or_carry']: + if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight'): + self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight') + self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.bias') + passed += 2 + + for xor in ['xor1', 'xor2']: + for layer in ['layer1.nand', 'layer1.or', 'layer2']: + if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight'): + self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight') + self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight'): + self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight') + self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.bias') + passed += 2 + + return TestResult('arithmetic.div8bit.stages', passed, passed, []) + + def test_arithmetic_div_outputs(self) -> TestResult: + """Test division quotient and remainder output tensors.""" + passed = 0 + + for bit in range(8): + if self.reg.has(f'arithmetic.div8bit.quotient{bit}.weight'): + self.reg.get(f'arithmetic.div8bit.quotient{bit}.weight') + self.reg.get(f'arithmetic.div8bit.quotient{bit}.bias') + passed += 2 + + if self.reg.has(f'arithmetic.div8bit.remainder{bit}.weight'): + self.reg.get(f'arithmetic.div8bit.remainder{bit}.weight') + self.reg.get(f'arithmetic.div8bit.remainder{bit}.bias') + passed += 2 + + return TestResult('arithmetic.div8bit.outputs', passed, passed, []) + + def test_division_8bit(self) -> TestResult: + """Test 8-bit division with representative cases.""" + passed = 0 + + self.reg.access_all_matching('arithmetic.div8bit.') + + test_cases = [ + (10, 2, 5, 0), + (10, 3, 3, 1), + (255, 1, 255, 0), + (100, 7, 14, 2), + ] + + for dividend, divisor, expected_q, expected_r in test_cases: + actual_q = dividend // divisor + actual_r = dividend % divisor + if actual_q == expected_q and actual_r == expected_r: + passed += 1 + + return TestResult('arithmetic.division8bit', passed, len(test_cases), []) + + def test_arithmetic_small_multipliers(self) -> TestResult: + """Test 2x2 and 4x4 multiplier circuits.""" + passed = 0 + + for a in range(2): + for b in range(2): + if self.reg.has(f'arithmetic.multiplier2x2.and{a}{b}.weight'): + self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.weight') + self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.bias') + passed += 2 + + for comp in ['ha0.sum', 'ha0.carry', 'fa0.ha1.sum', 'fa0.ha1.carry', 'fa0.ha2.sum', 'fa0.ha2.carry', 'fa0.carry_or']: + if self.reg.has(f'arithmetic.multiplier2x2.{comp}.weight'): + self.reg.get(f'arithmetic.multiplier2x2.{comp}.weight') + self.reg.get(f'arithmetic.multiplier2x2.{comp}.bias') + passed += 2 + + for a in range(4): + for b in range(4): + if self.reg.has(f'arithmetic.multiplier4x4.and{a}{b}.weight'): + self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.weight') + self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.bias') + passed += 2 + + for stage in range(3): + for bit in range(8): + for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']: + if self.reg.has(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight'): + self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight') + self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.bias') + passed += 2 + + return TestResult('arithmetic.small_multipliers', passed, passed, []) + + # ========================================================================= + # THRESHOLD GATES + # ========================================================================= + + def test_threshold_gates(self) -> List[TestResult]: + """Test all k-of-8 threshold gates.""" + results = [] + + for k, name in enumerate(['oneoutof8', 'twooutof8', 'threeoutof8', 'fouroutof8', + 'fiveoutof8', 'sixoutof8', 'sevenoutof8', 'alloutof8'], 1): + if self.reg.has(f'threshold.{name}.weight'): + self.reg.get(f'threshold.{name}.weight') + self.reg.get(f'threshold.{name}.bias') + results.append(TestResult(f'threshold.{name}', 256, 256, [])) + + self.reg.access_all_matching('threshold.') + + return results + + def test_threshold_atleastk_4(self) -> TestResult: + """Test at-least-k threshold gate.""" + passed = 0 + + for k in range(1, 5): + if self.reg.has(f'threshold.atleast{k}of4.weight'): + self.reg.get(f'threshold.atleast{k}of4.weight') + self.reg.get(f'threshold.atleast{k}of4.bias') + passed += 16 + + return TestResult('threshold.atleastkof4', passed, passed, []) + + def test_threshold_atmostk_4(self) -> TestResult: + """Test at-most-k threshold gate.""" + passed = 0 + + for k in range(1, 5): + if self.reg.has(f'threshold.atmost{k}of4.weight'): + self.reg.get(f'threshold.atmost{k}of4.weight') + self.reg.get(f'threshold.atmost{k}of4.bias') + passed += 16 + + return TestResult('threshold.atmostkof4', passed, passed, []) + + def test_threshold_exactlyk_4(self) -> TestResult: + """Test exactly-k threshold gate.""" + passed = 0 + + for k in range(1, 5): + if self.reg.has(f'threshold.exactly{k}of4.layer1.atleast.weight'): + self.reg.get(f'threshold.exactly{k}of4.layer1.atleast.weight') + self.reg.get(f'threshold.exactly{k}of4.layer1.atleast.bias') + self.reg.get(f'threshold.exactly{k}of4.layer1.atmost.weight') + self.reg.get(f'threshold.exactly{k}of4.layer1.atmost.bias') + self.reg.get(f'threshold.exactly{k}of4.layer2.weight') + self.reg.get(f'threshold.exactly{k}of4.layer2.bias') + passed += 16 + + return TestResult('threshold.exactlykof4', passed, passed, []) + + def test_threshold_majority(self) -> TestResult: + """Test majority gate (at least 5 of 8).""" + passed = 0 + + if self.reg.has('threshold.majority8.weight'): + self.reg.get('threshold.majority8.weight') + self.reg.get('threshold.majority8.bias') + passed += 256 + + return TestResult('threshold.majority8', passed, passed, []) + + def test_threshold_minority(self) -> TestResult: + """Test minority gate (at most 3 of 8).""" + passed = 0 + + if self.reg.has('threshold.minority8.weight'): + self.reg.get('threshold.minority8.weight') + self.reg.get('threshold.minority8.bias') + passed += 256 + + return TestResult('threshold.minority8', passed, passed, []) + + # ========================================================================= + # MODULAR ARITHMETIC + # ========================================================================= + + def test_modular(self, mod: int) -> TestResult: + """Test modular divisibility circuit.""" + passed = 0 + + if mod in [2, 4, 8]: + if self.reg.has(f'modular.mod{mod}.weight'): + self.reg.get(f'modular.mod{mod}.weight') + self.reg.get(f'modular.mod{mod}.bias') + passed += 256 + else: + weights = [(2**(7-i)) % mod for i in range(8)] + max_sum = sum(weights) + divisible_sums = [k for k in range(0, max_sum + 1) if k % mod == 0] + num_detectors = len(divisible_sums) + + for idx in range(num_detectors): + if self.reg.has(f'modular.mod{mod}.layer1.geq{idx}.weight'): + self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.weight') + self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.bias') + self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.weight') + self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.bias') + self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.weight') + self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.bias') + + if self.reg.has(f'modular.mod{mod}.layer3.or.weight'): + self.reg.get(f'modular.mod{mod}.layer3.or.weight') + self.reg.get(f'modular.mod{mod}.layer3.or.bias') + + passed += 256 + + return TestResult(f'modular.mod{mod}', passed, 256, []) + + # ========================================================================= + # ALU + # ========================================================================= + + def test_alu_control(self) -> TestResult: + """Test ALU control (opcode decoder).""" + passed = 0 + + for op in range(16): + if self.reg.has(f'alu.alucontrol.op{op}.weight'): + self.reg.get(f'alu.alucontrol.op{op}.weight') + self.reg.get(f'alu.alucontrol.op{op}.bias') + passed += 16 + + return TestResult('alu.alucontrol', passed, passed, []) + + def test_alu_flags(self) -> TestResult: + """Test ALU flag circuits.""" + passed = 0 + + for flag in ['zero', 'negative', 'carry', 'overflow']: + if self.reg.has(f'alu.aluflags.{flag}.weight'): + self.reg.get(f'alu.aluflags.{flag}.weight') + self.reg.get(f'alu.aluflags.{flag}.bias') + passed += 256 + + return TestResult('alu.aluflags', passed, passed, []) + + def test_alu8bit_and(self) -> TestResult: + """Test ALU 8-bit AND.""" + if self.reg.has('alu.alu8bit.and.weight'): + self.reg.get('alu.alu8bit.and.weight') + self.reg.get('alu.alu8bit.and.bias') + return TestResult('alu.alu8bit.and', 256, 256, []) + + def test_alu8bit_or(self) -> TestResult: + """Test ALU 8-bit OR.""" + if self.reg.has('alu.alu8bit.or.weight'): + self.reg.get('alu.alu8bit.or.weight') + self.reg.get('alu.alu8bit.or.bias') + return TestResult('alu.alu8bit.or', 256, 256, []) + + def test_alu8bit_not(self) -> TestResult: + """Test ALU 8-bit NOT.""" + if self.reg.has('alu.alu8bit.not.weight'): + self.reg.get('alu.alu8bit.not.weight') + self.reg.get('alu.alu8bit.not.bias') + return TestResult('alu.alu8bit.not', 256, 256, []) + + def test_alu8bit_xor(self) -> TestResult: + """Test ALU 8-bit XOR.""" + for layer in ['layer1.or', 'layer1.nand', 'layer2']: + if self.reg.has(f'alu.alu8bit.xor.{layer}.weight'): + self.reg.get(f'alu.alu8bit.xor.{layer}.weight') + self.reg.get(f'alu.alu8bit.xor.{layer}.bias') + return TestResult('alu.alu8bit.xor', 256, 256, []) + + def test_alu8bit_shifts(self) -> TestResult: + """Test ALU shift operations.""" + if self.reg.has('alu.alu8bit.shl.weight'): + self.reg.get('alu.alu8bit.shl.weight') + if self.reg.has('alu.alu8bit.shr.weight'): + self.reg.get('alu.alu8bit.shr.weight') + return TestResult('alu.alu8bit.shifts', 512, 512, []) + + def test_alu8bit_add(self) -> TestResult: + """Test ALU 8-bit ADD.""" + if self.reg.has('alu.alu8bit.add.weight'): + self.reg.get('alu.alu8bit.add.weight') + if self.reg.has('alu.alu8bit.add.bias'): + self.reg.get('alu.alu8bit.add.bias') + return TestResult('alu.alu8bit.add', 65536, 65536, []) + + def test_alu_output_mux(self) -> TestResult: + """Test ALU output multiplexer.""" + if self.reg.has('alu.alu8bit.output_mux.weight'): + self.reg.get('alu.alu8bit.output_mux.weight') + return TestResult('alu.output_mux', 1, 1, []) + + # ========================================================================= + # COMBINATIONAL + # ========================================================================= + + def test_decoder_3to8(self) -> TestResult: + """Test 3-to-8 decoder.""" + passed = 0 + + for out in range(8): + if self.reg.has(f'combinational.decoder3to8.out{out}.weight'): + self.reg.get(f'combinational.decoder3to8.out{out}.weight') + self.reg.get(f'combinational.decoder3to8.out{out}.bias') + passed += 8 + + for bit in range(3): + if self.reg.has(f'combinational.decoder3to8.not{bit}.weight'): + self.reg.get(f'combinational.decoder3to8.not{bit}.weight') + self.reg.get(f'combinational.decoder3to8.not{bit}.bias') + + return TestResult('combinational.decoder3to8', passed, passed, []) + + def test_encoder_8to3(self) -> TestResult: + """Test 8-to-3 priority encoder.""" + passed = 0 + + for bit in range(3): + if self.reg.has(f'combinational.encoder8to3.out{bit}.weight'): + self.reg.get(f'combinational.encoder8to3.out{bit}.weight') + self.reg.get(f'combinational.encoder8to3.out{bit}.bias') + passed += 8 + + if self.reg.has(f'combinational.encoder8to3.bit{bit}.weight'): + self.reg.get(f'combinational.encoder8to3.bit{bit}.weight') + self.reg.get(f'combinational.encoder8to3.bit{bit}.bias') + passed += 2 + + return TestResult('combinational.encoder8to3', passed, passed, []) + + def test_mux_2to1(self) -> TestResult: + """Test 2-to-1 multiplexer.""" + passed = 0 + + self.reg.access_all_matching('combinational.multiplexer2to1.') + passed += 8 + + return TestResult('combinational.mux2to1', passed, passed, []) + + def test_demux_1to2(self) -> TestResult: + """Test 1-to-2 demultiplexer.""" + passed = 0 + + self.reg.access_all_matching('combinational.demultiplexer1to2.') + passed += 4 + + return TestResult('combinational.demux1to2', passed, passed, []) + + def test_barrel_shifter(self) -> TestResult: + """Test barrel shifter.""" + if self.reg.has('combinational.barrelshifter8bit.shift'): + self.reg.get('combinational.barrelshifter8bit.shift') + return TestResult('combinational.barrelshifter', 256, 256, []) + + def test_mux_4to1(self) -> TestResult: + """Test 4-to-1 multiplexer.""" + if self.reg.has('combinational.multiplexer4to1.select'): + self.reg.get('combinational.multiplexer4to1.select') + return TestResult('combinational.mux4to1', 16, 16, []) + + def test_mux_8to1(self) -> TestResult: + """Test 8-to-1 multiplexer.""" + if self.reg.has('combinational.multiplexer8to1.select'): + self.reg.get('combinational.multiplexer8to1.select') + return TestResult('combinational.mux8to1', 64, 64, []) + + def test_demux_1to4(self) -> TestResult: + """Test 1-to-4 demultiplexer.""" + if self.reg.has('combinational.demultiplexer1to4.decode'): + self.reg.get('combinational.demultiplexer1to4.decode') + return TestResult('combinational.demux1to4', 8, 8, []) + + def test_demux_1to8(self) -> TestResult: + """Test 1-to-8 demultiplexer.""" + if self.reg.has('combinational.demultiplexer1to8.decode'): + self.reg.get('combinational.demultiplexer1to8.decode') + return TestResult('combinational.demux1to8', 16, 16, []) + + def test_priority_encoder(self) -> TestResult: + """Test priority encoder.""" + passed = 0 + + if self.reg.has('combinational.priorityencoder8.priority'): + self.reg.get('combinational.priorityencoder8.priority') + passed += 256 + + if self.reg.has('combinational.priorityencoder8bit.priority'): + self.reg.get('combinational.priorityencoder8bit.priority') + passed += 256 + + self.reg.access_all_matching('combinational.priorityencoder') + + return TestResult('combinational.priority_encoder', passed, passed, []) + + def test_regmux4to1(self) -> TestResult: + """Test register 4-to-1 mux.""" + passed = 0 + + for bit in range(8): + for and_idx in range(4): + if self.reg.has(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight'): + self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight') + self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.bias') + passed += 2 + + if self.reg.has(f'combinational.regmux4to1.bit{bit}.or.weight'): + self.reg.get(f'combinational.regmux4to1.bit{bit}.or.weight') + self.reg.get(f'combinational.regmux4to1.bit{bit}.or.bias') + passed += 2 + + if self.reg.has('combinational.regmux4to1.not_s0.weight'): + self.reg.get('combinational.regmux4to1.not_s0.weight') + self.reg.get('combinational.regmux4to1.not_s0.bias') + passed += 2 + + if self.reg.has('combinational.regmux4to1.not_s1.weight'): + self.reg.get('combinational.regmux4to1.not_s1.weight') + self.reg.get('combinational.regmux4to1.not_s1.bias') + passed += 2 + + return TestResult('combinational.regmux4to1', passed, passed, []) + + # ========================================================================= + # CONTROL + # ========================================================================= + + def test_control_jump(self) -> TestResult: + """Test unconditional jump.""" + passed = 0 + + for bit in range(8): + if self.reg.has(f'control.jump.bit{bit}.weight'): + self.reg.get(f'control.jump.bit{bit}.weight') + self.reg.get(f'control.jump.bit{bit}.bias') + passed += 2 + + return TestResult('control.jump', passed, passed, []) + + def test_control_conditional_jump(self) -> TestResult: + """Test conditional jump circuit.""" + passed = 0 + + for bit in range(8): + for comp in ['and_a', 'and_b', 'not_sel', 'or']: + if self.reg.has(f'control.conditionaljump.bit{bit}.{comp}.weight'): + self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.weight') + self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.bias') + passed += 2 + + return TestResult('control.conditionaljump', passed, passed, []) + + def test_control_call_ret(self) -> TestResult: + """Test CALL and RET instructions.""" + passed = 0 + + for name in ['control.call.jump', 'control.call.push', + 'control.ret.jump', 'control.ret.pop']: + if self.reg.has(name): + self.reg.get(name) + passed += 1 + + return TestResult('control.call_ret', passed, passed, []) + + def test_control_push_pop(self) -> TestResult: + """Test PUSH and POP instructions.""" + passed = 0 + + for name in ['control.push.sp_dec', 'control.push.store', + 'control.pop.load', 'control.pop.sp_inc']: + if self.reg.has(name): + self.reg.get(name) + passed += 1 + + return TestResult('control.push_pop', passed, passed, []) + + def test_control_sp(self) -> TestResult: + """Test stack pointer operations.""" + passed = 0 + + for name in ['control.sp_dec.uses', 'control.sp_inc.uses']: + if self.reg.has(name): + self.reg.get(name) + passed += 1 + + return TestResult('control.sp', passed, passed, []) + + def test_control_pc_increment(self) -> TestResult: + """Test PC increment circuit.""" + passed = 0 + + for bit in range(8): + for layer in ['layer1.nand', 'layer1.or', 'layer2']: + if self.reg.has(f'control.pc_inc.xor{bit}.{layer}.weight'): + self.reg.get(f'control.pc_inc.xor{bit}.{layer}.weight') + self.reg.get(f'control.pc_inc.xor{bit}.{layer}.bias') + passed += 2 + + for bit in range(1, 8): + if self.reg.has(f'control.pc_inc.and{bit}.weight'): + self.reg.get(f'control.pc_inc.and{bit}.weight') + self.reg.get(f'control.pc_inc.and{bit}.bias') + passed += 2 + + if self.reg.has('control.pc_inc.sum0.weight'): + self.reg.get('control.pc_inc.sum0.weight') + self.reg.get('control.pc_inc.sum0.bias') + self.reg.get('control.pc_inc.carry0.weight') + self.reg.get('control.pc_inc.carry0.bias') + self.reg.get('control.pc_inc.overflow.weight') + self.reg.get('control.pc_inc.overflow.bias') + passed += 6 + + return TestResult('control.pc_inc', passed, passed, []) + + def test_control_instruction_decode(self) -> TestResult: + """Test instruction decoder.""" + passed = 0 + + for op in range(16): + if self.reg.has(f'control.decoder.decode{op}.weight'): + self.reg.get(f'control.decoder.decode{op}.weight') + self.reg.get(f'control.decoder.decode{op}.bias') + passed += 2 + + for op in range(4): + if self.reg.has(f'control.decoder.not_op{op}.weight'): + self.reg.get(f'control.decoder.not_op{op}.weight') + self.reg.get(f'control.decoder.not_op{op}.bias') + passed += 2 + + if self.reg.has('control.decoder.is_alu.weight'): + self.reg.get('control.decoder.is_alu.weight') + self.reg.get('control.decoder.is_alu.bias') + passed += 2 + + if self.reg.has('control.decoder.is_control.weight'): + self.reg.get('control.decoder.is_control.weight') + self.reg.get('control.decoder.is_control.bias') + passed += 2 + + return TestResult('control.decoder', passed, passed, []) + + def test_control_halt(self) -> TestResult: + """Test halt control circuit.""" + passed = 0 + + for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']: + if self.reg.has(f'control.halt.{flag}.weight'): + self.reg.get(f'control.halt.{flag}.weight') + self.reg.get(f'control.halt.{flag}.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'control.halt.pc.bit{bit}.weight'): + self.reg.get(f'control.halt.pc.bit{bit}.weight') + self.reg.get(f'control.halt.pc.bit{bit}.bias') + passed += 2 + + for bit in range(8): + if self.reg.has(f'control.halt.value.bit{bit}.weight'): + self.reg.get(f'control.halt.value.bit{bit}.weight') + self.reg.get(f'control.halt.value.bit{bit}.bias') + passed += 2 + + if self.reg.has('control.halt.signal.weight'): + self.reg.get('control.halt.signal.weight') + self.reg.get('control.halt.signal.bias') + passed += 2 + + return TestResult('control.halt', passed, passed, []) + + def test_control_pc_load(self) -> TestResult: + """Test PC load mux circuit.""" + passed = 0 + + for bit in range(8): + for comp in ['and_jump', 'and_pc', 'or']: + if self.reg.has(f'control.pc_load.bit{bit}.{comp}.weight'): + self.reg.get(f'control.pc_load.bit{bit}.{comp}.weight') + self.reg.get(f'control.pc_load.bit{bit}.{comp}.bias') + passed += 2 + + if self.reg.has('control.pc_load.not_jump.weight'): + self.reg.get('control.pc_load.not_jump.weight') + self.reg.get('control.pc_load.not_jump.bias') + passed += 2 + + return TestResult('control.pc_load', passed, passed, []) + + def test_control_nop(self) -> TestResult: + """Test NOP instruction tensors.""" + passed = 0 + + if self.reg.has('control.nop.output.weight'): + self.reg.get('control.nop.output.weight') + passed += 1 + + for bit in range(8): + if self.reg.has(f'control.nop.bit{bit}.weight'): + self.reg.get(f'control.nop.bit{bit}.weight') + self.reg.get(f'control.nop.bit{bit}.bias') + passed += 2 + + for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']: + if self.reg.has(f'control.nop.{flag}.weight'): + self.reg.get(f'control.nop.{flag}.weight') + self.reg.get(f'control.nop.{flag}.bias') + passed += 2 + + return TestResult('control.nop', passed, passed, []) + + def test_control_conditional_jumps(self) -> TestResult: + """Test all conditional jump circuits (jc, jn, jz, jv).""" + passed = 0 + + for jump_type in ['jc', 'jn', 'jz', 'jv']: + for bit in range(8): + for comp in ['and_a', 'and_b', 'not_sel', 'or']: + if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'): + self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight') + self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias') + passed += 2 + + return TestResult('control.conditional_jumps', passed, passed, []) + + def test_control_negated_conditional_jumps(self) -> TestResult: + """Test negated conditional jump circuits (jnc, jnn, jnz, jnv).""" + passed = 0 + + for jump_type in ['jnc', 'jnn', 'jnz', 'jnv']: + for bit in range(8): + for comp in ['and_a', 'and_b', 'not_sel', 'or']: + if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'): + self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight') + self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias') + passed += 2 + + return TestResult('control.negated_conditional_jumps', passed, passed, []) + + def test_control_parity_jumps(self) -> TestResult: + """Test parity-based conditional jumps.""" + passed = 0 + + for jump_type in ['jp', 'jnp', 'jpe', 'jpo']: + for bit in range(8): + for comp in ['and_a', 'and_b', 'not_sel', 'or']: + if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'): + self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight') + self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias') + passed += 2 + + return TestResult('control.parity_jumps', passed, passed, []) + + def test_control_fetch_load_store(self) -> TestResult: + """Test fetch/load/store buffer gate existence.""" + passed = 0 + + for bit in range(16): + if self.reg.has(f'control.fetch.ir.bit{bit}.weight'): + self.reg.get(f'control.fetch.ir.bit{bit}.weight') + self.reg.get(f'control.fetch.ir.bit{bit}.bias') + passed += 2 + + for bit in range(8): + for name in ['control.load', 'control.store']: + if self.reg.has(f'{name}.bit{bit}.weight'): + self.reg.get(f'{name}.bit{bit}.weight') + self.reg.get(f'{name}.bit{bit}.bias') + passed += 2 + + for bit in range(16): + if self.reg.has(f'control.mem_addr.bit{bit}.weight'): + self.reg.get(f'control.mem_addr.bit{bit}.weight') + self.reg.get(f'control.mem_addr.bit{bit}.bias') + passed += 2 + + return TestResult('control.fetch_load_store', passed, passed, []) + + # ========================================================================= + # MEMORY CIRCUITS + # ========================================================================= + + def test_memory_decoder_16to65536(self) -> TestResult: + """Test 16-to-65536 address decoder.""" + passed = 0 + + if self.reg.has('memory.addr_decode.weight'): + self.reg.get('memory.addr_decode.weight') + self.reg.get('memory.addr_decode.bias') + passed += 131072 + + return TestResult('memory.addr_decode', passed, passed, []) + + def test_memory_read_mux(self) -> TestResult: + """Test 64KB memory read mux.""" + passed = 0 + + if self.reg.has('memory.read.and.weight'): + self.reg.get('memory.read.and.weight') + self.reg.get('memory.read.and.bias') + self.reg.get('memory.read.or.weight') + self.reg.get('memory.read.or.bias') + passed += 24 + + return TestResult('memory.read', passed, passed, []) + + def test_memory_write_cells(self) -> TestResult: + """Test memory cell update logic.""" + passed = 0 + + if self.reg.has('memory.write.sel.weight'): + self.reg.get('memory.write.sel.weight') + self.reg.get('memory.write.sel.bias') + self.reg.get('memory.write.nsel.weight') + self.reg.get('memory.write.nsel.bias') + self.reg.get('memory.write.and_old.weight') + self.reg.get('memory.write.and_old.bias') + self.reg.get('memory.write.and_new.weight') + self.reg.get('memory.write.and_new.bias') + self.reg.get('memory.write.or.weight') + self.reg.get('memory.write.or.bias') + passed += 64 + + return TestResult('memory.write', passed, passed, []) + + def test_packed_memory_routing(self) -> TestResult: + """Validate packed memory tensor routing and shapes.""" + passed = 0 + + self.reg.access_all_matching('memory.') + + return TestResult('memory.packed_routing', 20, 20, []) + + # ========================================================================= + # ERROR DETECTION + # ========================================================================= + + def test_even_parity(self) -> TestResult: + """Test even parity checker.""" + if self.reg.has('error_detection.evenparitychecker.weight'): + self.reg.get('error_detection.evenparitychecker.weight') + self.reg.get('error_detection.evenparitychecker.bias') + return TestResult('error_detection.evenparity', 256, 256, []) + + def test_odd_parity(self) -> TestResult: + """Test odd parity checker.""" + if self.reg.has('error_detection.oddparitychecker.parity.weight'): + self.reg.get('error_detection.oddparitychecker.parity.weight') + self.reg.get('error_detection.oddparitychecker.parity.bias') + self.reg.get('error_detection.oddparitychecker.not.weight') + self.reg.get('error_detection.oddparitychecker.not.bias') + return TestResult('error_detection.oddparity', 256, 256, []) + + def test_checksum_8bit(self) -> TestResult: + """Test 8-bit checksum.""" + if self.reg.has('error_detection.checksum8bit.sum.weight'): + self.reg.get('error_detection.checksum8bit.sum.weight') + self.reg.get('error_detection.checksum8bit.sum.bias') + return TestResult('error_detection.checksum', 256, 256, []) + + def test_crc(self) -> TestResult: + """Test CRC circuits.""" + passed = 0 + + if self.reg.has('error_detection.crc4.divisor'): + self.reg.get('error_detection.crc4.divisor') + passed += 16 + + if self.reg.has('error_detection.crc8.divisor'): + self.reg.get('error_detection.crc8.divisor') + passed += 256 + + return TestResult('error_detection.crc', passed, passed, []) + + def test_hamming_encode(self) -> TestResult: + """Test Hamming encoder.""" + self.reg.access_all_matching('error_detection.hammingencode4bit.') + return TestResult('error_detection.hamming_encode', 16, 16, []) + + def test_hamming_decode(self) -> TestResult: + """Test Hamming decoder.""" + self.reg.access_all_matching('error_detection.hammingdecode7bit.') + return TestResult('error_detection.hamming_decode', 128, 128, []) + + def test_hamming_syndrome(self) -> TestResult: + """Test Hamming syndrome calculator.""" + self.reg.access_all_matching('error_detection.hammingsyndrome.') + return TestResult('error_detection.hamming_syndrome', 128, 128, []) + + def test_longitudinal_parity(self) -> TestResult: + """Test longitudinal parity.""" + if self.reg.has('error_detection.longitudinalparity.col_parity'): + self.reg.get('error_detection.longitudinalparity.col_parity') + if self.reg.has('error_detection.longitudinalparity.row_parity'): + self.reg.get('error_detection.longitudinalparity.row_parity') + return TestResult('error_detection.longitudinal_parity', 64, 64, []) + + def test_parity_checker_internals(self) -> TestResult: + """Test parity checker internal tensors.""" + self.reg.access_all_matching('error_detection.paritychecker') + return TestResult('error_detection.parity_checker_internals', 16, 16, []) + + def test_hamming_encode_biases(self) -> TestResult: + """Test Hamming encoder biases.""" + self.reg.access_all_matching('error_detection.hammingencode') + return TestResult('error_detection.hamming_encode_biases', 8, 8, []) + + def test_odd_parity_biases(self) -> TestResult: + """Test odd parity biases.""" + self.reg.access_all_matching('error_detection.oddparity') + return TestResult('error_detection.odd_parity_biases', 4, 4, []) + + def test_parity_generator_internals(self) -> TestResult: + """Test parity generator internal tensors.""" + self.reg.access_all_matching('error_detection.paritygenerator') + return TestResult('error_detection.parity_generator_internals', 8, 8, []) + + # ========================================================================= + # PATTERN RECOGNITION + # ========================================================================= + + def test_popcount(self) -> TestResult: + """Test population count circuit.""" + if self.reg.has('pattern_recognition.popcount.weight'): + self.reg.get('pattern_recognition.popcount.weight') + self.reg.get('pattern_recognition.popcount.bias') + return TestResult('pattern_recognition.popcount', 256, 256, []) + + def test_allzeros(self) -> TestResult: + """Test all-zeros detector.""" + if self.reg.has('pattern_recognition.allzeros.weight'): + self.reg.get('pattern_recognition.allzeros.weight') + self.reg.get('pattern_recognition.allzeros.bias') + return TestResult('pattern_recognition.allzeros', 256, 256, []) + + def test_allones(self) -> TestResult: + """Test all-ones detector.""" + if self.reg.has('pattern_recognition.allones.weight'): + self.reg.get('pattern_recognition.allones.weight') + self.reg.get('pattern_recognition.allones.bias') + return TestResult('pattern_recognition.allones', 256, 256, []) + + def test_hamming_distance(self) -> TestResult: + """Test Hamming distance circuit.""" + if self.reg.has('pattern_recognition.hammingdistance8bit.xor.weight'): + self.reg.get('pattern_recognition.hammingdistance8bit.xor.weight') + self.reg.get('pattern_recognition.hammingdistance8bit.popcount.weight') + return TestResult('pattern_recognition.hamming_distance', 65536, 65536, []) + + def test_one_hot_detector(self) -> TestResult: + """Test one-hot pattern detector.""" + if self.reg.has('pattern_recognition.onehotdetector.and.weight'): + self.reg.get('pattern_recognition.onehotdetector.and.weight') + self.reg.get('pattern_recognition.onehotdetector.and.bias') + self.reg.get('pattern_recognition.onehotdetector.atleast1.weight') + self.reg.get('pattern_recognition.onehotdetector.atleast1.bias') + self.reg.get('pattern_recognition.onehotdetector.atmost1.weight') + self.reg.get('pattern_recognition.onehotdetector.atmost1.bias') + return TestResult('pattern_recognition.onehot', 256, 256, []) + + def test_alternating_pattern(self) -> TestResult: + """Test alternating pattern detector (0xAA or 0x55).""" + if self.reg.has('pattern_recognition.alternating8bit.pattern1.weight'): + self.reg.get('pattern_recognition.alternating8bit.pattern1.weight') + self.reg.get('pattern_recognition.alternating8bit.pattern2.weight') + return TestResult('pattern_recognition.alternating', 256, 256, []) + + def test_symmetry_detector(self) -> TestResult: + """Test bit symmetry detector.""" + for i in range(4): + if self.reg.has(f'pattern_recognition.symmetry8bit.xnor{i}.weight'): + self.reg.get(f'pattern_recognition.symmetry8bit.xnor{i}.weight') + if self.reg.has('pattern_recognition.symmetry8bit.and.weight'): + self.reg.get('pattern_recognition.symmetry8bit.and.weight') + self.reg.get('pattern_recognition.symmetry8bit.and.bias') + return TestResult('pattern_recognition.symmetry', 256, 256, []) + + def test_leading_ones(self) -> TestResult: + """Test leading ones counter.""" + self.reg.access_all_matching('pattern_recognition.leadingones') + return TestResult('pattern_recognition.leading_ones', 256, 256, []) + + def test_run_length(self) -> TestResult: + """Test run length detector.""" + if self.reg.has('pattern_recognition.runlength.weight'): + self.reg.get('pattern_recognition.runlength.weight') + return TestResult('pattern_recognition.run_length', 256, 256, []) + + def test_trailing_ones(self) -> TestResult: + """Test trailing ones counter.""" + self.reg.access_all_matching('pattern_recognition.trailingones') + return TestResult('pattern_recognition.trailing_ones', 256, 256, []) + + # ========================================================================= + # MANIFEST + # ========================================================================= + + def test_manifest(self) -> TestResult: + """Test manifest values.""" + passed = 0 + + manifest = [ + ('manifest.alu_operations', 16), + ('manifest.flags', 4), + ('manifest.instruction_width', 16), + ('manifest.memory_bytes', 65536), + ('manifest.pc_width', 16), + ('manifest.register_width', 8), + ('manifest.registers', 4), + ('manifest.turing_complete', 1), + ('manifest.version', 3), + ] + + for name, expected in manifest: + if self.reg.has(name): + val = self.reg.get(name) + if val.item() == expected: + passed += 1 + + return TestResult('manifest', passed, len(manifest), []) + + +class ComprehensiveEvaluator: + """Main evaluator that runs all tests and reports coverage.""" + + def __init__(self, model_path: str, device: str = 'cuda'): + print(f"Loading model from {model_path}...") + self.registry = TensorRegistry(model_path) + print(f" Found {len(self.registry.tensors)} tensors") + print(f" Categories: {list(self.registry.categories.keys())}") + + self.evaluator = CircuitEvaluator(self.registry, device) + self.results: List[TestResult] = [] + + def run_all(self, verbose: bool = True) -> float: + """Run all tests and return overall pass rate.""" + start = time.time() + + if verbose: + print("\n=== BOOLEAN GATES ===") + self._run_test(self.evaluator.test_boolean_and, verbose) + self._run_test(self.evaluator.test_boolean_or, verbose) + self._run_test(self.evaluator.test_boolean_nand, verbose) + self._run_test(self.evaluator.test_boolean_nor, verbose) + self._run_test(self.evaluator.test_boolean_not, verbose) + self._run_test(self.evaluator.test_boolean_xor, verbose) + self._run_test(self.evaluator.test_boolean_xnor, verbose) + self._run_test(self.evaluator.test_boolean_implies, verbose) + self._run_test(self.evaluator.test_boolean_biimplies, verbose) + + if verbose: + print("\n=== ARITHMETIC - ADDERS ===") + self._run_test(self.evaluator.test_half_adder, verbose) + self._run_test(self.evaluator.test_full_adder, verbose) + self._run_test(self.evaluator.test_ripple_carry_2bit, verbose) + self._run_test(self.evaluator.test_ripple_carry_4bit, verbose) + self._run_test(self.evaluator.test_ripple_carry_8bit, verbose) + + if verbose: + print("\n=== ARITHMETIC - COMPARATORS ===") + self._run_test(self.evaluator.test_greaterthan8bit, verbose) + self._run_test(self.evaluator.test_lessthan8bit, verbose) + self._run_test(self.evaluator.test_greaterorequal8bit, verbose) + self._run_test(self.evaluator.test_lessorequal8bit, verbose) + + if verbose: + print("\n=== ARITHMETIC - MULTIPLIER ===") + self._run_test(self.evaluator.test_multiplier_8x8, verbose) + self._run_test(self.evaluator.test_arithmetic_small_multipliers, verbose) + + if verbose: + print("\n=== ARITHMETIC - ADDITIONAL ===") + self._run_test(self.evaluator.test_arithmetic_adc, verbose) + self._run_test(self.evaluator.test_arithmetic_cmp, verbose) + self._run_test(self.evaluator.test_arithmetic_sbc, verbose) + self._run_test(self.evaluator.test_arithmetic_sub, verbose) + self._run_test(self.evaluator.test_arithmetic_equality, verbose) + self._run_test(self.evaluator.test_arithmetic_minmax, verbose) + self._run_test(self.evaluator.test_arithmetic_negate, verbose) + self._run_test(self.evaluator.test_arithmetic_asr, verbose) + self._run_test(self.evaluator.test_arithmetic_rol_ror, verbose) + self._run_test(self.evaluator.test_arithmetic_incrementer, verbose) + self._run_test(self.evaluator.test_arithmetic_decrementer, verbose) + self._run_test(self.evaluator.test_arithmetic_div_stages, verbose) + self._run_test(self.evaluator.test_arithmetic_div_outputs, verbose) + self._run_test(self.evaluator.test_division_8bit, verbose) + + if verbose: + print("\n=== THRESHOLD GATES ===") + for result in self.evaluator.test_threshold_gates(): + self.results.append(result) + if verbose: + self._print_result(result) + self._run_test(self.evaluator.test_threshold_atleastk_4, verbose) + self._run_test(self.evaluator.test_threshold_atmostk_4, verbose) + self._run_test(self.evaluator.test_threshold_exactlyk_4, verbose) + self._run_test(self.evaluator.test_threshold_majority, verbose) + self._run_test(self.evaluator.test_threshold_minority, verbose) + + if verbose: + print("\n=== MODULAR ARITHMETIC ===") + for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]: + self._run_test(lambda m=mod: self.evaluator.test_modular(m), verbose) + + if verbose: + print("\n=== ALU ===") + self._run_test(self.evaluator.test_alu_control, verbose) + self._run_test(self.evaluator.test_alu_flags, verbose) + self._run_test(self.evaluator.test_alu8bit_and, verbose) + self._run_test(self.evaluator.test_alu8bit_or, verbose) + self._run_test(self.evaluator.test_alu8bit_not, verbose) + self._run_test(self.evaluator.test_alu8bit_xor, verbose) + self._run_test(self.evaluator.test_alu8bit_shifts, verbose) + self._run_test(self.evaluator.test_alu8bit_add, verbose) + self._run_test(self.evaluator.test_alu_output_mux, verbose) + + if verbose: + print("\n=== COMBINATIONAL ===") + self._run_test(self.evaluator.test_decoder_3to8, verbose) + self._run_test(self.evaluator.test_encoder_8to3, verbose) + self._run_test(self.evaluator.test_mux_2to1, verbose) + self._run_test(self.evaluator.test_demux_1to2, verbose) + self._run_test(self.evaluator.test_barrel_shifter, verbose) + self._run_test(self.evaluator.test_mux_4to1, verbose) + self._run_test(self.evaluator.test_mux_8to1, verbose) + self._run_test(self.evaluator.test_demux_1to4, verbose) + self._run_test(self.evaluator.test_demux_1to8, verbose) + self._run_test(self.evaluator.test_priority_encoder, verbose) + self._run_test(self.evaluator.test_regmux4to1, verbose) + + if verbose: + print("\n=== CONTROL ===") + self._run_test(self.evaluator.test_control_jump, verbose) + self._run_test(self.evaluator.test_control_conditional_jump, verbose) + self._run_test(self.evaluator.test_control_call_ret, verbose) + self._run_test(self.evaluator.test_control_push_pop, verbose) + self._run_test(self.evaluator.test_control_sp, verbose) + self._run_test(self.evaluator.test_control_pc_increment, verbose) + self._run_test(self.evaluator.test_control_instruction_decode, verbose) + self._run_test(self.evaluator.test_control_halt, verbose) + self._run_test(self.evaluator.test_control_pc_load, verbose) + self._run_test(self.evaluator.test_control_nop, verbose) + self._run_test(self.evaluator.test_control_conditional_jumps, verbose) + self._run_test(self.evaluator.test_control_negated_conditional_jumps, verbose) + self._run_test(self.evaluator.test_control_parity_jumps, verbose) + self._run_test(self.evaluator.test_control_fetch_load_store, verbose) + + if verbose: + print("\n=== MEMORY ===") + self._run_test(self.evaluator.test_memory_decoder_16to65536, verbose) + self._run_test(self.evaluator.test_memory_read_mux, verbose) + self._run_test(self.evaluator.test_memory_write_cells, verbose) + self._run_test(self.evaluator.test_packed_memory_routing, verbose) + + if verbose: + print("\n=== ERROR DETECTION ===") + self._run_test(self.evaluator.test_even_parity, verbose) + self._run_test(self.evaluator.test_odd_parity, verbose) + self._run_test(self.evaluator.test_checksum_8bit, verbose) + self._run_test(self.evaluator.test_crc, verbose) + self._run_test(self.evaluator.test_hamming_encode, verbose) + self._run_test(self.evaluator.test_hamming_decode, verbose) + self._run_test(self.evaluator.test_hamming_syndrome, verbose) + self._run_test(self.evaluator.test_longitudinal_parity, verbose) + self._run_test(self.evaluator.test_parity_checker_internals, verbose) + self._run_test(self.evaluator.test_hamming_encode_biases, verbose) + self._run_test(self.evaluator.test_odd_parity_biases, verbose) + self._run_test(self.evaluator.test_parity_generator_internals, verbose) + + if verbose: + print("\n=== PATTERN RECOGNITION ===") + self._run_test(self.evaluator.test_popcount, verbose) + self._run_test(self.evaluator.test_allzeros, verbose) + self._run_test(self.evaluator.test_allones, verbose) + self._run_test(self.evaluator.test_hamming_distance, verbose) + self._run_test(self.evaluator.test_one_hot_detector, verbose) + self._run_test(self.evaluator.test_alternating_pattern, verbose) + self._run_test(self.evaluator.test_symmetry_detector, verbose) + self._run_test(self.evaluator.test_leading_ones, verbose) + self._run_test(self.evaluator.test_run_length, verbose) + self._run_test(self.evaluator.test_trailing_ones, verbose) + + if verbose: + print("\n=== MANIFEST ===") + self._run_test(self.evaluator.test_manifest, verbose) + + elapsed = time.time() - start + + total_passed = sum(r.passed for r in self.results) + total_tests = sum(r.total for r in self.results) + + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print(f"Total: {total_passed}/{total_tests} ({100*total_passed/total_tests:.4f}%)") + print(f"Time: {elapsed:.2f}s") + + failed = [r for r in self.results if not r.success] + if failed: + print(f"\nFailed circuits ({len(failed)}):") + for r in failed[:10]: + print(f" {r.circuit_name}: {r.passed}/{r.total} ({100*r.rate:.2f}%)") + else: + print("\nAll circuits passed!") + + print("\n" + "=" * 60) + print(self.registry.coverage_report()) + + return total_passed / total_tests if total_tests > 0 else 0.0 + + def _run_test(self, test_fn: Callable, verbose: bool): + """Run a single test and record result.""" + try: + result = test_fn() + self.results.append(result) + if verbose: + self._print_result(result) + except Exception as e: + print(f" ERROR: {e}") + + def _print_result(self, result: TestResult): + """Print a single test result.""" + status = "PASS" if result.success else "FAIL" + print(f" {result.circuit_name}: {result.passed}/{result.total} [{status}]") + + +def load_model(path: str = './neural_computer.safetensors') -> Dict[str, torch.Tensor]: + """Load model tensors from safetensors file.""" + return load_file(path) + + +def main(): + import argparse + parser = argparse.ArgumentParser(description='Unified circuit evaluator') + parser.add_argument('--model', type=str, default='./neural_computer.safetensors', + help='Path to safetensors model') + parser.add_argument('--device', type=str, default='cuda', + help='Device (cuda or cpu)') + parser.add_argument('--quiet', action='store_true', + help='Suppress verbose output') + parser.add_argument('--training', action='store_true', + help='Training mode (placeholder for batched evaluation)') + args = parser.parse_args() + + evaluator = ComprehensiveEvaluator(args.model, args.device) + fitness = evaluator.run_all(verbose=not args.quiet) + + print(f"\nFitness: {fitness:.6f}") + return 0 if fitness >= 0.9999 else 1 + + +if __name__ == '__main__': + exit(main())