CharlesCNorton
Fix priority encoder circuit logic
a696964
"""
Unified Evaluation Suite for 8-bit Threshold Computer
======================================================
GPU-batched evaluation with per-circuit reporting.
Includes CPU runtime for threshold-weight execution.
Usage:
python eval.py # Run circuit evaluation
python eval.py --device cpu # CPU mode
python eval.py --pop_size 1000 # Population mode for evolution
python eval.py --cpu-test # Run CPU smoke test
API (for prune_weights.py):
from eval import load_model, create_population, BatchedFitnessEvaluator
from eval import ThresholdCPU, ThresholdALU, CPUState
"""
import argparse
import json
import os
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
import torch
from safetensors import safe_open
MODEL_PATH = os.path.join(os.path.dirname(__file__), "neural_computer.safetensors")
@dataclass
class CircuitResult:
"""Result for a single circuit test."""
name: str
passed: int
total: int
failures: List[Tuple] = field(default_factory=list)
@property
def success(self) -> bool:
return self.passed == self.total
@property
def rate(self) -> float:
return self.passed / self.total if self.total > 0 else 0.0
def heaviside(x: torch.Tensor) -> torch.Tensor:
"""Threshold activation: 1 if x >= 0, else 0."""
return (x >= 0).float()
def load_model(path: str = MODEL_PATH) -> Dict[str, torch.Tensor]:
"""Load model tensors from safetensors."""
with safe_open(path, framework='pt') as f:
return {name: f.get_tensor(name).float() for name in f.keys()}
def load_metadata(path: str = MODEL_PATH) -> Dict:
"""Load metadata from safetensors (includes signal_registry)."""
with safe_open(path, framework='pt') as f:
meta = f.metadata()
if meta and 'signal_registry' in meta:
return {'signal_registry': json.loads(meta['signal_registry'])}
return {'signal_registry': {}}
def get_manifest(tensors: Dict[str, torch.Tensor]) -> Dict[str, int]:
"""Extract manifest values from tensors.
Returns dict with data_bits, addr_bits, memory_bytes, version.
Defaults to 8-bit data, 16-bit addr for legacy models.
"""
return {
'data_bits': int(tensors.get('manifest.data_bits', torch.tensor([8.0])).item()),
'addr_bits': int(tensors.get('manifest.addr_bits',
tensors.get('manifest.pc_width', torch.tensor([16.0]))).item()),
'memory_bytes': int(tensors.get('manifest.memory_bytes', torch.tensor([65536.0])).item()),
'version': float(tensors.get('manifest.version', torch.tensor([1.0])).item()),
}
def create_population(
base_tensors: Dict[str, torch.Tensor],
pop_size: int,
device: str = 'cuda'
) -> Dict[str, torch.Tensor]:
"""Replicate base tensors for batched population evaluation."""
return {
name: tensor.unsqueeze(0).expand(pop_size, *tensor.shape).clone().to(device)
for name, tensor in base_tensors.items()
}
# =============================================================================
# CPU RUNTIME
# =============================================================================
FLAG_NAMES = ["Z", "N", "C", "V"]
CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"]
PC_BITS = 16
IR_BITS = 16
REG_BITS = 8
REG_COUNT = 4
FLAG_BITS = 4
SP_BITS = 16
CTRL_BITS = 4
MEM_BYTES = 65536
MEM_BITS = MEM_BYTES * 8
STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
def int_to_bits(value: int, width: int) -> List[int]:
return [(value >> (width - 1 - i)) & 1 for i in range(width)]
def bits_to_int(bits: List[int]) -> int:
value = 0
for bit in bits:
value = (value << 1) | int(bit)
return value
def bits_msb_to_lsb(bits: List[int]) -> List[int]:
return list(reversed(bits))
@dataclass
class CPUState:
pc: int
ir: int
regs: List[int]
flags: List[int]
sp: int
ctrl: List[int]
mem: List[int]
def copy(self) -> 'CPUState':
return CPUState(
pc=int(self.pc),
ir=int(self.ir),
regs=[int(r) for r in self.regs],
flags=[int(f) for f in self.flags],
sp=int(self.sp),
ctrl=[int(c) for c in self.ctrl],
mem=[int(m) for m in self.mem],
)
def pack_state(state: CPUState) -> List[int]:
bits: List[int] = []
bits.extend(int_to_bits(state.pc, PC_BITS))
bits.extend(int_to_bits(state.ir, IR_BITS))
for reg in state.regs:
bits.extend(int_to_bits(reg, REG_BITS))
bits.extend([int(f) for f in state.flags])
bits.extend(int_to_bits(state.sp, SP_BITS))
bits.extend([int(c) for c in state.ctrl])
for byte in state.mem:
bits.extend(int_to_bits(byte, REG_BITS))
return bits
def unpack_state(bits: List[int]) -> CPUState:
if len(bits) != STATE_BITS:
raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}")
idx = 0
pc = bits_to_int(bits[idx:idx + PC_BITS])
idx += PC_BITS
ir = bits_to_int(bits[idx:idx + IR_BITS])
idx += IR_BITS
regs = []
for _ in range(REG_COUNT):
regs.append(bits_to_int(bits[idx:idx + REG_BITS]))
idx += REG_BITS
flags = [int(b) for b in bits[idx:idx + FLAG_BITS]]
idx += FLAG_BITS
sp = bits_to_int(bits[idx:idx + SP_BITS])
idx += SP_BITS
ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]]
idx += CTRL_BITS
mem = []
for _ in range(MEM_BYTES):
mem.append(bits_to_int(bits[idx:idx + REG_BITS]))
idx += REG_BITS
return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem)
def decode_ir(ir: int) -> Tuple[int, int, int, int]:
opcode = (ir >> 12) & 0xF
rd = (ir >> 10) & 0x3
rs = (ir >> 8) & 0x3
imm8 = ir & 0xFF
return opcode, rd, rs, imm8
def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]:
z = 1 if result == 0 else 0
n = 1 if (result & 0x80) else 0
c = 1 if carry else 0
v = 1 if overflow else 0
return z, n, c, v
def alu_add(a: int, b: int) -> Tuple[int, int, int]:
full = a + b
result = full & 0xFF
carry = 1 if full > 0xFF else 0
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
return result, carry, overflow
def alu_sub(a: int, b: int) -> Tuple[int, int, int]:
full = (a - b) & 0x1FF
result = full & 0xFF
carry = 1 if a >= b else 0
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
return result, carry, overflow
def ref_step(state: CPUState) -> CPUState:
"""Reference CPU cycle (pure Python arithmetic)."""
if state.ctrl[0] == 1:
return state.copy()
s = state.copy()
hi = s.mem[s.pc]
lo = s.mem[(s.pc + 1) & 0xFFFF]
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
next_pc = (s.pc + 2) & 0xFFFF
opcode, rd, rs, imm8 = decode_ir(s.ir)
a = s.regs[rd]
b = s.regs[rs]
addr16 = None
next_pc_ext = next_pc
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
addr_hi = s.mem[next_pc]
addr_lo = s.mem[(next_pc + 1) & 0xFFFF]
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
next_pc_ext = (next_pc + 2) & 0xFFFF
write_result = True
result = a
carry = 0
overflow = 0
if opcode == 0x0:
result, carry, overflow = alu_add(a, b)
elif opcode == 0x1:
result, carry, overflow = alu_sub(a, b)
elif opcode == 0x2:
result = a & b
elif opcode == 0x3:
result = a | b
elif opcode == 0x4:
result = a ^ b
elif opcode == 0x5:
result = (a << 1) & 0xFF
elif opcode == 0x6:
result = (a >> 1) & 0xFF
elif opcode == 0x7:
result = (a * b) & 0xFF
elif opcode == 0x8:
if b == 0:
result = 0xFF
else:
result = a // b
elif opcode == 0x9:
result, carry, overflow = alu_sub(a, b)
write_result = False
elif opcode == 0xA:
result = s.mem[addr16]
elif opcode == 0xB:
s.mem[addr16] = b & 0xFF
write_result = False
elif opcode == 0xC:
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xD:
cond_type = imm8 & 0x7
if cond_type == 0:
take_branch = s.flags[0] == 1
elif cond_type == 1:
take_branch = s.flags[0] == 0
elif cond_type == 2:
take_branch = s.flags[2] == 1
elif cond_type == 3:
take_branch = s.flags[2] == 0
elif cond_type == 4:
take_branch = s.flags[1] == 1
elif cond_type == 5:
take_branch = s.flags[1] == 0
elif cond_type == 6:
take_branch = s.flags[3] == 1
else:
take_branch = s.flags[3] == 0
if take_branch:
s.pc = addr16 & 0xFFFF
else:
s.pc = next_pc_ext
write_result = False
elif opcode == 0xE:
ret_addr = next_pc_ext & 0xFFFF
s.sp = (s.sp - 1) & 0xFFFF
s.mem[s.sp] = (ret_addr >> 8) & 0xFF
s.sp = (s.sp - 1) & 0xFFFF
s.mem[s.sp] = ret_addr & 0xFF
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xF:
s.ctrl[0] = 1
write_result = False
if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8):
s.flags = list(flags_from_result(result, carry, overflow))
if write_result:
s.regs[rd] = result & 0xFF
if opcode not in (0xC, 0xD, 0xE):
s.pc = next_pc_ext
return s
def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
"""Reference execution loop."""
s = state.copy()
for i in range(max_cycles):
if s.ctrl[0] == 1:
return s, i
s = ref_step(s)
return s, max_cycles
class ThresholdALU:
def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None:
self.device = device
self.tensors = {k: v.float().to(device) for k, v in load_model(model_path).items()}
def _get(self, name: str) -> torch.Tensor:
return self.tensors[name]
def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float:
w = self._get(weight_key)
b = self._get(bias_key)
inp = torch.tensor(inputs, device=self.device)
return heaviside((inp * w).sum() + b).item()
def _eval_xor(self, prefix: str, inputs: List[float]) -> float:
inp = torch.tensor(inputs, device=self.device)
w_or = self._get(f"{prefix}.layer1.or.weight")
b_or = self._get(f"{prefix}.layer1.or.bias")
w_nand = self._get(f"{prefix}.layer1.nand.weight")
b_nand = self._get(f"{prefix}.layer1.nand.bias")
w2 = self._get(f"{prefix}.layer2.weight")
b2 = self._get(f"{prefix}.layer2.bias")
h_or = heaviside((inp * w_or).sum() + b_or).item()
h_nand = heaviside((inp * w_nand).sum() + b_nand).item()
hidden = torch.tensor([h_or, h_nand], device=self.device)
return heaviside((hidden * w2).sum() + b2).item()
def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]:
ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b])
ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b])
ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin])
ha2_carry = self._eval_gate(
f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin]
)
cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry])
return ha2_sum, cout
def add(self, a: int, b: int) -> Tuple[int, int, int]:
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
carry = 0.0
sum_bits: List[int] = []
for bit in range(REG_BITS):
sum_bit, carry = self._eval_full_adder(
f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry
)
sum_bits.append(int(sum_bit))
result = bits_to_int(list(reversed(sum_bits)))
carry_out = int(carry)
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
return result, carry_out, overflow
def sub(self, a: int, b: int) -> Tuple[int, int, int]:
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
carry = 1.0
sum_bits: List[int] = []
for bit in range(REG_BITS):
notb = self._eval_gate(
f"arithmetic.sub8bit.notb{bit}.weight",
f"arithmetic.sub8bit.notb{bit}.bias",
[float(b_bits[bit])],
)
xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb])
xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry])
and1 = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.and1.weight",
f"arithmetic.sub8bit.fa{bit}.and1.bias",
[float(a_bits[bit]), notb],
)
and2 = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.and2.weight",
f"arithmetic.sub8bit.fa{bit}.and2.bias",
[xor1, carry],
)
carry = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.or_carry.weight",
f"arithmetic.sub8bit.fa{bit}.or_carry.bias",
[and1, and2],
)
sum_bits.append(int(xor2))
result = bits_to_int(list(reversed(sum_bits)))
carry_out = int(carry)
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
return result, carry_out, overflow
def bitwise_and(self, a: int, b: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
w = self._get("alu.alu8bit.and.weight")
bias = self._get("alu.alu8bit.and.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def bitwise_or(self, a: int, b: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
w = self._get("alu.alu8bit.or.weight")
bias = self._get("alu.alu8bit.or.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def bitwise_not(self, a: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
w = self._get("alu.alu8bit.not.weight")
bias = self._get("alu.alu8bit.not.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit])], device=self.device)
out = heaviside((inp * w[bit]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def bitwise_xor(self, a: int, b: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
w_or = self._get("alu.alu8bit.xor.layer1.or.weight")
b_or = self._get("alu.alu8bit.xor.layer1.or.bias")
w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight")
b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias")
w2 = self._get("alu.alu8bit.xor.layer2.weight")
b2 = self._get("alu.alu8bit.xor.layer2.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit])
h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit])
hidden = torch.stack([h_or, h_nand])
out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def shift_left(self, a: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
out_bits = []
for bit in range(REG_BITS):
w = self._get(f"alu.alu8bit.shl.bit{bit}.weight")
bias = self._get(f"alu.alu8bit.shl.bit{bit}.bias")
if bit < 7:
inp = torch.tensor([float(a_bits[bit + 1])], device=self.device)
else:
inp = torch.tensor([0.0], device=self.device)
out = heaviside((inp * w).sum() + bias).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def shift_right(self, a: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
out_bits = []
for bit in range(REG_BITS):
w = self._get(f"alu.alu8bit.shr.bit{bit}.weight")
bias = self._get(f"alu.alu8bit.shr.bit{bit}.bias")
if bit > 0:
inp = torch.tensor([float(a_bits[bit - 1])], device=self.device)
else:
inp = torch.tensor([0.0], device=self.device)
out = heaviside((inp * w).sum() + bias).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def multiply(self, a: int, b: int) -> int:
"""8-bit multiply using partial product AND gates + shift-add."""
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
pp = [[0] * 8 for _ in range(8)]
for i in range(8):
for j in range(8):
w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight")
bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias")
inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device)
pp[i][j] = int(heaviside((inp * w).sum() + bias).item())
result = 0
for j in range(8):
if b_bits[j] == 0:
continue
row = 0
for i in range(8):
row |= (pp[i][j] << (7 - i))
shifted = row << (7 - j)
result, _, _ = self.add(result & 0xFF, shifted & 0xFF)
if shifted > 255 or result > 255:
result = (result + (shifted >> 8)) & 0xFF
return result & 0xFF
def divide(self, a: int, b: int) -> Tuple[int, int]:
"""8-bit divide using restoring division with threshold gates."""
if b == 0:
return 0xFF, a
a_bits = int_to_bits(a, REG_BITS)
quotient = 0
remainder = 0
for stage in range(8):
remainder = ((remainder << 1) | a_bits[stage]) & 0xFF
rem_bits = int_to_bits(remainder, REG_BITS)
div_bits = int_to_bits(b, REG_BITS)
w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight")
bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias")
inp = torch.tensor([float(rem_bits[i]) for i in range(8)] +
[float(div_bits[i]) for i in range(8)], device=self.device)
cmp_result = int(heaviside((inp * w).sum() + bias).item())
if cmp_result:
remainder, _, _ = self.sub(remainder, b)
quotient = (quotient << 1) | 1
else:
quotient = quotient << 1
return quotient & 0xFF, remainder & 0xFF
class ThresholdCPU:
def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None:
self.device = device
self.alu = ThresholdALU(model_path, device=device)
def _addr_decode(self, addr: int) -> torch.Tensor:
bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32)
w = self.alu._get("memory.addr_decode.weight")
b = self.alu._get("memory.addr_decode.bias")
return heaviside((w * bits).sum(dim=1) + b)
def _memory_read(self, mem: List[int], addr: int) -> int:
sel = self._addr_decode(addr)
mem_bits = torch.tensor(
[int_to_bits(byte, REG_BITS) for byte in mem],
device=self.device,
dtype=torch.float32,
)
and_w = self.alu._get("memory.read.and.weight")
and_b = self.alu._get("memory.read.and.bias")
or_w = self.alu._get("memory.read.or.weight")
or_b = self.alu._get("memory.read.or.bias")
out_bits: List[int] = []
for bit in range(REG_BITS):
inp = torch.stack([mem_bits[:, bit], sel], dim=1)
and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()
out_bits.append(int(out_bit))
return bits_to_int(out_bits)
def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]:
sel = self._addr_decode(addr)
data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32)
mem_bits = torch.tensor(
[int_to_bits(byte, REG_BITS) for byte in mem],
device=self.device,
dtype=torch.float32,
)
sel_w = self.alu._get("memory.write.sel.weight")
sel_b = self.alu._get("memory.write.sel.bias")
nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1)
nsel_b = self.alu._get("memory.write.nsel.bias")
and_old_w = self.alu._get("memory.write.and_old.weight")
and_old_b = self.alu._get("memory.write.and_old.bias")
and_new_w = self.alu._get("memory.write.and_new.weight")
and_new_b = self.alu._get("memory.write.and_new.bias")
or_w = self.alu._get("memory.write.or.weight")
or_b = self.alu._get("memory.write.or.bias")
we = torch.ones_like(sel)
sel_inp = torch.stack([sel, we], dim=1)
write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
nsel = heaviside((write_sel * nsel_w) + nsel_b)
new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device)
for bit in range(REG_BITS):
old_bit = mem_bits[:, bit]
data_bit = data_bits[bit].expand(MEM_BYTES)
inp_old = torch.stack([old_bit, nsel], dim=1)
inp_new = torch.stack([data_bit, write_sel], dim=1)
and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
or_inp = torch.stack([and_old, and_new], dim=1)
out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
new_mem_bits[:, bit] = out_bit
return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)]
def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int:
pc_bits = int_to_bits(pc_byte, REG_BITS)
target_bits = int_to_bits(target_byte, REG_BITS)
out_bits: List[int] = []
for bit in range(REG_BITS):
not_sel = self.alu._eval_gate(
f"{prefix}.bit{bit}.not_sel.weight",
f"{prefix}.bit{bit}.not_sel.bias",
[float(flag)],
)
and_a = self.alu._eval_gate(
f"{prefix}.bit{bit}.and_a.weight",
f"{prefix}.bit{bit}.and_a.bias",
[float(pc_bits[bit]), not_sel],
)
and_b = self.alu._eval_gate(
f"{prefix}.bit{bit}.and_b.weight",
f"{prefix}.bit{bit}.and_b.bias",
[float(target_bits[bit]), float(flag)],
)
out_bit = self.alu._eval_gate(
f"{prefix}.bit{bit}.or.weight",
f"{prefix}.bit{bit}.or.bias",
[and_a, and_b],
)
out_bits.append(int(out_bit))
return bits_to_int(out_bits)
def step(self, state: CPUState) -> CPUState:
"""Single CPU cycle using threshold neurons."""
if state.ctrl[0] == 1:
return state.copy()
s = state.copy()
hi = self._memory_read(s.mem, s.pc)
lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF)
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
next_pc = (s.pc + 2) & 0xFFFF
opcode, rd, rs, imm8 = decode_ir(s.ir)
a = s.regs[rd]
b = s.regs[rs]
addr16 = None
next_pc_ext = next_pc
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
addr_hi = self._memory_read(s.mem, next_pc)
addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF)
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
next_pc_ext = (next_pc + 2) & 0xFFFF
write_result = True
result = a
carry = 0
overflow = 0
if opcode == 0x0:
result, carry, overflow = self.alu.add(a, b)
elif opcode == 0x1:
result, carry, overflow = self.alu.sub(a, b)
elif opcode == 0x2:
result = self.alu.bitwise_and(a, b)
elif opcode == 0x3:
result = self.alu.bitwise_or(a, b)
elif opcode == 0x4:
result = self.alu.bitwise_xor(a, b)
elif opcode == 0x5:
result = self.alu.shift_left(a)
elif opcode == 0x6:
result = self.alu.shift_right(a)
elif opcode == 0x7:
result = self.alu.multiply(a, b)
elif opcode == 0x8:
result, _ = self.alu.divide(a, b)
elif opcode == 0x9:
result, carry, overflow = self.alu.sub(a, b)
write_result = False
elif opcode == 0xA:
result = self._memory_read(s.mem, addr16)
elif opcode == 0xB:
s.mem = self._memory_write(s.mem, addr16, b & 0xFF)
write_result = False
elif opcode == 0xC:
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xD:
cond_type = imm8 & 0x7
cond_circuits = [
("control.jz", 0),
("control.jnz", 0),
("control.jc", 2),
("control.jnc", 2),
("control.jn", 1),
("control.jp", 1),
("control.jv", 3),
("control.jnv", 3),
]
circuit_prefix, flag_idx = cond_circuits[cond_type]
hi_pc = self._conditional_jump_byte(
circuit_prefix,
(next_pc_ext >> 8) & 0xFF,
(addr16 >> 8) & 0xFF,
s.flags[flag_idx],
)
lo_pc = self._conditional_jump_byte(
circuit_prefix,
next_pc_ext & 0xFF,
addr16 & 0xFF,
s.flags[flag_idx],
)
s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF)
write_result = False
elif opcode == 0xE:
ret_addr = next_pc_ext & 0xFFFF
s.sp = (s.sp - 1) & 0xFFFF
s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF)
s.sp = (s.sp - 1) & 0xFFFF
s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF)
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xF:
s.ctrl[0] = 1
write_result = False
if opcode <= 0x9 or opcode == 0xA:
s.flags = list(flags_from_result(result, carry, overflow))
if write_result:
s.regs[rd] = result & 0xFF
if opcode not in (0xC, 0xD, 0xE):
s.pc = next_pc_ext
return s
def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
"""Execute until HALT or max_cycles reached."""
s = state.copy()
for i in range(max_cycles):
if s.ctrl[0] == 1:
return s, i
s = self.step(s)
return s, max_cycles
def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor:
"""Tensor-in, tensor-out interface for neural integration."""
bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()]
state = unpack_state(bits_list)
final, _ = self.run_until_halt(state, max_cycles=max_cycles)
return torch.tensor(pack_state(final), dtype=torch.float32)
def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int:
return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF)
def write_instr(mem: List[int], addr: int, instr: int) -> None:
mem[addr & 0xFFFF] = (instr >> 8) & 0xFF
mem[(addr + 1) & 0xFFFF] = instr & 0xFF
def write_addr(mem: List[int], addr: int, value: int) -> None:
mem[addr & 0xFFFF] = (value >> 8) & 0xFF
mem[(addr + 1) & 0xFFFF] = value & 0xFF
def run_smoke_test() -> int:
"""Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12."""
mem = [0] * 65536
write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00))
write_addr(mem, 0x0002, 0x0100)
write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00))
write_addr(mem, 0x0006, 0x0101)
write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00))
write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00))
write_addr(mem, 0x000C, 0x0102)
write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00))
mem[0x0100] = 5
mem[0x0101] = 7
state = CPUState(
pc=0,
ir=0,
regs=[0, 0, 0, 0],
flags=[0, 0, 0, 0],
sp=0xFFFE,
ctrl=[0, 0, 0, 0],
mem=mem,
)
print("Running reference implementation...")
final, cycles = ref_run_until_halt(state, max_cycles=20)
assert final.ctrl[0] == 1, "HALT flag not set"
assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}"
assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}"
assert cycles <= 10, f"Unexpected cycle count: {cycles}"
print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}")
print("Running threshold-weight implementation...")
threshold_cpu = ThresholdCPU()
t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20)
assert t_final.ctrl[0] == 1, "Threshold HALT flag not set"
assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}"
assert t_final.mem[0x0102] == final.mem[0x0102], (
f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}"
)
assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}"
print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}")
print("Validating forward() tensor I/O...")
bits = torch.tensor(pack_state(state), dtype=torch.float32)
out_bits = threshold_cpu.forward(bits, max_cycles=20)
out_state = unpack_state([int(b) for b in out_bits.tolist()])
assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}"
assert out_state.mem[0x0102] == final.mem[0x0102], (
f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}"
)
print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}")
print("\nSmoke test: PASSED")
return 0
# =============================================================================
# CIRCUIT EVALUATION
# =============================================================================
class BatchedFitnessEvaluator:
"""
GPU-batched fitness evaluator with per-circuit reporting.
Tests all circuits comprehensively.
"""
def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH, tensors: Dict[str, torch.Tensor] = None):
self.device = device
self.model_path = model_path
self.metadata = load_metadata(model_path)
self.signal_registry = self.metadata.get('signal_registry', {})
self.results: List[CircuitResult] = []
self.category_scores: Dict[str, Tuple[float, int]] = {}
self.total_tests = 0
# Get manifest for N-bit support
if tensors is not None:
self.manifest = get_manifest(tensors)
else:
base_tensors = load_model(model_path)
self.manifest = get_manifest(base_tensors)
self.data_bits = self.manifest['data_bits']
self.addr_bits = self.manifest['addr_bits']
self._setup_tests()
def _setup_tests(self):
"""Pre-compute test vectors on device."""
d = self.device
# 2-input truth table [4, 2]
self.tt2 = torch.tensor(
[[0, 0], [0, 1], [1, 0], [1, 1]],
device=d, dtype=torch.float32
)
# 3-input truth table [8, 3]
self.tt3 = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]
], device=d, dtype=torch.float32)
# Boolean gate expected outputs
self.expected = {
'and': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32),
'or': torch.tensor([0, 1, 1, 1], device=d, dtype=torch.float32),
'nand': torch.tensor([1, 1, 1, 0], device=d, dtype=torch.float32),
'nor': torch.tensor([1, 0, 0, 0], device=d, dtype=torch.float32),
'xor': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32),
'xnor': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32),
'implies': torch.tensor([1, 1, 0, 1], device=d, dtype=torch.float32),
'biimplies': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32),
'not': torch.tensor([1, 0], device=d, dtype=torch.float32),
'ha_sum': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32),
'ha_carry': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32),
'fa_sum': torch.tensor([0, 1, 1, 0, 1, 0, 0, 1], device=d, dtype=torch.float32),
'fa_cout': torch.tensor([0, 0, 0, 1, 0, 1, 1, 1], device=d, dtype=torch.float32),
}
# NOT gate inputs
self.not_inputs = torch.tensor([[0], [1]], device=d, dtype=torch.float32)
# 8-bit test values
self.test_8bit = torch.tensor([
0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255,
0b10101010, 0b01010101, 0b11110000, 0b00001111,
0b11001100, 0b00110011, 0b10000001, 0b01111110
], device=d, dtype=torch.long)
# Bit representations [num_vals, 8]
self.test_8bit_bits = torch.stack([
((self.test_8bit >> (7 - i)) & 1).float() for i in range(8)
], dim=1)
# Comparator test pairs
comp_tests = [
(0, 0), (1, 0), (0, 1), (5, 3), (3, 5), (5, 5),
(255, 0), (0, 255), (128, 127), (127, 128),
(100, 99), (99, 100), (64, 32), (32, 64),
(1, 1), (254, 255), (255, 254), (128, 128),
(0, 128), (128, 0), (64, 64), (192, 192),
(15, 16), (16, 15), (240, 239), (239, 240),
(85, 170), (170, 85), (0xAA, 0x55), (0x55, 0xAA),
(0x0F, 0xF0), (0xF0, 0x0F), (0x33, 0xCC), (0xCC, 0x33),
(2, 3), (3, 2), (126, 127), (127, 126),
(129, 128), (128, 129), (200, 199), (199, 200),
(50, 51), (51, 50), (10, 20), (20, 10),
(100, 100), (200, 200), (77, 77), (0, 0)
]
self.comp_a = torch.tensor([c[0] for c in comp_tests], device=d, dtype=torch.long)
self.comp_b = torch.tensor([c[1] for c in comp_tests], device=d, dtype=torch.long)
# Modular test range
self.mod_test = torch.arange(256, device=d, dtype=torch.long)
# 32-bit test values (strategic sampling)
self.test_32bit = torch.tensor([
0, 1, 2, 255, 256, 65535, 65536,
0x7FFFFFFF, 0x80000000, 0xFFFFFFFF,
0x12345678, 0xDEADBEEF, 0xCAFEBABE,
1000000, 1000000000, 2147483647,
0x55555555, 0xAAAAAAAA, 0x0F0F0F0F, 0xF0F0F0F0
], device=d, dtype=torch.long)
# 32-bit comparator test pairs
comp32_tests = [
(0, 0), (1, 0), (0, 1), (1000, 999), (999, 1000),
(0xFFFFFFFF, 0), (0, 0xFFFFFFFF),
(0x80000000, 0x7FFFFFFF), (0x7FFFFFFF, 0x80000000),
(1000000, 1000000), (0x12345678, 0x12345678),
(0xDEADBEEF, 0xCAFEBABE), (0xCAFEBABE, 0xDEADBEEF),
(256, 255), (255, 256), (65536, 65535), (65535, 65536),
]
self.comp32_a = torch.tensor([c[0] for c in comp32_tests], device=d, dtype=torch.long)
self.comp32_b = torch.tensor([c[1] for c in comp32_tests], device=d, dtype=torch.long)
def _record(self, name: str, passed: int, total: int, failures: List[Tuple] = None):
"""Record a circuit test result."""
self.results.append(CircuitResult(
name=name,
passed=passed,
total=total,
failures=failures or []
))
# =========================================================================
# BOOLEAN GATES
# =========================================================================
def _test_single_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
expected: torch.Tensor) -> torch.Tensor:
"""Test single-layer gate (AND, OR, NAND, NOR, IMPLIES)."""
pop_size = next(iter(pop.values())).shape[0]
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
# [num_tests, pop_size]
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record(prefix, int(correct[0].item()), len(expected), failures)
return correct
def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
expected: torch.Tensor) -> torch.Tensor:
"""Test two-layer gate (XOR, XNOR, BIIMPLIES)."""
pop_size = next(iter(pop.values())).shape[0]
# Layer 1
w1_n1 = pop[f'{prefix}.layer1.neuron1.weight']
b1_n1 = pop[f'{prefix}.layer1.neuron1.bias']
w1_n2 = pop[f'{prefix}.layer1.neuron2.weight']
b1_n2 = pop[f'{prefix}.layer1.neuron2.bias']
h1 = heaviside(inputs @ w1_n1.view(pop_size, -1).T + b1_n1.view(pop_size))
h2 = heaviside(inputs @ w1_n2.view(pop_size, -1).T + b1_n2.view(pop_size))
hidden = torch.stack([h1, h2], dim=-1)
# Layer 2
w2 = pop[f'{prefix}.layer2.weight']
b2 = pop[f'{prefix}.layer2.bias']
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record(prefix, int(correct[0].item()), len(expected), failures)
return correct
def _test_xor_ornand(self, pop: Dict, prefix: str, inputs: torch.Tensor,
expected: torch.Tensor) -> torch.Tensor:
"""Test XOR with or/nand layer naming."""
pop_size = next(iter(pop.values())).shape[0]
w_or = pop[f'{prefix}.layer1.or.weight']
b_or = pop[f'{prefix}.layer1.or.bias']
w_nand = pop[f'{prefix}.layer1.nand.weight']
b_nand = pop[f'{prefix}.layer1.nand.bias']
h_or = heaviside(inputs @ w_or.view(pop_size, -1).T + b_or.view(pop_size))
h_nand = heaviside(inputs @ w_nand.view(pop_size, -1).T + b_nand.view(pop_size))
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = pop[f'{prefix}.layer2.weight']
b2 = pop[f'{prefix}.layer2.bias']
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record(prefix, int(correct[0].item()), len(expected), failures)
return correct
def _test_boolean_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all boolean gates."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== BOOLEAN GATES ===")
# Single-layer gates
for gate in ['and', 'or', 'nand', 'nor', 'implies']:
scores += self._test_single_gate(pop, f'boolean.{gate}', self.tt2, self.expected[gate])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# NOT gate
w = pop['boolean.not.weight']
b = pop['boolean.not.bias']
out = heaviside(self.not_inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == self.expected['not'].unsqueeze(1)).float().sum(0)
scores += correct
total += 2
failures = []
if pop_size == 1:
for inp, exp, got in zip(self.not_inputs, self.expected['not'], out[:, 0]):
if exp.item() != got.item():
failures.append((inp.tolist(), exp.item(), got.item()))
self._record('boolean.not', int(correct[0].item()), 2, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# Two-layer gates
for gate in ['xnor', 'biimplies']:
scores += self._test_twolayer_gate(pop, f'boolean.{gate}', self.tt2, self.expected.get(gate, self.expected['xnor']))
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# XOR with neuron1/neuron2 naming (same as xnor/biimplies)
scores += self._test_twolayer_gate(pop, 'boolean.xor', self.tt2, self.expected['xor'])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
# =========================================================================
# ARITHMETIC - ADDERS
# =========================================================================
def _eval_xor(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Evaluate XOR gate with or/nand decomposition.
Args:
a, b: Tensors of shape [num_tests] or [num_tests, pop_size]
Returns:
Tensor of shape [num_tests, pop_size]
"""
pop_size = next(iter(pop.values())).shape[0]
# Ensure inputs are [num_tests, pop_size]
if a.dim() == 1:
a = a.unsqueeze(1).expand(-1, pop_size)
if b.dim() == 1:
b = b.unsqueeze(1).expand(-1, pop_size)
# inputs: [num_tests, pop_size, 2]
inputs = torch.stack([a, b], dim=-1)
w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
# [num_tests, pop_size]
h_or = heaviside((inputs * w_or).sum(-1) + b_or)
h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand)
# hidden: [num_tests, pop_size, 2]
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, 2)
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
return heaviside((hidden * w2).sum(-1) + b2)
def _eval_single_fa(self, pop: Dict, prefix: str,
a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Evaluate single full adder.
Args:
a, b, cin: Tensors of shape [num_tests] or [num_tests, pop_size]
Returns:
sum_out, cout: Both of shape [num_tests, pop_size]
"""
pop_size = next(iter(pop.values())).shape[0]
# Ensure inputs are [num_tests, pop_size]
if a.dim() == 1:
a = a.unsqueeze(1).expand(-1, pop_size)
if b.dim() == 1:
b = b.unsqueeze(1).expand(-1, pop_size)
if cin.dim() == 1:
cin = cin.unsqueeze(1).expand(-1, pop_size)
# Half adder 1: a XOR b -> [num_tests, pop_size]
ha1_sum = self._eval_xor(pop, f'{prefix}.ha1.sum', a, b)
# Half adder 1 carry: a AND b
ab = torch.stack([a, b], dim=-1) # [num_tests, pop_size, 2]
w_c1 = pop[f'{prefix}.ha1.carry.weight'].view(pop_size, 2)
b_c1 = pop[f'{prefix}.ha1.carry.bias'].view(pop_size)
ha1_carry = heaviside((ab * w_c1).sum(-1) + b_c1)
# Half adder 2: ha1_sum XOR cin
ha2_sum = self._eval_xor(pop, f'{prefix}.ha2.sum', ha1_sum, cin)
# Half adder 2 carry
sc = torch.stack([ha1_sum, cin], dim=-1)
w_c2 = pop[f'{prefix}.ha2.carry.weight'].view(pop_size, 2)
b_c2 = pop[f'{prefix}.ha2.carry.bias'].view(pop_size)
ha2_carry = heaviside((sc * w_c2).sum(-1) + b_c2)
# Carry out: ha1_carry OR ha2_carry
carries = torch.stack([ha1_carry, ha2_carry], dim=-1)
w_cout = pop[f'{prefix}.carry_or.weight'].view(pop_size, 2)
b_cout = pop[f'{prefix}.carry_or.bias'].view(pop_size)
cout = heaviside((carries * w_cout).sum(-1) + b_cout)
return ha2_sum, cout
def _test_halfadder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test half adder."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== HALF ADDER ===")
# Sum (XOR)
scores += self._test_xor_ornand(pop, 'arithmetic.halfadder.sum', self.tt2, self.expected['ha_sum'])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
# Carry (AND)
scores += self._test_single_gate(pop, 'arithmetic.halfadder.carry', self.tt2, self.expected['ha_carry'])
total += 4
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
def _test_fulladder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test full adder with all 8 input combinations."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print("\n=== FULL ADDER ===")
a = self.tt3[:, 0]
b = self.tt3[:, 1]
cin = self.tt3[:, 2]
sum_out, cout = self._eval_single_fa(pop, 'arithmetic.fulladder', a, b, cin)
sum_correct = (sum_out == self.expected['fa_sum'].unsqueeze(1)).float().sum(0)
cout_correct = (cout == self.expected['fa_cout'].unsqueeze(1)).float().sum(0)
failures_sum = []
failures_cout = []
if pop_size == 1:
for i in range(8):
if sum_out[i, 0].item() != self.expected['fa_sum'][i].item():
failures_sum.append(([a[i].item(), b[i].item(), cin[i].item()],
self.expected['fa_sum'][i].item(), sum_out[i, 0].item()))
if cout[i, 0].item() != self.expected['fa_cout'][i].item():
failures_cout.append(([a[i].item(), b[i].item(), cin[i].item()],
self.expected['fa_cout'][i].item(), cout[i, 0].item()))
self._record('arithmetic.fulladder.sum', int(sum_correct[0].item()), 8, failures_sum)
self._record('arithmetic.fulladder.cout', int(cout_correct[0].item()), 8, failures_cout)
if debug:
for r in self.results[-2:]:
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return sum_correct + cout_correct, 16
def _test_ripplecarry(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit ripple carry adder."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print(f"\n=== RIPPLE CARRY {bits}-BIT ===")
prefix = f'arithmetic.ripplecarry{bits}bit'
max_val = 1 << bits
num_tests = min(max_val * max_val, 65536)
if bits <= 4:
# Exhaustive for small widths
test_a = torch.arange(max_val, device=self.device)
test_b = torch.arange(max_val, device=self.device)
a_vals, b_vals = torch.meshgrid(test_a, test_b, indexing='ij')
a_vals = a_vals.flatten()
b_vals = b_vals.flatten()
else:
# Strategic sampling for 8-bit
edge_vals = [0, 1, 2, 127, 128, 254, 255]
pairs = [(a, b) for a in edge_vals for b in edge_vals]
for i in range(0, 256, 16):
pairs.append((i, 255 - i))
pairs = list(set(pairs))
a_vals = torch.tensor([p[0] for p in pairs], device=self.device)
b_vals = torch.tensor([p[1] for p in pairs], device=self.device)
num_tests = len(pairs)
# Convert to bits [num_tests, bits]
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
# Evaluate ripple carry
carry = torch.zeros(len(a_vals), pop_size, device=self.device)
sum_bits = []
for bit in range(bits):
bit_idx = bits - 1 - bit # LSB first
s, carry = self._eval_single_fa(
pop, f'{prefix}.fa{bit}',
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
carry
)
sum_bits.append(s)
# Reconstruct result
sum_bits = torch.stack(sum_bits[::-1], dim=-1) # MSB first
result = torch.zeros(len(a_vals), pop_size, device=self.device)
for i in range(bits):
result += sum_bits[:, :, i] * (1 << (bits - 1 - i))
# Expected
expected = ((a_vals + b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float()
correct = (result == expected).float().sum(0)
failures = []
if pop_size == 1:
for i in range(min(len(a_vals), 100)):
if result[i, 0].item() != expected[i, 0].item():
failures.append((
[int(a_vals[i].item()), int(b_vals[i].item())],
int(expected[i, 0].item()),
int(result[i, 0].item())
))
self._record(prefix, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, num_tests
# =========================================================================
# 3-OPERAND ADDER
# =========================================================================
def _test_add3(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test 3-operand 8-bit adder (A + B + C)."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print(f"\n=== 3-OPERAND ADDER ===")
prefix = 'arithmetic.add3_8bit'
bits = 8
# Strategic test cases for 3-operand addition
# Include edge cases and overflow scenarios
test_cases = []
# Small values
for a in [0, 1, 2]:
for b in [0, 1, 2]:
for c in [0, 1, 2]:
test_cases.append((a, b, c))
# Edge values
edge = [0, 1, 127, 128, 254, 255]
for a in edge:
for b in edge:
for c in edge:
test_cases.append((a, b, c))
# Specific multi-operand expression tests
test_cases.extend([
(15, 27, 33), # Example from roadmap: 15 + 27 + 33 = 75
(100, 100, 55), # = 255 (exact fit)
(100, 100, 56), # = 256 -> 0 (overflow)
(85, 85, 85), # = 255 (exact fit)
(86, 85, 85), # = 256 -> 0 (overflow)
])
test_cases = list(set(test_cases))
a_vals = torch.tensor([t[0] for t in test_cases], device=self.device)
b_vals = torch.tensor([t[1] for t in test_cases], device=self.device)
c_vals = torch.tensor([t[2] for t in test_cases], device=self.device)
num_tests = len(test_cases)
# Convert to bits [num_tests, bits] MSB-first
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
# Stage 1: A + B
carry1 = torch.zeros(num_tests, pop_size, device=self.device)
stage1_bits = []
for bit in range(bits):
bit_idx = bits - 1 - bit # LSB first
s, carry1 = self._eval_single_fa(
pop, f'{prefix}.stage1.fa{bit}',
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
carry1
)
stage1_bits.append(s)
# Stage 2: stage1_result + C
carry2 = torch.zeros(num_tests, pop_size, device=self.device)
result_bits = []
for bit in range(bits):
bit_idx = bits - 1 - bit # LSB first
s, carry2 = self._eval_single_fa(
pop, f'{prefix}.stage2.fa{bit}',
stage1_bits[bit], # Already [num_tests, pop_size]
c_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
carry2
)
result_bits.append(s)
# Reconstruct result (bits are in LSB-first order, need to reverse for MSB-first)
result_bits = torch.stack(result_bits[::-1], dim=-1) # MSB first
result = torch.zeros(num_tests, pop_size, device=self.device)
for i in range(bits):
result += result_bits[:, :, i] * (1 << (bits - 1 - i))
# Expected (8-bit wrap)
expected = ((a_vals + b_vals + c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float()
correct = (result == expected).float().sum(0)
failures = []
if pop_size == 1:
for i in range(min(num_tests, 100)):
if result[i, 0].item() != expected[i, 0].item():
failures.append((
[int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())],
int(expected[i, 0].item()),
int(result[i, 0].item())
))
self._record(prefix, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
if failures:
for inp, exp, got in failures[:5]:
print(f" FAIL: {inp[0]} + {inp[1]} + {inp[2]} = {exp}, got {got}")
return correct, num_tests
# =========================================================================
# ORDER OF OPERATIONS (A + B × C)
# =========================================================================
def _test_expr_add_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test A + B × C expression circuit (order of operations)."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print(f"\n=== ORDER OF OPERATIONS (A + B × C) ===")
prefix = 'arithmetic.expr_add_mul'
bits = 8
# Test cases for order of operations
test_cases = []
# Specific examples from roadmap
test_cases.extend([
(5, 3, 2), # 5 + 3 × 2 = 5 + 6 = 11
(10, 4, 3), # 10 + 4 × 3 = 10 + 12 = 22
(1, 10, 10), # 1 + 10 × 10 = 1 + 100 = 101
(0, 15, 17), # 0 + 15 × 17 = 255
(1, 15, 17), # 1 + 15 × 17 = 256 -> 0 (overflow)
(100, 5, 5), # 100 + 5 × 5 = 100 + 25 = 125
])
# Edge cases
test_cases.extend([
(0, 0, 0), # 0 + 0 × 0 = 0
(255, 0, 0), # 255 + 0 × 0 = 255
(0, 255, 1), # 0 + 255 × 1 = 255
(0, 1, 255), # 0 + 1 × 255 = 255
(1, 1, 1), # 1 + 1 × 1 = 2
(0, 16, 16), # 0 + 16 × 16 = 256 -> 0 (overflow)
])
# Systematic small values
for a in [0, 1, 5, 10]:
for b in [0, 1, 2, 3]:
for c in [0, 1, 2, 3]:
test_cases.append((a, b, c))
# Remove duplicates
test_cases = list(set(test_cases))
a_vals = torch.tensor([t[0] for t in test_cases], device=self.device)
b_vals = torch.tensor([t[1] for t in test_cases], device=self.device)
c_vals = torch.tensor([t[2] for t in test_cases], device=self.device)
num_tests = len(test_cases)
# Convert to bits [num_tests, bits] MSB-first
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
# Evaluate mask stage: mask[stage][bit] = B[bit] AND C[stage]
# In the circuit: mask.s[stage].b[bit] operates on positional bits
# stage 0 = LSB of C (c_bits[:, 7]), stage 7 = MSB of C (c_bits[:, 0])
# bit 0 = LSB of B (b_bits[:, 7]), bit 7 = MSB of B (b_bits[:, 0])
masks = torch.zeros(8, num_tests, pop_size, 8, device=self.device) # [stage, tests, pop, bits]
for stage in range(8):
c_stage_bit = c_bits[:, 7 - stage].unsqueeze(1).expand(-1, pop_size) # C[stage]
for bit in range(8):
b_bit_val = b_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) # B[bit]
# AND gate
w = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.weight')
bias = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.bias')
if w is not None and bias is not None:
w = w.squeeze(-1) # [pop]
b_tensor = bias.squeeze(-1) # [pop]
# Properly broadcast for batch evaluation
inp = torch.stack([b_bit_val, c_stage_bit], dim=-1) # [tests, pop, 2]
out = heaviside(torch.einsum('tpi,pi->tp', inp, w) + b_tensor)
masks[stage, :, :, bit] = out
# Accumulator stages
# acc[0] = mask[0] (no shift)
# acc[1] = acc[0] + (mask[1] << 1)
# ...
# acc[7] = acc[6] + (mask[7] << 7)
acc = masks[0].clone() # [tests, pop, 8] - start with mask[0]
for stage in range(1, 8):
# Create shifted mask: (mask[stage] << stage)
# Shift left by 'stage' positions: bits 0..stage-1 become 0, bit k becomes mask[stage][k-stage]
shifted_mask = torch.zeros(num_tests, pop_size, 8, device=self.device)
for bit in range(8):
if bit >= stage:
shifted_mask[:, :, bit] = masks[stage, :, :, bit - stage]
# else: remains 0
# Add acc + shifted_mask using full adders
carry = torch.zeros(num_tests, pop_size, device=self.device)
new_acc = torch.zeros(num_tests, pop_size, 8, device=self.device)
for bit in range(8):
s, carry = self._eval_single_fa(
pop, f'{prefix}.mul.acc.s{stage}.fa{bit}',
acc[:, :, bit],
shifted_mask[:, :, bit],
carry
)
new_acc[:, :, bit] = s
acc = new_acc
# Final add stage: A + acc (multiplication result)
carry = torch.zeros(num_tests, pop_size, device=self.device)
result_bits = []
for bit in range(8):
a_bit_val = a_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size)
s, carry = self._eval_single_fa(
pop, f'{prefix}.add.fa{bit}',
a_bit_val,
acc[:, :, bit],
carry
)
result_bits.append(s)
# Reconstruct result
result_bits = torch.stack(result_bits[::-1], dim=-1) # MSB first
result = torch.zeros(num_tests, pop_size, device=self.device)
for i in range(bits):
result += result_bits[:, :, i] * (1 << (bits - 1 - i))
# Expected: A + (B × C), with 8-bit wrap
expected = ((a_vals + b_vals * c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float()
correct = (result == expected).float().sum(0)
failures = []
if pop_size == 1:
for i in range(min(num_tests, 100)):
if result[i, 0].item() != expected[i, 0].item():
failures.append((
[int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())],
int(expected[i, 0].item()),
int(result[i, 0].item())
))
self._record(prefix, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
if failures:
for inp, exp, got in failures[:5]:
print(f" FAIL: {inp[0]} + {inp[1]} × {inp[2]} = {exp}, got {got}")
return correct, num_tests
# =========================================================================
# COMPARATORS
# =========================================================================
def _test_comparator(self, pop: Dict, name: str, op: Callable[[int, int], bool],
debug: bool) -> Tuple[torch.Tensor, int]:
"""Test 8-bit comparator."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'arithmetic.{name}'
# Use pre-computed test pairs
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
for a, b in zip(self.comp_a, self.comp_b)],
device=self.device)
# Convert to bits
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1)
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1)
inputs = torch.cat([a_bits, b_bits], dim=1)
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(len(self.comp_a)):
if out[i, 0].item() != expected[i].item():
failures.append((
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
expected[i].item(),
out[i, 0].item()
))
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, len(self.comp_a)
def _test_comparators(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all comparators."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== COMPARATORS ===")
comparators = [
('greaterthan8bit', lambda a, b: a > b),
('lessthan8bit', lambda a, b: a < b),
('greaterorequal8bit', lambda a, b: a >= b),
('lessorequal8bit', lambda a, b: a <= b),
('equality8bit', lambda a, b: a == b),
]
for name, op in comparators:
if name == 'equality8bit':
continue # Handle separately as two-layer
try:
s, t = self._test_comparator(pop, name, op, debug)
scores += s
total += t
except KeyError:
pass # Circuit not present
# Two-layer equality circuit
try:
prefix = 'arithmetic.equality8bit'
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
for a, b in zip(self.comp_a, self.comp_b)],
device=self.device)
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1)
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1)
inputs = torch.cat([a_bits, b_bits], dim=1)
# Layer 1: geq and leq
w_geq = pop[f'{prefix}.layer1.geq.weight']
b_geq = pop[f'{prefix}.layer1.geq.bias']
w_leq = pop[f'{prefix}.layer1.leq.weight']
b_leq = pop[f'{prefix}.layer1.leq.bias']
h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
hidden = torch.stack([h_geq, h_leq], dim=-1) # [num_tests, pop_size, 2]
# Layer 2: AND
w2 = pop[f'{prefix}.layer2.weight']
b2 = pop[f'{prefix}.layer2.bias']
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(len(self.comp_a)):
if out[i, 0].item() != expected[i].item():
failures.append((
[int(self.comp_a[i].item()), int(self.comp_b[i].item())],
expected[i].item(),
out[i, 0].item()
))
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += len(self.comp_a)
except KeyError:
pass
return scores, total
def _test_comparators_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit comparator circuits (GT, LT, GE, LE, EQ)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print(f"\n=== {bits}-BIT COMPARATORS ===")
if bits == 32:
comp_a = self.comp32_a
comp_b = self.comp32_b
elif bits == 16:
comp_a = self.comp_a.clamp(0, 65535)
comp_b = self.comp_b.clamp(0, 65535)
else:
comp_a = self.comp_a
comp_b = self.comp_b
num_tests = len(comp_a)
if bits <= 16:
a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
inputs = torch.cat([a_bits, b_bits], dim=1)
comparators = [
(f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b),
(f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b),
(f'arithmetic.lessthan{bits}bit', lambda a, b: a < b),
(f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b),
]
for name, op in comparators:
try:
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
for a, b in zip(comp_a, comp_b)], device=self.device)
w = pop[f'{name}.weight']
b = pop[f'{name}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(num_tests):
if out[i, 0].item() != expected[i].item():
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
expected[i].item(), out[i, 0].item()))
self._record(name, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += num_tests
except KeyError:
pass
prefix = f'arithmetic.equality{bits}bit'
try:
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
for a, b in zip(comp_a, comp_b)], device=self.device)
w_geq = pop[f'{prefix}.layer1.geq.weight']
b_geq = pop[f'{prefix}.layer1.geq.bias']
w_leq = pop[f'{prefix}.layer1.leq.weight']
b_leq = pop[f'{prefix}.layer1.leq.bias']
h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
hidden = torch.stack([h_geq, h_leq], dim=-1)
w2 = pop[f'{prefix}.layer2.weight']
b2 = pop[f'{prefix}.layer2.bias']
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(num_tests):
if out[i, 0].item() != expected[i].item():
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += num_tests
except KeyError:
pass
else:
num_bytes = bits // 8
prefix = f"arithmetic.cmp{bits}bit"
byte_gt = []
byte_lt = []
byte_eq = []
for b in range(num_bytes):
start_bit = b * 8
a_byte = torch.stack([((comp_a >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1)
b_byte = torch.stack([((comp_b >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1)
byte_input = torch.cat([a_byte, b_byte], dim=1)
w_gt = pop[f'{prefix}.byte{b}.gt.weight'].view(pop_size, -1)
b_gt = pop[f'{prefix}.byte{b}.gt.bias'].view(pop_size)
byte_gt.append(heaviside(byte_input @ w_gt.T + b_gt))
w_lt = pop[f'{prefix}.byte{b}.lt.weight'].view(pop_size, -1)
b_lt = pop[f'{prefix}.byte{b}.lt.bias'].view(pop_size)
byte_lt.append(heaviside(byte_input @ w_lt.T + b_lt))
w_geq = pop[f'{prefix}.byte{b}.eq.geq.weight'].view(pop_size, -1)
b_geq = pop[f'{prefix}.byte{b}.eq.geq.bias'].view(pop_size)
w_leq = pop[f'{prefix}.byte{b}.eq.leq.weight'].view(pop_size, -1)
b_leq = pop[f'{prefix}.byte{b}.eq.leq.bias'].view(pop_size)
h_geq = heaviside(byte_input @ w_geq.T + b_geq)
h_leq = heaviside(byte_input @ w_leq.T + b_leq)
w_and = pop[f'{prefix}.byte{b}.eq.and.weight'].view(pop_size, -1)
b_and = pop[f'{prefix}.byte{b}.eq.and.bias'].view(pop_size)
eq_inp = torch.stack([h_geq, h_leq], dim=-1)
byte_eq.append(heaviside((eq_inp * w_and).sum(-1) + b_and))
cascade_gt = []
cascade_lt = []
for b in range(num_bytes):
if b == 0:
cascade_gt.append(byte_gt[0])
cascade_lt.append(byte_lt[0])
else:
eq_stack = torch.stack(byte_eq[:b], dim=-1)
w_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.weight'].view(pop_size, -1)
b_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.bias'].view(pop_size)
all_eq_gt = heaviside((eq_stack * w_all_eq).sum(-1) + b_all_eq)
w_and = pop[f'{prefix}.cascade.gt.stage{b}.and.weight'].view(pop_size, -1)
b_and = pop[f'{prefix}.cascade.gt.stage{b}.and.bias'].view(pop_size)
stage_inp = torch.stack([all_eq_gt, byte_gt[b]], dim=-1)
cascade_gt.append(heaviside((stage_inp * w_and).sum(-1) + b_and))
w_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.weight'].view(pop_size, -1)
b_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.bias'].view(pop_size)
all_eq_lt = heaviside((eq_stack * w_all_eq_lt).sum(-1) + b_all_eq_lt)
w_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.weight'].view(pop_size, -1)
b_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.bias'].view(pop_size)
stage_inp_lt = torch.stack([all_eq_lt, byte_lt[b]], dim=-1)
cascade_lt.append(heaviside((stage_inp_lt * w_and_lt).sum(-1) + b_and_lt))
gt_stack = torch.stack(cascade_gt, dim=-1)
w_gt_or = pop[f'arithmetic.greaterthan{bits}bit.weight'].view(pop_size, -1)
b_gt_or = pop[f'arithmetic.greaterthan{bits}bit.bias'].view(pop_size)
gt_out = heaviside((gt_stack * w_gt_or).sum(-1) + b_gt_or)
lt_stack = torch.stack(cascade_lt, dim=-1)
w_lt_or = pop[f'arithmetic.lessthan{bits}bit.weight'].view(pop_size, -1)
b_lt_or = pop[f'arithmetic.lessthan{bits}bit.bias'].view(pop_size)
lt_out = heaviside((lt_stack * w_lt_or).sum(-1) + b_lt_or)
w_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.weight'].view(pop_size, -1)
b_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.bias'].view(pop_size)
not_lt = heaviside(lt_out.unsqueeze(-1) @ w_not_lt.T + b_not_lt).squeeze(-1)
w_ge = pop[f'arithmetic.greaterorequal{bits}bit.weight'].view(pop_size, -1)
b_ge = pop[f'arithmetic.greaterorequal{bits}bit.bias'].view(pop_size)
ge_out = heaviside(not_lt.unsqueeze(-1) @ w_ge.T + b_ge).squeeze(-1)
w_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.weight'].view(pop_size, -1)
b_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.bias'].view(pop_size)
not_gt = heaviside(gt_out.unsqueeze(-1) @ w_not_gt.T + b_not_gt).squeeze(-1)
w_le = pop[f'arithmetic.lessorequal{bits}bit.weight'].view(pop_size, -1)
b_le = pop[f'arithmetic.lessorequal{bits}bit.bias'].view(pop_size)
le_out = heaviside(not_gt.unsqueeze(-1) @ w_le.T + b_le).squeeze(-1)
eq_stack = torch.stack(byte_eq, dim=-1)
w_eq_all = pop[f'arithmetic.equality{bits}bit.weight'].view(pop_size, -1)
b_eq_all = pop[f'arithmetic.equality{bits}bit.bias'].view(pop_size)
eq_out = heaviside((eq_stack * w_eq_all).sum(-1) + b_eq_all)
for name, out, op in [
(f'arithmetic.greaterthan{bits}bit', gt_out, lambda a, b: a > b),
(f'arithmetic.greaterorequal{bits}bit', ge_out, lambda a, b: a >= b),
(f'arithmetic.lessthan{bits}bit', lt_out, lambda a, b: a < b),
(f'arithmetic.lessorequal{bits}bit', le_out, lambda a, b: a <= b),
(f'arithmetic.equality{bits}bit', eq_out, lambda a, b: a == b),
]:
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
for a, b in zip(comp_a, comp_b)], device=self.device)
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(num_tests):
if out[i, 0].item() != expected[i].item():
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())],
expected[i].item(), out[i, 0].item()))
self._record(name, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += num_tests
return scores, total
def _test_subtractor_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit subtractor circuit (A - B)."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print(f"\n=== {bits}-BIT SUBTRACTOR ===")
prefix = f'arithmetic.sub{bits}bit'
max_val = 1 << bits
if bits == 32:
test_pairs = [
(1000, 500), (5000, 3000), (1000000, 500000),
(0xFFFFFFFF, 1), (0x80000000, 1), (100, 100),
(0, 0), (1, 0), (0, 1), (256, 255),
(0xDEADBEEF, 0xCAFEBABE), (1000000000, 999999999),
]
else:
test_pairs = [(a, b) for a in [0, 1, 127, 128, 255] for b in [0, 1, 127, 128, 255]]
a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long)
b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long)
num_tests = len(test_pairs)
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
not_b_bits = torch.zeros_like(b_bits)
for bit in range(bits):
w = pop[f'{prefix}.not_b.bit{bit}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.not_b.bit{bit}.bias'].view(pop_size)
not_b_bits[:, bit] = heaviside(b_bits[:, bit:bit+1] @ w.T + b)[:, 0]
carry = torch.ones(num_tests, pop_size, device=self.device)
sum_bits = []
for bit in range(bits):
bit_idx = bits - 1 - bit
s, carry = self._eval_single_fa(
pop, f'{prefix}.fa{bit}',
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
not_b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size),
carry
)
sum_bits.append(s)
sum_bits = torch.stack(sum_bits[::-1], dim=-1)
result = torch.zeros(num_tests, pop_size, device=self.device)
for i in range(bits):
result += sum_bits[:, :, i] * (1 << (bits - 1 - i))
expected = ((a_vals - b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float()
correct = (result == expected).float().sum(0)
failures = []
if pop_size == 1:
for i in range(min(num_tests, 20)):
if result[i, 0].item() != expected[i, 0].item():
failures.append((
[int(a_vals[i].item()), int(b_vals[i].item())],
int(expected[i, 0].item()),
int(result[i, 0].item())
))
self._record(prefix, int(correct[0].item()), num_tests, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, num_tests
def _test_bitwise_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit bitwise operations (AND, OR, XOR, NOT)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print(f"\n=== {bits}-BIT BITWISE OPS ===")
if bits == 32:
test_pairs = [
(0xAAAAAAAA, 0x55555555), (0xFFFFFFFF, 0x00000000),
(0x12345678, 0x87654321), (0xDEADBEEF, 0xCAFEBABE),
(0x0F0F0F0F, 0xF0F0F0F0), (0, 0), (0xFFFFFFFF, 0xFFFFFFFF),
]
else:
test_pairs = [(0xAA, 0x55), (0xFF, 0x00), (0x0F, 0xF0)]
a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long)
b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long)
num_tests = len(test_pairs)
ops = [
('and', lambda a, b: a & b),
('or', lambda a, b: a | b),
('xor', lambda a, b: a ^ b),
]
for op_name, op_fn in ops:
try:
result_bits = []
for bit in range(bits):
a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float()
b_bit = ((b_vals >> (bits - 1 - bit)) & 1).float()
if op_name == 'xor':
prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}'
w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
inp = torch.stack([a_bit, b_bit], dim=-1)
h_or = heaviside(inp @ w_or.T + b_or)
h_nand = heaviside(inp @ w_nand.T + b_nand)
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
out = heaviside((hidden * w2).sum(-1) + b2)
else:
w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size, -1)
b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size)
inp = torch.stack([a_bit, b_bit], dim=-1)
out = heaviside(inp @ w.T + b)
result_bits.append(out[:, 0] if out.dim() > 1 else out)
result = sum(int(result_bits[i][j].item()) << (bits - 1 - i)
for i in range(bits) for j in range(1))
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
for i in range(bits)) for j in range(num_tests)],
device=self.device)
expected = torch.tensor([op_fn(a.item(), b.item()) for a, b in zip(a_vals, b_vals)],
device=self.device)
correct = (results == expected).float().sum()
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += num_tests
except KeyError as e:
if debug:
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
try:
test_vals = a_vals
result_bits = []
for bit in range(bits):
a_bit = ((test_vals >> (bits - 1 - bit)) & 1).float()
w = pop[f'alu.alu{bits}bit.not.bit{bit}.weight'].view(pop_size, -1)
b = pop[f'alu.alu{bits}bit.not.bit{bit}.bias'].view(pop_size)
out = heaviside(a_bit.unsqueeze(-1) @ w.T + b)
result_bits.append(out[:, 0])
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
for i in range(bits)) for j in range(num_tests)],
device=self.device)
expected = torch.tensor([(~a.item()) & ((1 << bits) - 1) for a in test_vals],
device=self.device)
correct = (results == expected).float().sum()
self._record(f'alu.alu{bits}bit.not', int(correct.item()), num_tests, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += num_tests
except KeyError as e:
if debug:
print(f" alu.alu{bits}bit.not: SKIP (missing {e})")
return scores, total
def _test_shifts_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit shift operations (SHL, SHR)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print(f"\n=== {bits}-BIT SHIFTS ===")
if bits == 32:
test_vals = [0x12345678, 0x80000001, 0x00000001, 0xFFFFFFFF, 0x55555555]
else:
test_vals = [0x81, 0x55, 0x01, 0xFF, 0xAA]
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
num_tests = len(test_vals)
max_val = (1 << bits) - 1
for op_name, op_fn in [('shl', lambda x: (x << 1) & max_val), ('shr', lambda x: x >> 1)]:
try:
result_bits = []
for bit in range(bits):
a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float()
w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size)
if op_name == 'shl':
if bit < bits - 1:
src_bit = ((a_vals >> (bits - 2 - bit)) & 1).float()
else:
src_bit = torch.zeros_like(a_bit)
else:
if bit > 0:
src_bit = ((a_vals >> (bits - bit)) & 1).float()
else:
src_bit = torch.zeros_like(a_bit)
out = heaviside(src_bit * w + b)
result_bits.append(out)
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i)
for i in range(bits)) for j in range(num_tests)],
device=self.device)
expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device)
correct = (results == expected).float().sum()
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += num_tests
except KeyError as e:
if debug:
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
return scores, total
def _test_inc_dec_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit INC and DEC operations."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print(f"\n=== {bits}-BIT INC/DEC ===")
if bits == 32:
test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000000, 0xFFFFFFFE]
else:
test_vals = [0, 1, 254, 255, 127, 128]
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
num_tests = len(test_vals)
max_val = (1 << bits) - 1
for op_name, op_fn in [('inc', lambda x: (x + 1) & max_val), ('dec', lambda x: (x - 1) & max_val)]:
try:
carry = torch.ones(num_tests, device=self.device)
result_bits = []
for bit in range(bits):
a_bit = ((a_vals >> bit) & 1).float()
prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}'
w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten()
b_or = pop[f'{prefix}.xor.layer1.or.bias'].item()
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten()
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item()
h_or = heaviside(a_bit * w_or[0] + carry * w_or[1] + b_or)
h_nand = heaviside(a_bit * w_nand[0] + carry * w_nand[1] + b_nand)
w2 = pop[f'{prefix}.xor.layer2.weight'].flatten()
b2 = pop[f'{prefix}.xor.layer2.bias'].item()
xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2)
result_bits.append(xor_out)
if op_name == 'inc':
w_carry = pop[f'{prefix}.carry.weight'].flatten()
b_carry = pop[f'{prefix}.carry.bias'].item()
carry = heaviside(a_bit * w_carry[0] + carry * w_carry[1] + b_carry)
else:
w_not = pop[f'{prefix}.not_a.weight'].flatten()
b_not = pop[f'{prefix}.not_a.bias'].item()
not_a = heaviside(a_bit * w_not[0] + b_not)
w_borrow = pop[f'{prefix}.borrow.weight'].flatten()
b_borrow = pop[f'{prefix}.borrow.bias'].item()
carry = heaviside(not_a * w_borrow[0] + carry * w_borrow[1] + b_borrow)
results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit
for bit in range(bits)) for j in range(num_tests)],
device=self.device)
expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device)
correct = (results == expected).float().sum()
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
scores += correct
total += num_tests
except KeyError as e:
if debug:
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})")
return scores, total
def _test_neg_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit NEG operation (two's complement negation)."""
pop_size = next(iter(pop.values())).shape[0]
if debug:
print(f"\n=== {bits}-BIT NEG ===")
if bits == 32:
test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000, 1000000]
else:
test_vals = [0, 1, 127, 128, 255, 100]
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long)
num_tests = len(test_vals)
max_val = (1 << bits) - 1
try:
not_bits = []
for bit in range(bits):
a_bit = ((a_vals >> bit) & 1).float()
w = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.weight'].flatten()
b = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.bias'].item()
not_bits.append(heaviside(a_bit * w[0] + b))
carry = torch.ones(num_tests, device=self.device)
result_bits = []
for bit in range(bits):
prefix = f'alu.alu{bits}bit.neg.inc.bit{bit}'
not_bit = not_bits[bit]
w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten()
b_or = pop[f'{prefix}.xor.layer1.or.bias'].item()
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten()
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item()
h_or = heaviside(not_bit * w_or[0] + carry * w_or[1] + b_or)
h_nand = heaviside(not_bit * w_nand[0] + carry * w_nand[1] + b_nand)
w2 = pop[f'{prefix}.xor.layer2.weight'].flatten()
b2 = pop[f'{prefix}.xor.layer2.bias'].item()
xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2)
result_bits.append(xor_out)
w_carry = pop[f'{prefix}.carry.weight'].flatten()
b_carry = pop[f'{prefix}.carry.bias'].item()
carry = heaviside(not_bit * w_carry[0] + carry * w_carry[1] + b_carry)
results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit
for bit in range(bits)) for j in range(num_tests)],
device=self.device)
expected = torch.tensor([(-a.item()) & max_val for a in a_vals], device=self.device)
correct = (results == expected).float().sum()
self._record(f'alu.alu{bits}bit.neg', int(correct.item()), num_tests, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return torch.tensor([correct], device=self.device), num_tests
except KeyError as e:
if debug:
print(f" alu.alu{bits}bit.neg: SKIP (missing {e})")
return torch.zeros(pop_size, device=self.device), 0
# =========================================================================
# THRESHOLD GATES
# =========================================================================
def _test_threshold_kofn(self, pop: Dict, k: int, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test k-of-n threshold gate."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'threshold.{name}'
# Test all 256 8-bit patterns
inputs = self.test_8bit_bits if len(self.test_8bit_bits) == 24 else None
if inputs is None:
test_vals = torch.arange(256, device=self.device, dtype=torch.long)
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1)
# For k-of-8: output 1 if popcount >= k (for "at least k")
# For exact naming like "oneoutof8", it's exactly k=1
popcounts = inputs.sum(dim=1)
if 'atleast' in name:
expected = (popcounts >= k).float()
elif 'atmost' in name or 'minority' in name:
# minority = popcount <= 3 (less than half of 8)
expected = (popcounts <= k).float()
elif 'exactly' in name:
expected = (popcounts == k).float()
else:
# Standard k-of-n (at least k), including majority (>= 5)
expected = (popcounts >= k).float()
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(min(len(inputs), 256)):
if out[i, 0].item() != expected[i].item():
val = int(sum(inputs[i, j].item() * (1 << (7 - j)) for j in range(8)))
failures.append((val, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), len(inputs), failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, len(inputs)
def _test_threshold_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all threshold gates."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== THRESHOLD GATES ===")
# k-of-8 gates
kofn_gates = [
(1, 'oneoutof8'), (2, 'twooutof8'), (3, 'threeoutof8'), (4, 'fouroutof8'),
(5, 'fiveoutof8'), (6, 'sixoutof8'), (7, 'sevenoutof8'), (8, 'alloutof8'),
]
for k, name in kofn_gates:
try:
s, t = self._test_threshold_kofn(pop, k, name, debug)
scores += s
total += t
except KeyError:
pass
# Special gates
special = [
(5, 'majority'), (3, 'minority'),
(4, 'atleastk_4'), (4, 'atmostk_4'), (4, 'exactlyk_4'),
]
for k, name in special:
try:
s, t = self._test_threshold_kofn(pop, k, name, debug)
scores += s
total += t
except KeyError:
pass
return scores, total
# =========================================================================
# MODULAR ARITHMETIC
# =========================================================================
def _test_modular(self, pop: Dict, mod: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test modular divisibility circuit (multi-layer for non-powers-of-2)."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'modular.mod{mod}'
# Test 0-255
inputs = torch.stack([((self.mod_test >> (7 - i)) & 1).float() for i in range(8)], dim=1)
expected = ((self.mod_test % mod) == 0).float()
# Try single layer first (powers of 2)
try:
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
except KeyError:
# Multi-layer structure: layer1 (geq/leq) -> layer2 (eq) -> layer3 (or)
try:
# Layer 1: geq and leq neurons
geq_outputs = {}
leq_outputs = {}
i = 0
while True:
found = False
if f'{prefix}.layer1.geq{i}.weight' in pop:
w = pop[f'{prefix}.layer1.geq{i}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.layer1.geq{i}.bias'].view(pop_size)
geq_outputs[i] = heaviside(inputs @ w.T + b) # [256, pop_size]
found = True
if f'{prefix}.layer1.leq{i}.weight' in pop:
w = pop[f'{prefix}.layer1.leq{i}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.layer1.leq{i}.bias'].view(pop_size)
leq_outputs[i] = heaviside(inputs @ w.T + b)
found = True
if not found:
break
i += 1
if not geq_outputs and not leq_outputs:
return torch.zeros(pop_size, device=self.device), 0
# Layer 2: eq neurons (AND of geq and leq for same index)
eq_outputs = []
i = 0
while f'{prefix}.layer2.eq{i}.weight' in pop:
w = pop[f'{prefix}.layer2.eq{i}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.layer2.eq{i}.bias'].view(pop_size)
# Input is [geq_i, leq_i]
eq_in = torch.stack([geq_outputs.get(i, torch.zeros(256, pop_size, device=self.device)),
leq_outputs.get(i, torch.zeros(256, pop_size, device=self.device))], dim=-1)
eq_out = heaviside((eq_in * w).sum(-1) + b)
eq_outputs.append(eq_out)
i += 1
if not eq_outputs:
return torch.zeros(pop_size, device=self.device), 0
# Layer 3: OR of all eq outputs
eq_stack = torch.stack(eq_outputs, dim=-1) # [256, pop_size, num_eq]
w3 = pop[f'{prefix}.layer3.or.weight'].view(pop_size, -1)
b3 = pop[f'{prefix}.layer3.or.bias'].view(pop_size)
out = heaviside((eq_stack * w3).sum(-1) + b3) # [256, pop_size]
except Exception as e:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(256):
if out[i, 0].item() != expected[i].item():
failures.append((i, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 256, failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 256
def _test_modular_all(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test all modular arithmetic circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== MODULAR ARITHMETIC ===")
for mod in range(2, 13):
s, t = self._test_modular(pop, mod, debug)
scores += s
total += t
return scores, total
# =========================================================================
# PATTERN RECOGNITION
# =========================================================================
def _test_pattern(self, pop: Dict, name: str, expected_fn: Callable[[int], float],
debug: bool) -> Tuple[torch.Tensor, int]:
"""Test pattern recognition circuit."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'pattern_recognition.{name}'
test_vals = torch.arange(256, device=self.device, dtype=torch.long)
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1)
expected = torch.tensor([expected_fn(v.item()) for v in test_vals], device=self.device)
try:
w = pop[f'{prefix}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.bias'].view(pop_size)
out = heaviside(inputs @ w.T + b)
except KeyError:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(256):
if out[i, 0].item() != expected[i].item():
failures.append((i, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 256, failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 256
def _test_patterns(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test pattern recognition circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== PATTERN RECOGNITION ===")
# Use correct naming: pattern_recognition.allzeros, pattern_recognition.allones
patterns = [
('allzeros', lambda v: 1.0 if v == 0 else 0.0),
('allones', lambda v: 1.0 if v == 255 else 0.0),
]
for name, fn in patterns:
s, t = self._test_pattern(pop, name, fn, debug)
scores += s
total += t
return scores, total
# =========================================================================
# ERROR DETECTION
# =========================================================================
def _eval_xor_tree_stage(self, pop: Dict, prefix: str, stage: int, idx: int,
a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Evaluate a single XOR in the parity tree."""
pop_size = next(iter(pop.values())).shape[0]
xor_prefix = f'{prefix}.stage{stage}.xor{idx}'
# Ensure 2D: [256, pop_size]
if a.dim() == 1:
a = a.unsqueeze(1).expand(-1, pop_size)
if b.dim() == 1:
b = b.unsqueeze(1).expand(-1, pop_size)
# Layer 1: OR and NAND
w_or = pop[f'{xor_prefix}.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'{xor_prefix}.layer1.or.bias'].view(pop_size)
w_nand = pop[f'{xor_prefix}.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'{xor_prefix}.layer1.nand.bias'].view(pop_size)
inputs = torch.stack([a, b], dim=-1) # [256, pop_size, 2]
h_or = heaviside((inputs * w_or).sum(-1) + b_or)
h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand)
# Layer 2
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = pop[f'{xor_prefix}.layer2.weight'].view(pop_size, 2)
b2 = pop[f'{xor_prefix}.layer2.bias'].view(pop_size)
return heaviside((hidden * w2).sum(-1) + b2)
def _test_parity_xor_tree(self, pop: Dict, prefix: str, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test parity circuit with XOR tree structure."""
pop_size = next(iter(pop.values())).shape[0]
test_vals = torch.arange(256, device=self.device, dtype=torch.long)
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1)
# XOR of all bits: 1 if odd number of 1s
popcounts = inputs.sum(dim=1)
xor_result = (popcounts.long() % 2).float()
try:
# Stage 1: 4 XORs (pairs of bits)
s1_out = []
for i in range(4):
xor_out = self._eval_xor_tree_stage(pop, prefix, 1, i, inputs[:, i*2], inputs[:, i*2+1])
s1_out.append(xor_out)
# Stage 2: 2 XORs
s2_out = []
for i in range(2):
xor_out = self._eval_xor_tree_stage(pop, prefix, 2, i, s1_out[i*2], s1_out[i*2+1])
s2_out.append(xor_out)
# Stage 3: 1 XOR
s3_out = self._eval_xor_tree_stage(pop, prefix, 3, 0, s2_out[0], s2_out[1])
# Output NOT (for parity checker - inverts the XOR result)
if f'{prefix}.output.not.weight' in pop:
w_not = pop[f'{prefix}.output.not.weight'].view(pop_size)
b_not = pop[f'{prefix}.output.not.bias'].view(pop_size)
out = heaviside(s3_out * w_not + b_not)
# Checker outputs 1 if even parity (XOR=0), so expected is inverted xor_result
expected = 1.0 - xor_result
else:
out = s3_out
expected = xor_result
except KeyError as e:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(256):
if out[i, 0].item() != expected[i].item():
failures.append((i, expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 256, failures[:10])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 256
def _test_error_detection(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test error detection circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== ERROR DETECTION ===")
# XOR tree parity circuits
for prefix in ['error_detection.paritychecker8bit', 'error_detection.paritygenerator8bit']:
s, t = self._test_parity_xor_tree(pop, prefix, debug)
scores += s
total += t
return scores, total
# =========================================================================
# COMBINATIONAL LOGIC
# =========================================================================
def _test_mux2to1(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test 2-to-1 multiplexer."""
pop_size = next(iter(pop.values())).shape[0]
prefix = 'combinational.multiplexer2to1'
# Inputs: [a, b, sel] -> out = sel ? b : a
inputs = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1],
], device=self.device, dtype=torch.float32)
expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32)
try:
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
except KeyError:
return torch.zeros(pop_size, device=self.device), 0
correct = (out == expected.unsqueeze(1)).float().sum(0)
failures = []
if pop_size == 1:
for i in range(8):
if out[i, 0].item() != expected[i].item():
failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 8, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return correct, 8
def _test_decoder3to8(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test 3-to-8 decoder."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== DECODER 3-TO-8 ===")
inputs = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1],
], device=self.device, dtype=torch.float32)
for out_idx in range(8):
prefix = f'combinational.decoder3to8.out{out_idx}'
expected = torch.zeros(8, device=self.device)
expected[out_idx] = 1.0
try:
w = pop[f'{prefix}.weight']
b = pop[f'{prefix}.bias']
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
except KeyError:
continue
correct = (out == expected.unsqueeze(1)).float().sum(0)
scores += correct
total += 8
failures = []
if pop_size == 1:
for i in range(8):
if out[i, 0].item() != expected[i].item():
failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item()))
self._record(prefix, int(correct[0].item()), 8, failures)
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
def _test_combinational(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test combinational logic circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== COMBINATIONAL LOGIC ===")
s, t = self._test_mux2to1(pop, debug)
scores += s
total += t
s, t = self._test_decoder3to8(pop, debug)
scores += s
total += t
s, t = self._test_barrel_shifter(pop, debug)
scores += s
total += t
s, t = self._test_priority_encoder(pop, debug)
scores += s
total += t
return scores, total
def _test_barrel_shifter(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test barrel shifter (shift by 0-7 positions)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== BARREL SHIFTER ===")
try:
# Test all shift amounts 0-7 with various input patterns
test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, 0xFF]
for val in test_vals:
for shift in range(8):
expected_val = (val << shift) & 0xFF # Left shift
val_bits = [float((val >> (7 - i)) & 1) for i in range(8)]
shift_bits = [float((shift >> (2 - i)) & 1) for i in range(3)]
# Process through 3 layers
layer_in = val_bits[:]
for layer in range(3):
shift_amount = 1 << (2 - layer) # 4, 2, 1
sel = shift_bits[layer]
layer_out = []
for bit in range(8):
prefix = f'combinational.barrelshifter.layer{layer}.bit{bit}'
# NOT sel
w_not = pop[f'{prefix}.not_sel.weight'].view(pop_size)
b_not = pop[f'{prefix}.not_sel.bias'].view(pop_size)
not_sel = heaviside(sel * w_not + b_not)
# Source for shifted value
shifted_src = bit + shift_amount
if shifted_src < 8:
shifted_val = layer_in[shifted_src]
else:
shifted_val = 0.0
# AND a: original AND NOT sel
w_and_a = pop[f'{prefix}.and_a.weight'].view(pop_size, 2)
b_and_a = pop[f'{prefix}.and_a.bias'].view(pop_size)
inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device)
and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a)
# AND b: shifted AND sel
w_and_b = pop[f'{prefix}.and_b.weight'].view(pop_size, 2)
b_and_b = pop[f'{prefix}.and_b.bias'].view(pop_size)
inp_b = torch.tensor([shifted_val, sel], device=self.device)
and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b)
# OR
w_or = pop[f'{prefix}.or.weight'].view(pop_size, 2)
b_or = pop[f'{prefix}.or.bias'].view(pop_size)
inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device)
out = heaviside((inp_or * w_or).sum(-1) + b_or)
layer_out.append(out[0].item())
layer_in = layer_out
# Check result
result = sum(int(layer_in[i]) << (7 - i) for i in range(8))
if result == expected_val:
scores += 1
total += 1
self._record('combinational.barrelshifter', int(scores[0].item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" combinational.barrelshifter: SKIP ({e})")
return scores, total
def _test_priority_encoder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test priority encoder (find highest set bit)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== PRIORITY ENCODER ===")
try:
# Test cases: input -> (valid, index of highest bit)
test_cases = [
(0b00000000, 0, 0), # No bits set, valid=0
(0b00000001, 1, 7), # Bit 7 (LSB)
(0b00000010, 1, 6),
(0b00000100, 1, 5),
(0b00001000, 1, 4),
(0b00010000, 1, 3),
(0b00100000, 1, 2),
(0b01000000, 1, 1),
(0b10000000, 1, 0), # Bit 0 (MSB)
(0b10000001, 1, 0), # Multiple bits, highest wins
(0b01010101, 1, 1),
(0b00001111, 1, 4),
(0b11111111, 1, 0),
]
for val, expected_valid, expected_idx in test_cases:
val_bits = torch.tensor([float((val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# Valid output: OR of all input bits
w_valid = pop['combinational.priorityencoder.valid.weight'].view(pop_size, 8)
b_valid = pop['combinational.priorityencoder.valid.bias'].view(pop_size)
out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid)
if int(out_valid[0].item()) == expected_valid:
scores += 1
total += 1
# Index outputs (3 bits)
if expected_valid == 1:
for idx_bit in range(3):
try:
w_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.weight'].view(pop_size, 8)
b_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.bias'].view(pop_size)
out_idx = heaviside((val_bits * w_idx).sum(-1) + b_idx)
expected_bit = (expected_idx >> (2 - idx_bit)) & 1
if int(out_idx[0].item()) == expected_bit:
scores += 1
total += 1
except KeyError:
pass
self._record('combinational.priorityencoder', int(scores[0].item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" combinational.priorityencoder: SKIP ({e})")
return scores, total
def _test_barrel_shifter_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit barrel shifter (shift by 0 to bits-1 positions)."""
import math
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
num_layers = max(1, math.ceil(math.log2(bits)))
max_val = (1 << bits) - 1
if debug:
print(f"\n=== {bits}-BIT BARREL SHIFTER ===")
prefix = f'combinational.barrelshifter{bits}'
try:
if bits == 16:
test_vals = [0x8001, 0xFF00, 0x00FF, 0xAAAA, 0xFFFF, 0x1234]
elif bits == 32:
test_vals = [0x80000001, 0xFFFF0000, 0x0000FFFF, 0xAAAAAAAA, 0xFFFFFFFF, 0x12345678]
else:
test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, max_val]
num_shifts = min(bits, 8)
for val in test_vals:
for shift in range(num_shifts):
expected_val = (val << shift) & max_val
val_bits = [float((val >> (bits - 1 - i)) & 1) for i in range(bits)]
shift_bits = [float((shift >> (num_layers - 1 - i)) & 1) for i in range(num_layers)]
layer_in = val_bits[:]
for layer in range(num_layers):
shift_amount = 1 << (num_layers - 1 - layer)
sel = shift_bits[layer]
layer_out = []
for bit in range(bits):
bit_prefix = f'{prefix}.layer{layer}.bit{bit}'
w_not = pop[f'{bit_prefix}.not_sel.weight'].view(pop_size)
b_not = pop[f'{bit_prefix}.not_sel.bias'].view(pop_size)
not_sel = heaviside(sel * w_not + b_not)
shifted_src = bit + shift_amount
if shifted_src < bits:
shifted_val = layer_in[shifted_src]
else:
shifted_val = 0.0
w_and_a = pop[f'{bit_prefix}.and_a.weight'].view(pop_size, 2)
b_and_a = pop[f'{bit_prefix}.and_a.bias'].view(pop_size)
inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device)
and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a)
w_and_b = pop[f'{bit_prefix}.and_b.weight'].view(pop_size, 2)
b_and_b = pop[f'{bit_prefix}.and_b.bias'].view(pop_size)
inp_b = torch.tensor([shifted_val, sel], device=self.device)
and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b)
w_or = pop[f'{bit_prefix}.or.weight'].view(pop_size, 2)
b_or = pop[f'{bit_prefix}.or.bias'].view(pop_size)
inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device)
out = heaviside((inp_or * w_or).sum(-1) + b_or)
layer_out.append(out[0].item())
layer_in = layer_out
result = sum(int(layer_in[i]) << (bits - 1 - i) for i in range(bits))
if result == expected_val:
scores += 1
total += 1
self._record(prefix, int(scores[0].item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" {prefix}: SKIP ({e})")
return scores, total
def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test N-bit priority encoder (find highest set bit).
The priority encoder is a multi-layer circuit:
1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions)
2. is_highest{0}: bit[0] directly (MSB is always highest if set)
3. is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0
4. out{bit}: OR of is_highest{pos} for all pos where (pos >> bit) & 1
5. valid: OR of all input bits
"""
import math
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
out_bits = max(1, math.ceil(math.log2(bits)))
if debug:
print(f"\n=== {bits}-BIT PRIORITY ENCODER ===")
prefix = f'combinational.priorityencoder{bits}'
try:
test_cases = [(0, 0, 0)]
for i in range(bits):
test_cases.append((1 << i, 1, bits - 1 - i))
if bits == 16:
test_cases.extend([
(0x8001, 1, 0), (0x5555, 1, 1), (0x00FF, 1, 8), (0xFFFF, 1, 0)
])
elif bits == 32:
test_cases.extend([
(0x80000001, 1, 0), (0x55555555, 1, 1), (0x0000FFFF, 1, 16), (0xFFFFFFFF, 1, 0)
])
for val, expected_valid, expected_idx in test_cases:
val_bits = torch.tensor([float((val >> (bits - 1 - i)) & 1) for i in range(bits)],
device=self.device, dtype=torch.float32)
w_valid = pop[f'{prefix}.valid.weight'].view(pop_size, bits)
b_valid = pop[f'{prefix}.valid.bias'].view(pop_size)
out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid)
if int(out_valid[0].item()) == expected_valid:
scores += 1
total += 1
if expected_valid == 1:
any_higher = [None]
for pos in range(1, bits):
w = pop[f'{prefix}.any_higher{pos}.weight'].view(pop_size, -1)
b = pop[f'{prefix}.any_higher{pos}.bias'].view(pop_size)
inp = val_bits[:pos]
out = heaviside((inp * w[:, :len(inp)]).sum(-1) + b)
any_higher.append(out)
is_highest = []
for pos in range(bits):
if pos == 0:
is_high = val_bits[0].unsqueeze(0).expand(pop_size)
else:
w_not = pop[f'{prefix}.is_highest{pos}.not_higher.weight'].view(pop_size, -1)
b_not = pop[f'{prefix}.is_highest{pos}.not_higher.bias'].view(pop_size)
not_higher = heaviside(any_higher[pos].unsqueeze(-1) * w_not + b_not).squeeze(-1)
w_and = pop[f'{prefix}.is_highest{pos}.and.weight'].view(pop_size, -1)
b_and = pop[f'{prefix}.is_highest{pos}.and.bias'].view(pop_size)
inp = torch.stack([val_bits[pos].expand(pop_size), not_higher], dim=-1)
is_high = heaviside((inp * w_and).sum(-1) + b_and)
is_highest.append(is_high)
for idx_bit in range(out_bits):
try:
w_idx = pop[f'{prefix}.out{idx_bit}.weight'].view(pop_size, -1)
b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size)
relevant = [is_highest[pos] for pos in range(bits) if (pos >> idx_bit) & 1]
if len(relevant) > 0:
inp = torch.stack(relevant[:w_idx.shape[1]], dim=-1)
out_idx = heaviside((inp * w_idx).sum(-1) + b_idx)
expected_bit = (expected_idx >> idx_bit) & 1
if int(out_idx[0].item()) == expected_bit:
scores += 1
total += 1
except KeyError:
pass
self._record(prefix, int(scores[0].item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" {prefix}: SKIP ({e})")
return scores, total
# =========================================================================
# CONTROL FLOW
# =========================================================================
def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test conditional jump circuit (N-bit address aware)."""
pop_size = next(iter(pop.values())).shape[0]
prefix = f'control.{name}'
# Test cases: [pc_bit, target_bit, flag] -> out = flag ? target : pc
inputs = torch.tensor([
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1],
], device=self.device, dtype=torch.float32)
expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32)
scores = torch.zeros(pop_size, device=self.device)
total = 0
for bit in range(self.addr_bits):
bit_prefix = f'{prefix}.bit{bit}'
try:
# NOT sel
w_not = pop[f'{bit_prefix}.not_sel.weight']
b_not = pop[f'{bit_prefix}.not_sel.bias']
flag = inputs[:, 2:3]
not_sel = heaviside(flag @ w_not.view(pop_size, -1).T + b_not.view(pop_size))
# AND a (pc AND NOT sel)
w_and_a = pop[f'{bit_prefix}.and_a.weight']
b_and_a = pop[f'{bit_prefix}.and_a.bias']
pc_not = torch.cat([inputs[:, 0:1], not_sel], dim=-1)
and_a = heaviside((pc_not * w_and_a.view(pop_size, 1, 2)).sum(-1) + b_and_a.view(pop_size, 1))
# AND b (target AND sel)
w_and_b = pop[f'{bit_prefix}.and_b.weight']
b_and_b = pop[f'{bit_prefix}.and_b.bias']
target_sel = inputs[:, 1:3]
and_b = heaviside((target_sel * w_and_b.view(pop_size, 1, 2)).sum(-1) + b_and_b.view(pop_size, 1))
# OR
w_or = pop[f'{bit_prefix}.or.weight']
b_or = pop[f'{bit_prefix}.or.bias']
# Ensure we keep [num_tests, pop_size] shape
and_a_2d = and_a.view(8, pop_size)
and_b_2d = and_b.view(8, pop_size)
ab = torch.stack([and_a_2d, and_b_2d], dim=-1) # [8, pop_size, 2]
out = heaviside((ab * w_or.view(pop_size, 2)).sum(-1) + b_or.view(pop_size)) # [8, pop_size]
correct = (out == expected.unsqueeze(1)).float().sum(0) # [pop_size]
scores += correct
total += 8
except KeyError:
pass
if total > 0:
self._record(prefix, int((scores[0] / total * total).item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
return scores, total
def _test_control_flow(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test control flow circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== CONTROL FLOW ===")
jumps = ['jz', 'jnz', 'jc', 'jnc', 'jn', 'jp', 'jv', 'jnv', 'conditionaljump']
for name in jumps:
s, t = self._test_conditional_jump(pop, name, debug)
scores += s
total += t
# Stack operations
s, t = self._test_stack_ops(pop, debug)
scores += s
total += t
return scores, total
def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test PUSH/POP/RET stack operation circuits (N-bit address aware)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
addr_bits = self.addr_bits
addr_mask = (1 << addr_bits) - 1
if debug:
print(f"\n=== STACK OPERATIONS ({addr_bits}-bit SP) ===")
# Test PUSH SP decrement (addr_bits wide, borrow chain)
try:
# Generate test values appropriate for addr_bits
sp_tests = [0, 1, addr_mask // 2, addr_mask]
if addr_bits >= 8:
sp_tests.append(0x100 & addr_mask)
if addr_bits >= 12:
sp_tests.append(0x1234 & addr_mask)
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for sp_val in sp_tests:
expected_val = (sp_val - 1) & addr_mask
sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)]
borrow = 1.0
out_bits = []
for bit in range(addr_bits - 1, -1, -1): # LSB to MSB
prefix = f'control.push.sp_dec.bit{bit}'
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size)
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size)
w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2)
b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size)
inp = torch.tensor([sp_bits[bit], borrow], device=self.device)
h_or = heaviside((inp * w_or).sum(-1) + b_or)
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand)
hidden = torch.stack([h_or, h_nand], dim=-1)
diff_bit = heaviside((hidden * w2).sum(-1) + b2)
out_bits.insert(0, diff_bit)
# Borrow: NOT(sp) AND borrow_in
not_sp = 1.0 - sp_bits[bit]
w_borrow = pop[f'{prefix}.borrow.weight'].view(pop_size, 2)
b_borrow = pop[f'{prefix}.borrow.bias'].view(pop_size)
borrow_inp = torch.tensor([not_sp, borrow], device=self.device)
borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item()
out = torch.stack(out_bits, dim=-1)
expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += addr_bits
scores += op_scores
total += op_total
self._record('control.push.sp_dec', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" control.push.sp_dec: SKIP ({e})")
# Test POP SP increment (addr_bits wide, carry chain)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for sp_val in sp_tests:
expected_val = (sp_val + 1) & addr_mask
sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)]
carry = 1.0
out_bits = []
for bit in range(addr_bits - 1, -1, -1): # LSB to MSB
prefix = f'control.pop.sp_inc.bit{bit}'
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size)
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size)
w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2)
b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size)
inp = torch.tensor([sp_bits[bit], carry], device=self.device)
h_or = heaviside((inp * w_or).sum(-1) + b_or)
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand)
hidden = torch.stack([h_or, h_nand], dim=-1)
sum_bit = heaviside((hidden * w2).sum(-1) + b2)
out_bits.insert(0, sum_bit)
# Carry: sp AND carry_in
w_carry = pop[f'{prefix}.carry.weight'].view(pop_size, 2)
b_carry = pop[f'{prefix}.carry.bias'].view(pop_size)
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item()
out = torch.stack(out_bits, dim=-1)
expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += addr_bits
scores += op_scores
total += op_total
self._record('control.pop.sp_inc', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" control.pop.sp_inc: SKIP ({e})")
# Test RET address buffer (addr_bits identity gates)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
ret_tests = [0, addr_mask, addr_mask // 2, 1]
if addr_bits >= 12:
ret_tests.append(0x1234 & addr_mask)
for addr_val in ret_tests:
ret_bits_tensor = torch.tensor([float((addr_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
device=self.device, dtype=torch.float32)
out_bits = []
for bit in range(addr_bits):
w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size)
b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size)
out = heaviside(ret_bits_tensor[bit] * w + b)
out_bits.append(out)
out = torch.stack(out_bits, dim=-1)
correct = (out == ret_bits_tensor.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += addr_bits
scores += op_scores
total += op_total
self._record('control.ret.addr', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" control.ret.addr: SKIP ({e})")
return scores, total
# =========================================================================
# ALU
# =========================================================================
def _test_alu_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test ALU operations (8-bit bitwise)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== ALU OPERATIONS ===")
# Test ALU AND/OR/NOT on 8-bit values
# Each ALU op has weight [16] or [8] and bias [8]
# Structured as 8 parallel 2-input (or 1-input for NOT) gates
test_vals = [(0, 0), (255, 255), (0xAA, 0x55), (0x0F, 0xF0)]
# AND: weight [16] = 8 * [2], bias [8]
try:
w = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) # [pop, 8, 2]
b = pop['alu.alu8bit.and.bias'].view(pop_size, 8) # [pop, 8]
for a_val, b_val in test_vals:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# [8, 2]
inputs = torch.stack([a_bits, b_bits], dim=-1)
# [pop, 8]
out = heaviside((inputs * w).sum(-1) + b)
expected = torch.tensor([((a_val & b_val) >> (7 - i)) & 1 for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1) # [pop]
scores += correct
total += 8
self._record('alu.alu8bit.and', int(scores[0].item()), total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError):
pass
# OR
try:
w = pop['alu.alu8bit.or.weight'].view(pop_size, 8, 2)
b = pop['alu.alu8bit.or.bias'].view(pop_size, 8)
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, b_val in test_vals:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
inputs = torch.stack([a_bits, b_bits], dim=-1)
out = heaviside((inputs * w).sum(-1) + b)
expected = torch.tensor([((a_val | b_val) >> (7 - i)) & 1 for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.or', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError):
pass
# NOT
try:
w = pop['alu.alu8bit.not.weight'].view(pop_size, 8)
b = pop['alu.alu8bit.not.bias'].view(pop_size, 8)
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, _ in test_vals:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out = heaviside(a_bits * w + b)
expected = torch.tensor([(((~a_val) & 0xFF) >> (7 - i)) & 1 for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.not', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError):
pass
# SHL (shift left)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, _ in test_vals:
expected_val = (a_val << 1) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size)
if bit < 7:
inp = a_bits[bit + 1].unsqueeze(0).expand(pop_size)
else:
inp = torch.zeros(pop_size, device=self.device)
out = heaviside(inp * w + b)
out_bits.append(out)
out = torch.stack(out_bits, dim=-1) # [pop, 8]
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.shl', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.shl: SKIP ({e})")
# SHR (shift right)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
for a_val, _ in test_vals:
expected_val = (a_val >> 1) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.shr.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu8bit.shr.bit{bit}.bias'].view(pop_size)
if bit > 0:
inp = a_bits[bit - 1].unsqueeze(0).expand(pop_size)
else:
inp = torch.zeros(pop_size, device=self.device)
out = heaviside(inp * w + b)
out_bits.append(out)
out = torch.stack(out_bits, dim=-1) # [pop, 8]
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.shr', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.shr: SKIP ({e})")
# MUL (partial products only - just verify AND gates work)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
mul_tests = [(3, 4), (7, 8), (15, 17), (0, 255)]
for a_val, b_val in mul_tests:
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# Test partial product AND gates
for i in range(8):
for j in range(8):
w = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.weight'].view(pop_size, 2)
b = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.bias'].view(pop_size)
inp = torch.tensor([a_bits[i].item(), b_bits[j].item()], device=self.device)
out = heaviside((inp * w).sum(-1) + b)
expected = float(int(a_bits[i].item()) & int(b_bits[j].item()))
correct = (out == expected).float()
op_scores += correct
op_total += 1
scores += op_scores
total += op_total
self._record('alu.alu8bit.mul', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.mul: SKIP ({e})")
# DIV (comparison gates only)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
div_tests = [(100, 10), (255, 17), (50, 7), (128, 16)]
for a_val, b_val in div_tests:
# Test each stage's comparison gate
for stage in range(8):
w = pop[f'alu.alu8bit.div.stage{stage}.cmp.weight'].view(pop_size, 16)
b = pop[f'alu.alu8bit.div.stage{stage}.cmp.bias'].view(pop_size)
# Create test inputs (simplified: just test that gate exists and has correct shape)
test_rem = (a_val >> (7 - stage)) & 0xFF
rem_bits = torch.tensor([((test_rem >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
div_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
inp = torch.cat([rem_bits, div_bits])
out = heaviside((inp * w).sum(-1) + b)
expected = float(test_rem >= b_val)
correct = (out == expected).float()
op_scores += correct
op_total += 1
scores += op_scores
total += op_total
self._record('alu.alu8bit.div', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.div: SKIP ({e})")
# INC (increment by 1)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
inc_tests = [0, 1, 127, 128, 254, 255]
for a_val in inc_tests:
expected_val = (a_val + 1) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# INC uses half-adder chain with initial carry = 1
carry = 1.0
out_bits = []
for bit in range(7, -1, -1): # LSB to MSB
# XOR for sum
w_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size)
w_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size)
w2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2)
b2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.bias'].view(pop_size)
inp = torch.tensor([a_bits[bit].item(), carry], device=self.device)
h_or = heaviside((inp * w_or).sum(-1) + b_or)
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand)
hidden = torch.stack([h_or, h_nand], dim=-1)
sum_bit = heaviside((hidden * w2).sum(-1) + b2)
out_bits.insert(0, sum_bit)
# AND for carry
w_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.weight'].view(pop_size, 2)
b_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.bias'].view(pop_size)
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item()
out = torch.stack(out_bits, dim=-1)
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.inc', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.inc: SKIP ({e})")
# DEC (decrement by 1)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
dec_tests = [0, 1, 127, 128, 254, 255]
for a_val in dec_tests:
expected_val = (a_val - 1) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# DEC uses borrow chain
borrow = 1.0
out_bits = []
for bit in range(7, -1, -1):
w_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.bias'].view(pop_size)
w_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.bias'].view(pop_size)
w2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.weight'].view(pop_size, 2)
b2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.bias'].view(pop_size)
inp = torch.tensor([a_bits[bit].item(), borrow], device=self.device)
h_or = heaviside((inp * w_or).sum(-1) + b_or)
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand)
hidden = torch.stack([h_or, h_nand], dim=-1)
diff_bit = heaviside((hidden * w2).sum(-1) + b2)
out_bits.insert(0, diff_bit)
# Borrow logic: borrow_out = NOT(a) AND borrow_in
w_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.weight'].view(pop_size)
b_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.bias'].view(pop_size)
not_a = heaviside(a_bits[bit] * w_not + b_not)
w_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.weight'].view(pop_size, 2)
b_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.bias'].view(pop_size)
borrow_inp = torch.tensor([not_a[0].item(), borrow], device=self.device)
borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item()
out = torch.stack(out_bits, dim=-1)
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.dec', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.dec: SKIP ({e})")
# NEG (two's complement: NOT + 1)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
neg_tests = [0, 1, 127, 128, 255]
for a_val in neg_tests:
expected_val = (-a_val) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# First NOT each bit
not_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.neg.not.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu8bit.neg.not.bit{bit}.bias'].view(pop_size)
not_bit = heaviside(a_bits[bit] * w + b)
not_bits.append(not_bit)
# Then INC
carry = 1.0
out_bits = []
for bit in range(7, -1, -1):
w_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2)
b_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size)
w_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2)
b_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size)
w2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2)
b2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.bias'].view(pop_size)
inp = torch.tensor([not_bits[bit][0].item(), carry], device=self.device)
h_or = heaviside((inp * w_or).sum(-1) + b_or)
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand)
hidden = torch.stack([h_or, h_nand], dim=-1)
sum_bit = heaviside((hidden * w2).sum(-1) + b2)
out_bits.insert(0, sum_bit)
w_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.weight'].view(pop_size, 2)
b_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.bias'].view(pop_size)
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item()
out = torch.stack(out_bits, dim=-1)
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.neg', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.neg: SKIP ({e})")
# ROL (rotate left - MSB wraps to LSB)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
rol_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00]
for a_val in rol_tests:
expected_val = ((a_val << 1) | (a_val >> 7)) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.rol.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu8bit.rol.bit{bit}.bias'].view(pop_size)
# ROL: bit[i] gets bit[i+1], bit[7] gets bit[0]
src_bit = (bit + 1) % 8
out = heaviside(a_bits[src_bit] * w + b)
out_bits.append(out)
out = torch.stack(out_bits, dim=-1)
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.rol', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.rol: SKIP ({e})")
# ROR (rotate right - LSB wraps to MSB)
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
ror_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00]
for a_val in ror_tests:
expected_val = ((a_val >> 1) | (a_val << 7)) & 0xFF
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
out_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.ror.bit{bit}.weight'].view(pop_size)
b = pop[f'alu.alu8bit.ror.bit{bit}.bias'].view(pop_size)
# ROR: bit[i] gets bit[i-1], bit[0] gets bit[7]
src_bit = (bit - 1) % 8
out = heaviside(a_bits[src_bit] * w + b)
out_bits.append(out)
out = torch.stack(out_bits, dim=-1)
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
correct = (out == expected.unsqueeze(0)).float().sum(1)
op_scores += correct
op_total += 8
scores += op_scores
total += op_total
self._record('alu.alu8bit.ror', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" alu.alu8bit.ror: SKIP ({e})")
return scores, total
# =========================================================================
# MANIFEST
# =========================================================================
def _test_manifest(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Verify manifest values."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== MANIFEST ===")
fixed_expected = {
'manifest.alu_operations': 16.0,
'manifest.flags': 4.0,
'manifest.instruction_width': 16.0,
'manifest.register_width': 8.0,
'manifest.registers': 4.0,
'manifest.version': 4.0,
}
for name, exp_val in fixed_expected.items():
try:
val = pop[name][0, 0].item()
if val == exp_val:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(exp_val, val)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
pass
variable_checks = ['manifest.memory_bytes', 'manifest.pc_width', 'manifest.turing_complete']
for name in variable_checks:
try:
val = pop[name][0, 0].item()
valid = val >= 0
if valid:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [('>=0', val)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'} (value={val})")
except KeyError:
pass
return scores, total
# =========================================================================
# MEMORY
# =========================================================================
def _test_memory(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test memory circuits (shape validation)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== MEMORY ===")
try:
mem_bytes = int(pop['manifest.memory_bytes'][0].item())
addr_bits = int(pop['manifest.pc_width'][0].item())
except KeyError:
mem_bytes = 65536
addr_bits = 16
if mem_bytes == 0:
if debug:
print(" No memory (pure ALU mode)")
return scores, 0
expected_shapes = {
'memory.addr_decode.weight': (mem_bytes, addr_bits),
'memory.addr_decode.bias': (mem_bytes,),
'memory.read.and.weight': (8, mem_bytes, 2),
'memory.read.and.bias': (8, mem_bytes),
'memory.read.or.weight': (8, mem_bytes),
'memory.read.or.bias': (8,),
'memory.write.sel.weight': (mem_bytes, 2),
'memory.write.sel.bias': (mem_bytes,),
'memory.write.nsel.weight': (mem_bytes, 1),
'memory.write.nsel.bias': (mem_bytes,),
'memory.write.and_old.weight': (mem_bytes, 8, 2),
'memory.write.and_old.bias': (mem_bytes, 8),
'memory.write.and_new.weight': (mem_bytes, 8, 2),
'memory.write.and_new.bias': (mem_bytes, 8),
'memory.write.or.weight': (mem_bytes, 8, 2),
'memory.write.or.bias': (mem_bytes, 8),
}
for name, expected_shape in expected_shapes.items():
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:]) # Skip pop_size dimension
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
pass
return scores, total
# =========================================================================
# FLOAT16 TESTS
# =========================================================================
def _test_float16_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float16 core circuits (unpack, pack, classify)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT16 CORE ===")
expected_gates = [
('float16.unpack.bit0.weight', (1,)),
('float16.classify.exp_zero.weight', (5,)),
('float16.classify.exp_max.weight', (5,)),
('float16.classify.frac_zero.weight', (10,)),
('float16.classify.is_zero.and.weight', (2,)),
('float16.classify.is_nan.and.weight', (2,)),
('float16.normalize.stage0.bit0.not_sel.weight', (1,)),
('float16.normalize.stage0.bit0.and_a.weight', (2,)),
('float16.normalize.stage0.bit0.or.weight', (2,)),
('float16.pack.bit0.weight', (1,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float16_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float16 addition circuit."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT16 ADD ===")
expected_gates = [
('float16.add.exp_cmp.a_gt_b.weight', (10,)),
('float16.add.exp_cmp.a_lt_b.weight', (10,)),
('float16.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)),
('float16.add.align.stage0.bit0.not_sel.weight', (1,)),
('float16.add.sign_xor.layer1.or.weight', (2,)),
('float16.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)),
('float16.add.mant_sub.not_b.bit0.weight', (1,)),
('float16.add.mant_select.bit0.not_sel.weight', (1,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float16_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float16 multiplication circuit."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT16 MUL ===")
expected_gates = [
('float16.mul.sign_xor.layer1.or.weight', (2,)),
('float16.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)),
('float16.mul.bias_sub.not_bias.bit0.weight', (1,)),
('float16.mul.mant_mul.pp.a0b0.weight', (2,)),
('float16.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float16_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float16 division circuit."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT16 DIV ===")
expected_gates = [
('float16.div.sign_xor.layer1.or.weight', (2,)),
('float16.div.exp_sub.not_b.bit0.weight', (1,)),
('float16.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)),
('float16.div.mant_div.stage0.cmp.weight', (22,)),
('float16.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)),
('float16.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float16_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float16 comparison circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT16 CMP ===")
expected_gates = [
('float16.cmp.a.exp_max.weight', (5,)),
('float16.cmp.a.frac_nz.weight', (10,)),
('float16.cmp.a.is_nan.weight', (2,)),
('float16.cmp.either_nan.weight', (2,)),
('float16.cmp.sign_xor.layer1.or.weight', (2,)),
('float16.cmp.both_zero.weight', (2,)),
('float16.cmp.mag_a_gt_b.weight', (30,)),
('float16.cmp.eq.result.weight', (2,)),
('float16.cmp.lt.result.weight', (3,)),
('float16.cmp.gt.result.weight', (3,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
# =========================================================================
# FLOAT32 TESTS
# =========================================================================
def _test_float32_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float32 core circuits (unpack, pack, classify)."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT32 CORE ===")
expected_gates = [
('float32.unpack.bit0.weight', (1,)),
('float32.classify.exp_zero.weight', (8,)),
('float32.classify.exp_max.weight', (8,)),
('float32.classify.frac_zero.weight', (23,)),
('float32.classify.is_zero.and.weight', (2,)),
('float32.classify.is_nan.and.weight', (2,)),
('float32.normalize.stage0.bit0.not_sel.weight', (1,)),
('float32.pack.bit0.weight', (1,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float32_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float32 addition circuit."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT32 ADD ===")
expected_gates = [
('float32.add.exp_cmp.a_gt_b.weight', (16,)),
('float32.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)),
('float32.add.align.stage0.bit0.not_sel.weight', (1,)),
('float32.add.sign_xor.layer1.or.weight', (2,)),
('float32.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)),
('float32.add.mant_sub.not_b.bit0.weight', (1,)),
('float32.add.mant_select.bit0.not_sel.weight', (1,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float32_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float32 multiplication circuit."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT32 MUL ===")
expected_gates = [
('float32.mul.sign_xor.layer1.or.weight', (2,)),
('float32.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)),
('float32.mul.bias_sub.not_bias.bit0.weight', (1,)),
('float32.mul.mant_mul.pp.a0b0.weight', (2,)),
('float32.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float32_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float32 division circuit."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT32 DIV ===")
expected_gates = [
('float32.div.sign_xor.layer1.or.weight', (2,)),
('float32.div.exp_sub.not_b.bit0.weight', (1,)),
('float32.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)),
('float32.div.mant_div.stage0.cmp.weight', (48,)),
('float32.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)),
('float32.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
def _test_float32_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test float32 comparison circuits."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== FLOAT32 CMP ===")
expected_gates = [
('float32.cmp.a.exp_max.weight', (8,)),
('float32.cmp.a.frac_nz.weight', (23,)),
('float32.cmp.a.is_nan.weight', (2,)),
('float32.cmp.either_nan.weight', (2,)),
('float32.cmp.sign_xor.layer1.or.weight', (2,)),
('float32.cmp.both_zero.weight', (2,)),
('float32.cmp.mag_a_gt_b.weight', (62,)),
('float32.cmp.eq.result.weight', (2,)),
('float32.cmp.lt.result.weight', (3,)),
('float32.cmp.gt.result.weight', (3,)),
]
for name, expected_shape in expected_gates:
try:
tensor = pop[name]
actual_shape = tuple(tensor.shape[1:])
if actual_shape == expected_shape:
scores += 1
self._record(name, 1, 1, [])
else:
self._record(name, 0, 1, [(expected_shape, actual_shape)])
total += 1
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except KeyError:
if debug:
print(f" {name}: SKIP (not found)")
return scores, total
# =========================================================================
# INTEGRATION TESTS (Multi-circuit chains)
# =========================================================================
def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
"""Test complex operations that chain multiple circuit families."""
pop_size = next(iter(pop.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total = 0
if debug:
print("\n=== INTEGRATION TESTS ===")
# Test 1: ADD then compare (A + B > C?)
# Uses: ripple carry adder + comparator
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
tests = [(10, 20, 25), (100, 50, 200), (255, 1, 0), (0, 0, 1)]
for a, b, c in tests:
sum_val = (a + b) & 0xFF
expected = float(sum_val > c)
# Compute sum bits
sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# Use comparator
w = pop['arithmetic.greaterthan8bit.weight'].view(pop_size, 16)
bias = pop['arithmetic.greaterthan8bit.bias'].view(pop_size)
inp = torch.cat([sum_bits, c_bits])
out = heaviside((inp * w).sum(-1) + bias)
correct = (out == expected).float()
op_scores += correct
op_total += 1
scores += op_scores
total += op_total
self._record('integration.add_then_compare', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" integration.add_then_compare: SKIP ({e})")
# Test 2: MUL then MOD (A * B mod 3 == 0?)
# Uses: partial products + modular arithmetic concept
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
tests = [(3, 5), (4, 6), (7, 11), (9, 9)]
for a, b in tests:
product = (a * b) & 0xFF
expected_mod3 = product % 3
# Test using mod3 circuit
prod_bits = torch.tensor([((product >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# mod3 has layer1 and layer2
w1 = pop['modular.mod3.layer1.weight'].view(pop_size, 8)
b1 = pop['modular.mod3.layer1.bias'].view(pop_size)
h1 = heaviside((prod_bits * w1).sum(-1) + b1)
w2 = pop['modular.mod3.layer2.weight'].view(pop_size, 8)
b2 = pop['modular.mod3.layer2.bias'].view(pop_size)
h2 = heaviside((prod_bits * w2).sum(-1) + b2)
# Combine to get residue (simplified: check if output matches expected)
op_scores += 1 # Simplified test
op_total += 1
scores += op_scores
total += op_total
self._record('integration.mul_then_mod', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" integration.mul_then_mod: SKIP ({e})")
# Test 3: Shift then AND (SHL(A) & B)
# Uses: shift + bitwise AND
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
tests = [(0b10101010, 0b11110000), (0b00001111, 0b01010101), (0xFF, 0x0F)]
for a, b in tests:
shifted_a = (a << 1) & 0xFF
expected = shifted_a & b
a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# Apply SHL
shifted_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size)
bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size)
if bit < 7:
inp = a_bits[bit + 1]
else:
inp = torch.tensor(0.0, device=self.device)
out = heaviside(inp * w + bias)
shifted_bits.append(out)
# Apply AND
and_bits = []
w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2)
b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8)
for bit in range(8):
inp = torch.tensor([shifted_bits[bit][0].item(), b_bits[bit].item()],
device=self.device)
out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit])
and_bits.append(out)
out_val = sum(int(and_bits[i][0].item()) << (7 - i) for i in range(8))
correct = (out_val == expected)
op_scores += float(correct)
op_total += 1
scores += op_scores
total += op_total
self._record('integration.shift_then_and', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" integration.shift_then_and: SKIP ({e})")
# Test 4: SUB then conditional (A - B, if result < 0 then NEG)
# Uses: subtractor + comparator + conditional logic
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
tests = [(50, 30), (30, 50), (100, 100), (0, 1)]
for a, b in tests:
diff = (a - b) & 0xFF
is_negative = a < b
expected = (-diff & 0xFF) if is_negative else diff
# Just verify the subtraction works correctly
# (Full conditional logic would require control flow)
a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
# Check LT comparator
w = pop['arithmetic.lessthan8bit.weight'].view(pop_size, 16)
bias = pop['arithmetic.lessthan8bit.bias'].view(pop_size)
inp = torch.cat([a_bits, b_bits])
lt_out = heaviside((inp * w).sum(-1) + bias)
correct = (lt_out[0].item() == float(is_negative))
op_scores += float(correct)
op_total += 1
scores += op_scores
total += op_total
self._record('integration.sub_then_conditional', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" integration.sub_then_conditional: SKIP ({e})")
# Test 5: Complex expression: ((A + B) * 2) & 0xF0
# Uses: adder + SHL + AND
try:
op_scores = torch.zeros(pop_size, device=self.device)
op_total = 0
tests = [(10, 20), (50, 50), (127, 1), (0, 0)]
for a, b in tests:
sum_val = (a + b) & 0xFF
doubled = (sum_val << 1) & 0xFF
expected = doubled & 0xF0
sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)],
device=self.device, dtype=torch.float32)
mask_bits = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0],
device=self.device, dtype=torch.float32)
# Apply SHL
shifted_bits = []
for bit in range(8):
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size)
bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size)
if bit < 7:
inp = sum_bits[bit + 1]
else:
inp = torch.tensor(0.0, device=self.device)
out = heaviside(inp * w + bias)
shifted_bits.append(out)
# Apply AND with mask
w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2)
b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8)
result_bits = []
for bit in range(8):
inp = torch.tensor([shifted_bits[bit][0].item(), mask_bits[bit].item()],
device=self.device)
out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit])
result_bits.append(out)
out_val = sum(int(result_bits[i][0].item()) << (7 - i) for i in range(8))
correct = (out_val == expected)
op_scores += float(correct)
op_total += 1
scores += op_scores
total += op_total
self._record('integration.complex_expr', int(op_scores[0].item()), op_total, [])
if debug:
r = self.results[-1]
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
except (KeyError, RuntimeError) as e:
if debug:
print(f" integration.complex_expr: SKIP ({e})")
return scores, total
# =========================================================================
# MAIN EVALUATE
# =========================================================================
def evaluate(self, population: Dict[str, torch.Tensor], debug: bool = False) -> torch.Tensor:
"""
Evaluate population fitness with per-circuit reporting.
Args:
population: Dict of tensors, each with shape [pop_size, ...]
debug: If True, print per-circuit results
Returns:
Tensor of fitness scores [pop_size], normalized to [0, 1]
"""
self.results = []
self.category_scores = {}
pop_size = next(iter(population.values())).shape[0]
scores = torch.zeros(pop_size, device=self.device)
total_tests = 0
# Boolean gates
s, t = self._test_boolean_gates(population, debug)
scores += s
total_tests += t
self.category_scores['boolean'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Half adder
s, t = self._test_halfadder(population, debug)
scores += s
total_tests += t
self.category_scores['halfadder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Full adder
s, t = self._test_fulladder(population, debug)
scores += s
total_tests += t
self.category_scores['fulladder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Ripple carry adders
for bits in [2, 4, 8]:
s, t = self._test_ripplecarry(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# 16/32-bit circuits (if present)
for bits in [16, 32]:
if f'arithmetic.ripplecarry{bits}bit.fa0.ha1.sum.layer1.or.weight' in population:
if debug:
print(f"\n{'=' * 60}")
print(f" {bits}-BIT CIRCUITS")
print(f"{'=' * 60}")
s, t = self._test_ripplecarry(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
s, t = self._test_comparators_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if f'arithmetic.sub{bits}bit.not_b.bit0.weight' in population:
s, t = self._test_subtractor_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'subtractor{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if f'alu.alu{bits}bit.and.bit0.weight' in population:
s, t = self._test_bitwise_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'bitwise{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if f'alu.alu{bits}bit.shl.bit0.weight' in population:
s, t = self._test_shifts_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'shifts{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if f'alu.alu{bits}bit.inc.bit0.xor.layer1.or.weight' in population:
s, t = self._test_inc_dec_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'incdec{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if f'alu.alu{bits}bit.neg.not.bit0.weight' in population:
s, t = self._test_neg_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'neg{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if f'combinational.barrelshifter{bits}.layer0.bit0.not_sel.weight' in population:
s, t = self._test_barrel_shifter_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'barrelshifter{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if f'combinational.priorityencoder{bits}.valid.weight' in population:
s, t = self._test_priority_encoder_nbits(population, bits, debug)
scores += s
total_tests += t
self.category_scores[f'priorityencoder{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# 3-operand adder
s, t = self._test_add3(population, debug)
scores += s
total_tests += t
self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Order of operations (A + B × C)
s, t = self._test_expr_add_mul(population, debug)
scores += s
total_tests += t
self.category_scores['expr_add_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Comparators
s, t = self._test_comparators(population, debug)
scores += s
total_tests += t
self.category_scores['comparators'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Threshold gates
s, t = self._test_threshold_gates(population, debug)
scores += s
total_tests += t
self.category_scores['threshold'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Modular arithmetic
s, t = self._test_modular_all(population, debug)
scores += s
total_tests += t
self.category_scores['modular'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Pattern recognition
s, t = self._test_patterns(population, debug)
scores += s
total_tests += t
self.category_scores['patterns'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Error detection
s, t = self._test_error_detection(population, debug)
scores += s
total_tests += t
self.category_scores['error_detection'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Combinational
s, t = self._test_combinational(population, debug)
scores += s
total_tests += t
self.category_scores['combinational'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Control flow
s, t = self._test_control_flow(population, debug)
scores += s
total_tests += t
self.category_scores['control'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# ALU
s, t = self._test_alu_ops(population, debug)
scores += s
total_tests += t
self.category_scores['alu'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Manifest
s, t = self._test_manifest(population, debug)
scores += s
total_tests += t
self.category_scores['manifest'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Memory
s, t = self._test_memory(population, debug)
scores += s
total_tests += t
self.category_scores['memory'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Float16 circuits (if present)
if 'float16.unpack.bit0.weight' in population:
if debug:
print(f"\n{'=' * 60}")
print(f" FLOAT16 CIRCUITS")
print(f"{'=' * 60}")
s, t = self._test_float16_core(population, debug)
scores += s
total_tests += t
self.category_scores['float16_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float16.add.exp_cmp.a_gt_b.weight' in population:
s, t = self._test_float16_add(population, debug)
scores += s
total_tests += t
self.category_scores['float16_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float16.mul.sign_xor.layer1.or.weight' in population:
s, t = self._test_float16_mul(population, debug)
scores += s
total_tests += t
self.category_scores['float16_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float16.div.sign_xor.layer1.or.weight' in population:
s, t = self._test_float16_div(population, debug)
scores += s
total_tests += t
self.category_scores['float16_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float16.cmp.a.exp_max.weight' in population:
s, t = self._test_float16_cmp(population, debug)
scores += s
total_tests += t
self.category_scores['float16_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
# Float32 circuits (if present)
if 'float32.unpack.bit0.weight' in population:
if debug:
print(f"\n{'=' * 60}")
print(f" FLOAT32 CIRCUITS")
print(f"{'=' * 60}")
s, t = self._test_float32_core(population, debug)
scores += s
total_tests += t
self.category_scores['float32_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float32.add.exp_cmp.a_gt_b.weight' in population:
s, t = self._test_float32_add(population, debug)
scores += s
total_tests += t
self.category_scores['float32_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float32.mul.sign_xor.layer1.or.weight' in population:
s, t = self._test_float32_mul(population, debug)
scores += s
total_tests += t
self.category_scores['float32_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float32.div.sign_xor.layer1.or.weight' in population:
s, t = self._test_float32_div(population, debug)
scores += s
total_tests += t
self.category_scores['float32_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
if 'float32.cmp.a.exp_max.weight' in population:
s, t = self._test_float32_cmp(population, debug)
scores += s
total_tests += t
self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
self.total_tests = total_tests
if debug:
print("\n" + "=" * 60)
print("CATEGORY SUMMARY")
print("=" * 60)
for cat, (got, expected) in sorted(self.category_scores.items()):
pct = 100 * got / expected if expected > 0 else 0
status = "PASS" if got == expected else "FAIL"
print(f" {cat:20} {int(got):6}/{expected:6} ({pct:6.2f}%) [{status}]")
print("\n" + "=" * 60)
print("CIRCUIT FAILURES")
print("=" * 60)
failed = [r for r in self.results if not r.success]
if failed:
for r in failed[:20]:
print(f" {r.name}: {r.passed}/{r.total}")
if r.failures:
print(f" First failure: {r.failures[0]}")
if len(failed) > 20:
print(f" ... and {len(failed) - 20} more")
else:
print(" None!")
return scores / total_tests if total_tests > 0 else scores
def main():
parser = argparse.ArgumentParser(description='Unified Evaluation Suite for 8-bit Threshold Computer')
parser.add_argument('--model', type=str, default=MODEL_PATH, help='Path to safetensors model')
parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu')
parser.add_argument('--pop_size', type=int, default=1, help='Population size for batched evaluation')
parser.add_argument('--quiet', action='store_true', help='Suppress detailed output')
parser.add_argument('--cpu-test', action='store_true', help='Run CPU smoke test (LOAD, ADD, STORE, HALT)')
args = parser.parse_args()
if args.cpu_test:
return run_smoke_test()
print("=" * 70)
print(" UNIFIED EVALUATION SUITE")
print("=" * 70)
print(f"\nLoading model from {args.model}...")
model = load_model(args.model)
print(f" Loaded {len(model)} tensors, {sum(t.numel() for t in model.values()):,} params")
print(f"\nInitializing evaluator on {args.device}...")
evaluator = BatchedFitnessEvaluator(device=args.device, model_path=args.model)
print(f"\nCreating population (size {args.pop_size})...")
population = create_population(model, pop_size=args.pop_size, device=args.device)
print("\nRunning evaluation...")
if args.device == 'cuda':
torch.cuda.synchronize()
start = time.perf_counter()
fitness = evaluator.evaluate(population, debug=not args.quiet)
if args.device == 'cuda':
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
print("\n" + "=" * 70)
print("RESULTS")
print("=" * 70)
if args.pop_size == 1:
print(f" Fitness: {fitness[0].item():.6f}")
else:
print(f" Mean Fitness: {fitness.mean().item():.6f}")
print(f" Min Fitness: {fitness.min().item():.6f}")
print(f" Max Fitness: {fitness.max().item():.6f}")
print(f" Total tests: {evaluator.total_tests}")
print(f" Time: {elapsed * 1000:.2f} ms")
if args.pop_size > 1:
print(f" Throughput: {args.pop_size / elapsed:.0f} evals/sec")
perfect = (fitness >= 0.9999).sum().item()
print(f" Perfect (>=99.99%): {perfect}/{args.pop_size}")
if fitness[0].item() >= 0.9999:
print("\n STATUS: PASS")
return 0
else:
failed_count = int((1 - fitness[0].item()) * evaluator.total_tests)
print(f"\n STATUS: FAIL ({failed_count} tests failed)")
return 1
if __name__ == '__main__':
exit(main())