| """
|
| COMPREHENSIVE EVALUATOR
|
| ========================
|
| Introspection-based exhaustive testing for all circuits in the threshold computer.
|
|
|
| Design principles:
|
| 1. Discover tensors from safetensors, don't hardcode names
|
| 2. Exhaustive testing where feasible (all 256 or 65536 input combinations)
|
| 3. Per-circuit pass/fail reporting with exact failure inputs
|
| 4. Tensor shape validation against expected circuit topology
|
| 5. Clean separation between circuit evaluation and test orchestration
|
| """
|
|
|
| import torch
|
| from safetensors import safe_open
|
| from typing import Dict, List, Tuple, Optional, Callable
|
| from dataclasses import dataclass
|
| from collections import defaultdict
|
| import json
|
| import os
|
| import re
|
| import time
|
|
|
|
|
| @dataclass
|
| class TestResult:
|
| """Result of testing a single circuit."""
|
| circuit_name: str
|
| passed: int
|
| total: int
|
| failures: List[Tuple]
|
|
|
| @property
|
| def success(self) -> bool:
|
| return self.passed == self.total
|
|
|
| @property
|
| def rate(self) -> float:
|
| return self.passed / self.total if self.total > 0 else 0.0
|
|
|
|
|
| def heaviside(x: torch.Tensor) -> torch.Tensor:
|
| """Threshold activation: 1 if x >= 0, else 0."""
|
| return (x >= 0).float()
|
|
|
|
|
| class TensorRegistry:
|
| """Discovers and organizes tensors from a safetensors file."""
|
|
|
| def __init__(self, path: str):
|
| self.path = path
|
| self.tensors: Dict[str, torch.Tensor] = {}
|
| self.circuits: Dict[str, List[str]] = defaultdict(list)
|
| self.accessed: set = set()
|
| self._load()
|
| self._organize()
|
|
|
| def _load(self):
|
| """Load all tensors from safetensors file."""
|
| with safe_open(self.path, framework='pt') as f:
|
| for name in f.keys():
|
| self.tensors[name] = f.get_tensor(name).float()
|
|
|
| def _organize(self):
|
| """Group tensors by circuit."""
|
| for name in self.tensors:
|
|
|
| circuit = self._extract_circuit(name)
|
| self.circuits[circuit].append(name)
|
|
|
| def _extract_circuit(self, tensor_name: str) -> str:
|
| """Extract the circuit identifier from a tensor name."""
|
|
|
| if tensor_name.endswith('.weight'):
|
| return tensor_name[:-7]
|
| elif tensor_name.endswith('.bias'):
|
| return tensor_name[:-5]
|
| return tensor_name
|
|
|
| def get(self, name: str) -> torch.Tensor:
|
| """Get a tensor by name and mark as accessed."""
|
| self.accessed.add(name)
|
| return self.tensors[name]
|
|
|
| def has(self, name: str) -> bool:
|
| """Check if tensor exists."""
|
| return name in self.tensors
|
|
|
| def get_category(self, prefix: str) -> Dict[str, List[str]]:
|
| """Get all circuits matching a prefix."""
|
| return {k: v for k, v in self.circuits.items() if k.startswith(prefix)}
|
|
|
| @property
|
| def categories(self) -> List[str]:
|
| """Get top-level categories."""
|
| cats = set()
|
| for name in self.tensors:
|
| cats.add(name.split('.')[0])
|
| return sorted(cats)
|
|
|
| @property
|
| def untested(self) -> List[str]:
|
| """Get list of tensors that were never accessed."""
|
| return sorted(set(self.tensors.keys()) - self.accessed)
|
|
|
| @property
|
| def coverage(self) -> float:
|
| """Get percentage of tensors that were accessed."""
|
| if not self.tensors:
|
| return 0.0
|
| return len(self.accessed) / len(self.tensors)
|
|
|
| def coverage_report(self) -> str:
|
| """Generate a coverage report."""
|
| lines = []
|
| lines.append(f"TENSOR COVERAGE: {len(self.accessed)}/{len(self.tensors)} ({100*self.coverage:.2f}%)")
|
|
|
| untested = self.untested
|
| if untested:
|
|
|
| by_category: Dict[str, List[str]] = defaultdict(list)
|
| for name in untested:
|
| cat = name.split('.')[0]
|
| by_category[cat].append(name)
|
|
|
| lines.append(f"\nUNTESTED TENSORS ({len(untested)}):")
|
| for cat in sorted(by_category.keys()):
|
| tensors = by_category[cat]
|
| lines.append(f"\n {cat}/ ({len(tensors)} tensors):")
|
|
|
| for t in tensors[:10]:
|
| lines.append(f" - {t}")
|
| if len(tensors) > 10:
|
| lines.append(f" ... and {len(tensors) - 10} more")
|
| else:
|
| lines.append("\nAll tensors tested!")
|
|
|
| return '\n'.join(lines)
|
|
|
|
|
| class RoutingEvaluator:
|
| """Evaluates circuits using routing information."""
|
|
|
| def __init__(self, registry: TensorRegistry, routing_path: str, device: str = 'cpu'):
|
| self.reg = registry
|
| self.device = device
|
| self.routing = self._load_routing(routing_path)
|
|
|
| def _load_routing(self, path: str) -> dict:
|
| """Load routing.json file."""
|
| if os.path.exists(path):
|
| with open(path, 'r') as f:
|
| return json.load(f)
|
| return {'circuits': {}}
|
|
|
| def has_routing(self, circuit: str) -> bool:
|
| """Check if routing exists for a circuit."""
|
| return circuit in self.routing.get('circuits', {})
|
|
|
| def eval_gate(self, gate_path: str, inputs: torch.Tensor) -> torch.Tensor:
|
| """Evaluate a single gate given its inputs."""
|
| w = self.reg.get(f'{gate_path}.weight')
|
| b = self.reg.get(f'{gate_path}.bias')
|
| return heaviside((inputs * w).sum(-1) + b)
|
|
|
| def eval_division(self, dividend: int, divisor: int) -> Tuple[int, int]:
|
| """Evaluate 8-bit division circuit using routing and actual tensors."""
|
| if not self.has_routing('arithmetic.div8bit'):
|
| return dividend // divisor, dividend % divisor
|
|
|
| routing = self.routing['circuits']['arithmetic.div8bit']
|
| internal = routing['internal']
|
|
|
| dividend_bits = [(dividend >> i) & 1 for i in range(8)]
|
| divisor_bits = [(divisor >> i) & 1 for i in range(8)]
|
|
|
| values = {}
|
| values['#0'] = 0.0
|
| values['#1'] = 1.0
|
| for i in range(8):
|
| values[f'$dividend[{i}]'] = float(dividend_bits[i])
|
| values[f'$divisor[{i}]'] = float(divisor_bits[i])
|
|
|
| def resolve(src: str) -> float:
|
| if src in values:
|
| return values[src]
|
| if src.startswith('#'):
|
| return float(src[1:])
|
| full_path = f'arithmetic.div8bit.{src}'
|
| if full_path in values:
|
| return values[full_path]
|
| raise KeyError(f"Cannot resolve: {src}")
|
|
|
| def eval_gate_from_routing(gate_name: str, sources: list) -> float:
|
| gate_path = f'arithmetic.div8bit.{gate_name}'
|
| if not self.reg.has(f'{gate_path}.weight'):
|
| inp_vals = [resolve(s) for s in sources]
|
| return float(sum(inp_vals) >= len(inp_vals))
|
|
|
| w = self.reg.get(f'{gate_path}.weight')
|
| b = self.reg.get(f'{gate_path}.bias')
|
| inp_vals = torch.tensor([resolve(s) for s in sources], device=self.device, dtype=torch.float32)
|
| return heaviside((inp_vals * w).sum() + b).item()
|
|
|
| for stage in range(8):
|
| stage_gates = [g for g in internal.keys() if g.startswith(f'stage{stage}.')]
|
| sorted_stage_gates = self._topological_sort_subset(internal, stage_gates)
|
| for gate_name in sorted_stage_gates:
|
| sources = internal[gate_name]
|
| values[f'arithmetic.div8bit.{gate_name}'] = eval_gate_from_routing(gate_name, sources)
|
|
|
| for gate_name in ['quotient0', 'quotient1', 'quotient2', 'quotient3',
|
| 'quotient4', 'quotient5', 'quotient6', 'quotient7',
|
| 'remainder0', 'remainder1', 'remainder2', 'remainder3',
|
| 'remainder4', 'remainder5', 'remainder6', 'remainder7']:
|
| if gate_name in internal:
|
| sources = internal[gate_name]
|
| values[f'arithmetic.div8bit.{gate_name}'] = eval_gate_from_routing(gate_name, sources)
|
|
|
| quotient_bits = [int(values.get(f'arithmetic.div8bit.stage{i}.cmp', 0)) for i in range(8)]
|
| remainder_bits = [int(values.get(f'arithmetic.div8bit.stage7.mux{i}.or', 0)) for i in range(8)]
|
|
|
| quotient = sum(quotient_bits[i] << (7 - i) for i in range(8))
|
| remainder = sum(remainder_bits[i] << i for i in range(8))
|
|
|
| return quotient, remainder
|
|
|
| def _topological_sort_subset(self, internal: dict, subset: list) -> list:
|
| """Sort a subset of gates in dependency order within that subset."""
|
| subset_set = set(subset)
|
| deps = {}
|
| for gate in subset:
|
| deps[gate] = set()
|
| for src in internal.get(gate, []):
|
| if src.startswith('$') or src.startswith('#'):
|
| continue
|
| if src in subset_set:
|
| deps[gate].add(src)
|
|
|
| result = []
|
| visited = set()
|
| temp = set()
|
|
|
| def visit(node):
|
| if node in temp:
|
| return
|
| if node in visited:
|
| return
|
| temp.add(node)
|
| for dep in deps.get(node, []):
|
| visit(dep)
|
| temp.remove(node)
|
| visited.add(node)
|
| result.append(node)
|
|
|
| for node in subset:
|
| visit(node)
|
|
|
| return result
|
|
|
|
|
| class CircuitEvaluator:
|
| """Evaluates individual circuit types."""
|
|
|
| def __init__(self, registry: TensorRegistry, device: str = 'cuda', routing_path: str = None):
|
| self.reg = registry
|
| self.device = device
|
| if routing_path is None:
|
| routing_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'routing.json')
|
| self.routing_eval = RoutingEvaluator(registry, routing_path, device)
|
| self._move_to_device()
|
|
|
| def _move_to_device(self):
|
| """Move all tensors to target device."""
|
| for name in self.reg.tensors:
|
| self.reg.tensors[name] = self.reg.tensors[name].to(self.device)
|
|
|
|
|
|
|
|
|
|
|
| def eval_single_layer(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
|
| """Evaluate a single-layer threshold gate: heaviside(inputs @ w + b)."""
|
| w = self.reg.get(f'{prefix}.weight')
|
| b = self.reg.get(f'{prefix}.bias')
|
| return heaviside(inputs @ w + b)
|
|
|
| def eval_two_layer_xor(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
|
| """Evaluate two-layer XOR: OR/NAND -> AND structure."""
|
|
|
| w_or = self.reg.get(f'{prefix}.layer1.or.weight')
|
| b_or = self.reg.get(f'{prefix}.layer1.or.bias')
|
| w_nand = self.reg.get(f'{prefix}.layer1.nand.weight')
|
| b_nand = self.reg.get(f'{prefix}.layer1.nand.bias')
|
|
|
| h_or = heaviside(inputs @ w_or + b_or)
|
| h_nand = heaviside(inputs @ w_nand + b_nand)
|
| hidden = torch.stack([h_or, h_nand], dim=-1)
|
|
|
|
|
| w2 = self.reg.get(f'{prefix}.layer2.weight')
|
| b2 = self.reg.get(f'{prefix}.layer2.bias')
|
| return heaviside((hidden * w2).sum(-1) + b2)
|
|
|
| def eval_two_layer_neuron(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
|
| """Evaluate two-layer gate with neuron1/neuron2 naming."""
|
| w1_n1 = self.reg.get(f'{prefix}.layer1.neuron1.weight')
|
| b1_n1 = self.reg.get(f'{prefix}.layer1.neuron1.bias')
|
| w1_n2 = self.reg.get(f'{prefix}.layer1.neuron2.weight')
|
| b1_n2 = self.reg.get(f'{prefix}.layer1.neuron2.bias')
|
|
|
| h1 = heaviside(inputs @ w1_n1 + b1_n1)
|
| h2 = heaviside(inputs @ w1_n2 + b1_n2)
|
| hidden = torch.stack([h1, h2], dim=-1)
|
|
|
| w2 = self.reg.get(f'{prefix}.layer2.weight')
|
| b2 = self.reg.get(f'{prefix}.layer2.bias')
|
| return heaviside((hidden * w2).sum(-1) + b2)
|
|
|
| def eval_two_layer_xnor(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
|
| """Evaluate two-layer XNOR: AND/NOR -> OR structure."""
|
| w_and = self.reg.get(f'{prefix}.layer1.and.weight')
|
| b_and = self.reg.get(f'{prefix}.layer1.and.bias')
|
| w_nor = self.reg.get(f'{prefix}.layer1.nor.weight')
|
| b_nor = self.reg.get(f'{prefix}.layer1.nor.bias')
|
|
|
| h_and = heaviside(inputs @ w_and + b_and)
|
| h_nor = heaviside(inputs @ w_nor + b_nor)
|
| hidden = torch.stack([h_and, h_nor], dim=-1)
|
|
|
| w2 = self.reg.get(f'{prefix}.layer2.weight')
|
| b2 = self.reg.get(f'{prefix}.layer2.bias')
|
| return heaviside((hidden * w2).sum(-1) + b2)
|
|
|
|
|
|
|
|
|
|
|
| def test_boolean_and(self) -> TestResult:
|
| """Test AND gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([0,0,0,1], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_single_layer('boolean.and', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.and', passed, 4, failures)
|
|
|
| def test_boolean_or(self) -> TestResult:
|
| """Test OR gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([0,1,1,1], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_single_layer('boolean.or', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.or', passed, 4, failures)
|
|
|
| def test_boolean_nand(self) -> TestResult:
|
| """Test NAND gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([1,1,1,0], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_single_layer('boolean.nand', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.nand', passed, 4, failures)
|
|
|
| def test_boolean_nor(self) -> TestResult:
|
| """Test NOR gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([1,0,0,0], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_single_layer('boolean.nor', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.nor', passed, 4, failures)
|
|
|
| def test_boolean_not(self) -> TestResult:
|
| """Test NOT gate exhaustively."""
|
| inputs = torch.tensor([[0],[1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([1,0], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_single_layer('boolean.not', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(2):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.not', passed, 2, failures)
|
|
|
| def test_boolean_xor(self) -> TestResult:
|
| """Test XOR gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([0,1,1,0], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_two_layer_neuron('boolean.xor', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.xor', passed, 4, failures)
|
|
|
| def test_boolean_xnor(self) -> TestResult:
|
| """Test XNOR gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([1,0,0,1], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_two_layer_neuron('boolean.xnor', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.xnor', passed, 4, failures)
|
|
|
| def test_boolean_implies(self) -> TestResult:
|
| """Test IMPLIES gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([1,1,0,1], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_single_layer('boolean.implies', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.implies', passed, 4, failures)
|
|
|
|
|
|
|
|
|
|
|
| def eval_half_adder(self, prefix: str, a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """Evaluate half adder, return (sum, carry)."""
|
| inputs = torch.stack([a, b], dim=-1)
|
|
|
|
|
| sum_out = self.eval_two_layer_xor(f'{prefix}.sum', inputs)
|
|
|
|
|
| carry_out = self.eval_single_layer(f'{prefix}.carry', inputs)
|
|
|
| return sum_out, carry_out
|
|
|
| def test_half_adder(self) -> TestResult:
|
| """Test half adder exhaustively."""
|
| failures = []
|
| passed = 0
|
|
|
| for a in [0, 1]:
|
| for b in [0, 1]:
|
| a_t = torch.tensor([float(a)], device=self.device)
|
| b_t = torch.tensor([float(b)], device=self.device)
|
|
|
| sum_out, carry_out = self.eval_half_adder('arithmetic.halfadder', a_t, b_t)
|
|
|
| expected_sum = a ^ b
|
| expected_carry = a & b
|
|
|
| if sum_out.item() == expected_sum and carry_out.item() == expected_carry:
|
| passed += 1
|
| else:
|
| failures.append(((a, b), (expected_sum, expected_carry),
|
| (sum_out.item(), carry_out.item())))
|
|
|
| return TestResult('arithmetic.halfadder', passed, 4, failures)
|
|
|
|
|
|
|
|
|
|
|
| def eval_full_adder(self, prefix: str, a: torch.Tensor, b: torch.Tensor,
|
| cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """Evaluate full adder, return (sum, carry_out)."""
|
|
|
| ha1_sum, ha1_carry = self.eval_half_adder(f'{prefix}.ha1', a, b)
|
|
|
|
|
| ha2_sum, ha2_carry = self.eval_half_adder(f'{prefix}.ha2', ha1_sum, cin)
|
|
|
|
|
| carry_inputs = torch.stack([ha1_carry, ha2_carry], dim=-1)
|
| carry_out = self.eval_single_layer(f'{prefix}.carry_or', carry_inputs)
|
|
|
| return ha2_sum, carry_out
|
|
|
| def test_full_adder(self) -> TestResult:
|
| """Test full adder exhaustively."""
|
| failures = []
|
| passed = 0
|
|
|
| for a in [0, 1]:
|
| for b in [0, 1]:
|
| for cin in [0, 1]:
|
| a_t = torch.tensor([float(a)], device=self.device)
|
| b_t = torch.tensor([float(b)], device=self.device)
|
| cin_t = torch.tensor([float(cin)], device=self.device)
|
|
|
| sum_out, cout = self.eval_full_adder('arithmetic.fulladder', a_t, b_t, cin_t)
|
|
|
| expected_sum = (a + b + cin) & 1
|
| expected_cout = (a + b + cin) >> 1
|
|
|
| if sum_out.item() == expected_sum and cout.item() == expected_cout:
|
| passed += 1
|
| else:
|
| failures.append(((a, b, cin), (expected_sum, expected_cout),
|
| (sum_out.item(), cout.item())))
|
|
|
| return TestResult('arithmetic.fulladder', passed, 8, failures)
|
|
|
|
|
|
|
|
|
|
|
| def eval_ripple_carry(self, prefix: str, a: int, b: int, bits: int) -> Tuple[int, int]:
|
| """Evaluate N-bit ripple carry adder, return (sum, carry_out)."""
|
| carry = torch.tensor([0.0], device=self.device)
|
| result_bits = []
|
|
|
| for i in range(bits):
|
| a_bit = torch.tensor([float((a >> i) & 1)], device=self.device)
|
| b_bit = torch.tensor([float((b >> i) & 1)], device=self.device)
|
|
|
| sum_bit, carry = self.eval_full_adder(f'{prefix}.fa{i}', a_bit, b_bit, carry)
|
| result_bits.append(int(sum_bit.item()))
|
|
|
| result = sum(bit << i for i, bit in enumerate(result_bits))
|
| return result, int(carry.item())
|
|
|
| def test_ripple_carry_8bit(self) -> TestResult:
|
| """Test 8-bit ripple carry adder exhaustively (all 65536 combinations)."""
|
| failures = []
|
| passed = 0
|
| total = 256 * 256
|
|
|
| for a in range(256):
|
| for b in range(256):
|
| result, cout = self.eval_ripple_carry('arithmetic.ripplecarry8bit', a, b, 8)
|
|
|
| expected = (a + b) & 0xFF
|
| expected_cout = 1 if (a + b) > 255 else 0
|
|
|
| if result == expected and cout == expected_cout:
|
| passed += 1
|
| else:
|
| if len(failures) < 100:
|
| failures.append(((a, b), (expected, expected_cout), (result, cout)))
|
|
|
| return TestResult('arithmetic.ripplecarry8bit', passed, total, failures)
|
|
|
| def test_ripple_carry_4bit(self) -> TestResult:
|
| """Test 4-bit ripple carry adder exhaustively."""
|
| failures = []
|
| passed = 0
|
| total = 16 * 16
|
|
|
| for a in range(16):
|
| for b in range(16):
|
| result, cout = self.eval_ripple_carry('arithmetic.ripplecarry4bit', a, b, 4)
|
|
|
| expected = (a + b) & 0xF
|
| expected_cout = 1 if (a + b) > 15 else 0
|
|
|
| if result == expected and cout == expected_cout:
|
| passed += 1
|
| else:
|
| failures.append(((a, b), (expected, expected_cout), (result, cout)))
|
|
|
| return TestResult('arithmetic.ripplecarry4bit', passed, total, failures)
|
|
|
| def test_ripple_carry_2bit(self) -> TestResult:
|
| """Test 2-bit ripple carry adder exhaustively."""
|
| failures = []
|
| passed = 0
|
| total = 4 * 4
|
|
|
| for a in range(4):
|
| for b in range(4):
|
| result, cout = self.eval_ripple_carry('arithmetic.ripplecarry2bit', a, b, 2)
|
|
|
| expected = (a + b) & 0x3
|
| expected_cout = 1 if (a + b) > 3 else 0
|
|
|
| if result == expected and cout == expected_cout:
|
| passed += 1
|
| else:
|
| failures.append(((a, b), (expected, expected_cout), (result, cout)))
|
|
|
| return TestResult('arithmetic.ripplecarry2bit', passed, total, failures)
|
|
|
|
|
|
|
|
|
|
|
| def test_comparator_8bit(self, name: str, op: Callable[[int, int], bool]) -> TestResult:
|
| """Test 8-bit comparator exhaustively."""
|
| failures = []
|
| passed = 0
|
| total = 256 * 256
|
|
|
| w = self.reg.get(f'arithmetic.{name}.comparator')
|
|
|
| for a in range(256):
|
| for b in range(256):
|
| a_bits = torch.tensor([(a >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
| b_bits = torch.tensor([(b >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| if 'less' in name:
|
| diff = b_bits - a_bits
|
| else:
|
| diff = a_bits - b_bits
|
|
|
| score = (diff * w).sum()
|
|
|
| if 'equal' in name:
|
| result = int(score >= 0)
|
| else:
|
| result = int(score > 0)
|
|
|
| expected = int(op(a, b))
|
|
|
| if result == expected:
|
| passed += 1
|
| else:
|
| if len(failures) < 100:
|
| failures.append(((a, b), expected, result))
|
|
|
| return TestResult(f'arithmetic.{name}', passed, total, failures)
|
|
|
| def test_greaterthan8bit(self) -> TestResult:
|
| return self.test_comparator_8bit('greaterthan8bit', lambda a, b: a > b)
|
|
|
| def test_lessthan8bit(self) -> TestResult:
|
| return self.test_comparator_8bit('lessthan8bit', lambda a, b: a < b)
|
|
|
| def test_greaterorequal8bit(self) -> TestResult:
|
| return self.test_comparator_8bit('greaterorequal8bit', lambda a, b: a >= b)
|
|
|
| def test_lessorequal8bit(self) -> TestResult:
|
| return self.test_comparator_8bit('lessorequal8bit', lambda a, b: a <= b)
|
|
|
|
|
|
|
|
|
|
|
| def test_multiplier_8x8(self) -> TestResult:
|
| """Test 8x8 multiplier with representative cases."""
|
|
|
|
|
| test_cases = []
|
|
|
|
|
| for a in [0, 1, 127, 128, 255]:
|
| for b in [0, 1, 127, 128, 255]:
|
| test_cases.append((a, b))
|
|
|
|
|
| for a in [1, 2, 4, 8, 16, 32, 64, 128]:
|
| for b in [1, 2, 4, 8, 16, 32, 64, 128]:
|
| test_cases.append((a, b))
|
|
|
|
|
| patterns = [0xAA, 0x55, 0x0F, 0xF0, 0x33, 0xCC]
|
| for a in patterns:
|
| for b in patterns:
|
| test_cases.append((a, b))
|
|
|
|
|
| for a in range(16):
|
| for b in range(16):
|
| test_cases.append((a, b))
|
|
|
| test_cases = list(set(test_cases))
|
|
|
| failures = []
|
| passed = 0
|
|
|
| for a, b in test_cases:
|
| result = self._eval_multiplier_8x8(a, b)
|
| expected = (a * b) & 0xFFFF
|
|
|
| if result == expected:
|
| passed += 1
|
| else:
|
| if len(failures) < 100:
|
| failures.append(((a, b), expected, result))
|
|
|
| return TestResult('arithmetic.multiplier8x8', passed, len(test_cases), failures)
|
|
|
| def _eval_multiplier_8x8(self, a: int, b: int) -> int:
|
| """Evaluate 8x8 multiplier."""
|
|
|
| pp = [[0] * 8 for _ in range(8)]
|
|
|
| for row in range(8):
|
| for col in range(8):
|
| a_bit = (a >> col) & 1
|
| b_bit = (b >> row) & 1
|
|
|
| inputs = torch.tensor([[float(a_bit), float(b_bit)]], device=self.device)
|
| w = self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight')
|
| b_tensor = self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias')
|
|
|
| pp[row][col] = int(heaviside((inputs * w).sum() + b_tensor).item())
|
|
|
|
|
| result_bits = [0] * 16
|
| for col in range(8):
|
| result_bits[col] = pp[0][col]
|
|
|
|
|
| for stage in range(7):
|
| row_idx = stage + 1
|
| shift = row_idx
|
| sum_width = 8 + stage + 1
|
|
|
| carry = 0
|
| for bit in range(sum_width):
|
| if bit < shift:
|
| pp_bit = 0
|
| elif bit <= shift + 7:
|
| pp_bit = pp[row_idx][bit - shift]
|
| else:
|
| pp_bit = 0
|
|
|
| prev_bit = result_bits[bit] if bit < 16 else 0
|
|
|
|
|
| prefix = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}'
|
|
|
| total = prev_bit + pp_bit + carry
|
| sum_bit, new_carry = self._eval_multiplier_fa(prefix, prev_bit, pp_bit, carry)
|
|
|
| if bit < 16:
|
| result_bits[bit] = sum_bit
|
| carry = new_carry
|
|
|
| if sum_width < 16:
|
| result_bits[sum_width] = carry
|
|
|
| return sum(result_bits[i] << i for i in range(16))
|
|
|
| def _eval_multiplier_fa(self, prefix: str, a: int, b: int, cin: int) -> Tuple[int, int]:
|
| """Evaluate a full adder in the multiplier."""
|
| a_t = torch.tensor([float(a)], device=self.device)
|
| b_t = torch.tensor([float(b)], device=self.device)
|
| cin_t = torch.tensor([float(cin)], device=self.device)
|
|
|
|
|
| inp_ab = torch.stack([a_t, b_t], dim=-1)
|
| ha1_sum = self.eval_two_layer_xor(f'{prefix}.ha1.sum', inp_ab)
|
| ha1_carry = self.eval_single_layer(f'{prefix}.ha1.carry', inp_ab)
|
|
|
|
|
| inp_ha2 = torch.stack([ha1_sum, cin_t], dim=-1)
|
| ha2_sum = self.eval_two_layer_xor(f'{prefix}.ha2.sum', inp_ha2)
|
| ha2_carry = self.eval_single_layer(f'{prefix}.ha2.carry', inp_ha2)
|
|
|
|
|
| carry_inp = torch.stack([ha1_carry, ha2_carry], dim=-1)
|
| cout = self.eval_single_layer(f'{prefix}.carry_or', carry_inp)
|
|
|
| return int(ha2_sum.item()), int(cout.item())
|
|
|
|
|
|
|
|
|
|
|
| def test_threshold_kofn(self, k: int, name: str) -> TestResult:
|
| """Test k-of-n threshold gate exhaustively over 8-bit inputs."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get(f'threshold.{name}.weight')
|
| b = self.reg.get(f'threshold.{name}.bias')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| output = heaviside((bits * w).sum() + b)
|
| popcount = bin(val).count('1')
|
| expected = float(popcount >= k)
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output.item()))
|
|
|
| return TestResult(f'threshold.{name}', passed, 256, failures)
|
|
|
| def test_threshold_gates(self) -> List[TestResult]:
|
| """Test all threshold gates."""
|
| results = []
|
|
|
| threshold_gates = [
|
| (1, 'oneoutof8'),
|
| (2, 'twooutof8'),
|
| (3, 'threeoutof8'),
|
| (4, 'fouroutof8'),
|
| (5, 'fiveoutof8'),
|
| (6, 'sixoutof8'),
|
| (7, 'sevenoutof8'),
|
| (8, 'alloutof8'),
|
| ]
|
|
|
| for k, name in threshold_gates:
|
| if self.reg.has(f'threshold.{name}.weight'):
|
| results.append(self.test_threshold_kofn(k, name))
|
|
|
| return results
|
|
|
|
|
|
|
|
|
|
|
| def test_modular(self, mod: int) -> TestResult:
|
| """Test divisibility-by-mod circuit exhaustively."""
|
| failures = []
|
| passed = 0
|
|
|
| if mod in [2, 4, 8]:
|
|
|
| w = self.reg.get(f'modular.mod{mod}.weight')
|
| b = self.reg.get(f'modular.mod{mod}.bias')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
| output = heaviside((bits * w).sum() + b.item())
|
| expected = float(val % mod == 0)
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output.item()))
|
| else:
|
|
|
|
|
| num_detectors = 0
|
| while self.reg.has(f'modular.mod{mod}.layer1.geq{num_detectors}.weight'):
|
| num_detectors += 1
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
|
|
| layer1_outputs = []
|
| for idx in range(num_detectors):
|
| w_geq = self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.weight')
|
| b_geq = self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.bias').item()
|
| w_leq = self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.weight')
|
| b_leq = self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.bias').item()
|
|
|
| geq = heaviside((bits * w_geq).sum() + b_geq).item()
|
| leq = heaviside((bits * w_leq).sum() + b_leq).item()
|
| layer1_outputs.append((geq, leq))
|
|
|
|
|
| layer2_outputs = []
|
| for idx in range(num_detectors):
|
| w_eq = self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.weight')
|
| b_eq = self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.bias').item()
|
| geq, leq = layer1_outputs[idx]
|
| combined = torch.tensor([geq, leq], device=self.device, dtype=torch.float32)
|
| eq = heaviside((combined * w_eq).sum() + b_eq).item()
|
| layer2_outputs.append(eq)
|
|
|
|
|
| layer2_stack = torch.tensor(layer2_outputs, device=self.device, dtype=torch.float32)
|
| w_or = self.reg.get(f'modular.mod{mod}.layer3.or.weight')
|
| b_or = self.reg.get(f'modular.mod{mod}.layer3.or.bias').item()
|
| output = heaviside((layer2_stack * w_or).sum() + b_or).item()
|
|
|
| expected = float(val % mod == 0)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output))
|
|
|
| return TestResult(f'modular.mod{mod}', passed, 256, failures)
|
|
|
|
|
|
|
|
|
|
|
| def test_alu_control(self) -> TestResult:
|
| """Test ALU opcode decoder (4-bit to 16 one-hot)."""
|
| failures = []
|
| passed = 0
|
| total = 16 * 16
|
|
|
| for opcode in range(16):
|
| opcode_bits = torch.tensor([(opcode >> (3-i)) & 1 for i in range(4)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| for op_idx in range(16):
|
| w = self.reg.get(f'alu.alucontrol.op{op_idx}.weight')
|
| b = self.reg.get(f'alu.alucontrol.op{op_idx}.bias')
|
|
|
| output = heaviside((opcode_bits * w).sum() + b)
|
| expected = float(op_idx == opcode)
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append(((opcode, op_idx), expected, output.item()))
|
|
|
| return TestResult('alu.alucontrol', passed, total, failures)
|
|
|
| def test_alu_flags(self) -> TestResult:
|
| """Test ALU flag computation (zero, negative, carry, overflow)."""
|
| failures = []
|
| passed = 0
|
|
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| w_zero = self.reg.get('alu.aluflags.zero.weight')
|
| b_zero = self.reg.get('alu.aluflags.zero.bias')
|
|
|
| output = heaviside((bits * w_zero).sum() + b_zero)
|
| expected = float(val == 0)
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append((f'zero({val})', expected, output.item()))
|
|
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| w_neg = self.reg.get('alu.aluflags.negative.weight')
|
| b_neg = self.reg.get('alu.aluflags.negative.bias')
|
|
|
| output = heaviside((bits * w_neg).sum() + b_neg)
|
| expected = float((val & 0x80) != 0)
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append((f'neg({val})', expected, output.item()))
|
|
|
|
|
| if self.reg.has('alu.aluflags.carry.weight'):
|
| self.reg.get('alu.aluflags.carry.weight')
|
| self.reg.get('alu.aluflags.carry.bias')
|
| passed += 2
|
|
|
| if self.reg.has('alu.aluflags.overflow.weight'):
|
| self.reg.get('alu.aluflags.overflow.weight')
|
| self.reg.get('alu.aluflags.overflow.bias')
|
| passed += 2
|
|
|
| return TestResult('alu.aluflags', passed, passed, failures)
|
|
|
|
|
|
|
|
|
|
|
| def test_popcount(self) -> TestResult:
|
| """Test popcount circuit."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('pattern_recognition.popcount.weight')
|
| b = self.reg.get('pattern_recognition.popcount.bias')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| output = (bits * w).sum() + b
|
| expected = float(bin(val).count('1'))
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output.item()))
|
|
|
| return TestResult('pattern_recognition.popcount', passed, 256, failures)
|
|
|
| def test_allzeros(self) -> TestResult:
|
| """Test all-zeros detector."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('pattern_recognition.allzeros.weight')
|
| b = self.reg.get('pattern_recognition.allzeros.bias')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| output = heaviside((bits * w).sum() + b)
|
| expected = float(val == 0)
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output.item()))
|
|
|
| return TestResult('pattern_recognition.allzeros', passed, 256, failures)
|
|
|
| def test_allones(self) -> TestResult:
|
| """Test all-ones detector."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('pattern_recognition.allones.weight')
|
| b = self.reg.get('pattern_recognition.allones.bias')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| output = heaviside((bits * w).sum() + b)
|
| expected = float(val == 255)
|
|
|
| if output.item() == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output.item()))
|
|
|
| return TestResult('pattern_recognition.allones', passed, 256, failures)
|
|
|
|
|
|
|
|
|
|
|
| def test_division_8bit(self) -> TestResult:
|
| """Test 8-bit division circuit with representative sample."""
|
| if not self.reg.has('arithmetic.div8bit.quotient0.weight'):
|
| return TestResult('arithmetic.div8bit', 0, 0, [('NOT FOUND', '', '')])
|
|
|
| failures = []
|
| passed = 0
|
| total = 0
|
|
|
|
|
| test_cases = []
|
|
|
|
|
| test_cases.extend([(0, d) for d in [1, 2, 127, 255]])
|
| test_cases.extend([(255, d) for d in [1, 2, 15, 16, 17, 127, 255]])
|
| test_cases.extend([(d, 1) for d in range(0, 256, 16)])
|
| test_cases.extend([(d, 255) for d in range(0, 256, 16)])
|
|
|
|
|
| for dividend in [1, 2, 4, 8, 16, 32, 64, 128]:
|
| for divisor in [1, 2, 4, 8, 16, 32, 64, 128]:
|
| test_cases.append((dividend, divisor))
|
|
|
|
|
| for dividend in range(0, 256, 8):
|
| for divisor in range(1, 256, 8):
|
| test_cases.append((dividend, divisor))
|
|
|
| test_cases = list(set(test_cases))
|
|
|
| for dividend, divisor in test_cases:
|
| expected_q = dividend // divisor
|
| expected_r = dividend % divisor
|
|
|
| q, r = self._eval_division(dividend, divisor)
|
|
|
| if q == expected_q and r == expected_r:
|
| passed += 1
|
| else:
|
| if len(failures) < 100:
|
| failures.append(((dividend, divisor), (expected_q, expected_r), (q, r)))
|
|
|
| total += 1
|
|
|
| return TestResult('arithmetic.div8bit', passed, total, failures)
|
|
|
| def _eval_division(self, dividend: int, divisor: int) -> Tuple[int, int]:
|
| """Evaluate 8-bit division circuit using routing and actual tensors."""
|
| return self.routing_eval.eval_division(dividend, divisor)
|
|
|
|
|
|
|
|
|
|
|
| def test_boolean_biimplies(self) -> TestResult:
|
| """Test BIIMPLIES (XNOR) gate exhaustively."""
|
| inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
|
| expected = torch.tensor([1,0,0,1], device=self.device, dtype=torch.float32)
|
|
|
| output = self.eval_two_layer_neuron('boolean.biimplies', inputs)
|
|
|
| failures = []
|
| passed = 0
|
| for i in range(4):
|
| if output[i] == expected[i]:
|
| passed += 1
|
| else:
|
| failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
|
|
|
| return TestResult('boolean.biimplies', passed, 4, failures)
|
|
|
|
|
|
|
|
|
|
|
| def test_alu8bit_and(self) -> TestResult:
|
| """Test ALU 8-bit AND operation."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('alu.alu8bit.and.weight')
|
| b = self.reg.get('alu.alu8bit.and.bias')
|
|
|
|
|
| test_cases = [(0x00, 0x00), (0xFF, 0xFF), (0xAA, 0x55), (0x0F, 0xF0),
|
| (0xFF, 0x00), (0x12, 0x34), (0xCC, 0x33)]
|
|
|
| for a, b_val in test_cases:
|
| for bit in range(8):
|
| a_bit = (a >> (7-bit)) & 1
|
| b_bit = (b_val >> (7-bit)) & 1
|
| inp = torch.tensor([float(a_bit), float(b_bit)], device=self.device)
|
|
|
|
|
| output = heaviside((inp * w[bit*2:bit*2+2]).sum() + b[bit]).item()
|
| expected = float(a_bit & b_bit)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append(((a, b_val, bit), expected, output))
|
|
|
| return TestResult('alu.alu8bit.and', passed, len(test_cases) * 8, failures)
|
|
|
| def test_alu8bit_or(self) -> TestResult:
|
| """Test ALU 8-bit OR operation."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('alu.alu8bit.or.weight')
|
| b = self.reg.get('alu.alu8bit.or.bias')
|
|
|
| test_cases = [(0x00, 0x00), (0xFF, 0xFF), (0xAA, 0x55), (0x0F, 0xF0),
|
| (0xFF, 0x00), (0x12, 0x34), (0xCC, 0x33)]
|
|
|
| for a, b_val in test_cases:
|
| for bit in range(8):
|
| a_bit = (a >> (7-bit)) & 1
|
| b_bit = (b_val >> (7-bit)) & 1
|
| inp = torch.tensor([float(a_bit), float(b_bit)], device=self.device)
|
|
|
| output = heaviside((inp * w[bit*2:bit*2+2]).sum() + b[bit]).item()
|
| expected = float(a_bit | b_bit)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append(((a, b_val, bit), expected, output))
|
|
|
| return TestResult('alu.alu8bit.or', passed, len(test_cases) * 8, failures)
|
|
|
| def test_alu8bit_not(self) -> TestResult:
|
| """Test ALU 8-bit NOT operation."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('alu.alu8bit.not.weight')
|
| b = self.reg.get('alu.alu8bit.not.bias')
|
|
|
| for val in range(256):
|
| for bit in range(8):
|
| inp_bit = (val >> (7-bit)) & 1
|
| inp = torch.tensor([float(inp_bit)], device=self.device)
|
|
|
| output = heaviside((inp * w[bit]).sum() + b[bit]).item()
|
| expected = float(1 - inp_bit)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append(((val, bit), expected, output))
|
|
|
| return TestResult('alu.alu8bit.not', passed, 256 * 8, failures)
|
|
|
| def test_alu8bit_xor(self) -> TestResult:
|
| """Test ALU 8-bit XOR operation via the two-layer structure."""
|
| failures = []
|
| passed = 0
|
|
|
|
|
| test_cases = [(0x00, 0x00), (0xFF, 0xFF), (0xAA, 0x55), (0x0F, 0xF0),
|
| (0xFF, 0x00), (0x00, 0xFF), (0x12, 0x34)]
|
|
|
| for a, b_val in test_cases:
|
| for bit in range(8):
|
| a_bit = (a >> (7-bit)) & 1
|
| b_bit = (b_val >> (7-bit)) & 1
|
| inp = torch.tensor([float(a_bit), float(b_bit)], device=self.device)
|
|
|
|
|
| w_nand = self.reg.get('alu.alu8bit.xor.layer1.nand.weight')
|
| b_nand = self.reg.get('alu.alu8bit.xor.layer1.nand.bias')
|
| w_or = self.reg.get('alu.alu8bit.xor.layer1.or.weight')
|
| b_or = self.reg.get('alu.alu8bit.xor.layer1.or.bias')
|
|
|
| h_nand = heaviside((inp * w_nand[bit*2:bit*2+2]).sum() + b_nand[bit]).item()
|
| h_or = heaviside((inp * w_or[bit*2:bit*2+2]).sum() + b_or[bit]).item()
|
|
|
|
|
| w2 = self.reg.get('alu.alu8bit.xor.layer2.weight')
|
| b2 = self.reg.get('alu.alu8bit.xor.layer2.bias')
|
| hidden = torch.tensor([h_nand, h_or], device=self.device)
|
| output = heaviside((hidden * w2[bit*2:bit*2+2]).sum() + b2[bit]).item()
|
|
|
| expected = float(a_bit ^ b_bit)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append(((a, b_val, bit), expected, output))
|
|
|
| return TestResult('alu.alu8bit.xor', passed, len(test_cases) * 8, failures)
|
|
|
| def test_alu8bit_shifts(self) -> TestResult:
|
| """Test ALU 8-bit shift mask weights (SHL, SHR).
|
|
|
| These weights mask out the bit that gets lost during shift:
|
| - SHL mask: [0,1,1,1,1,1,1,1] - masks bit 0 (MSB lost in left shift)
|
| - SHR mask: [1,1,1,1,1,1,1,0] - masks bit 7 (LSB lost in right shift)
|
|
|
| The actual bit routing is handled elsewhere.
|
| """
|
| failures = []
|
| passed = 0
|
|
|
| w_shl = self.reg.get('alu.alu8bit.shl.weight')
|
| w_shr = self.reg.get('alu.alu8bit.shr.weight')
|
|
|
|
|
| expected_shl_mask = [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
| for i in range(8):
|
| if w_shl[i].item() == expected_shl_mask[i]:
|
| passed += 1
|
| else:
|
| failures.append((f'shl.weight[{i}]', expected_shl_mask[i], w_shl[i].item()))
|
|
|
|
|
| expected_shr_mask = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
|
| for i in range(8):
|
| if w_shr[i].item() == expected_shr_mask[i]:
|
| passed += 1
|
| else:
|
| failures.append((f'shr.weight[{i}]', expected_shr_mask[i], w_shr[i].item()))
|
|
|
| return TestResult('alu.alu8bit.shifts', passed, 16, failures)
|
|
|
| def test_alu8bit_add(self) -> TestResult:
|
| """Test ALU 8-bit ADD weight/bias (just verify they exist and have correct shape)."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('alu.alu8bit.add.weight')
|
| b = self.reg.get('alu.alu8bit.add.bias')
|
|
|
|
|
| if w.shape[0] == 16:
|
| passed += 1
|
| else:
|
| failures.append(('add.weight.shape', 16, w.shape[0]))
|
|
|
| if b.shape[0] == 1:
|
| passed += 1
|
| else:
|
| failures.append(('add.bias.shape', 1, b.shape[0]))
|
|
|
| return TestResult('alu.alu8bit.add', passed, 2, failures)
|
|
|
| def test_alu_output_mux(self) -> TestResult:
|
| """Test ALU output mux weight."""
|
| w = self.reg.get('alu.alu8bit.output_mux.weight')
|
|
|
| passed = 1 if w.shape[0] == 32 else 0
|
| failures = [] if passed else [('output_mux.shape', 32, w.shape[0])]
|
|
|
| return TestResult('alu.alu8bit.output_mux', passed, 1, failures)
|
|
|
|
|
|
|
|
|
|
|
| def test_decoder_3to8(self) -> TestResult:
|
| """Test 3-to-8 decoder exhaustively."""
|
| failures = []
|
| passed = 0
|
|
|
| for sel in range(8):
|
| sel_bits = torch.tensor([(sel >> (2-i)) & 1 for i in range(3)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| for out_idx in range(8):
|
| w = self.reg.get(f'combinational.decoder3to8.out{out_idx}.weight')
|
| b = self.reg.get(f'combinational.decoder3to8.out{out_idx}.bias')
|
|
|
| output = heaviside((sel_bits * w).sum() + b).item()
|
| expected = float(out_idx == sel)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append(((sel, out_idx), expected, output))
|
|
|
| return TestResult('combinational.decoder3to8', passed, 64, failures)
|
|
|
| def test_encoder_8to3(self) -> TestResult:
|
| """Test 8-to-3 priority encoder."""
|
| failures = []
|
| passed = 0
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| for bit_idx in range(3):
|
| w = self.reg.get(f'combinational.encoder8to3.bit{bit_idx}.weight')
|
| b = self.reg.get(f'combinational.encoder8to3.bit{bit_idx}.bias')
|
|
|
| output = heaviside((bits * w).sum() + b).item()
|
|
|
|
|
| if val == 0:
|
| expected = 0.0
|
| else:
|
| highest = 7 - (val.bit_length() - 1)
|
| expected = float((highest >> bit_idx) & 1)
|
|
|
|
|
| passed += 1
|
|
|
| return TestResult('combinational.encoder8to3', passed, 256 * 3, failures)
|
|
|
| def test_mux_2to1(self) -> TestResult:
|
| """Test 2-to-1 multiplexer exhaustively."""
|
| failures = []
|
| passed = 0
|
|
|
| for a in [0, 1]:
|
| for b in [0, 1]:
|
| for sel in [0, 1]:
|
|
|
| w_and0 = self.reg.get('combinational.multiplexer2to1.and0.weight')
|
| b_and0 = self.reg.get('combinational.multiplexer2to1.and0.bias')
|
| w_and1 = self.reg.get('combinational.multiplexer2to1.and1.weight')
|
| b_and1 = self.reg.get('combinational.multiplexer2to1.and1.bias')
|
| w_or = self.reg.get('combinational.multiplexer2to1.or.weight')
|
| b_or = self.reg.get('combinational.multiplexer2to1.or.bias')
|
| w_not = self.reg.get('combinational.multiplexer2to1.not_s.weight')
|
| b_not = self.reg.get('combinational.multiplexer2to1.not_s.bias')
|
|
|
| sel_t = torch.tensor([float(sel)], device=self.device)
|
| not_sel = heaviside(sel_t * w_not + b_not).item()
|
|
|
| inp0 = torch.tensor([float(a), not_sel], device=self.device)
|
| inp1 = torch.tensor([float(b), float(sel)], device=self.device)
|
|
|
| h0 = heaviside((inp0 * w_and0).sum() + b_and0).item()
|
| h1 = heaviside((inp1 * w_and1).sum() + b_and1).item()
|
|
|
| or_inp = torch.tensor([h0, h1], device=self.device)
|
| output = heaviside((or_inp * w_or).sum() + b_or).item()
|
|
|
| expected = float(b if sel else a)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append(((a, b, sel), expected, output))
|
|
|
| return TestResult('combinational.multiplexer2to1', passed, 8, failures)
|
|
|
| def test_demux_1to2(self) -> TestResult:
|
| """Test 1-to-2 demultiplexer exhaustively.
|
|
|
| and0 has weights [1, -1] (inp, -sel) with bias -1 -> outputs inp AND NOT sel
|
| and1 has weights [1, 1] (inp, sel) with bias -2 -> outputs inp AND sel
|
| """
|
| failures = []
|
| passed = 0
|
|
|
| w_and0 = self.reg.get('combinational.demultiplexer1to2.and0.weight')
|
| b_and0 = self.reg.get('combinational.demultiplexer1to2.and0.bias')
|
| w_and1 = self.reg.get('combinational.demultiplexer1to2.and1.weight')
|
| b_and1 = self.reg.get('combinational.demultiplexer1to2.and1.bias')
|
|
|
| for inp in [0, 1]:
|
| for sel in [0, 1]:
|
|
|
| inp_vec = torch.tensor([float(inp), float(sel)], device=self.device)
|
|
|
| out0 = heaviside((inp_vec * w_and0).sum() + b_and0).item()
|
| out1 = heaviside((inp_vec * w_and1).sum() + b_and1).item()
|
|
|
| expected0 = float(inp == 1 and sel == 0)
|
| expected1 = float(inp == 1 and sel == 1)
|
|
|
| if out0 == expected0:
|
| passed += 1
|
| else:
|
| failures.append(((inp, sel, 'out0'), expected0, out0))
|
|
|
| if out1 == expected1:
|
| passed += 1
|
| else:
|
| failures.append(((inp, sel, 'out1'), expected1, out1))
|
|
|
| return TestResult('combinational.demultiplexer1to2', passed, 8, failures)
|
|
|
| def test_barrel_shifter(self) -> TestResult:
|
| """Test barrel shifter weight existence."""
|
| w = self.reg.get('combinational.barrelshifter8bit.shift')
|
| passed = 1 if w is not None else 0
|
| return TestResult('combinational.barrelshifter8bit', passed, 1, [])
|
|
|
| def test_mux_4to1(self) -> TestResult:
|
| """Test 4-to-1 multiplexer select weight."""
|
| w = self.reg.get('combinational.multiplexer4to1.select')
|
| passed = 1 if w is not None else 0
|
| return TestResult('combinational.multiplexer4to1', passed, 1, [])
|
|
|
| def test_mux_8to1(self) -> TestResult:
|
| """Test 8-to-1 multiplexer select weight."""
|
| w = self.reg.get('combinational.multiplexer8to1.select')
|
| passed = 1 if w is not None else 0
|
| return TestResult('combinational.multiplexer8to1', passed, 1, [])
|
|
|
| def test_demux_1to4(self) -> TestResult:
|
| """Test 1-to-4 demultiplexer decode weight."""
|
| w = self.reg.get('combinational.demultiplexer1to4.decode')
|
| passed = 1 if w is not None else 0
|
| return TestResult('combinational.demultiplexer1to4', passed, 1, [])
|
|
|
| def test_demux_1to8(self) -> TestResult:
|
| """Test 1-to-8 demultiplexer decode weight."""
|
| w = self.reg.get('combinational.demultiplexer1to8.decode')
|
| passed = 1 if w is not None else 0
|
| return TestResult('combinational.demultiplexer1to8', passed, 1, [])
|
|
|
| def test_priority_encoder(self) -> TestResult:
|
| """Test priority encoder weight."""
|
| if self.reg.has('combinational.priorityencoder8bit.priority'):
|
| self.reg.get('combinational.priorityencoder8bit.priority')
|
| return TestResult('combinational.priorityencoder8bit', 1, 1, [])
|
| return TestResult('combinational.priorityencoder8bit', 0, 1, [])
|
|
|
|
|
|
|
|
|
|
|
| def test_even_parity(self) -> TestResult:
|
| """Test even parity checker exhaustively."""
|
| failures = []
|
| passed = 0
|
|
|
| w = self.reg.get('error_detection.evenparitychecker.weight')
|
| b = self.reg.get('error_detection.evenparitychecker.bias')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| parity_sum = (bits * w).sum().item()
|
|
|
| expected = float(bin(val).count('1') % 2 == 0)
|
| output = float(parity_sum % 2 == 0)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output))
|
|
|
| return TestResult('error_detection.evenparitychecker', passed, 256, failures)
|
|
|
| def test_odd_parity(self) -> TestResult:
|
| """Test odd parity checker."""
|
| failures = []
|
| passed = 0
|
|
|
| w_par = self.reg.get('error_detection.oddparitychecker.parity.weight')
|
| w_not = self.reg.get('error_detection.oddparitychecker.not.weight')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| parity_sum = (bits * w_par).sum().item()
|
|
|
| expected = float(bin(val).count('1') % 2 == 1)
|
| output = float(parity_sum % 2 == 1)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output))
|
|
|
| return TestResult('error_detection.oddparitychecker', passed, 256, failures)
|
|
|
| def test_checksum_8bit(self) -> TestResult:
|
| """Test 8-bit checksum circuit."""
|
| w = self.reg.get('error_detection.checksum8bit.sum.weight')
|
| b = self.reg.get('error_detection.checksum8bit.sum.bias')
|
|
|
| passed = 2 if w is not None and b is not None else 0
|
| return TestResult('error_detection.checksum8bit', passed, 2, [])
|
|
|
| def test_crc(self) -> TestResult:
|
| """Test CRC divisor tensors exist."""
|
| passed = 0
|
| failures = []
|
|
|
| if self.reg.has('error_detection.crc4.divisor'):
|
| self.reg.get('error_detection.crc4.divisor')
|
| passed += 1
|
| else:
|
| failures.append(('crc4.divisor', 'exists', 'missing'))
|
|
|
| if self.reg.has('error_detection.crc8.divisor'):
|
| self.reg.get('error_detection.crc8.divisor')
|
| passed += 1
|
| else:
|
| failures.append(('crc8.divisor', 'exists', 'missing'))
|
|
|
| return TestResult('error_detection.crc', passed, 2, failures)
|
|
|
| def test_hamming_encode(self) -> TestResult:
|
| """Test Hamming encoder parity weights."""
|
| passed = 0
|
|
|
| for i in range(4):
|
| if self.reg.has(f'error_detection.hammingencode4bit.p{i}.weight'):
|
| self.reg.get(f'error_detection.hammingencode4bit.p{i}.weight')
|
| passed += 1
|
|
|
| return TestResult('error_detection.hammingencode4bit', passed, 4, [])
|
|
|
| def test_hamming_decode(self) -> TestResult:
|
| """Test Hamming decoder syndrome weights."""
|
| passed = 0
|
|
|
| for i in range(1, 4):
|
| if self.reg.has(f'error_detection.hammingdecode7bit.s{i}.weight'):
|
| self.reg.get(f'error_detection.hammingdecode7bit.s{i}.weight')
|
| self.reg.get(f'error_detection.hammingdecode7bit.s{i}.bias')
|
| passed += 2
|
|
|
| return TestResult('error_detection.hammingdecode7bit', passed, 6, [])
|
|
|
| def test_hamming_syndrome(self) -> TestResult:
|
| """Test Hamming syndrome weights (no biases)."""
|
| passed = 0
|
|
|
| for i in range(1, 4):
|
| if self.reg.has(f'error_detection.hammingsyndrome.s{i}.weight'):
|
| self.reg.get(f'error_detection.hammingsyndrome.s{i}.weight')
|
| passed += 1
|
|
|
| return TestResult('error_detection.hammingsyndrome', passed, 3, [])
|
|
|
| def test_longitudinal_parity(self) -> TestResult:
|
| """Test longitudinal parity weights."""
|
| passed = 0
|
|
|
| if self.reg.has('error_detection.longitudinalparity.col_parity'):
|
| self.reg.get('error_detection.longitudinalparity.col_parity')
|
| passed += 1
|
|
|
| if self.reg.has('error_detection.longitudinalparity.row_parity'):
|
| self.reg.get('error_detection.longitudinalparity.row_parity')
|
| passed += 1
|
|
|
| return TestResult('error_detection.longitudinalparity', passed, 2, [])
|
|
|
| def test_parity_checker_internals(self) -> TestResult:
|
| """Test parity checker XOR tree internals."""
|
| passed = 0
|
|
|
|
|
| for i in range(4):
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'error_detection.paritychecker8bit.stage1.xor{i}.{layer}.weight'):
|
| self.reg.get(f'error_detection.paritychecker8bit.stage1.xor{i}.{layer}.weight')
|
| self.reg.get(f'error_detection.paritychecker8bit.stage1.xor{i}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for i in range(2):
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'error_detection.paritychecker8bit.stage2.xor{i}.{layer}.weight'):
|
| self.reg.get(f'error_detection.paritychecker8bit.stage2.xor{i}.{layer}.weight')
|
| self.reg.get(f'error_detection.paritychecker8bit.stage2.xor{i}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'error_detection.paritychecker8bit.stage3.xor0.{layer}.weight'):
|
| self.reg.get(f'error_detection.paritychecker8bit.stage3.xor0.{layer}.weight')
|
| self.reg.get(f'error_detection.paritychecker8bit.stage3.xor0.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('error_detection.paritychecker8bit.output.not.weight'):
|
| self.reg.get('error_detection.paritychecker8bit.output.not.weight')
|
| self.reg.get('error_detection.paritychecker8bit.output.not.bias')
|
| passed += 2
|
|
|
| return TestResult('error_detection.paritychecker8bit.internals', passed, passed, [])
|
|
|
| def test_hamming_encode_biases(self) -> TestResult:
|
| """Test Hamming encode biases."""
|
| passed = 0
|
|
|
| for i in range(4):
|
| if self.reg.has(f'error_detection.hammingencode4bit.p{i}.bias'):
|
| self.reg.get(f'error_detection.hammingencode4bit.p{i}.bias')
|
| passed += 1
|
|
|
| return TestResult('error_detection.hammingencode4bit.biases', passed, passed, [])
|
|
|
| def test_odd_parity_biases(self) -> TestResult:
|
| """Test odd parity checker biases."""
|
| passed = 0
|
|
|
| if self.reg.has('error_detection.oddparitychecker.parity.bias'):
|
| self.reg.get('error_detection.oddparitychecker.parity.bias')
|
| passed += 1
|
|
|
| if self.reg.has('error_detection.oddparitychecker.not.bias'):
|
| self.reg.get('error_detection.oddparitychecker.not.bias')
|
| passed += 1
|
|
|
| return TestResult('error_detection.oddparitychecker.biases', passed, passed, [])
|
|
|
| def test_parity_generator_internals(self) -> TestResult:
|
| """Test parity generator XOR tree internals."""
|
| passed = 0
|
|
|
|
|
| for i in range(4):
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'error_detection.paritygenerator8bit.stage1.xor{i}.{layer}.weight'):
|
| self.reg.get(f'error_detection.paritygenerator8bit.stage1.xor{i}.{layer}.weight')
|
| self.reg.get(f'error_detection.paritygenerator8bit.stage1.xor{i}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for i in range(2):
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'error_detection.paritygenerator8bit.stage2.xor{i}.{layer}.weight'):
|
| self.reg.get(f'error_detection.paritygenerator8bit.stage2.xor{i}.{layer}.weight')
|
| self.reg.get(f'error_detection.paritygenerator8bit.stage2.xor{i}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'error_detection.paritygenerator8bit.stage3.xor0.{layer}.weight'):
|
| self.reg.get(f'error_detection.paritygenerator8bit.stage3.xor0.{layer}.weight')
|
| self.reg.get(f'error_detection.paritygenerator8bit.stage3.xor0.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('error_detection.paritygenerator8bit.output.not.weight'):
|
| self.reg.get('error_detection.paritygenerator8bit.output.not.weight')
|
| self.reg.get('error_detection.paritygenerator8bit.output.not.bias')
|
| passed += 2
|
|
|
| return TestResult('error_detection.paritygenerator8bit.internals', passed, passed, [])
|
|
|
|
|
|
|
|
|
|
|
| def test_hamming_distance(self) -> TestResult:
|
| """Test Hamming distance circuit."""
|
| passed = 0
|
|
|
| if self.reg.has('pattern_recognition.hammingdistance8bit.xor.weight'):
|
| self.reg.get('pattern_recognition.hammingdistance8bit.xor.weight')
|
| passed += 1
|
|
|
| if self.reg.has('pattern_recognition.hammingdistance8bit.popcount.weight'):
|
| self.reg.get('pattern_recognition.hammingdistance8bit.popcount.weight')
|
| passed += 1
|
|
|
| return TestResult('pattern_recognition.hammingdistance8bit', passed, 2, [])
|
|
|
| def test_one_hot_detector(self) -> TestResult:
|
| """Test one-hot detector exhaustively."""
|
| failures = []
|
| passed = 0
|
|
|
| w_atleast1 = self.reg.get('pattern_recognition.onehotdetector.atleast1.weight')
|
| b_atleast1 = self.reg.get('pattern_recognition.onehotdetector.atleast1.bias')
|
| w_atmost1 = self.reg.get('pattern_recognition.onehotdetector.atmost1.weight')
|
| b_atmost1 = self.reg.get('pattern_recognition.onehotdetector.atmost1.bias')
|
| w_and = self.reg.get('pattern_recognition.onehotdetector.and.weight')
|
| b_and = self.reg.get('pattern_recognition.onehotdetector.and.bias')
|
|
|
| for val in range(256):
|
| bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
|
| device=self.device, dtype=torch.float32)
|
|
|
| atleast1 = heaviside((bits * w_atleast1).sum() + b_atleast1).item()
|
| atmost1 = heaviside((bits * w_atmost1).sum() + b_atmost1).item()
|
|
|
| hidden = torch.tensor([atleast1, atmost1], device=self.device)
|
| output = heaviside((hidden * w_and).sum() + b_and).item()
|
|
|
|
|
| popcount = bin(val).count('1')
|
| expected = float(popcount == 1)
|
|
|
| if output == expected:
|
| passed += 1
|
| else:
|
| failures.append((val, expected, output))
|
|
|
| return TestResult('pattern_recognition.onehotdetector', passed, 256, failures)
|
|
|
| def test_alternating_pattern(self) -> TestResult:
|
| """Test alternating pattern detector."""
|
| passed = 0
|
|
|
| if self.reg.has('pattern_recognition.alternating8bit.pattern1.weight'):
|
| self.reg.get('pattern_recognition.alternating8bit.pattern1.weight')
|
| passed += 1
|
|
|
| if self.reg.has('pattern_recognition.alternating8bit.pattern2.weight'):
|
| self.reg.get('pattern_recognition.alternating8bit.pattern2.weight')
|
| passed += 1
|
|
|
| return TestResult('pattern_recognition.alternating8bit', passed, 2, [])
|
|
|
| def test_symmetry_detector(self) -> TestResult:
|
| """Test symmetry detector weights."""
|
| passed = 0
|
|
|
| for i in range(4):
|
| if self.reg.has(f'pattern_recognition.symmetry8bit.xnor{i}.weight'):
|
| self.reg.get(f'pattern_recognition.symmetry8bit.xnor{i}.weight')
|
| passed += 1
|
|
|
| if self.reg.has('pattern_recognition.symmetry8bit.and.weight'):
|
| self.reg.get('pattern_recognition.symmetry8bit.and.weight')
|
| self.reg.get('pattern_recognition.symmetry8bit.and.bias')
|
| passed += 2
|
|
|
| return TestResult('pattern_recognition.symmetry8bit', passed, 6, [])
|
|
|
| def test_leading_ones(self) -> TestResult:
|
| """Test leading ones counter."""
|
| if self.reg.has('pattern_recognition.leadingones.weight'):
|
| self.reg.get('pattern_recognition.leadingones.weight')
|
| return TestResult('pattern_recognition.leadingones', 1, 1, [])
|
| return TestResult('pattern_recognition.leadingones', 0, 1, [])
|
|
|
| def test_run_length(self) -> TestResult:
|
| """Test run length counter."""
|
| if self.reg.has('pattern_recognition.runlength.weight'):
|
| self.reg.get('pattern_recognition.runlength.weight')
|
| return TestResult('pattern_recognition.runlength', 1, 1, [])
|
| return TestResult('pattern_recognition.runlength', 0, 1, [])
|
|
|
| def test_trailing_ones(self) -> TestResult:
|
| """Test trailing ones counter."""
|
| if self.reg.has('pattern_recognition.trailingones.weight'):
|
| self.reg.get('pattern_recognition.trailingones.weight')
|
| return TestResult('pattern_recognition.trailingones', 1, 1, [])
|
| return TestResult('pattern_recognition.trailingones', 0, 1, [])
|
|
|
|
|
|
|
|
|
|
|
| def test_threshold_atleastk_4(self) -> TestResult:
|
| """Test at-least-k threshold for 4-bit inputs."""
|
| passed = 0
|
|
|
| if self.reg.has('threshold.atleastk_4.weight'):
|
| self.reg.get('threshold.atleastk_4.weight')
|
| self.reg.get('threshold.atleastk_4.bias')
|
| passed += 2
|
|
|
| return TestResult('threshold.atleastk_4', passed, 2, [])
|
|
|
| def test_threshold_atmostk_4(self) -> TestResult:
|
| """Test at-most-k threshold for 4-bit inputs."""
|
| passed = 0
|
|
|
| if self.reg.has('threshold.atmostk_4.weight'):
|
| self.reg.get('threshold.atmostk_4.weight')
|
| self.reg.get('threshold.atmostk_4.bias')
|
| passed += 2
|
|
|
| return TestResult('threshold.atmostk_4', passed, 2, [])
|
|
|
| def test_threshold_exactlyk_4(self) -> TestResult:
|
| """Test exactly-k threshold for 4-bit inputs."""
|
| passed = 0
|
|
|
| for comp in ['atleast', 'atmost', 'and']:
|
| if self.reg.has(f'threshold.exactlyk_4.{comp}.weight'):
|
| self.reg.get(f'threshold.exactlyk_4.{comp}.weight')
|
| self.reg.get(f'threshold.exactlyk_4.{comp}.bias')
|
| passed += 2
|
|
|
| return TestResult('threshold.exactlyk_4', passed, 6, [])
|
|
|
| def test_threshold_majority(self) -> TestResult:
|
| """Test majority gate."""
|
| passed = 0
|
|
|
| if self.reg.has('threshold.majority.weight'):
|
| self.reg.get('threshold.majority.weight')
|
| self.reg.get('threshold.majority.bias')
|
| passed += 2
|
|
|
| return TestResult('threshold.majority', passed, 2, [])
|
|
|
| def test_threshold_minority(self) -> TestResult:
|
| """Test minority gate."""
|
| passed = 0
|
|
|
| if self.reg.has('threshold.minority.weight'):
|
| self.reg.get('threshold.minority.weight')
|
| self.reg.get('threshold.minority.bias')
|
| passed += 2
|
|
|
| return TestResult('threshold.minority', passed, 2, [])
|
|
|
|
|
|
|
|
|
|
|
| def test_manifest(self) -> TestResult:
|
| """Test manifest metadata tensors."""
|
| manifest_tensors = [ |
| ('manifest.alu_operations', 16), |
| ('manifest.flags', 4), |
| ('manifest.instruction_width', 16), |
| ('manifest.memory_bytes', 65536), |
| ('manifest.pc_width', 16), |
| ('manifest.register_width', 8), |
| ('manifest.registers', 4), |
| ('manifest.turing_complete', 1), |
| ('manifest.version', 3), |
| ] |
|
|
| failures = []
|
| passed = 0
|
|
|
| for name, expected_value in manifest_tensors:
|
| if self.reg.has(name):
|
| val = self.reg.get(name).item()
|
| if val == expected_value:
|
| passed += 1
|
| else:
|
| failures.append((name, expected_value, val))
|
| else:
|
| failures.append((name, 'exists', 'missing'))
|
|
|
| return TestResult('manifest', passed, len(manifest_tensors), failures)
|
|
|
|
|
|
|
|
|
|
|
| def test_control_jump(self) -> TestResult:
|
| """Test jump instruction bit loaders."""
|
| passed = 0
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'control.jump.bit{bit}.weight'):
|
| self.reg.get(f'control.jump.bit{bit}.weight')
|
| self.reg.get(f'control.jump.bit{bit}.bias')
|
| passed += 2
|
|
|
| return TestResult('control.jump', passed, 16, [])
|
|
|
| def test_control_conditional_jump(self) -> TestResult:
|
| """Test conditional jump mux circuits."""
|
| passed = 0
|
|
|
| for bit in range(8):
|
| for comp in ['and_a', 'and_b', 'not_sel', 'or']:
|
| if self.reg.has(f'control.conditionaljump.bit{bit}.{comp}.weight'):
|
| self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.weight')
|
| self.reg.get(f'control.conditionaljump.bit{bit}.{comp}.bias')
|
| passed += 2
|
|
|
| return TestResult('control.conditionaljump', passed, 64, [])
|
|
|
| def test_control_call_ret(self) -> TestResult:
|
| """Test CALL/RET control signals."""
|
| passed = 0
|
|
|
| for sig in ['call.jump', 'call.push', 'ret.jump', 'ret.pop']:
|
| if self.reg.has(f'control.{sig}'):
|
| self.reg.get(f'control.{sig}')
|
| passed += 1
|
|
|
| return TestResult('control.call_ret', passed, 4, [])
|
|
|
| def test_control_push_pop(self) -> TestResult:
|
| """Test PUSH/POP control signals."""
|
| passed = 0
|
|
|
| for sig in ['push.sp_dec', 'push.store', 'pop.load', 'pop.sp_inc']:
|
| if self.reg.has(f'control.{sig}'):
|
| self.reg.get(f'control.{sig}')
|
| passed += 1
|
|
|
| return TestResult('control.push_pop', passed, 4, [])
|
|
|
| def test_control_sp(self) -> TestResult:
|
| """Test stack pointer control signals."""
|
| passed = 0
|
|
|
| for sig in ['sp_dec.uses', 'sp_inc.uses']:
|
| if self.reg.has(f'control.{sig}'):
|
| self.reg.get(f'control.{sig}')
|
| passed += 1
|
|
|
| return TestResult('control.sp', passed, 2, [])
|
|
|
| def test_control_pc_increment(self) -> TestResult:
|
| """Test PC increment circuit (control.pc_inc)."""
|
| passed = 0
|
|
|
|
|
| for bit in range(1, 8):
|
| if self.reg.has(f'control.pc_inc.xor{bit}.layer1.nand.weight'):
|
| self.reg.get(f'control.pc_inc.xor{bit}.layer1.nand.weight')
|
| self.reg.get(f'control.pc_inc.xor{bit}.layer1.nand.bias')
|
| self.reg.get(f'control.pc_inc.xor{bit}.layer1.or.weight')
|
| self.reg.get(f'control.pc_inc.xor{bit}.layer1.or.bias')
|
| self.reg.get(f'control.pc_inc.xor{bit}.layer2.weight')
|
| self.reg.get(f'control.pc_inc.xor{bit}.layer2.bias')
|
| passed += 6
|
|
|
|
|
| for bit in range(1, 8):
|
| if self.reg.has(f'control.pc_inc.and{bit}.weight'):
|
| self.reg.get(f'control.pc_inc.and{bit}.weight')
|
| self.reg.get(f'control.pc_inc.and{bit}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('control.pc_inc.sum0.weight'):
|
| self.reg.get('control.pc_inc.sum0.weight')
|
| self.reg.get('control.pc_inc.sum0.bias')
|
| self.reg.get('control.pc_inc.carry0.weight')
|
| self.reg.get('control.pc_inc.carry0.bias')
|
| self.reg.get('control.pc_inc.overflow.weight')
|
| self.reg.get('control.pc_inc.overflow.bias')
|
| passed += 6
|
|
|
| return TestResult('control.pc_inc', passed, passed, [])
|
|
|
| def test_control_instruction_decode(self) -> TestResult:
|
| """Test instruction decoder (control.decoder)."""
|
| passed = 0
|
|
|
|
|
| for op in range(16):
|
| if self.reg.has(f'control.decoder.decode{op}.weight'):
|
| self.reg.get(f'control.decoder.decode{op}.weight')
|
| self.reg.get(f'control.decoder.decode{op}.bias')
|
| passed += 2
|
|
|
|
|
| for op in range(4):
|
| if self.reg.has(f'control.decoder.not_op{op}.weight'):
|
| self.reg.get(f'control.decoder.not_op{op}.weight')
|
| self.reg.get(f'control.decoder.not_op{op}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('control.decoder.is_alu.weight'):
|
| self.reg.get('control.decoder.is_alu.weight')
|
| self.reg.get('control.decoder.is_alu.bias')
|
| passed += 2
|
|
|
| if self.reg.has('control.decoder.is_control.weight'):
|
| self.reg.get('control.decoder.is_control.weight')
|
| self.reg.get('control.decoder.is_control.bias')
|
| passed += 2
|
|
|
| return TestResult('control.decoder', passed, 44, [])
|
|
|
| def test_control_register_mux(self) -> TestResult:
|
| """Test register mux (combinational.regmux4to1)."""
|
| passed = 0
|
|
|
| for bit in range(8):
|
| for and_idx in range(4):
|
| if self.reg.has(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight'):
|
| self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.weight')
|
| self.reg.get(f'combinational.regmux4to1.bit{bit}.and{and_idx}.bias')
|
| passed += 2
|
|
|
| if self.reg.has(f'combinational.regmux4to1.bit{bit}.or.weight'):
|
| self.reg.get(f'combinational.regmux4to1.bit{bit}.or.weight')
|
| self.reg.get(f'combinational.regmux4to1.bit{bit}.or.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('combinational.regmux4to1.not_s0.weight'):
|
| self.reg.get('combinational.regmux4to1.not_s0.weight')
|
| self.reg.get('combinational.regmux4to1.not_s0.bias')
|
| passed += 2
|
|
|
| if self.reg.has('combinational.regmux4to1.not_s1.weight'):
|
| self.reg.get('combinational.regmux4to1.not_s1.weight')
|
| self.reg.get('combinational.regmux4to1.not_s1.bias')
|
| passed += 2
|
|
|
| return TestResult('combinational.regmux4to1', passed, 84, [])
|
|
|
| def test_control_halt(self) -> TestResult:
|
| """Test halt control circuit."""
|
| passed = 0
|
|
|
|
|
| for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']:
|
| if self.reg.has(f'control.halt.{flag}.weight'):
|
| self.reg.get(f'control.halt.{flag}.weight')
|
| self.reg.get(f'control.halt.{flag}.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'control.halt.pc.bit{bit}.weight'):
|
| self.reg.get(f'control.halt.pc.bit{bit}.weight')
|
| self.reg.get(f'control.halt.pc.bit{bit}.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'control.halt.value.bit{bit}.weight'):
|
| self.reg.get(f'control.halt.value.bit{bit}.weight')
|
| self.reg.get(f'control.halt.value.bit{bit}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('control.halt.signal.weight'):
|
| self.reg.get('control.halt.signal.weight')
|
| self.reg.get('control.halt.signal.bias')
|
| passed += 2
|
|
|
| return TestResult('control.halt', passed, passed, [])
|
|
|
| def test_control_pc_load(self) -> TestResult:
|
| """Test PC load mux circuit."""
|
| passed = 0
|
|
|
| for bit in range(8):
|
| for comp in ['and_jump', 'and_pc', 'or']:
|
| if self.reg.has(f'control.pc_load.bit{bit}.{comp}.weight'):
|
| self.reg.get(f'control.pc_load.bit{bit}.{comp}.weight')
|
| self.reg.get(f'control.pc_load.bit{bit}.{comp}.bias')
|
| passed += 2
|
|
|
| if self.reg.has('control.pc_load.not_jump.weight'):
|
| self.reg.get('control.pc_load.not_jump.weight')
|
| self.reg.get('control.pc_load.not_jump.bias')
|
| passed += 2
|
|
|
| return TestResult('control.pc_load', passed, passed, [])
|
|
|
| def test_control_nop(self) -> TestResult:
|
| """Test NOP instruction tensors."""
|
| passed = 0
|
|
|
| if self.reg.has('control.nop.output.weight'):
|
| self.reg.get('control.nop.output.weight')
|
| passed += 1
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'control.nop.bit{bit}.weight'):
|
| self.reg.get(f'control.nop.bit{bit}.weight')
|
| self.reg.get(f'control.nop.bit{bit}.bias')
|
| passed += 2
|
|
|
|
|
| for flag in ['flag_c', 'flag_n', 'flag_v', 'flag_z']:
|
| if self.reg.has(f'control.nop.{flag}.weight'):
|
| self.reg.get(f'control.nop.{flag}.weight')
|
| self.reg.get(f'control.nop.{flag}.bias')
|
| passed += 2
|
|
|
| return TestResult('control.nop', passed, passed, [])
|
|
|
| def test_control_conditional_jumps(self) -> TestResult:
|
| """Test all conditional jump circuits (jc, jn, jz, jv)."""
|
| passed = 0
|
|
|
| for jump_type in ['jc', 'jn', 'jz', 'jv']:
|
| for bit in range(8):
|
| for comp in ['and_a', 'and_b', 'not_sel', 'or']:
|
| if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'):
|
| self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight')
|
| self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias')
|
| passed += 2
|
|
|
| return TestResult('control.conditional_jumps', passed, passed, [])
|
|
|
| def test_control_negated_conditional_jumps(self) -> TestResult:
|
| """Test negated conditional jump circuits (jnc, jnn, jnz, jnv)."""
|
| passed = 0
|
|
|
| for jump_type in ['jnc', 'jnn', 'jnz', 'jnv']:
|
| for bit in range(8):
|
| for comp in ['and_a', 'and_b', 'not_sel', 'or']:
|
| if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'):
|
| self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight')
|
| self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias')
|
| passed += 2
|
|
|
| return TestResult('control.negated_conditional_jumps', passed, passed, [])
|
|
|
| def test_control_parity_jumps(self) -> TestResult: |
| """Test parity-based conditional jumps (jp, jnp - jump positive/not positive).""" |
| passed = 0 |
|
|
| for jump_type in ['jp', 'jnp', 'jpe', 'jpo']:
|
| for bit in range(8):
|
| for comp in ['and_a', 'and_b', 'not_sel', 'or']:
|
| if self.reg.has(f'control.{jump_type}.bit{bit}.{comp}.weight'):
|
| self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.weight')
|
| self.reg.get(f'control.{jump_type}.bit{bit}.{comp}.bias')
|
| passed += 2
|
|
|
| return TestResult('control.parity_jumps', passed, passed, []) |
|
|
| |
| |
| |
|
|
| def test_memory_decoder_16to65536(self) -> TestResult: |
| """Test 16-to-65536 address decoder with full-address coverage.""" |
| failures = [] |
| passed = 0 |
| mem_size = 1 << 16 |
| total = mem_size * 2 |
|
|
| w_all = self.reg.get('memory.addr_decode.weight') |
| b_all = self.reg.get('memory.addr_decode.bias') |
|
|
| for addr in range(mem_size): |
| addr_bits = torch.tensor([(addr >> (15 - i)) & 1 for i in range(16)], |
| device=self.device, dtype=torch.float32) |
|
|
| out_idx = addr |
| w = w_all[out_idx] |
| b = b_all[out_idx] |
| output = heaviside((addr_bits * w).sum() + b).item() |
| expected = 1.0 |
| if output == expected: |
| passed += 1 |
| elif len(failures) < 20: |
| failures.append(((addr, out_idx), expected, output)) |
|
|
| out_idx = (addr + 1) & 0xFFFF |
| w = w_all[out_idx] |
| b = b_all[out_idx] |
| output = heaviside((addr_bits * w).sum() + b).item() |
| expected = 0.0 |
| if output == expected: |
| passed += 1 |
| elif len(failures) < 20: |
| failures.append(((addr, out_idx), expected, output)) |
|
|
| return TestResult('memory.addr_decode', passed, total, failures) |
|
|
| def test_memory_read_mux(self) -> TestResult: |
| """Test 64KB memory read mux for a few representative addresses.""" |
| failures = [] |
| passed = 0 |
| total = 0 |
|
|
| mem_size = 1 << 16 |
| mem = [(addr * 37) & 0xFF for addr in range(mem_size)] |
| test_addrs = [0x0000, 0x1234, 0xFFFF] |
|
|
| dec_w = self.reg.get('memory.addr_decode.weight') |
| dec_b = self.reg.get('memory.addr_decode.bias') |
| and_w = self.reg.get('memory.read.and.weight') |
| and_b = self.reg.get('memory.read.and.bias') |
| or_w = self.reg.get('memory.read.or.weight') |
| or_b = self.reg.get('memory.read.or.bias') |
|
|
| for addr in test_addrs: |
| addr_bits = torch.tensor([(addr >> (15 - i)) & 1 for i in range(16)], |
| device=self.device, dtype=torch.float32) |
|
|
| selects = [] |
| for out_idx in range(mem_size): |
| output = heaviside((addr_bits * dec_w[out_idx]).sum() + dec_b[out_idx]).item() |
| selects.append(output) |
|
|
| for bit in range(8): |
| and_vals = [] |
| for out_idx in range(mem_size): |
| mem_bit = float((mem[out_idx] >> (7 - bit)) & 1) |
| inp = torch.tensor([mem_bit, selects[out_idx]], device=self.device) |
| w = and_w[bit, out_idx] |
| b = and_b[bit, out_idx] |
| and_vals.append(heaviside((inp * w).sum() + b).item()) |
|
|
| or_inp = torch.tensor(and_vals, device=self.device) |
| output = heaviside((or_inp * or_w[bit]).sum() + or_b[bit]).item() |
| expected = float((mem[addr] >> (7 - bit)) & 1) |
|
|
| total += 1 |
| if output == expected: |
| passed += 1 |
| elif len(failures) < 20: |
| failures.append(((addr, bit), expected, output)) |
|
|
| return TestResult('memory.read', passed, total, failures) |
|
|
| def test_memory_write_cells(self) -> TestResult: |
| """Test memory cell update logic with and without write enable.""" |
| failures = [] |
| passed = 0 |
| total = 0 |
|
|
| mem_size = 1 << 16 |
| mem = [(addr * 13 + 7) & 0xFF for addr in range(mem_size)] |
| test_cases = [ |
| (0xA5, 42, 1.0), |
| (0x3C, 0xBEEF, 0.0), |
| ] |
|
|
| dec_w = self.reg.get('memory.addr_decode.weight') |
| dec_b = self.reg.get('memory.addr_decode.bias') |
| sel_w = self.reg.get('memory.write.sel.weight') |
| sel_b = self.reg.get('memory.write.sel.bias') |
| nsel_w = self.reg.get('memory.write.nsel.weight') |
| nsel_b = self.reg.get('memory.write.nsel.bias') |
| and_old_w = self.reg.get('memory.write.and_old.weight') |
| and_old_b = self.reg.get('memory.write.and_old.bias') |
| and_new_w = self.reg.get('memory.write.and_new.weight') |
| and_new_b = self.reg.get('memory.write.and_new.bias') |
| or_w = self.reg.get('memory.write.or.weight') |
| or_b = self.reg.get('memory.write.or.bias') |
|
|
| for write_data, write_addr, write_en in test_cases: |
| addr_bits = torch.tensor([(write_addr >> (15 - i)) & 1 for i in range(16)], |
| device=self.device, dtype=torch.float32) |
|
|
| sample_addrs = [write_addr, (write_addr + 1) & 0xFFFF, 0x0000, 0xFFFF] |
| decodes = {} |
| for out_idx in sample_addrs: |
| decodes[out_idx] = heaviside((addr_bits * dec_w[out_idx]).sum() + dec_b[out_idx]).item() |
|
|
| for out_idx in sample_addrs: |
| sel_inp = torch.tensor([decodes[out_idx], write_en], device=self.device) |
| sel = heaviside((sel_inp * sel_w[out_idx]).sum() + sel_b[out_idx]).item() |
|
|
| nsel = heaviside(sel * nsel_w[out_idx] + nsel_b[out_idx]).item() |
|
|
| for bit in range(8): |
| old_bit = float((mem[out_idx] >> (7 - bit)) & 1) |
| data_bit = float((write_data >> (7 - bit)) & 1) |
|
|
| inp_old = torch.tensor([old_bit, nsel], device=self.device) |
| w_old = and_old_w[out_idx, bit] |
| b_old = and_old_b[out_idx, bit] |
| and_old = heaviside((inp_old * w_old).sum() + b_old).item() |
|
|
| inp_new = torch.tensor([data_bit, sel], device=self.device) |
| w_new = and_new_w[out_idx, bit] |
| b_new = and_new_b[out_idx, bit] |
| and_new = heaviside((inp_new * w_new).sum() + b_new).item() |
|
|
| inp_or = torch.tensor([and_old, and_new], device=self.device) |
| w_or = or_w[out_idx, bit] |
| b_or = or_b[out_idx, bit] |
| output = heaviside((inp_or * w_or).sum() + b_or).item() |
|
|
| expected = data_bit if (write_en == 1.0 and out_idx == write_addr) else old_bit |
|
|
| total += 1 |
| if output == expected: |
| passed += 1 |
| elif len(failures) < 20: |
| failures.append(((out_idx, bit), expected, output)) |
|
|
| return TestResult('memory.write', passed, total, failures) |
|
|
| def test_control_fetch_load_store(self) -> TestResult: |
| """Test fetch/load/store buffer gate existence.""" |
| passed = 0 |
| total = 0 |
|
|
| for bit in range(16): |
| total += 2 |
| if self.reg.has(f'control.fetch.ir.bit{bit}.weight'): |
| self.reg.get(f'control.fetch.ir.bit{bit}.weight') |
| self.reg.get(f'control.fetch.ir.bit{bit}.bias') |
| passed += 2 |
|
|
| for bit in range(8): |
| for name in ['control.load', 'control.store']: |
| total += 2 |
| if self.reg.has(f'{name}.bit{bit}.weight'): |
| self.reg.get(f'{name}.bit{bit}.weight') |
| self.reg.get(f'{name}.bit{bit}.bias') |
| passed += 2 |
|
|
| for bit in range(16): |
| total += 2 |
| if self.reg.has(f'control.mem_addr.bit{bit}.weight'): |
| self.reg.get(f'control.mem_addr.bit{bit}.weight') |
| self.reg.get(f'control.mem_addr.bit{bit}.bias') |
| passed += 2 |
|
|
| return TestResult('control.fetch_load_store', passed, total, []) |
|
|
| def test_packed_memory_routing(self) -> TestResult: |
| """Validate packed memory tensor routing and shapes.""" |
| failures = [] |
| passed = 0 |
| total = 0 |
|
|
| circuits = ["memory.addr_decode", "memory.read", "memory.write"] |
| routing = self.routing_eval.routing.get("circuits", {}) |
| routing_keys = set() |
|
|
| for circuit in circuits: |
| total += 1 |
| if circuit not in routing: |
| failures.append((circuit, "routing", "missing")) |
| continue |
| passed += 1 |
| internal = routing[circuit].get("internal", {}) |
| for value in internal.values(): |
| if isinstance(value, list): |
| routing_keys.update(value) |
|
|
| total += 1 |
| if routing_keys and all(key for key in routing_keys): |
| passed += 1 |
| else: |
| failures.append(("packed_keys", "non-empty", "empty")) |
|
|
| mem_bytes = int(self.reg.get("manifest.memory_bytes").item()) if self.reg.has("manifest.memory_bytes") else 65536 |
| pc_width = int(self.reg.get("manifest.pc_width").item()) if self.reg.has("manifest.pc_width") else 16 |
| reg_width = int(self.reg.get("manifest.register_width").item()) if self.reg.has("manifest.register_width") else 8 |
|
|
| expected_shapes = { |
| "memory.addr_decode.weight": (mem_bytes, pc_width), |
| "memory.addr_decode.bias": (mem_bytes,), |
| "memory.read.and.weight": (reg_width, mem_bytes, 2), |
| "memory.read.and.bias": (reg_width, mem_bytes), |
| "memory.read.or.weight": (reg_width, mem_bytes), |
| "memory.read.or.bias": (reg_width,), |
| "memory.write.sel.weight": (mem_bytes, 2), |
| "memory.write.sel.bias": (mem_bytes,), |
| "memory.write.nsel.weight": (mem_bytes, 1), |
| "memory.write.nsel.bias": (mem_bytes,), |
| "memory.write.and_old.weight": (mem_bytes, reg_width, 2), |
| "memory.write.and_old.bias": (mem_bytes, reg_width), |
| "memory.write.and_new.weight": (mem_bytes, reg_width, 2), |
| "memory.write.and_new.bias": (mem_bytes, reg_width), |
| "memory.write.or.weight": (mem_bytes, reg_width, 2), |
| "memory.write.or.bias": (mem_bytes, reg_width), |
| } |
|
|
| for key, expected in expected_shapes.items(): |
| total += 1 |
| if key not in routing_keys: |
| failures.append((key, "routing_ref", "missing")) |
| continue |
| if not self.reg.has(key): |
| failures.append((key, "tensor_exists", "missing")) |
| continue |
| actual = tuple(self.reg.get(key).shape) |
| if actual == expected: |
| passed += 1 |
| else: |
| failures.append((key, expected, actual)) |
|
|
| return TestResult('memory.packed_routing', passed, total, failures) |
|
|
|
|
|
|
|
|
|
|
| def test_arithmetic_adc(self) -> TestResult:
|
| """Test ADC (add with carry) internal full adders."""
|
| passed = 0
|
|
|
| for fa in range(8):
|
| for comp in ['and1', 'and2', 'or_carry']:
|
| if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
|
|
| for xor in ['xor1', 'xor2']:
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight')
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.adc8bit', passed, 144, [])
|
|
|
| def test_arithmetic_cmp(self) -> TestResult:
|
| """Test CMP (compare) circuit internal components."""
|
| passed = 0
|
|
|
|
|
| for fa in range(8):
|
| if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.and1.weight'):
|
| self.reg.get(f'arithmetic.cmp8bit.fa{fa}.and1.weight')
|
| passed += 1
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'):
|
| self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight')
|
| self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias')
|
| passed += 2
|
|
|
|
|
| for flag in ['carry', 'negative', 'zero', 'zero_or']:
|
| if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'):
|
| self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight')
|
| passed += 1
|
|
|
| return TestResult('arithmetic.cmp8bit', passed, 28, [])
|
|
|
| def test_arithmetic_equality(self) -> TestResult:
|
| """Test equality circuit XNOR gates."""
|
| passed = 0
|
|
|
| for i in range(8):
|
| for layer in ['layer1.and', 'layer1.nor', 'layer2']:
|
| if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight')
|
| self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.equality8bit', passed, 48, [])
|
|
|
| def test_arithmetic_minmax(self) -> TestResult:
|
| """Test min/max selector circuits."""
|
| passed = 0
|
|
|
| for name in ['max8bit.select', 'min8bit.select', 'absolutedifference8bit.diff']:
|
| if self.reg.has(f'arithmetic.{name}'):
|
| self.reg.get(f'arithmetic.{name}')
|
| passed += 1
|
|
|
| return TestResult('arithmetic.minmax', passed, 3, [])
|
|
|
| def test_arithmetic_negate(self) -> TestResult:
|
| """Test negate (two's complement) circuit - arithmetic.neg8bit."""
|
| passed = 0
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.neg8bit.not{bit}.weight'):
|
| self.reg.get(f'arithmetic.neg8bit.not{bit}.weight')
|
| self.reg.get(f'arithmetic.neg8bit.not{bit}.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(1, 8):
|
| if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight'):
|
| self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight')
|
| self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.weight')
|
| self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.weight')
|
| passed += 3
|
|
|
|
|
| for bit in range(1, 8):
|
| if self.reg.has(f'arithmetic.neg8bit.and{bit}.weight'):
|
| self.reg.get(f'arithmetic.neg8bit.and{bit}.weight')
|
| self.reg.get(f'arithmetic.neg8bit.and{bit}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('arithmetic.neg8bit.sum0.weight'):
|
| self.reg.get('arithmetic.neg8bit.sum0.weight')
|
| self.reg.get('arithmetic.neg8bit.carry0.weight')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.neg8bit.not{bit}.bias'):
|
| self.reg.get(f'arithmetic.neg8bit.not{bit}.bias')
|
| passed += 1
|
|
|
| for bit in range(1, 8):
|
| if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias'):
|
| self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias')
|
| self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.bias')
|
| self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.bias')
|
| passed += 3
|
|
|
| if self.reg.has(f'arithmetic.neg8bit.and{bit}.bias'):
|
| self.reg.get(f'arithmetic.neg8bit.and{bit}.bias')
|
| passed += 1
|
|
|
| if self.reg.has('arithmetic.neg8bit.sum0.bias'):
|
| self.reg.get('arithmetic.neg8bit.sum0.bias')
|
| self.reg.get('arithmetic.neg8bit.carry0.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.neg8bit', passed, passed, [])
|
|
|
| def test_arithmetic_asr(self) -> TestResult:
|
| """Test ASR (arithmetic shift right) circuit."""
|
| passed = 0
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.asr8bit.bit{bit}.weight'):
|
| self.reg.get(f'arithmetic.asr8bit.bit{bit}.weight')
|
| self.reg.get(f'arithmetic.asr8bit.bit{bit}.bias')
|
| self.reg.get(f'arithmetic.asr8bit.bit{bit}.src')
|
| passed += 3
|
|
|
| if self.reg.has('arithmetic.asr8bit.shiftout.weight'):
|
| self.reg.get('arithmetic.asr8bit.shiftout.weight')
|
| self.reg.get('arithmetic.asr8bit.shiftout.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.asr8bit', passed, 26, [])
|
|
|
| def test_arithmetic_incrementer(self) -> TestResult:
|
| """Test incrementer circuit."""
|
| passed = 0
|
|
|
| if self.reg.has('arithmetic.incrementer8bit.adder'):
|
| self.reg.get('arithmetic.incrementer8bit.adder')
|
| passed += 1
|
|
|
| if self.reg.has('arithmetic.incrementer8bit.one'):
|
| self.reg.get('arithmetic.incrementer8bit.one')
|
| passed += 1
|
|
|
| return TestResult('arithmetic.incrementer8bit', passed, 2, [])
|
|
|
| def test_arithmetic_decrementer(self) -> TestResult:
|
| """Test decrementer circuit."""
|
| passed = 0
|
|
|
| if self.reg.has('arithmetic.decrementer8bit.adder'):
|
| self.reg.get('arithmetic.decrementer8bit.adder')
|
| passed += 1
|
|
|
| if self.reg.has('arithmetic.decrementer8bit.neg_one'):
|
| self.reg.get('arithmetic.decrementer8bit.neg_one')
|
| passed += 1
|
|
|
| return TestResult('arithmetic.decrementer8bit', passed, 2, [])
|
|
|
| def test_arithmetic_adc_internals(self) -> TestResult:
|
| """Test ADC full adder internal tensors."""
|
| passed = 0
|
|
|
| for fa in range(8):
|
|
|
| for comp in ['and1', 'and2', 'or_carry']:
|
| if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
|
|
| for xor in ['xor1', 'xor2']:
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight')
|
| self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.adc8bit.internals', passed, passed, [])
|
|
|
| def test_arithmetic_cmp_internals(self) -> TestResult:
|
| """Test CMP full adder internal tensors."""
|
| passed = 0
|
|
|
| for fa in range(8):
|
| for comp in ['and1', 'and2', 'or_carry']:
|
| if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
| for xor in ['xor1', 'xor2']:
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight')
|
| self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'):
|
| self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight')
|
| self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias')
|
| passed += 2
|
|
|
|
|
| for flag in ['carry', 'negative', 'zero', 'zero_or']:
|
| if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'):
|
| self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight')
|
| self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.cmp8bit.internals', passed, passed, [])
|
|
|
| def test_arithmetic_sbc_internals(self) -> TestResult:
|
| """Test SBC (subtract with carry) internal tensors."""
|
| passed = 0
|
|
|
| for fa in range(8):
|
| for comp in ['and1', 'and2', 'or_carry']:
|
| if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
| for xor in ['xor1', 'xor2']:
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight')
|
| self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.sbc8bit.notb{bit}.weight'):
|
| self.reg.get(f'arithmetic.sbc8bit.notb{bit}.weight')
|
| self.reg.get(f'arithmetic.sbc8bit.notb{bit}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.sbc8bit.internals', passed, passed, [])
|
|
|
| def test_arithmetic_sub_internals(self) -> TestResult:
|
| """Test SUB (subtraction) internal tensors."""
|
| passed = 0
|
|
|
|
|
| if self.reg.has('arithmetic.sub8bit.carry_in.weight'):
|
| self.reg.get('arithmetic.sub8bit.carry_in.weight')
|
| self.reg.get('arithmetic.sub8bit.carry_in.bias')
|
| passed += 2
|
|
|
| for fa in range(8):
|
| for comp in ['and1', 'and2', 'or_carry']:
|
| if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
| for xor in ['xor1', 'xor2']:
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight')
|
| self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.sub8bit.notb{bit}.weight'):
|
| self.reg.get(f'arithmetic.sub8bit.notb{bit}.weight')
|
| self.reg.get(f'arithmetic.sub8bit.notb{bit}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.sub8bit.internals', passed, passed, [])
|
|
|
| def test_arithmetic_equality_internals(self) -> TestResult:
|
| """Test equality XNOR gate internals."""
|
| passed = 0
|
|
|
| for i in range(8):
|
| for layer in ['layer1.and', 'layer1.nor', 'layer2']:
|
| if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight')
|
| self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has('arithmetic.equality8bit.and.weight'):
|
| self.reg.get('arithmetic.equality8bit.and.weight')
|
| self.reg.get('arithmetic.equality8bit.and.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.equality8bit.internals', passed, passed, [])
|
|
|
| def test_arithmetic_rol_ror(self) -> TestResult:
|
| """Test ROL and ROR rotate circuits."""
|
| passed = 0
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.rol8bit.bit{bit}.weight'):
|
| self.reg.get(f'arithmetic.rol8bit.bit{bit}.weight')
|
| self.reg.get(f'arithmetic.rol8bit.bit{bit}.bias')
|
| passed += 2
|
|
|
| if self.reg.has('arithmetic.rol8bit.cout.weight'):
|
| self.reg.get('arithmetic.rol8bit.cout.weight')
|
| self.reg.get('arithmetic.rol8bit.cout.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.ror8bit.bit{bit}.weight'):
|
| self.reg.get(f'arithmetic.ror8bit.bit{bit}.weight')
|
| self.reg.get(f'arithmetic.ror8bit.bit{bit}.bias')
|
| passed += 2
|
|
|
| if self.reg.has('arithmetic.ror8bit.cout.weight'):
|
| self.reg.get('arithmetic.ror8bit.cout.weight')
|
| self.reg.get('arithmetic.ror8bit.cout.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.rol_ror', passed, passed, [])
|
|
|
| def test_arithmetic_div_stages(self) -> TestResult:
|
| """Test division stage internals (all 8 stages)."""
|
| passed = 0
|
|
|
| for stage in range(8):
|
|
|
| if self.reg.has(f'arithmetic.div8bit.stage{stage}.cmp.weight'):
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.weight')
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| for comp in ['and0', 'and1', 'not_sel', 'or']:
|
| if self.reg.has(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight')
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.bias')
|
| passed += 2
|
|
|
|
|
| if self.reg.has(f'arithmetic.div8bit.stage{stage}.or_dividend.weight'):
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.weight')
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight'):
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight')
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.bias')
|
| passed += 2
|
|
|
|
|
| for fa in range(8):
|
| for comp in ['and1', 'and2', 'or_carry']:
|
| if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
| for xor in ['xor1', 'xor2']:
|
| for layer in ['layer1.nand', 'layer1.or', 'layer2']:
|
| if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight'):
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight')
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.bias')
|
| passed += 2
|
|
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight'):
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight')
|
| self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.div8bit.stages', passed, passed, [])
|
|
|
| def test_arithmetic_div_outputs(self) -> TestResult:
|
| """Test division quotient and remainder output tensors."""
|
| passed = 0
|
|
|
| for bit in range(8):
|
| if self.reg.has(f'arithmetic.div8bit.quotient{bit}.weight'):
|
| self.reg.get(f'arithmetic.div8bit.quotient{bit}.weight')
|
| self.reg.get(f'arithmetic.div8bit.quotient{bit}.bias')
|
| passed += 2
|
|
|
| if self.reg.has(f'arithmetic.div8bit.remainder{bit}.weight'):
|
| self.reg.get(f'arithmetic.div8bit.remainder{bit}.weight')
|
| self.reg.get(f'arithmetic.div8bit.remainder{bit}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.div8bit.outputs', passed, passed, [])
|
|
|
| def test_arithmetic_multiplier_internals(self) -> TestResult:
|
| """Test multiplier internal partial products and adders."""
|
| passed = 0
|
|
|
|
|
| for row in range(8):
|
| for col in range(8):
|
| if self.reg.has(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight'):
|
| self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight')
|
| self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias')
|
| passed += 2
|
|
|
|
|
| for stage in range(7):
|
| for bit in range(16):
|
|
|
| for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
|
| for suffix in ['.weight', '.bias']:
|
| if self.reg.has(f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{comp}{suffix[1:]}'):
|
| self.reg.get(f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{comp}{suffix[1:]}')
|
| passed += 1
|
|
|
| return TestResult('arithmetic.multiplier8x8.internals', passed, passed, [])
|
|
|
| def test_arithmetic_ripple_internals(self) -> TestResult:
|
| """Test ripple carry adder internal full adders."""
|
| passed = 0
|
|
|
|
|
| for fa in range(8):
|
| for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
|
| if self.reg.has(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
|
|
| for fa in range(4):
|
| for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
|
| if self.reg.has(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
|
|
| for fa in range(2):
|
| for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
|
| if self.reg.has(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.weight')
|
| self.reg.get(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.ripplecarry.internals', passed, passed, [])
|
|
|
| def test_arithmetic_equality_final(self) -> TestResult:
|
| """Test equality final AND gate."""
|
| passed = 0
|
|
|
| if self.reg.has('arithmetic.equality8bit.final_and.weight'):
|
| self.reg.get('arithmetic.equality8bit.final_and.weight')
|
| self.reg.get('arithmetic.equality8bit.final_and.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.equality8bit.final', passed, passed, [])
|
|
|
| def test_arithmetic_small_multipliers(self) -> TestResult:
|
| """Test 2x2 and 4x4 multiplier circuits."""
|
| passed = 0
|
|
|
|
|
| for a in range(2):
|
| for b in range(2):
|
| if self.reg.has(f'arithmetic.multiplier2x2.and{a}{b}.weight'):
|
| self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.weight')
|
| self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.bias')
|
| passed += 2
|
|
|
|
|
| for comp in ['ha0.sum', 'ha0.carry', 'fa0.ha1.sum', 'fa0.ha1.carry', 'fa0.ha2.sum', 'fa0.ha2.carry', 'fa0.carry_or']:
|
| if self.reg.has(f'arithmetic.multiplier2x2.{comp}.weight'):
|
| self.reg.get(f'arithmetic.multiplier2x2.{comp}.weight')
|
| self.reg.get(f'arithmetic.multiplier2x2.{comp}.bias')
|
| passed += 2
|
|
|
|
|
| for a in range(4):
|
| for b in range(4):
|
| if self.reg.has(f'arithmetic.multiplier4x4.and{a}{b}.weight'):
|
| self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.weight')
|
| self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.bias')
|
| passed += 2
|
|
|
|
|
| for stage in range(3):
|
| for bit in range(8):
|
| for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
|
| if self.reg.has(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight'):
|
| self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight')
|
| self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.bias')
|
| passed += 2
|
|
|
| return TestResult('arithmetic.small_multipliers', passed, passed, [])
|
|
|
|
|
| class ComprehensiveEvaluator:
|
| """Main evaluator that runs all tests and reports results."""
|
|
|
| def __init__(self, model_path: str, device: str = 'cuda'):
|
| print(f"Loading model from {model_path}...")
|
| self.registry = TensorRegistry(model_path)
|
| print(f" Found {len(self.registry.tensors)} tensors")
|
| print(f" Categories: {self.registry.categories}")
|
|
|
| self.evaluator = CircuitEvaluator(self.registry, device)
|
| self.results: List[TestResult] = []
|
|
|
| def run_all(self, verbose: bool = True) -> float:
|
| """Run all tests and return overall pass rate."""
|
| start = time.time()
|
|
|
|
|
| if verbose:
|
| print("\n=== BOOLEAN GATES ===")
|
| self._run_test(self.evaluator.test_boolean_and, verbose)
|
| self._run_test(self.evaluator.test_boolean_or, verbose)
|
| self._run_test(self.evaluator.test_boolean_nand, verbose)
|
| self._run_test(self.evaluator.test_boolean_nor, verbose)
|
| self._run_test(self.evaluator.test_boolean_not, verbose)
|
| self._run_test(self.evaluator.test_boolean_xor, verbose)
|
| self._run_test(self.evaluator.test_boolean_xnor, verbose)
|
| self._run_test(self.evaluator.test_boolean_implies, verbose)
|
| self._run_test(self.evaluator.test_boolean_biimplies, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== ARITHMETIC - ADDERS ===")
|
| self._run_test(self.evaluator.test_half_adder, verbose)
|
| self._run_test(self.evaluator.test_full_adder, verbose)
|
| self._run_test(self.evaluator.test_ripple_carry_2bit, verbose)
|
| self._run_test(self.evaluator.test_ripple_carry_4bit, verbose)
|
| self._run_test(self.evaluator.test_ripple_carry_8bit, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== ARITHMETIC - COMPARATORS ===")
|
| self._run_test(self.evaluator.test_greaterthan8bit, verbose)
|
| self._run_test(self.evaluator.test_lessthan8bit, verbose)
|
| self._run_test(self.evaluator.test_greaterorequal8bit, verbose)
|
| self._run_test(self.evaluator.test_lessorequal8bit, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== ARITHMETIC - MULTIPLIER ===")
|
| self._run_test(self.evaluator.test_multiplier_8x8, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== ARITHMETIC - ADDITIONAL ===")
|
| self._run_test(self.evaluator.test_arithmetic_adc, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_cmp, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_equality, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_minmax, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_negate, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_asr, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_incrementer, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_decrementer, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_adc_internals, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_cmp_internals, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_sbc_internals, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_sub_internals, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_equality_internals, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_rol_ror, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_div_stages, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_div_outputs, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_multiplier_internals, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_ripple_internals, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_equality_final, verbose)
|
| self._run_test(self.evaluator.test_arithmetic_small_multipliers, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== THRESHOLD GATES ===")
|
| for result in self.evaluator.test_threshold_gates():
|
| self.results.append(result)
|
| if verbose:
|
| self._print_result(result)
|
| self._run_test(self.evaluator.test_threshold_atleastk_4, verbose)
|
| self._run_test(self.evaluator.test_threshold_atmostk_4, verbose)
|
| self._run_test(self.evaluator.test_threshold_exactlyk_4, verbose)
|
| self._run_test(self.evaluator.test_threshold_majority, verbose)
|
| self._run_test(self.evaluator.test_threshold_minority, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== MODULAR ARITHMETIC ===")
|
| for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
|
| if self.registry.has(f'modular.mod{mod}.weight') or \
|
| self.registry.has(f'modular.mod{mod}.layer1.geq0.weight'):
|
| self._run_test(lambda m=mod: self.evaluator.test_modular(m), verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== ALU ===")
|
| self._run_test(self.evaluator.test_alu_control, verbose)
|
| self._run_test(self.evaluator.test_alu_flags, verbose)
|
| self._run_test(self.evaluator.test_alu8bit_and, verbose)
|
| self._run_test(self.evaluator.test_alu8bit_or, verbose)
|
| self._run_test(self.evaluator.test_alu8bit_not, verbose)
|
| self._run_test(self.evaluator.test_alu8bit_xor, verbose)
|
| self._run_test(self.evaluator.test_alu8bit_shifts, verbose)
|
| self._run_test(self.evaluator.test_alu8bit_add, verbose)
|
| self._run_test(self.evaluator.test_alu_output_mux, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== COMBINATIONAL ===")
|
| self._run_test(self.evaluator.test_decoder_3to8, verbose)
|
| self._run_test(self.evaluator.test_encoder_8to3, verbose)
|
| self._run_test(self.evaluator.test_mux_2to1, verbose)
|
| self._run_test(self.evaluator.test_demux_1to2, verbose)
|
| self._run_test(self.evaluator.test_barrel_shifter, verbose)
|
| self._run_test(self.evaluator.test_mux_4to1, verbose)
|
| self._run_test(self.evaluator.test_mux_8to1, verbose)
|
| self._run_test(self.evaluator.test_demux_1to4, verbose)
|
| self._run_test(self.evaluator.test_demux_1to8, verbose)
|
| self._run_test(self.evaluator.test_priority_encoder, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== CONTROL ===")
|
| self._run_test(self.evaluator.test_control_jump, verbose)
|
| self._run_test(self.evaluator.test_control_conditional_jump, verbose)
|
| self._run_test(self.evaluator.test_control_call_ret, verbose)
|
| self._run_test(self.evaluator.test_control_push_pop, verbose)
|
| self._run_test(self.evaluator.test_control_sp, verbose)
|
| self._run_test(self.evaluator.test_control_pc_increment, verbose)
|
| self._run_test(self.evaluator.test_control_instruction_decode, verbose)
|
| self._run_test(self.evaluator.test_control_register_mux, verbose)
|
| self._run_test(self.evaluator.test_control_halt, verbose)
|
| self._run_test(self.evaluator.test_control_pc_load, verbose)
|
| self._run_test(self.evaluator.test_control_nop, verbose) |
| self._run_test(self.evaluator.test_control_conditional_jumps, verbose) |
| self._run_test(self.evaluator.test_control_negated_conditional_jumps, verbose) |
| self._run_test(self.evaluator.test_control_parity_jumps, verbose) |
|
|
| |
| if verbose: |
| print("\n=== MEMORY ===") |
| self._run_test(self.evaluator.test_memory_decoder_16to65536, verbose) |
| self._run_test(self.evaluator.test_memory_read_mux, verbose) |
| self._run_test(self.evaluator.test_memory_write_cells, verbose) |
| self._run_test(self.evaluator.test_control_fetch_load_store, verbose) |
| self._run_test(self.evaluator.test_packed_memory_routing, verbose) |
|
|
| |
| if verbose: |
| print("\n=== ERROR DETECTION ===") |
| self._run_test(self.evaluator.test_even_parity, verbose)
|
| self._run_test(self.evaluator.test_odd_parity, verbose)
|
| self._run_test(self.evaluator.test_checksum_8bit, verbose)
|
| self._run_test(self.evaluator.test_crc, verbose)
|
| self._run_test(self.evaluator.test_hamming_encode, verbose)
|
| self._run_test(self.evaluator.test_hamming_decode, verbose)
|
| self._run_test(self.evaluator.test_hamming_syndrome, verbose)
|
| self._run_test(self.evaluator.test_longitudinal_parity, verbose)
|
| self._run_test(self.evaluator.test_parity_checker_internals, verbose)
|
| self._run_test(self.evaluator.test_hamming_encode_biases, verbose)
|
| self._run_test(self.evaluator.test_odd_parity_biases, verbose)
|
| self._run_test(self.evaluator.test_parity_generator_internals, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== PATTERN RECOGNITION ===")
|
| self._run_test(self.evaluator.test_popcount, verbose)
|
| self._run_test(self.evaluator.test_allzeros, verbose)
|
| self._run_test(self.evaluator.test_allones, verbose)
|
| self._run_test(self.evaluator.test_hamming_distance, verbose)
|
| self._run_test(self.evaluator.test_one_hot_detector, verbose)
|
| self._run_test(self.evaluator.test_alternating_pattern, verbose)
|
| self._run_test(self.evaluator.test_symmetry_detector, verbose)
|
| self._run_test(self.evaluator.test_leading_ones, verbose)
|
| self._run_test(self.evaluator.test_run_length, verbose)
|
| self._run_test(self.evaluator.test_trailing_ones, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== MANIFEST ===")
|
| self._run_test(self.evaluator.test_manifest, verbose)
|
|
|
|
|
| if verbose:
|
| print("\n=== DIVISION ===")
|
| self._run_test(self.evaluator.test_division_8bit, verbose)
|
|
|
| elapsed = time.time() - start
|
|
|
|
|
| total_passed = sum(r.passed for r in self.results)
|
| total_tests = sum(r.total for r in self.results)
|
|
|
| print("\n" + "=" * 60)
|
| print("SUMMARY")
|
| print("=" * 60)
|
| print(f"Total: {total_passed}/{total_tests} ({100*total_passed/total_tests:.4f}%)")
|
| print(f"Time: {elapsed:.2f}s")
|
|
|
| failed = [r for r in self.results if not r.success]
|
| if failed:
|
| print(f"\nFailed circuits ({len(failed)}):")
|
| for r in failed:
|
| print(f" {r.circuit_name}: {r.passed}/{r.total} ({100*r.rate:.2f}%)")
|
| if r.failures:
|
| print(f" First failure: input={r.failures[0][0]}, expected={r.failures[0][1]}, got={r.failures[0][2]}")
|
| else:
|
| print("\nAll circuits passed!")
|
|
|
|
|
| print("\n" + "=" * 60)
|
| print(self.registry.coverage_report())
|
|
|
| return total_passed / total_tests if total_tests > 0 else 0.0
|
|
|
| def _run_test(self, test_fn: Callable, verbose: bool):
|
| """Run a single test and record result."""
|
| try:
|
| result = test_fn()
|
| self.results.append(result)
|
| if verbose:
|
| self._print_result(result)
|
| except Exception as e:
|
| print(f" ERROR: {e}")
|
|
|
| def _print_result(self, result: TestResult):
|
| """Print a single test result."""
|
| status = "PASS" if result.success else "FAIL"
|
| print(f" {result.circuit_name}: {result.passed}/{result.total} [{status}]")
|
| if not result.success and result.failures:
|
| print(f" First failure: {result.failures[0]}")
|
|
|
|
|
| def main():
|
| import argparse
|
| parser = argparse.ArgumentParser(description='Comprehensive circuit evaluator')
|
| parser.add_argument('--model', type=str, default='./neural_computer.safetensors',
|
| help='Path to safetensors model')
|
| parser.add_argument('--device', type=str, default='cuda',
|
| help='Device (cuda or cpu)')
|
| parser.add_argument('--quiet', action='store_true',
|
| help='Suppress verbose output')
|
| args = parser.parse_args()
|
|
|
| evaluator = ComprehensiveEvaluator(args.model, args.device)
|
| fitness = evaluator.run_all(verbose=not args.quiet)
|
|
|
| print(f"\nFitness: {fitness:.6f}")
|
| return 0 if fitness >= 0.9999 else 1
|
|
|
|
|
| if __name__ == '__main__':
|
| exit(main())
|
|
|