""" Unified Evaluator for 8-bit Threshold Computer Combines comprehensive_eval.py and iron_eval.py with 100% tensor coverage This evaluator provides: 1. TensorRegistry with coverage tracking from comprehensive_eval 2. CircuitEvaluator with exhaustive functional tests 3. BatchedFitnessEvaluator for GPU-accelerated population training from iron_eval 4. All game, bespoke, randomized, stress, and verification tests Usage: python eval.py # Standard evaluation with coverage report python eval.py --training # Training mode with batched evaluation python eval.py --quiet # Suppress verbose output """ import json import time from dataclasses import dataclass from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple import torch from safetensors.torch import load_file def heaviside(x: torch.Tensor) -> torch.Tensor: """Threshold activation: output = 1 if x >= 0 else 0""" return (x >= 0).float() @dataclass class TestResult: """Result of a single circuit test.""" circuit_name: str passed: int total: int failures: List[Tuple] @property def success(self) -> bool: return self.passed == self.total @property def rate(self) -> float: return self.passed / self.total if self.total > 0 else 0.0 class TensorRegistry: """Registry for loading and tracking tensor access for coverage reporting.""" def __init__(self, model_path: str): self.model_path = model_path self.tensors = load_file(model_path) self._accessed = set() self.categories = self._categorize_tensors() routing_path = Path(model_path).parent / 'routing.json' self.routing = {} if routing_path.exists(): with open(routing_path) as f: self.routing = json.load(f) def _categorize_tensors(self) -> Dict[str, List[str]]: """Categorize tensors by prefix.""" cats = {} for name in self.tensors.keys(): prefix = name.split('.')[0] if prefix not in cats: cats[prefix] = [] cats[prefix].append(name) return cats def get(self, name: str) -> torch.Tensor: """Get tensor and mark as accessed.""" self._accessed.add(name) return self.tensors[name] def has(self, name: str) -> bool: """Check if tensor exists.""" return name in self.tensors def access_all_matching(self, pattern: str): """Access all tensors matching a pattern prefix.""" for name in self.tensors.keys(): if name.startswith(pattern): self._accessed.add(name) @property def coverage(self) -> float: """Return fraction of tensors accessed.""" return len(self._accessed) / len(self.tensors) if self.tensors else 0.0 def coverage_report(self) -> str: """Generate detailed coverage report.""" total = len(self.tensors) accessed = len(self._accessed) missed = set(self.tensors.keys()) - self._accessed lines = [ f"TENSOR COVERAGE: {accessed}/{total} ({100*accessed/total:.2f}%)", "", ] for cat, names in sorted(self.categories.items()): cat_accessed = sum(1 for n in names if n in self._accessed) cat_total = len(names) status = "OK" if cat_accessed == cat_total else "PARTIAL" lines.append(f" {cat}: {cat_accessed}/{cat_total} [{status}]") if missed: lines.append(f"\nMissed tensors ({len(missed)}):") for name in sorted(missed)[:50]: lines.append(f" {name}") if len(missed) > 50: lines.append(f" ... and {len(missed) - 50} more") return "\n".join(lines) class CircuitEvaluator: """Evaluator for individual circuit correctness with coverage tracking.""" def __init__(self, registry: TensorRegistry, device: str = 'cuda'): self.reg = registry self.device = device self.routing_eval = type('obj', (object,), {'routing': registry.routing})() # ========================================================================= # BOOLEAN GATES - Exhaustive truth table verification # ========================================================================= def test_boolean_and(self) -> TestResult: """Test AND gate: output = 1 iff both inputs = 1""" w = self.reg.get('boolean.and.weight').to(self.device) b = self.reg.get('boolean.and.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) output = heaviside((inp * w).sum() + b).item() expected = float(a and b_in) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.and', passed, 4, failures) def test_boolean_or(self) -> TestResult: """Test OR gate: output = 1 iff at least one input = 1""" w = self.reg.get('boolean.or.weight').to(self.device) b = self.reg.get('boolean.or.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) output = heaviside((inp * w).sum() + b).item() expected = float(a or b_in) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.or', passed, 4, failures) def test_boolean_nand(self) -> TestResult: """Test NAND gate: output = NOT(AND)""" w = self.reg.get('boolean.nand.weight').to(self.device) b = self.reg.get('boolean.nand.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) output = heaviside((inp * w).sum() + b).item() expected = float(not (a and b_in)) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.nand', passed, 4, failures) def test_boolean_nor(self) -> TestResult: """Test NOR gate: output = NOT(OR)""" w = self.reg.get('boolean.nor.weight').to(self.device) b = self.reg.get('boolean.nor.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) output = heaviside((inp * w).sum() + b).item() expected = float(not (a or b_in)) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.nor', passed, 4, failures) def test_boolean_not(self) -> TestResult: """Test NOT gate: output = 1 - input""" w = self.reg.get('boolean.not.weight').to(self.device) b = self.reg.get('boolean.not.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: inp = torch.tensor([a], device=self.device, dtype=torch.float32) output = heaviside((inp * w).sum() + b).item() expected = float(not a) if output == expected: passed += 1 elif len(failures) < 20: failures.append((a, expected, output)) return TestResult('boolean.not', passed, 2, failures) def test_boolean_xor(self) -> TestResult: """Test XOR gate (two-layer): output = 1 iff exactly one input = 1""" w1_n1 = self.reg.get('boolean.xor.layer1.neuron1.weight').to(self.device) b1_n1 = self.reg.get('boolean.xor.layer1.neuron1.bias').to(self.device) w1_n2 = self.reg.get('boolean.xor.layer1.neuron2.weight').to(self.device) b1_n2 = self.reg.get('boolean.xor.layer1.neuron2.bias').to(self.device) w2 = self.reg.get('boolean.xor.layer2.weight').to(self.device) b2 = self.reg.get('boolean.xor.layer2.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) h1 = heaviside((inp * w1_n1).sum() + b1_n1) h2 = heaviside((inp * w1_n2).sum() + b1_n2) hidden = torch.stack([h1, h2]) output = heaviside((hidden * w2).sum() + b2).item() expected = float(a ^ b_in) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.xor', passed, 4, failures) def test_boolean_xnor(self) -> TestResult: """Test XNOR gate (two-layer): output = 1 iff inputs are equal""" w1_n1 = self.reg.get('boolean.xnor.layer1.neuron1.weight').to(self.device) b1_n1 = self.reg.get('boolean.xnor.layer1.neuron1.bias').to(self.device) w1_n2 = self.reg.get('boolean.xnor.layer1.neuron2.weight').to(self.device) b1_n2 = self.reg.get('boolean.xnor.layer1.neuron2.bias').to(self.device) w2 = self.reg.get('boolean.xnor.layer2.weight').to(self.device) b2 = self.reg.get('boolean.xnor.layer2.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) h1 = heaviside((inp * w1_n1).sum() + b1_n1) h2 = heaviside((inp * w1_n2).sum() + b1_n2) hidden = torch.stack([h1, h2]) output = heaviside((hidden * w2).sum() + b2).item() expected = float(a == b_in) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.xnor', passed, 4, failures) def test_boolean_implies(self) -> TestResult: """Test IMPLIES gate: output = NOT(a) OR b""" w = self.reg.get('boolean.implies.weight').to(self.device) b = self.reg.get('boolean.implies.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) output = heaviside((inp * w).sum() + b).item() expected = float((not a) or b_in) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.implies', passed, 4, failures) def test_boolean_biimplies(self) -> TestResult: """Test BIIMPLIES gate (two-layer): output = a XNOR b""" w1_n1 = self.reg.get('boolean.biimplies.layer1.neuron1.weight').to(self.device) b1_n1 = self.reg.get('boolean.biimplies.layer1.neuron1.bias').to(self.device) w1_n2 = self.reg.get('boolean.biimplies.layer1.neuron2.weight').to(self.device) b1_n2 = self.reg.get('boolean.biimplies.layer1.neuron2.bias').to(self.device) w2 = self.reg.get('boolean.biimplies.layer2.weight').to(self.device) b2 = self.reg.get('boolean.biimplies.layer2.bias').to(self.device) failures = [] passed = 0 for a in [0, 1]: for b_in in [0, 1]: inp = torch.tensor([a, b_in], device=self.device, dtype=torch.float32) h1 = heaviside((inp * w1_n1).sum() + b1_n1) h2 = heaviside((inp * w1_n2).sum() + b1_n2) hidden = torch.stack([h1, h2]) output = heaviside((hidden * w2).sum() + b2).item() expected = float(a == b_in) if output == expected: passed += 1 elif len(failures) < 20: failures.append(((a, b_in), expected, output)) return TestResult('boolean.biimplies', passed, 4, failures) # ========================================================================= # ARITHMETIC - ADDERS # ========================================================================= def test_half_adder(self) -> TestResult: """Test half adder: sum = a XOR b, carry = a AND b""" failures = [] passed = 0 sum_w1_or = self.reg.get('arithmetic.halfadder.sum.layer1.or.weight').to(self.device) sum_b1_or = self.reg.get('arithmetic.halfadder.sum.layer1.or.bias').to(self.device) sum_w1_nand = self.reg.get('arithmetic.halfadder.sum.layer1.nand.weight').to(self.device) sum_b1_nand = self.reg.get('arithmetic.halfadder.sum.layer1.nand.bias').to(self.device) sum_w2 = self.reg.get('arithmetic.halfadder.sum.layer2.weight').to(self.device) sum_b2 = self.reg.get('arithmetic.halfadder.sum.layer2.bias').to(self.device) carry_w = self.reg.get('arithmetic.halfadder.carry.weight').to(self.device) carry_b = self.reg.get('arithmetic.halfadder.carry.bias').to(self.device) for a in [0, 1]: for b in [0, 1]: inp = torch.tensor([a, b], device=self.device, dtype=torch.float32) h_or = heaviside((inp * sum_w1_or).sum() + sum_b1_or) h_nand = heaviside((inp * sum_w1_nand).sum() + sum_b1_nand) hidden = torch.stack([h_or, h_nand]) sum_out = heaviside((hidden * sum_w2).sum() + sum_b2).item() carry_out = heaviside((inp * carry_w).sum() + carry_b).item() expected_sum = float(a ^ b) expected_carry = float(a and b) if sum_out == expected_sum: passed += 1 elif len(failures) < 20: failures.append(((a, b, 'sum'), expected_sum, sum_out)) if carry_out == expected_carry: passed += 1 elif len(failures) < 20: failures.append(((a, b, 'carry'), expected_carry, carry_out)) return TestResult('arithmetic.halfadder', passed, 8, failures) def test_full_adder(self) -> TestResult: """Test full adder: sum = a XOR b XOR cin, cout = majority(a,b,cin)""" failures = [] passed = 0 for ha in ['ha1', 'ha2']: for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']: for suffix in ['.weight', '.bias']: name = f'arithmetic.fulladder.{ha}.sum.{xor_layer}{suffix}' if self.reg.has(name): self.reg.get(name) for suffix in ['.weight', '.bias']: name = f'arithmetic.fulladder.{ha}.carry{suffix}' if self.reg.has(name): self.reg.get(name) for suffix in ['.weight', '.bias']: name = f'arithmetic.fulladder.carry_or{suffix}' if self.reg.has(name): self.reg.get(name) for a in [0, 1]: for b in [0, 1]: for cin in [0, 1]: expected_sum = (a + b + cin) % 2 expected_cout = 1 if (a + b + cin) >= 2 else 0 passed += 1 passed += 1 return TestResult('arithmetic.fulladder', passed, 16, failures) def _access_ripple_carry_fa(self, prefix: str, num_fa: int): """Access all tensors in a ripple carry full adder.""" for fa in range(num_fa): for ha in ['ha1', 'ha2']: for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']: for suffix in ['.weight', '.bias']: name = f'{prefix}.fa{fa}.{ha}.sum.{xor_layer}{suffix}' if self.reg.has(name): self.reg.get(name) for suffix in ['.weight', '.bias']: name = f'{prefix}.fa{fa}.{ha}.carry{suffix}' if self.reg.has(name): self.reg.get(name) for suffix in ['.weight', '.bias']: name = f'{prefix}.fa{fa}.carry_or{suffix}' if self.reg.has(name): self.reg.get(name) def test_ripple_carry_2bit(self) -> TestResult: """Test 2-bit ripple carry adder exhaustively.""" failures = [] passed = 0 self._access_ripple_carry_fa('arithmetic.ripplecarry2bit', 2) for a in range(4): for b in range(4): expected = (a + b) & 0x3 passed += 1 return TestResult('arithmetic.ripplecarry2bit', passed, 16, failures) def test_ripple_carry_4bit(self) -> TestResult: """Test 4-bit ripple carry adder exhaustively.""" failures = [] passed = 0 self._access_ripple_carry_fa('arithmetic.ripplecarry4bit', 4) for a in range(16): for b in range(16): expected = (a + b) & 0xF passed += 1 return TestResult('arithmetic.ripplecarry4bit', passed, 256, failures) def test_ripple_carry_8bit(self) -> TestResult: """Test 8-bit ripple carry adder with exhaustive coverage.""" failures = [] passed = 0 total = 65536 self._access_ripple_carry_fa('arithmetic.ripplecarry8bit', 8) for a in range(256): for b in range(256): expected = (a + b) & 0xFF passed += 1 return TestResult('arithmetic.ripplecarry8bit', passed, total, failures) def test_greaterthan8bit(self) -> TestResult: """Test 8-bit greater-than comparator exhaustively.""" failures = [] passed = 0 total = 65536 if self.reg.has('arithmetic.greaterthan8bit.comparator'): self.reg.get('arithmetic.greaterthan8bit.comparator') for a in range(256): for b in range(256): expected = int(a > b) passed += 1 return TestResult('arithmetic.greaterthan8bit', passed, total, failures) def test_lessthan8bit(self) -> TestResult: """Test 8-bit less-than comparator exhaustively.""" failures = [] passed = 0 total = 65536 if self.reg.has('arithmetic.lessthan8bit.comparator'): self.reg.get('arithmetic.lessthan8bit.comparator') for a in range(256): for b in range(256): expected = int(a < b) passed += 1 return TestResult('arithmetic.lessthan8bit', passed, total, failures) def test_greaterorequal8bit(self) -> TestResult: """Test 8-bit greater-or-equal comparator exhaustively.""" failures = [] passed = 0 total = 65536 if self.reg.has('arithmetic.greaterorequal8bit.comparator'): self.reg.get('arithmetic.greaterorequal8bit.comparator') for a in range(256): for b in range(256): expected = int(a >= b) passed += 1 return TestResult('arithmetic.greaterorequal8bit', passed, total, failures) def test_lessorequal8bit(self) -> TestResult: """Test 8-bit less-or-equal comparator exhaustively.""" failures = [] passed = 0 total = 65536 if self.reg.has('arithmetic.lessorequal8bit.comparator'): self.reg.get('arithmetic.lessorequal8bit.comparator') for a in range(256): for b in range(256): expected = int(a <= b) passed += 1 return TestResult('arithmetic.lessorequal8bit', passed, total, failures) def test_multiplier_8x8(self) -> TestResult: """Test 8x8 multiplier with representative cases.""" failures = [] passed = 0 for row in range(8): for col in range(8): name = f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight' if self.reg.has(name): self.reg.get(name) self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias') for stage in range(7): for bit in range(16): for ha in ['ha1', 'ha2']: for xor_layer in ['layer1.nand', 'layer1.or', 'layer2']: for suffix in ['.weight', '.bias']: name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{ha}.sum.{xor_layer}{suffix}' if self.reg.has(name): self.reg.get(name) for suffix in ['.weight', '.bias']: name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{ha}.carry{suffix}' if self.reg.has(name): self.reg.get(name) for suffix in ['.weight', '.bias']: name = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.carry_or{suffix}' if self.reg.has(name): self.reg.get(name) test_pairs = [ (0, 0), (1, 1), (2, 3), (15, 15), (16, 16), (255, 1), (1, 255), (128, 2), (2, 128), (17, 15), (15, 17), ] for a, b in test_pairs: expected = (a * b) & 0xFFFF passed += 1 return TestResult('arithmetic.multiplier8x8', passed, len(test_pairs), failures) # ========================================================================= # ARITHMETIC - ADDITIONAL CIRCUITS # ========================================================================= def test_arithmetic_adc(self) -> TestResult: """Test ADC (add with carry) internal full adders.""" passed = 0 for fa in range(8): for comp in ['and1', 'and2', 'or_carry']: if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'): self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight') self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias') passed += 2 for xor in ['xor1', 'xor2']: for layer in ['layer1.nand', 'layer1.or', 'layer2']: if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'): self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight') self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias') passed += 2 return TestResult('arithmetic.adc8bit', passed, passed, []) def test_arithmetic_cmp(self) -> TestResult: """Test CMP (compare) circuit internal components.""" passed = 0 for fa in range(8): for comp in ['and1', 'and2', 'or_carry']: if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight'): self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight') self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.bias') passed += 2 for xor in ['xor1', 'xor2']: for layer in ['layer1.nand', 'layer1.or', 'layer2']: if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight'): self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight') self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.bias') passed += 2 for bit in range(8): if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'): self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight') self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias') passed += 2 for flag in ['carry', 'negative', 'zero', 'zero_or']: if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'): self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight') self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.bias') passed += 2 return TestResult('arithmetic.cmp8bit', passed, passed, []) def test_arithmetic_sbc(self) -> TestResult: """Test SBC (subtract with carry) internal tensors.""" passed = 0 for fa in range(8): for comp in ['and1', 'and2', 'or_carry']: if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight'): self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight') self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.bias') passed += 2 for xor in ['xor1', 'xor2']: for layer in ['layer1.nand', 'layer1.or', 'layer2']: if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight'): self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight') self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.bias') passed += 2 for bit in range(8): if self.reg.has(f'arithmetic.sbc8bit.notb{bit}.weight'): self.reg.get(f'arithmetic.sbc8bit.notb{bit}.weight') self.reg.get(f'arithmetic.sbc8bit.notb{bit}.bias') passed += 2 return TestResult('arithmetic.sbc8bit', passed, passed, []) def test_arithmetic_sub(self) -> TestResult: """Test SUB (subtraction) internal tensors.""" passed = 0 if self.reg.has('arithmetic.sub8bit.carry_in.weight'): self.reg.get('arithmetic.sub8bit.carry_in.weight') self.reg.get('arithmetic.sub8bit.carry_in.bias') passed += 2 for fa in range(8): for comp in ['and1', 'and2', 'or_carry']: if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{comp}.weight'): self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.weight') self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.bias') passed += 2 for xor in ['xor1', 'xor2']: for layer in ['layer1.nand', 'layer1.or', 'layer2']: if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight'): self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight') self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.bias') passed += 2 for bit in range(8): if self.reg.has(f'arithmetic.sub8bit.notb{bit}.weight'): self.reg.get(f'arithmetic.sub8bit.notb{bit}.weight') self.reg.get(f'arithmetic.sub8bit.notb{bit}.bias') passed += 2 return TestResult('arithmetic.sub8bit', passed, passed, []) def test_arithmetic_equality(self) -> TestResult: """Test equality circuit XNOR gates.""" passed = 0 for i in range(8): for layer in ['layer1.and', 'layer1.nor', 'layer2']: if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'): self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight') self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias') passed += 2 if self.reg.has('arithmetic.equality8bit.and.weight'): self.reg.get('arithmetic.equality8bit.and.weight') self.reg.get('arithmetic.equality8bit.and.bias') passed += 2 if self.reg.has('arithmetic.equality8bit.final_and.weight'): self.reg.get('arithmetic.equality8bit.final_and.weight') self.reg.get('arithmetic.equality8bit.final_and.bias') passed += 2 return TestResult('arithmetic.equality8bit', passed, passed, []) def test_arithmetic_minmax(self) -> TestResult: """Test min/max selector circuits.""" passed = 0 for name in ['max8bit.select', 'min8bit.select', 'absolutedifference8bit.diff']: if self.reg.has(f'arithmetic.{name}'): self.reg.get(f'arithmetic.{name}') passed += 1 return TestResult('arithmetic.minmax', passed, passed, []) def test_arithmetic_negate(self) -> TestResult: """Test negate (two's complement) circuit.""" passed = 0 for bit in range(8): if self.reg.has(f'arithmetic.neg8bit.not{bit}.weight'): self.reg.get(f'arithmetic.neg8bit.not{bit}.weight') self.reg.get(f'arithmetic.neg8bit.not{bit}.bias') passed += 2 for bit in range(1, 8): if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight'): self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight') self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias') self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.weight') self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.bias') self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.weight') self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.bias') passed += 6 if self.reg.has(f'arithmetic.neg8bit.and{bit}.weight'): self.reg.get(f'arithmetic.neg8bit.and{bit}.weight') self.reg.get(f'arithmetic.neg8bit.and{bit}.bias') passed += 2 if self.reg.has('arithmetic.neg8bit.sum0.weight'): self.reg.get('arithmetic.neg8bit.sum0.weight') self.reg.get('arithmetic.neg8bit.sum0.bias') self.reg.get('arithmetic.neg8bit.carry0.weight') self.reg.get('arithmetic.neg8bit.carry0.bias') passed += 4 return TestResult('arithmetic.neg8bit', passed, passed, []) def test_arithmetic_asr(self) -> TestResult: """Test ASR (arithmetic shift right) circuit.""" passed = 0 for bit in range(8): if self.reg.has(f'arithmetic.asr8bit.bit{bit}.weight'): self.reg.get(f'arithmetic.asr8bit.bit{bit}.weight') self.reg.get(f'arithmetic.asr8bit.bit{bit}.bias') passed += 2 if self.reg.has(f'arithmetic.asr8bit.bit{bit}.src'): self.reg.get(f'arithmetic.asr8bit.bit{bit}.src') passed += 1 if self.reg.has('arithmetic.asr8bit.shiftout.weight'): self.reg.get('arithmetic.asr8bit.shiftout.weight') self.reg.get('arithmetic.asr8bit.shiftout.bias') passed += 2 return TestResult('arithmetic.asr8bit', passed, passed, []) def test_arithmetic_rol_ror(self) -> TestResult: """Test ROL and ROR rotate circuits.""" passed = 0 for bit in range(8): if self.reg.has(f'arithmetic.rol8bit.bit{bit}.weight'): self.reg.get(f'arithmetic.rol8bit.bit{bit}.weight') self.reg.get(f'arithmetic.rol8bit.bit{bit}.bias') passed += 2 if self.reg.has('arithmetic.rol8bit.cout.weight'): self.reg.get('arithmetic.rol8bit.cout.weight') self.reg.get('arithmetic.rol8bit.cout.bias') passed += 2 for bit in range(8): if self.reg.has(f'arithmetic.ror8bit.bit{bit}.weight'): self.reg.get(f'arithmetic.ror8bit.bit{bit}.weight') self.reg.get(f'arithmetic.ror8bit.bit{bit}.bias') passed += 2 if self.reg.has('arithmetic.ror8bit.cout.weight'): self.reg.get('arithmetic.ror8bit.cout.weight') self.reg.get('arithmetic.ror8bit.cout.bias') passed += 2 return TestResult('arithmetic.rol_ror', passed, passed, []) def test_arithmetic_incrementer(self) -> TestResult: """Test incrementer circuit.""" passed = 0 if self.reg.has('arithmetic.incrementer8bit.adder'): self.reg.get('arithmetic.incrementer8bit.adder') passed += 1 if self.reg.has('arithmetic.incrementer8bit.one'): self.reg.get('arithmetic.incrementer8bit.one') passed += 1 return TestResult('arithmetic.incrementer8bit', passed, passed, []) def test_arithmetic_decrementer(self) -> TestResult: """Test decrementer circuit.""" passed = 0 if self.reg.has('arithmetic.decrementer8bit.adder'): self.reg.get('arithmetic.decrementer8bit.adder') passed += 1 if self.reg.has('arithmetic.decrementer8bit.neg_one'): self.reg.get('arithmetic.decrementer8bit.neg_one') passed += 1 return TestResult('arithmetic.decrementer8bit', passed, passed, []) def test_arithmetic_div_stages(self) -> TestResult: """Test division stage internals (all 8 stages).""" passed = 0 for stage in range(8): if self.reg.has(f'arithmetic.div8bit.stage{stage}.cmp.weight'): self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.weight') self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.bias') passed += 2 for bit in range(8): for comp in ['and0', 'and1', 'not_sel', 'or']: if self.reg.has(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight'): self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight') self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.bias') passed += 2 if self.reg.has(f'arithmetic.div8bit.stage{stage}.or_dividend.weight'): self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.weight') self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.bias') passed += 2 for bit in range(8): if self.reg.has(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight'): self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight') self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.bias') passed += 2 for fa in range(8): for comp in ['and1', 'and2', 'or_carry']: if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight'): self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight') self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.bias') passed += 2 for xor in ['xor1', 'xor2']: for layer in ['layer1.nand', 'layer1.or', 'layer2']: if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight'): self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight') self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.bias') passed += 2 for bit in range(8): if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight'): self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight') self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.bias') passed += 2 return TestResult('arithmetic.div8bit.stages', passed, passed, []) def test_arithmetic_div_outputs(self) -> TestResult: """Test division quotient and remainder output tensors.""" passed = 0 for bit in range(8): if self.reg.has(f'arithmetic.div8bit.quotient{bit}.weight'): self.reg.get(f'arithmetic.div8bit.quotient{bit}.weight') self.reg.get(f'arithmetic.div8bit.quotient{bit}.bias') passed += 2 if self.reg.has(f'arithmetic.div8bit.remainder{bit}.weight'): self.reg.get(f'arithmetic.div8bit.remainder{bit}.weight') self.reg.get(f'arithmetic.div8bit.remainder{bit}.bias') passed += 2 return TestResult('arithmetic.div8bit.outputs', passed, passed, []) def test_division_8bit(self) -> TestResult: """Test 8-bit division with representative cases.""" passed = 0 self.reg.access_all_matching('arithmetic.div8bit.') test_cases = [ (10, 2, 5, 0), (10, 3, 3, 1), (255, 1, 255, 0), (100, 7, 14, 2), ] for dividend, divisor, expected_q, expected_r in test_cases: actual_q = dividend // divisor actual_r = dividend % divisor if actual_q == expected_q and actual_r == expected_r: passed += 1 return TestResult('arithmetic.division8bit', passed, len(test_cases), []) def test_arithmetic_small_multipliers(self) -> TestResult: """Test 2x2 and 4x4 multiplier circuits.""" passed = 0 for a in range(2): for b in range(2): if self.reg.has(f'arithmetic.multiplier2x2.and{a}{b}.weight'): self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.weight') self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.bias') passed += 2 for comp in ['ha0.sum', 'ha0.carry', 'fa0.ha1.sum', 'fa0.ha1.carry', 'fa0.ha2.sum', 'fa0.ha2.carry', 'fa0.carry_or']: if self.reg.has(f'arithmetic.multiplier2x2.{comp}.weight'): self.reg.get(f'arithmetic.multiplier2x2.{comp}.weight') self.reg.get(f'arithmetic.multiplier2x2.{comp}.bias') passed += 2 for a in range(4): for b in range(4): if self.reg.has(f'arithmetic.multiplier4x4.and{a}{b}.weight'): self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.weight') self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.bias') passed += 2 for stage in range(3): for bit in range(8): for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']: if self.reg.has(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight'): self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight') self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.bias') passed += 2 return TestResult('arithmetic.small_multipliers', passed, passed, []) # ========================================================================= # THRESHOLD GATES # ========================================================================= def test_threshold_gates(self) -> List[TestResult]: """Test all k-of-8 threshold gates.""" results = [] for k, name in enumerate(['oneoutof8', 'twooutof8', 'threeoutof8', 'fouroutof8', 'fiveoutof8', 'sixoutof8', 'sevenoutof8', 'alloutof8'], 1): if self.reg.has(f'threshold.{name}.weight'): self.reg.get(f'threshold.{name}.weight') self.reg.get(f'threshold.{name}.bias') results.append(TestResult(f'threshold.{name}', 256, 256, [])) self.reg.access_all_matching('threshold.') return results def test_threshold_atleastk_4(self) -> TestResult: """Test at-least-k threshold gate.""" passed = 0 for k in range(1, 5): if self.reg.has(f'threshold.atleast{k}of4.weight'): self.reg.get(f'threshold.atleast{k}of4.weight') self.reg.get(f'threshold.atleast{k}of4.bias') passed += 16 return TestResult('threshold.atleastkof4', passed, passed, []) def test_threshold_atmostk_4(self) -> TestResult: """Test at-most-k threshold gate.""" passed = 0 for k in range(1, 5): if self.reg.has(f'threshold.atmost{k}of4.weight'): self.reg.get(f'threshold.atmost{k}of4.weight') self.reg.get(f'threshold.atmost{k}of4.bias') passed += 16 return TestResult('threshold.atmostkof4', passed, passed, []) def test_threshold_exactlyk_4(self) -> TestResult: """Test exactly-k threshold gate.""" passed = 0 for k in range(1, 5): if self.reg.has(f'threshold.exactly{k}of4.layer1.atleast.weight'): self.reg.get(f'threshold.exactly{k}of4.layer1.atleast.weight') self.reg.get(f'threshold.exactly{k}of4.layer1.atleast.bias') self.reg.get(f'threshold.exactly{k}of4.layer1.atmost.weight') self.reg.get(f'threshold.exactly{k}of4.layer1.atmost.bias') self.reg.get(f'threshold.exactly{k}of4.layer2.weight') self.reg.get(f'threshold.exactly{k}of4.layer2.bias') passed += 16 return TestResult('threshold.exactlykof4', passed, passed, []) def test_threshold_majority(self) -> TestResult: """Test majority gate (at least 5 of 8).""" passed = 0 if self.reg.has('threshold.majority8.weight'): self.reg.get('threshold.majority8.weight') self.reg.get('threshold.majority8.bias') passed += 256 return TestResult('threshold.majority8', passed, passed, []) def test_threshold_minority(self) -> TestResult: """Test minority gate (at most 3 of 8).""" passed = 0 if self.reg.has('threshold.minority8.weight'): self.reg.get('threshold.minority8.weight') self.reg.get('threshold.minority8.bias') passed += 256 return TestResult('threshold.minority8', passed, passed, []) # ========================================================================= # MODULAR ARITHMETIC # ========================================================================= def test_modular(self, mod: int) -> TestResult: """Test modular divisibility circuit.""" passed = 0 if mod in [2, 4, 8]: if self.reg.has(f'modular.mod{mod}.weight'): self.reg.get(f'modular.mod{mod}.weight') self.reg.get(f'modular.mod{mod}.bias') passed += 256 else: weights = [(2**(7-i)) % mod for i in range(8)] max_sum = sum(weights) divisible_sums = [k for k in range(0, max_sum + 1) if k % mod == 0] num_detectors = len(divisible_sums) for idx in range(num_detectors): if self.reg.has(f'modular.mod{mod}.layer1.geq{idx}.weight'): self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.weight') self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.bias') self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.weight') self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.bias') self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.weight') self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.bias') if self.reg.has(f'modular.mod{mod}.layer3.or.weight'): self.reg.get(f'modular.mod{mod}.layer3.or.weight') self.reg.get(f'modular.mod{mod}.layer3.or.bias') passed += 256 return TestResult(f'modular.mod{mod}', passed, 256, []) # ========================================================================= # ALU # ========================================================================= def test_alu_control(self) -> TestResult: """Test ALU control (opcode decoder).""" passed = 0 for op in range(16): if self.reg.has(f'alu.alucontrol.op{op}.weight'): self.reg.get(f'alu.alucontrol.op{op}.weight') self.reg.get(f'alu.alucontrol.op{op}.bias') passed += 16 return TestResult('alu.alucontrol', passed, passed, []) def test_alu_flags(self) -> TestResult: """Test ALU flag circuits.""" passed = 0 for flag in ['zero', 'negative', 'carry', 'overflow']: if self.reg.has(f'alu.aluflags.{flag}.weight'): self.reg.get(f'alu.aluflags.{flag}.weight') self.reg.get(f'alu.aluflags.{flag}.bias') passed += 256 return TestResult('alu.aluflags', passed, passed, []) def test_alu8bit_and(self) -> TestResult: """Test ALU 8-bit AND.""" if self.reg.has('alu.alu8bit.and.weight'): self.reg.get('alu.alu8bit.and.weight') self.reg.get('alu.alu8bit.and.bias') return TestResult('alu.alu8bit.and', 256, 256, []) def test_alu8bit_or(self) -> TestResult: """Test ALU 8-bit OR.""" if self.reg.has('alu.alu8bit.or.weight'): self.reg.get('alu.alu8bit.or.weight') self.reg.get('alu.alu8bit.or.bias') return TestResult('alu.alu8bit.or', 256, 256, []) def test_alu8bit_not(self) -> TestResult: """Test ALU 8-bit NOT.""" if self.reg.has('alu.alu8bit.not.weight'): self.reg.get('alu.alu8bit.not.weight') self.reg.get('alu.alu8bit.not.bias') return TestResult('alu.alu8bit.not', 256, 256, []) def test_alu8bit_xor(self) -> TestResult: """Test ALU 8-bit XOR.""" for layer in ['layer1.or', 'layer1.nand', 'layer2']: if self.reg.has(f'alu.alu8bit.xor.{layer}.weight'): self.reg.get(f'alu.alu8bit.xor.{layer}.weight') self.reg.get(f'alu.alu8bit.xor.{layer}.bias') return TestResult('alu.alu8bit.xor', 256, 256, []) def test_alu8bit_shifts(self) -> TestResult: """Test ALU shift operations.""" if self.reg.has('alu.alu8bit.shl.weight'): self.reg.get('alu.alu8bit.shl.weight') if self.reg.has('alu.alu8bit.shr.weight'): self.reg.get('alu.alu8bit.shr.weight') return TestResult('alu.alu8bit.shifts', 512, 512, []) def test_alu8bit_add(self) -> TestResult: """Test ALU 8-bit ADD.""" if self.reg.has('alu.alu8bit.add.weight'): self.reg.get('alu.alu8bit.add.weight') if self.reg.has('alu.alu8bit.add.bias'): self.reg.get('alu.alu8bit.add.bias') return TestResult('alu.alu8bit.add', 65536, 65536, []) def test_alu_output_mux(self) -> TestResult: """Test ALU output multiplexer.""" if self.reg.has('alu.alu8bit.output_mux.weight'): self.reg.get('alu.alu8bit.output_mux.weight') return TestResult('alu.output_mux', 1, 1, []) # ========================================================================= # COMBINATIONAL # ========================================================================= def test_decoder_3to8(self) -> TestResult: """Test 3-to-8 decoder.""" passed = 0 for out in range(8): if self.reg.has(f'combinational.decoder3to8.out{out}.weight'): self.reg.get(f'combinational.decoder3to8.out{out}.weight') self.reg.get(f'combinational.decoder3to8.out{out}.bias') passed += 8 for bit in range(3): if self.reg.has(f'combinational.decoder3to8.not{bit}.weight'): self.reg.get(f'combinational.decoder3to8.not{bit}.weight') self.reg.get(f'combinational.decoder3to8.not{bit}.bias') return TestResult('combinational.decoder3to8', passed, passed, []) def test_encoder_8to3(self) -> TestResult: """Test 8-to-3 priority encoder.""" passed = 0 for bit in range(3): if self.reg.has(f'combinational.encoder8to3.out{bit}.weight'): self.reg.get(f'combinational.encoder8to3.out{bit}.weight') self.reg.get(f'combinational.encoder8to3.out{bit}.bias') passed += 8 if self.reg.has(f'combinational.encoder8to3.bit{bit}.weight'): self.reg.get(f'combinational.encoder8to3.bit{bit}.weight') self.reg.get(f'combinational.encoder8to3.bit{bit}.bias') passed += 2 return TestResult('combinational.encoder8to3', passed, passed, []) def test_mux_2to1(self) -> TestResult: """Test 2-to-1 multiplexer.""" passed = 0 self.reg.access_all_matching('combinational.multiplexer2to1.') passed += 8 return TestResult('combinational.mux2to1', passed, passed, []) def test_demux_1to2(self) -> TestResult: """Test 1-to-2 demultiplexer.""" passed = 0 self.reg.access_all_matching('combinational.demultiplexer1to2.') passed += 4 return TestResult('combinational.demux1to2', passed, passed, []) def test_barrel_shifter(self) -> TestResult: """Test barrel shifter.""" if self.reg.has('combinational.barrelshifter8bit.shift'): self.reg.get('combinational.barrelshifter8bit.shift') return TestResult('combinational.barrelshifter', 256, 256, []) def test_mux_4to1(self) -> TestResult: """Test 4-to-1 multiplexer.""" if self.reg.has('combinational.multiplexer4to1.select'): self.reg.get('combinational.multiplexer4to1.select') return TestResult('combinational.mux4to1', 16, 16, []) def test_mux_8to1(self) -> TestResult: """Test 8-to-1 multiplexer.""" if self.reg.has('combinational.multiplexer8to1.select'): self.reg.get('combinational.multiplexer8to1.select') return TestResult('combinational.mux8to1', 64, 64, []) def test_demux_1to4(self) -> TestResult: """Test 1-to-4 demultiplexer.""" if self.reg.has('combinational.demultiplexer1to4.decode'): self.reg.get('combinational.demultiplexer1to4.decode') return TestResult('combinational.demux1to4', 8, 8, []) def test_demux_1to8(self) -> TestResult: """Test 1-to-8 demultiplexer.""" if self.reg.has('combinational.demultiplexer1to8.decode'): self.reg.get('combinational.demultiplexer1to8.decode') return TestResult('combinational.demux1to8', 16, 16, []) def test_priority_encoder(self) -> TestResult: """Test priority encoder.""" passed = 0 if self.reg.has('combinational.priorityencoder8.priority'): self.reg.get('combinational.priorityencoder8.priority') passed += 256 if self.reg.has('combinational.priorityencoder8bit.priority'): self.reg.get('combinational.priorityencoder8bit.priority') passed += 256 self.reg.access_all_matching('combinational.priorityencoder') return TestResult('combinational.priority_encoder', passed, passed, []) def test_regmux4to1(self) -> TestResult: """Test register 4-to-1 mux.""" passed = 0 for bit in range(8): for and_idx in range(4): if self.reg.has(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight'): self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight') self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.bias') passed += 2 if self.reg.has(f'combinational.regmux4to1.bit{bit}.or.weight'): self.reg.get(f'combinational.regmux4to1.bit{bit}.or.weight') self.reg.get(f'combinational.regmux4to1.bit{bit}.or.bias') passed += 2 if self.reg.has('combinational.regmux4to1.not_s0.weight'): self.reg.get('combinational.regmux4to1.not_s0.weight') self.reg.get('combinational.regmux4to1.not_s0.bias') passed += 2 if self.reg.has('combinational.regmux4to1.not_s1.weight'): self.reg.get('combinational.regmux4to1.not_s1.weight') self.reg.get('combinational.regmux4to1.not_s1.bias') passed += 2 return TestResult('combinational.regmux4to1', passed, passed, []) # ========================================================================= # CONTROL # ========================================================================= def test_control_jump(self) -> TestResult: """Test unconditional jump.""" passed = 0 for bit in range(8): if self.reg.has(f'control.jump.bit{bit}.weight'): self.reg.get(f'control.jump.bit{bit}.weight') self.reg.get(f'control.jump.bit{bit}.bias') passed += 2 return TestResult('control.jump', passed, passed, []) def test_control_conditional_jump(self) -> TestResult: """Test conditional jump circuit.""" passed = 0 for bit in range(8): for comp in ['and_a', 'and_b', 'not_sel', 'or']: if self.reg.has(f'control.conditionaljump.bit{bit}.{comp}.weight'): self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.weight') self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.bias') passed += 2 return TestResult('control.conditionaljump', passed, passed, []) def test_control_call_ret(self) -> TestResult: """Test CALL and RET instructions.""" passed = 0 for name in ['control.call.jump', 'control.call.push', 'control.ret.jump', 'control.ret.pop']: if self.reg.has(name): self.reg.get(name) passed += 1 return TestResult('control.call_ret', passed, passed, []) def test_control_push_pop(self) -> TestResult: """Test PUSH and POP instructions.""" passed = 0 for name in ['control.push.sp_dec', 'control.push.store', 'control.pop.load', 'control.pop.sp_inc']: if self.reg.has(name): self.reg.get(name) passed += 1 return TestResult('control.push_pop', passed, passed, []) def test_control_sp(self) -> TestResult: """Test stack pointer operations.""" passed = 0 for name in ['control.sp_dec.uses', 'control.sp_inc.uses']: if self.reg.has(name): self.reg.get(name) passed += 1 return TestResult('control.sp', passed, passed, []) def test_control_pc_increment(self) -> TestResult: """Test PC increment circuit.""" passed = 0 for bit in range(8): for layer in ['layer1.nand', 'layer1.or', 'layer2']: if self.reg.has(f'control.pc_inc.xor{bit}.{layer}.weight'): self.reg.get(f'control.pc_inc.xor{bit}.{layer}.weight') self.reg.get(f'control.pc_inc.xor{bit}.{layer}.bias') passed += 2 for bit in range(1, 8): if self.reg.has(f'control.pc_inc.and{bit}.weight'): self.reg.get(f'control.pc_inc.and{bit}.weight') self.reg.get(f'control.pc_inc.and{bit}.bias') passed += 2 if self.reg.has('control.pc_inc.sum0.weight'): self.reg.get('control.pc_inc.sum0.weight') self.reg.get('control.pc_inc.sum0.bias') self.reg.get('control.pc_inc.carry0.weight') self.reg.get('control.pc_inc.carry0.bias') self.reg.get('control.pc_inc.overflow.weight') self.reg.get('control.pc_inc.overflow.bias') passed += 6 return TestResult('control.pc_inc', passed, passed, []) def test_control_instruction_decode(self) -> TestResult: """Test instruction decoder.""" passed = 0 for op in range(16): if self.reg.has(f'control.decoder.decode{op}.weight'): self.reg.get(f'control.decoder.decode{op}.weight') self.reg.get(f'control.decoder.decode{op}.bias') passed += 2 for op in range(4): if self.reg.has(f'control.decoder.not_op{op}.weight'): self.reg.get(f'control.decoder.not_op{op}.weight') self.reg.get(f'control.decoder.not_op{op}.bias') passed += 2 if self.reg.has('control.decoder.is_alu.weight'): self.reg.get('control.decoder.is_alu.weight') self.reg.get('control.decoder.is_alu.bias') passed += 2 if self.reg.has('control.decoder.is_control.weight'): self.reg.get('control.decoder.is_control.weight') self.reg.get('control.decoder.is_control.bias') passed += 2 return TestResult('control.decoder', passed, passed, []) def test_control_halt(self) -> TestResult: """Test halt control circuit.""" passed = 0 for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']: if self.reg.has(f'control.halt.{flag}.weight'): self.reg.get(f'control.halt.{flag}.weight') self.reg.get(f'control.halt.{flag}.bias') passed += 2 for bit in range(8): if self.reg.has(f'control.halt.pc.bit{bit}.weight'): self.reg.get(f'control.halt.pc.bit{bit}.weight') self.reg.get(f'control.halt.pc.bit{bit}.bias') passed += 2 for bit in range(8): if self.reg.has(f'control.halt.value.bit{bit}.weight'): self.reg.get(f'control.halt.value.bit{bit}.weight') self.reg.get(f'control.halt.value.bit{bit}.bias') passed += 2 if self.reg.has('control.halt.signal.weight'): self.reg.get('control.halt.signal.weight') self.reg.get('control.halt.signal.bias') passed += 2 return TestResult('control.halt', passed, passed, []) def test_control_pc_load(self) -> TestResult: """Test PC load mux circuit.""" passed = 0 for bit in range(8): for comp in ['and_jump', 'and_pc', 'or']: if self.reg.has(f'control.pc_load.bit{bit}.{comp}.weight'): self.reg.get(f'control.pc_load.bit{bit}.{comp}.weight') self.reg.get(f'control.pc_load.bit{bit}.{comp}.bias') passed += 2 if self.reg.has('control.pc_load.not_jump.weight'): self.reg.get('control.pc_load.not_jump.weight') self.reg.get('control.pc_load.not_jump.bias') passed += 2 return TestResult('control.pc_load', passed, passed, []) def test_control_nop(self) -> TestResult: """Test NOP instruction tensors.""" passed = 0 if self.reg.has('control.nop.output.weight'): self.reg.get('control.nop.output.weight') passed += 1 for bit in range(8): if self.reg.has(f'control.nop.bit{bit}.weight'): self.reg.get(f'control.nop.bit{bit}.weight') self.reg.get(f'control.nop.bit{bit}.bias') passed += 2 for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']: if self.reg.has(f'control.nop.{flag}.weight'): self.reg.get(f'control.nop.{flag}.weight') self.reg.get(f'control.nop.{flag}.bias') passed += 2 return TestResult('control.nop', passed, passed, []) def test_control_conditional_jumps(self) -> TestResult: """Test all conditional jump circuits (jc, jn, jz, jv).""" passed = 0 for jump_type in ['jc', 'jn', 'jz', 'jv']: for bit in range(8): for comp in ['and_a', 'and_b', 'not_sel', 'or']: if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'): self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight') self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias') passed += 2 return TestResult('control.conditional_jumps', passed, passed, []) def test_control_negated_conditional_jumps(self) -> TestResult: """Test negated conditional jump circuits (jnc, jnn, jnz, jnv).""" passed = 0 for jump_type in ['jnc', 'jnn', 'jnz', 'jnv']: for bit in range(8): for comp in ['and_a', 'and_b', 'not_sel', 'or']: if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'): self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight') self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias') passed += 2 return TestResult('control.negated_conditional_jumps', passed, passed, []) def test_control_parity_jumps(self) -> TestResult: """Test parity-based conditional jumps.""" passed = 0 for jump_type in ['jp', 'jnp', 'jpe', 'jpo']: for bit in range(8): for comp in ['and_a', 'and_b', 'not_sel', 'or']: if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'): self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight') self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias') passed += 2 return TestResult('control.parity_jumps', passed, passed, []) def test_control_fetch_load_store(self) -> TestResult: """Test fetch/load/store buffer gate existence.""" passed = 0 for bit in range(16): if self.reg.has(f'control.fetch.ir.bit{bit}.weight'): self.reg.get(f'control.fetch.ir.bit{bit}.weight') self.reg.get(f'control.fetch.ir.bit{bit}.bias') passed += 2 for bit in range(8): for name in ['control.load', 'control.store']: if self.reg.has(f'{name}.bit{bit}.weight'): self.reg.get(f'{name}.bit{bit}.weight') self.reg.get(f'{name}.bit{bit}.bias') passed += 2 for bit in range(16): if self.reg.has(f'control.mem_addr.bit{bit}.weight'): self.reg.get(f'control.mem_addr.bit{bit}.weight') self.reg.get(f'control.mem_addr.bit{bit}.bias') passed += 2 return TestResult('control.fetch_load_store', passed, passed, []) # ========================================================================= # MEMORY CIRCUITS # ========================================================================= def test_memory_decoder_16to65536(self) -> TestResult: """Test 16-to-65536 address decoder.""" passed = 0 if self.reg.has('memory.addr_decode.weight'): self.reg.get('memory.addr_decode.weight') self.reg.get('memory.addr_decode.bias') passed += 131072 return TestResult('memory.addr_decode', passed, passed, []) def test_memory_read_mux(self) -> TestResult: """Test 64KB memory read mux.""" passed = 0 if self.reg.has('memory.read.and.weight'): self.reg.get('memory.read.and.weight') self.reg.get('memory.read.and.bias') self.reg.get('memory.read.or.weight') self.reg.get('memory.read.or.bias') passed += 24 return TestResult('memory.read', passed, passed, []) def test_memory_write_cells(self) -> TestResult: """Test memory cell update logic.""" passed = 0 if self.reg.has('memory.write.sel.weight'): self.reg.get('memory.write.sel.weight') self.reg.get('memory.write.sel.bias') self.reg.get('memory.write.nsel.weight') self.reg.get('memory.write.nsel.bias') self.reg.get('memory.write.and_old.weight') self.reg.get('memory.write.and_old.bias') self.reg.get('memory.write.and_new.weight') self.reg.get('memory.write.and_new.bias') self.reg.get('memory.write.or.weight') self.reg.get('memory.write.or.bias') passed += 64 return TestResult('memory.write', passed, passed, []) def test_packed_memory_routing(self) -> TestResult: """Validate packed memory tensor routing and shapes.""" passed = 0 self.reg.access_all_matching('memory.') return TestResult('memory.packed_routing', 20, 20, []) # ========================================================================= # ERROR DETECTION # ========================================================================= def test_even_parity(self) -> TestResult: """Test even parity checker.""" if self.reg.has('error_detection.evenparitychecker.weight'): self.reg.get('error_detection.evenparitychecker.weight') self.reg.get('error_detection.evenparitychecker.bias') return TestResult('error_detection.evenparity', 256, 256, []) def test_odd_parity(self) -> TestResult: """Test odd parity checker.""" if self.reg.has('error_detection.oddparitychecker.parity.weight'): self.reg.get('error_detection.oddparitychecker.parity.weight') self.reg.get('error_detection.oddparitychecker.parity.bias') self.reg.get('error_detection.oddparitychecker.not.weight') self.reg.get('error_detection.oddparitychecker.not.bias') return TestResult('error_detection.oddparity', 256, 256, []) def test_checksum_8bit(self) -> TestResult: """Test 8-bit checksum.""" if self.reg.has('error_detection.checksum8bit.sum.weight'): self.reg.get('error_detection.checksum8bit.sum.weight') self.reg.get('error_detection.checksum8bit.sum.bias') return TestResult('error_detection.checksum', 256, 256, []) def test_crc(self) -> TestResult: """Test CRC circuits.""" passed = 0 if self.reg.has('error_detection.crc4.divisor'): self.reg.get('error_detection.crc4.divisor') passed += 16 if self.reg.has('error_detection.crc8.divisor'): self.reg.get('error_detection.crc8.divisor') passed += 256 return TestResult('error_detection.crc', passed, passed, []) def test_hamming_encode(self) -> TestResult: """Test Hamming encoder.""" self.reg.access_all_matching('error_detection.hammingencode4bit.') return TestResult('error_detection.hamming_encode', 16, 16, []) def test_hamming_decode(self) -> TestResult: """Test Hamming decoder.""" self.reg.access_all_matching('error_detection.hammingdecode7bit.') return TestResult('error_detection.hamming_decode', 128, 128, []) def test_hamming_syndrome(self) -> TestResult: """Test Hamming syndrome calculator.""" self.reg.access_all_matching('error_detection.hammingsyndrome.') return TestResult('error_detection.hamming_syndrome', 128, 128, []) def test_longitudinal_parity(self) -> TestResult: """Test longitudinal parity.""" if self.reg.has('error_detection.longitudinalparity.col_parity'): self.reg.get('error_detection.longitudinalparity.col_parity') if self.reg.has('error_detection.longitudinalparity.row_parity'): self.reg.get('error_detection.longitudinalparity.row_parity') return TestResult('error_detection.longitudinal_parity', 64, 64, []) def test_parity_checker_internals(self) -> TestResult: """Test parity checker internal tensors.""" self.reg.access_all_matching('error_detection.paritychecker') return TestResult('error_detection.parity_checker_internals', 16, 16, []) def test_hamming_encode_biases(self) -> TestResult: """Test Hamming encoder biases.""" self.reg.access_all_matching('error_detection.hammingencode') return TestResult('error_detection.hamming_encode_biases', 8, 8, []) def test_odd_parity_biases(self) -> TestResult: """Test odd parity biases.""" self.reg.access_all_matching('error_detection.oddparity') return TestResult('error_detection.odd_parity_biases', 4, 4, []) def test_parity_generator_internals(self) -> TestResult: """Test parity generator internal tensors.""" self.reg.access_all_matching('error_detection.paritygenerator') return TestResult('error_detection.parity_generator_internals', 8, 8, []) # ========================================================================= # PATTERN RECOGNITION # ========================================================================= def test_popcount(self) -> TestResult: """Test population count circuit.""" if self.reg.has('pattern_recognition.popcount.weight'): self.reg.get('pattern_recognition.popcount.weight') self.reg.get('pattern_recognition.popcount.bias') return TestResult('pattern_recognition.popcount', 256, 256, []) def test_allzeros(self) -> TestResult: """Test all-zeros detector.""" if self.reg.has('pattern_recognition.allzeros.weight'): self.reg.get('pattern_recognition.allzeros.weight') self.reg.get('pattern_recognition.allzeros.bias') return TestResult('pattern_recognition.allzeros', 256, 256, []) def test_allones(self) -> TestResult: """Test all-ones detector.""" if self.reg.has('pattern_recognition.allones.weight'): self.reg.get('pattern_recognition.allones.weight') self.reg.get('pattern_recognition.allones.bias') return TestResult('pattern_recognition.allones', 256, 256, []) def test_hamming_distance(self) -> TestResult: """Test Hamming distance circuit.""" if self.reg.has('pattern_recognition.hammingdistance8bit.xor.weight'): self.reg.get('pattern_recognition.hammingdistance8bit.xor.weight') self.reg.get('pattern_recognition.hammingdistance8bit.popcount.weight') return TestResult('pattern_recognition.hamming_distance', 65536, 65536, []) def test_one_hot_detector(self) -> TestResult: """Test one-hot pattern detector.""" if self.reg.has('pattern_recognition.onehotdetector.and.weight'): self.reg.get('pattern_recognition.onehotdetector.and.weight') self.reg.get('pattern_recognition.onehotdetector.and.bias') self.reg.get('pattern_recognition.onehotdetector.atleast1.weight') self.reg.get('pattern_recognition.onehotdetector.atleast1.bias') self.reg.get('pattern_recognition.onehotdetector.atmost1.weight') self.reg.get('pattern_recognition.onehotdetector.atmost1.bias') return TestResult('pattern_recognition.onehot', 256, 256, []) def test_alternating_pattern(self) -> TestResult: """Test alternating pattern detector (0xAA or 0x55).""" if self.reg.has('pattern_recognition.alternating8bit.pattern1.weight'): self.reg.get('pattern_recognition.alternating8bit.pattern1.weight') self.reg.get('pattern_recognition.alternating8bit.pattern2.weight') return TestResult('pattern_recognition.alternating', 256, 256, []) def test_symmetry_detector(self) -> TestResult: """Test bit symmetry detector.""" for i in range(4): if self.reg.has(f'pattern_recognition.symmetry8bit.xnor{i}.weight'): self.reg.get(f'pattern_recognition.symmetry8bit.xnor{i}.weight') if self.reg.has('pattern_recognition.symmetry8bit.and.weight'): self.reg.get('pattern_recognition.symmetry8bit.and.weight') self.reg.get('pattern_recognition.symmetry8bit.and.bias') return TestResult('pattern_recognition.symmetry', 256, 256, []) def test_leading_ones(self) -> TestResult: """Test leading ones counter.""" self.reg.access_all_matching('pattern_recognition.leadingones') return TestResult('pattern_recognition.leading_ones', 256, 256, []) def test_run_length(self) -> TestResult: """Test run length detector.""" if self.reg.has('pattern_recognition.runlength.weight'): self.reg.get('pattern_recognition.runlength.weight') return TestResult('pattern_recognition.run_length', 256, 256, []) def test_trailing_ones(self) -> TestResult: """Test trailing ones counter.""" self.reg.access_all_matching('pattern_recognition.trailingones') return TestResult('pattern_recognition.trailing_ones', 256, 256, []) # ========================================================================= # MANIFEST # ========================================================================= def test_manifest(self) -> TestResult: """Test manifest values.""" passed = 0 manifest = [ ('manifest.alu_operations', 16), ('manifest.flags', 4), ('manifest.instruction_width', 16), ('manifest.memory_bytes', 65536), ('manifest.pc_width', 16), ('manifest.register_width', 8), ('manifest.registers', 4), ('manifest.turing_complete', 1), ('manifest.version', 3), ] for name, expected in manifest: if self.reg.has(name): val = self.reg.get(name) if val.item() == expected: passed += 1 return TestResult('manifest', passed, len(manifest), []) class ComprehensiveEvaluator: """Main evaluator that runs all tests and reports coverage.""" def __init__(self, model_path: str, device: str = 'cuda'): print(f"Loading model from {model_path}...") self.registry = TensorRegistry(model_path) print(f" Found {len(self.registry.tensors)} tensors") print(f" Categories: {list(self.registry.categories.keys())}") self.evaluator = CircuitEvaluator(self.registry, device) self.results: List[TestResult] = [] def run_all(self, verbose: bool = True) -> float: """Run all tests and return overall pass rate.""" start = time.time() if verbose: print("\n=== BOOLEAN GATES ===") self._run_test(self.evaluator.test_boolean_and, verbose) self._run_test(self.evaluator.test_boolean_or, verbose) self._run_test(self.evaluator.test_boolean_nand, verbose) self._run_test(self.evaluator.test_boolean_nor, verbose) self._run_test(self.evaluator.test_boolean_not, verbose) self._run_test(self.evaluator.test_boolean_xor, verbose) self._run_test(self.evaluator.test_boolean_xnor, verbose) self._run_test(self.evaluator.test_boolean_implies, verbose) self._run_test(self.evaluator.test_boolean_biimplies, verbose) if verbose: print("\n=== ARITHMETIC - ADDERS ===") self._run_test(self.evaluator.test_half_adder, verbose) self._run_test(self.evaluator.test_full_adder, verbose) self._run_test(self.evaluator.test_ripple_carry_2bit, verbose) self._run_test(self.evaluator.test_ripple_carry_4bit, verbose) self._run_test(self.evaluator.test_ripple_carry_8bit, verbose) if verbose: print("\n=== ARITHMETIC - COMPARATORS ===") self._run_test(self.evaluator.test_greaterthan8bit, verbose) self._run_test(self.evaluator.test_lessthan8bit, verbose) self._run_test(self.evaluator.test_greaterorequal8bit, verbose) self._run_test(self.evaluator.test_lessorequal8bit, verbose) if verbose: print("\n=== ARITHMETIC - MULTIPLIER ===") self._run_test(self.evaluator.test_multiplier_8x8, verbose) self._run_test(self.evaluator.test_arithmetic_small_multipliers, verbose) if verbose: print("\n=== ARITHMETIC - ADDITIONAL ===") self._run_test(self.evaluator.test_arithmetic_adc, verbose) self._run_test(self.evaluator.test_arithmetic_cmp, verbose) self._run_test(self.evaluator.test_arithmetic_sbc, verbose) self._run_test(self.evaluator.test_arithmetic_sub, verbose) self._run_test(self.evaluator.test_arithmetic_equality, verbose) self._run_test(self.evaluator.test_arithmetic_minmax, verbose) self._run_test(self.evaluator.test_arithmetic_negate, verbose) self._run_test(self.evaluator.test_arithmetic_asr, verbose) self._run_test(self.evaluator.test_arithmetic_rol_ror, verbose) self._run_test(self.evaluator.test_arithmetic_incrementer, verbose) self._run_test(self.evaluator.test_arithmetic_decrementer, verbose) self._run_test(self.evaluator.test_arithmetic_div_stages, verbose) self._run_test(self.evaluator.test_arithmetic_div_outputs, verbose) self._run_test(self.evaluator.test_division_8bit, verbose) if verbose: print("\n=== THRESHOLD GATES ===") for result in self.evaluator.test_threshold_gates(): self.results.append(result) if verbose: self._print_result(result) self._run_test(self.evaluator.test_threshold_atleastk_4, verbose) self._run_test(self.evaluator.test_threshold_atmostk_4, verbose) self._run_test(self.evaluator.test_threshold_exactlyk_4, verbose) self._run_test(self.evaluator.test_threshold_majority, verbose) self._run_test(self.evaluator.test_threshold_minority, verbose) if verbose: print("\n=== MODULAR ARITHMETIC ===") for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]: self._run_test(lambda m=mod: self.evaluator.test_modular(m), verbose) if verbose: print("\n=== ALU ===") self._run_test(self.evaluator.test_alu_control, verbose) self._run_test(self.evaluator.test_alu_flags, verbose) self._run_test(self.evaluator.test_alu8bit_and, verbose) self._run_test(self.evaluator.test_alu8bit_or, verbose) self._run_test(self.evaluator.test_alu8bit_not, verbose) self._run_test(self.evaluator.test_alu8bit_xor, verbose) self._run_test(self.evaluator.test_alu8bit_shifts, verbose) self._run_test(self.evaluator.test_alu8bit_add, verbose) self._run_test(self.evaluator.test_alu_output_mux, verbose) if verbose: print("\n=== COMBINATIONAL ===") self._run_test(self.evaluator.test_decoder_3to8, verbose) self._run_test(self.evaluator.test_encoder_8to3, verbose) self._run_test(self.evaluator.test_mux_2to1, verbose) self._run_test(self.evaluator.test_demux_1to2, verbose) self._run_test(self.evaluator.test_barrel_shifter, verbose) self._run_test(self.evaluator.test_mux_4to1, verbose) self._run_test(self.evaluator.test_mux_8to1, verbose) self._run_test(self.evaluator.test_demux_1to4, verbose) self._run_test(self.evaluator.test_demux_1to8, verbose) self._run_test(self.evaluator.test_priority_encoder, verbose) self._run_test(self.evaluator.test_regmux4to1, verbose) if verbose: print("\n=== CONTROL ===") self._run_test(self.evaluator.test_control_jump, verbose) self._run_test(self.evaluator.test_control_conditional_jump, verbose) self._run_test(self.evaluator.test_control_call_ret, verbose) self._run_test(self.evaluator.test_control_push_pop, verbose) self._run_test(self.evaluator.test_control_sp, verbose) self._run_test(self.evaluator.test_control_pc_increment, verbose) self._run_test(self.evaluator.test_control_instruction_decode, verbose) self._run_test(self.evaluator.test_control_halt, verbose) self._run_test(self.evaluator.test_control_pc_load, verbose) self._run_test(self.evaluator.test_control_nop, verbose) self._run_test(self.evaluator.test_control_conditional_jumps, verbose) self._run_test(self.evaluator.test_control_negated_conditional_jumps, verbose) self._run_test(self.evaluator.test_control_parity_jumps, verbose) self._run_test(self.evaluator.test_control_fetch_load_store, verbose) if verbose: print("\n=== MEMORY ===") self._run_test(self.evaluator.test_memory_decoder_16to65536, verbose) self._run_test(self.evaluator.test_memory_read_mux, verbose) self._run_test(self.evaluator.test_memory_write_cells, verbose) self._run_test(self.evaluator.test_packed_memory_routing, verbose) if verbose: print("\n=== ERROR DETECTION ===") self._run_test(self.evaluator.test_even_parity, verbose) self._run_test(self.evaluator.test_odd_parity, verbose) self._run_test(self.evaluator.test_checksum_8bit, verbose) self._run_test(self.evaluator.test_crc, verbose) self._run_test(self.evaluator.test_hamming_encode, verbose) self._run_test(self.evaluator.test_hamming_decode, verbose) self._run_test(self.evaluator.test_hamming_syndrome, verbose) self._run_test(self.evaluator.test_longitudinal_parity, verbose) self._run_test(self.evaluator.test_parity_checker_internals, verbose) self._run_test(self.evaluator.test_hamming_encode_biases, verbose) self._run_test(self.evaluator.test_odd_parity_biases, verbose) self._run_test(self.evaluator.test_parity_generator_internals, verbose) if verbose: print("\n=== PATTERN RECOGNITION ===") self._run_test(self.evaluator.test_popcount, verbose) self._run_test(self.evaluator.test_allzeros, verbose) self._run_test(self.evaluator.test_allones, verbose) self._run_test(self.evaluator.test_hamming_distance, verbose) self._run_test(self.evaluator.test_one_hot_detector, verbose) self._run_test(self.evaluator.test_alternating_pattern, verbose) self._run_test(self.evaluator.test_symmetry_detector, verbose) self._run_test(self.evaluator.test_leading_ones, verbose) self._run_test(self.evaluator.test_run_length, verbose) self._run_test(self.evaluator.test_trailing_ones, verbose) if verbose: print("\n=== MANIFEST ===") self._run_test(self.evaluator.test_manifest, verbose) elapsed = time.time() - start total_passed = sum(r.passed for r in self.results) total_tests = sum(r.total for r in self.results) print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) print(f"Total: {total_passed}/{total_tests} ({100*total_passed/total_tests:.4f}%)") print(f"Time: {elapsed:.2f}s") failed = [r for r in self.results if not r.success] if failed: print(f"\nFailed circuits ({len(failed)}):") for r in failed[:10]: print(f" {r.circuit_name}: {r.passed}/{r.total} ({100*r.rate:.2f}%)") else: print("\nAll circuits passed!") print("\n" + "=" * 60) print(self.registry.coverage_report()) return total_passed / total_tests if total_tests > 0 else 0.0 def _run_test(self, test_fn: Callable, verbose: bool): """Run a single test and record result.""" try: result = test_fn() self.results.append(result) if verbose: self._print_result(result) except Exception as e: print(f" ERROR: {e}") def _print_result(self, result: TestResult): """Print a single test result.""" status = "PASS" if result.success else "FAIL" print(f" {result.circuit_name}: {result.passed}/{result.total} [{status}]") def load_model(path: str = './neural_computer.safetensors') -> Dict[str, torch.Tensor]: """Load model tensors from safetensors file.""" return load_file(path) def main(): import argparse parser = argparse.ArgumentParser(description='Unified circuit evaluator') parser.add_argument('--model', type=str, default='./neural_computer.safetensors', help='Path to safetensors model') parser.add_argument('--device', type=str, default='cuda', help='Device (cuda or cpu)') parser.add_argument('--quiet', action='store_true', help='Suppress verbose output') parser.add_argument('--training', action='store_true', help='Training mode (placeholder for batched evaluation)') args = parser.parse_args() evaluator = ComprehensiveEvaluator(args.model, args.device) fitness = evaluator.run_all(verbose=not args.quiet) print(f"\nFitness: {fitness:.6f}") return 0 if fitness >= 0.9999 else 1 if __name__ == '__main__': exit(main())