CharlesCNorton
Add SHL, SHR, MUL, DIV, and comparator circuits
6087b2e
raw
history blame
70.3 kB
"""
Unified Evaluation Suite for 8-bit Threshold Computer
======================================================
GPU-batched evaluation with per-circuit reporting.
Usage:
python eval.py # Run evaluation
python eval.py --device cpu # CPU mode
python eval.py --pop_size 1000 # Population mode for evolution
API (for prune_weights.py):
from eval import load_model, create_population, BatchedFitnessEvaluator
"""
import argparse
import json
import os
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
import torch
from safetensors import safe_open
MODEL_PATH = os.path.join(os.path.dirname(__file__), "neural_computer.safetensors")
@dataclass
class CircuitResult:
"""Result for a single circuit test."""
name: str
passed: int
total: int
failures: List[Tuple] = field(default_factory=list)
@property
def success(self) -> bool:
return self.passed == self.total
@property
def rate(self) -> float:
return self.passed / self.total if self.total > 0 else 0.0
def heaviside(x: torch.Tensor) -> torch.Tensor:
"""Threshold activation: 1 if x >= 0, else 0."""
return (x >= 0).float()
def load_model(path: str = MODEL_PATH) -> Dict[str, torch.Tensor]:
"""Load model tensors from safetensors."""
with safe_open(path, framework='pt') as f:
return {name: f.get_tensor(name).float() for name in f.keys()}
def load_metadata(path: str = MODEL_PATH) -> Dict:
"""Load metadata from safetensors (includes signal_registry)."""
with safe_open(path, framework='pt') as f:
meta = f.metadata()
if meta and 'signal_registry' in meta:
return {'signal_registry': json.loads(meta['signal_registry'])}
return {'signal_registry': {}}
def create_population(
base_tensors: Dict[str, torch.Tensor],
pop_size: int,
device: str = 'cuda'
) -> Dict[str, torch.Tensor]:
"""Replicate base tensors for batched population evaluation."""
return {
name: tensor.unsqueeze(0).expand(pop_size, *tensor.shape).clone().to(device)
for name, tensor in base_tensors.items()
}
class BatchedFitnessEvaluator:
"""
GPU-batched fitness evaluator with per-circuit reporting.
Tests all circuits comprehensively.
"""
def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH):
self.device = device
self.model_path = model_path
self.metadata = load_metadata(model_path)
self.signal_registry = self.metadata.get('signal_registry', {})
self.results: List[CircuitResult] = []
self.category_scores: Dict[str, Tuple[float, int]] = {}
self.total_tests = 0
self._setup_tests()
def _setup_tests(self):
"""Pre-compute test vectors on device."""
d = self.device
# 2-input truth table [4, 2]
self.tt2 = torch.tensor(
[[0, 0], [0, 1], [1, 0], [1, 1]],
device=d, dtype=torch.float32
)
# 3-input truth table [8, 3]
self.tt3 = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]
], device=d, dtype=torch.float32)
# Boolean gate expected outputs
self.expected = {
'and': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32),
'or': torch.tensor([0, 1, 1, 1], device=d, dtype=torch.float32),
'nand': torch.tensor([1, 1, 1, 0], device=d, dtype=torch.float32),
'nor': torch.tensor([1, 0, 0, 0], device=d, dtype=torch.float32),
'xor': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32),
'xnor': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32),
'implies': torch.tensor([1, 1, 0, 1], device=d, dtype=torch.float32),
'biimplies': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32),
'not': torch.tensor([1, 0], device=d, dtype=torch.float32),
'ha_sum': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32),
'ha_carry': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32),
'fa_sum': torch.tensor([0, 1, 1, 0, 1, 0, 0, 1], device=d, dtype=torch.float32),
'fa_cout': torch.tensor([0, 0, 0, 1, 0, 1, 1, 1], device=d, dtype=torch.float32),
}
# NOT gate inputs
self.not_inputs = torch.tensor([[0], [1]], device=d, dtype=torch.float32)
# 8-bit test values
self.test_8bit = torch.tensor([
0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255,
0b10101010, 0b01010101, 0b11110000, 0b00001111,
0b11001100, 0b00110011, 0b10000001, 0b01111110
], device=d, dtype=torch.long)
# Bit representations [num_vals, 8]
self.test_8bit_bits = torch.stack([
((self.test_8bit >> (7 - i)) & 1).float() for i in range(8)
], dim=1)
# Comparator test pairs
comp_tests = [
(0, 0), (1, 0), (0, 1), (5, 3), (3, 5), (5, 5),
(255, 0), (0, 255), (128, 127), (127, 128),
(100, 99), (99, 100), (64, 32), (32, 64),
(1, 1), (254, 255), (255, 254), (128, 128),
(0, 128), (128, 0), (64, 64), (192, 192),
(15, 16), (16, 15), (240, 239), (239, 240),
(85, 170), (170, 85), (0xAA, 0x55), (0x55, 0xAA),
(0x0F, 0xF0), (0xF0, 0x0F), (0x33, 0xCC), (0xCC, 0x33),
(2, 3), (3, 2), (126, 127), (127, 126),
(129, 128), (128, 129), (200, 199), (199, 200),
(50, 51), (51, 50), (10, 20), (20, 10),
(100, 100), (200, 200), (77, 77), (0, 0)
]
self.comp_a = torch.tensor([c[0] for c in comp_tests], device=d, dtype=torch.long)
self.comp_b = torch.tensor([c[1] for c in comp_tests], device=d, dtype=torch.long)
# Modular test range
self.mod_test = torch.arange(256, device=d, dtype=torch.long)
def _record(self, name: str, passed: int, total: int, failures: List[Tuple] = None):
"""Record a circuit test result."""
self.results.append(CircuitResult(
name=name,
passed=passed,
total=total,
failures=failures or []
))
# =========================================================================
# BOOLEAN GATES
# =========================================================================
def _test_single_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
expected: torch.Tensor) -> torch.Tensor:
"""Test single-layer gate (AND, OR, NAND, NOR, IMPLIES)."""
pop_size = next(iter(pop.values())).shape[0]
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
# [num_tests, pop_size]
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record(prefix, int(correct[0].item()), len(expected), failures)
return correct
def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
expected: torch.Tensor) -> torch.Tensor:
"""Test two-layer gate (XOR, XNOR, BIIMPLIES)."""
pop_size = next(iter(pop.values())).shape[0]
# Layer 1
w1_n1 = pop[f'{prefix}.layer1.neuron1.weight']
b1_n1 = pop[f'{prefix}.layer1.neuron1.bias']
w1_n2 = pop[f'{prefix}.layer1.neuron2.weight']
b1_n2 = pop[f'{prefix}.layer1.neuron2.bias']
h1 = heaviside(inputs @ w1_n1.view(pop_size, -1).T + b1_n1.view(pop_size))
h2 = heaviside(inputs @ w1_n2.view(pop_size, -1).T + b1_n2.view(pop_size))
hidden = torch.stack([h1, h2], dim=-1)
# Layer 2
w2 = pop[f'{prefix}.layer2.weight']
b2 = pop[f'{prefix}.layer2.bias']
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record(prefix, int(correct[0].item()), len(expected), failures)
return correct
def _test_xor_ornand(self, pop: Dict, prefix: str, inputs: torch.Tensor,
expected: torch.Tensor) -> torch.Tensor:
"""Test XOR with or/nand layer naming."""
pop_size = next(iter(pop.values())).shape[0]
w_or = pop[f'{prefix}.layer1.or.weight']
b_or = pop[f'{prefix}.layer1.or.bias']
w_nand = pop[f'{prefix}.layer1.nand.weight']
b_nand = pop[f'{prefix}.layer1.nand.bias']
h_or = heaviside(inputs @ w_or.view(pop_size, -1).T + b_or.view(pop_size))
h_nand = heaviside(inputs @ w_nand.view(pop_size, -1).T + b_nand.view(pop_size))
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = pop[f'{prefix}.layer2.weight']
b2 = pop[f'{prefix}.layer2.bias']
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record(prefix, int(correct[0].item()), len(expected), failures)
return correct
def _test_boolean_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all boolean gates."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== BOOLEAN GATES ===")
# Single-layer gates
for gate in ['and', 'or', 'nand', 'nor', 'implies']:
scores += self._test_single_gate(pop, f'boolean.{gate}', self.tt2, self.expected[gate])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# NOT gate
w = pop['boolean.not.weight']
b = pop['boolean.not.bias']
out = heaviside(self.not_inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == self.expected['not'].unsqueeze(1)).float().sum(0)
scores += correct
total += 2
failures = []
if pop_size == 1:
for inp, exp, got in zip(self.not_inputs, self.expected['not'], out[:, 0]):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record('boolean.not', int(correct[0].item()), 2, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# Two-layer gates
for gate in ['xnor', 'biimplies']:
scores += self._test_twolayer_gate(pop, f'boolean.{gate}', self.tt2, self.expected.get(gate, self.expected['xnor']))
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# XOR with neuron1/neuron2 naming (same as xnor/biimplies)
scores += self._test_twolayer_gate(pop, 'boolean.xor', self.tt2, self.expected['xor'])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
# =========================================================================
# ARITHMETIC - ADDERS
# =========================================================================
def _eval_xor(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Evaluate XOR gate with or/nand decomposition.
Args:
a, b: Tensors of shape [num_tests] or [num_tests, pop_size]
Returns:
Tensor of shape [num_tests, pop_size]
"""
pop_size = next(iter(pop.values())).shape[0]
# Ensure inputs are [num_tests, pop_size]
if a.dim() == 1:
a = a.unsqueeze(1).expand(-1, pop_size)
if b.dim() == 1:
b = b.unsqueeze(1).expand(-1, pop_size)
# inputs: [num_tests, pop_size, 2]
inputs = torch.stack([a, b], dim=-1)
w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
# [num_tests, pop_size]
h_or = heaviside((inputs * w_or).sum(-1) + b_or)
h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand)
# hidden: [num_tests, pop_size, 2]
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, 2)
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
return heaviside((hidden * w2).sum(-1) + b2)
def _eval_single_fa(self, pop: Dict, prefix: str,
a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Evaluate single full adder.
Args:
a, b, cin: Tensors of shape [num_tests] or [num_tests, pop_size]
Returns:
sum_out, cout: Both of shape [num_tests, pop_size]
"""
pop_size = next(iter(pop.values())).shape[0]
# Ensure inputs are [num_tests, pop_size]
if a.dim() == 1:
a = a.unsqueeze(1).expand(-1, pop_size)
if b.dim() == 1:
b = b.unsqueeze(1).expand(-1, pop_size)
if cin.dim() == 1:
cin = cin.unsqueeze(1).expand(-1, pop_size)
# Half adder 1: a XOR b -> [num_tests, pop_size]
ha1_sum = self._eval_xor(pop, f'{prefix}.ha1.sum', a, b)
# Half adder 1 carry: a AND b
ab = torch.stack([a, b], dim=-1) # [num_tests, pop_size, 2]
w_c1 = pop[f'{prefix}.ha1.carry.weight'].view(pop_size, 2)
b_c1 = pop[f'{prefix}.ha1.carry.bias'].view(pop_size)
ha1_carry = heaviside((ab * w_c1).sum(-1) + b_c1)
# Half adder 2: ha1_sum XOR cin
ha2_sum = self._eval_xor(pop, f'{prefix}.ha2.sum', ha1_sum, cin)
# Half adder 2 carry
sc = torch.stack([ha1_sum, cin], dim=-1)
w_c2 = pop[f'{prefix}.ha2.carry.weight'].view(pop_size, 2)
b_c2 = pop[f'{prefix}.ha2.carry.bias'].view(pop_size)
ha2_carry = heaviside((sc * w_c2).sum(-1) + b_c2)
# Carry out: ha1_carry OR ha2_carry
carries = torch.stack([ha1_carry, ha2_carry], dim=-1)
w_cout = pop[f'{prefix}.carry_or.weight'].view(pop_size, 2)
b_cout = pop[f'{prefix}.carry_or.bias'].view(pop_size)
cout = heaviside((carries * w_cout).sum(-1) + b_cout)
return ha2_sum, cout
def _test_halfadder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test half adder."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== HALF ADDER ===")
# Sum (XOR)
scores += self._test_xor_ornand(pop, 'arithmetic.halfadder.sum', self.tt2, self.expected['ha_sum'])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# Carry (AND)
scores += self._test_single_gate(pop, 'arithmetic.halfadder.carry', self.tt2, self.expected['ha_carry'])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
def _test_fulladder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test full adder with all 8 input combinations."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print("\n=== FULL ADDER ===")
a = self.tt3[:, 0]
b = self.tt3[:, 1]
cin = self.tt3[:, 2]
sum_out, cout = self._eval_single_fa(pop, 'arithmetic.fulladder', a, b, cin)
sum_correct = (sum_out == self.expected['fa_sum'].unsqueeze(1)).float().sum(0)
cout_correct = (cout == self.expected['fa_cout'].unsqueeze(1)).float().sum(0)
failures_sum = []
failures_cout = []
if pop_size == 1:
for i in range(8):
if sum_out[i, 0].item() != self.expected['fa_sum'][i].item():
failures_sum.append(([a[i].item(), b[i].item(), cin[i].item()],
self.expected['fa_sum'][i].item(), sum_out[i, 0].item()))
if cout[i, 0].item() != self.expected['fa_cout'][i].item():
failures_cout.append(([a[i].item(), b[i].item(), cin[i].item()],
self.expected['fa_cout'][i].item(), cout[i, 0].item()))
self._record('arithmetic.fulladder.sum', int(sum_correct[0].item()), 8, failures_sum)
self._record('arithmetic.fulladder.cout', int(cout_correct[0].item()), 8, failures_cout)
if debug:
for r in self.results[-2:]:
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return sum_correct + cout_correct, 16
def _test_ripplecarry(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit ripple carry adder."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print(f"\n=== RIPPLE CARRY {bits}-BIT ===")
prefix = f'arithmetic.ripplecarry{bits}bit'
max_val = 1 << bits
num_tests = min(max_val * max_val, 65536)
if bits <= 4:
# Exhaustive for small widths
test_a = torch.arange(max_val, device=self.device)
test_b = torch.arange(max_val, device=self.device)
a_vals, b_vals = torch.meshgrid(test_a, test_b, indexing='ij')
a_vals = a_vals.flatten()
b_vals = b_vals.flatten()
else:
# Strategic sampling for 8-bit
edge_vals = [0, 1, 2, 127, 128, 254, 255]
pairs = [(a, b) for a in edge_vals for b in edge_vals]
for i in range(0, 256, 16):
pairs.append((i, 255 - i))
pairs = list(set(pairs))
a_vals = torch.tensor([p[0] for p in pairs], device=self.device)
b_vals = torch.tensor([p[1] for p in pairs], device=self.device)
num_tests = len(pairs)
# Convert to bits [num_tests, bits]
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
# Evaluate ripple carry
carry = torch.zeros(len(a_vals), pop_size, device=self.device)
sum_bits = []
for bit in range(bits):
bit_idx = bits - 1 - bit # LSB first
s, carry = self._eval_single_fa(
pop, f'{prefix}.fa{bit}',
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
carry
)
sum_bits.append(s)
# Reconstruct result
sum_bits = torch.stack(sum_bits[::-1], dim=-1) # MSB first
result = torch.zeros(len(a_vals), pop_size, device=self.device)
for i in range(bits):
result += sum_bits[:, :, i] * (1 << (bits - 1 - i))
# Expected
expected = ((a_vals + b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float()
correct = (result == expected).float().sum(0)
failures = []
if pop_size == 1:
for i in range(min(len(a_vals), 100)):
if result[i, 0].item() != expected[i, 0].item():
failures.append((
[int(a_vals[i].item()), int(b_vals[i].item())],
int(expected[i, 0].item()),
int(result[i, 0].item())
))
self._record(prefix, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, num_tests
# =========================================================================
# COMPARATORS
# =========================================================================
def _test_comparator(self, pop: Dict, name: str, op: Callable[[int, int], bool],
debug: bool) -> Tuple[torch.Tensor, int]:
"""Test 8-bit comparator."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'arithmetic.{name}'
# Use pre-computed test pairs
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
for a, b in zip(self.comp_a, self.comp_b)],
device=self.device)
# Convert to bits
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1)
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1)
inputs = torch.cat([a_bits, b_bits], dim=1)
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(len(self.comp_a)):
if out[i, 0].item() != expected[i].item():
failures.append((
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
expected[i].item(),
out[i, 0].item()
))
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, len(self.comp_a)
def _test_comparators(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all comparators."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== COMPARATORS ===")
comparators = [
('greaterthan8bit', lambda a, b: a > b),
('lessthan8bit', lambda a, b: a < b),
('greaterorequal8bit', lambda a, b: a >= b),
('lessorequal8bit', lambda a, b: a <= b),
('equality8bit', lambda a, b: a == b),
]
for name, op in comparators:
if name == 'equality8bit':
continue # Handle separately as two-layer
try:
s, t = self._test_comparator(pop, name, op, debug)
scores += s
total += t
except KeyError:
pass # Circuit not present
# Two-layer equality circuit
try:
prefix = 'arithmetic.equality8bit'
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
for a, b in zip(self.comp_a, self.comp_b)],
device=self.device)
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1)
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1)
inputs = torch.cat([a_bits, b_bits], dim=1)
# Layer 1: geq and leq
w_geq = pop[f'{prefix}.layer1.geq.weight']
b_geq = pop[f'{prefix}.layer1.geq.bias']
w_leq = pop[f'{prefix}.layer1.leq.weight']
b_leq = pop[f'{prefix}.layer1.leq.bias']
h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
hidden = torch.stack([h_geq, h_leq], dim=-1) # [num_tests, pop_size, 2]
# Layer 2: AND
w2 = pop[f'{prefix}.layer2.weight']
b2 = pop[f'{prefix}.layer2.bias']
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(len(self.comp_a)):
if out[i, 0].item() != expected[i].item():
failures.append((
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
expected[i].item(),
out[i, 0].item()
))
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += len(self.comp_a)
except KeyError:
pass
return scores, total
# =========================================================================
# THRESHOLD GATES
# =========================================================================
def _test_threshold_kofn(self, pop: Dict, k: int, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test k-of-n threshold gate."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'threshold.{name}'
# Test all 256 8-bit patterns
inputs = self.test_8bit_bits if len(self.test_8bit_bits) == 24 else None
if inputs is None:
test_vals = torch.arange(256, device=self.device, dtype=torch.long)
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1)
# For k-of-8: output 1 if popcount >= k (for "at least k")
# For exact naming like "oneoutof8", it's exactly k=1
popcounts = inputs.sum(dim=1)
if 'atleast' in name:
expected = (popcounts >= k).float()
elif 'atmost' in name or 'minority' in name:
# minority = popcount <= 3 (less than half of 8)
expected = (popcounts <= k).float()
elif 'exactly' in name:
expected = (popcounts == k).float()
else:
# Standard k-of-n (at least k), including majority (>= 5)
expected = (popcounts >= k).float()
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(min(len(inputs), 256)):
if out[i, 0].item() != expected[i].item():
val = int(sum(inputs[i, j].item() * (1 << (7 - j)) for j in range(8)))
failures.append((val, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), len(inputs), failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, len(inputs)
def _test_threshold_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all threshold gates."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== THRESHOLD GATES ===")
# k-of-8 gates
kofn_gates = [
(1, 'oneoutof8'), (2, 'twooutof8'), (3, 'threeoutof8'), (4, 'fouroutof8'),
(5, 'fiveoutof8'), (6, 'sixoutof8'), (7, 'sevenoutof8'), (8, 'alloutof8'),
]
for k, name in kofn_gates:
try:
s, t = self._test_threshold_kofn(pop, k, name, debug)
scores += s
total += t
except KeyError:
pass
# Special gates
special = [
(5, 'majority'), (3, 'minority'),
(4, 'atleastk_4'), (4, 'atmostk_4'), (4, 'exactlyk_4'),
]
for k, name in special:
try:
s, t = self._test_threshold_kofn(pop, k, name, debug)
scores += s
total += t
except KeyError:
pass
return scores, total
# =========================================================================
# MODULAR ARITHMETIC
# =========================================================================
def _test_modular(self, pop: Dict, mod: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test modular divisibility circuit (multi-layer for non-powers-of-2)."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'modular.mod{mod}'
# Test 0-255
inputs = torch.stack([((self.mod_test >> (7 - i)) & 1).float() for i in range(8)], dim=1)
expected = ((self.mod_test % mod) == 0).float()
# Try single layer first (powers of 2)
try:
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
except KeyError:
# Multi-layer structure: layer1 (geq/leq) -> layer2 (eq) -> layer3 (or)
try:
# Layer 1: geq and leq neurons
geq_outputs = {}
leq_outputs = {}
i = 0
while True:
found = False
if f'{prefix}.layer1.geq{i}.weight' in pop:
w = pop[f'{prefix}.layer1.geq{i}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.layer1.geq{i}.bias'].view(pop_size)
geq_outputs[i] = heaviside(inputs @ w.T + b) # [256, pop_size]
found = True
if f'{prefix}.layer1.leq{i}.weight' in pop:
w = pop[f'{prefix}.layer1.leq{i}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.layer1.leq{i}.bias'].view(pop_size)
leq_outputs[i] = heaviside(inputs @ w.T + b)
found = True
if not found:
break
i += 1
if not geq_outputs and not leq_outputs:
return torch.zeros(pop_size, device=self.device), 0
# Layer 2: eq neurons (AND of geq and leq for same index)
eq_outputs = []
i = 0
while f'{prefix}.layer2.eq{i}.weight' in pop:
w = pop[f'{prefix}.layer2.eq{i}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.layer2.eq{i}.bias'].view(pop_size)
# Input is [geq_i, leq_i]
eq_in = torch.stack([geq_outputs.get(i, torch.zeros(256, pop_size, device=self.device)),
leq_outputs.get(i, torch.zeros(256, pop_size, device=self.device))], dim=-1)
eq_out = heaviside((eq_in * w).sum(-1) + b)
eq_outputs.append(eq_out)
i += 1
if not eq_outputs:
return torch.zeros(pop_size, device=self.device), 0
# Layer 3: OR of all eq outputs
eq_stack = torch.stack(eq_outputs, dim=-1) # [256, pop_size, num_eq]
w3 = pop[f'{prefix}.layer3.or.weight'].view(pop_size, -1)
b3 = pop[f'{prefix}.layer3.or.bias'].view(pop_size)
out = heaviside((eq_stack * w3).sum(-1) + b3) # [256, pop_size]
except Exception as e:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(256):
if out[i, 0].item() != expected[i].item():
failures.append((i, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 256, failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 256
def _test_modular_all(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all modular arithmetic circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== MODULAR ARITHMETIC ===")
for mod in range(2, 13):
s, t = self._test_modular(pop, mod, debug)
scores += s
total += t
return scores, total
# =========================================================================
# PATTERN RECOGNITION
# =========================================================================
def _test_pattern(self, pop: Dict, name: str, expected_fn: Callable[[int], float],
debug: bool) -> Tuple[torch.Tensor, int]:
"""Test pattern recognition circuit."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'pattern_recognition.{name}'
test_vals = torch.arange(256, device=self.device, dtype=torch.long)
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1)
expected = torch.tensor([expected_fn(v.item()) for v in test_vals], device=self.device)
try:
w = pop[f'{prefix}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.bias'].view(pop_size)
out = heaviside(inputs @ w.T + b)
except KeyError:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(256):
if out[i, 0].item() != expected[i].item():
failures.append((i, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 256, failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 256
def _test_patterns(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test pattern recognition circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== PATTERN RECOGNITION ===")
# Use correct naming: pattern_recognition.allzeros, pattern_recognition.allones
patterns = [
('allzeros', lambda v: 1.0 if v == 0 else 0.0),
('allones', lambda v: 1.0 if v == 255 else 0.0),
]
for name, fn in patterns:
s, t = self._test_pattern(pop, name, fn, debug)
scores += s
total += t
return scores, total
# =========================================================================
# ERROR DETECTION
# =========================================================================
def _eval_xor_tree_stage(self, pop: Dict, prefix: str, stage: int, idx: int,
a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Evaluate a single XOR in the parity tree."""
pop_size = next(iter(pop.values())).shape[0]
xor_prefix = f'{prefix}.stage{stage}.xor{idx}'
# Ensure 2D: [256, pop_size]
if a.dim() == 1:
a = a.unsqueeze(1).expand(-1, pop_size)
if b.dim() == 1:
b = b.unsqueeze(1).expand(-1, pop_size)
# Layer 1: OR and NAND
w_or = pop[f'{xor_prefix}.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'{xor_prefix}.layer1.or.bias'].view(pop_size)
w_nand = pop[f'{xor_prefix}.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'{xor_prefix}.layer1.nand.bias'].view(pop_size)
inputs = torch.stack([a, b], dim=-1) # [256, pop_size, 2]
h_or = heaviside((inputs * w_or).sum(-1) + b_or)
h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand)
# Layer 2
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = pop[f'{xor_prefix}.layer2.weight'].view(pop_size, 2)
b2 = pop[f'{xor_prefix}.layer2.bias'].view(pop_size)
return heaviside((hidden * w2).sum(-1) + b2)
def _test_parity_xor_tree(self, pop: Dict, prefix: str, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test parity circuit with XOR tree structure."""
pop_size = next(iter(pop.values())).shape[0]
test_vals = torch.arange(256, device=self.device, dtype=torch.long)
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1)
# XOR of all bits: 1 if odd number of 1s
popcounts = inputs.sum(dim=1)
xor_result = (popcounts.long() % 2).float()
try:
# Stage 1: 4 XORs (pairs of bits)
s1_out = []
for i in range(4):
xor_out = self._eval_xor_tree_stage(pop, prefix, 1, i, inputs[:, i*2], inputs[:, i*2+1])
s1_out.append(xor_out)
# Stage 2: 2 XORs
s2_out = []
for i in range(2):
xor_out = self._eval_xor_tree_stage(pop, prefix, 2, i, s1_out[i*2], s1_out[i*2+1])
s2_out.append(xor_out)
# Stage 3: 1 XOR
s3_out = self._eval_xor_tree_stage(pop, prefix, 3, 0, s2_out[0], s2_out[1])
# Output NOT (for parity checker - inverts the XOR result)
if f'{prefix}.output.not.weight' in pop:
w_not = pop[f'{prefix}.output.not.weight'].view(pop_size)
b_not = pop[f'{prefix}.output.not.bias'].view(pop_size)
out = heaviside(s3_out * w_not + b_not)
# Checker outputs 1 if even parity (XOR=0), so expected is inverted xor_result
expected = 1.0 - xor_result
else:
out = s3_out
expected = xor_result
except KeyError as e:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(256):
if out[i, 0].item() != expected[i].item():
failures.append((i, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 256, failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 256
def _test_error_detection(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test error detection circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== ERROR DETECTION ===")
# XOR tree parity circuits
for prefix in ['error_detection.paritychecker8bit', 'error_detection.paritygenerator8bit']:
s, t = self._test_parity_xor_tree(pop, prefix, debug)
scores += s
total += t
return scores, total
# =========================================================================
# COMBINATIONAL LOGIC
# =========================================================================
def _test_mux2to1(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test 2-to-1 multiplexer."""
pop_size = next(iter(pop.values())).shape[0]
prefix = 'combinational.multiplexer2to1'
# Inputs: [a, b, sel] -> out = sel ? b : a
inputs = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1],
], device=self.device, dtype=torch.float32)
expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32)
try:
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
except KeyError:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(8):
if out[i, 0].item() != expected[i].item():
failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 8, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 8
def _test_decoder3to8(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test 3-to-8 decoder."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== DECODER 3-TO-8 ===")
inputs = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1],
], device=self.device, dtype=torch.float32)
for out_idx in range(8):
prefix = f'combinational.decoder3to8.out{out_idx}'
expected = torch.zeros(8, device=self.device)
expected[out_idx] = 1.0
try:
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
except KeyError:
continue
correct = (out == expected.unsqueeze(1)).float().sum(0)
scores += correct
total += 8
failures = []
if pop_size == 1:
for i in range(8):
if out[i, 0].item() != expected[i].item():
failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 8, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
def _test_combinational(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test combinational logic circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== COMBINATIONAL LOGIC ===")
s, t = self._test_mux2to1(pop, debug)
scores += s
total += t
s, t = self._test_decoder3to8(pop, debug)
scores += s
total += t
return scores, total
# =========================================================================
# CONTROL FLOW
# =========================================================================
def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test conditional jump circuit."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'control.{name}'
# Test cases: [pc_bit, target_bit, flag] -> out = flag ? target : pc
inputs = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1],
], device=self.device, dtype=torch.float32)
expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32)
scores = torch.zeros(pop_size, device=self.device)
total = 0
for bit in range(8):
bit_prefix = f'{prefix}.bit{bit}'
try:
# NOT sel
w_not = pop[f'{bit_prefix}.not_sel.weight']
b_not = pop[f'{bit_prefix}.not_sel.bias']
flag = inputs[:, 2:3]
not_sel = heaviside(flag @ w_not.view(pop_size, -1).T + b_not.view(pop_size))
# AND a (pc AND NOT sel)
w_and_a = pop[f'{bit_prefix}.and_a.weight']
b_and_a = pop[f'{bit_prefix}.and_a.bias']
pc_not = torch.cat([inputs[:, 0:1], not_sel], dim=-1)
and_a = heaviside((pc_not * w_and_a.view(pop_size, 1, 2)).sum(-1) + b_and_a.view(pop_size, 1))
# AND b (target AND sel)
w_and_b = pop[f'{bit_prefix}.and_b.weight']
b_and_b = pop[f'{bit_prefix}.and_b.bias']
target_sel = inputs[:, 1:3]
and_b = heaviside((target_sel * w_and_b.view(pop_size, 1, 2)).sum(-1) + b_and_b.view(pop_size, 1))
# OR
w_or = pop[f'{bit_prefix}.or.weight']
b_or = pop[f'{bit_prefix}.or.bias']
# Ensure we keep [num_tests, pop_size] shape
and_a_2d = and_a.view(8, pop_size)
and_b_2d = and_b.view(8, pop_size)
ab = torch.stack([and_a_2d, and_b_2d], dim=-1) # [8, pop_size, 2]
out = heaviside((ab * w_or.view(pop_size, 2)).sum(-1) + b_or.view(pop_size)) # [8, pop_size]
correct = (out == expected.unsqueeze(1)).float().sum(0) # [pop_size]
scores += correct
total += 8
except KeyError:
pass
if total > 0:
self._record(prefix, int((scores[0] / total * total).item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
def _test_control_flow(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test control flow circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== CONTROL FLOW ===")
jumps = ['jz', 'jnz', 'jc', 'jnc', 'jn', 'jp', 'jv', 'jnv', 'conditionaljump']
for name in jumps:
s, t = self._test_conditional_jump(pop, name, debug)
scores += s
total += t
return scores, total
# =========================================================================
# ALU
# =========================================================================
def _test_alu_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test ALU operations (8-bit bitwise)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== ALU OPERATIONS ===")
# Test ALU AND/OR/NOT on 8-bit values
# Each ALU op has weight [16] or [8] and bias [8]
# Structured as 8 parallel 2-input (or 1-input for NOT) gates
test_vals = [(0, 0), (255, 255), (0xAA, 0x55), (0x0F, 0xF0)]
# AND: weight [16] = 8 * [2], bias [8]
try:
w = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) # [pop, 8, 2]
b = pop['alu.alu8bit.and.bias'].view(pop_size, 8) # [pop, 8]
for a_val, b_val in test_vals:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# [8, 2]
inputs = torch.stack([a_bits, b_bits], dim=-1)
# [pop, 8]
out = heaviside((inputs * w).sum(-1) + b)
expected = torch.tensor([((a_val & b_val) >> (7 - i)) & 1 for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1) # [pop]
scores += correct
total += 8
self._record('alu.alu8bit.and', int(scores[0].item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError):
pass
# OR
try:
w = pop['alu.alu8bit.or.weight'].view(pop_size, 8, 2)
b = pop['alu.alu8bit.or.bias'].view(pop_size, 8)
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, b_val in test_vals:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
inputs = torch.stack([a_bits, b_bits], dim=-1)
out = heaviside((inputs * w).sum(-1) + b)
expected = torch.tensor([((a_val | b_val) >> (7 - i)) & 1 for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.or', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError):
pass
# NOT
try:
w = pop['alu.alu8bit.not.weight'].view(pop_size, 8)
b = pop['alu.alu8bit.not.bias'].view(pop_size, 8)
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, _ in test_vals:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out = heaviside(a_bits * w + b)
expected = torch.tensor([(((~a_val) & 0xFF) >> (7 - i)) & 1 for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.not', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError):
pass
# SHL (shift left)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, _ in test_vals:
expected_val = (a_val << 1) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size)
if bit < 7:
inp = a_bits[bit + 1].unsqueeze(0).expand(pop_size)
else:
inp = torch.zeros(pop_size, device=self.device)
out = heaviside(inp * w + b)
out_bits.append(out)
out = torch.stack(out_bits, dim=-1) # [pop, 8]
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.shl', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.shl: SKIP ({e})")
# SHR (shift right)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, _ in test_vals:
expected_val = (a_val >> 1) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.shr.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu8bit.shr.bit{bit}.bias'].view(pop_size)
if bit > 0:
inp = a_bits[bit - 1].unsqueeze(0).expand(pop_size)
else:
inp = torch.zeros(pop_size, device=self.device)
out = heaviside(inp * w + b)
out_bits.append(out)
out = torch.stack(out_bits, dim=-1) # [pop, 8]
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.shr', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.shr: SKIP ({e})")
# MUL (partial products only - just verify AND gates work)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
mul_tests = [(3, 4), (7, 8), (15, 17), (0, 255)]
for a_val, b_val in mul_tests:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# Test partial product AND gates
for i in range(8):
for j in range(8):
w = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.weight'].view(pop_size, 2)
b = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.bias'].view(pop_size)
inp = torch.tensor([a_bits[i].item(), b_bits[j].item()], device=self.device)
out = heaviside((inp * w).sum(-1) + b)
expected = float(int(a_bits[i].item()) & int(b_bits[j].item()))
correct = (out == expected).float()
op_scores += correct
op_total += 1
scores += op_scores
total += op_total
self._record('alu.alu8bit.mul', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.mul: SKIP ({e})")
# DIV (comparison gates only)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
div_tests = [(100, 10), (255, 17), (50, 7), (128, 16)]
for a_val, b_val in div_tests:
# Test each stage's comparison gate
for stage in range(8):
w = pop[f'alu.alu8bit.div.stage{stage}.cmp.weight'].view(pop_size, 16)
b = pop[f'alu.alu8bit.div.stage{stage}.cmp.bias'].view(pop_size)
# Create test inputs (simplified: just test that gate exists and has correct shape)
test_rem = (a_val >> (7 - stage)) & 0xFF
rem_bits = torch.tensor([((test_rem >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
div_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
inp = torch.cat([rem_bits, div_bits])
out = heaviside((inp * w).sum(-1) + b)
expected = float(test_rem >= b_val)
correct = (out == expected).float()
op_scores += correct
op_total += 1
scores += op_scores
total += op_total
self._record('alu.alu8bit.div', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.div: SKIP ({e})")
return scores, total
# =========================================================================
# MANIFEST
# =========================================================================
def _test_manifest(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Verify manifest values."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== MANIFEST ===")
expected = {
'manifest.alu_operations': 16.0,
'manifest.flags': 4.0,
'manifest.instruction_width': 16.0,
'manifest.memory_bytes': 65536.0,
'manifest.pc_width': 16.0,
'manifest.register_width': 8.0,
'manifest.registers': 4.0,
'manifest.turing_complete': 1.0,
'manifest.version': 3.0,
}
for name, exp_val in expected.items():
try:
val = pop[name][0, 0].item() # [pop_size, 1] -> scalar
if val == exp_val:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(exp_val, val)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
pass
return scores, total
# =========================================================================
# MEMORY
# =========================================================================
def _test_memory(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test memory circuits (shape validation)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== MEMORY ===")
expected_shapes = {
'memory.addr_decode.weight': (65536, 16),
'memory.addr_decode.bias': (65536,),
'memory.read.and.weight': (8, 65536, 2),
'memory.read.and.bias': (8, 65536),
'memory.read.or.weight': (8, 65536),
'memory.read.or.bias': (8,),
'memory.write.sel.weight': (65536, 2),
'memory.write.sel.bias': (65536,),
'memory.write.nsel.weight': (65536, 1),
'memory.write.nsel.bias': (65536,),
'memory.write.and_old.weight': (65536, 8, 2),
'memory.write.and_old.bias': (65536, 8),
'memory.write.and_new.weight': (65536, 8, 2),
'memory.write.and_new.bias': (65536, 8),
'memory.write.or.weight': (65536, 8, 2),
'memory.write.or.bias': (65536, 8),
}
for name, expected_shape in expected_shapes.items():
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:]) # Skip pop_size dimension
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
pass
return scores, total
# =========================================================================
# MAIN EVALUATE
# =========================================================================
def evaluate(self, population: Dict[str, torch.Tensor], debug: bool = False) -> torch.Tensor:
"""
Evaluate population fitness with per-circuit reporting.
Args:
population: Dict of tensors, each with shape [pop_size, ...]
debug: If True, print per-circuit results
Returns:
Tensor of fitness scores [pop_size], normalized to [0, 1]
"""
self.results = []
self.category_scores = {}
pop_size = next(iter(population.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total_tests = 0
# Boolean gates
s, t = self._test_boolean_gates(population, debug)
scores += s
total_tests += t
self.category_scores['boolean'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Half adder
s, t = self._test_halfadder(population, debug)
scores += s
total_tests += t
self.category_scores['halfadder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Full adder
s, t = self._test_fulladder(population, debug)
scores += s
total_tests += t
self.category_scores['fulladder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Ripple carry adders
for bits in [2, 4, 8]:
s, t = self._test_ripplecarry(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Comparators
s, t = self._test_comparators(population, debug)
scores += s
total_tests += t
self.category_scores['comparators'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Threshold gates
s, t = self._test_threshold_gates(population, debug)
scores += s
total_tests += t
self.category_scores['threshold'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Modular arithmetic
s, t = self._test_modular_all(population, debug)
scores += s
total_tests += t
self.category_scores['modular'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Pattern recognition
s, t = self._test_patterns(population, debug)
scores += s
total_tests += t
self.category_scores['patterns'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Error detection
s, t = self._test_error_detection(population, debug)
scores += s
total_tests += t
self.category_scores['error_detection'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Combinational
s, t = self._test_combinational(population, debug)
scores += s
total_tests += t
self.category_scores['combinational'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Control flow
s, t = self._test_control_flow(population, debug)
scores += s
total_tests += t
self.category_scores['control'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# ALU
s, t = self._test_alu_ops(population, debug)
scores += s
total_tests += t
self.category_scores['alu'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Manifest
s, t = self._test_manifest(population, debug)
scores += s
total_tests += t
self.category_scores['manifest'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Memory
s, t = self._test_memory(population, debug)
scores += s
total_tests += t
self.category_scores['memory'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
self.total_tests = total_tests
if debug:
print("\n" + "=" * 60)
print("CATEGORY SUMMARY")
print("=" * 60)
for cat, (got, expected) in sorted(self.category_scores.items()):
pct = 100 * got / expected if expected > 0 else 0
status = "PASS" if got == expected else "FAIL"
print(f" {cat:20} {int(got):6}/{expected:6} ({pct:6.2f}%) [{status}]")
print("\n" + "=" * 60)
print("CIRCUIT FAILURES")
print("=" * 60)
failed = [r for r in self.results if not r.success]
if failed:
for r in failed[:20]:
print(f" {r.name}: {r.passed}/{r.total}")
if r.failures:
print(f" First failure: {r.failures[0]}")
if len(failed) > 20:
print(f" ... and {len(failed) - 20} more")
else:
print(" None!")
return scores / total_tests if total_tests > 0 else scores
def main():
parser = argparse.ArgumentParser(description='Unified Evaluation Suite for 8-bit Threshold Computer')
parser.add_argument('--model', type=str, default=MODEL_PATH, help='Path to safetensors model')
parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu')
parser.add_argument('--pop_size', type=int, default=1, help='Population size for batched evaluation')
parser.add_argument('--quiet', action='store_true', help='Suppress detailed output')
args = parser.parse_args()
print("=" * 70)
print(" UNIFIED EVALUATION SUITE")
print("=" * 70)
print(f"\nLoading model from {args.model}...")
model = load_model(args.model)
print(f" Loaded {len(model)} tensors, {sum(t.numel() for t in model.values()):,} params")
print(f"\nInitializing evaluator on {args.device}...")
evaluator = BatchedFitnessEvaluator(device=args.device, model_path=args.model)
print(f"\nCreating population (size {args.pop_size})...")
population = create_population(model, pop_size=args.pop_size, device=args.device)
print("\nRunning evaluation...")
if args.device == 'cuda':
torch.cuda.synchronize()
start = time.perf_counter()
fitness = evaluator.evaluate(population, debug=not args.quiet)
if args.device == 'cuda':
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
print("\n" + "=" * 70)
print("RESULTS")
print("=" * 70)
if args.pop_size == 1:
print(f" Fitness: {fitness[0].item():.6f}")
else:
print(f" Mean Fitness: {fitness.mean().item():.6f}")
print(f" Min Fitness: {fitness.min().item():.6f}")
print(f" Max Fitness: {fitness.max().item():.6f}")
print(f" Total tests: {evaluator.total_tests}")
print(f" Time: {elapsed * 1000:.2f} ms")
if args.pop_size > 1:
print(f" Throughput: {args.pop_size / elapsed:.0f} evals/sec")
perfect = (fitness >= 0.9999).sum().item()
print(f" Perfect (>=99.99%): {perfect}/{args.pop_size}")
if fitness[0].item() >= 0.9999:
print("\n STATUS: PASS")
return 0
else:
failed_count = int((1 - fitness[0].item()) * evaluator.total_tests)
print(f"\n STATUS: FAIL ({failed_count} tests failed)")
return 1
if __name__ == '__main__':
exit(main())