""" Unified Evaluation Suite for 8-bit Threshold Computer ====================================================== GPU-batched evaluation with per-circuit reporting. Includes CPU runtime for threshold-weight execution. Usage: python eval.py # Run circuit evaluation python eval.py --device cpu # CPU mode python eval.py --pop_size 1000 # Population mode for evolution python eval.py --cpu-test # Run CPU smoke test API (for prune_weights.py): from eval import load_model, create_population, BatchedFitnessEvaluator from eval import ThresholdCPU, ThresholdALU, CPUState """ import argparse import json import os import time from collections import defaultdict from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional, Tuple import torch from safetensors import safe_open MODEL_PATH = os.path.join(os.path.dirname(__file__), "neural_computer.safetensors") @dataclass class CircuitResult: """Result for a single circuit test.""" name: str passed: int total: int failures: List[Tuple] = field(default_factory=list) @property def success(self) -> bool: return self.passed == self.total @property def rate(self) -> float: return self.passed / self.total if self.total > 0 else 0.0 def heaviside(x: torch.Tensor) -> torch.Tensor: """Threshold activation: 1 if x >= 0, else 0.""" return (x >= 0).float() def load_model(path: str = MODEL_PATH) -> Dict[str, torch.Tensor]: """Load model tensors from safetensors.""" with safe_open(path, framework='pt') as f: return {name: f.get_tensor(name).float() for name in f.keys()} def load_metadata(path: str = MODEL_PATH) -> Dict: """Load metadata from safetensors (includes signal_registry).""" with safe_open(path, framework='pt') as f: meta = f.metadata() if meta and 'signal_registry' in meta: return {'signal_registry': json.loads(meta['signal_registry'])} return {'signal_registry': {}} def get_manifest(tensors: Dict[str, torch.Tensor]) -> Dict[str, int]: """Extract manifest values from tensors. Returns dict with data_bits, addr_bits, memory_bytes, version. Defaults to 8-bit data, 16-bit addr for legacy models. """ return { 'data_bits': int(tensors.get('manifest.data_bits', torch.tensor([8.0])).item()), 'addr_bits': int(tensors.get('manifest.addr_bits', tensors.get('manifest.pc_width', torch.tensor([16.0]))).item()), 'memory_bytes': int(tensors.get('manifest.memory_bytes', torch.tensor([65536.0])).item()), 'version': float(tensors.get('manifest.version', torch.tensor([1.0])).item()), } def create_population( base_tensors: Dict[str, torch.Tensor], pop_size: int, device: str = 'cuda' ) -> Dict[str, torch.Tensor]: """Replicate base tensors for batched population evaluation.""" return { name: tensor.unsqueeze(0).expand(pop_size, *tensor.shape).clone().to(device) for name, tensor in base_tensors.items() } # ============================================================================= # CPU RUNTIME # ============================================================================= FLAG_NAMES = ["Z", "N", "C", "V"] CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"] PC_BITS = 16 IR_BITS = 16 REG_BITS = 8 REG_COUNT = 4 FLAG_BITS = 4 SP_BITS = 16 CTRL_BITS = 4 MEM_BYTES = 65536 MEM_BITS = MEM_BYTES * 8 STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS def int_to_bits(value: int, width: int) -> List[int]: return [(value >> (width - 1 - i)) & 1 for i in range(width)] def bits_to_int(bits: List[int]) -> int: value = 0 for bit in bits: value = (value << 1) | int(bit) return value def bits_msb_to_lsb(bits: List[int]) -> List[int]: return list(reversed(bits)) @dataclass class CPUState: pc: int ir: int regs: List[int] flags: List[int] sp: int ctrl: List[int] mem: List[int] def copy(self) -> 'CPUState': return CPUState( pc=int(self.pc), ir=int(self.ir), regs=[int(r) for r in self.regs], flags=[int(f) for f in self.flags], sp=int(self.sp), ctrl=[int(c) for c in self.ctrl], mem=[int(m) for m in self.mem], ) def pack_state(state: CPUState) -> List[int]: bits: List[int] = [] bits.extend(int_to_bits(state.pc, PC_BITS)) bits.extend(int_to_bits(state.ir, IR_BITS)) for reg in state.regs: bits.extend(int_to_bits(reg, REG_BITS)) bits.extend([int(f) for f in state.flags]) bits.extend(int_to_bits(state.sp, SP_BITS)) bits.extend([int(c) for c in state.ctrl]) for byte in state.mem: bits.extend(int_to_bits(byte, REG_BITS)) return bits def unpack_state(bits: List[int]) -> CPUState: if len(bits) != STATE_BITS: raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}") idx = 0 pc = bits_to_int(bits[idx:idx + PC_BITS]) idx += PC_BITS ir = bits_to_int(bits[idx:idx + IR_BITS]) idx += IR_BITS regs = [] for _ in range(REG_COUNT): regs.append(bits_to_int(bits[idx:idx + REG_BITS])) idx += REG_BITS flags = [int(b) for b in bits[idx:idx + FLAG_BITS]] idx += FLAG_BITS sp = bits_to_int(bits[idx:idx + SP_BITS]) idx += SP_BITS ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]] idx += CTRL_BITS mem = [] for _ in range(MEM_BYTES): mem.append(bits_to_int(bits[idx:idx + REG_BITS])) idx += REG_BITS return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem) def decode_ir(ir: int) -> Tuple[int, int, int, int]: opcode = (ir >> 12) & 0xF rd = (ir >> 10) & 0x3 rs = (ir >> 8) & 0x3 imm8 = ir & 0xFF return opcode, rd, rs, imm8 def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]: z = 1 if result == 0 else 0 n = 1 if (result & 0x80) else 0 c = 1 if carry else 0 v = 1 if overflow else 0 return z, n, c, v def alu_add(a: int, b: int) -> Tuple[int, int, int]: full = a + b result = full & 0xFF carry = 1 if full > 0xFF else 0 overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 return result, carry, overflow def alu_sub(a: int, b: int) -> Tuple[int, int, int]: full = (a - b) & 0x1FF result = full & 0xFF carry = 1 if a >= b else 0 overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 return result, carry, overflow def ref_step(state: CPUState) -> CPUState: """Reference CPU cycle (pure Python arithmetic).""" if state.ctrl[0] == 1: return state.copy() s = state.copy() hi = s.mem[s.pc] lo = s.mem[(s.pc + 1) & 0xFFFF] s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) next_pc = (s.pc + 2) & 0xFFFF opcode, rd, rs, imm8 = decode_ir(s.ir) a = s.regs[rd] b = s.regs[rs] addr16 = None next_pc_ext = next_pc if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): addr_hi = s.mem[next_pc] addr_lo = s.mem[(next_pc + 1) & 0xFFFF] addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) next_pc_ext = (next_pc + 2) & 0xFFFF write_result = True result = a carry = 0 overflow = 0 if opcode == 0x0: result, carry, overflow = alu_add(a, b) elif opcode == 0x1: result, carry, overflow = alu_sub(a, b) elif opcode == 0x2: result = a & b elif opcode == 0x3: result = a | b elif opcode == 0x4: result = a ^ b elif opcode == 0x5: result = (a << 1) & 0xFF elif opcode == 0x6: result = (a >> 1) & 0xFF elif opcode == 0x7: result = (a * b) & 0xFF elif opcode == 0x8: if b == 0: result = 0xFF else: result = a // b elif opcode == 0x9: result, carry, overflow = alu_sub(a, b) write_result = False elif opcode == 0xA: result = s.mem[addr16] elif opcode == 0xB: s.mem[addr16] = b & 0xFF write_result = False elif opcode == 0xC: s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xD: cond_type = imm8 & 0x7 if cond_type == 0: take_branch = s.flags[0] == 1 elif cond_type == 1: take_branch = s.flags[0] == 0 elif cond_type == 2: take_branch = s.flags[2] == 1 elif cond_type == 3: take_branch = s.flags[2] == 0 elif cond_type == 4: take_branch = s.flags[1] == 1 elif cond_type == 5: take_branch = s.flags[1] == 0 elif cond_type == 6: take_branch = s.flags[3] == 1 else: take_branch = s.flags[3] == 0 if take_branch: s.pc = addr16 & 0xFFFF else: s.pc = next_pc_ext write_result = False elif opcode == 0xE: ret_addr = next_pc_ext & 0xFFFF s.sp = (s.sp - 1) & 0xFFFF s.mem[s.sp] = (ret_addr >> 8) & 0xFF s.sp = (s.sp - 1) & 0xFFFF s.mem[s.sp] = ret_addr & 0xFF s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xF: s.ctrl[0] = 1 write_result = False if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8): s.flags = list(flags_from_result(result, carry, overflow)) if write_result: s.regs[rd] = result & 0xFF if opcode not in (0xC, 0xD, 0xE): s.pc = next_pc_ext return s def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: """Reference execution loop.""" s = state.copy() for i in range(max_cycles): if s.ctrl[0] == 1: return s, i s = ref_step(s) return s, max_cycles class ThresholdALU: def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None: self.device = device self.tensors = {k: v.float().to(device) for k, v in load_model(model_path).items()} def _get(self, name: str) -> torch.Tensor: return self.tensors[name] def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float: w = self._get(weight_key) b = self._get(bias_key) inp = torch.tensor(inputs, device=self.device) return heaviside((inp * w).sum() + b).item() def _eval_xor(self, prefix: str, inputs: List[float]) -> float: inp = torch.tensor(inputs, device=self.device) w_or = self._get(f"{prefix}.layer1.or.weight") b_or = self._get(f"{prefix}.layer1.or.bias") w_nand = self._get(f"{prefix}.layer1.nand.weight") b_nand = self._get(f"{prefix}.layer1.nand.bias") w2 = self._get(f"{prefix}.layer2.weight") b2 = self._get(f"{prefix}.layer2.bias") h_or = heaviside((inp * w_or).sum() + b_or).item() h_nand = heaviside((inp * w_nand).sum() + b_nand).item() hidden = torch.tensor([h_or, h_nand], device=self.device) return heaviside((hidden * w2).sum() + b2).item() def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]: ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b]) ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b]) ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin]) ha2_carry = self._eval_gate( f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin] ) cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry]) return ha2_sum, cout def add(self, a: int, b: int) -> Tuple[int, int, int]: a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) carry = 0.0 sum_bits: List[int] = [] for bit in range(REG_BITS): sum_bit, carry = self._eval_full_adder( f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry ) sum_bits.append(int(sum_bit)) result = bits_to_int(list(reversed(sum_bits))) carry_out = int(carry) overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 return result, carry_out, overflow def sub(self, a: int, b: int) -> Tuple[int, int, int]: a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) carry = 1.0 sum_bits: List[int] = [] for bit in range(REG_BITS): notb = self._eval_gate( f"arithmetic.sub8bit.notb{bit}.weight", f"arithmetic.sub8bit.notb{bit}.bias", [float(b_bits[bit])], ) xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb]) xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry]) and1 = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.and1.weight", f"arithmetic.sub8bit.fa{bit}.and1.bias", [float(a_bits[bit]), notb], ) and2 = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.and2.weight", f"arithmetic.sub8bit.fa{bit}.and2.bias", [xor1, carry], ) carry = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.or_carry.weight", f"arithmetic.sub8bit.fa{bit}.or_carry.bias", [and1, and2], ) sum_bits.append(int(xor2)) result = bits_to_int(list(reversed(sum_bits))) carry_out = int(carry) overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 return result, carry_out, overflow def bitwise_and(self, a: int, b: int) -> int: a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) w = self._get("alu.alu8bit.and.weight") bias = self._get("alu.alu8bit.and.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def bitwise_or(self, a: int, b: int) -> int: a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) w = self._get("alu.alu8bit.or.weight") bias = self._get("alu.alu8bit.or.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def bitwise_not(self, a: int) -> int: a_bits = int_to_bits(a, REG_BITS) w = self._get("alu.alu8bit.not.weight") bias = self._get("alu.alu8bit.not.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit])], device=self.device) out = heaviside((inp * w[bit]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def bitwise_xor(self, a: int, b: int) -> int: a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) w_or = self._get("alu.alu8bit.xor.layer1.or.weight") b_or = self._get("alu.alu8bit.xor.layer1.or.bias") w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight") b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias") w2 = self._get("alu.alu8bit.xor.layer2.weight") b2 = self._get("alu.alu8bit.xor.layer2.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit]) h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit]) hidden = torch.stack([h_or, h_nand]) out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def shift_left(self, a: int) -> int: a_bits = int_to_bits(a, REG_BITS) out_bits = [] for bit in range(REG_BITS): w = self._get(f"alu.alu8bit.shl.bit{bit}.weight") bias = self._get(f"alu.alu8bit.shl.bit{bit}.bias") if bit < 7: inp = torch.tensor([float(a_bits[bit + 1])], device=self.device) else: inp = torch.tensor([0.0], device=self.device) out = heaviside((inp * w).sum() + bias).item() out_bits.append(int(out)) return bits_to_int(out_bits) def shift_right(self, a: int) -> int: a_bits = int_to_bits(a, REG_BITS) out_bits = [] for bit in range(REG_BITS): w = self._get(f"alu.alu8bit.shr.bit{bit}.weight") bias = self._get(f"alu.alu8bit.shr.bit{bit}.bias") if bit > 0: inp = torch.tensor([float(a_bits[bit - 1])], device=self.device) else: inp = torch.tensor([0.0], device=self.device) out = heaviside((inp * w).sum() + bias).item() out_bits.append(int(out)) return bits_to_int(out_bits) def multiply(self, a: int, b: int) -> int: """8-bit multiply using partial product AND gates + shift-add.""" a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) pp = [[0] * 8 for _ in range(8)] for i in range(8): for j in range(8): w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight") bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias") inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device) pp[i][j] = int(heaviside((inp * w).sum() + bias).item()) result = 0 for j in range(8): if b_bits[j] == 0: continue row = 0 for i in range(8): row |= (pp[i][j] << (7 - i)) shifted = row << (7 - j) result, _, _ = self.add(result & 0xFF, shifted & 0xFF) if shifted > 255 or result > 255: result = (result + (shifted >> 8)) & 0xFF return result & 0xFF def divide(self, a: int, b: int) -> Tuple[int, int]: """8-bit divide using restoring division with threshold gates.""" if b == 0: return 0xFF, a a_bits = int_to_bits(a, REG_BITS) quotient = 0 remainder = 0 for stage in range(8): remainder = ((remainder << 1) | a_bits[stage]) & 0xFF rem_bits = int_to_bits(remainder, REG_BITS) div_bits = int_to_bits(b, REG_BITS) w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight") bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias") inp = torch.tensor([float(rem_bits[i]) for i in range(8)] + [float(div_bits[i]) for i in range(8)], device=self.device) cmp_result = int(heaviside((inp * w).sum() + bias).item()) if cmp_result: remainder, _, _ = self.sub(remainder, b) quotient = (quotient << 1) | 1 else: quotient = quotient << 1 return quotient & 0xFF, remainder & 0xFF class ThresholdCPU: def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None: self.device = device self.alu = ThresholdALU(model_path, device=device) def _addr_decode(self, addr: int) -> torch.Tensor: bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32) w = self.alu._get("memory.addr_decode.weight") b = self.alu._get("memory.addr_decode.bias") return heaviside((w * bits).sum(dim=1) + b) def _memory_read(self, mem: List[int], addr: int) -> int: sel = self._addr_decode(addr) mem_bits = torch.tensor( [int_to_bits(byte, REG_BITS) for byte in mem], device=self.device, dtype=torch.float32, ) and_w = self.alu._get("memory.read.and.weight") and_b = self.alu._get("memory.read.and.bias") or_w = self.alu._get("memory.read.or.weight") or_b = self.alu._get("memory.read.or.bias") out_bits: List[int] = [] for bit in range(REG_BITS): inp = torch.stack([mem_bits[:, bit], sel], dim=1) and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item() out_bits.append(int(out_bit)) return bits_to_int(out_bits) def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]: sel = self._addr_decode(addr) data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32) mem_bits = torch.tensor( [int_to_bits(byte, REG_BITS) for byte in mem], device=self.device, dtype=torch.float32, ) sel_w = self.alu._get("memory.write.sel.weight") sel_b = self.alu._get("memory.write.sel.bias") nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1) nsel_b = self.alu._get("memory.write.nsel.bias") and_old_w = self.alu._get("memory.write.and_old.weight") and_old_b = self.alu._get("memory.write.and_old.bias") and_new_w = self.alu._get("memory.write.and_new.weight") and_new_b = self.alu._get("memory.write.and_new.bias") or_w = self.alu._get("memory.write.or.weight") or_b = self.alu._get("memory.write.or.bias") we = torch.ones_like(sel) sel_inp = torch.stack([sel, we], dim=1) write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) nsel = heaviside((write_sel * nsel_w) + nsel_b) new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device) for bit in range(REG_BITS): old_bit = mem_bits[:, bit] data_bit = data_bits[bit].expand(MEM_BYTES) inp_old = torch.stack([old_bit, nsel], dim=1) inp_new = torch.stack([data_bit, write_sel], dim=1) and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) or_inp = torch.stack([and_old, and_new], dim=1) out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) new_mem_bits[:, bit] = out_bit return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)] def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int: pc_bits = int_to_bits(pc_byte, REG_BITS) target_bits = int_to_bits(target_byte, REG_BITS) out_bits: List[int] = [] for bit in range(REG_BITS): not_sel = self.alu._eval_gate( f"{prefix}.bit{bit}.not_sel.weight", f"{prefix}.bit{bit}.not_sel.bias", [float(flag)], ) and_a = self.alu._eval_gate( f"{prefix}.bit{bit}.and_a.weight", f"{prefix}.bit{bit}.and_a.bias", [float(pc_bits[bit]), not_sel], ) and_b = self.alu._eval_gate( f"{prefix}.bit{bit}.and_b.weight", f"{prefix}.bit{bit}.and_b.bias", [float(target_bits[bit]), float(flag)], ) out_bit = self.alu._eval_gate( f"{prefix}.bit{bit}.or.weight", f"{prefix}.bit{bit}.or.bias", [and_a, and_b], ) out_bits.append(int(out_bit)) return bits_to_int(out_bits) def step(self, state: CPUState) -> CPUState: """Single CPU cycle using threshold neurons.""" if state.ctrl[0] == 1: return state.copy() s = state.copy() hi = self._memory_read(s.mem, s.pc) lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF) s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) next_pc = (s.pc + 2) & 0xFFFF opcode, rd, rs, imm8 = decode_ir(s.ir) a = s.regs[rd] b = s.regs[rs] addr16 = None next_pc_ext = next_pc if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): addr_hi = self._memory_read(s.mem, next_pc) addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF) addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) next_pc_ext = (next_pc + 2) & 0xFFFF write_result = True result = a carry = 0 overflow = 0 if opcode == 0x0: result, carry, overflow = self.alu.add(a, b) elif opcode == 0x1: result, carry, overflow = self.alu.sub(a, b) elif opcode == 0x2: result = self.alu.bitwise_and(a, b) elif opcode == 0x3: result = self.alu.bitwise_or(a, b) elif opcode == 0x4: result = self.alu.bitwise_xor(a, b) elif opcode == 0x5: result = self.alu.shift_left(a) elif opcode == 0x6: result = self.alu.shift_right(a) elif opcode == 0x7: result = self.alu.multiply(a, b) elif opcode == 0x8: result, _ = self.alu.divide(a, b) elif opcode == 0x9: result, carry, overflow = self.alu.sub(a, b) write_result = False elif opcode == 0xA: result = self._memory_read(s.mem, addr16) elif opcode == 0xB: s.mem = self._memory_write(s.mem, addr16, b & 0xFF) write_result = False elif opcode == 0xC: s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xD: cond_type = imm8 & 0x7 cond_circuits = [ ("control.jz", 0), ("control.jnz", 0), ("control.jc", 2), ("control.jnc", 2), ("control.jn", 1), ("control.jp", 1), ("control.jv", 3), ("control.jnv", 3), ] circuit_prefix, flag_idx = cond_circuits[cond_type] hi_pc = self._conditional_jump_byte( circuit_prefix, (next_pc_ext >> 8) & 0xFF, (addr16 >> 8) & 0xFF, s.flags[flag_idx], ) lo_pc = self._conditional_jump_byte( circuit_prefix, next_pc_ext & 0xFF, addr16 & 0xFF, s.flags[flag_idx], ) s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF) write_result = False elif opcode == 0xE: ret_addr = next_pc_ext & 0xFFFF s.sp = (s.sp - 1) & 0xFFFF s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF) s.sp = (s.sp - 1) & 0xFFFF s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF) s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xF: s.ctrl[0] = 1 write_result = False if opcode <= 0x9 or opcode == 0xA: s.flags = list(flags_from_result(result, carry, overflow)) if write_result: s.regs[rd] = result & 0xFF if opcode not in (0xC, 0xD, 0xE): s.pc = next_pc_ext return s def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: """Execute until HALT or max_cycles reached.""" s = state.copy() for i in range(max_cycles): if s.ctrl[0] == 1: return s, i s = self.step(s) return s, max_cycles def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor: """Tensor-in, tensor-out interface for neural integration.""" bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()] state = unpack_state(bits_list) final, _ = self.run_until_halt(state, max_cycles=max_cycles) return torch.tensor(pack_state(final), dtype=torch.float32) def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int: return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF) def write_instr(mem: List[int], addr: int, instr: int) -> None: mem[addr & 0xFFFF] = (instr >> 8) & 0xFF mem[(addr + 1) & 0xFFFF] = instr & 0xFF def write_addr(mem: List[int], addr: int, value: int) -> None: mem[addr & 0xFFFF] = (value >> 8) & 0xFF mem[(addr + 1) & 0xFFFF] = value & 0xFF def run_smoke_test() -> int: """Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12.""" mem = [0] * 65536 write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00)) write_addr(mem, 0x0002, 0x0100) write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00)) write_addr(mem, 0x0006, 0x0101) write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00)) write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00)) write_addr(mem, 0x000C, 0x0102) write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00)) mem[0x0100] = 5 mem[0x0101] = 7 state = CPUState( pc=0, ir=0, regs=[0, 0, 0, 0], flags=[0, 0, 0, 0], sp=0xFFFE, ctrl=[0, 0, 0, 0], mem=mem, ) print("Running reference implementation...") final, cycles = ref_run_until_halt(state, max_cycles=20) assert final.ctrl[0] == 1, "HALT flag not set" assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}" assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}" assert cycles <= 10, f"Unexpected cycle count: {cycles}" print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}") print("Running threshold-weight implementation...") threshold_cpu = ThresholdCPU() t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20) assert t_final.ctrl[0] == 1, "Threshold HALT flag not set" assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}" assert t_final.mem[0x0102] == final.mem[0x0102], ( f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}" ) assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}" print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}") print("Validating forward() tensor I/O...") bits = torch.tensor(pack_state(state), dtype=torch.float32) out_bits = threshold_cpu.forward(bits, max_cycles=20) out_state = unpack_state([int(b) for b in out_bits.tolist()]) assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}" assert out_state.mem[0x0102] == final.mem[0x0102], ( f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}" ) print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}") print("\nSmoke test: PASSED") return 0 # ============================================================================= # CIRCUIT EVALUATION # ============================================================================= class BatchedFitnessEvaluator: """ GPU-batched fitness evaluator with per-circuit reporting. Tests all circuits comprehensively. """ def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH, tensors: Dict[str, torch.Tensor] = None): self.device = device self.model_path = model_path self.metadata = load_metadata(model_path) self.signal_registry = self.metadata.get('signal_registry', {}) self.results: List[CircuitResult] = [] self.category_scores: Dict[str, Tuple[float, int]] = {} self.total_tests = 0 # Get manifest for N-bit support if tensors is not None: self.manifest = get_manifest(tensors) else: base_tensors = load_model(model_path) self.manifest = get_manifest(base_tensors) self.data_bits = self.manifest['data_bits'] self.addr_bits = self.manifest['addr_bits'] self._setup_tests() def _setup_tests(self): """Pre-compute test vectors on device.""" d = self.device # 2-input truth table [4, 2] self.tt2 = torch.tensor( [[0, 0], [0, 1], [1, 0], [1, 1]], device=d, dtype=torch.float32 ) # 3-input truth table [8, 3] self.tt3 = torch.tensor([ [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1] ], device=d, dtype=torch.float32) # Boolean gate expected outputs self.expected = { 'and': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32), 'or': torch.tensor([0, 1, 1, 1], device=d, dtype=torch.float32), 'nand': torch.tensor([1, 1, 1, 0], device=d, dtype=torch.float32), 'nor': torch.tensor([1, 0, 0, 0], device=d, dtype=torch.float32), 'xor': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32), 'xnor': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32), 'implies': torch.tensor([1, 1, 0, 1], device=d, dtype=torch.float32), 'biimplies': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32), 'not': torch.tensor([1, 0], device=d, dtype=torch.float32), 'ha_sum': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32), 'ha_carry': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32), 'fa_sum': torch.tensor([0, 1, 1, 0, 1, 0, 0, 1], device=d, dtype=torch.float32), 'fa_cout': torch.tensor([0, 0, 0, 1, 0, 1, 1, 1], device=d, dtype=torch.float32), } # NOT gate inputs self.not_inputs = torch.tensor([[0], [1]], device=d, dtype=torch.float32) # 8-bit test values self.test_8bit = torch.tensor([ 0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 0b10101010, 0b01010101, 0b11110000, 0b00001111, 0b11001100, 0b00110011, 0b10000001, 0b01111110 ], device=d, dtype=torch.long) # Bit representations [num_vals, 8] self.test_8bit_bits = torch.stack([ ((self.test_8bit >> (7 - i)) & 1).float() for i in range(8) ], dim=1) # Comparator test pairs comp_tests = [ (0, 0), (1, 0), (0, 1), (5, 3), (3, 5), (5, 5), (255, 0), (0, 255), (128, 127), (127, 128), (100, 99), (99, 100), (64, 32), (32, 64), (1, 1), (254, 255), (255, 254), (128, 128), (0, 128), (128, 0), (64, 64), (192, 192), (15, 16), (16, 15), (240, 239), (239, 240), (85, 170), (170, 85), (0xAA, 0x55), (0x55, 0xAA), (0x0F, 0xF0), (0xF0, 0x0F), (0x33, 0xCC), (0xCC, 0x33), (2, 3), (3, 2), (126, 127), (127, 126), (129, 128), (128, 129), (200, 199), (199, 200), (50, 51), (51, 50), (10, 20), (20, 10), (100, 100), (200, 200), (77, 77), (0, 0) ] self.comp_a = torch.tensor([c[0] for c in comp_tests], device=d, dtype=torch.long) self.comp_b = torch.tensor([c[1] for c in comp_tests], device=d, dtype=torch.long) # Modular test range self.mod_test = torch.arange(256, device=d, dtype=torch.long) # 32-bit test values (strategic sampling) self.test_32bit = torch.tensor([ 0, 1, 2, 255, 256, 65535, 65536, 0x7FFFFFFF, 0x80000000, 0xFFFFFFFF, 0x12345678, 0xDEADBEEF, 0xCAFEBABE, 1000000, 1000000000, 2147483647, 0x55555555, 0xAAAAAAAA, 0x0F0F0F0F, 0xF0F0F0F0 ], device=d, dtype=torch.long) # 32-bit comparator test pairs comp32_tests = [ (0, 0), (1, 0), (0, 1), (1000, 999), (999, 1000), (0xFFFFFFFF, 0), (0, 0xFFFFFFFF), (0x80000000, 0x7FFFFFFF), (0x7FFFFFFF, 0x80000000), (1000000, 1000000), (0x12345678, 0x12345678), (0xDEADBEEF, 0xCAFEBABE), (0xCAFEBABE, 0xDEADBEEF), (256, 255), (255, 256), (65536, 65535), (65535, 65536), ] self.comp32_a = torch.tensor([c[0] for c in comp32_tests], device=d, dtype=torch.long) self.comp32_b = torch.tensor([c[1] for c in comp32_tests], device=d, dtype=torch.long) def _record(self, name: str, passed: int, total: int, failures: List[Tuple] = None): """Record a circuit test result.""" self.results.append(CircuitResult( name=name, passed=passed, total=total, failures=failures or [] )) # ========================================================================= # BOOLEAN GATES # ========================================================================= def _test_single_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor, expected: torch.Tensor) -> torch.Tensor: """Test single-layer gate (AND, OR, NAND, NOR, IMPLIES).""" pop_size = next(iter(pop.values())).shape[0] w = pop[f'{prefix}.weight'] b = pop[f'{prefix}.bias'] # [num_tests, pop_size] out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): if exp.item() != got.item(): failures.append((inp.tolist(), exp.item(), got.item())) self._record(prefix, int(correct[0].item()), len(expected), failures) return correct def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor, expected: torch.Tensor) -> torch.Tensor: """Test two-layer gate (XOR, XNOR, BIIMPLIES).""" pop_size = next(iter(pop.values())).shape[0] # Layer 1 w1_n1 = pop[f'{prefix}.layer1.neuron1.weight'] b1_n1 = pop[f'{prefix}.layer1.neuron1.bias'] w1_n2 = pop[f'{prefix}.layer1.neuron2.weight'] b1_n2 = pop[f'{prefix}.layer1.neuron2.bias'] h1 = heaviside(inputs @ w1_n1.view(pop_size, -1).T + b1_n1.view(pop_size)) h2 = heaviside(inputs @ w1_n2.view(pop_size, -1).T + b1_n2.view(pop_size)) hidden = torch.stack([h1, h2], dim=-1) # Layer 2 w2 = pop[f'{prefix}.layer2.weight'] b2 = pop[f'{prefix}.layer2.bias'] out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): if exp.item() != got.item(): failures.append((inp.tolist(), exp.item(), got.item())) self._record(prefix, int(correct[0].item()), len(expected), failures) return correct def _test_xor_ornand(self, pop: Dict, prefix: str, inputs: torch.Tensor, expected: torch.Tensor) -> torch.Tensor: """Test XOR with or/nand layer naming.""" pop_size = next(iter(pop.values())).shape[0] w_or = pop[f'{prefix}.layer1.or.weight'] b_or = pop[f'{prefix}.layer1.or.bias'] w_nand = pop[f'{prefix}.layer1.nand.weight'] b_nand = pop[f'{prefix}.layer1.nand.bias'] h_or = heaviside(inputs @ w_or.view(pop_size, -1).T + b_or.view(pop_size)) h_nand = heaviside(inputs @ w_nand.view(pop_size, -1).T + b_nand.view(pop_size)) hidden = torch.stack([h_or, h_nand], dim=-1) w2 = pop[f'{prefix}.layer2.weight'] b2 = pop[f'{prefix}.layer2.bias'] out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): if exp.item() != got.item(): failures.append((inp.tolist(), exp.item(), got.item())) self._record(prefix, int(correct[0].item()), len(expected), failures) return correct def _test_boolean_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test all boolean gates.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== BOOLEAN GATES ===") # Single-layer gates for gate in ['and', 'or', 'nand', 'nor', 'implies']: scores += self._test_single_gate(pop, f'boolean.{gate}', self.tt2, self.expected[gate]) total += 4 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") # NOT gate w = pop['boolean.not.weight'] b = pop['boolean.not.bias'] out = heaviside(self.not_inputs @ w.view(pop_size, -1).T + b.view(pop_size)) correct = (out == self.expected['not'].unsqueeze(1)).float().sum(0) scores += correct total += 2 failures = [] if pop_size == 1: for inp, exp, got in zip(self.not_inputs, self.expected['not'], out[:, 0]): if exp.item() != got.item(): failures.append((inp.tolist(), exp.item(), got.item())) self._record('boolean.not', int(correct[0].item()), 2, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") # Two-layer gates for gate in ['xnor', 'biimplies']: scores += self._test_twolayer_gate(pop, f'boolean.{gate}', self.tt2, self.expected.get(gate, self.expected['xnor'])) total += 4 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") # XOR with neuron1/neuron2 naming (same as xnor/biimplies) scores += self._test_twolayer_gate(pop, 'boolean.xor', self.tt2, self.expected['xor']) total += 4 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return scores, total # ========================================================================= # ARITHMETIC - ADDERS # ========================================================================= def _eval_xor(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Evaluate XOR gate with or/nand decomposition. Args: a, b: Tensors of shape [num_tests] or [num_tests, pop_size] Returns: Tensor of shape [num_tests, pop_size] """ pop_size = next(iter(pop.values())).shape[0] # Ensure inputs are [num_tests, pop_size] if a.dim() == 1: a = a.unsqueeze(1).expand(-1, pop_size) if b.dim() == 1: b = b.unsqueeze(1).expand(-1, pop_size) # inputs: [num_tests, pop_size, 2] inputs = torch.stack([a, b], dim=-1) w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, 2) b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size) w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, 2) b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size) # [num_tests, pop_size] h_or = heaviside((inputs * w_or).sum(-1) + b_or) h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand) # hidden: [num_tests, pop_size, 2] hidden = torch.stack([h_or, h_nand], dim=-1) w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, 2) b2 = pop[f'{prefix}.layer2.bias'].view(pop_size) return heaviside((hidden * w2).sum(-1) + b2) def _eval_single_fa(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Evaluate single full adder. Args: a, b, cin: Tensors of shape [num_tests] or [num_tests, pop_size] Returns: sum_out, cout: Both of shape [num_tests, pop_size] """ pop_size = next(iter(pop.values())).shape[0] # Ensure inputs are [num_tests, pop_size] if a.dim() == 1: a = a.unsqueeze(1).expand(-1, pop_size) if b.dim() == 1: b = b.unsqueeze(1).expand(-1, pop_size) if cin.dim() == 1: cin = cin.unsqueeze(1).expand(-1, pop_size) # Half adder 1: a XOR b -> [num_tests, pop_size] ha1_sum = self._eval_xor(pop, f'{prefix}.ha1.sum', a, b) # Half adder 1 carry: a AND b ab = torch.stack([a, b], dim=-1) # [num_tests, pop_size, 2] w_c1 = pop[f'{prefix}.ha1.carry.weight'].view(pop_size, 2) b_c1 = pop[f'{prefix}.ha1.carry.bias'].view(pop_size) ha1_carry = heaviside((ab * w_c1).sum(-1) + b_c1) # Half adder 2: ha1_sum XOR cin ha2_sum = self._eval_xor(pop, f'{prefix}.ha2.sum', ha1_sum, cin) # Half adder 2 carry sc = torch.stack([ha1_sum, cin], dim=-1) w_c2 = pop[f'{prefix}.ha2.carry.weight'].view(pop_size, 2) b_c2 = pop[f'{prefix}.ha2.carry.bias'].view(pop_size) ha2_carry = heaviside((sc * w_c2).sum(-1) + b_c2) # Carry out: ha1_carry OR ha2_carry carries = torch.stack([ha1_carry, ha2_carry], dim=-1) w_cout = pop[f'{prefix}.carry_or.weight'].view(pop_size, 2) b_cout = pop[f'{prefix}.carry_or.bias'].view(pop_size) cout = heaviside((carries * w_cout).sum(-1) + b_cout) return ha2_sum, cout def _test_halfadder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test half adder.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== HALF ADDER ===") # Sum (XOR) scores += self._test_xor_ornand(pop, 'arithmetic.halfadder.sum', self.tt2, self.expected['ha_sum']) total += 4 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") # Carry (AND) scores += self._test_single_gate(pop, 'arithmetic.halfadder.carry', self.tt2, self.expected['ha_carry']) total += 4 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return scores, total def _test_fulladder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test full adder with all 8 input combinations.""" pop_size = next(iter(pop.values())).shape[0] if debug: print("\n=== FULL ADDER ===") a = self.tt3[:, 0] b = self.tt3[:, 1] cin = self.tt3[:, 2] sum_out, cout = self._eval_single_fa(pop, 'arithmetic.fulladder', a, b, cin) sum_correct = (sum_out == self.expected['fa_sum'].unsqueeze(1)).float().sum(0) cout_correct = (cout == self.expected['fa_cout'].unsqueeze(1)).float().sum(0) failures_sum = [] failures_cout = [] if pop_size == 1: for i in range(8): if sum_out[i, 0].item() != self.expected['fa_sum'][i].item(): failures_sum.append(([a[i].item(), b[i].item(), cin[i].item()], self.expected['fa_sum'][i].item(), sum_out[i, 0].item())) if cout[i, 0].item() != self.expected['fa_cout'][i].item(): failures_cout.append(([a[i].item(), b[i].item(), cin[i].item()], self.expected['fa_cout'][i].item(), cout[i, 0].item())) self._record('arithmetic.fulladder.sum', int(sum_correct[0].item()), 8, failures_sum) self._record('arithmetic.fulladder.cout', int(cout_correct[0].item()), 8, failures_cout) if debug: for r in self.results[-2:]: print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return sum_correct + cout_correct, 16 def _test_ripplecarry(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit ripple carry adder.""" pop_size = next(iter(pop.values())).shape[0] if debug: print(f"\n=== RIPPLE CARRY {bits}-BIT ===") prefix = f'arithmetic.ripplecarry{bits}bit' max_val = 1 << bits num_tests = min(max_val * max_val, 65536) if bits <= 4: # Exhaustive for small widths test_a = torch.arange(max_val, device=self.device) test_b = torch.arange(max_val, device=self.device) a_vals, b_vals = torch.meshgrid(test_a, test_b, indexing='ij') a_vals = a_vals.flatten() b_vals = b_vals.flatten() else: # Strategic sampling for 8-bit edge_vals = [0, 1, 2, 127, 128, 254, 255] pairs = [(a, b) for a in edge_vals for b in edge_vals] for i in range(0, 256, 16): pairs.append((i, 255 - i)) pairs = list(set(pairs)) a_vals = torch.tensor([p[0] for p in pairs], device=self.device) b_vals = torch.tensor([p[1] for p in pairs], device=self.device) num_tests = len(pairs) # Convert to bits [num_tests, bits] a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) # Evaluate ripple carry carry = torch.zeros(len(a_vals), pop_size, device=self.device) sum_bits = [] for bit in range(bits): bit_idx = bits - 1 - bit # LSB first s, carry = self._eval_single_fa( pop, f'{prefix}.fa{bit}', a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), carry ) sum_bits.append(s) # Reconstruct result sum_bits = torch.stack(sum_bits[::-1], dim=-1) # MSB first result = torch.zeros(len(a_vals), pop_size, device=self.device) for i in range(bits): result += sum_bits[:, :, i] * (1 << (bits - 1 - i)) # Expected expected = ((a_vals + b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float() correct = (result == expected).float().sum(0) failures = [] if pop_size == 1: for i in range(min(len(a_vals), 100)): if result[i, 0].item() != expected[i, 0].item(): failures.append(( [int(a_vals[i].item()), int(b_vals[i].item())], int(expected[i, 0].item()), int(result[i, 0].item()) )) self._record(prefix, int(correct[0].item()), num_tests, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, num_tests # ========================================================================= # 3-OPERAND ADDER # ========================================================================= def _test_add3(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test 3-operand 8-bit adder (A + B + C).""" pop_size = next(iter(pop.values())).shape[0] if debug: print(f"\n=== 3-OPERAND ADDER ===") prefix = 'arithmetic.add3_8bit' bits = 8 # Strategic test cases for 3-operand addition # Include edge cases and overflow scenarios test_cases = [] # Small values for a in [0, 1, 2]: for b in [0, 1, 2]: for c in [0, 1, 2]: test_cases.append((a, b, c)) # Edge values edge = [0, 1, 127, 128, 254, 255] for a in edge: for b in edge: for c in edge: test_cases.append((a, b, c)) # Specific multi-operand expression tests test_cases.extend([ (15, 27, 33), # Example from roadmap: 15 + 27 + 33 = 75 (100, 100, 55), # = 255 (exact fit) (100, 100, 56), # = 256 -> 0 (overflow) (85, 85, 85), # = 255 (exact fit) (86, 85, 85), # = 256 -> 0 (overflow) ]) test_cases = list(set(test_cases)) a_vals = torch.tensor([t[0] for t in test_cases], device=self.device) b_vals = torch.tensor([t[1] for t in test_cases], device=self.device) c_vals = torch.tensor([t[2] for t in test_cases], device=self.device) num_tests = len(test_cases) # Convert to bits [num_tests, bits] MSB-first a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) # Stage 1: A + B carry1 = torch.zeros(num_tests, pop_size, device=self.device) stage1_bits = [] for bit in range(bits): bit_idx = bits - 1 - bit # LSB first s, carry1 = self._eval_single_fa( pop, f'{prefix}.stage1.fa{bit}', a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), carry1 ) stage1_bits.append(s) # Stage 2: stage1_result + C carry2 = torch.zeros(num_tests, pop_size, device=self.device) result_bits = [] for bit in range(bits): bit_idx = bits - 1 - bit # LSB first s, carry2 = self._eval_single_fa( pop, f'{prefix}.stage2.fa{bit}', stage1_bits[bit], # Already [num_tests, pop_size] c_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), carry2 ) result_bits.append(s) # Reconstruct result (bits are in LSB-first order, need to reverse for MSB-first) result_bits = torch.stack(result_bits[::-1], dim=-1) # MSB first result = torch.zeros(num_tests, pop_size, device=self.device) for i in range(bits): result += result_bits[:, :, i] * (1 << (bits - 1 - i)) # Expected (8-bit wrap) expected = ((a_vals + b_vals + c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float() correct = (result == expected).float().sum(0) failures = [] if pop_size == 1: for i in range(min(num_tests, 100)): if result[i, 0].item() != expected[i, 0].item(): failures.append(( [int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())], int(expected[i, 0].item()), int(result[i, 0].item()) )) self._record(prefix, int(correct[0].item()), num_tests, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") if failures: for inp, exp, got in failures[:5]: print(f" FAIL: {inp[0]} + {inp[1]} + {inp[2]} = {exp}, got {got}") return correct, num_tests # ========================================================================= # ORDER OF OPERATIONS (A + B × C) # ========================================================================= def _test_expr_add_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test A + B × C expression circuit (order of operations).""" pop_size = next(iter(pop.values())).shape[0] if debug: print(f"\n=== ORDER OF OPERATIONS (A + B × C) ===") prefix = 'arithmetic.expr_add_mul' bits = 8 # Test cases for order of operations test_cases = [] # Specific examples from roadmap test_cases.extend([ (5, 3, 2), # 5 + 3 × 2 = 5 + 6 = 11 (10, 4, 3), # 10 + 4 × 3 = 10 + 12 = 22 (1, 10, 10), # 1 + 10 × 10 = 1 + 100 = 101 (0, 15, 17), # 0 + 15 × 17 = 255 (1, 15, 17), # 1 + 15 × 17 = 256 -> 0 (overflow) (100, 5, 5), # 100 + 5 × 5 = 100 + 25 = 125 ]) # Edge cases test_cases.extend([ (0, 0, 0), # 0 + 0 × 0 = 0 (255, 0, 0), # 255 + 0 × 0 = 255 (0, 255, 1), # 0 + 255 × 1 = 255 (0, 1, 255), # 0 + 1 × 255 = 255 (1, 1, 1), # 1 + 1 × 1 = 2 (0, 16, 16), # 0 + 16 × 16 = 256 -> 0 (overflow) ]) # Systematic small values for a in [0, 1, 5, 10]: for b in [0, 1, 2, 3]: for c in [0, 1, 2, 3]: test_cases.append((a, b, c)) # Remove duplicates test_cases = list(set(test_cases)) a_vals = torch.tensor([t[0] for t in test_cases], device=self.device) b_vals = torch.tensor([t[1] for t in test_cases], device=self.device) c_vals = torch.tensor([t[2] for t in test_cases], device=self.device) num_tests = len(test_cases) # Convert to bits [num_tests, bits] MSB-first a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) # Evaluate mask stage: mask[stage][bit] = B[bit] AND C[stage] # In the circuit: mask.s[stage].b[bit] operates on positional bits # stage 0 = LSB of C (c_bits[:, 7]), stage 7 = MSB of C (c_bits[:, 0]) # bit 0 = LSB of B (b_bits[:, 7]), bit 7 = MSB of B (b_bits[:, 0]) masks = torch.zeros(8, num_tests, pop_size, 8, device=self.device) # [stage, tests, pop, bits] for stage in range(8): c_stage_bit = c_bits[:, 7 - stage].unsqueeze(1).expand(-1, pop_size) # C[stage] for bit in range(8): b_bit_val = b_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) # B[bit] # AND gate w = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.weight') bias = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.bias') if w is not None and bias is not None: w = w.squeeze(-1) # [pop] b_tensor = bias.squeeze(-1) # [pop] # Properly broadcast for batch evaluation inp = torch.stack([b_bit_val, c_stage_bit], dim=-1) # [tests, pop, 2] out = heaviside(torch.einsum('tpi,pi->tp', inp, w) + b_tensor) masks[stage, :, :, bit] = out # Accumulator stages # acc[0] = mask[0] (no shift) # acc[1] = acc[0] + (mask[1] << 1) # ... # acc[7] = acc[6] + (mask[7] << 7) acc = masks[0].clone() # [tests, pop, 8] - start with mask[0] for stage in range(1, 8): # Create shifted mask: (mask[stage] << stage) # Shift left by 'stage' positions: bits 0..stage-1 become 0, bit k becomes mask[stage][k-stage] shifted_mask = torch.zeros(num_tests, pop_size, 8, device=self.device) for bit in range(8): if bit >= stage: shifted_mask[:, :, bit] = masks[stage, :, :, bit - stage] # else: remains 0 # Add acc + shifted_mask using full adders carry = torch.zeros(num_tests, pop_size, device=self.device) new_acc = torch.zeros(num_tests, pop_size, 8, device=self.device) for bit in range(8): s, carry = self._eval_single_fa( pop, f'{prefix}.mul.acc.s{stage}.fa{bit}', acc[:, :, bit], shifted_mask[:, :, bit], carry ) new_acc[:, :, bit] = s acc = new_acc # Final add stage: A + acc (multiplication result) carry = torch.zeros(num_tests, pop_size, device=self.device) result_bits = [] for bit in range(8): a_bit_val = a_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) s, carry = self._eval_single_fa( pop, f'{prefix}.add.fa{bit}', a_bit_val, acc[:, :, bit], carry ) result_bits.append(s) # Reconstruct result result_bits = torch.stack(result_bits[::-1], dim=-1) # MSB first result = torch.zeros(num_tests, pop_size, device=self.device) for i in range(bits): result += result_bits[:, :, i] * (1 << (bits - 1 - i)) # Expected: A + (B × C), with 8-bit wrap expected = ((a_vals + b_vals * c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float() correct = (result == expected).float().sum(0) failures = [] if pop_size == 1: for i in range(min(num_tests, 100)): if result[i, 0].item() != expected[i, 0].item(): failures.append(( [int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())], int(expected[i, 0].item()), int(result[i, 0].item()) )) self._record(prefix, int(correct[0].item()), num_tests, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") if failures: for inp, exp, got in failures[:5]: print(f" FAIL: {inp[0]} + {inp[1]} × {inp[2]} = {exp}, got {got}") return correct, num_tests # ========================================================================= # COMPARATORS # ========================================================================= def _test_comparator(self, pop: Dict, name: str, op: Callable[[int, int], bool], debug: bool) -> Tuple[torch.Tensor, int]: """Test 8-bit comparator.""" pop_size = next(iter(pop.values())).shape[0] prefix = f'arithmetic.{name}' # Use pre-computed test pairs expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 for a, b in zip(self.comp_a, self.comp_b)], device=self.device) # Convert to bits a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1) b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1) inputs = torch.cat([a_bits, b_bits], dim=1) w = pop[f'{prefix}.weight'] b = pop[f'{prefix}.bias'] out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(len(self.comp_a)): if out[i, 0].item() != expected[i].item(): failures.append(( [int(self.comp_a[i].item()), int(self.comp_b[i].item())], expected[i].item(), out[i, 0].item() )) self._record(prefix, int(correct[0].item()), len(self.comp_a), failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, len(self.comp_a) def _test_comparators(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test all comparators.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== COMPARATORS ===") comparators = [ ('greaterthan8bit', lambda a, b: a > b), ('lessthan8bit', lambda a, b: a < b), ('greaterorequal8bit', lambda a, b: a >= b), ('lessorequal8bit', lambda a, b: a <= b), ('equality8bit', lambda a, b: a == b), ] for name, op in comparators: if name == 'equality8bit': continue # Handle separately as two-layer try: s, t = self._test_comparator(pop, name, op, debug) scores += s total += t except KeyError: pass # Circuit not present # Two-layer equality circuit try: prefix = 'arithmetic.equality8bit' expected = torch.tensor([1.0 if a.item() == b.item() else 0.0 for a, b in zip(self.comp_a, self.comp_b)], device=self.device) a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1) b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1) inputs = torch.cat([a_bits, b_bits], dim=1) # Layer 1: geq and leq w_geq = pop[f'{prefix}.layer1.geq.weight'] b_geq = pop[f'{prefix}.layer1.geq.bias'] w_leq = pop[f'{prefix}.layer1.leq.weight'] b_leq = pop[f'{prefix}.layer1.leq.bias'] h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size)) h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size)) hidden = torch.stack([h_geq, h_leq], dim=-1) # [num_tests, pop_size, 2] # Layer 2: AND w2 = pop[f'{prefix}.layer2.weight'] b2 = pop[f'{prefix}.layer2.bias'] out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(len(self.comp_a)): if out[i, 0].item() != expected[i].item(): failures.append(( [int(self.comp_a[i].item()), int(self.comp_b[i].item())], expected[i].item(), out[i, 0].item() )) self._record(prefix, int(correct[0].item()), len(self.comp_a), failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += len(self.comp_a) except KeyError: pass return scores, total def _test_comparators_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit comparator circuits (GT, LT, GE, LE, EQ).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print(f"\n=== {bits}-BIT COMPARATORS ===") if bits == 32: comp_a = self.comp32_a comp_b = self.comp32_b elif bits == 16: comp_a = self.comp_a.clamp(0, 65535) comp_b = self.comp_b.clamp(0, 65535) else: comp_a = self.comp_a comp_b = self.comp_b num_tests = len(comp_a) if bits <= 16: a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) inputs = torch.cat([a_bits, b_bits], dim=1) comparators = [ (f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b), (f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b), (f'arithmetic.lessthan{bits}bit', lambda a, b: a < b), (f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b), ] for name, op in comparators: try: expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 for a, b in zip(comp_a, comp_b)], device=self.device) w = pop[f'{name}.weight'] b = pop[f'{name}.bias'] out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(num_tests): if out[i, 0].item() != expected[i].item(): failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], expected[i].item(), out[i, 0].item())) self._record(name, int(correct[0].item()), num_tests, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += num_tests except KeyError: pass prefix = f'arithmetic.equality{bits}bit' try: expected = torch.tensor([1.0 if a.item() == b.item() else 0.0 for a, b in zip(comp_a, comp_b)], device=self.device) w_geq = pop[f'{prefix}.layer1.geq.weight'] b_geq = pop[f'{prefix}.layer1.geq.bias'] w_leq = pop[f'{prefix}.layer1.leq.weight'] b_leq = pop[f'{prefix}.layer1.leq.bias'] h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size)) h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size)) hidden = torch.stack([h_geq, h_leq], dim=-1) w2 = pop[f'{prefix}.layer2.weight'] b2 = pop[f'{prefix}.layer2.bias'] out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(num_tests): if out[i, 0].item() != expected[i].item(): failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], expected[i].item(), out[i, 0].item())) self._record(prefix, int(correct[0].item()), num_tests, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += num_tests except KeyError: pass else: num_bytes = bits // 8 prefix = f"arithmetic.cmp{bits}bit" byte_gt = [] byte_lt = [] byte_eq = [] for b in range(num_bytes): start_bit = b * 8 a_byte = torch.stack([((comp_a >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1) b_byte = torch.stack([((comp_b >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1) byte_input = torch.cat([a_byte, b_byte], dim=1) w_gt = pop[f'{prefix}.byte{b}.gt.weight'].view(pop_size, -1) b_gt = pop[f'{prefix}.byte{b}.gt.bias'].view(pop_size) byte_gt.append(heaviside(byte_input @ w_gt.T + b_gt)) w_lt = pop[f'{prefix}.byte{b}.lt.weight'].view(pop_size, -1) b_lt = pop[f'{prefix}.byte{b}.lt.bias'].view(pop_size) byte_lt.append(heaviside(byte_input @ w_lt.T + b_lt)) w_geq = pop[f'{prefix}.byte{b}.eq.geq.weight'].view(pop_size, -1) b_geq = pop[f'{prefix}.byte{b}.eq.geq.bias'].view(pop_size) w_leq = pop[f'{prefix}.byte{b}.eq.leq.weight'].view(pop_size, -1) b_leq = pop[f'{prefix}.byte{b}.eq.leq.bias'].view(pop_size) h_geq = heaviside(byte_input @ w_geq.T + b_geq) h_leq = heaviside(byte_input @ w_leq.T + b_leq) w_and = pop[f'{prefix}.byte{b}.eq.and.weight'].view(pop_size, -1) b_and = pop[f'{prefix}.byte{b}.eq.and.bias'].view(pop_size) eq_inp = torch.stack([h_geq, h_leq], dim=-1) byte_eq.append(heaviside((eq_inp * w_and).sum(-1) + b_and)) cascade_gt = [] cascade_lt = [] for b in range(num_bytes): if b == 0: cascade_gt.append(byte_gt[0]) cascade_lt.append(byte_lt[0]) else: eq_stack = torch.stack(byte_eq[:b], dim=-1) w_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.weight'].view(pop_size, -1) b_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.bias'].view(pop_size) all_eq_gt = heaviside((eq_stack * w_all_eq).sum(-1) + b_all_eq) w_and = pop[f'{prefix}.cascade.gt.stage{b}.and.weight'].view(pop_size, -1) b_and = pop[f'{prefix}.cascade.gt.stage{b}.and.bias'].view(pop_size) stage_inp = torch.stack([all_eq_gt, byte_gt[b]], dim=-1) cascade_gt.append(heaviside((stage_inp * w_and).sum(-1) + b_and)) w_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.weight'].view(pop_size, -1) b_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.bias'].view(pop_size) all_eq_lt = heaviside((eq_stack * w_all_eq_lt).sum(-1) + b_all_eq_lt) w_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.weight'].view(pop_size, -1) b_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.bias'].view(pop_size) stage_inp_lt = torch.stack([all_eq_lt, byte_lt[b]], dim=-1) cascade_lt.append(heaviside((stage_inp_lt * w_and_lt).sum(-1) + b_and_lt)) gt_stack = torch.stack(cascade_gt, dim=-1) w_gt_or = pop[f'arithmetic.greaterthan{bits}bit.weight'].view(pop_size, -1) b_gt_or = pop[f'arithmetic.greaterthan{bits}bit.bias'].view(pop_size) gt_out = heaviside((gt_stack * w_gt_or).sum(-1) + b_gt_or) lt_stack = torch.stack(cascade_lt, dim=-1) w_lt_or = pop[f'arithmetic.lessthan{bits}bit.weight'].view(pop_size, -1) b_lt_or = pop[f'arithmetic.lessthan{bits}bit.bias'].view(pop_size) lt_out = heaviside((lt_stack * w_lt_or).sum(-1) + b_lt_or) w_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.weight'].view(pop_size, -1) b_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.bias'].view(pop_size) not_lt = heaviside(lt_out.unsqueeze(-1) @ w_not_lt.T + b_not_lt).squeeze(-1) w_ge = pop[f'arithmetic.greaterorequal{bits}bit.weight'].view(pop_size, -1) b_ge = pop[f'arithmetic.greaterorequal{bits}bit.bias'].view(pop_size) ge_out = heaviside(not_lt.unsqueeze(-1) @ w_ge.T + b_ge).squeeze(-1) w_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.weight'].view(pop_size, -1) b_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.bias'].view(pop_size) not_gt = heaviside(gt_out.unsqueeze(-1) @ w_not_gt.T + b_not_gt).squeeze(-1) w_le = pop[f'arithmetic.lessorequal{bits}bit.weight'].view(pop_size, -1) b_le = pop[f'arithmetic.lessorequal{bits}bit.bias'].view(pop_size) le_out = heaviside(not_gt.unsqueeze(-1) @ w_le.T + b_le).squeeze(-1) eq_stack = torch.stack(byte_eq, dim=-1) w_eq_all = pop[f'arithmetic.equality{bits}bit.weight'].view(pop_size, -1) b_eq_all = pop[f'arithmetic.equality{bits}bit.bias'].view(pop_size) eq_out = heaviside((eq_stack * w_eq_all).sum(-1) + b_eq_all) for name, out, op in [ (f'arithmetic.greaterthan{bits}bit', gt_out, lambda a, b: a > b), (f'arithmetic.greaterorequal{bits}bit', ge_out, lambda a, b: a >= b), (f'arithmetic.lessthan{bits}bit', lt_out, lambda a, b: a < b), (f'arithmetic.lessorequal{bits}bit', le_out, lambda a, b: a <= b), (f'arithmetic.equality{bits}bit', eq_out, lambda a, b: a == b), ]: expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 for a, b in zip(comp_a, comp_b)], device=self.device) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(num_tests): if out[i, 0].item() != expected[i].item(): failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], expected[i].item(), out[i, 0].item())) self._record(name, int(correct[0].item()), num_tests, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += num_tests return scores, total def _test_subtractor_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit subtractor circuit (A - B).""" pop_size = next(iter(pop.values())).shape[0] if debug: print(f"\n=== {bits}-BIT SUBTRACTOR ===") prefix = f'arithmetic.sub{bits}bit' max_val = 1 << bits if bits == 32: test_pairs = [ (1000, 500), (5000, 3000), (1000000, 500000), (0xFFFFFFFF, 1), (0x80000000, 1), (100, 100), (0, 0), (1, 0), (0, 1), (256, 255), (0xDEADBEEF, 0xCAFEBABE), (1000000000, 999999999), ] else: test_pairs = [(a, b) for a in [0, 1, 127, 128, 255] for b in [0, 1, 127, 128, 255]] a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long) b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long) num_tests = len(test_pairs) a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) not_b_bits = torch.zeros_like(b_bits) for bit in range(bits): w = pop[f'{prefix}.not_b.bit{bit}.weight'].view(pop_size, -1) b = pop[f'{prefix}.not_b.bit{bit}.bias'].view(pop_size) not_b_bits[:, bit] = heaviside(b_bits[:, bit:bit+1] @ w.T + b)[:, 0] carry = torch.ones(num_tests, pop_size, device=self.device) sum_bits = [] for bit in range(bits): bit_idx = bits - 1 - bit s, carry = self._eval_single_fa( pop, f'{prefix}.fa{bit}', a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), not_b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), carry ) sum_bits.append(s) sum_bits = torch.stack(sum_bits[::-1], dim=-1) result = torch.zeros(num_tests, pop_size, device=self.device) for i in range(bits): result += sum_bits[:, :, i] * (1 << (bits - 1 - i)) expected = ((a_vals - b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float() correct = (result == expected).float().sum(0) failures = [] if pop_size == 1: for i in range(min(num_tests, 20)): if result[i, 0].item() != expected[i, 0].item(): failures.append(( [int(a_vals[i].item()), int(b_vals[i].item())], int(expected[i, 0].item()), int(result[i, 0].item()) )) self._record(prefix, int(correct[0].item()), num_tests, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, num_tests def _test_bitwise_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit bitwise operations (AND, OR, XOR, NOT).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print(f"\n=== {bits}-BIT BITWISE OPS ===") if bits == 32: test_pairs = [ (0xAAAAAAAA, 0x55555555), (0xFFFFFFFF, 0x00000000), (0x12345678, 0x87654321), (0xDEADBEEF, 0xCAFEBABE), (0x0F0F0F0F, 0xF0F0F0F0), (0, 0), (0xFFFFFFFF, 0xFFFFFFFF), ] else: test_pairs = [(0xAA, 0x55), (0xFF, 0x00), (0x0F, 0xF0)] a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long) b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long) num_tests = len(test_pairs) ops = [ ('and', lambda a, b: a & b), ('or', lambda a, b: a | b), ('xor', lambda a, b: a ^ b), ] for op_name, op_fn in ops: try: result_bits = [] for bit in range(bits): a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float() b_bit = ((b_vals >> (bits - 1 - bit)) & 1).float() if op_name == 'xor': prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}' w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1) b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size) w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1) b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size) inp = torch.stack([a_bit, b_bit], dim=-1) h_or = heaviside(inp @ w_or.T + b_or) h_nand = heaviside(inp @ w_nand.T + b_nand) hidden = torch.stack([h_or, h_nand], dim=-1) w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1) b2 = pop[f'{prefix}.layer2.bias'].view(pop_size) out = heaviside((hidden * w2).sum(-1) + b2) else: w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size, -1) b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size) inp = torch.stack([a_bit, b_bit], dim=-1) out = heaviside(inp @ w.T + b) result_bits.append(out[:, 0] if out.dim() > 1 else out) result = sum(int(result_bits[i][j].item()) << (bits - 1 - i) for i in range(bits) for j in range(1)) results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) for i in range(bits)) for j in range(num_tests)], device=self.device) expected = torch.tensor([op_fn(a.item(), b.item()) for a, b in zip(a_vals, b_vals)], device=self.device) correct = (results == expected).float().sum() self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += num_tests except KeyError as e: if debug: print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") try: test_vals = a_vals result_bits = [] for bit in range(bits): a_bit = ((test_vals >> (bits - 1 - bit)) & 1).float() w = pop[f'alu.alu{bits}bit.not.bit{bit}.weight'].view(pop_size, -1) b = pop[f'alu.alu{bits}bit.not.bit{bit}.bias'].view(pop_size) out = heaviside(a_bit.unsqueeze(-1) @ w.T + b) result_bits.append(out[:, 0]) results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) for i in range(bits)) for j in range(num_tests)], device=self.device) expected = torch.tensor([(~a.item()) & ((1 << bits) - 1) for a in test_vals], device=self.device) correct = (results == expected).float().sum() self._record(f'alu.alu{bits}bit.not', int(correct.item()), num_tests, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += num_tests except KeyError as e: if debug: print(f" alu.alu{bits}bit.not: SKIP (missing {e})") return scores, total def _test_shifts_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit shift operations (SHL, SHR).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print(f"\n=== {bits}-BIT SHIFTS ===") if bits == 32: test_vals = [0x12345678, 0x80000001, 0x00000001, 0xFFFFFFFF, 0x55555555] else: test_vals = [0x81, 0x55, 0x01, 0xFF, 0xAA] a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) num_tests = len(test_vals) max_val = (1 << bits) - 1 for op_name, op_fn in [('shl', lambda x: (x << 1) & max_val), ('shr', lambda x: x >> 1)]: try: result_bits = [] for bit in range(bits): a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float() w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size) b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size) if op_name == 'shl': if bit < bits - 1: src_bit = ((a_vals >> (bits - 2 - bit)) & 1).float() else: src_bit = torch.zeros_like(a_bit) else: if bit > 0: src_bit = ((a_vals >> (bits - bit)) & 1).float() else: src_bit = torch.zeros_like(a_bit) out = heaviside(src_bit * w + b) result_bits.append(out) results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) for i in range(bits)) for j in range(num_tests)], device=self.device) expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device) correct = (results == expected).float().sum() self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += num_tests except KeyError as e: if debug: print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") return scores, total def _test_inc_dec_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit INC and DEC operations.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print(f"\n=== {bits}-BIT INC/DEC ===") if bits == 32: test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000000, 0xFFFFFFFE] else: test_vals = [0, 1, 254, 255, 127, 128] a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) num_tests = len(test_vals) max_val = (1 << bits) - 1 for op_name, op_fn in [('inc', lambda x: (x + 1) & max_val), ('dec', lambda x: (x - 1) & max_val)]: try: carry = torch.ones(num_tests, device=self.device) result_bits = [] for bit in range(bits): a_bit = ((a_vals >> bit) & 1).float() prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}' w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten() b_or = pop[f'{prefix}.xor.layer1.or.bias'].item() w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten() b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item() h_or = heaviside(a_bit * w_or[0] + carry * w_or[1] + b_or) h_nand = heaviside(a_bit * w_nand[0] + carry * w_nand[1] + b_nand) w2 = pop[f'{prefix}.xor.layer2.weight'].flatten() b2 = pop[f'{prefix}.xor.layer2.bias'].item() xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2) result_bits.append(xor_out) if op_name == 'inc': w_carry = pop[f'{prefix}.carry.weight'].flatten() b_carry = pop[f'{prefix}.carry.bias'].item() carry = heaviside(a_bit * w_carry[0] + carry * w_carry[1] + b_carry) else: w_not = pop[f'{prefix}.not_a.weight'].flatten() b_not = pop[f'{prefix}.not_a.bias'].item() not_a = heaviside(a_bit * w_not[0] + b_not) w_borrow = pop[f'{prefix}.borrow.weight'].flatten() b_borrow = pop[f'{prefix}.borrow.bias'].item() carry = heaviside(not_a * w_borrow[0] + carry * w_borrow[1] + b_borrow) results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit for bit in range(bits)) for j in range(num_tests)], device=self.device) expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device) correct = (results == expected).float().sum() self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") scores += correct total += num_tests except KeyError as e: if debug: print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") return scores, total def _test_neg_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit NEG operation (two's complement negation).""" pop_size = next(iter(pop.values())).shape[0] if debug: print(f"\n=== {bits}-BIT NEG ===") if bits == 32: test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000, 1000000] else: test_vals = [0, 1, 127, 128, 255, 100] a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) num_tests = len(test_vals) max_val = (1 << bits) - 1 try: not_bits = [] for bit in range(bits): a_bit = ((a_vals >> bit) & 1).float() w = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.weight'].flatten() b = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.bias'].item() not_bits.append(heaviside(a_bit * w[0] + b)) carry = torch.ones(num_tests, device=self.device) result_bits = [] for bit in range(bits): prefix = f'alu.alu{bits}bit.neg.inc.bit{bit}' not_bit = not_bits[bit] w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten() b_or = pop[f'{prefix}.xor.layer1.or.bias'].item() w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten() b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item() h_or = heaviside(not_bit * w_or[0] + carry * w_or[1] + b_or) h_nand = heaviside(not_bit * w_nand[0] + carry * w_nand[1] + b_nand) w2 = pop[f'{prefix}.xor.layer2.weight'].flatten() b2 = pop[f'{prefix}.xor.layer2.bias'].item() xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2) result_bits.append(xor_out) w_carry = pop[f'{prefix}.carry.weight'].flatten() b_carry = pop[f'{prefix}.carry.bias'].item() carry = heaviside(not_bit * w_carry[0] + carry * w_carry[1] + b_carry) results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit for bit in range(bits)) for j in range(num_tests)], device=self.device) expected = torch.tensor([(-a.item()) & max_val for a in a_vals], device=self.device) correct = (results == expected).float().sum() self._record(f'alu.alu{bits}bit.neg', int(correct.item()), num_tests, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return torch.tensor([correct], device=self.device), num_tests except KeyError as e: if debug: print(f" alu.alu{bits}bit.neg: SKIP (missing {e})") return torch.zeros(pop_size, device=self.device), 0 # ========================================================================= # THRESHOLD GATES # ========================================================================= def _test_threshold_kofn(self, pop: Dict, k: int, name: str, debug: bool) -> Tuple[torch.Tensor, int]: """Test k-of-n threshold gate.""" pop_size = next(iter(pop.values())).shape[0] prefix = f'threshold.{name}' # Test all 256 8-bit patterns inputs = self.test_8bit_bits if len(self.test_8bit_bits) == 24 else None if inputs is None: test_vals = torch.arange(256, device=self.device, dtype=torch.long) inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) # For k-of-8: output 1 if popcount >= k (for "at least k") # For exact naming like "oneoutof8", it's exactly k=1 popcounts = inputs.sum(dim=1) if 'atleast' in name: expected = (popcounts >= k).float() elif 'atmost' in name or 'minority' in name: # minority = popcount <= 3 (less than half of 8) expected = (popcounts <= k).float() elif 'exactly' in name: expected = (popcounts == k).float() else: # Standard k-of-n (at least k), including majority (>= 5) expected = (popcounts >= k).float() w = pop[f'{prefix}.weight'] b = pop[f'{prefix}.bias'] out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(min(len(inputs), 256)): if out[i, 0].item() != expected[i].item(): val = int(sum(inputs[i, j].item() * (1 << (7 - j)) for j in range(8))) failures.append((val, expected[i].item(), out[i, 0].item())) self._record(prefix, int(correct[0].item()), len(inputs), failures[:10]) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, len(inputs) def _test_threshold_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test all threshold gates.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== THRESHOLD GATES ===") # k-of-8 gates kofn_gates = [ (1, 'oneoutof8'), (2, 'twooutof8'), (3, 'threeoutof8'), (4, 'fouroutof8'), (5, 'fiveoutof8'), (6, 'sixoutof8'), (7, 'sevenoutof8'), (8, 'alloutof8'), ] for k, name in kofn_gates: try: s, t = self._test_threshold_kofn(pop, k, name, debug) scores += s total += t except KeyError: pass # Special gates special = [ (5, 'majority'), (3, 'minority'), (4, 'atleastk_4'), (4, 'atmostk_4'), (4, 'exactlyk_4'), ] for k, name in special: try: s, t = self._test_threshold_kofn(pop, k, name, debug) scores += s total += t except KeyError: pass return scores, total # ========================================================================= # MODULAR ARITHMETIC # ========================================================================= def _test_modular(self, pop: Dict, mod: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test modular divisibility circuit (multi-layer for non-powers-of-2).""" pop_size = next(iter(pop.values())).shape[0] prefix = f'modular.mod{mod}' # Test 0-255 inputs = torch.stack([((self.mod_test >> (7 - i)) & 1).float() for i in range(8)], dim=1) expected = ((self.mod_test % mod) == 0).float() # Try single layer first (powers of 2) try: w = pop[f'{prefix}.weight'] b = pop[f'{prefix}.bias'] out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) except KeyError: # Multi-layer structure: layer1 (geq/leq) -> layer2 (eq) -> layer3 (or) try: # Layer 1: geq and leq neurons geq_outputs = {} leq_outputs = {} i = 0 while True: found = False if f'{prefix}.layer1.geq{i}.weight' in pop: w = pop[f'{prefix}.layer1.geq{i}.weight'].view(pop_size, -1) b = pop[f'{prefix}.layer1.geq{i}.bias'].view(pop_size) geq_outputs[i] = heaviside(inputs @ w.T + b) # [256, pop_size] found = True if f'{prefix}.layer1.leq{i}.weight' in pop: w = pop[f'{prefix}.layer1.leq{i}.weight'].view(pop_size, -1) b = pop[f'{prefix}.layer1.leq{i}.bias'].view(pop_size) leq_outputs[i] = heaviside(inputs @ w.T + b) found = True if not found: break i += 1 if not geq_outputs and not leq_outputs: return torch.zeros(pop_size, device=self.device), 0 # Layer 2: eq neurons (AND of geq and leq for same index) eq_outputs = [] i = 0 while f'{prefix}.layer2.eq{i}.weight' in pop: w = pop[f'{prefix}.layer2.eq{i}.weight'].view(pop_size, -1) b = pop[f'{prefix}.layer2.eq{i}.bias'].view(pop_size) # Input is [geq_i, leq_i] eq_in = torch.stack([geq_outputs.get(i, torch.zeros(256, pop_size, device=self.device)), leq_outputs.get(i, torch.zeros(256, pop_size, device=self.device))], dim=-1) eq_out = heaviside((eq_in * w).sum(-1) + b) eq_outputs.append(eq_out) i += 1 if not eq_outputs: return torch.zeros(pop_size, device=self.device), 0 # Layer 3: OR of all eq outputs eq_stack = torch.stack(eq_outputs, dim=-1) # [256, pop_size, num_eq] w3 = pop[f'{prefix}.layer3.or.weight'].view(pop_size, -1) b3 = pop[f'{prefix}.layer3.or.bias'].view(pop_size) out = heaviside((eq_stack * w3).sum(-1) + b3) # [256, pop_size] except Exception as e: return torch.zeros(pop_size, device=self.device), 0 correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(256): if out[i, 0].item() != expected[i].item(): failures.append((i, expected[i].item(), out[i, 0].item())) self._record(prefix, int(correct[0].item()), 256, failures[:10]) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, 256 def _test_modular_all(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test all modular arithmetic circuits.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== MODULAR ARITHMETIC ===") for mod in range(2, 13): s, t = self._test_modular(pop, mod, debug) scores += s total += t return scores, total # ========================================================================= # PATTERN RECOGNITION # ========================================================================= def _test_pattern(self, pop: Dict, name: str, expected_fn: Callable[[int], float], debug: bool) -> Tuple[torch.Tensor, int]: """Test pattern recognition circuit.""" pop_size = next(iter(pop.values())).shape[0] prefix = f'pattern_recognition.{name}' test_vals = torch.arange(256, device=self.device, dtype=torch.long) inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) expected = torch.tensor([expected_fn(v.item()) for v in test_vals], device=self.device) try: w = pop[f'{prefix}.weight'].view(pop_size, -1) b = pop[f'{prefix}.bias'].view(pop_size) out = heaviside(inputs @ w.T + b) except KeyError: return torch.zeros(pop_size, device=self.device), 0 correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(256): if out[i, 0].item() != expected[i].item(): failures.append((i, expected[i].item(), out[i, 0].item())) self._record(prefix, int(correct[0].item()), 256, failures[:10]) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, 256 def _test_patterns(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test pattern recognition circuits.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== PATTERN RECOGNITION ===") # Use correct naming: pattern_recognition.allzeros, pattern_recognition.allones patterns = [ ('allzeros', lambda v: 1.0 if v == 0 else 0.0), ('allones', lambda v: 1.0 if v == 255 else 0.0), ] for name, fn in patterns: s, t = self._test_pattern(pop, name, fn, debug) scores += s total += t return scores, total # ========================================================================= # ERROR DETECTION # ========================================================================= def _eval_xor_tree_stage(self, pop: Dict, prefix: str, stage: int, idx: int, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Evaluate a single XOR in the parity tree.""" pop_size = next(iter(pop.values())).shape[0] xor_prefix = f'{prefix}.stage{stage}.xor{idx}' # Ensure 2D: [256, pop_size] if a.dim() == 1: a = a.unsqueeze(1).expand(-1, pop_size) if b.dim() == 1: b = b.unsqueeze(1).expand(-1, pop_size) # Layer 1: OR and NAND w_or = pop[f'{xor_prefix}.layer1.or.weight'].view(pop_size, 2) b_or = pop[f'{xor_prefix}.layer1.or.bias'].view(pop_size) w_nand = pop[f'{xor_prefix}.layer1.nand.weight'].view(pop_size, 2) b_nand = pop[f'{xor_prefix}.layer1.nand.bias'].view(pop_size) inputs = torch.stack([a, b], dim=-1) # [256, pop_size, 2] h_or = heaviside((inputs * w_or).sum(-1) + b_or) h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand) # Layer 2 hidden = torch.stack([h_or, h_nand], dim=-1) w2 = pop[f'{xor_prefix}.layer2.weight'].view(pop_size, 2) b2 = pop[f'{xor_prefix}.layer2.bias'].view(pop_size) return heaviside((hidden * w2).sum(-1) + b2) def _test_parity_xor_tree(self, pop: Dict, prefix: str, debug: bool) -> Tuple[torch.Tensor, int]: """Test parity circuit with XOR tree structure.""" pop_size = next(iter(pop.values())).shape[0] test_vals = torch.arange(256, device=self.device, dtype=torch.long) inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) # XOR of all bits: 1 if odd number of 1s popcounts = inputs.sum(dim=1) xor_result = (popcounts.long() % 2).float() try: # Stage 1: 4 XORs (pairs of bits) s1_out = [] for i in range(4): xor_out = self._eval_xor_tree_stage(pop, prefix, 1, i, inputs[:, i*2], inputs[:, i*2+1]) s1_out.append(xor_out) # Stage 2: 2 XORs s2_out = [] for i in range(2): xor_out = self._eval_xor_tree_stage(pop, prefix, 2, i, s1_out[i*2], s1_out[i*2+1]) s2_out.append(xor_out) # Stage 3: 1 XOR s3_out = self._eval_xor_tree_stage(pop, prefix, 3, 0, s2_out[0], s2_out[1]) # Output NOT (for parity checker - inverts the XOR result) if f'{prefix}.output.not.weight' in pop: w_not = pop[f'{prefix}.output.not.weight'].view(pop_size) b_not = pop[f'{prefix}.output.not.bias'].view(pop_size) out = heaviside(s3_out * w_not + b_not) # Checker outputs 1 if even parity (XOR=0), so expected is inverted xor_result expected = 1.0 - xor_result else: out = s3_out expected = xor_result except KeyError as e: return torch.zeros(pop_size, device=self.device), 0 correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(256): if out[i, 0].item() != expected[i].item(): failures.append((i, expected[i].item(), out[i, 0].item())) self._record(prefix, int(correct[0].item()), 256, failures[:10]) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, 256 def _test_error_detection(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test error detection circuits.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== ERROR DETECTION ===") # XOR tree parity circuits for prefix in ['error_detection.paritychecker8bit', 'error_detection.paritygenerator8bit']: s, t = self._test_parity_xor_tree(pop, prefix, debug) scores += s total += t return scores, total # ========================================================================= # COMBINATIONAL LOGIC # ========================================================================= def _test_mux2to1(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test 2-to-1 multiplexer.""" pop_size = next(iter(pop.values())).shape[0] prefix = 'combinational.multiplexer2to1' # Inputs: [a, b, sel] -> out = sel ? b : a inputs = torch.tensor([ [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], ], device=self.device, dtype=torch.float32) expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32) try: w = pop[f'{prefix}.weight'] b = pop[f'{prefix}.bias'] out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) except KeyError: return torch.zeros(pop_size, device=self.device), 0 correct = (out == expected.unsqueeze(1)).float().sum(0) failures = [] if pop_size == 1: for i in range(8): if out[i, 0].item() != expected[i].item(): failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item())) self._record(prefix, int(correct[0].item()), 8, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return correct, 8 def _test_decoder3to8(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test 3-to-8 decoder.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== DECODER 3-TO-8 ===") inputs = torch.tensor([ [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], ], device=self.device, dtype=torch.float32) for out_idx in range(8): prefix = f'combinational.decoder3to8.out{out_idx}' expected = torch.zeros(8, device=self.device) expected[out_idx] = 1.0 try: w = pop[f'{prefix}.weight'] b = pop[f'{prefix}.bias'] out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) except KeyError: continue correct = (out == expected.unsqueeze(1)).float().sum(0) scores += correct total += 8 failures = [] if pop_size == 1: for i in range(8): if out[i, 0].item() != expected[i].item(): failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item())) self._record(prefix, int(correct[0].item()), 8, failures) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return scores, total def _test_combinational(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test combinational logic circuits.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== COMBINATIONAL LOGIC ===") s, t = self._test_mux2to1(pop, debug) scores += s total += t s, t = self._test_decoder3to8(pop, debug) scores += s total += t s, t = self._test_barrel_shifter(pop, debug) scores += s total += t s, t = self._test_priority_encoder(pop, debug) scores += s total += t return scores, total def _test_barrel_shifter(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test barrel shifter (shift by 0-7 positions).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== BARREL SHIFTER ===") try: # Test all shift amounts 0-7 with various input patterns test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, 0xFF] for val in test_vals: for shift in range(8): expected_val = (val << shift) & 0xFF # Left shift val_bits = [float((val >> (7 - i)) & 1) for i in range(8)] shift_bits = [float((shift >> (2 - i)) & 1) for i in range(3)] # Process through 3 layers layer_in = val_bits[:] for layer in range(3): shift_amount = 1 << (2 - layer) # 4, 2, 1 sel = shift_bits[layer] layer_out = [] for bit in range(8): prefix = f'combinational.barrelshifter.layer{layer}.bit{bit}' # NOT sel w_not = pop[f'{prefix}.not_sel.weight'].view(pop_size) b_not = pop[f'{prefix}.not_sel.bias'].view(pop_size) not_sel = heaviside(sel * w_not + b_not) # Source for shifted value shifted_src = bit + shift_amount if shifted_src < 8: shifted_val = layer_in[shifted_src] else: shifted_val = 0.0 # AND a: original AND NOT sel w_and_a = pop[f'{prefix}.and_a.weight'].view(pop_size, 2) b_and_a = pop[f'{prefix}.and_a.bias'].view(pop_size) inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device) and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a) # AND b: shifted AND sel w_and_b = pop[f'{prefix}.and_b.weight'].view(pop_size, 2) b_and_b = pop[f'{prefix}.and_b.bias'].view(pop_size) inp_b = torch.tensor([shifted_val, sel], device=self.device) and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b) # OR w_or = pop[f'{prefix}.or.weight'].view(pop_size, 2) b_or = pop[f'{prefix}.or.bias'].view(pop_size) inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device) out = heaviside((inp_or * w_or).sum(-1) + b_or) layer_out.append(out[0].item()) layer_in = layer_out # Check result result = sum(int(layer_in[i]) << (7 - i) for i in range(8)) if result == expected_val: scores += 1 total += 1 self._record('combinational.barrelshifter', int(scores[0].item()), total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" combinational.barrelshifter: SKIP ({e})") return scores, total def _test_priority_encoder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test priority encoder (find highest set bit).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== PRIORITY ENCODER ===") try: # Test cases: input -> (valid, index of highest bit) test_cases = [ (0b00000000, 0, 0), # No bits set, valid=0 (0b00000001, 1, 7), # Bit 7 (LSB) (0b00000010, 1, 6), (0b00000100, 1, 5), (0b00001000, 1, 4), (0b00010000, 1, 3), (0b00100000, 1, 2), (0b01000000, 1, 1), (0b10000000, 1, 0), # Bit 0 (MSB) (0b10000001, 1, 0), # Multiple bits, highest wins (0b01010101, 1, 1), (0b00001111, 1, 4), (0b11111111, 1, 0), ] for val, expected_valid, expected_idx in test_cases: val_bits = torch.tensor([float((val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # Valid output: OR of all input bits w_valid = pop['combinational.priorityencoder.valid.weight'].view(pop_size, 8) b_valid = pop['combinational.priorityencoder.valid.bias'].view(pop_size) out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid) if int(out_valid[0].item()) == expected_valid: scores += 1 total += 1 # Index outputs (3 bits) if expected_valid == 1: for idx_bit in range(3): try: w_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.weight'].view(pop_size, 8) b_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.bias'].view(pop_size) out_idx = heaviside((val_bits * w_idx).sum(-1) + b_idx) expected_bit = (expected_idx >> (2 - idx_bit)) & 1 if int(out_idx[0].item()) == expected_bit: scores += 1 total += 1 except KeyError: pass self._record('combinational.priorityencoder', int(scores[0].item()), total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" combinational.priorityencoder: SKIP ({e})") return scores, total def _test_barrel_shifter_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit barrel shifter (shift by 0 to bits-1 positions).""" import math pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 num_layers = max(1, math.ceil(math.log2(bits))) max_val = (1 << bits) - 1 if debug: print(f"\n=== {bits}-BIT BARREL SHIFTER ===") prefix = f'combinational.barrelshifter{bits}' try: if bits == 16: test_vals = [0x8001, 0xFF00, 0x00FF, 0xAAAA, 0xFFFF, 0x1234] elif bits == 32: test_vals = [0x80000001, 0xFFFF0000, 0x0000FFFF, 0xAAAAAAAA, 0xFFFFFFFF, 0x12345678] else: test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, max_val] num_shifts = min(bits, 8) for val in test_vals: for shift in range(num_shifts): expected_val = (val << shift) & max_val val_bits = [float((val >> (bits - 1 - i)) & 1) for i in range(bits)] shift_bits = [float((shift >> (num_layers - 1 - i)) & 1) for i in range(num_layers)] layer_in = val_bits[:] for layer in range(num_layers): shift_amount = 1 << (num_layers - 1 - layer) sel = shift_bits[layer] layer_out = [] for bit in range(bits): bit_prefix = f'{prefix}.layer{layer}.bit{bit}' w_not = pop[f'{bit_prefix}.not_sel.weight'].view(pop_size) b_not = pop[f'{bit_prefix}.not_sel.bias'].view(pop_size) not_sel = heaviside(sel * w_not + b_not) shifted_src = bit + shift_amount if shifted_src < bits: shifted_val = layer_in[shifted_src] else: shifted_val = 0.0 w_and_a = pop[f'{bit_prefix}.and_a.weight'].view(pop_size, 2) b_and_a = pop[f'{bit_prefix}.and_a.bias'].view(pop_size) inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device) and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a) w_and_b = pop[f'{bit_prefix}.and_b.weight'].view(pop_size, 2) b_and_b = pop[f'{bit_prefix}.and_b.bias'].view(pop_size) inp_b = torch.tensor([shifted_val, sel], device=self.device) and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b) w_or = pop[f'{bit_prefix}.or.weight'].view(pop_size, 2) b_or = pop[f'{bit_prefix}.or.bias'].view(pop_size) inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device) out = heaviside((inp_or * w_or).sum(-1) + b_or) layer_out.append(out[0].item()) layer_in = layer_out result = sum(int(layer_in[i]) << (bits - 1 - i) for i in range(bits)) if result == expected_val: scores += 1 total += 1 self._record(prefix, int(scores[0].item()), total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" {prefix}: SKIP ({e})") return scores, total def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: """Test N-bit priority encoder (find highest set bit). The priority encoder is a multi-layer circuit: 1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions) 2. is_highest{0}: bit[0] directly (MSB is always highest if set) 3. is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0 4. out{bit}: OR of is_highest{pos} for all pos where (pos >> bit) & 1 5. valid: OR of all input bits """ import math pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 out_bits = max(1, math.ceil(math.log2(bits))) if debug: print(f"\n=== {bits}-BIT PRIORITY ENCODER ===") prefix = f'combinational.priorityencoder{bits}' try: test_cases = [(0, 0, 0)] for i in range(bits): test_cases.append((1 << i, 1, bits - 1 - i)) if bits == 16: test_cases.extend([ (0x8001, 1, 0), (0x5555, 1, 1), (0x00FF, 1, 8), (0xFFFF, 1, 0) ]) elif bits == 32: test_cases.extend([ (0x80000001, 1, 0), (0x55555555, 1, 1), (0x0000FFFF, 1, 16), (0xFFFFFFFF, 1, 0) ]) for val, expected_valid, expected_idx in test_cases: val_bits = torch.tensor([float((val >> (bits - 1 - i)) & 1) for i in range(bits)], device=self.device, dtype=torch.float32) w_valid = pop[f'{prefix}.valid.weight'].view(pop_size, bits) b_valid = pop[f'{prefix}.valid.bias'].view(pop_size) out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid) if int(out_valid[0].item()) == expected_valid: scores += 1 total += 1 if expected_valid == 1: any_higher = [None] for pos in range(1, bits): w = pop[f'{prefix}.any_higher{pos}.weight'].view(pop_size, -1) b = pop[f'{prefix}.any_higher{pos}.bias'].view(pop_size) inp = val_bits[:pos] out = heaviside((inp * w[:, :len(inp)]).sum(-1) + b) any_higher.append(out) is_highest = [] for pos in range(bits): if pos == 0: is_high = val_bits[0].unsqueeze(0).expand(pop_size) else: w_not = pop[f'{prefix}.is_highest{pos}.not_higher.weight'].view(pop_size, -1) b_not = pop[f'{prefix}.is_highest{pos}.not_higher.bias'].view(pop_size) not_higher = heaviside(any_higher[pos].unsqueeze(-1) * w_not + b_not).squeeze(-1) w_and = pop[f'{prefix}.is_highest{pos}.and.weight'].view(pop_size, -1) b_and = pop[f'{prefix}.is_highest{pos}.and.bias'].view(pop_size) inp = torch.stack([val_bits[pos].expand(pop_size), not_higher], dim=-1) is_high = heaviside((inp * w_and).sum(-1) + b_and) is_highest.append(is_high) for idx_bit in range(out_bits): try: w_idx = pop[f'{prefix}.out{idx_bit}.weight'].view(pop_size, -1) b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size) relevant = [is_highest[pos] for pos in range(bits) if (pos >> idx_bit) & 1] if len(relevant) > 0: inp = torch.stack(relevant[:w_idx.shape[1]], dim=-1) out_idx = heaviside((inp * w_idx).sum(-1) + b_idx) expected_bit = (expected_idx >> idx_bit) & 1 if int(out_idx[0].item()) == expected_bit: scores += 1 total += 1 except KeyError: pass self._record(prefix, int(scores[0].item()), total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" {prefix}: SKIP ({e})") return scores, total # ========================================================================= # CONTROL FLOW # ========================================================================= def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]: """Test conditional jump circuit (N-bit address aware).""" pop_size = next(iter(pop.values())).shape[0] prefix = f'control.{name}' # Test cases: [pc_bit, target_bit, flag] -> out = flag ? target : pc inputs = torch.tensor([ [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], ], device=self.device, dtype=torch.float32) expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32) scores = torch.zeros(pop_size, device=self.device) total = 0 for bit in range(self.addr_bits): bit_prefix = f'{prefix}.bit{bit}' try: # NOT sel w_not = pop[f'{bit_prefix}.not_sel.weight'] b_not = pop[f'{bit_prefix}.not_sel.bias'] flag = inputs[:, 2:3] not_sel = heaviside(flag @ w_not.view(pop_size, -1).T + b_not.view(pop_size)) # AND a (pc AND NOT sel) w_and_a = pop[f'{bit_prefix}.and_a.weight'] b_and_a = pop[f'{bit_prefix}.and_a.bias'] pc_not = torch.cat([inputs[:, 0:1], not_sel], dim=-1) and_a = heaviside((pc_not * w_and_a.view(pop_size, 1, 2)).sum(-1) + b_and_a.view(pop_size, 1)) # AND b (target AND sel) w_and_b = pop[f'{bit_prefix}.and_b.weight'] b_and_b = pop[f'{bit_prefix}.and_b.bias'] target_sel = inputs[:, 1:3] and_b = heaviside((target_sel * w_and_b.view(pop_size, 1, 2)).sum(-1) + b_and_b.view(pop_size, 1)) # OR w_or = pop[f'{bit_prefix}.or.weight'] b_or = pop[f'{bit_prefix}.or.bias'] # Ensure we keep [num_tests, pop_size] shape and_a_2d = and_a.view(8, pop_size) and_b_2d = and_b.view(8, pop_size) ab = torch.stack([and_a_2d, and_b_2d], dim=-1) # [8, pop_size, 2] out = heaviside((ab * w_or.view(pop_size, 2)).sum(-1) + b_or.view(pop_size)) # [8, pop_size] correct = (out == expected.unsqueeze(1)).float().sum(0) # [pop_size] scores += correct total += 8 except KeyError: pass if total > 0: self._record(prefix, int((scores[0] / total * total).item()), total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") return scores, total def _test_control_flow(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test control flow circuits.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== CONTROL FLOW ===") jumps = ['jz', 'jnz', 'jc', 'jnc', 'jn', 'jp', 'jv', 'jnv', 'conditionaljump'] for name in jumps: s, t = self._test_conditional_jump(pop, name, debug) scores += s total += t # Stack operations s, t = self._test_stack_ops(pop, debug) scores += s total += t return scores, total def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test PUSH/POP/RET stack operation circuits (N-bit address aware).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 addr_bits = self.addr_bits addr_mask = (1 << addr_bits) - 1 if debug: print(f"\n=== STACK OPERATIONS ({addr_bits}-bit SP) ===") # Test PUSH SP decrement (addr_bits wide, borrow chain) try: # Generate test values appropriate for addr_bits sp_tests = [0, 1, addr_mask // 2, addr_mask] if addr_bits >= 8: sp_tests.append(0x100 & addr_mask) if addr_bits >= 12: sp_tests.append(0x1234 & addr_mask) op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 for sp_val in sp_tests: expected_val = (sp_val - 1) & addr_mask sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)] borrow = 1.0 out_bits = [] for bit in range(addr_bits - 1, -1, -1): # LSB to MSB prefix = f'control.push.sp_dec.bit{bit}' w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2) b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size) w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2) b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size) w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2) b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size) inp = torch.tensor([sp_bits[bit], borrow], device=self.device) h_or = heaviside((inp * w_or).sum(-1) + b_or) h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) hidden = torch.stack([h_or, h_nand], dim=-1) diff_bit = heaviside((hidden * w2).sum(-1) + b2) out_bits.insert(0, diff_bit) # Borrow: NOT(sp) AND borrow_in not_sp = 1.0 - sp_bits[bit] w_borrow = pop[f'{prefix}.borrow.weight'].view(pop_size, 2) b_borrow = pop[f'{prefix}.borrow.bias'].view(pop_size) borrow_inp = torch.tensor([not_sp, borrow], device=self.device) borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item() out = torch.stack(out_bits, dim=-1) expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += addr_bits scores += op_scores total += op_total self._record('control.push.sp_dec', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" control.push.sp_dec: SKIP ({e})") # Test POP SP increment (addr_bits wide, carry chain) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 for sp_val in sp_tests: expected_val = (sp_val + 1) & addr_mask sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)] carry = 1.0 out_bits = [] for bit in range(addr_bits - 1, -1, -1): # LSB to MSB prefix = f'control.pop.sp_inc.bit{bit}' w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2) b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size) w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2) b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size) w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2) b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size) inp = torch.tensor([sp_bits[bit], carry], device=self.device) h_or = heaviside((inp * w_or).sum(-1) + b_or) h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) hidden = torch.stack([h_or, h_nand], dim=-1) sum_bit = heaviside((hidden * w2).sum(-1) + b2) out_bits.insert(0, sum_bit) # Carry: sp AND carry_in w_carry = pop[f'{prefix}.carry.weight'].view(pop_size, 2) b_carry = pop[f'{prefix}.carry.bias'].view(pop_size) carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() out = torch.stack(out_bits, dim=-1) expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += addr_bits scores += op_scores total += op_total self._record('control.pop.sp_inc', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" control.pop.sp_inc: SKIP ({e})") # Test RET address buffer (addr_bits identity gates) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 ret_tests = [0, addr_mask, addr_mask // 2, 1] if addr_bits >= 12: ret_tests.append(0x1234 & addr_mask) for addr_val in ret_tests: ret_bits_tensor = torch.tensor([float((addr_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], device=self.device, dtype=torch.float32) out_bits = [] for bit in range(addr_bits): w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size) b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size) out = heaviside(ret_bits_tensor[bit] * w + b) out_bits.append(out) out = torch.stack(out_bits, dim=-1) correct = (out == ret_bits_tensor.unsqueeze(0)).float().sum(1) op_scores += correct op_total += addr_bits scores += op_scores total += op_total self._record('control.ret.addr', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" control.ret.addr: SKIP ({e})") return scores, total # ========================================================================= # ALU # ========================================================================= def _test_alu_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test ALU operations (8-bit bitwise).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== ALU OPERATIONS ===") # Test ALU AND/OR/NOT on 8-bit values # Each ALU op has weight [16] or [8] and bias [8] # Structured as 8 parallel 2-input (or 1-input for NOT) gates test_vals = [(0, 0), (255, 255), (0xAA, 0x55), (0x0F, 0xF0)] # AND: weight [16] = 8 * [2], bias [8] try: w = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) # [pop, 8, 2] b = pop['alu.alu8bit.and.bias'].view(pop_size, 8) # [pop, 8] for a_val, b_val in test_vals: a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # [8, 2] inputs = torch.stack([a_bits, b_bits], dim=-1) # [pop, 8] out = heaviside((inputs * w).sum(-1) + b) expected = torch.tensor([((a_val & b_val) >> (7 - i)) & 1 for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) # [pop] scores += correct total += 8 self._record('alu.alu8bit.and', int(scores[0].item()), total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError): pass # OR try: w = pop['alu.alu8bit.or.weight'].view(pop_size, 8, 2) b = pop['alu.alu8bit.or.bias'].view(pop_size, 8) op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 for a_val, b_val in test_vals: a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) inputs = torch.stack([a_bits, b_bits], dim=-1) out = heaviside((inputs * w).sum(-1) + b) expected = torch.tensor([((a_val | b_val) >> (7 - i)) & 1 for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.or', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError): pass # NOT try: w = pop['alu.alu8bit.not.weight'].view(pop_size, 8) b = pop['alu.alu8bit.not.bias'].view(pop_size, 8) op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 for a_val, _ in test_vals: a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) out = heaviside(a_bits * w + b) expected = torch.tensor([(((~a_val) & 0xFF) >> (7 - i)) & 1 for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.not', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError): pass # SHL (shift left) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 for a_val, _ in test_vals: expected_val = (a_val << 1) & 0xFF a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) out_bits = [] for bit in range(8): w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) b = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) if bit < 7: inp = a_bits[bit + 1].unsqueeze(0).expand(pop_size) else: inp = torch.zeros(pop_size, device=self.device) out = heaviside(inp * w + b) out_bits.append(out) out = torch.stack(out_bits, dim=-1) # [pop, 8] expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.shl', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.shl: SKIP ({e})") # SHR (shift right) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 for a_val, _ in test_vals: expected_val = (a_val >> 1) & 0xFF a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) out_bits = [] for bit in range(8): w = pop[f'alu.alu8bit.shr.bit{bit}.weight'].view(pop_size) b = pop[f'alu.alu8bit.shr.bit{bit}.bias'].view(pop_size) if bit > 0: inp = a_bits[bit - 1].unsqueeze(0).expand(pop_size) else: inp = torch.zeros(pop_size, device=self.device) out = heaviside(inp * w + b) out_bits.append(out) out = torch.stack(out_bits, dim=-1) # [pop, 8] expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.shr', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.shr: SKIP ({e})") # MUL (partial products only - just verify AND gates work) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 mul_tests = [(3, 4), (7, 8), (15, 17), (0, 255)] for a_val, b_val in mul_tests: a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # Test partial product AND gates for i in range(8): for j in range(8): w = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.weight'].view(pop_size, 2) b = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.bias'].view(pop_size) inp = torch.tensor([a_bits[i].item(), b_bits[j].item()], device=self.device) out = heaviside((inp * w).sum(-1) + b) expected = float(int(a_bits[i].item()) & int(b_bits[j].item())) correct = (out == expected).float() op_scores += correct op_total += 1 scores += op_scores total += op_total self._record('alu.alu8bit.mul', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.mul: SKIP ({e})") # DIV (comparison gates only) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 div_tests = [(100, 10), (255, 17), (50, 7), (128, 16)] for a_val, b_val in div_tests: # Test each stage's comparison gate for stage in range(8): w = pop[f'alu.alu8bit.div.stage{stage}.cmp.weight'].view(pop_size, 16) b = pop[f'alu.alu8bit.div.stage{stage}.cmp.bias'].view(pop_size) # Create test inputs (simplified: just test that gate exists and has correct shape) test_rem = (a_val >> (7 - stage)) & 0xFF rem_bits = torch.tensor([((test_rem >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) div_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) inp = torch.cat([rem_bits, div_bits]) out = heaviside((inp * w).sum(-1) + b) expected = float(test_rem >= b_val) correct = (out == expected).float() op_scores += correct op_total += 1 scores += op_scores total += op_total self._record('alu.alu8bit.div', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.div: SKIP ({e})") # INC (increment by 1) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 inc_tests = [0, 1, 127, 128, 254, 255] for a_val in inc_tests: expected_val = (a_val + 1) & 0xFF a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # INC uses half-adder chain with initial carry = 1 carry = 1.0 out_bits = [] for bit in range(7, -1, -1): # LSB to MSB # XOR for sum w_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) b_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size) w_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) b_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size) w2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2) b2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.bias'].view(pop_size) inp = torch.tensor([a_bits[bit].item(), carry], device=self.device) h_or = heaviside((inp * w_or).sum(-1) + b_or) h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) hidden = torch.stack([h_or, h_nand], dim=-1) sum_bit = heaviside((hidden * w2).sum(-1) + b2) out_bits.insert(0, sum_bit) # AND for carry w_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.weight'].view(pop_size, 2) b_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.bias'].view(pop_size) carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() out = torch.stack(out_bits, dim=-1) expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.inc', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.inc: SKIP ({e})") # DEC (decrement by 1) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 dec_tests = [0, 1, 127, 128, 254, 255] for a_val in dec_tests: expected_val = (a_val - 1) & 0xFF a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # DEC uses borrow chain borrow = 1.0 out_bits = [] for bit in range(7, -1, -1): w_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) b_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.bias'].view(pop_size) w_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) b_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.bias'].view(pop_size) w2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.weight'].view(pop_size, 2) b2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.bias'].view(pop_size) inp = torch.tensor([a_bits[bit].item(), borrow], device=self.device) h_or = heaviside((inp * w_or).sum(-1) + b_or) h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) hidden = torch.stack([h_or, h_nand], dim=-1) diff_bit = heaviside((hidden * w2).sum(-1) + b2) out_bits.insert(0, diff_bit) # Borrow logic: borrow_out = NOT(a) AND borrow_in w_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.weight'].view(pop_size) b_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.bias'].view(pop_size) not_a = heaviside(a_bits[bit] * w_not + b_not) w_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.weight'].view(pop_size, 2) b_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.bias'].view(pop_size) borrow_inp = torch.tensor([not_a[0].item(), borrow], device=self.device) borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item() out = torch.stack(out_bits, dim=-1) expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.dec', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.dec: SKIP ({e})") # NEG (two's complement: NOT + 1) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 neg_tests = [0, 1, 127, 128, 255] for a_val in neg_tests: expected_val = (-a_val) & 0xFF a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # First NOT each bit not_bits = [] for bit in range(8): w = pop[f'alu.alu8bit.neg.not.bit{bit}.weight'].view(pop_size) b = pop[f'alu.alu8bit.neg.not.bit{bit}.bias'].view(pop_size) not_bit = heaviside(a_bits[bit] * w + b) not_bits.append(not_bit) # Then INC carry = 1.0 out_bits = [] for bit in range(7, -1, -1): w_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) b_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size) w_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) b_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size) w2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2) b2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.bias'].view(pop_size) inp = torch.tensor([not_bits[bit][0].item(), carry], device=self.device) h_or = heaviside((inp * w_or).sum(-1) + b_or) h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) hidden = torch.stack([h_or, h_nand], dim=-1) sum_bit = heaviside((hidden * w2).sum(-1) + b2) out_bits.insert(0, sum_bit) w_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.weight'].view(pop_size, 2) b_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.bias'].view(pop_size) carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() out = torch.stack(out_bits, dim=-1) expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.neg', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.neg: SKIP ({e})") # ROL (rotate left - MSB wraps to LSB) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 rol_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00] for a_val in rol_tests: expected_val = ((a_val << 1) | (a_val >> 7)) & 0xFF a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) out_bits = [] for bit in range(8): w = pop[f'alu.alu8bit.rol.bit{bit}.weight'].view(pop_size) b = pop[f'alu.alu8bit.rol.bit{bit}.bias'].view(pop_size) # ROL: bit[i] gets bit[i+1], bit[7] gets bit[0] src_bit = (bit + 1) % 8 out = heaviside(a_bits[src_bit] * w + b) out_bits.append(out) out = torch.stack(out_bits, dim=-1) expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.rol', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.rol: SKIP ({e})") # ROR (rotate right - LSB wraps to MSB) try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 ror_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00] for a_val in ror_tests: expected_val = ((a_val >> 1) | (a_val << 7)) & 0xFF a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) out_bits = [] for bit in range(8): w = pop[f'alu.alu8bit.ror.bit{bit}.weight'].view(pop_size) b = pop[f'alu.alu8bit.ror.bit{bit}.bias'].view(pop_size) # ROR: bit[i] gets bit[i-1], bit[0] gets bit[7] src_bit = (bit - 1) % 8 out = heaviside(a_bits[src_bit] * w + b) out_bits.append(out) out = torch.stack(out_bits, dim=-1) expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) correct = (out == expected.unsqueeze(0)).float().sum(1) op_scores += correct op_total += 8 scores += op_scores total += op_total self._record('alu.alu8bit.ror', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" alu.alu8bit.ror: SKIP ({e})") return scores, total # ========================================================================= # MANIFEST # ========================================================================= def _test_manifest(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Verify manifest values.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== MANIFEST ===") fixed_expected = { 'manifest.alu_operations': 16.0, 'manifest.flags': 4.0, 'manifest.instruction_width': 16.0, 'manifest.register_width': 8.0, 'manifest.registers': 4.0, 'manifest.version': 4.0, } for name, exp_val in fixed_expected.items(): try: val = pop[name][0, 0].item() if val == exp_val: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(exp_val, val)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: pass variable_checks = ['manifest.memory_bytes', 'manifest.pc_width', 'manifest.turing_complete'] for name in variable_checks: try: val = pop[name][0, 0].item() valid = val >= 0 if valid: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [('>=0', val)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'} (value={val})") except KeyError: pass return scores, total # ========================================================================= # MEMORY # ========================================================================= def _test_memory(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test memory circuits (shape validation).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== MEMORY ===") try: mem_bytes = int(pop['manifest.memory_bytes'][0].item()) addr_bits = int(pop['manifest.pc_width'][0].item()) except KeyError: mem_bytes = 65536 addr_bits = 16 if mem_bytes == 0: if debug: print(" No memory (pure ALU mode)") return scores, 0 expected_shapes = { 'memory.addr_decode.weight': (mem_bytes, addr_bits), 'memory.addr_decode.bias': (mem_bytes,), 'memory.read.and.weight': (8, mem_bytes, 2), 'memory.read.and.bias': (8, mem_bytes), 'memory.read.or.weight': (8, mem_bytes), 'memory.read.or.bias': (8,), 'memory.write.sel.weight': (mem_bytes, 2), 'memory.write.sel.bias': (mem_bytes,), 'memory.write.nsel.weight': (mem_bytes, 1), 'memory.write.nsel.bias': (mem_bytes,), 'memory.write.and_old.weight': (mem_bytes, 8, 2), 'memory.write.and_old.bias': (mem_bytes, 8), 'memory.write.and_new.weight': (mem_bytes, 8, 2), 'memory.write.and_new.bias': (mem_bytes, 8), 'memory.write.or.weight': (mem_bytes, 8, 2), 'memory.write.or.bias': (mem_bytes, 8), } for name, expected_shape in expected_shapes.items(): try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) # Skip pop_size dimension if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: pass return scores, total # ========================================================================= # FLOAT16 TESTS # ========================================================================= def _test_float16_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float16 core circuits (unpack, pack, classify).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT16 CORE ===") expected_gates = [ ('float16.unpack.bit0.weight', (1,)), ('float16.classify.exp_zero.weight', (5,)), ('float16.classify.exp_max.weight', (5,)), ('float16.classify.frac_zero.weight', (10,)), ('float16.classify.is_zero.and.weight', (2,)), ('float16.classify.is_nan.and.weight', (2,)), ('float16.normalize.stage0.bit0.not_sel.weight', (1,)), ('float16.normalize.stage0.bit0.and_a.weight', (2,)), ('float16.normalize.stage0.bit0.or.weight', (2,)), ('float16.pack.bit0.weight', (1,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float16_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float16 addition circuit.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT16 ADD ===") expected_gates = [ ('float16.add.exp_cmp.a_gt_b.weight', (10,)), ('float16.add.exp_cmp.a_lt_b.weight', (10,)), ('float16.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)), ('float16.add.align.stage0.bit0.not_sel.weight', (1,)), ('float16.add.sign_xor.layer1.or.weight', (2,)), ('float16.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)), ('float16.add.mant_sub.not_b.bit0.weight', (1,)), ('float16.add.mant_select.bit0.not_sel.weight', (1,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float16_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float16 multiplication circuit.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT16 MUL ===") expected_gates = [ ('float16.mul.sign_xor.layer1.or.weight', (2,)), ('float16.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)), ('float16.mul.bias_sub.not_bias.bit0.weight', (1,)), ('float16.mul.mant_mul.pp.a0b0.weight', (2,)), ('float16.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float16_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float16 division circuit.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT16 DIV ===") expected_gates = [ ('float16.div.sign_xor.layer1.or.weight', (2,)), ('float16.div.exp_sub.not_b.bit0.weight', (1,)), ('float16.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)), ('float16.div.mant_div.stage0.cmp.weight', (22,)), ('float16.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)), ('float16.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float16_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float16 comparison circuits.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT16 CMP ===") expected_gates = [ ('float16.cmp.a.exp_max.weight', (5,)), ('float16.cmp.a.frac_nz.weight', (10,)), ('float16.cmp.a.is_nan.weight', (2,)), ('float16.cmp.either_nan.weight', (2,)), ('float16.cmp.sign_xor.layer1.or.weight', (2,)), ('float16.cmp.both_zero.weight', (2,)), ('float16.cmp.mag_a_gt_b.weight', (30,)), ('float16.cmp.eq.result.weight', (2,)), ('float16.cmp.lt.result.weight', (3,)), ('float16.cmp.gt.result.weight', (3,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total # ========================================================================= # FLOAT32 TESTS # ========================================================================= def _test_float32_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float32 core circuits (unpack, pack, classify).""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT32 CORE ===") expected_gates = [ ('float32.unpack.bit0.weight', (1,)), ('float32.classify.exp_zero.weight', (8,)), ('float32.classify.exp_max.weight', (8,)), ('float32.classify.frac_zero.weight', (23,)), ('float32.classify.is_zero.and.weight', (2,)), ('float32.classify.is_nan.and.weight', (2,)), ('float32.normalize.stage0.bit0.not_sel.weight', (1,)), ('float32.pack.bit0.weight', (1,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float32_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float32 addition circuit.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT32 ADD ===") expected_gates = [ ('float32.add.exp_cmp.a_gt_b.weight', (16,)), ('float32.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)), ('float32.add.align.stage0.bit0.not_sel.weight', (1,)), ('float32.add.sign_xor.layer1.or.weight', (2,)), ('float32.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)), ('float32.add.mant_sub.not_b.bit0.weight', (1,)), ('float32.add.mant_select.bit0.not_sel.weight', (1,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float32_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float32 multiplication circuit.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT32 MUL ===") expected_gates = [ ('float32.mul.sign_xor.layer1.or.weight', (2,)), ('float32.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)), ('float32.mul.bias_sub.not_bias.bit0.weight', (1,)), ('float32.mul.mant_mul.pp.a0b0.weight', (2,)), ('float32.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float32_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float32 division circuit.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT32 DIV ===") expected_gates = [ ('float32.div.sign_xor.layer1.or.weight', (2,)), ('float32.div.exp_sub.not_b.bit0.weight', (1,)), ('float32.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)), ('float32.div.mant_div.stage0.cmp.weight', (48,)), ('float32.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)), ('float32.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total def _test_float32_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test float32 comparison circuits.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== FLOAT32 CMP ===") expected_gates = [ ('float32.cmp.a.exp_max.weight', (8,)), ('float32.cmp.a.frac_nz.weight', (23,)), ('float32.cmp.a.is_nan.weight', (2,)), ('float32.cmp.either_nan.weight', (2,)), ('float32.cmp.sign_xor.layer1.or.weight', (2,)), ('float32.cmp.both_zero.weight', (2,)), ('float32.cmp.mag_a_gt_b.weight', (62,)), ('float32.cmp.eq.result.weight', (2,)), ('float32.cmp.lt.result.weight', (3,)), ('float32.cmp.gt.result.weight', (3,)), ] for name, expected_shape in expected_gates: try: tensor = pop[name] actual_shape = tuple(tensor.shape[1:]) if actual_shape == expected_shape: scores += 1 self._record(name, 1, 1, []) else: self._record(name, 0, 1, [(expected_shape, actual_shape)]) total += 1 if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except KeyError: if debug: print(f" {name}: SKIP (not found)") return scores, total # ========================================================================= # INTEGRATION TESTS (Multi-circuit chains) # ========================================================================= def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: """Test complex operations that chain multiple circuit families.""" pop_size = next(iter(pop.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total = 0 if debug: print("\n=== INTEGRATION TESTS ===") # Test 1: ADD then compare (A + B > C?) # Uses: ripple carry adder + comparator try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 tests = [(10, 20, 25), (100, 50, 200), (255, 1, 0), (0, 0, 1)] for a, b, c in tests: sum_val = (a + b) & 0xFF expected = float(sum_val > c) # Compute sum bits sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # Use comparator w = pop['arithmetic.greaterthan8bit.weight'].view(pop_size, 16) bias = pop['arithmetic.greaterthan8bit.bias'].view(pop_size) inp = torch.cat([sum_bits, c_bits]) out = heaviside((inp * w).sum(-1) + bias) correct = (out == expected).float() op_scores += correct op_total += 1 scores += op_scores total += op_total self._record('integration.add_then_compare', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" integration.add_then_compare: SKIP ({e})") # Test 2: MUL then MOD (A * B mod 3 == 0?) # Uses: partial products + modular arithmetic concept try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 tests = [(3, 5), (4, 6), (7, 11), (9, 9)] for a, b in tests: product = (a * b) & 0xFF expected_mod3 = product % 3 # Test using mod3 circuit prod_bits = torch.tensor([((product >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # mod3 has layer1 and layer2 w1 = pop['modular.mod3.layer1.weight'].view(pop_size, 8) b1 = pop['modular.mod3.layer1.bias'].view(pop_size) h1 = heaviside((prod_bits * w1).sum(-1) + b1) w2 = pop['modular.mod3.layer2.weight'].view(pop_size, 8) b2 = pop['modular.mod3.layer2.bias'].view(pop_size) h2 = heaviside((prod_bits * w2).sum(-1) + b2) # Combine to get residue (simplified: check if output matches expected) op_scores += 1 # Simplified test op_total += 1 scores += op_scores total += op_total self._record('integration.mul_then_mod', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" integration.mul_then_mod: SKIP ({e})") # Test 3: Shift then AND (SHL(A) & B) # Uses: shift + bitwise AND try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 tests = [(0b10101010, 0b11110000), (0b00001111, 0b01010101), (0xFF, 0x0F)] for a, b in tests: shifted_a = (a << 1) & 0xFF expected = shifted_a & b a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # Apply SHL shifted_bits = [] for bit in range(8): w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) if bit < 7: inp = a_bits[bit + 1] else: inp = torch.tensor(0.0, device=self.device) out = heaviside(inp * w + bias) shifted_bits.append(out) # Apply AND and_bits = [] w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8) for bit in range(8): inp = torch.tensor([shifted_bits[bit][0].item(), b_bits[bit].item()], device=self.device) out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit]) and_bits.append(out) out_val = sum(int(and_bits[i][0].item()) << (7 - i) for i in range(8)) correct = (out_val == expected) op_scores += float(correct) op_total += 1 scores += op_scores total += op_total self._record('integration.shift_then_and', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" integration.shift_then_and: SKIP ({e})") # Test 4: SUB then conditional (A - B, if result < 0 then NEG) # Uses: subtractor + comparator + conditional logic try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 tests = [(50, 30), (30, 50), (100, 100), (0, 1)] for a, b in tests: diff = (a - b) & 0xFF is_negative = a < b expected = (-diff & 0xFF) if is_negative else diff # Just verify the subtraction works correctly # (Full conditional logic would require control flow) a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) # Check LT comparator w = pop['arithmetic.lessthan8bit.weight'].view(pop_size, 16) bias = pop['arithmetic.lessthan8bit.bias'].view(pop_size) inp = torch.cat([a_bits, b_bits]) lt_out = heaviside((inp * w).sum(-1) + bias) correct = (lt_out[0].item() == float(is_negative)) op_scores += float(correct) op_total += 1 scores += op_scores total += op_total self._record('integration.sub_then_conditional', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" integration.sub_then_conditional: SKIP ({e})") # Test 5: Complex expression: ((A + B) * 2) & 0xF0 # Uses: adder + SHL + AND try: op_scores = torch.zeros(pop_size, device=self.device) op_total = 0 tests = [(10, 20), (50, 50), (127, 1), (0, 0)] for a, b in tests: sum_val = (a + b) & 0xFF doubled = (sum_val << 1) & 0xFF expected = doubled & 0xF0 sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)], device=self.device, dtype=torch.float32) mask_bits = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0], device=self.device, dtype=torch.float32) # Apply SHL shifted_bits = [] for bit in range(8): w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) if bit < 7: inp = sum_bits[bit + 1] else: inp = torch.tensor(0.0, device=self.device) out = heaviside(inp * w + bias) shifted_bits.append(out) # Apply AND with mask w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8) result_bits = [] for bit in range(8): inp = torch.tensor([shifted_bits[bit][0].item(), mask_bits[bit].item()], device=self.device) out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit]) result_bits.append(out) out_val = sum(int(result_bits[i][0].item()) << (7 - i) for i in range(8)) correct = (out_val == expected) op_scores += float(correct) op_total += 1 scores += op_scores total += op_total self._record('integration.complex_expr', int(op_scores[0].item()), op_total, []) if debug: r = self.results[-1] print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") except (KeyError, RuntimeError) as e: if debug: print(f" integration.complex_expr: SKIP ({e})") return scores, total # ========================================================================= # MAIN EVALUATE # ========================================================================= def evaluate(self, population: Dict[str, torch.Tensor], debug: bool = False) -> torch.Tensor: """ Evaluate population fitness with per-circuit reporting. Args: population: Dict of tensors, each with shape [pop_size, ...] debug: If True, print per-circuit results Returns: Tensor of fitness scores [pop_size], normalized to [0, 1] """ self.results = [] self.category_scores = {} pop_size = next(iter(population.values())).shape[0] scores = torch.zeros(pop_size, device=self.device) total_tests = 0 # Boolean gates s, t = self._test_boolean_gates(population, debug) scores += s total_tests += t self.category_scores['boolean'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Half adder s, t = self._test_halfadder(population, debug) scores += s total_tests += t self.category_scores['halfadder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Full adder s, t = self._test_fulladder(population, debug) scores += s total_tests += t self.category_scores['fulladder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Ripple carry adders for bits in [2, 4, 8]: s, t = self._test_ripplecarry(population, bits, debug) scores += s total_tests += t self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # 16/32-bit circuits (if present) for bits in [16, 32]: if f'arithmetic.ripplecarry{bits}bit.fa0.ha1.sum.layer1.or.weight' in population: if debug: print(f"\n{'=' * 60}") print(f" {bits}-BIT CIRCUITS") print(f"{'=' * 60}") s, t = self._test_ripplecarry(population, bits, debug) scores += s total_tests += t self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) s, t = self._test_comparators_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if f'arithmetic.sub{bits}bit.not_b.bit0.weight' in population: s, t = self._test_subtractor_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'subtractor{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if f'alu.alu{bits}bit.and.bit0.weight' in population: s, t = self._test_bitwise_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'bitwise{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if f'alu.alu{bits}bit.shl.bit0.weight' in population: s, t = self._test_shifts_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'shifts{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if f'alu.alu{bits}bit.inc.bit0.xor.layer1.or.weight' in population: s, t = self._test_inc_dec_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'incdec{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if f'alu.alu{bits}bit.neg.not.bit0.weight' in population: s, t = self._test_neg_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'neg{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if f'combinational.barrelshifter{bits}.layer0.bit0.not_sel.weight' in population: s, t = self._test_barrel_shifter_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'barrelshifter{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if f'combinational.priorityencoder{bits}.valid.weight' in population: s, t = self._test_priority_encoder_nbits(population, bits, debug) scores += s total_tests += t self.category_scores[f'priorityencoder{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # 3-operand adder s, t = self._test_add3(population, debug) scores += s total_tests += t self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Order of operations (A + B × C) s, t = self._test_expr_add_mul(population, debug) scores += s total_tests += t self.category_scores['expr_add_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Comparators s, t = self._test_comparators(population, debug) scores += s total_tests += t self.category_scores['comparators'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Threshold gates s, t = self._test_threshold_gates(population, debug) scores += s total_tests += t self.category_scores['threshold'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Modular arithmetic s, t = self._test_modular_all(population, debug) scores += s total_tests += t self.category_scores['modular'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Pattern recognition s, t = self._test_patterns(population, debug) scores += s total_tests += t self.category_scores['patterns'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Error detection s, t = self._test_error_detection(population, debug) scores += s total_tests += t self.category_scores['error_detection'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Combinational s, t = self._test_combinational(population, debug) scores += s total_tests += t self.category_scores['combinational'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Control flow s, t = self._test_control_flow(population, debug) scores += s total_tests += t self.category_scores['control'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # ALU s, t = self._test_alu_ops(population, debug) scores += s total_tests += t self.category_scores['alu'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Manifest s, t = self._test_manifest(population, debug) scores += s total_tests += t self.category_scores['manifest'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Memory s, t = self._test_memory(population, debug) scores += s total_tests += t self.category_scores['memory'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Float16 circuits (if present) if 'float16.unpack.bit0.weight' in population: if debug: print(f"\n{'=' * 60}") print(f" FLOAT16 CIRCUITS") print(f"{'=' * 60}") s, t = self._test_float16_core(population, debug) scores += s total_tests += t self.category_scores['float16_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float16.add.exp_cmp.a_gt_b.weight' in population: s, t = self._test_float16_add(population, debug) scores += s total_tests += t self.category_scores['float16_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float16.mul.sign_xor.layer1.or.weight' in population: s, t = self._test_float16_mul(population, debug) scores += s total_tests += t self.category_scores['float16_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float16.div.sign_xor.layer1.or.weight' in population: s, t = self._test_float16_div(population, debug) scores += s total_tests += t self.category_scores['float16_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float16.cmp.a.exp_max.weight' in population: s, t = self._test_float16_cmp(population, debug) scores += s total_tests += t self.category_scores['float16_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) # Float32 circuits (if present) if 'float32.unpack.bit0.weight' in population: if debug: print(f"\n{'=' * 60}") print(f" FLOAT32 CIRCUITS") print(f"{'=' * 60}") s, t = self._test_float32_core(population, debug) scores += s total_tests += t self.category_scores['float32_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float32.add.exp_cmp.a_gt_b.weight' in population: s, t = self._test_float32_add(population, debug) scores += s total_tests += t self.category_scores['float32_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float32.mul.sign_xor.layer1.or.weight' in population: s, t = self._test_float32_mul(population, debug) scores += s total_tests += t self.category_scores['float32_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float32.div.sign_xor.layer1.or.weight' in population: s, t = self._test_float32_div(population, debug) scores += s total_tests += t self.category_scores['float32_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) if 'float32.cmp.a.exp_max.weight' in population: s, t = self._test_float32_cmp(population, debug) scores += s total_tests += t self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) self.total_tests = total_tests if debug: print("\n" + "=" * 60) print("CATEGORY SUMMARY") print("=" * 60) for cat, (got, expected) in sorted(self.category_scores.items()): pct = 100 * got / expected if expected > 0 else 0 status = "PASS" if got == expected else "FAIL" print(f" {cat:20} {int(got):6}/{expected:6} ({pct:6.2f}%) [{status}]") print("\n" + "=" * 60) print("CIRCUIT FAILURES") print("=" * 60) failed = [r for r in self.results if not r.success] if failed: for r in failed[:20]: print(f" {r.name}: {r.passed}/{r.total}") if r.failures: print(f" First failure: {r.failures[0]}") if len(failed) > 20: print(f" ... and {len(failed) - 20} more") else: print(" None!") return scores / total_tests if total_tests > 0 else scores def main(): parser = argparse.ArgumentParser(description='Unified Evaluation Suite for 8-bit Threshold Computer') parser.add_argument('--model', type=str, default=MODEL_PATH, help='Path to safetensors model') parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu') parser.add_argument('--pop_size', type=int, default=1, help='Population size for batched evaluation') parser.add_argument('--quiet', action='store_true', help='Suppress detailed output') parser.add_argument('--cpu-test', action='store_true', help='Run CPU smoke test (LOAD, ADD, STORE, HALT)') args = parser.parse_args() if args.cpu_test: return run_smoke_test() print("=" * 70) print(" UNIFIED EVALUATION SUITE") print("=" * 70) print(f"\nLoading model from {args.model}...") model = load_model(args.model) print(f" Loaded {len(model)} tensors, {sum(t.numel() for t in model.values()):,} params") print(f"\nInitializing evaluator on {args.device}...") evaluator = BatchedFitnessEvaluator(device=args.device, model_path=args.model) print(f"\nCreating population (size {args.pop_size})...") population = create_population(model, pop_size=args.pop_size, device=args.device) print("\nRunning evaluation...") if args.device == 'cuda': torch.cuda.synchronize() start = time.perf_counter() fitness = evaluator.evaluate(population, debug=not args.quiet) if args.device == 'cuda': torch.cuda.synchronize() elapsed = time.perf_counter() - start print("\n" + "=" * 70) print("RESULTS") print("=" * 70) if args.pop_size == 1: print(f" Fitness: {fitness[0].item():.6f}") else: print(f" Mean Fitness: {fitness.mean().item():.6f}") print(f" Min Fitness: {fitness.min().item():.6f}") print(f" Max Fitness: {fitness.max().item():.6f}") print(f" Total tests: {evaluator.total_tests}") print(f" Time: {elapsed * 1000:.2f} ms") if args.pop_size > 1: print(f" Throughput: {args.pop_size / elapsed:.0f} evals/sec") perfect = (fitness >= 0.9999).sum().item() print(f" Perfect (>=99.99%): {perfect}/{args.pop_size}") if fitness[0].item() >= 0.9999: print("\n STATUS: PASS") return 0 else: failed_count = int((1 - fitness[0].item()) * evaluator.total_tests) print(f"\n STATUS: FAIL ({failed_count} tests failed)") return 1 if __name__ == '__main__': exit(main())