|
|
""" |
|
|
Unified Evaluation Suite for 8-bit Threshold Computer |
|
|
====================================================== |
|
|
GPU-batched evaluation with per-circuit reporting. |
|
|
Includes CPU runtime for threshold-weight execution. |
|
|
|
|
|
Usage: |
|
|
python eval.py # Run circuit evaluation |
|
|
python eval.py --device cpu # CPU mode |
|
|
python eval.py --pop_size 1000 # Population mode for evolution |
|
|
python eval.py --cpu-test # Run CPU smoke test |
|
|
|
|
|
API (for prune_weights.py): |
|
|
from eval import load_model, create_population, BatchedFitnessEvaluator |
|
|
from eval import ThresholdCPU, ThresholdALU, CPUState |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from collections import defaultdict |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Callable, Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from safetensors import safe_open |
|
|
|
|
|
|
|
|
MODEL_PATH = os.path.join(os.path.dirname(__file__), "neural_computer.safetensors") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CircuitResult: |
|
|
"""Result for a single circuit test.""" |
|
|
name: str |
|
|
passed: int |
|
|
total: int |
|
|
failures: List[Tuple] = field(default_factory=list) |
|
|
|
|
|
@property |
|
|
def success(self) -> bool: |
|
|
return self.passed == self.total |
|
|
|
|
|
@property |
|
|
def rate(self) -> float: |
|
|
return self.passed / self.total if self.total > 0 else 0.0 |
|
|
|
|
|
|
|
|
def heaviside(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Threshold activation: 1 if x >= 0, else 0.""" |
|
|
return (x >= 0).float() |
|
|
|
|
|
|
|
|
def load_model(path: str = MODEL_PATH) -> Dict[str, torch.Tensor]: |
|
|
"""Load model tensors from safetensors.""" |
|
|
with safe_open(path, framework='pt') as f: |
|
|
return {name: f.get_tensor(name).float() for name in f.keys()} |
|
|
|
|
|
|
|
|
def load_metadata(path: str = MODEL_PATH) -> Dict: |
|
|
"""Load metadata from safetensors (includes signal_registry).""" |
|
|
with safe_open(path, framework='pt') as f: |
|
|
meta = f.metadata() |
|
|
if meta and 'signal_registry' in meta: |
|
|
return {'signal_registry': json.loads(meta['signal_registry'])} |
|
|
return {'signal_registry': {}} |
|
|
|
|
|
|
|
|
def get_manifest(tensors: Dict[str, torch.Tensor]) -> Dict[str, int]: |
|
|
"""Extract manifest values from tensors. |
|
|
|
|
|
Returns dict with data_bits, addr_bits, memory_bytes, version. |
|
|
Defaults to 8-bit data, 16-bit addr for legacy models. |
|
|
""" |
|
|
return { |
|
|
'data_bits': int(tensors.get('manifest.data_bits', torch.tensor([8.0])).item()), |
|
|
'addr_bits': int(tensors.get('manifest.addr_bits', |
|
|
tensors.get('manifest.pc_width', torch.tensor([16.0]))).item()), |
|
|
'memory_bytes': int(tensors.get('manifest.memory_bytes', torch.tensor([65536.0])).item()), |
|
|
'version': float(tensors.get('manifest.version', torch.tensor([1.0])).item()), |
|
|
} |
|
|
|
|
|
|
|
|
def create_population( |
|
|
base_tensors: Dict[str, torch.Tensor], |
|
|
pop_size: int, |
|
|
device: str = 'cuda' |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Replicate base tensors for batched population evaluation.""" |
|
|
return { |
|
|
name: tensor.unsqueeze(0).expand(pop_size, *tensor.shape).clone().to(device) |
|
|
for name, tensor in base_tensors.items() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FLAG_NAMES = ["Z", "N", "C", "V"] |
|
|
CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"] |
|
|
|
|
|
PC_BITS = 16 |
|
|
IR_BITS = 16 |
|
|
REG_BITS = 8 |
|
|
REG_COUNT = 4 |
|
|
FLAG_BITS = 4 |
|
|
SP_BITS = 16 |
|
|
CTRL_BITS = 4 |
|
|
MEM_BYTES = 65536 |
|
|
MEM_BITS = MEM_BYTES * 8 |
|
|
|
|
|
STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS |
|
|
|
|
|
|
|
|
def int_to_bits(value: int, width: int) -> List[int]: |
|
|
return [(value >> (width - 1 - i)) & 1 for i in range(width)] |
|
|
|
|
|
|
|
|
def bits_to_int(bits: List[int]) -> int: |
|
|
value = 0 |
|
|
for bit in bits: |
|
|
value = (value << 1) | int(bit) |
|
|
return value |
|
|
|
|
|
|
|
|
def bits_msb_to_lsb(bits: List[int]) -> List[int]: |
|
|
return list(reversed(bits)) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CPUState: |
|
|
pc: int |
|
|
ir: int |
|
|
regs: List[int] |
|
|
flags: List[int] |
|
|
sp: int |
|
|
ctrl: List[int] |
|
|
mem: List[int] |
|
|
|
|
|
def copy(self) -> 'CPUState': |
|
|
return CPUState( |
|
|
pc=int(self.pc), |
|
|
ir=int(self.ir), |
|
|
regs=[int(r) for r in self.regs], |
|
|
flags=[int(f) for f in self.flags], |
|
|
sp=int(self.sp), |
|
|
ctrl=[int(c) for c in self.ctrl], |
|
|
mem=[int(m) for m in self.mem], |
|
|
) |
|
|
|
|
|
|
|
|
def pack_state(state: CPUState) -> List[int]: |
|
|
bits: List[int] = [] |
|
|
bits.extend(int_to_bits(state.pc, PC_BITS)) |
|
|
bits.extend(int_to_bits(state.ir, IR_BITS)) |
|
|
for reg in state.regs: |
|
|
bits.extend(int_to_bits(reg, REG_BITS)) |
|
|
bits.extend([int(f) for f in state.flags]) |
|
|
bits.extend(int_to_bits(state.sp, SP_BITS)) |
|
|
bits.extend([int(c) for c in state.ctrl]) |
|
|
for byte in state.mem: |
|
|
bits.extend(int_to_bits(byte, REG_BITS)) |
|
|
return bits |
|
|
|
|
|
|
|
|
def unpack_state(bits: List[int]) -> CPUState: |
|
|
if len(bits) != STATE_BITS: |
|
|
raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}") |
|
|
|
|
|
idx = 0 |
|
|
pc = bits_to_int(bits[idx:idx + PC_BITS]) |
|
|
idx += PC_BITS |
|
|
ir = bits_to_int(bits[idx:idx + IR_BITS]) |
|
|
idx += IR_BITS |
|
|
|
|
|
regs = [] |
|
|
for _ in range(REG_COUNT): |
|
|
regs.append(bits_to_int(bits[idx:idx + REG_BITS])) |
|
|
idx += REG_BITS |
|
|
|
|
|
flags = [int(b) for b in bits[idx:idx + FLAG_BITS]] |
|
|
idx += FLAG_BITS |
|
|
|
|
|
sp = bits_to_int(bits[idx:idx + SP_BITS]) |
|
|
idx += SP_BITS |
|
|
|
|
|
ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]] |
|
|
idx += CTRL_BITS |
|
|
|
|
|
mem = [] |
|
|
for _ in range(MEM_BYTES): |
|
|
mem.append(bits_to_int(bits[idx:idx + REG_BITS])) |
|
|
idx += REG_BITS |
|
|
|
|
|
return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem) |
|
|
|
|
|
|
|
|
def decode_ir(ir: int) -> Tuple[int, int, int, int]: |
|
|
opcode = (ir >> 12) & 0xF |
|
|
rd = (ir >> 10) & 0x3 |
|
|
rs = (ir >> 8) & 0x3 |
|
|
imm8 = ir & 0xFF |
|
|
return opcode, rd, rs, imm8 |
|
|
|
|
|
|
|
|
def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]: |
|
|
z = 1 if result == 0 else 0 |
|
|
n = 1 if (result & 0x80) else 0 |
|
|
c = 1 if carry else 0 |
|
|
v = 1 if overflow else 0 |
|
|
return z, n, c, v |
|
|
|
|
|
|
|
|
def alu_add(a: int, b: int) -> Tuple[int, int, int]: |
|
|
full = a + b |
|
|
result = full & 0xFF |
|
|
carry = 1 if full > 0xFF else 0 |
|
|
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 |
|
|
return result, carry, overflow |
|
|
|
|
|
|
|
|
def alu_sub(a: int, b: int) -> Tuple[int, int, int]: |
|
|
full = (a - b) & 0x1FF |
|
|
result = full & 0xFF |
|
|
carry = 1 if a >= b else 0 |
|
|
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 |
|
|
return result, carry, overflow |
|
|
|
|
|
|
|
|
def ref_step(state: CPUState) -> CPUState: |
|
|
"""Reference CPU cycle (pure Python arithmetic).""" |
|
|
if state.ctrl[0] == 1: |
|
|
return state.copy() |
|
|
|
|
|
s = state.copy() |
|
|
|
|
|
hi = s.mem[s.pc] |
|
|
lo = s.mem[(s.pc + 1) & 0xFFFF] |
|
|
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) |
|
|
next_pc = (s.pc + 2) & 0xFFFF |
|
|
|
|
|
opcode, rd, rs, imm8 = decode_ir(s.ir) |
|
|
a = s.regs[rd] |
|
|
b = s.regs[rs] |
|
|
|
|
|
addr16 = None |
|
|
next_pc_ext = next_pc |
|
|
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): |
|
|
addr_hi = s.mem[next_pc] |
|
|
addr_lo = s.mem[(next_pc + 1) & 0xFFFF] |
|
|
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) |
|
|
next_pc_ext = (next_pc + 2) & 0xFFFF |
|
|
|
|
|
write_result = True |
|
|
result = a |
|
|
carry = 0 |
|
|
overflow = 0 |
|
|
|
|
|
if opcode == 0x0: |
|
|
result, carry, overflow = alu_add(a, b) |
|
|
elif opcode == 0x1: |
|
|
result, carry, overflow = alu_sub(a, b) |
|
|
elif opcode == 0x2: |
|
|
result = a & b |
|
|
elif opcode == 0x3: |
|
|
result = a | b |
|
|
elif opcode == 0x4: |
|
|
result = a ^ b |
|
|
elif opcode == 0x5: |
|
|
result = (a << 1) & 0xFF |
|
|
elif opcode == 0x6: |
|
|
result = (a >> 1) & 0xFF |
|
|
elif opcode == 0x7: |
|
|
result = (a * b) & 0xFF |
|
|
elif opcode == 0x8: |
|
|
if b == 0: |
|
|
result = 0xFF |
|
|
else: |
|
|
result = a // b |
|
|
elif opcode == 0x9: |
|
|
result, carry, overflow = alu_sub(a, b) |
|
|
write_result = False |
|
|
elif opcode == 0xA: |
|
|
result = s.mem[addr16] |
|
|
elif opcode == 0xB: |
|
|
s.mem[addr16] = b & 0xFF |
|
|
write_result = False |
|
|
elif opcode == 0xC: |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xD: |
|
|
cond_type = imm8 & 0x7 |
|
|
if cond_type == 0: |
|
|
take_branch = s.flags[0] == 1 |
|
|
elif cond_type == 1: |
|
|
take_branch = s.flags[0] == 0 |
|
|
elif cond_type == 2: |
|
|
take_branch = s.flags[2] == 1 |
|
|
elif cond_type == 3: |
|
|
take_branch = s.flags[2] == 0 |
|
|
elif cond_type == 4: |
|
|
take_branch = s.flags[1] == 1 |
|
|
elif cond_type == 5: |
|
|
take_branch = s.flags[1] == 0 |
|
|
elif cond_type == 6: |
|
|
take_branch = s.flags[3] == 1 |
|
|
else: |
|
|
take_branch = s.flags[3] == 0 |
|
|
if take_branch: |
|
|
s.pc = addr16 & 0xFFFF |
|
|
else: |
|
|
s.pc = next_pc_ext |
|
|
write_result = False |
|
|
elif opcode == 0xE: |
|
|
ret_addr = next_pc_ext & 0xFFFF |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem[s.sp] = (ret_addr >> 8) & 0xFF |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem[s.sp] = ret_addr & 0xFF |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xF: |
|
|
s.ctrl[0] = 1 |
|
|
write_result = False |
|
|
|
|
|
if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8): |
|
|
s.flags = list(flags_from_result(result, carry, overflow)) |
|
|
|
|
|
if write_result: |
|
|
s.regs[rd] = result & 0xFF |
|
|
|
|
|
if opcode not in (0xC, 0xD, 0xE): |
|
|
s.pc = next_pc_ext |
|
|
|
|
|
return s |
|
|
|
|
|
|
|
|
def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: |
|
|
"""Reference execution loop.""" |
|
|
s = state.copy() |
|
|
for i in range(max_cycles): |
|
|
if s.ctrl[0] == 1: |
|
|
return s, i |
|
|
s = ref_step(s) |
|
|
return s, max_cycles |
|
|
|
|
|
|
|
|
class ThresholdALU: |
|
|
def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None: |
|
|
self.device = device |
|
|
self.tensors = {k: v.float().to(device) for k, v in load_model(model_path).items()} |
|
|
|
|
|
def _get(self, name: str) -> torch.Tensor: |
|
|
return self.tensors[name] |
|
|
|
|
|
def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float: |
|
|
w = self._get(weight_key) |
|
|
b = self._get(bias_key) |
|
|
inp = torch.tensor(inputs, device=self.device) |
|
|
return heaviside((inp * w).sum() + b).item() |
|
|
|
|
|
def _eval_xor(self, prefix: str, inputs: List[float]) -> float: |
|
|
inp = torch.tensor(inputs, device=self.device) |
|
|
w_or = self._get(f"{prefix}.layer1.or.weight") |
|
|
b_or = self._get(f"{prefix}.layer1.or.bias") |
|
|
w_nand = self._get(f"{prefix}.layer1.nand.weight") |
|
|
b_nand = self._get(f"{prefix}.layer1.nand.bias") |
|
|
w2 = self._get(f"{prefix}.layer2.weight") |
|
|
b2 = self._get(f"{prefix}.layer2.bias") |
|
|
|
|
|
h_or = heaviside((inp * w_or).sum() + b_or).item() |
|
|
h_nand = heaviside((inp * w_nand).sum() + b_nand).item() |
|
|
hidden = torch.tensor([h_or, h_nand], device=self.device) |
|
|
return heaviside((hidden * w2).sum() + b2).item() |
|
|
|
|
|
def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]: |
|
|
ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b]) |
|
|
ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b]) |
|
|
|
|
|
ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin]) |
|
|
ha2_carry = self._eval_gate( |
|
|
f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin] |
|
|
) |
|
|
|
|
|
cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry]) |
|
|
return ha2_sum, cout |
|
|
|
|
|
def add(self, a: int, b: int) -> Tuple[int, int, int]: |
|
|
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) |
|
|
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) |
|
|
|
|
|
carry = 0.0 |
|
|
sum_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
sum_bit, carry = self._eval_full_adder( |
|
|
f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry |
|
|
) |
|
|
sum_bits.append(int(sum_bit)) |
|
|
|
|
|
result = bits_to_int(list(reversed(sum_bits))) |
|
|
carry_out = int(carry) |
|
|
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 |
|
|
return result, carry_out, overflow |
|
|
|
|
|
def sub(self, a: int, b: int) -> Tuple[int, int, int]: |
|
|
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) |
|
|
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) |
|
|
|
|
|
carry = 1.0 |
|
|
sum_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
notb = self._eval_gate( |
|
|
f"arithmetic.sub8bit.notb{bit}.weight", |
|
|
f"arithmetic.sub8bit.notb{bit}.bias", |
|
|
[float(b_bits[bit])], |
|
|
) |
|
|
|
|
|
xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb]) |
|
|
xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry]) |
|
|
|
|
|
and1 = self._eval_gate( |
|
|
f"arithmetic.sub8bit.fa{bit}.and1.weight", |
|
|
f"arithmetic.sub8bit.fa{bit}.and1.bias", |
|
|
[float(a_bits[bit]), notb], |
|
|
) |
|
|
and2 = self._eval_gate( |
|
|
f"arithmetic.sub8bit.fa{bit}.and2.weight", |
|
|
f"arithmetic.sub8bit.fa{bit}.and2.bias", |
|
|
[xor1, carry], |
|
|
) |
|
|
carry = self._eval_gate( |
|
|
f"arithmetic.sub8bit.fa{bit}.or_carry.weight", |
|
|
f"arithmetic.sub8bit.fa{bit}.or_carry.bias", |
|
|
[and1, and2], |
|
|
) |
|
|
|
|
|
sum_bits.append(int(xor2)) |
|
|
|
|
|
result = bits_to_int(list(reversed(sum_bits))) |
|
|
carry_out = int(carry) |
|
|
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 |
|
|
return result, carry_out, overflow |
|
|
|
|
|
def bitwise_and(self, a: int, b: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
w = self._get("alu.alu8bit.and.weight") |
|
|
bias = self._get("alu.alu8bit.and.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
|
|
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def bitwise_or(self, a: int, b: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
w = self._get("alu.alu8bit.or.weight") |
|
|
bias = self._get("alu.alu8bit.or.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
|
|
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def bitwise_not(self, a: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
w = self._get("alu.alu8bit.not.weight") |
|
|
bias = self._get("alu.alu8bit.not.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit])], device=self.device) |
|
|
out = heaviside((inp * w[bit]).sum() + bias[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def bitwise_xor(self, a: int, b: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
|
|
|
w_or = self._get("alu.alu8bit.xor.layer1.or.weight") |
|
|
b_or = self._get("alu.alu8bit.xor.layer1.or.bias") |
|
|
w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight") |
|
|
b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias") |
|
|
w2 = self._get("alu.alu8bit.xor.layer2.weight") |
|
|
b2 = self._get("alu.alu8bit.xor.layer2.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
|
|
h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit]) |
|
|
h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit]) |
|
|
hidden = torch.stack([h_or, h_nand]) |
|
|
out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def shift_left(self, a: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
w = self._get(f"alu.alu8bit.shl.bit{bit}.weight") |
|
|
bias = self._get(f"alu.alu8bit.shl.bit{bit}.bias") |
|
|
if bit < 7: |
|
|
inp = torch.tensor([float(a_bits[bit + 1])], device=self.device) |
|
|
else: |
|
|
inp = torch.tensor([0.0], device=self.device) |
|
|
out = heaviside((inp * w).sum() + bias).item() |
|
|
out_bits.append(int(out)) |
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def shift_right(self, a: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
w = self._get(f"alu.alu8bit.shr.bit{bit}.weight") |
|
|
bias = self._get(f"alu.alu8bit.shr.bit{bit}.bias") |
|
|
if bit > 0: |
|
|
inp = torch.tensor([float(a_bits[bit - 1])], device=self.device) |
|
|
else: |
|
|
inp = torch.tensor([0.0], device=self.device) |
|
|
out = heaviside((inp * w).sum() + bias).item() |
|
|
out_bits.append(int(out)) |
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def multiply(self, a: int, b: int) -> int: |
|
|
"""8-bit multiply using partial product AND gates + shift-add.""" |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
|
|
|
pp = [[0] * 8 for _ in range(8)] |
|
|
for i in range(8): |
|
|
for j in range(8): |
|
|
w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight") |
|
|
bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias") |
|
|
inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device) |
|
|
pp[i][j] = int(heaviside((inp * w).sum() + bias).item()) |
|
|
|
|
|
result = 0 |
|
|
for j in range(8): |
|
|
if b_bits[j] == 0: |
|
|
continue |
|
|
row = 0 |
|
|
for i in range(8): |
|
|
row |= (pp[i][j] << (7 - i)) |
|
|
shifted = row << (7 - j) |
|
|
result, _, _ = self.add(result & 0xFF, shifted & 0xFF) |
|
|
if shifted > 255 or result > 255: |
|
|
result = (result + (shifted >> 8)) & 0xFF |
|
|
|
|
|
return result & 0xFF |
|
|
|
|
|
def divide(self, a: int, b: int) -> Tuple[int, int]: |
|
|
"""8-bit divide using restoring division with threshold gates.""" |
|
|
if b == 0: |
|
|
return 0xFF, a |
|
|
|
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
|
|
|
quotient = 0 |
|
|
remainder = 0 |
|
|
|
|
|
for stage in range(8): |
|
|
remainder = ((remainder << 1) | a_bits[stage]) & 0xFF |
|
|
|
|
|
rem_bits = int_to_bits(remainder, REG_BITS) |
|
|
div_bits = int_to_bits(b, REG_BITS) |
|
|
|
|
|
w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight") |
|
|
bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias") |
|
|
inp = torch.tensor([float(rem_bits[i]) for i in range(8)] + |
|
|
[float(div_bits[i]) for i in range(8)], device=self.device) |
|
|
cmp_result = int(heaviside((inp * w).sum() + bias).item()) |
|
|
|
|
|
if cmp_result: |
|
|
remainder, _, _ = self.sub(remainder, b) |
|
|
quotient = (quotient << 1) | 1 |
|
|
else: |
|
|
quotient = quotient << 1 |
|
|
|
|
|
return quotient & 0xFF, remainder & 0xFF |
|
|
|
|
|
|
|
|
class ThresholdCPU: |
|
|
def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None: |
|
|
self.device = device |
|
|
self.alu = ThresholdALU(model_path, device=device) |
|
|
|
|
|
def _addr_decode(self, addr: int) -> torch.Tensor: |
|
|
bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32) |
|
|
w = self.alu._get("memory.addr_decode.weight") |
|
|
b = self.alu._get("memory.addr_decode.bias") |
|
|
return heaviside((w * bits).sum(dim=1) + b) |
|
|
|
|
|
def _memory_read(self, mem: List[int], addr: int) -> int: |
|
|
sel = self._addr_decode(addr) |
|
|
mem_bits = torch.tensor( |
|
|
[int_to_bits(byte, REG_BITS) for byte in mem], |
|
|
device=self.device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
and_w = self.alu._get("memory.read.and.weight") |
|
|
and_b = self.alu._get("memory.read.and.bias") |
|
|
or_w = self.alu._get("memory.read.or.weight") |
|
|
or_b = self.alu._get("memory.read.or.bias") |
|
|
|
|
|
out_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.stack([mem_bits[:, bit], sel], dim=1) |
|
|
and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) |
|
|
out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item() |
|
|
out_bits.append(int(out_bit)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]: |
|
|
sel = self._addr_decode(addr) |
|
|
data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32) |
|
|
mem_bits = torch.tensor( |
|
|
[int_to_bits(byte, REG_BITS) for byte in mem], |
|
|
device=self.device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
sel_w = self.alu._get("memory.write.sel.weight") |
|
|
sel_b = self.alu._get("memory.write.sel.bias") |
|
|
nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1) |
|
|
nsel_b = self.alu._get("memory.write.nsel.bias") |
|
|
and_old_w = self.alu._get("memory.write.and_old.weight") |
|
|
and_old_b = self.alu._get("memory.write.and_old.bias") |
|
|
and_new_w = self.alu._get("memory.write.and_new.weight") |
|
|
and_new_b = self.alu._get("memory.write.and_new.bias") |
|
|
or_w = self.alu._get("memory.write.or.weight") |
|
|
or_b = self.alu._get("memory.write.or.bias") |
|
|
|
|
|
we = torch.ones_like(sel) |
|
|
sel_inp = torch.stack([sel, we], dim=1) |
|
|
write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) |
|
|
nsel = heaviside((write_sel * nsel_w) + nsel_b) |
|
|
|
|
|
new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device) |
|
|
for bit in range(REG_BITS): |
|
|
old_bit = mem_bits[:, bit] |
|
|
data_bit = data_bits[bit].expand(MEM_BYTES) |
|
|
inp_old = torch.stack([old_bit, nsel], dim=1) |
|
|
inp_new = torch.stack([data_bit, write_sel], dim=1) |
|
|
|
|
|
and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) |
|
|
and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) |
|
|
or_inp = torch.stack([and_old, and_new], dim=1) |
|
|
out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) |
|
|
new_mem_bits[:, bit] = out_bit |
|
|
|
|
|
return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)] |
|
|
|
|
|
def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int: |
|
|
pc_bits = int_to_bits(pc_byte, REG_BITS) |
|
|
target_bits = int_to_bits(target_byte, REG_BITS) |
|
|
|
|
|
out_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
not_sel = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.not_sel.weight", |
|
|
f"{prefix}.bit{bit}.not_sel.bias", |
|
|
[float(flag)], |
|
|
) |
|
|
and_a = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.and_a.weight", |
|
|
f"{prefix}.bit{bit}.and_a.bias", |
|
|
[float(pc_bits[bit]), not_sel], |
|
|
) |
|
|
and_b = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.and_b.weight", |
|
|
f"{prefix}.bit{bit}.and_b.bias", |
|
|
[float(target_bits[bit]), float(flag)], |
|
|
) |
|
|
out_bit = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.or.weight", |
|
|
f"{prefix}.bit{bit}.or.bias", |
|
|
[and_a, and_b], |
|
|
) |
|
|
out_bits.append(int(out_bit)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def step(self, state: CPUState) -> CPUState: |
|
|
"""Single CPU cycle using threshold neurons.""" |
|
|
if state.ctrl[0] == 1: |
|
|
return state.copy() |
|
|
|
|
|
s = state.copy() |
|
|
|
|
|
hi = self._memory_read(s.mem, s.pc) |
|
|
lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF) |
|
|
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) |
|
|
next_pc = (s.pc + 2) & 0xFFFF |
|
|
|
|
|
opcode, rd, rs, imm8 = decode_ir(s.ir) |
|
|
a = s.regs[rd] |
|
|
b = s.regs[rs] |
|
|
|
|
|
addr16 = None |
|
|
next_pc_ext = next_pc |
|
|
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): |
|
|
addr_hi = self._memory_read(s.mem, next_pc) |
|
|
addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF) |
|
|
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) |
|
|
next_pc_ext = (next_pc + 2) & 0xFFFF |
|
|
|
|
|
write_result = True |
|
|
result = a |
|
|
carry = 0 |
|
|
overflow = 0 |
|
|
|
|
|
if opcode == 0x0: |
|
|
result, carry, overflow = self.alu.add(a, b) |
|
|
elif opcode == 0x1: |
|
|
result, carry, overflow = self.alu.sub(a, b) |
|
|
elif opcode == 0x2: |
|
|
result = self.alu.bitwise_and(a, b) |
|
|
elif opcode == 0x3: |
|
|
result = self.alu.bitwise_or(a, b) |
|
|
elif opcode == 0x4: |
|
|
result = self.alu.bitwise_xor(a, b) |
|
|
elif opcode == 0x5: |
|
|
result = self.alu.shift_left(a) |
|
|
elif opcode == 0x6: |
|
|
result = self.alu.shift_right(a) |
|
|
elif opcode == 0x7: |
|
|
result = self.alu.multiply(a, b) |
|
|
elif opcode == 0x8: |
|
|
result, _ = self.alu.divide(a, b) |
|
|
elif opcode == 0x9: |
|
|
result, carry, overflow = self.alu.sub(a, b) |
|
|
write_result = False |
|
|
elif opcode == 0xA: |
|
|
result = self._memory_read(s.mem, addr16) |
|
|
elif opcode == 0xB: |
|
|
s.mem = self._memory_write(s.mem, addr16, b & 0xFF) |
|
|
write_result = False |
|
|
elif opcode == 0xC: |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xD: |
|
|
cond_type = imm8 & 0x7 |
|
|
cond_circuits = [ |
|
|
("control.jz", 0), |
|
|
("control.jnz", 0), |
|
|
("control.jc", 2), |
|
|
("control.jnc", 2), |
|
|
("control.jn", 1), |
|
|
("control.jp", 1), |
|
|
("control.jv", 3), |
|
|
("control.jnv", 3), |
|
|
] |
|
|
circuit_prefix, flag_idx = cond_circuits[cond_type] |
|
|
hi_pc = self._conditional_jump_byte( |
|
|
circuit_prefix, |
|
|
(next_pc_ext >> 8) & 0xFF, |
|
|
(addr16 >> 8) & 0xFF, |
|
|
s.flags[flag_idx], |
|
|
) |
|
|
lo_pc = self._conditional_jump_byte( |
|
|
circuit_prefix, |
|
|
next_pc_ext & 0xFF, |
|
|
addr16 & 0xFF, |
|
|
s.flags[flag_idx], |
|
|
) |
|
|
s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF) |
|
|
write_result = False |
|
|
elif opcode == 0xE: |
|
|
ret_addr = next_pc_ext & 0xFFFF |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF) |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF) |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xF: |
|
|
s.ctrl[0] = 1 |
|
|
write_result = False |
|
|
|
|
|
if opcode <= 0x9 or opcode == 0xA: |
|
|
s.flags = list(flags_from_result(result, carry, overflow)) |
|
|
|
|
|
if write_result: |
|
|
s.regs[rd] = result & 0xFF |
|
|
|
|
|
if opcode not in (0xC, 0xD, 0xE): |
|
|
s.pc = next_pc_ext |
|
|
|
|
|
return s |
|
|
|
|
|
def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: |
|
|
"""Execute until HALT or max_cycles reached.""" |
|
|
s = state.copy() |
|
|
for i in range(max_cycles): |
|
|
if s.ctrl[0] == 1: |
|
|
return s, i |
|
|
s = self.step(s) |
|
|
return s, max_cycles |
|
|
|
|
|
def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor: |
|
|
"""Tensor-in, tensor-out interface for neural integration.""" |
|
|
bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()] |
|
|
state = unpack_state(bits_list) |
|
|
final, _ = self.run_until_halt(state, max_cycles=max_cycles) |
|
|
return torch.tensor(pack_state(final), dtype=torch.float32) |
|
|
|
|
|
|
|
|
def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int: |
|
|
return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF) |
|
|
|
|
|
|
|
|
def write_instr(mem: List[int], addr: int, instr: int) -> None: |
|
|
mem[addr & 0xFFFF] = (instr >> 8) & 0xFF |
|
|
mem[(addr + 1) & 0xFFFF] = instr & 0xFF |
|
|
|
|
|
|
|
|
def write_addr(mem: List[int], addr: int, value: int) -> None: |
|
|
mem[addr & 0xFFFF] = (value >> 8) & 0xFF |
|
|
mem[(addr + 1) & 0xFFFF] = value & 0xFF |
|
|
|
|
|
|
|
|
def run_smoke_test() -> int: |
|
|
"""Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12.""" |
|
|
mem = [0] * 65536 |
|
|
|
|
|
write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00)) |
|
|
write_addr(mem, 0x0002, 0x0100) |
|
|
write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00)) |
|
|
write_addr(mem, 0x0006, 0x0101) |
|
|
write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00)) |
|
|
write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00)) |
|
|
write_addr(mem, 0x000C, 0x0102) |
|
|
write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00)) |
|
|
|
|
|
mem[0x0100] = 5 |
|
|
mem[0x0101] = 7 |
|
|
|
|
|
state = CPUState( |
|
|
pc=0, |
|
|
ir=0, |
|
|
regs=[0, 0, 0, 0], |
|
|
flags=[0, 0, 0, 0], |
|
|
sp=0xFFFE, |
|
|
ctrl=[0, 0, 0, 0], |
|
|
mem=mem, |
|
|
) |
|
|
|
|
|
print("Running reference implementation...") |
|
|
final, cycles = ref_run_until_halt(state, max_cycles=20) |
|
|
|
|
|
assert final.ctrl[0] == 1, "HALT flag not set" |
|
|
assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}" |
|
|
assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}" |
|
|
assert cycles <= 10, f"Unexpected cycle count: {cycles}" |
|
|
print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}") |
|
|
|
|
|
print("Running threshold-weight implementation...") |
|
|
threshold_cpu = ThresholdCPU() |
|
|
t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20) |
|
|
|
|
|
assert t_final.ctrl[0] == 1, "Threshold HALT flag not set" |
|
|
assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}" |
|
|
assert t_final.mem[0x0102] == final.mem[0x0102], ( |
|
|
f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}" |
|
|
) |
|
|
assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}" |
|
|
print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}") |
|
|
|
|
|
print("Validating forward() tensor I/O...") |
|
|
bits = torch.tensor(pack_state(state), dtype=torch.float32) |
|
|
out_bits = threshold_cpu.forward(bits, max_cycles=20) |
|
|
out_state = unpack_state([int(b) for b in out_bits.tolist()]) |
|
|
assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}" |
|
|
assert out_state.mem[0x0102] == final.mem[0x0102], ( |
|
|
f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}" |
|
|
) |
|
|
print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}") |
|
|
|
|
|
print("\nSmoke test: PASSED") |
|
|
return 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchedFitnessEvaluator: |
|
|
""" |
|
|
GPU-batched fitness evaluator with per-circuit reporting. |
|
|
Tests all circuits comprehensively. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH, tensors: Dict[str, torch.Tensor] = None): |
|
|
self.device = device |
|
|
self.model_path = model_path |
|
|
self.metadata = load_metadata(model_path) |
|
|
self.signal_registry = self.metadata.get('signal_registry', {}) |
|
|
self.results: List[CircuitResult] = [] |
|
|
self.category_scores: Dict[str, Tuple[float, int]] = {} |
|
|
self.total_tests = 0 |
|
|
|
|
|
|
|
|
if tensors is not None: |
|
|
self.manifest = get_manifest(tensors) |
|
|
else: |
|
|
base_tensors = load_model(model_path) |
|
|
self.manifest = get_manifest(base_tensors) |
|
|
self.data_bits = self.manifest['data_bits'] |
|
|
self.addr_bits = self.manifest['addr_bits'] |
|
|
|
|
|
self._setup_tests() |
|
|
|
|
|
def _setup_tests(self): |
|
|
"""Pre-compute test vectors on device.""" |
|
|
d = self.device |
|
|
|
|
|
|
|
|
self.tt2 = torch.tensor( |
|
|
[[0, 0], [0, 1], [1, 0], [1, 1]], |
|
|
device=d, dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
self.tt3 = torch.tensor([ |
|
|
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
|
|
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1] |
|
|
], device=d, dtype=torch.float32) |
|
|
|
|
|
|
|
|
self.expected = { |
|
|
'and': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32), |
|
|
'or': torch.tensor([0, 1, 1, 1], device=d, dtype=torch.float32), |
|
|
'nand': torch.tensor([1, 1, 1, 0], device=d, dtype=torch.float32), |
|
|
'nor': torch.tensor([1, 0, 0, 0], device=d, dtype=torch.float32), |
|
|
'xor': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32), |
|
|
'xnor': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32), |
|
|
'implies': torch.tensor([1, 1, 0, 1], device=d, dtype=torch.float32), |
|
|
'biimplies': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32), |
|
|
'not': torch.tensor([1, 0], device=d, dtype=torch.float32), |
|
|
'ha_sum': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32), |
|
|
'ha_carry': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32), |
|
|
'fa_sum': torch.tensor([0, 1, 1, 0, 1, 0, 0, 1], device=d, dtype=torch.float32), |
|
|
'fa_cout': torch.tensor([0, 0, 0, 1, 0, 1, 1, 1], device=d, dtype=torch.float32), |
|
|
} |
|
|
|
|
|
|
|
|
self.not_inputs = torch.tensor([[0], [1]], device=d, dtype=torch.float32) |
|
|
|
|
|
|
|
|
self.test_8bit = torch.tensor([ |
|
|
0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, |
|
|
0b10101010, 0b01010101, 0b11110000, 0b00001111, |
|
|
0b11001100, 0b00110011, 0b10000001, 0b01111110 |
|
|
], device=d, dtype=torch.long) |
|
|
|
|
|
|
|
|
self.test_8bit_bits = torch.stack([ |
|
|
((self.test_8bit >> (7 - i)) & 1).float() for i in range(8) |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
comp_tests = [ |
|
|
(0, 0), (1, 0), (0, 1), (5, 3), (3, 5), (5, 5), |
|
|
(255, 0), (0, 255), (128, 127), (127, 128), |
|
|
(100, 99), (99, 100), (64, 32), (32, 64), |
|
|
(1, 1), (254, 255), (255, 254), (128, 128), |
|
|
(0, 128), (128, 0), (64, 64), (192, 192), |
|
|
(15, 16), (16, 15), (240, 239), (239, 240), |
|
|
(85, 170), (170, 85), (0xAA, 0x55), (0x55, 0xAA), |
|
|
(0x0F, 0xF0), (0xF0, 0x0F), (0x33, 0xCC), (0xCC, 0x33), |
|
|
(2, 3), (3, 2), (126, 127), (127, 126), |
|
|
(129, 128), (128, 129), (200, 199), (199, 200), |
|
|
(50, 51), (51, 50), (10, 20), (20, 10), |
|
|
(100, 100), (200, 200), (77, 77), (0, 0) |
|
|
] |
|
|
self.comp_a = torch.tensor([c[0] for c in comp_tests], device=d, dtype=torch.long) |
|
|
self.comp_b = torch.tensor([c[1] for c in comp_tests], device=d, dtype=torch.long) |
|
|
|
|
|
|
|
|
self.mod_test = torch.arange(256, device=d, dtype=torch.long) |
|
|
|
|
|
|
|
|
self.test_32bit = torch.tensor([ |
|
|
0, 1, 2, 255, 256, 65535, 65536, |
|
|
0x7FFFFFFF, 0x80000000, 0xFFFFFFFF, |
|
|
0x12345678, 0xDEADBEEF, 0xCAFEBABE, |
|
|
1000000, 1000000000, 2147483647, |
|
|
0x55555555, 0xAAAAAAAA, 0x0F0F0F0F, 0xF0F0F0F0 |
|
|
], device=d, dtype=torch.long) |
|
|
|
|
|
|
|
|
comp32_tests = [ |
|
|
(0, 0), (1, 0), (0, 1), (1000, 999), (999, 1000), |
|
|
(0xFFFFFFFF, 0), (0, 0xFFFFFFFF), |
|
|
(0x80000000, 0x7FFFFFFF), (0x7FFFFFFF, 0x80000000), |
|
|
(1000000, 1000000), (0x12345678, 0x12345678), |
|
|
(0xDEADBEEF, 0xCAFEBABE), (0xCAFEBABE, 0xDEADBEEF), |
|
|
(256, 255), (255, 256), (65536, 65535), (65535, 65536), |
|
|
] |
|
|
self.comp32_a = torch.tensor([c[0] for c in comp32_tests], device=d, dtype=torch.long) |
|
|
self.comp32_b = torch.tensor([c[1] for c in comp32_tests], device=d, dtype=torch.long) |
|
|
|
|
|
def _record(self, name: str, passed: int, total: int, failures: List[Tuple] = None): |
|
|
"""Record a circuit test result.""" |
|
|
self.results.append(CircuitResult( |
|
|
name=name, |
|
|
passed=passed, |
|
|
total=total, |
|
|
failures=failures or [] |
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_single_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor, |
|
|
expected: torch.Tensor) -> torch.Tensor: |
|
|
"""Test single-layer gate (AND, OR, NAND, NOR, IMPLIES).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
w = pop[f'{prefix}.weight'] |
|
|
b = pop[f'{prefix}.bias'] |
|
|
|
|
|
|
|
|
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): |
|
|
if exp.item() != got.item(): |
|
|
failures.append((inp.tolist(), exp.item(), got.item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), len(expected), failures) |
|
|
return correct |
|
|
|
|
|
def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor, |
|
|
expected: torch.Tensor) -> torch.Tensor: |
|
|
"""Test two-layer gate (XOR, XNOR, BIIMPLIES).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
|
|
|
w1_n1 = pop[f'{prefix}.layer1.neuron1.weight'] |
|
|
b1_n1 = pop[f'{prefix}.layer1.neuron1.bias'] |
|
|
w1_n2 = pop[f'{prefix}.layer1.neuron2.weight'] |
|
|
b1_n2 = pop[f'{prefix}.layer1.neuron2.bias'] |
|
|
|
|
|
h1 = heaviside(inputs @ w1_n1.view(pop_size, -1).T + b1_n1.view(pop_size)) |
|
|
h2 = heaviside(inputs @ w1_n2.view(pop_size, -1).T + b1_n2.view(pop_size)) |
|
|
hidden = torch.stack([h1, h2], dim=-1) |
|
|
|
|
|
|
|
|
w2 = pop[f'{prefix}.layer2.weight'] |
|
|
b2 = pop[f'{prefix}.layer2.bias'] |
|
|
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): |
|
|
if exp.item() != got.item(): |
|
|
failures.append((inp.tolist(), exp.item(), got.item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), len(expected), failures) |
|
|
return correct |
|
|
|
|
|
def _test_xor_ornand(self, pop: Dict, prefix: str, inputs: torch.Tensor, |
|
|
expected: torch.Tensor) -> torch.Tensor: |
|
|
"""Test XOR with or/nand layer naming.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
w_or = pop[f'{prefix}.layer1.or.weight'] |
|
|
b_or = pop[f'{prefix}.layer1.or.bias'] |
|
|
w_nand = pop[f'{prefix}.layer1.nand.weight'] |
|
|
b_nand = pop[f'{prefix}.layer1.nand.bias'] |
|
|
|
|
|
h_or = heaviside(inputs @ w_or.view(pop_size, -1).T + b_or.view(pop_size)) |
|
|
h_nand = heaviside(inputs @ w_nand.view(pop_size, -1).T + b_nand.view(pop_size)) |
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
|
|
|
w2 = pop[f'{prefix}.layer2.weight'] |
|
|
b2 = pop[f'{prefix}.layer2.bias'] |
|
|
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): |
|
|
if exp.item() != got.item(): |
|
|
failures.append((inp.tolist(), exp.item(), got.item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), len(expected), failures) |
|
|
return correct |
|
|
|
|
|
def _test_boolean_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test all boolean gates.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== BOOLEAN GATES ===") |
|
|
|
|
|
|
|
|
for gate in ['and', 'or', 'nand', 'nor', 'implies']: |
|
|
scores += self._test_single_gate(pop, f'boolean.{gate}', self.tt2, self.expected[gate]) |
|
|
total += 4 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
|
|
|
w = pop['boolean.not.weight'] |
|
|
b = pop['boolean.not.bias'] |
|
|
out = heaviside(self.not_inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
correct = (out == self.expected['not'].unsqueeze(1)).float().sum(0) |
|
|
scores += correct |
|
|
total += 2 |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for inp, exp, got in zip(self.not_inputs, self.expected['not'], out[:, 0]): |
|
|
if exp.item() != got.item(): |
|
|
failures.append((inp.tolist(), exp.item(), got.item())) |
|
|
self._record('boolean.not', int(correct[0].item()), 2, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
|
|
|
for gate in ['xnor', 'biimplies']: |
|
|
scores += self._test_twolayer_gate(pop, f'boolean.{gate}', self.tt2, self.expected.get(gate, self.expected['xnor'])) |
|
|
total += 4 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
|
|
|
scores += self._test_twolayer_gate(pop, 'boolean.xor', self.tt2, self.expected['xor']) |
|
|
total += 4 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _eval_xor(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
|
|
"""Evaluate XOR gate with or/nand decomposition. |
|
|
|
|
|
Args: |
|
|
a, b: Tensors of shape [num_tests] or [num_tests, pop_size] |
|
|
|
|
|
Returns: |
|
|
Tensor of shape [num_tests, pop_size] |
|
|
""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
|
|
|
if a.dim() == 1: |
|
|
a = a.unsqueeze(1).expand(-1, pop_size) |
|
|
if b.dim() == 1: |
|
|
b = b.unsqueeze(1).expand(-1, pop_size) |
|
|
|
|
|
|
|
|
inputs = torch.stack([a, b], dim=-1) |
|
|
|
|
|
w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, 2) |
|
|
b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size) |
|
|
|
|
|
|
|
|
h_or = heaviside((inputs * w_or).sum(-1) + b_or) |
|
|
h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand) |
|
|
|
|
|
|
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
|
|
|
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, 2) |
|
|
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size) |
|
|
return heaviside((hidden * w2).sum(-1) + b2) |
|
|
|
|
|
def _eval_single_fa(self, pop: Dict, prefix: str, |
|
|
a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Evaluate single full adder. |
|
|
|
|
|
Args: |
|
|
a, b, cin: Tensors of shape [num_tests] or [num_tests, pop_size] |
|
|
|
|
|
Returns: |
|
|
sum_out, cout: Both of shape [num_tests, pop_size] |
|
|
""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
|
|
|
if a.dim() == 1: |
|
|
a = a.unsqueeze(1).expand(-1, pop_size) |
|
|
if b.dim() == 1: |
|
|
b = b.unsqueeze(1).expand(-1, pop_size) |
|
|
if cin.dim() == 1: |
|
|
cin = cin.unsqueeze(1).expand(-1, pop_size) |
|
|
|
|
|
|
|
|
ha1_sum = self._eval_xor(pop, f'{prefix}.ha1.sum', a, b) |
|
|
|
|
|
|
|
|
ab = torch.stack([a, b], dim=-1) |
|
|
w_c1 = pop[f'{prefix}.ha1.carry.weight'].view(pop_size, 2) |
|
|
b_c1 = pop[f'{prefix}.ha1.carry.bias'].view(pop_size) |
|
|
ha1_carry = heaviside((ab * w_c1).sum(-1) + b_c1) |
|
|
|
|
|
|
|
|
ha2_sum = self._eval_xor(pop, f'{prefix}.ha2.sum', ha1_sum, cin) |
|
|
|
|
|
|
|
|
sc = torch.stack([ha1_sum, cin], dim=-1) |
|
|
w_c2 = pop[f'{prefix}.ha2.carry.weight'].view(pop_size, 2) |
|
|
b_c2 = pop[f'{prefix}.ha2.carry.bias'].view(pop_size) |
|
|
ha2_carry = heaviside((sc * w_c2).sum(-1) + b_c2) |
|
|
|
|
|
|
|
|
carries = torch.stack([ha1_carry, ha2_carry], dim=-1) |
|
|
w_cout = pop[f'{prefix}.carry_or.weight'].view(pop_size, 2) |
|
|
b_cout = pop[f'{prefix}.carry_or.bias'].view(pop_size) |
|
|
cout = heaviside((carries * w_cout).sum(-1) + b_cout) |
|
|
|
|
|
return ha2_sum, cout |
|
|
|
|
|
def _test_halfadder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test half adder.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== HALF ADDER ===") |
|
|
|
|
|
|
|
|
scores += self._test_xor_ornand(pop, 'arithmetic.halfadder.sum', self.tt2, self.expected['ha_sum']) |
|
|
total += 4 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
|
|
|
scores += self._test_single_gate(pop, 'arithmetic.halfadder.carry', self.tt2, self.expected['ha_carry']) |
|
|
total += 4 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_fulladder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test full adder with all 8 input combinations.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
if debug: |
|
|
print("\n=== FULL ADDER ===") |
|
|
|
|
|
a = self.tt3[:, 0] |
|
|
b = self.tt3[:, 1] |
|
|
cin = self.tt3[:, 2] |
|
|
|
|
|
sum_out, cout = self._eval_single_fa(pop, 'arithmetic.fulladder', a, b, cin) |
|
|
|
|
|
sum_correct = (sum_out == self.expected['fa_sum'].unsqueeze(1)).float().sum(0) |
|
|
cout_correct = (cout == self.expected['fa_cout'].unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures_sum = [] |
|
|
failures_cout = [] |
|
|
if pop_size == 1: |
|
|
for i in range(8): |
|
|
if sum_out[i, 0].item() != self.expected['fa_sum'][i].item(): |
|
|
failures_sum.append(([a[i].item(), b[i].item(), cin[i].item()], |
|
|
self.expected['fa_sum'][i].item(), sum_out[i, 0].item())) |
|
|
if cout[i, 0].item() != self.expected['fa_cout'][i].item(): |
|
|
failures_cout.append(([a[i].item(), b[i].item(), cin[i].item()], |
|
|
self.expected['fa_cout'][i].item(), cout[i, 0].item())) |
|
|
|
|
|
self._record('arithmetic.fulladder.sum', int(sum_correct[0].item()), 8, failures_sum) |
|
|
self._record('arithmetic.fulladder.cout', int(cout_correct[0].item()), 8, failures_cout) |
|
|
|
|
|
if debug: |
|
|
for r in self.results[-2:]: |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return sum_correct + cout_correct, 16 |
|
|
|
|
|
def _test_ripplecarry(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit ripple carry adder.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== RIPPLE CARRY {bits}-BIT ===") |
|
|
|
|
|
prefix = f'arithmetic.ripplecarry{bits}bit' |
|
|
max_val = 1 << bits |
|
|
num_tests = min(max_val * max_val, 65536) |
|
|
|
|
|
if bits <= 4: |
|
|
|
|
|
test_a = torch.arange(max_val, device=self.device) |
|
|
test_b = torch.arange(max_val, device=self.device) |
|
|
a_vals, b_vals = torch.meshgrid(test_a, test_b, indexing='ij') |
|
|
a_vals = a_vals.flatten() |
|
|
b_vals = b_vals.flatten() |
|
|
else: |
|
|
|
|
|
edge_vals = [0, 1, 2, 127, 128, 254, 255] |
|
|
pairs = [(a, b) for a in edge_vals for b in edge_vals] |
|
|
for i in range(0, 256, 16): |
|
|
pairs.append((i, 255 - i)) |
|
|
pairs = list(set(pairs)) |
|
|
a_vals = torch.tensor([p[0] for p in pairs], device=self.device) |
|
|
b_vals = torch.tensor([p[1] for p in pairs], device=self.device) |
|
|
num_tests = len(pairs) |
|
|
|
|
|
|
|
|
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
|
|
|
|
|
|
carry = torch.zeros(len(a_vals), pop_size, device=self.device) |
|
|
sum_bits = [] |
|
|
|
|
|
for bit in range(bits): |
|
|
bit_idx = bits - 1 - bit |
|
|
s, carry = self._eval_single_fa( |
|
|
pop, f'{prefix}.fa{bit}', |
|
|
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
|
|
b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
|
|
carry |
|
|
) |
|
|
sum_bits.append(s) |
|
|
|
|
|
|
|
|
sum_bits = torch.stack(sum_bits[::-1], dim=-1) |
|
|
result = torch.zeros(len(a_vals), pop_size, device=self.device) |
|
|
for i in range(bits): |
|
|
result += sum_bits[:, :, i] * (1 << (bits - 1 - i)) |
|
|
|
|
|
|
|
|
expected = ((a_vals + b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float() |
|
|
correct = (result == expected).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(min(len(a_vals), 100)): |
|
|
if result[i, 0].item() != expected[i, 0].item(): |
|
|
failures.append(( |
|
|
[int(a_vals[i].item()), int(b_vals[i].item())], |
|
|
int(expected[i, 0].item()), |
|
|
int(result[i, 0].item()) |
|
|
)) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), num_tests, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, num_tests |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_add3(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test 3-operand 8-bit adder (A + B + C).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== 3-OPERAND ADDER ===") |
|
|
|
|
|
prefix = 'arithmetic.add3_8bit' |
|
|
bits = 8 |
|
|
|
|
|
|
|
|
|
|
|
test_cases = [] |
|
|
|
|
|
for a in [0, 1, 2]: |
|
|
for b in [0, 1, 2]: |
|
|
for c in [0, 1, 2]: |
|
|
test_cases.append((a, b, c)) |
|
|
|
|
|
edge = [0, 1, 127, 128, 254, 255] |
|
|
for a in edge: |
|
|
for b in edge: |
|
|
for c in edge: |
|
|
test_cases.append((a, b, c)) |
|
|
|
|
|
test_cases.extend([ |
|
|
(15, 27, 33), |
|
|
(100, 100, 55), |
|
|
(100, 100, 56), |
|
|
(85, 85, 85), |
|
|
(86, 85, 85), |
|
|
]) |
|
|
test_cases = list(set(test_cases)) |
|
|
|
|
|
a_vals = torch.tensor([t[0] for t in test_cases], device=self.device) |
|
|
b_vals = torch.tensor([t[1] for t in test_cases], device=self.device) |
|
|
c_vals = torch.tensor([t[2] for t in test_cases], device=self.device) |
|
|
num_tests = len(test_cases) |
|
|
|
|
|
|
|
|
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
|
|
|
|
|
|
carry1 = torch.zeros(num_tests, pop_size, device=self.device) |
|
|
stage1_bits = [] |
|
|
for bit in range(bits): |
|
|
bit_idx = bits - 1 - bit |
|
|
s, carry1 = self._eval_single_fa( |
|
|
pop, f'{prefix}.stage1.fa{bit}', |
|
|
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
|
|
b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
|
|
carry1 |
|
|
) |
|
|
stage1_bits.append(s) |
|
|
|
|
|
|
|
|
carry2 = torch.zeros(num_tests, pop_size, device=self.device) |
|
|
result_bits = [] |
|
|
for bit in range(bits): |
|
|
bit_idx = bits - 1 - bit |
|
|
s, carry2 = self._eval_single_fa( |
|
|
pop, f'{prefix}.stage2.fa{bit}', |
|
|
stage1_bits[bit], |
|
|
c_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
|
|
carry2 |
|
|
) |
|
|
result_bits.append(s) |
|
|
|
|
|
|
|
|
result_bits = torch.stack(result_bits[::-1], dim=-1) |
|
|
result = torch.zeros(num_tests, pop_size, device=self.device) |
|
|
for i in range(bits): |
|
|
result += result_bits[:, :, i] * (1 << (bits - 1 - i)) |
|
|
|
|
|
|
|
|
expected = ((a_vals + b_vals + c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float() |
|
|
correct = (result == expected).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(min(num_tests, 100)): |
|
|
if result[i, 0].item() != expected[i, 0].item(): |
|
|
failures.append(( |
|
|
[int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())], |
|
|
int(expected[i, 0].item()), |
|
|
int(result[i, 0].item()) |
|
|
)) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), num_tests, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
if failures: |
|
|
for inp, exp, got in failures[:5]: |
|
|
print(f" FAIL: {inp[0]} + {inp[1]} + {inp[2]} = {exp}, got {got}") |
|
|
|
|
|
return correct, num_tests |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_expr_add_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test A + B × C expression circuit (order of operations).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== ORDER OF OPERATIONS (A + B × C) ===") |
|
|
|
|
|
prefix = 'arithmetic.expr_add_mul' |
|
|
bits = 8 |
|
|
|
|
|
|
|
|
test_cases = [] |
|
|
|
|
|
|
|
|
test_cases.extend([ |
|
|
(5, 3, 2), |
|
|
(10, 4, 3), |
|
|
(1, 10, 10), |
|
|
(0, 15, 17), |
|
|
(1, 15, 17), |
|
|
(100, 5, 5), |
|
|
]) |
|
|
|
|
|
|
|
|
test_cases.extend([ |
|
|
(0, 0, 0), |
|
|
(255, 0, 0), |
|
|
(0, 255, 1), |
|
|
(0, 1, 255), |
|
|
(1, 1, 1), |
|
|
(0, 16, 16), |
|
|
]) |
|
|
|
|
|
|
|
|
for a in [0, 1, 5, 10]: |
|
|
for b in [0, 1, 2, 3]: |
|
|
for c in [0, 1, 2, 3]: |
|
|
test_cases.append((a, b, c)) |
|
|
|
|
|
|
|
|
test_cases = list(set(test_cases)) |
|
|
|
|
|
a_vals = torch.tensor([t[0] for t in test_cases], device=self.device) |
|
|
b_vals = torch.tensor([t[1] for t in test_cases], device=self.device) |
|
|
c_vals = torch.tensor([t[2] for t in test_cases], device=self.device) |
|
|
num_tests = len(test_cases) |
|
|
|
|
|
|
|
|
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masks = torch.zeros(8, num_tests, pop_size, 8, device=self.device) |
|
|
for stage in range(8): |
|
|
c_stage_bit = c_bits[:, 7 - stage].unsqueeze(1).expand(-1, pop_size) |
|
|
for bit in range(8): |
|
|
b_bit_val = b_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) |
|
|
|
|
|
w = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.weight') |
|
|
bias = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.bias') |
|
|
if w is not None and bias is not None: |
|
|
w = w.squeeze(-1) |
|
|
b_tensor = bias.squeeze(-1) |
|
|
|
|
|
inp = torch.stack([b_bit_val, c_stage_bit], dim=-1) |
|
|
out = heaviside(torch.einsum('tpi,pi->tp', inp, w) + b_tensor) |
|
|
masks[stage, :, :, bit] = out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acc = masks[0].clone() |
|
|
|
|
|
for stage in range(1, 8): |
|
|
|
|
|
|
|
|
shifted_mask = torch.zeros(num_tests, pop_size, 8, device=self.device) |
|
|
for bit in range(8): |
|
|
if bit >= stage: |
|
|
shifted_mask[:, :, bit] = masks[stage, :, :, bit - stage] |
|
|
|
|
|
|
|
|
|
|
|
carry = torch.zeros(num_tests, pop_size, device=self.device) |
|
|
new_acc = torch.zeros(num_tests, pop_size, 8, device=self.device) |
|
|
for bit in range(8): |
|
|
s, carry = self._eval_single_fa( |
|
|
pop, f'{prefix}.mul.acc.s{stage}.fa{bit}', |
|
|
acc[:, :, bit], |
|
|
shifted_mask[:, :, bit], |
|
|
carry |
|
|
) |
|
|
new_acc[:, :, bit] = s |
|
|
acc = new_acc |
|
|
|
|
|
|
|
|
carry = torch.zeros(num_tests, pop_size, device=self.device) |
|
|
result_bits = [] |
|
|
for bit in range(8): |
|
|
a_bit_val = a_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) |
|
|
s, carry = self._eval_single_fa( |
|
|
pop, f'{prefix}.add.fa{bit}', |
|
|
a_bit_val, |
|
|
acc[:, :, bit], |
|
|
carry |
|
|
) |
|
|
result_bits.append(s) |
|
|
|
|
|
|
|
|
result_bits = torch.stack(result_bits[::-1], dim=-1) |
|
|
result = torch.zeros(num_tests, pop_size, device=self.device) |
|
|
for i in range(bits): |
|
|
result += result_bits[:, :, i] * (1 << (bits - 1 - i)) |
|
|
|
|
|
|
|
|
expected = ((a_vals + b_vals * c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float() |
|
|
correct = (result == expected).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(min(num_tests, 100)): |
|
|
if result[i, 0].item() != expected[i, 0].item(): |
|
|
failures.append(( |
|
|
[int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())], |
|
|
int(expected[i, 0].item()), |
|
|
int(result[i, 0].item()) |
|
|
)) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), num_tests, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
if failures: |
|
|
for inp, exp, got in failures[:5]: |
|
|
print(f" FAIL: {inp[0]} + {inp[1]} × {inp[2]} = {exp}, got {got}") |
|
|
|
|
|
return correct, num_tests |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_comparator(self, pop: Dict, name: str, op: Callable[[int, int], bool], |
|
|
debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test 8-bit comparator.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
prefix = f'arithmetic.{name}' |
|
|
|
|
|
|
|
|
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 |
|
|
for a, b in zip(self.comp_a, self.comp_b)], |
|
|
device=self.device) |
|
|
|
|
|
|
|
|
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
inputs = torch.cat([a_bits, b_bits], dim=1) |
|
|
|
|
|
w = pop[f'{prefix}.weight'] |
|
|
b = pop[f'{prefix}.bias'] |
|
|
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(len(self.comp_a)): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append(( |
|
|
[int(self.comp_a[i].item()), int(self.comp_b[i].item())], |
|
|
expected[i].item(), |
|
|
out[i, 0].item() |
|
|
)) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, len(self.comp_a) |
|
|
|
|
|
def _test_comparators(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test all comparators.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== COMPARATORS ===") |
|
|
|
|
|
comparators = [ |
|
|
('greaterthan8bit', lambda a, b: a > b), |
|
|
('lessthan8bit', lambda a, b: a < b), |
|
|
('greaterorequal8bit', lambda a, b: a >= b), |
|
|
('lessorequal8bit', lambda a, b: a <= b), |
|
|
('equality8bit', lambda a, b: a == b), |
|
|
] |
|
|
|
|
|
for name, op in comparators: |
|
|
if name == 'equality8bit': |
|
|
continue |
|
|
try: |
|
|
s, t = self._test_comparator(pop, name, op, debug) |
|
|
scores += s |
|
|
total += t |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
prefix = 'arithmetic.equality8bit' |
|
|
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0 |
|
|
for a, b in zip(self.comp_a, self.comp_b)], |
|
|
device=self.device) |
|
|
|
|
|
a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
inputs = torch.cat([a_bits, b_bits], dim=1) |
|
|
|
|
|
|
|
|
w_geq = pop[f'{prefix}.layer1.geq.weight'] |
|
|
b_geq = pop[f'{prefix}.layer1.geq.bias'] |
|
|
w_leq = pop[f'{prefix}.layer1.leq.weight'] |
|
|
b_leq = pop[f'{prefix}.layer1.leq.bias'] |
|
|
|
|
|
h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size)) |
|
|
h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size)) |
|
|
hidden = torch.stack([h_geq, h_leq], dim=-1) |
|
|
|
|
|
|
|
|
w2 = pop[f'{prefix}.layer2.weight'] |
|
|
b2 = pop[f'{prefix}.layer2.bias'] |
|
|
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(len(self.comp_a)): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append(( |
|
|
[int(self.comp_a[i].item()), int(self.comp_b[i].item())], |
|
|
expected[i].item(), |
|
|
out[i, 0].item() |
|
|
)) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), len(self.comp_a), failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += len(self.comp_a) |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_comparators_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit comparator circuits (GT, LT, GE, LE, EQ).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT COMPARATORS ===") |
|
|
|
|
|
if bits == 32: |
|
|
comp_a = self.comp32_a |
|
|
comp_b = self.comp32_b |
|
|
elif bits == 16: |
|
|
comp_a = self.comp_a.clamp(0, 65535) |
|
|
comp_b = self.comp_b.clamp(0, 65535) |
|
|
else: |
|
|
comp_a = self.comp_a |
|
|
comp_b = self.comp_b |
|
|
|
|
|
num_tests = len(comp_a) |
|
|
|
|
|
if bits <= 16: |
|
|
a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
inputs = torch.cat([a_bits, b_bits], dim=1) |
|
|
|
|
|
comparators = [ |
|
|
(f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b), |
|
|
(f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b), |
|
|
(f'arithmetic.lessthan{bits}bit', lambda a, b: a < b), |
|
|
(f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b), |
|
|
] |
|
|
|
|
|
for name, op in comparators: |
|
|
try: |
|
|
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 |
|
|
for a, b in zip(comp_a, comp_b)], device=self.device) |
|
|
w = pop[f'{name}.weight'] |
|
|
b = pop[f'{name}.bias'] |
|
|
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(num_tests): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], |
|
|
expected[i].item(), out[i, 0].item())) |
|
|
self._record(name, int(correct[0].item()), num_tests, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += num_tests |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
prefix = f'arithmetic.equality{bits}bit' |
|
|
try: |
|
|
expected = torch.tensor([1.0 if a.item() == b.item() else 0.0 |
|
|
for a, b in zip(comp_a, comp_b)], device=self.device) |
|
|
w_geq = pop[f'{prefix}.layer1.geq.weight'] |
|
|
b_geq = pop[f'{prefix}.layer1.geq.bias'] |
|
|
w_leq = pop[f'{prefix}.layer1.leq.weight'] |
|
|
b_leq = pop[f'{prefix}.layer1.leq.bias'] |
|
|
h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size)) |
|
|
h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size)) |
|
|
hidden = torch.stack([h_geq, h_leq], dim=-1) |
|
|
w2 = pop[f'{prefix}.layer2.weight'] |
|
|
b2 = pop[f'{prefix}.layer2.bias'] |
|
|
out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(num_tests): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], |
|
|
expected[i].item(), out[i, 0].item())) |
|
|
self._record(prefix, int(correct[0].item()), num_tests, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += num_tests |
|
|
except KeyError: |
|
|
pass |
|
|
else: |
|
|
num_bytes = bits // 8 |
|
|
prefix = f"arithmetic.cmp{bits}bit" |
|
|
|
|
|
byte_gt = [] |
|
|
byte_lt = [] |
|
|
byte_eq = [] |
|
|
|
|
|
for b in range(num_bytes): |
|
|
start_bit = b * 8 |
|
|
a_byte = torch.stack([((comp_a >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1) |
|
|
b_byte = torch.stack([((comp_b >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1) |
|
|
byte_input = torch.cat([a_byte, b_byte], dim=1) |
|
|
|
|
|
w_gt = pop[f'{prefix}.byte{b}.gt.weight'].view(pop_size, -1) |
|
|
b_gt = pop[f'{prefix}.byte{b}.gt.bias'].view(pop_size) |
|
|
byte_gt.append(heaviside(byte_input @ w_gt.T + b_gt)) |
|
|
|
|
|
w_lt = pop[f'{prefix}.byte{b}.lt.weight'].view(pop_size, -1) |
|
|
b_lt = pop[f'{prefix}.byte{b}.lt.bias'].view(pop_size) |
|
|
byte_lt.append(heaviside(byte_input @ w_lt.T + b_lt)) |
|
|
|
|
|
w_geq = pop[f'{prefix}.byte{b}.eq.geq.weight'].view(pop_size, -1) |
|
|
b_geq = pop[f'{prefix}.byte{b}.eq.geq.bias'].view(pop_size) |
|
|
w_leq = pop[f'{prefix}.byte{b}.eq.leq.weight'].view(pop_size, -1) |
|
|
b_leq = pop[f'{prefix}.byte{b}.eq.leq.bias'].view(pop_size) |
|
|
h_geq = heaviside(byte_input @ w_geq.T + b_geq) |
|
|
h_leq = heaviside(byte_input @ w_leq.T + b_leq) |
|
|
w_and = pop[f'{prefix}.byte{b}.eq.and.weight'].view(pop_size, -1) |
|
|
b_and = pop[f'{prefix}.byte{b}.eq.and.bias'].view(pop_size) |
|
|
eq_inp = torch.stack([h_geq, h_leq], dim=-1) |
|
|
byte_eq.append(heaviside((eq_inp * w_and).sum(-1) + b_and)) |
|
|
|
|
|
cascade_gt = [] |
|
|
cascade_lt = [] |
|
|
for b in range(num_bytes): |
|
|
if b == 0: |
|
|
cascade_gt.append(byte_gt[0]) |
|
|
cascade_lt.append(byte_lt[0]) |
|
|
else: |
|
|
eq_stack = torch.stack(byte_eq[:b], dim=-1) |
|
|
w_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.weight'].view(pop_size, -1) |
|
|
b_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.bias'].view(pop_size) |
|
|
all_eq_gt = heaviside((eq_stack * w_all_eq).sum(-1) + b_all_eq) |
|
|
w_and = pop[f'{prefix}.cascade.gt.stage{b}.and.weight'].view(pop_size, -1) |
|
|
b_and = pop[f'{prefix}.cascade.gt.stage{b}.and.bias'].view(pop_size) |
|
|
stage_inp = torch.stack([all_eq_gt, byte_gt[b]], dim=-1) |
|
|
cascade_gt.append(heaviside((stage_inp * w_and).sum(-1) + b_and)) |
|
|
|
|
|
w_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.weight'].view(pop_size, -1) |
|
|
b_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.bias'].view(pop_size) |
|
|
all_eq_lt = heaviside((eq_stack * w_all_eq_lt).sum(-1) + b_all_eq_lt) |
|
|
w_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.weight'].view(pop_size, -1) |
|
|
b_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.bias'].view(pop_size) |
|
|
stage_inp_lt = torch.stack([all_eq_lt, byte_lt[b]], dim=-1) |
|
|
cascade_lt.append(heaviside((stage_inp_lt * w_and_lt).sum(-1) + b_and_lt)) |
|
|
|
|
|
gt_stack = torch.stack(cascade_gt, dim=-1) |
|
|
w_gt_or = pop[f'arithmetic.greaterthan{bits}bit.weight'].view(pop_size, -1) |
|
|
b_gt_or = pop[f'arithmetic.greaterthan{bits}bit.bias'].view(pop_size) |
|
|
gt_out = heaviside((gt_stack * w_gt_or).sum(-1) + b_gt_or) |
|
|
|
|
|
lt_stack = torch.stack(cascade_lt, dim=-1) |
|
|
w_lt_or = pop[f'arithmetic.lessthan{bits}bit.weight'].view(pop_size, -1) |
|
|
b_lt_or = pop[f'arithmetic.lessthan{bits}bit.bias'].view(pop_size) |
|
|
lt_out = heaviside((lt_stack * w_lt_or).sum(-1) + b_lt_or) |
|
|
|
|
|
w_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.weight'].view(pop_size, -1) |
|
|
b_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.bias'].view(pop_size) |
|
|
not_lt = heaviside(lt_out.unsqueeze(-1) @ w_not_lt.T + b_not_lt).squeeze(-1) |
|
|
w_ge = pop[f'arithmetic.greaterorequal{bits}bit.weight'].view(pop_size, -1) |
|
|
b_ge = pop[f'arithmetic.greaterorequal{bits}bit.bias'].view(pop_size) |
|
|
ge_out = heaviside(not_lt.unsqueeze(-1) @ w_ge.T + b_ge).squeeze(-1) |
|
|
|
|
|
w_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.weight'].view(pop_size, -1) |
|
|
b_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.bias'].view(pop_size) |
|
|
not_gt = heaviside(gt_out.unsqueeze(-1) @ w_not_gt.T + b_not_gt).squeeze(-1) |
|
|
w_le = pop[f'arithmetic.lessorequal{bits}bit.weight'].view(pop_size, -1) |
|
|
b_le = pop[f'arithmetic.lessorequal{bits}bit.bias'].view(pop_size) |
|
|
le_out = heaviside(not_gt.unsqueeze(-1) @ w_le.T + b_le).squeeze(-1) |
|
|
|
|
|
eq_stack = torch.stack(byte_eq, dim=-1) |
|
|
w_eq_all = pop[f'arithmetic.equality{bits}bit.weight'].view(pop_size, -1) |
|
|
b_eq_all = pop[f'arithmetic.equality{bits}bit.bias'].view(pop_size) |
|
|
eq_out = heaviside((eq_stack * w_eq_all).sum(-1) + b_eq_all) |
|
|
|
|
|
for name, out, op in [ |
|
|
(f'arithmetic.greaterthan{bits}bit', gt_out, lambda a, b: a > b), |
|
|
(f'arithmetic.greaterorequal{bits}bit', ge_out, lambda a, b: a >= b), |
|
|
(f'arithmetic.lessthan{bits}bit', lt_out, lambda a, b: a < b), |
|
|
(f'arithmetic.lessorequal{bits}bit', le_out, lambda a, b: a <= b), |
|
|
(f'arithmetic.equality{bits}bit', eq_out, lambda a, b: a == b), |
|
|
]: |
|
|
expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 |
|
|
for a, b in zip(comp_a, comp_b)], device=self.device) |
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(num_tests): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], |
|
|
expected[i].item(), out[i, 0].item())) |
|
|
self._record(name, int(correct[0].item()), num_tests, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += num_tests |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_subtractor_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit subtractor circuit (A - B).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT SUBTRACTOR ===") |
|
|
|
|
|
prefix = f'arithmetic.sub{bits}bit' |
|
|
max_val = 1 << bits |
|
|
|
|
|
if bits == 32: |
|
|
test_pairs = [ |
|
|
(1000, 500), (5000, 3000), (1000000, 500000), |
|
|
(0xFFFFFFFF, 1), (0x80000000, 1), (100, 100), |
|
|
(0, 0), (1, 0), (0, 1), (256, 255), |
|
|
(0xDEADBEEF, 0xCAFEBABE), (1000000000, 999999999), |
|
|
] |
|
|
else: |
|
|
test_pairs = [(a, b) for a in [0, 1, 127, 128, 255] for b in [0, 1, 127, 128, 255]] |
|
|
|
|
|
a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long) |
|
|
b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long) |
|
|
num_tests = len(test_pairs) |
|
|
|
|
|
a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
|
|
|
|
|
not_b_bits = torch.zeros_like(b_bits) |
|
|
for bit in range(bits): |
|
|
w = pop[f'{prefix}.not_b.bit{bit}.weight'].view(pop_size, -1) |
|
|
b = pop[f'{prefix}.not_b.bit{bit}.bias'].view(pop_size) |
|
|
not_b_bits[:, bit] = heaviside(b_bits[:, bit:bit+1] @ w.T + b)[:, 0] |
|
|
|
|
|
carry = torch.ones(num_tests, pop_size, device=self.device) |
|
|
sum_bits = [] |
|
|
|
|
|
for bit in range(bits): |
|
|
bit_idx = bits - 1 - bit |
|
|
s, carry = self._eval_single_fa( |
|
|
pop, f'{prefix}.fa{bit}', |
|
|
a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
|
|
not_b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
|
|
carry |
|
|
) |
|
|
sum_bits.append(s) |
|
|
|
|
|
sum_bits = torch.stack(sum_bits[::-1], dim=-1) |
|
|
result = torch.zeros(num_tests, pop_size, device=self.device) |
|
|
for i in range(bits): |
|
|
result += sum_bits[:, :, i] * (1 << (bits - 1 - i)) |
|
|
|
|
|
expected = ((a_vals - b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float() |
|
|
correct = (result == expected).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(min(num_tests, 20)): |
|
|
if result[i, 0].item() != expected[i, 0].item(): |
|
|
failures.append(( |
|
|
[int(a_vals[i].item()), int(b_vals[i].item())], |
|
|
int(expected[i, 0].item()), |
|
|
int(result[i, 0].item()) |
|
|
)) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), num_tests, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, num_tests |
|
|
|
|
|
def _test_bitwise_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit bitwise operations (AND, OR, XOR, NOT).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT BITWISE OPS ===") |
|
|
|
|
|
if bits == 32: |
|
|
test_pairs = [ |
|
|
(0xAAAAAAAA, 0x55555555), (0xFFFFFFFF, 0x00000000), |
|
|
(0x12345678, 0x87654321), (0xDEADBEEF, 0xCAFEBABE), |
|
|
(0x0F0F0F0F, 0xF0F0F0F0), (0, 0), (0xFFFFFFFF, 0xFFFFFFFF), |
|
|
] |
|
|
else: |
|
|
test_pairs = [(0xAA, 0x55), (0xFF, 0x00), (0x0F, 0xF0)] |
|
|
|
|
|
a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long) |
|
|
b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long) |
|
|
num_tests = len(test_pairs) |
|
|
|
|
|
ops = [ |
|
|
('and', lambda a, b: a & b), |
|
|
('or', lambda a, b: a | b), |
|
|
('xor', lambda a, b: a ^ b), |
|
|
] |
|
|
|
|
|
for op_name, op_fn in ops: |
|
|
try: |
|
|
result_bits = [] |
|
|
for bit in range(bits): |
|
|
a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float() |
|
|
b_bit = ((b_vals >> (bits - 1 - bit)) & 1).float() |
|
|
|
|
|
if op_name == 'xor': |
|
|
prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}' |
|
|
w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1) |
|
|
b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1) |
|
|
b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size) |
|
|
inp = torch.stack([a_bit, b_bit], dim=-1) |
|
|
h_or = heaviside(inp @ w_or.T + b_or) |
|
|
h_nand = heaviside(inp @ w_nand.T + b_nand) |
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1) |
|
|
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size) |
|
|
out = heaviside((hidden * w2).sum(-1) + b2) |
|
|
else: |
|
|
w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size, -1) |
|
|
b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size) |
|
|
inp = torch.stack([a_bit, b_bit], dim=-1) |
|
|
out = heaviside(inp @ w.T + b) |
|
|
|
|
|
result_bits.append(out[:, 0] if out.dim() > 1 else out) |
|
|
|
|
|
result = sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
|
|
for i in range(bits) for j in range(1)) |
|
|
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
|
|
for i in range(bits)) for j in range(num_tests)], |
|
|
device=self.device) |
|
|
expected = torch.tensor([op_fn(a.item(), b.item()) for a, b in zip(a_vals, b_vals)], |
|
|
device=self.device) |
|
|
|
|
|
correct = (results == expected).float().sum() |
|
|
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += num_tests |
|
|
except KeyError as e: |
|
|
if debug: |
|
|
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") |
|
|
|
|
|
try: |
|
|
test_vals = a_vals |
|
|
result_bits = [] |
|
|
for bit in range(bits): |
|
|
a_bit = ((test_vals >> (bits - 1 - bit)) & 1).float() |
|
|
w = pop[f'alu.alu{bits}bit.not.bit{bit}.weight'].view(pop_size, -1) |
|
|
b = pop[f'alu.alu{bits}bit.not.bit{bit}.bias'].view(pop_size) |
|
|
out = heaviside(a_bit.unsqueeze(-1) @ w.T + b) |
|
|
result_bits.append(out[:, 0]) |
|
|
|
|
|
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
|
|
for i in range(bits)) for j in range(num_tests)], |
|
|
device=self.device) |
|
|
expected = torch.tensor([(~a.item()) & ((1 << bits) - 1) for a in test_vals], |
|
|
device=self.device) |
|
|
|
|
|
correct = (results == expected).float().sum() |
|
|
self._record(f'alu.alu{bits}bit.not', int(correct.item()), num_tests, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += num_tests |
|
|
except KeyError as e: |
|
|
if debug: |
|
|
print(f" alu.alu{bits}bit.not: SKIP (missing {e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_shifts_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit shift operations (SHL, SHR).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT SHIFTS ===") |
|
|
|
|
|
if bits == 32: |
|
|
test_vals = [0x12345678, 0x80000001, 0x00000001, 0xFFFFFFFF, 0x55555555] |
|
|
else: |
|
|
test_vals = [0x81, 0x55, 0x01, 0xFF, 0xAA] |
|
|
|
|
|
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) |
|
|
num_tests = len(test_vals) |
|
|
max_val = (1 << bits) - 1 |
|
|
|
|
|
for op_name, op_fn in [('shl', lambda x: (x << 1) & max_val), ('shr', lambda x: x >> 1)]: |
|
|
try: |
|
|
result_bits = [] |
|
|
for bit in range(bits): |
|
|
a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float() |
|
|
w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size) |
|
|
b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size) |
|
|
|
|
|
if op_name == 'shl': |
|
|
if bit < bits - 1: |
|
|
src_bit = ((a_vals >> (bits - 2 - bit)) & 1).float() |
|
|
else: |
|
|
src_bit = torch.zeros_like(a_bit) |
|
|
else: |
|
|
if bit > 0: |
|
|
src_bit = ((a_vals >> (bits - bit)) & 1).float() |
|
|
else: |
|
|
src_bit = torch.zeros_like(a_bit) |
|
|
|
|
|
out = heaviside(src_bit * w + b) |
|
|
result_bits.append(out) |
|
|
|
|
|
results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
|
|
for i in range(bits)) for j in range(num_tests)], |
|
|
device=self.device) |
|
|
expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device) |
|
|
|
|
|
correct = (results == expected).float().sum() |
|
|
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += num_tests |
|
|
except KeyError as e: |
|
|
if debug: |
|
|
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_inc_dec_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit INC and DEC operations.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT INC/DEC ===") |
|
|
|
|
|
if bits == 32: |
|
|
test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000000, 0xFFFFFFFE] |
|
|
else: |
|
|
test_vals = [0, 1, 254, 255, 127, 128] |
|
|
|
|
|
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) |
|
|
num_tests = len(test_vals) |
|
|
max_val = (1 << bits) - 1 |
|
|
|
|
|
for op_name, op_fn in [('inc', lambda x: (x + 1) & max_val), ('dec', lambda x: (x - 1) & max_val)]: |
|
|
try: |
|
|
carry = torch.ones(num_tests, device=self.device) |
|
|
result_bits = [] |
|
|
|
|
|
for bit in range(bits): |
|
|
a_bit = ((a_vals >> bit) & 1).float() |
|
|
|
|
|
prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}' |
|
|
w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten() |
|
|
b_or = pop[f'{prefix}.xor.layer1.or.bias'].item() |
|
|
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten() |
|
|
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item() |
|
|
|
|
|
h_or = heaviside(a_bit * w_or[0] + carry * w_or[1] + b_or) |
|
|
h_nand = heaviside(a_bit * w_nand[0] + carry * w_nand[1] + b_nand) |
|
|
|
|
|
w2 = pop[f'{prefix}.xor.layer2.weight'].flatten() |
|
|
b2 = pop[f'{prefix}.xor.layer2.bias'].item() |
|
|
xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2) |
|
|
result_bits.append(xor_out) |
|
|
|
|
|
if op_name == 'inc': |
|
|
w_carry = pop[f'{prefix}.carry.weight'].flatten() |
|
|
b_carry = pop[f'{prefix}.carry.bias'].item() |
|
|
carry = heaviside(a_bit * w_carry[0] + carry * w_carry[1] + b_carry) |
|
|
else: |
|
|
w_not = pop[f'{prefix}.not_a.weight'].flatten() |
|
|
b_not = pop[f'{prefix}.not_a.bias'].item() |
|
|
not_a = heaviside(a_bit * w_not[0] + b_not) |
|
|
w_borrow = pop[f'{prefix}.borrow.weight'].flatten() |
|
|
b_borrow = pop[f'{prefix}.borrow.bias'].item() |
|
|
carry = heaviside(not_a * w_borrow[0] + carry * w_borrow[1] + b_borrow) |
|
|
|
|
|
results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit |
|
|
for bit in range(bits)) for j in range(num_tests)], |
|
|
device=self.device) |
|
|
expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device) |
|
|
|
|
|
correct = (results == expected).float().sum() |
|
|
self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
scores += correct |
|
|
total += num_tests |
|
|
except KeyError as e: |
|
|
if debug: |
|
|
print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_neg_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit NEG operation (two's complement negation).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT NEG ===") |
|
|
|
|
|
if bits == 32: |
|
|
test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000, 1000000] |
|
|
else: |
|
|
test_vals = [0, 1, 127, 128, 255, 100] |
|
|
|
|
|
a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) |
|
|
num_tests = len(test_vals) |
|
|
max_val = (1 << bits) - 1 |
|
|
|
|
|
try: |
|
|
not_bits = [] |
|
|
for bit in range(bits): |
|
|
a_bit = ((a_vals >> bit) & 1).float() |
|
|
w = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.weight'].flatten() |
|
|
b = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.bias'].item() |
|
|
not_bits.append(heaviside(a_bit * w[0] + b)) |
|
|
|
|
|
carry = torch.ones(num_tests, device=self.device) |
|
|
result_bits = [] |
|
|
|
|
|
for bit in range(bits): |
|
|
prefix = f'alu.alu{bits}bit.neg.inc.bit{bit}' |
|
|
not_bit = not_bits[bit] |
|
|
|
|
|
w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten() |
|
|
b_or = pop[f'{prefix}.xor.layer1.or.bias'].item() |
|
|
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten() |
|
|
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item() |
|
|
|
|
|
h_or = heaviside(not_bit * w_or[0] + carry * w_or[1] + b_or) |
|
|
h_nand = heaviside(not_bit * w_nand[0] + carry * w_nand[1] + b_nand) |
|
|
|
|
|
w2 = pop[f'{prefix}.xor.layer2.weight'].flatten() |
|
|
b2 = pop[f'{prefix}.xor.layer2.bias'].item() |
|
|
xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2) |
|
|
result_bits.append(xor_out) |
|
|
|
|
|
w_carry = pop[f'{prefix}.carry.weight'].flatten() |
|
|
b_carry = pop[f'{prefix}.carry.bias'].item() |
|
|
carry = heaviside(not_bit * w_carry[0] + carry * w_carry[1] + b_carry) |
|
|
|
|
|
results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit |
|
|
for bit in range(bits)) for j in range(num_tests)], |
|
|
device=self.device) |
|
|
expected = torch.tensor([(-a.item()) & max_val for a in a_vals], device=self.device) |
|
|
|
|
|
correct = (results == expected).float().sum() |
|
|
self._record(f'alu.alu{bits}bit.neg', int(correct.item()), num_tests, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return torch.tensor([correct], device=self.device), num_tests |
|
|
except KeyError as e: |
|
|
if debug: |
|
|
print(f" alu.alu{bits}bit.neg: SKIP (missing {e})") |
|
|
return torch.zeros(pop_size, device=self.device), 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_threshold_kofn(self, pop: Dict, k: int, name: str, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test k-of-n threshold gate.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
prefix = f'threshold.{name}' |
|
|
|
|
|
|
|
|
inputs = self.test_8bit_bits if len(self.test_8bit_bits) == 24 else None |
|
|
if inputs is None: |
|
|
test_vals = torch.arange(256, device=self.device, dtype=torch.long) |
|
|
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
popcounts = inputs.sum(dim=1) |
|
|
|
|
|
if 'atleast' in name: |
|
|
expected = (popcounts >= k).float() |
|
|
elif 'atmost' in name or 'minority' in name: |
|
|
|
|
|
expected = (popcounts <= k).float() |
|
|
elif 'exactly' in name: |
|
|
expected = (popcounts == k).float() |
|
|
else: |
|
|
|
|
|
expected = (popcounts >= k).float() |
|
|
|
|
|
w = pop[f'{prefix}.weight'] |
|
|
b = pop[f'{prefix}.bias'] |
|
|
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(min(len(inputs), 256)): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
val = int(sum(inputs[i, j].item() * (1 << (7 - j)) for j in range(8))) |
|
|
failures.append((val, expected[i].item(), out[i, 0].item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), len(inputs), failures[:10]) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, len(inputs) |
|
|
|
|
|
def _test_threshold_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test all threshold gates.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== THRESHOLD GATES ===") |
|
|
|
|
|
|
|
|
kofn_gates = [ |
|
|
(1, 'oneoutof8'), (2, 'twooutof8'), (3, 'threeoutof8'), (4, 'fouroutof8'), |
|
|
(5, 'fiveoutof8'), (6, 'sixoutof8'), (7, 'sevenoutof8'), (8, 'alloutof8'), |
|
|
] |
|
|
|
|
|
for k, name in kofn_gates: |
|
|
try: |
|
|
s, t = self._test_threshold_kofn(pop, k, name, debug) |
|
|
scores += s |
|
|
total += t |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
|
|
|
special = [ |
|
|
(5, 'majority'), (3, 'minority'), |
|
|
(4, 'atleastk_4'), (4, 'atmostk_4'), (4, 'exactlyk_4'), |
|
|
] |
|
|
|
|
|
for k, name in special: |
|
|
try: |
|
|
s, t = self._test_threshold_kofn(pop, k, name, debug) |
|
|
scores += s |
|
|
total += t |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_modular(self, pop: Dict, mod: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test modular divisibility circuit (multi-layer for non-powers-of-2).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
prefix = f'modular.mod{mod}' |
|
|
|
|
|
|
|
|
inputs = torch.stack([((self.mod_test >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
expected = ((self.mod_test % mod) == 0).float() |
|
|
|
|
|
|
|
|
try: |
|
|
w = pop[f'{prefix}.weight'] |
|
|
b = pop[f'{prefix}.bias'] |
|
|
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
except KeyError: |
|
|
|
|
|
try: |
|
|
|
|
|
geq_outputs = {} |
|
|
leq_outputs = {} |
|
|
i = 0 |
|
|
while True: |
|
|
found = False |
|
|
if f'{prefix}.layer1.geq{i}.weight' in pop: |
|
|
w = pop[f'{prefix}.layer1.geq{i}.weight'].view(pop_size, -1) |
|
|
b = pop[f'{prefix}.layer1.geq{i}.bias'].view(pop_size) |
|
|
geq_outputs[i] = heaviside(inputs @ w.T + b) |
|
|
found = True |
|
|
if f'{prefix}.layer1.leq{i}.weight' in pop: |
|
|
w = pop[f'{prefix}.layer1.leq{i}.weight'].view(pop_size, -1) |
|
|
b = pop[f'{prefix}.layer1.leq{i}.bias'].view(pop_size) |
|
|
leq_outputs[i] = heaviside(inputs @ w.T + b) |
|
|
found = True |
|
|
if not found: |
|
|
break |
|
|
i += 1 |
|
|
|
|
|
if not geq_outputs and not leq_outputs: |
|
|
return torch.zeros(pop_size, device=self.device), 0 |
|
|
|
|
|
|
|
|
eq_outputs = [] |
|
|
i = 0 |
|
|
while f'{prefix}.layer2.eq{i}.weight' in pop: |
|
|
w = pop[f'{prefix}.layer2.eq{i}.weight'].view(pop_size, -1) |
|
|
b = pop[f'{prefix}.layer2.eq{i}.bias'].view(pop_size) |
|
|
|
|
|
eq_in = torch.stack([geq_outputs.get(i, torch.zeros(256, pop_size, device=self.device)), |
|
|
leq_outputs.get(i, torch.zeros(256, pop_size, device=self.device))], dim=-1) |
|
|
eq_out = heaviside((eq_in * w).sum(-1) + b) |
|
|
eq_outputs.append(eq_out) |
|
|
i += 1 |
|
|
|
|
|
if not eq_outputs: |
|
|
return torch.zeros(pop_size, device=self.device), 0 |
|
|
|
|
|
|
|
|
eq_stack = torch.stack(eq_outputs, dim=-1) |
|
|
w3 = pop[f'{prefix}.layer3.or.weight'].view(pop_size, -1) |
|
|
b3 = pop[f'{prefix}.layer3.or.bias'].view(pop_size) |
|
|
out = heaviside((eq_stack * w3).sum(-1) + b3) |
|
|
|
|
|
except Exception as e: |
|
|
return torch.zeros(pop_size, device=self.device), 0 |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(256): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append((i, expected[i].item(), out[i, 0].item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), 256, failures[:10]) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, 256 |
|
|
|
|
|
def _test_modular_all(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test all modular arithmetic circuits.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== MODULAR ARITHMETIC ===") |
|
|
|
|
|
for mod in range(2, 13): |
|
|
s, t = self._test_modular(pop, mod, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_pattern(self, pop: Dict, name: str, expected_fn: Callable[[int], float], |
|
|
debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test pattern recognition circuit.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
prefix = f'pattern_recognition.{name}' |
|
|
|
|
|
test_vals = torch.arange(256, device=self.device, dtype=torch.long) |
|
|
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
expected = torch.tensor([expected_fn(v.item()) for v in test_vals], device=self.device) |
|
|
|
|
|
try: |
|
|
w = pop[f'{prefix}.weight'].view(pop_size, -1) |
|
|
b = pop[f'{prefix}.bias'].view(pop_size) |
|
|
out = heaviside(inputs @ w.T + b) |
|
|
except KeyError: |
|
|
return torch.zeros(pop_size, device=self.device), 0 |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(256): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append((i, expected[i].item(), out[i, 0].item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), 256, failures[:10]) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, 256 |
|
|
|
|
|
def _test_patterns(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test pattern recognition circuits.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== PATTERN RECOGNITION ===") |
|
|
|
|
|
|
|
|
patterns = [ |
|
|
('allzeros', lambda v: 1.0 if v == 0 else 0.0), |
|
|
('allones', lambda v: 1.0 if v == 255 else 0.0), |
|
|
] |
|
|
|
|
|
for name, fn in patterns: |
|
|
s, t = self._test_pattern(pop, name, fn, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _eval_xor_tree_stage(self, pop: Dict, prefix: str, stage: int, idx: int, |
|
|
a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
|
|
"""Evaluate a single XOR in the parity tree.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
xor_prefix = f'{prefix}.stage{stage}.xor{idx}' |
|
|
|
|
|
|
|
|
if a.dim() == 1: |
|
|
a = a.unsqueeze(1).expand(-1, pop_size) |
|
|
if b.dim() == 1: |
|
|
b = b.unsqueeze(1).expand(-1, pop_size) |
|
|
|
|
|
|
|
|
w_or = pop[f'{xor_prefix}.layer1.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'{xor_prefix}.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'{xor_prefix}.layer1.nand.weight'].view(pop_size, 2) |
|
|
b_nand = pop[f'{xor_prefix}.layer1.nand.bias'].view(pop_size) |
|
|
|
|
|
inputs = torch.stack([a, b], dim=-1) |
|
|
h_or = heaviside((inputs * w_or).sum(-1) + b_or) |
|
|
h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand) |
|
|
|
|
|
|
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
w2 = pop[f'{xor_prefix}.layer2.weight'].view(pop_size, 2) |
|
|
b2 = pop[f'{xor_prefix}.layer2.bias'].view(pop_size) |
|
|
return heaviside((hidden * w2).sum(-1) + b2) |
|
|
|
|
|
def _test_parity_xor_tree(self, pop: Dict, prefix: str, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test parity circuit with XOR tree structure.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
|
|
|
test_vals = torch.arange(256, device=self.device, dtype=torch.long) |
|
|
inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
|
|
|
|
|
|
|
|
popcounts = inputs.sum(dim=1) |
|
|
xor_result = (popcounts.long() % 2).float() |
|
|
|
|
|
try: |
|
|
|
|
|
s1_out = [] |
|
|
for i in range(4): |
|
|
xor_out = self._eval_xor_tree_stage(pop, prefix, 1, i, inputs[:, i*2], inputs[:, i*2+1]) |
|
|
s1_out.append(xor_out) |
|
|
|
|
|
|
|
|
s2_out = [] |
|
|
for i in range(2): |
|
|
xor_out = self._eval_xor_tree_stage(pop, prefix, 2, i, s1_out[i*2], s1_out[i*2+1]) |
|
|
s2_out.append(xor_out) |
|
|
|
|
|
|
|
|
s3_out = self._eval_xor_tree_stage(pop, prefix, 3, 0, s2_out[0], s2_out[1]) |
|
|
|
|
|
|
|
|
if f'{prefix}.output.not.weight' in pop: |
|
|
w_not = pop[f'{prefix}.output.not.weight'].view(pop_size) |
|
|
b_not = pop[f'{prefix}.output.not.bias'].view(pop_size) |
|
|
out = heaviside(s3_out * w_not + b_not) |
|
|
|
|
|
expected = 1.0 - xor_result |
|
|
else: |
|
|
out = s3_out |
|
|
expected = xor_result |
|
|
|
|
|
except KeyError as e: |
|
|
return torch.zeros(pop_size, device=self.device), 0 |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(256): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append((i, expected[i].item(), out[i, 0].item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), 256, failures[:10]) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, 256 |
|
|
|
|
|
def _test_error_detection(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test error detection circuits.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== ERROR DETECTION ===") |
|
|
|
|
|
|
|
|
for prefix in ['error_detection.paritychecker8bit', 'error_detection.paritygenerator8bit']: |
|
|
s, t = self._test_parity_xor_tree(pop, prefix, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_mux2to1(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test 2-to-1 multiplexer.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
prefix = 'combinational.multiplexer2to1' |
|
|
|
|
|
|
|
|
inputs = torch.tensor([ |
|
|
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
|
|
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], |
|
|
], device=self.device, dtype=torch.float32) |
|
|
expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32) |
|
|
|
|
|
try: |
|
|
w = pop[f'{prefix}.weight'] |
|
|
b = pop[f'{prefix}.bias'] |
|
|
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
except KeyError: |
|
|
return torch.zeros(pop_size, device=self.device), 0 |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(8): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), 8, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return correct, 8 |
|
|
|
|
|
def _test_decoder3to8(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test 3-to-8 decoder.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== DECODER 3-TO-8 ===") |
|
|
|
|
|
inputs = torch.tensor([ |
|
|
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
|
|
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], |
|
|
], device=self.device, dtype=torch.float32) |
|
|
|
|
|
for out_idx in range(8): |
|
|
prefix = f'combinational.decoder3to8.out{out_idx}' |
|
|
expected = torch.zeros(8, device=self.device) |
|
|
expected[out_idx] = 1.0 |
|
|
|
|
|
try: |
|
|
w = pop[f'{prefix}.weight'] |
|
|
b = pop[f'{prefix}.bias'] |
|
|
out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
|
|
except KeyError: |
|
|
continue |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
scores += correct |
|
|
total += 8 |
|
|
|
|
|
failures = [] |
|
|
if pop_size == 1: |
|
|
for i in range(8): |
|
|
if out[i, 0].item() != expected[i].item(): |
|
|
failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item())) |
|
|
|
|
|
self._record(prefix, int(correct[0].item()), 8, failures) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_combinational(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test combinational logic circuits.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== COMBINATIONAL LOGIC ===") |
|
|
|
|
|
s, t = self._test_mux2to1(pop, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
s, t = self._test_decoder3to8(pop, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
s, t = self._test_barrel_shifter(pop, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
s, t = self._test_priority_encoder(pop, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_barrel_shifter(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test barrel shifter (shift by 0-7 positions).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== BARREL SHIFTER ===") |
|
|
|
|
|
try: |
|
|
|
|
|
test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, 0xFF] |
|
|
|
|
|
for val in test_vals: |
|
|
for shift in range(8): |
|
|
expected_val = (val << shift) & 0xFF |
|
|
val_bits = [float((val >> (7 - i)) & 1) for i in range(8)] |
|
|
shift_bits = [float((shift >> (2 - i)) & 1) for i in range(3)] |
|
|
|
|
|
|
|
|
layer_in = val_bits[:] |
|
|
for layer in range(3): |
|
|
shift_amount = 1 << (2 - layer) |
|
|
sel = shift_bits[layer] |
|
|
layer_out = [] |
|
|
|
|
|
for bit in range(8): |
|
|
prefix = f'combinational.barrelshifter.layer{layer}.bit{bit}' |
|
|
|
|
|
|
|
|
w_not = pop[f'{prefix}.not_sel.weight'].view(pop_size) |
|
|
b_not = pop[f'{prefix}.not_sel.bias'].view(pop_size) |
|
|
not_sel = heaviside(sel * w_not + b_not) |
|
|
|
|
|
|
|
|
shifted_src = bit + shift_amount |
|
|
if shifted_src < 8: |
|
|
shifted_val = layer_in[shifted_src] |
|
|
else: |
|
|
shifted_val = 0.0 |
|
|
|
|
|
|
|
|
w_and_a = pop[f'{prefix}.and_a.weight'].view(pop_size, 2) |
|
|
b_and_a = pop[f'{prefix}.and_a.bias'].view(pop_size) |
|
|
inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device) |
|
|
and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a) |
|
|
|
|
|
|
|
|
w_and_b = pop[f'{prefix}.and_b.weight'].view(pop_size, 2) |
|
|
b_and_b = pop[f'{prefix}.and_b.bias'].view(pop_size) |
|
|
inp_b = torch.tensor([shifted_val, sel], device=self.device) |
|
|
and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b) |
|
|
|
|
|
|
|
|
w_or = pop[f'{prefix}.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'{prefix}.or.bias'].view(pop_size) |
|
|
inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device) |
|
|
out = heaviside((inp_or * w_or).sum(-1) + b_or) |
|
|
layer_out.append(out[0].item()) |
|
|
|
|
|
layer_in = layer_out |
|
|
|
|
|
|
|
|
result = sum(int(layer_in[i]) << (7 - i) for i in range(8)) |
|
|
if result == expected_val: |
|
|
scores += 1 |
|
|
total += 1 |
|
|
|
|
|
self._record('combinational.barrelshifter', int(scores[0].item()), total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" combinational.barrelshifter: SKIP ({e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_priority_encoder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test priority encoder (find highest set bit).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== PRIORITY ENCODER ===") |
|
|
|
|
|
try: |
|
|
|
|
|
test_cases = [ |
|
|
(0b00000000, 0, 0), |
|
|
(0b00000001, 1, 7), |
|
|
(0b00000010, 1, 6), |
|
|
(0b00000100, 1, 5), |
|
|
(0b00001000, 1, 4), |
|
|
(0b00010000, 1, 3), |
|
|
(0b00100000, 1, 2), |
|
|
(0b01000000, 1, 1), |
|
|
(0b10000000, 1, 0), |
|
|
(0b10000001, 1, 0), |
|
|
(0b01010101, 1, 1), |
|
|
(0b00001111, 1, 4), |
|
|
(0b11111111, 1, 0), |
|
|
] |
|
|
|
|
|
for val, expected_valid, expected_idx in test_cases: |
|
|
val_bits = torch.tensor([float((val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
w_valid = pop['combinational.priorityencoder.valid.weight'].view(pop_size, 8) |
|
|
b_valid = pop['combinational.priorityencoder.valid.bias'].view(pop_size) |
|
|
out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid) |
|
|
|
|
|
if int(out_valid[0].item()) == expected_valid: |
|
|
scores += 1 |
|
|
total += 1 |
|
|
|
|
|
|
|
|
if expected_valid == 1: |
|
|
for idx_bit in range(3): |
|
|
try: |
|
|
w_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.weight'].view(pop_size, 8) |
|
|
b_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.bias'].view(pop_size) |
|
|
out_idx = heaviside((val_bits * w_idx).sum(-1) + b_idx) |
|
|
expected_bit = (expected_idx >> (2 - idx_bit)) & 1 |
|
|
if int(out_idx[0].item()) == expected_bit: |
|
|
scores += 1 |
|
|
total += 1 |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
self._record('combinational.priorityencoder', int(scores[0].item()), total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" combinational.priorityencoder: SKIP ({e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_barrel_shifter_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit barrel shifter (shift by 0 to bits-1 positions).""" |
|
|
import math |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
num_layers = max(1, math.ceil(math.log2(bits))) |
|
|
max_val = (1 << bits) - 1 |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT BARREL SHIFTER ===") |
|
|
|
|
|
prefix = f'combinational.barrelshifter{bits}' |
|
|
try: |
|
|
if bits == 16: |
|
|
test_vals = [0x8001, 0xFF00, 0x00FF, 0xAAAA, 0xFFFF, 0x1234] |
|
|
elif bits == 32: |
|
|
test_vals = [0x80000001, 0xFFFF0000, 0x0000FFFF, 0xAAAAAAAA, 0xFFFFFFFF, 0x12345678] |
|
|
else: |
|
|
test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, max_val] |
|
|
|
|
|
num_shifts = min(bits, 8) |
|
|
for val in test_vals: |
|
|
for shift in range(num_shifts): |
|
|
expected_val = (val << shift) & max_val |
|
|
val_bits = [float((val >> (bits - 1 - i)) & 1) for i in range(bits)] |
|
|
shift_bits = [float((shift >> (num_layers - 1 - i)) & 1) for i in range(num_layers)] |
|
|
|
|
|
layer_in = val_bits[:] |
|
|
for layer in range(num_layers): |
|
|
shift_amount = 1 << (num_layers - 1 - layer) |
|
|
sel = shift_bits[layer] |
|
|
layer_out = [] |
|
|
|
|
|
for bit in range(bits): |
|
|
bit_prefix = f'{prefix}.layer{layer}.bit{bit}' |
|
|
|
|
|
w_not = pop[f'{bit_prefix}.not_sel.weight'].view(pop_size) |
|
|
b_not = pop[f'{bit_prefix}.not_sel.bias'].view(pop_size) |
|
|
not_sel = heaviside(sel * w_not + b_not) |
|
|
|
|
|
shifted_src = bit + shift_amount |
|
|
if shifted_src < bits: |
|
|
shifted_val = layer_in[shifted_src] |
|
|
else: |
|
|
shifted_val = 0.0 |
|
|
|
|
|
w_and_a = pop[f'{bit_prefix}.and_a.weight'].view(pop_size, 2) |
|
|
b_and_a = pop[f'{bit_prefix}.and_a.bias'].view(pop_size) |
|
|
inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device) |
|
|
and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a) |
|
|
|
|
|
w_and_b = pop[f'{bit_prefix}.and_b.weight'].view(pop_size, 2) |
|
|
b_and_b = pop[f'{bit_prefix}.and_b.bias'].view(pop_size) |
|
|
inp_b = torch.tensor([shifted_val, sel], device=self.device) |
|
|
and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b) |
|
|
|
|
|
w_or = pop[f'{bit_prefix}.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'{bit_prefix}.or.bias'].view(pop_size) |
|
|
inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device) |
|
|
out = heaviside((inp_or * w_or).sum(-1) + b_or) |
|
|
layer_out.append(out[0].item()) |
|
|
|
|
|
layer_in = layer_out |
|
|
|
|
|
result = sum(int(layer_in[i]) << (bits - 1 - i) for i in range(bits)) |
|
|
if result == expected_val: |
|
|
scores += 1 |
|
|
total += 1 |
|
|
|
|
|
self._record(prefix, int(scores[0].item()), total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" {prefix}: SKIP ({e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test N-bit priority encoder (find highest set bit). |
|
|
|
|
|
The priority encoder is a multi-layer circuit: |
|
|
1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions) |
|
|
2. is_highest{0}: bit[0] directly (MSB is always highest if set) |
|
|
3. is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0 |
|
|
4. out{bit}: OR of is_highest{pos} for all pos where (pos >> bit) & 1 |
|
|
5. valid: OR of all input bits |
|
|
""" |
|
|
import math |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
out_bits = max(1, math.ceil(math.log2(bits))) |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== {bits}-BIT PRIORITY ENCODER ===") |
|
|
|
|
|
prefix = f'combinational.priorityencoder{bits}' |
|
|
try: |
|
|
test_cases = [(0, 0, 0)] |
|
|
for i in range(bits): |
|
|
test_cases.append((1 << i, 1, bits - 1 - i)) |
|
|
if bits == 16: |
|
|
test_cases.extend([ |
|
|
(0x8001, 1, 0), (0x5555, 1, 1), (0x00FF, 1, 8), (0xFFFF, 1, 0) |
|
|
]) |
|
|
elif bits == 32: |
|
|
test_cases.extend([ |
|
|
(0x80000001, 1, 0), (0x55555555, 1, 1), (0x0000FFFF, 1, 16), (0xFFFFFFFF, 1, 0) |
|
|
]) |
|
|
|
|
|
for val, expected_valid, expected_idx in test_cases: |
|
|
val_bits = torch.tensor([float((val >> (bits - 1 - i)) & 1) for i in range(bits)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
w_valid = pop[f'{prefix}.valid.weight'].view(pop_size, bits) |
|
|
b_valid = pop[f'{prefix}.valid.bias'].view(pop_size) |
|
|
out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid) |
|
|
|
|
|
if int(out_valid[0].item()) == expected_valid: |
|
|
scores += 1 |
|
|
total += 1 |
|
|
|
|
|
if expected_valid == 1: |
|
|
any_higher = [None] |
|
|
for pos in range(1, bits): |
|
|
w = pop[f'{prefix}.any_higher{pos}.weight'].view(pop_size, -1) |
|
|
b = pop[f'{prefix}.any_higher{pos}.bias'].view(pop_size) |
|
|
inp = val_bits[:pos] |
|
|
out = heaviside((inp * w[:, :len(inp)]).sum(-1) + b) |
|
|
any_higher.append(out) |
|
|
|
|
|
is_highest = [] |
|
|
for pos in range(bits): |
|
|
if pos == 0: |
|
|
is_high = val_bits[0].unsqueeze(0).expand(pop_size) |
|
|
else: |
|
|
w_not = pop[f'{prefix}.is_highest{pos}.not_higher.weight'].view(pop_size, -1) |
|
|
b_not = pop[f'{prefix}.is_highest{pos}.not_higher.bias'].view(pop_size) |
|
|
not_higher = heaviside(any_higher[pos].unsqueeze(-1) * w_not + b_not).squeeze(-1) |
|
|
|
|
|
w_and = pop[f'{prefix}.is_highest{pos}.and.weight'].view(pop_size, -1) |
|
|
b_and = pop[f'{prefix}.is_highest{pos}.and.bias'].view(pop_size) |
|
|
inp = torch.stack([val_bits[pos].expand(pop_size), not_higher], dim=-1) |
|
|
is_high = heaviside((inp * w_and).sum(-1) + b_and) |
|
|
is_highest.append(is_high) |
|
|
|
|
|
for idx_bit in range(out_bits): |
|
|
try: |
|
|
w_idx = pop[f'{prefix}.out{idx_bit}.weight'].view(pop_size, -1) |
|
|
b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size) |
|
|
relevant = [is_highest[pos] for pos in range(bits) if (pos >> idx_bit) & 1] |
|
|
if len(relevant) > 0: |
|
|
inp = torch.stack(relevant[:w_idx.shape[1]], dim=-1) |
|
|
out_idx = heaviside((inp * w_idx).sum(-1) + b_idx) |
|
|
expected_bit = (expected_idx >> idx_bit) & 1 |
|
|
if int(out_idx[0].item()) == expected_bit: |
|
|
scores += 1 |
|
|
total += 1 |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
self._record(prefix, int(scores[0].item()), total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" {prefix}: SKIP ({e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test conditional jump circuit (N-bit address aware).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
prefix = f'control.{name}' |
|
|
|
|
|
|
|
|
inputs = torch.tensor([ |
|
|
[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
|
|
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], |
|
|
], device=self.device, dtype=torch.float32) |
|
|
expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32) |
|
|
|
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
for bit in range(self.addr_bits): |
|
|
bit_prefix = f'{prefix}.bit{bit}' |
|
|
try: |
|
|
|
|
|
w_not = pop[f'{bit_prefix}.not_sel.weight'] |
|
|
b_not = pop[f'{bit_prefix}.not_sel.bias'] |
|
|
flag = inputs[:, 2:3] |
|
|
not_sel = heaviside(flag @ w_not.view(pop_size, -1).T + b_not.view(pop_size)) |
|
|
|
|
|
|
|
|
w_and_a = pop[f'{bit_prefix}.and_a.weight'] |
|
|
b_and_a = pop[f'{bit_prefix}.and_a.bias'] |
|
|
pc_not = torch.cat([inputs[:, 0:1], not_sel], dim=-1) |
|
|
and_a = heaviside((pc_not * w_and_a.view(pop_size, 1, 2)).sum(-1) + b_and_a.view(pop_size, 1)) |
|
|
|
|
|
|
|
|
w_and_b = pop[f'{bit_prefix}.and_b.weight'] |
|
|
b_and_b = pop[f'{bit_prefix}.and_b.bias'] |
|
|
target_sel = inputs[:, 1:3] |
|
|
and_b = heaviside((target_sel * w_and_b.view(pop_size, 1, 2)).sum(-1) + b_and_b.view(pop_size, 1)) |
|
|
|
|
|
|
|
|
w_or = pop[f'{bit_prefix}.or.weight'] |
|
|
b_or = pop[f'{bit_prefix}.or.bias'] |
|
|
|
|
|
and_a_2d = and_a.view(8, pop_size) |
|
|
and_b_2d = and_b.view(8, pop_size) |
|
|
ab = torch.stack([and_a_2d, and_b_2d], dim=-1) |
|
|
out = heaviside((ab * w_or.view(pop_size, 2)).sum(-1) + b_or.view(pop_size)) |
|
|
|
|
|
correct = (out == expected.unsqueeze(1)).float().sum(0) |
|
|
scores += correct |
|
|
total += 8 |
|
|
|
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
if total > 0: |
|
|
self._record(prefix, int((scores[0] / total * total).item()), total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_control_flow(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test control flow circuits.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== CONTROL FLOW ===") |
|
|
|
|
|
jumps = ['jz', 'jnz', 'jc', 'jnc', 'jn', 'jp', 'jv', 'jnv', 'conditionaljump'] |
|
|
for name in jumps: |
|
|
s, t = self._test_conditional_jump(pop, name, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
|
|
|
s, t = self._test_stack_ops(pop, debug) |
|
|
scores += s |
|
|
total += t |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test PUSH/POP/RET stack operation circuits (N-bit address aware).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
addr_bits = self.addr_bits |
|
|
addr_mask = (1 << addr_bits) - 1 |
|
|
|
|
|
if debug: |
|
|
print(f"\n=== STACK OPERATIONS ({addr_bits}-bit SP) ===") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
sp_tests = [0, 1, addr_mask // 2, addr_mask] |
|
|
if addr_bits >= 8: |
|
|
sp_tests.append(0x100 & addr_mask) |
|
|
if addr_bits >= 12: |
|
|
sp_tests.append(0x1234 & addr_mask) |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
for sp_val in sp_tests: |
|
|
expected_val = (sp_val - 1) & addr_mask |
|
|
sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)] |
|
|
|
|
|
borrow = 1.0 |
|
|
out_bits = [] |
|
|
for bit in range(addr_bits - 1, -1, -1): |
|
|
prefix = f'control.push.sp_dec.bit{bit}' |
|
|
|
|
|
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2) |
|
|
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size) |
|
|
w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2) |
|
|
b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size) |
|
|
|
|
|
inp = torch.tensor([sp_bits[bit], borrow], device=self.device) |
|
|
h_or = heaviside((inp * w_or).sum(-1) + b_or) |
|
|
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
diff_bit = heaviside((hidden * w2).sum(-1) + b2) |
|
|
out_bits.insert(0, diff_bit) |
|
|
|
|
|
|
|
|
not_sp = 1.0 - sp_bits[bit] |
|
|
w_borrow = pop[f'{prefix}.borrow.weight'].view(pop_size, 2) |
|
|
b_borrow = pop[f'{prefix}.borrow.bias'].view(pop_size) |
|
|
borrow_inp = torch.tensor([not_sp, borrow], device=self.device) |
|
|
borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item() |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += addr_bits |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('control.push.sp_dec', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" control.push.sp_dec: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
for sp_val in sp_tests: |
|
|
expected_val = (sp_val + 1) & addr_mask |
|
|
sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)] |
|
|
|
|
|
carry = 1.0 |
|
|
out_bits = [] |
|
|
for bit in range(addr_bits - 1, -1, -1): |
|
|
prefix = f'control.pop.sp_inc.bit{bit}' |
|
|
|
|
|
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2) |
|
|
b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size) |
|
|
w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2) |
|
|
b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size) |
|
|
|
|
|
inp = torch.tensor([sp_bits[bit], carry], device=self.device) |
|
|
h_or = heaviside((inp * w_or).sum(-1) + b_or) |
|
|
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
sum_bit = heaviside((hidden * w2).sum(-1) + b2) |
|
|
out_bits.insert(0, sum_bit) |
|
|
|
|
|
|
|
|
w_carry = pop[f'{prefix}.carry.weight'].view(pop_size, 2) |
|
|
b_carry = pop[f'{prefix}.carry.bias'].view(pop_size) |
|
|
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += addr_bits |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('control.pop.sp_inc', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" control.pop.sp_inc: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
ret_tests = [0, addr_mask, addr_mask // 2, 1] |
|
|
if addr_bits >= 12: |
|
|
ret_tests.append(0x1234 & addr_mask) |
|
|
for addr_val in ret_tests: |
|
|
ret_bits_tensor = torch.tensor([float((addr_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(addr_bits): |
|
|
w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size) |
|
|
b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size) |
|
|
out = heaviside(ret_bits_tensor[bit] * w + b) |
|
|
out_bits.append(out) |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
correct = (out == ret_bits_tensor.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += addr_bits |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('control.ret.addr', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" control.ret.addr: SKIP ({e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_alu_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test ALU operations (8-bit bitwise).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== ALU OPERATIONS ===") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_vals = [(0, 0), (255, 255), (0xAA, 0x55), (0x0F, 0xF0)] |
|
|
|
|
|
|
|
|
try: |
|
|
w = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) |
|
|
b = pop['alu.alu8bit.and.bias'].view(pop_size, 8) |
|
|
|
|
|
for a_val, b_val in test_vals: |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
inputs = torch.stack([a_bits, b_bits], dim=-1) |
|
|
|
|
|
out = heaviside((inputs * w).sum(-1) + b) |
|
|
expected = torch.tensor([((a_val & b_val) >> (7 - i)) & 1 for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
scores += correct |
|
|
total += 8 |
|
|
|
|
|
self._record('alu.alu8bit.and', int(scores[0].item()), total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError): |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
w = pop['alu.alu8bit.or.weight'].view(pop_size, 8, 2) |
|
|
b = pop['alu.alu8bit.or.bias'].view(pop_size, 8) |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
for a_val, b_val in test_vals: |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
inputs = torch.stack([a_bits, b_bits], dim=-1) |
|
|
out = heaviside((inputs * w).sum(-1) + b) |
|
|
expected = torch.tensor([((a_val | b_val) >> (7 - i)) & 1 for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.or', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError): |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
w = pop['alu.alu8bit.not.weight'].view(pop_size, 8) |
|
|
b = pop['alu.alu8bit.not.bias'].view(pop_size, 8) |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
for a_val, _ in test_vals: |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
out = heaviside(a_bits * w + b) |
|
|
expected = torch.tensor([(((~a_val) & 0xFF) >> (7 - i)) & 1 for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.not', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError): |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
for a_val, _ in test_vals: |
|
|
expected_val = (a_val << 1) & 0xFF |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
out_bits = [] |
|
|
for bit in range(8): |
|
|
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) |
|
|
b = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) |
|
|
if bit < 7: |
|
|
inp = a_bits[bit + 1].unsqueeze(0).expand(pop_size) |
|
|
else: |
|
|
inp = torch.zeros(pop_size, device=self.device) |
|
|
out = heaviside(inp * w + b) |
|
|
out_bits.append(out) |
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.shl', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.shl: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
for a_val, _ in test_vals: |
|
|
expected_val = (a_val >> 1) & 0xFF |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
out_bits = [] |
|
|
for bit in range(8): |
|
|
w = pop[f'alu.alu8bit.shr.bit{bit}.weight'].view(pop_size) |
|
|
b = pop[f'alu.alu8bit.shr.bit{bit}.bias'].view(pop_size) |
|
|
if bit > 0: |
|
|
inp = a_bits[bit - 1].unsqueeze(0).expand(pop_size) |
|
|
else: |
|
|
inp = torch.zeros(pop_size, device=self.device) |
|
|
out = heaviside(inp * w + b) |
|
|
out_bits.append(out) |
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.shr', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.shr: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
mul_tests = [(3, 4), (7, 8), (15, 17), (0, 255)] |
|
|
for a_val, b_val in mul_tests: |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
for i in range(8): |
|
|
for j in range(8): |
|
|
w = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.weight'].view(pop_size, 2) |
|
|
b = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.bias'].view(pop_size) |
|
|
inp = torch.tensor([a_bits[i].item(), b_bits[j].item()], device=self.device) |
|
|
out = heaviside((inp * w).sum(-1) + b) |
|
|
expected = float(int(a_bits[i].item()) & int(b_bits[j].item())) |
|
|
correct = (out == expected).float() |
|
|
op_scores += correct |
|
|
op_total += 1 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.mul', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.mul: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
div_tests = [(100, 10), (255, 17), (50, 7), (128, 16)] |
|
|
for a_val, b_val in div_tests: |
|
|
|
|
|
for stage in range(8): |
|
|
w = pop[f'alu.alu8bit.div.stage{stage}.cmp.weight'].view(pop_size, 16) |
|
|
b = pop[f'alu.alu8bit.div.stage{stage}.cmp.bias'].view(pop_size) |
|
|
|
|
|
|
|
|
test_rem = (a_val >> (7 - stage)) & 0xFF |
|
|
rem_bits = torch.tensor([((test_rem >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
div_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
inp = torch.cat([rem_bits, div_bits]) |
|
|
|
|
|
out = heaviside((inp * w).sum(-1) + b) |
|
|
expected = float(test_rem >= b_val) |
|
|
correct = (out == expected).float() |
|
|
op_scores += correct |
|
|
op_total += 1 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.div', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.div: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
inc_tests = [0, 1, 127, 128, 254, 255] |
|
|
for a_val in inc_tests: |
|
|
expected_val = (a_val + 1) & 0xFF |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
carry = 1.0 |
|
|
out_bits = [] |
|
|
for bit in range(7, -1, -1): |
|
|
|
|
|
w_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) |
|
|
b_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size) |
|
|
w2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2) |
|
|
b2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.bias'].view(pop_size) |
|
|
|
|
|
inp = torch.tensor([a_bits[bit].item(), carry], device=self.device) |
|
|
h_or = heaviside((inp * w_or).sum(-1) + b_or) |
|
|
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
sum_bit = heaviside((hidden * w2).sum(-1) + b2) |
|
|
out_bits.insert(0, sum_bit) |
|
|
|
|
|
|
|
|
w_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.weight'].view(pop_size, 2) |
|
|
b_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.bias'].view(pop_size) |
|
|
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.inc', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.inc: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
dec_tests = [0, 1, 127, 128, 254, 255] |
|
|
for a_val in dec_tests: |
|
|
expected_val = (a_val - 1) & 0xFF |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
borrow = 1.0 |
|
|
out_bits = [] |
|
|
for bit in range(7, -1, -1): |
|
|
w_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) |
|
|
b_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.bias'].view(pop_size) |
|
|
w2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.weight'].view(pop_size, 2) |
|
|
b2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.bias'].view(pop_size) |
|
|
|
|
|
inp = torch.tensor([a_bits[bit].item(), borrow], device=self.device) |
|
|
h_or = heaviside((inp * w_or).sum(-1) + b_or) |
|
|
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
diff_bit = heaviside((hidden * w2).sum(-1) + b2) |
|
|
out_bits.insert(0, diff_bit) |
|
|
|
|
|
|
|
|
w_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.weight'].view(pop_size) |
|
|
b_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.bias'].view(pop_size) |
|
|
not_a = heaviside(a_bits[bit] * w_not + b_not) |
|
|
|
|
|
w_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.weight'].view(pop_size, 2) |
|
|
b_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.bias'].view(pop_size) |
|
|
borrow_inp = torch.tensor([not_a[0].item(), borrow], device=self.device) |
|
|
borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item() |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.dec', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.dec: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
neg_tests = [0, 1, 127, 128, 255] |
|
|
for a_val in neg_tests: |
|
|
expected_val = (-a_val) & 0xFF |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
not_bits = [] |
|
|
for bit in range(8): |
|
|
w = pop[f'alu.alu8bit.neg.not.bit{bit}.weight'].view(pop_size) |
|
|
b = pop[f'alu.alu8bit.neg.not.bit{bit}.bias'].view(pop_size) |
|
|
not_bit = heaviside(a_bits[bit] * w + b) |
|
|
not_bits.append(not_bit) |
|
|
|
|
|
|
|
|
carry = 1.0 |
|
|
out_bits = [] |
|
|
for bit in range(7, -1, -1): |
|
|
w_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) |
|
|
b_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size) |
|
|
w_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) |
|
|
b_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size) |
|
|
w2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2) |
|
|
b2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.bias'].view(pop_size) |
|
|
|
|
|
inp = torch.tensor([not_bits[bit][0].item(), carry], device=self.device) |
|
|
h_or = heaviside((inp * w_or).sum(-1) + b_or) |
|
|
h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
sum_bit = heaviside((hidden * w2).sum(-1) + b2) |
|
|
out_bits.insert(0, sum_bit) |
|
|
|
|
|
w_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.weight'].view(pop_size, 2) |
|
|
b_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.bias'].view(pop_size) |
|
|
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.neg', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.neg: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
rol_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00] |
|
|
for a_val in rol_tests: |
|
|
expected_val = ((a_val << 1) | (a_val >> 7)) & 0xFF |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(8): |
|
|
w = pop[f'alu.alu8bit.rol.bit{bit}.weight'].view(pop_size) |
|
|
b = pop[f'alu.alu8bit.rol.bit{bit}.bias'].view(pop_size) |
|
|
|
|
|
src_bit = (bit + 1) % 8 |
|
|
out = heaviside(a_bits[src_bit] * w + b) |
|
|
out_bits.append(out) |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.rol', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.rol: SKIP ({e})") |
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
ror_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00] |
|
|
for a_val in ror_tests: |
|
|
expected_val = ((a_val >> 1) | (a_val << 7)) & 0xFF |
|
|
a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(8): |
|
|
w = pop[f'alu.alu8bit.ror.bit{bit}.weight'].view(pop_size) |
|
|
b = pop[f'alu.alu8bit.ror.bit{bit}.bias'].view(pop_size) |
|
|
|
|
|
src_bit = (bit - 1) % 8 |
|
|
out = heaviside(a_bits[src_bit] * w + b) |
|
|
out_bits.append(out) |
|
|
|
|
|
out = torch.stack(out_bits, dim=-1) |
|
|
expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
correct = (out == expected.unsqueeze(0)).float().sum(1) |
|
|
op_scores += correct |
|
|
op_total += 8 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('alu.alu8bit.ror', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" alu.alu8bit.ror: SKIP ({e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_manifest(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Verify manifest values.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== MANIFEST ===") |
|
|
|
|
|
fixed_expected = { |
|
|
'manifest.alu_operations': 16.0, |
|
|
'manifest.flags': 4.0, |
|
|
'manifest.instruction_width': 16.0, |
|
|
'manifest.register_width': 8.0, |
|
|
'manifest.registers': 4.0, |
|
|
'manifest.version': 4.0, |
|
|
} |
|
|
|
|
|
for name, exp_val in fixed_expected.items(): |
|
|
try: |
|
|
val = pop[name][0, 0].item() |
|
|
if val == exp_val: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(exp_val, val)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
variable_checks = ['manifest.memory_bytes', 'manifest.pc_width', 'manifest.turing_complete'] |
|
|
for name in variable_checks: |
|
|
try: |
|
|
val = pop[name][0, 0].item() |
|
|
valid = val >= 0 |
|
|
if valid: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [('>=0', val)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'} (value={val})") |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_memory(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test memory circuits (shape validation).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== MEMORY ===") |
|
|
|
|
|
try: |
|
|
mem_bytes = int(pop['manifest.memory_bytes'][0].item()) |
|
|
addr_bits = int(pop['manifest.pc_width'][0].item()) |
|
|
except KeyError: |
|
|
mem_bytes = 65536 |
|
|
addr_bits = 16 |
|
|
|
|
|
if mem_bytes == 0: |
|
|
if debug: |
|
|
print(" No memory (pure ALU mode)") |
|
|
return scores, 0 |
|
|
|
|
|
expected_shapes = { |
|
|
'memory.addr_decode.weight': (mem_bytes, addr_bits), |
|
|
'memory.addr_decode.bias': (mem_bytes,), |
|
|
'memory.read.and.weight': (8, mem_bytes, 2), |
|
|
'memory.read.and.bias': (8, mem_bytes), |
|
|
'memory.read.or.weight': (8, mem_bytes), |
|
|
'memory.read.or.bias': (8,), |
|
|
'memory.write.sel.weight': (mem_bytes, 2), |
|
|
'memory.write.sel.bias': (mem_bytes,), |
|
|
'memory.write.nsel.weight': (mem_bytes, 1), |
|
|
'memory.write.nsel.bias': (mem_bytes,), |
|
|
'memory.write.and_old.weight': (mem_bytes, 8, 2), |
|
|
'memory.write.and_old.bias': (mem_bytes, 8), |
|
|
'memory.write.and_new.weight': (mem_bytes, 8, 2), |
|
|
'memory.write.and_new.bias': (mem_bytes, 8), |
|
|
'memory.write.or.weight': (mem_bytes, 8, 2), |
|
|
'memory.write.or.bias': (mem_bytes, 8), |
|
|
} |
|
|
|
|
|
for name, expected_shape in expected_shapes.items(): |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
|
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
pass |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_float16_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float16 core circuits (unpack, pack, classify).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT16 CORE ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float16.unpack.bit0.weight', (1,)), |
|
|
('float16.classify.exp_zero.weight', (5,)), |
|
|
('float16.classify.exp_max.weight', (5,)), |
|
|
('float16.classify.frac_zero.weight', (10,)), |
|
|
('float16.classify.is_zero.and.weight', (2,)), |
|
|
('float16.classify.is_nan.and.weight', (2,)), |
|
|
('float16.normalize.stage0.bit0.not_sel.weight', (1,)), |
|
|
('float16.normalize.stage0.bit0.and_a.weight', (2,)), |
|
|
('float16.normalize.stage0.bit0.or.weight', (2,)), |
|
|
('float16.pack.bit0.weight', (1,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float16_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float16 addition circuit.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT16 ADD ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float16.add.exp_cmp.a_gt_b.weight', (10,)), |
|
|
('float16.add.exp_cmp.a_lt_b.weight', (10,)), |
|
|
('float16.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float16.add.align.stage0.bit0.not_sel.weight', (1,)), |
|
|
('float16.add.sign_xor.layer1.or.weight', (2,)), |
|
|
('float16.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float16.add.mant_sub.not_b.bit0.weight', (1,)), |
|
|
('float16.add.mant_select.bit0.not_sel.weight', (1,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float16_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float16 multiplication circuit.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT16 MUL ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float16.mul.sign_xor.layer1.or.weight', (2,)), |
|
|
('float16.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float16.mul.bias_sub.not_bias.bit0.weight', (1,)), |
|
|
('float16.mul.mant_mul.pp.a0b0.weight', (2,)), |
|
|
('float16.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float16_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float16 division circuit.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT16 DIV ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float16.div.sign_xor.layer1.or.weight', (2,)), |
|
|
('float16.div.exp_sub.not_b.bit0.weight', (1,)), |
|
|
('float16.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float16.div.mant_div.stage0.cmp.weight', (22,)), |
|
|
('float16.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)), |
|
|
('float16.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float16_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float16 comparison circuits.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT16 CMP ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float16.cmp.a.exp_max.weight', (5,)), |
|
|
('float16.cmp.a.frac_nz.weight', (10,)), |
|
|
('float16.cmp.a.is_nan.weight', (2,)), |
|
|
('float16.cmp.either_nan.weight', (2,)), |
|
|
('float16.cmp.sign_xor.layer1.or.weight', (2,)), |
|
|
('float16.cmp.both_zero.weight', (2,)), |
|
|
('float16.cmp.mag_a_gt_b.weight', (30,)), |
|
|
('float16.cmp.eq.result.weight', (2,)), |
|
|
('float16.cmp.lt.result.weight', (3,)), |
|
|
('float16.cmp.gt.result.weight', (3,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_float32_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float32 core circuits (unpack, pack, classify).""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT32 CORE ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float32.unpack.bit0.weight', (1,)), |
|
|
('float32.classify.exp_zero.weight', (8,)), |
|
|
('float32.classify.exp_max.weight', (8,)), |
|
|
('float32.classify.frac_zero.weight', (23,)), |
|
|
('float32.classify.is_zero.and.weight', (2,)), |
|
|
('float32.classify.is_nan.and.weight', (2,)), |
|
|
('float32.normalize.stage0.bit0.not_sel.weight', (1,)), |
|
|
('float32.pack.bit0.weight', (1,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float32_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float32 addition circuit.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT32 ADD ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float32.add.exp_cmp.a_gt_b.weight', (16,)), |
|
|
('float32.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float32.add.align.stage0.bit0.not_sel.weight', (1,)), |
|
|
('float32.add.sign_xor.layer1.or.weight', (2,)), |
|
|
('float32.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float32.add.mant_sub.not_b.bit0.weight', (1,)), |
|
|
('float32.add.mant_select.bit0.not_sel.weight', (1,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float32_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float32 multiplication circuit.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT32 MUL ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float32.mul.sign_xor.layer1.or.weight', (2,)), |
|
|
('float32.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float32.mul.bias_sub.not_bias.bit0.weight', (1,)), |
|
|
('float32.mul.mant_mul.pp.a0b0.weight', (2,)), |
|
|
('float32.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float32_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float32 division circuit.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT32 DIV ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float32.div.sign_xor.layer1.or.weight', (2,)), |
|
|
('float32.div.exp_sub.not_b.bit0.weight', (1,)), |
|
|
('float32.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
|
|
('float32.div.mant_div.stage0.cmp.weight', (48,)), |
|
|
('float32.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)), |
|
|
('float32.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
def _test_float32_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test float32 comparison circuits.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== FLOAT32 CMP ===") |
|
|
|
|
|
expected_gates = [ |
|
|
('float32.cmp.a.exp_max.weight', (8,)), |
|
|
('float32.cmp.a.frac_nz.weight', (23,)), |
|
|
('float32.cmp.a.is_nan.weight', (2,)), |
|
|
('float32.cmp.either_nan.weight', (2,)), |
|
|
('float32.cmp.sign_xor.layer1.or.weight', (2,)), |
|
|
('float32.cmp.both_zero.weight', (2,)), |
|
|
('float32.cmp.mag_a_gt_b.weight', (62,)), |
|
|
('float32.cmp.eq.result.weight', (2,)), |
|
|
('float32.cmp.lt.result.weight', (3,)), |
|
|
('float32.cmp.gt.result.weight', (3,)), |
|
|
] |
|
|
|
|
|
for name, expected_shape in expected_gates: |
|
|
try: |
|
|
tensor = pop[name] |
|
|
actual_shape = tuple(tensor.shape[1:]) |
|
|
if actual_shape == expected_shape: |
|
|
scores += 1 |
|
|
self._record(name, 1, 1, []) |
|
|
else: |
|
|
self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
|
|
total += 1 |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except KeyError: |
|
|
if debug: |
|
|
print(f" {name}: SKIP (not found)") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
|
|
"""Test complex operations that chain multiple circuit families.""" |
|
|
pop_size = next(iter(pop.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total = 0 |
|
|
|
|
|
if debug: |
|
|
print("\n=== INTEGRATION TESTS ===") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
tests = [(10, 20, 25), (100, 50, 200), (255, 1, 0), (0, 0, 1)] |
|
|
for a, b, c in tests: |
|
|
sum_val = (a + b) & 0xFF |
|
|
expected = float(sum_val > c) |
|
|
|
|
|
|
|
|
sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
w = pop['arithmetic.greaterthan8bit.weight'].view(pop_size, 16) |
|
|
bias = pop['arithmetic.greaterthan8bit.bias'].view(pop_size) |
|
|
inp = torch.cat([sum_bits, c_bits]) |
|
|
out = heaviside((inp * w).sum(-1) + bias) |
|
|
correct = (out == expected).float() |
|
|
op_scores += correct |
|
|
op_total += 1 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('integration.add_then_compare', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" integration.add_then_compare: SKIP ({e})") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
tests = [(3, 5), (4, 6), (7, 11), (9, 9)] |
|
|
for a, b in tests: |
|
|
product = (a * b) & 0xFF |
|
|
expected_mod3 = product % 3 |
|
|
|
|
|
|
|
|
prod_bits = torch.tensor([((product >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
w1 = pop['modular.mod3.layer1.weight'].view(pop_size, 8) |
|
|
b1 = pop['modular.mod3.layer1.bias'].view(pop_size) |
|
|
h1 = heaviside((prod_bits * w1).sum(-1) + b1) |
|
|
|
|
|
w2 = pop['modular.mod3.layer2.weight'].view(pop_size, 8) |
|
|
b2 = pop['modular.mod3.layer2.bias'].view(pop_size) |
|
|
h2 = heaviside((prod_bits * w2).sum(-1) + b2) |
|
|
|
|
|
|
|
|
op_scores += 1 |
|
|
op_total += 1 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('integration.mul_then_mod', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" integration.mul_then_mod: SKIP ({e})") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
tests = [(0b10101010, 0b11110000), (0b00001111, 0b01010101), (0xFF, 0x0F)] |
|
|
for a, b in tests: |
|
|
shifted_a = (a << 1) & 0xFF |
|
|
expected = shifted_a & b |
|
|
|
|
|
a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
shifted_bits = [] |
|
|
for bit in range(8): |
|
|
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) |
|
|
bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) |
|
|
if bit < 7: |
|
|
inp = a_bits[bit + 1] |
|
|
else: |
|
|
inp = torch.tensor(0.0, device=self.device) |
|
|
out = heaviside(inp * w + bias) |
|
|
shifted_bits.append(out) |
|
|
|
|
|
|
|
|
and_bits = [] |
|
|
w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) |
|
|
b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8) |
|
|
for bit in range(8): |
|
|
inp = torch.tensor([shifted_bits[bit][0].item(), b_bits[bit].item()], |
|
|
device=self.device) |
|
|
out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit]) |
|
|
and_bits.append(out) |
|
|
|
|
|
out_val = sum(int(and_bits[i][0].item()) << (7 - i) for i in range(8)) |
|
|
correct = (out_val == expected) |
|
|
op_scores += float(correct) |
|
|
op_total += 1 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('integration.shift_then_and', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" integration.shift_then_and: SKIP ({e})") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
tests = [(50, 30), (30, 50), (100, 100), (0, 1)] |
|
|
for a, b in tests: |
|
|
diff = (a - b) & 0xFF |
|
|
is_negative = a < b |
|
|
expected = (-diff & 0xFF) if is_negative else diff |
|
|
|
|
|
|
|
|
|
|
|
a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
w = pop['arithmetic.lessthan8bit.weight'].view(pop_size, 16) |
|
|
bias = pop['arithmetic.lessthan8bit.bias'].view(pop_size) |
|
|
inp = torch.cat([a_bits, b_bits]) |
|
|
lt_out = heaviside((inp * w).sum(-1) + bias) |
|
|
|
|
|
correct = (lt_out[0].item() == float(is_negative)) |
|
|
op_scores += float(correct) |
|
|
op_total += 1 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('integration.sub_then_conditional', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" integration.sub_then_conditional: SKIP ({e})") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
op_scores = torch.zeros(pop_size, device=self.device) |
|
|
op_total = 0 |
|
|
|
|
|
tests = [(10, 20), (50, 50), (127, 1), (0, 0)] |
|
|
for a, b in tests: |
|
|
sum_val = (a + b) & 0xFF |
|
|
doubled = (sum_val << 1) & 0xFF |
|
|
expected = doubled & 0xF0 |
|
|
|
|
|
sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)], |
|
|
device=self.device, dtype=torch.float32) |
|
|
mask_bits = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0], |
|
|
device=self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
shifted_bits = [] |
|
|
for bit in range(8): |
|
|
w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) |
|
|
bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) |
|
|
if bit < 7: |
|
|
inp = sum_bits[bit + 1] |
|
|
else: |
|
|
inp = torch.tensor(0.0, device=self.device) |
|
|
out = heaviside(inp * w + bias) |
|
|
shifted_bits.append(out) |
|
|
|
|
|
|
|
|
w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) |
|
|
b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8) |
|
|
result_bits = [] |
|
|
for bit in range(8): |
|
|
inp = torch.tensor([shifted_bits[bit][0].item(), mask_bits[bit].item()], |
|
|
device=self.device) |
|
|
out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit]) |
|
|
result_bits.append(out) |
|
|
|
|
|
out_val = sum(int(result_bits[i][0].item()) << (7 - i) for i in range(8)) |
|
|
correct = (out_val == expected) |
|
|
op_scores += float(correct) |
|
|
op_total += 1 |
|
|
|
|
|
scores += op_scores |
|
|
total += op_total |
|
|
self._record('integration.complex_expr', int(op_scores[0].item()), op_total, []) |
|
|
if debug: |
|
|
r = self.results[-1] |
|
|
print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
|
|
except (KeyError, RuntimeError) as e: |
|
|
if debug: |
|
|
print(f" integration.complex_expr: SKIP ({e})") |
|
|
|
|
|
return scores, total |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(self, population: Dict[str, torch.Tensor], debug: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Evaluate population fitness with per-circuit reporting. |
|
|
|
|
|
Args: |
|
|
population: Dict of tensors, each with shape [pop_size, ...] |
|
|
debug: If True, print per-circuit results |
|
|
|
|
|
Returns: |
|
|
Tensor of fitness scores [pop_size], normalized to [0, 1] |
|
|
""" |
|
|
self.results = [] |
|
|
self.category_scores = {} |
|
|
|
|
|
pop_size = next(iter(population.values())).shape[0] |
|
|
scores = torch.zeros(pop_size, device=self.device) |
|
|
total_tests = 0 |
|
|
|
|
|
|
|
|
s, t = self._test_boolean_gates(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['boolean'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_halfadder(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['halfadder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_fulladder(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['fulladder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
for bits in [2, 4, 8]: |
|
|
s, t = self._test_ripplecarry(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
for bits in [16, 32]: |
|
|
if f'arithmetic.ripplecarry{bits}bit.fa0.ha1.sum.layer1.or.weight' in population: |
|
|
if debug: |
|
|
print(f"\n{'=' * 60}") |
|
|
print(f" {bits}-BIT CIRCUITS") |
|
|
print(f"{'=' * 60}") |
|
|
|
|
|
s, t = self._test_ripplecarry(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
s, t = self._test_comparators_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if f'arithmetic.sub{bits}bit.not_b.bit0.weight' in population: |
|
|
s, t = self._test_subtractor_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'subtractor{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if f'alu.alu{bits}bit.and.bit0.weight' in population: |
|
|
s, t = self._test_bitwise_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'bitwise{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if f'alu.alu{bits}bit.shl.bit0.weight' in population: |
|
|
s, t = self._test_shifts_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'shifts{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if f'alu.alu{bits}bit.inc.bit0.xor.layer1.or.weight' in population: |
|
|
s, t = self._test_inc_dec_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'incdec{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if f'alu.alu{bits}bit.neg.not.bit0.weight' in population: |
|
|
s, t = self._test_neg_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'neg{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if f'combinational.barrelshifter{bits}.layer0.bit0.not_sel.weight' in population: |
|
|
s, t = self._test_barrel_shifter_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'barrelshifter{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if f'combinational.priorityencoder{bits}.valid.weight' in population: |
|
|
s, t = self._test_priority_encoder_nbits(population, bits, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores[f'priorityencoder{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_add3(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_expr_add_mul(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['expr_add_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_comparators(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['comparators'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_threshold_gates(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['threshold'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_modular_all(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['modular'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_patterns(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['patterns'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_error_detection(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['error_detection'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_combinational(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['combinational'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_control_flow(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['control'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_alu_ops(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['alu'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_manifest(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['manifest'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
s, t = self._test_memory(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['memory'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
if 'float16.unpack.bit0.weight' in population: |
|
|
if debug: |
|
|
print(f"\n{'=' * 60}") |
|
|
print(f" FLOAT16 CIRCUITS") |
|
|
print(f"{'=' * 60}") |
|
|
|
|
|
s, t = self._test_float16_core(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float16_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float16.add.exp_cmp.a_gt_b.weight' in population: |
|
|
s, t = self._test_float16_add(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float16_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float16.mul.sign_xor.layer1.or.weight' in population: |
|
|
s, t = self._test_float16_mul(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float16_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float16.div.sign_xor.layer1.or.weight' in population: |
|
|
s, t = self._test_float16_div(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float16_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float16.cmp.a.exp_max.weight' in population: |
|
|
s, t = self._test_float16_cmp(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float16_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
|
|
|
if 'float32.unpack.bit0.weight' in population: |
|
|
if debug: |
|
|
print(f"\n{'=' * 60}") |
|
|
print(f" FLOAT32 CIRCUITS") |
|
|
print(f"{'=' * 60}") |
|
|
|
|
|
s, t = self._test_float32_core(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float32_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float32.add.exp_cmp.a_gt_b.weight' in population: |
|
|
s, t = self._test_float32_add(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float32_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float32.mul.sign_xor.layer1.or.weight' in population: |
|
|
s, t = self._test_float32_mul(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float32_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float32.div.sign_xor.layer1.or.weight' in population: |
|
|
s, t = self._test_float32_div(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float32_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
if 'float32.cmp.a.exp_max.weight' in population: |
|
|
s, t = self._test_float32_cmp(population, debug) |
|
|
scores += s |
|
|
total_tests += t |
|
|
self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
|
|
|
|
|
self.total_tests = total_tests |
|
|
|
|
|
if debug: |
|
|
print("\n" + "=" * 60) |
|
|
print("CATEGORY SUMMARY") |
|
|
print("=" * 60) |
|
|
for cat, (got, expected) in sorted(self.category_scores.items()): |
|
|
pct = 100 * got / expected if expected > 0 else 0 |
|
|
status = "PASS" if got == expected else "FAIL" |
|
|
print(f" {cat:20} {int(got):6}/{expected:6} ({pct:6.2f}%) [{status}]") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("CIRCUIT FAILURES") |
|
|
print("=" * 60) |
|
|
failed = [r for r in self.results if not r.success] |
|
|
if failed: |
|
|
for r in failed[:20]: |
|
|
print(f" {r.name}: {r.passed}/{r.total}") |
|
|
if r.failures: |
|
|
print(f" First failure: {r.failures[0]}") |
|
|
if len(failed) > 20: |
|
|
print(f" ... and {len(failed) - 20} more") |
|
|
else: |
|
|
print(" None!") |
|
|
|
|
|
return scores / total_tests if total_tests > 0 else scores |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Unified Evaluation Suite for 8-bit Threshold Computer') |
|
|
parser.add_argument('--model', type=str, default=MODEL_PATH, help='Path to safetensors model') |
|
|
parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu') |
|
|
parser.add_argument('--pop_size', type=int, default=1, help='Population size for batched evaluation') |
|
|
parser.add_argument('--quiet', action='store_true', help='Suppress detailed output') |
|
|
parser.add_argument('--cpu-test', action='store_true', help='Run CPU smoke test (LOAD, ADD, STORE, HALT)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.cpu_test: |
|
|
return run_smoke_test() |
|
|
|
|
|
print("=" * 70) |
|
|
print(" UNIFIED EVALUATION SUITE") |
|
|
print("=" * 70) |
|
|
|
|
|
print(f"\nLoading model from {args.model}...") |
|
|
model = load_model(args.model) |
|
|
print(f" Loaded {len(model)} tensors, {sum(t.numel() for t in model.values()):,} params") |
|
|
|
|
|
print(f"\nInitializing evaluator on {args.device}...") |
|
|
evaluator = BatchedFitnessEvaluator(device=args.device, model_path=args.model) |
|
|
|
|
|
print(f"\nCreating population (size {args.pop_size})...") |
|
|
population = create_population(model, pop_size=args.pop_size, device=args.device) |
|
|
|
|
|
print("\nRunning evaluation...") |
|
|
if args.device == 'cuda': |
|
|
torch.cuda.synchronize() |
|
|
start = time.perf_counter() |
|
|
|
|
|
fitness = evaluator.evaluate(population, debug=not args.quiet) |
|
|
|
|
|
if args.device == 'cuda': |
|
|
torch.cuda.synchronize() |
|
|
elapsed = time.perf_counter() - start |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("RESULTS") |
|
|
print("=" * 70) |
|
|
|
|
|
if args.pop_size == 1: |
|
|
print(f" Fitness: {fitness[0].item():.6f}") |
|
|
else: |
|
|
print(f" Mean Fitness: {fitness.mean().item():.6f}") |
|
|
print(f" Min Fitness: {fitness.min().item():.6f}") |
|
|
print(f" Max Fitness: {fitness.max().item():.6f}") |
|
|
|
|
|
print(f" Total tests: {evaluator.total_tests}") |
|
|
print(f" Time: {elapsed * 1000:.2f} ms") |
|
|
|
|
|
if args.pop_size > 1: |
|
|
print(f" Throughput: {args.pop_size / elapsed:.0f} evals/sec") |
|
|
perfect = (fitness >= 0.9999).sum().item() |
|
|
print(f" Perfect (>=99.99%): {perfect}/{args.pop_size}") |
|
|
|
|
|
if fitness[0].item() >= 0.9999: |
|
|
print("\n STATUS: PASS") |
|
|
return 0 |
|
|
else: |
|
|
failed_count = int((1 - fitness[0].item()) * evaluator.total_tests) |
|
|
print(f"\n STATUS: FAIL ({failed_count} tests failed)") |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
exit(main()) |
|
|
|