""" Threshold-weight runtime for the 8-bit CPU. Implements a reference cycle using the frozen circuit weights for core ALU ops. """ from __future__ import annotations from pathlib import Path from typing import List, Tuple import torch from safetensors.torch import load_file from .state import CPUState, pack_state, unpack_state, REG_BITS, PC_BITS, MEM_BYTES def heaviside(x: torch.Tensor) -> torch.Tensor: return (x >= 0).float() def int_to_bits_msb(value: int, width: int) -> List[int]: return [(value >> (width - 1 - i)) & 1 for i in range(width)] def bits_to_int_msb(bits: List[int]) -> int: value = 0 for bit in bits: value = (value << 1) | int(bit) return value def bits_msb_to_lsb(bits: List[int]) -> List[int]: return list(reversed(bits)) DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors" class ThresholdALU: def __init__(self, model_path: str, device: str = "cpu") -> None: self.device = device self.tensors = {k: v.float().to(device) for k, v in load_file(model_path).items()} def _get(self, name: str) -> torch.Tensor: return self.tensors[name] def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float: w = self._get(weight_key) b = self._get(bias_key) inp = torch.tensor(inputs, device=self.device) return heaviside((inp * w).sum() + b).item() def _eval_xor(self, prefix: str, inputs: List[float]) -> float: inp = torch.tensor(inputs, device=self.device) w_or = self._get(f"{prefix}.layer1.or.weight") b_or = self._get(f"{prefix}.layer1.or.bias") w_nand = self._get(f"{prefix}.layer1.nand.weight") b_nand = self._get(f"{prefix}.layer1.nand.bias") w2 = self._get(f"{prefix}.layer2.weight") b2 = self._get(f"{prefix}.layer2.bias") h_or = heaviside((inp * w_or).sum() + b_or).item() h_nand = heaviside((inp * w_nand).sum() + b_nand).item() hidden = torch.tensor([h_or, h_nand], device=self.device) return heaviside((hidden * w2).sum() + b2).item() def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]: ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b]) ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b]) ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin]) ha2_carry = self._eval_gate( f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin] ) cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry]) return ha2_sum, cout def add(self, a: int, b: int) -> Tuple[int, int, int]: a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS)) b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS)) carry = 0.0 sum_bits: List[int] = [] for bit in range(REG_BITS): sum_bit, carry = self._eval_full_adder( f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry ) sum_bits.append(int(sum_bit)) result = bits_to_int_msb(list(reversed(sum_bits))) carry_out = int(carry) overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 return result, carry_out, overflow def sub(self, a: int, b: int) -> Tuple[int, int, int]: a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS)) b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS)) carry = 1.0 # two's complement carry-in sum_bits: List[int] = [] for bit in range(REG_BITS): notb = self._eval_gate( f"arithmetic.sub8bit.notb{bit}.weight", f"arithmetic.sub8bit.notb{bit}.bias", [float(b_bits[bit])], ) xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb]) xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry]) and1 = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.and1.weight", f"arithmetic.sub8bit.fa{bit}.and1.bias", [float(a_bits[bit]), notb], ) and2 = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.and2.weight", f"arithmetic.sub8bit.fa{bit}.and2.bias", [xor1, carry], ) carry = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.or_carry.weight", f"arithmetic.sub8bit.fa{bit}.or_carry.bias", [and1, and2], ) sum_bits.append(int(xor2)) result = bits_to_int_msb(list(reversed(sum_bits))) carry_out = int(carry) overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 return result, carry_out, overflow def bitwise_and(self, a: int, b: int) -> int: a_bits = int_to_bits_msb(a, REG_BITS) b_bits = int_to_bits_msb(b, REG_BITS) w = self._get("alu.alu8bit.and.weight") bias = self._get("alu.alu8bit.and.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int_msb(out_bits) def bitwise_or(self, a: int, b: int) -> int: a_bits = int_to_bits_msb(a, REG_BITS) b_bits = int_to_bits_msb(b, REG_BITS) w = self._get("alu.alu8bit.or.weight") bias = self._get("alu.alu8bit.or.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int_msb(out_bits) def bitwise_not(self, a: int) -> int: a_bits = int_to_bits_msb(a, REG_BITS) w = self._get("alu.alu8bit.not.weight") bias = self._get("alu.alu8bit.not.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit])], device=self.device) out = heaviside((inp * w[bit]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int_msb(out_bits) def bitwise_xor(self, a: int, b: int) -> int: a_bits = int_to_bits_msb(a, REG_BITS) b_bits = int_to_bits_msb(b, REG_BITS) w_or = self._get("alu.alu8bit.xor.layer1.or.weight") b_or = self._get("alu.alu8bit.xor.layer1.or.bias") w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight") b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias") w2 = self._get("alu.alu8bit.xor.layer2.weight") b2 = self._get("alu.alu8bit.xor.layer2.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit]) h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit]) hidden = torch.stack([h_or, h_nand]) out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item() out_bits.append(int(out)) return bits_to_int_msb(out_bits) class ThresholdCPU: def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None: self.device = device self.alu = ThresholdALU(str(model_path), device=device) @staticmethod def decode_ir(ir: int) -> Tuple[int, int, int, int]: opcode = (ir >> 12) & 0xF rd = (ir >> 10) & 0x3 rs = (ir >> 8) & 0x3 imm8 = ir & 0xFF return opcode, rd, rs, imm8 @staticmethod def flags_from_result(result: int, carry: int, overflow: int) -> List[int]: z = 1 if result == 0 else 0 n = 1 if (result & 0x80) else 0 c = 1 if carry else 0 v = 1 if overflow else 0 return [z, n, c, v] def _addr_decode(self, addr: int) -> torch.Tensor: bits = torch.tensor(int_to_bits_msb(addr, PC_BITS), device=self.device, dtype=torch.float32) w = self.alu._get("memory.addr_decode.weight") b = self.alu._get("memory.addr_decode.bias") return heaviside((w * bits).sum(dim=1) + b) def _memory_read(self, mem: List[int], addr: int) -> int: sel = self._addr_decode(addr) mem_bits = torch.tensor( [int_to_bits_msb(byte, REG_BITS) for byte in mem], device=self.device, dtype=torch.float32, ) and_w = self.alu._get("memory.read.and.weight") and_b = self.alu._get("memory.read.and.bias") or_w = self.alu._get("memory.read.or.weight") or_b = self.alu._get("memory.read.or.bias") out_bits: List[int] = [] for bit in range(REG_BITS): inp = torch.stack([mem_bits[:, bit], sel], dim=1) and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item() out_bits.append(int(out_bit)) return bits_to_int_msb(out_bits) def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]: sel = self._addr_decode(addr) data_bits = torch.tensor(int_to_bits_msb(value, REG_BITS), device=self.device, dtype=torch.float32) mem_bits = torch.tensor( [int_to_bits_msb(byte, REG_BITS) for byte in mem], device=self.device, dtype=torch.float32, ) sel_w = self.alu._get("memory.write.sel.weight") sel_b = self.alu._get("memory.write.sel.bias") nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1) nsel_b = self.alu._get("memory.write.nsel.bias") and_old_w = self.alu._get("memory.write.and_old.weight") and_old_b = self.alu._get("memory.write.and_old.bias") and_new_w = self.alu._get("memory.write.and_new.weight") and_new_b = self.alu._get("memory.write.and_new.bias") or_w = self.alu._get("memory.write.or.weight") or_b = self.alu._get("memory.write.or.bias") we = torch.ones_like(sel) sel_inp = torch.stack([sel, we], dim=1) write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) nsel = heaviside((write_sel * nsel_w) + nsel_b) new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device) for bit in range(REG_BITS): old_bit = mem_bits[:, bit] data_bit = data_bits[bit].expand(MEM_BYTES) inp_old = torch.stack([old_bit, nsel], dim=1) inp_new = torch.stack([data_bit, write_sel], dim=1) and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) or_inp = torch.stack([and_old, and_new], dim=1) out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) new_mem_bits[:, bit] = out_bit return [bits_to_int_msb([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)] def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int: pc_bits = int_to_bits_msb(pc_byte, REG_BITS) target_bits = int_to_bits_msb(target_byte, REG_BITS) out_bits: List[int] = [] for bit in range(REG_BITS): not_sel = self.alu._eval_gate( f"{prefix}.bit{bit}.not_sel.weight", f"{prefix}.bit{bit}.not_sel.bias", [float(flag)], ) and_a = self.alu._eval_gate( f"{prefix}.bit{bit}.and_a.weight", f"{prefix}.bit{bit}.and_a.bias", [float(pc_bits[bit]), not_sel], ) and_b = self.alu._eval_gate( f"{prefix}.bit{bit}.and_b.weight", f"{prefix}.bit{bit}.and_b.bias", [float(target_bits[bit]), float(flag)], ) out_bit = self.alu._eval_gate( f"{prefix}.bit{bit}.or.weight", f"{prefix}.bit{bit}.or.bias", [and_a, and_b], ) out_bits.append(int(out_bit)) return bits_to_int_msb(out_bits) def step(self, state: CPUState) -> CPUState: if state.ctrl[0] == 1: # HALT return state.copy() s = state.copy() # Fetch: two bytes, big-endian hi = self._memory_read(s.mem, s.pc) lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF) s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) next_pc = (s.pc + 2) & 0xFFFF opcode, rd, rs, imm8 = self.decode_ir(s.ir) a = s.regs[rd] b = s.regs[rs] addr16 = None next_pc_ext = next_pc if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): addr_hi = self._memory_read(s.mem, next_pc) addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF) addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) next_pc_ext = (next_pc + 2) & 0xFFFF write_result = True result = a carry = 0 overflow = 0 if opcode == 0x0: # ADD result, carry, overflow = self.alu.add(a, b) elif opcode == 0x1: # SUB result, carry, overflow = self.alu.sub(a, b) elif opcode == 0x2: # AND result = self.alu.bitwise_and(a, b) elif opcode == 0x3: # OR result = self.alu.bitwise_or(a, b) elif opcode == 0x4: # XOR result = self.alu.bitwise_xor(a, b) elif opcode == 0x5: # SHL carry = 1 if (a & 0x80) else 0 result = (a << 1) & 0xFF elif opcode == 0x6: # SHR carry = 1 if (a & 0x01) else 0 result = (a >> 1) & 0xFF elif opcode == 0x7: # MUL full = a * b result = full & 0xFF carry = 1 if full > 0xFF else 0 elif opcode == 0x8: # DIV if b == 0: result = 0 carry = 1 overflow = 1 else: result = (a // b) & 0xFF elif opcode == 0x9: # CMP result, carry, overflow = self.alu.sub(a, b) write_result = False elif opcode == 0xA: # LOAD result = self._memory_read(s.mem, addr16) elif opcode == 0xB: # STORE s.mem = self._memory_write(s.mem, addr16, b & 0xFF) write_result = False elif opcode == 0xC: # JMP s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xD: # JZ hi_pc = self._conditional_jump_byte( "control.jz", (next_pc_ext >> 8) & 0xFF, (addr16 >> 8) & 0xFF, s.flags[0], ) lo_pc = self._conditional_jump_byte( "control.jz", next_pc_ext & 0xFF, addr16 & 0xFF, s.flags[0], ) s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF) write_result = False elif opcode == 0xE: # CALL ret_addr = next_pc_ext & 0xFFFF s.sp = (s.sp - 1) & 0xFFFF s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF) s.sp = (s.sp - 1) & 0xFFFF s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF) s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xF: # HALT s.ctrl[0] = 1 write_result = False if opcode <= 0x9 or opcode == 0xA: s.flags = self.flags_from_result(result, carry, overflow) if write_result: s.regs[rd] = result & 0xFF if opcode not in (0xC, 0xD, 0xE): s.pc = next_pc_ext return s def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: s = state.copy() for i in range(max_cycles): if s.ctrl[0] == 1: return s, i s = self.step(s) return s, max_cycles def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor: bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()] state = unpack_state(bits_list) final, _ = self.run_until_halt(state, max_cycles=max_cycles) return torch.tensor(pack_state(final), dtype=torch.float32)