| | """ |
| | Threshold-weight runtime for the 8-bit CPU. |
| | |
| | Implements a reference cycle using the frozen circuit weights for core ALU ops. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from pathlib import Path |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | from safetensors.torch import load_file |
| |
|
| | from .state import CPUState, pack_state, unpack_state, REG_BITS, PC_BITS, MEM_BYTES |
| |
|
| |
|
| | def heaviside(x: torch.Tensor) -> torch.Tensor: |
| | return (x >= 0).float() |
| |
|
| |
|
| | def int_to_bits_msb(value: int, width: int) -> List[int]: |
| | return [(value >> (width - 1 - i)) & 1 for i in range(width)] |
| |
|
| |
|
| | def bits_to_int_msb(bits: List[int]) -> int: |
| | value = 0 |
| | for bit in bits: |
| | value = (value << 1) | int(bit) |
| | return value |
| |
|
| |
|
| | def bits_msb_to_lsb(bits: List[int]) -> List[int]: |
| | return list(reversed(bits)) |
| |
|
| |
|
| | DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors" |
| |
|
| |
|
| | class ThresholdALU: |
| | def __init__(self, model_path: str, device: str = "cpu") -> None: |
| | self.device = device |
| | self.tensors = {k: v.float().to(device) for k, v in load_file(model_path).items()} |
| |
|
| | def _get(self, name: str) -> torch.Tensor: |
| | return self.tensors[name] |
| |
|
| | def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float: |
| | w = self._get(weight_key) |
| | b = self._get(bias_key) |
| | inp = torch.tensor(inputs, device=self.device) |
| | return heaviside((inp * w).sum() + b).item() |
| |
|
| | def _eval_xor(self, prefix: str, inputs: List[float]) -> float: |
| | inp = torch.tensor(inputs, device=self.device) |
| | w_or = self._get(f"{prefix}.layer1.or.weight") |
| | b_or = self._get(f"{prefix}.layer1.or.bias") |
| | w_nand = self._get(f"{prefix}.layer1.nand.weight") |
| | b_nand = self._get(f"{prefix}.layer1.nand.bias") |
| | w2 = self._get(f"{prefix}.layer2.weight") |
| | b2 = self._get(f"{prefix}.layer2.bias") |
| |
|
| | h_or = heaviside((inp * w_or).sum() + b_or).item() |
| | h_nand = heaviside((inp * w_nand).sum() + b_nand).item() |
| | hidden = torch.tensor([h_or, h_nand], device=self.device) |
| | return heaviside((hidden * w2).sum() + b2).item() |
| |
|
| | def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]: |
| | ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b]) |
| | ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b]) |
| |
|
| | ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin]) |
| | ha2_carry = self._eval_gate( |
| | f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin] |
| | ) |
| |
|
| | cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry]) |
| | return ha2_sum, cout |
| |
|
| | def add(self, a: int, b: int) -> Tuple[int, int, int]: |
| | a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS)) |
| | b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS)) |
| |
|
| | carry = 0.0 |
| | sum_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | sum_bit, carry = self._eval_full_adder( |
| | f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry |
| | ) |
| | sum_bits.append(int(sum_bit)) |
| |
|
| | result = bits_to_int_msb(list(reversed(sum_bits))) |
| | carry_out = int(carry) |
| | overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 |
| | return result, carry_out, overflow |
| |
|
| | def sub(self, a: int, b: int) -> Tuple[int, int, int]: |
| | a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS)) |
| | b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS)) |
| |
|
| | carry = 1.0 |
| | sum_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | notb = self._eval_gate( |
| | f"arithmetic.sub8bit.notb{bit}.weight", |
| | f"arithmetic.sub8bit.notb{bit}.bias", |
| | [float(b_bits[bit])], |
| | ) |
| |
|
| | xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb]) |
| | xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry]) |
| |
|
| | and1 = self._eval_gate( |
| | f"arithmetic.sub8bit.fa{bit}.and1.weight", |
| | f"arithmetic.sub8bit.fa{bit}.and1.bias", |
| | [float(a_bits[bit]), notb], |
| | ) |
| | and2 = self._eval_gate( |
| | f"arithmetic.sub8bit.fa{bit}.and2.weight", |
| | f"arithmetic.sub8bit.fa{bit}.and2.bias", |
| | [xor1, carry], |
| | ) |
| | carry = self._eval_gate( |
| | f"arithmetic.sub8bit.fa{bit}.or_carry.weight", |
| | f"arithmetic.sub8bit.fa{bit}.or_carry.bias", |
| | [and1, and2], |
| | ) |
| |
|
| | sum_bits.append(int(xor2)) |
| |
|
| | result = bits_to_int_msb(list(reversed(sum_bits))) |
| | carry_out = int(carry) |
| | overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 |
| | return result, carry_out, overflow |
| |
|
| | def bitwise_and(self, a: int, b: int) -> int: |
| | a_bits = int_to_bits_msb(a, REG_BITS) |
| | b_bits = int_to_bits_msb(b, REG_BITS) |
| | w = self._get("alu.alu8bit.and.weight") |
| | bias = self._get("alu.alu8bit.and.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
| | out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int_msb(out_bits) |
| |
|
| | def bitwise_or(self, a: int, b: int) -> int: |
| | a_bits = int_to_bits_msb(a, REG_BITS) |
| | b_bits = int_to_bits_msb(b, REG_BITS) |
| | w = self._get("alu.alu8bit.or.weight") |
| | bias = self._get("alu.alu8bit.or.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
| | out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int_msb(out_bits) |
| |
|
| | def bitwise_not(self, a: int) -> int: |
| | a_bits = int_to_bits_msb(a, REG_BITS) |
| | w = self._get("alu.alu8bit.not.weight") |
| | bias = self._get("alu.alu8bit.not.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit])], device=self.device) |
| | out = heaviside((inp * w[bit]).sum() + bias[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int_msb(out_bits) |
| |
|
| | def bitwise_xor(self, a: int, b: int) -> int: |
| | a_bits = int_to_bits_msb(a, REG_BITS) |
| | b_bits = int_to_bits_msb(b, REG_BITS) |
| |
|
| | w_or = self._get("alu.alu8bit.xor.layer1.or.weight") |
| | b_or = self._get("alu.alu8bit.xor.layer1.or.bias") |
| | w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight") |
| | b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias") |
| | w2 = self._get("alu.alu8bit.xor.layer2.weight") |
| | b2 = self._get("alu.alu8bit.xor.layer2.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
| | h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit]) |
| | h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit]) |
| | hidden = torch.stack([h_or, h_nand]) |
| | out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int_msb(out_bits) |
| |
|
| |
|
| | class ThresholdCPU: |
| | def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None: |
| | self.device = device |
| | self.alu = ThresholdALU(str(model_path), device=device) |
| |
|
| | @staticmethod |
| | def decode_ir(ir: int) -> Tuple[int, int, int, int]: |
| | opcode = (ir >> 12) & 0xF |
| | rd = (ir >> 10) & 0x3 |
| | rs = (ir >> 8) & 0x3 |
| | imm8 = ir & 0xFF |
| | return opcode, rd, rs, imm8 |
| |
|
| | @staticmethod |
| | def flags_from_result(result: int, carry: int, overflow: int) -> List[int]: |
| | z = 1 if result == 0 else 0 |
| | n = 1 if (result & 0x80) else 0 |
| | c = 1 if carry else 0 |
| | v = 1 if overflow else 0 |
| | return [z, n, c, v] |
| |
|
| | def _addr_decode(self, addr: int) -> torch.Tensor: |
| | bits = torch.tensor(int_to_bits_msb(addr, PC_BITS), device=self.device, dtype=torch.float32) |
| | w = self.alu._get("memory.addr_decode.weight") |
| | b = self.alu._get("memory.addr_decode.bias") |
| | return heaviside((w * bits).sum(dim=1) + b) |
| |
|
| | def _memory_read(self, mem: List[int], addr: int) -> int: |
| | sel = self._addr_decode(addr) |
| | mem_bits = torch.tensor( |
| | [int_to_bits_msb(byte, REG_BITS) for byte in mem], |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| | and_w = self.alu._get("memory.read.and.weight") |
| | and_b = self.alu._get("memory.read.and.bias") |
| | or_w = self.alu._get("memory.read.or.weight") |
| | or_b = self.alu._get("memory.read.or.bias") |
| |
|
| | out_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.stack([mem_bits[:, bit], sel], dim=1) |
| | and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) |
| | out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item() |
| | out_bits.append(int(out_bit)) |
| |
|
| | return bits_to_int_msb(out_bits) |
| |
|
| | def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]: |
| | sel = self._addr_decode(addr) |
| | data_bits = torch.tensor(int_to_bits_msb(value, REG_BITS), device=self.device, dtype=torch.float32) |
| | mem_bits = torch.tensor( |
| | [int_to_bits_msb(byte, REG_BITS) for byte in mem], |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| |
|
| | sel_w = self.alu._get("memory.write.sel.weight") |
| | sel_b = self.alu._get("memory.write.sel.bias") |
| | nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1) |
| | nsel_b = self.alu._get("memory.write.nsel.bias") |
| | and_old_w = self.alu._get("memory.write.and_old.weight") |
| | and_old_b = self.alu._get("memory.write.and_old.bias") |
| | and_new_w = self.alu._get("memory.write.and_new.weight") |
| | and_new_b = self.alu._get("memory.write.and_new.bias") |
| | or_w = self.alu._get("memory.write.or.weight") |
| | or_b = self.alu._get("memory.write.or.bias") |
| |
|
| | we = torch.ones_like(sel) |
| | sel_inp = torch.stack([sel, we], dim=1) |
| | write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) |
| | nsel = heaviside((write_sel * nsel_w) + nsel_b) |
| |
|
| | new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device) |
| | for bit in range(REG_BITS): |
| | old_bit = mem_bits[:, bit] |
| | data_bit = data_bits[bit].expand(MEM_BYTES) |
| | inp_old = torch.stack([old_bit, nsel], dim=1) |
| | inp_new = torch.stack([data_bit, write_sel], dim=1) |
| |
|
| | and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) |
| | and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) |
| | or_inp = torch.stack([and_old, and_new], dim=1) |
| | out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) |
| | new_mem_bits[:, bit] = out_bit |
| |
|
| | return [bits_to_int_msb([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)] |
| |
|
| | def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int: |
| | pc_bits = int_to_bits_msb(pc_byte, REG_BITS) |
| | target_bits = int_to_bits_msb(target_byte, REG_BITS) |
| |
|
| | out_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | not_sel = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.not_sel.weight", |
| | f"{prefix}.bit{bit}.not_sel.bias", |
| | [float(flag)], |
| | ) |
| | and_a = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.and_a.weight", |
| | f"{prefix}.bit{bit}.and_a.bias", |
| | [float(pc_bits[bit]), not_sel], |
| | ) |
| | and_b = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.and_b.weight", |
| | f"{prefix}.bit{bit}.and_b.bias", |
| | [float(target_bits[bit]), float(flag)], |
| | ) |
| | out_bit = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.or.weight", |
| | f"{prefix}.bit{bit}.or.bias", |
| | [and_a, and_b], |
| | ) |
| | out_bits.append(int(out_bit)) |
| |
|
| | return bits_to_int_msb(out_bits) |
| |
|
| | def step(self, state: CPUState) -> CPUState: |
| | if state.ctrl[0] == 1: |
| | return state.copy() |
| |
|
| | s = state.copy() |
| |
|
| | |
| | hi = self._memory_read(s.mem, s.pc) |
| | lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF) |
| | s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) |
| | next_pc = (s.pc + 2) & 0xFFFF |
| |
|
| | opcode, rd, rs, imm8 = self.decode_ir(s.ir) |
| | a = s.regs[rd] |
| | b = s.regs[rs] |
| |
|
| | addr16 = None |
| | next_pc_ext = next_pc |
| | if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): |
| | addr_hi = self._memory_read(s.mem, next_pc) |
| | addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF) |
| | addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) |
| | next_pc_ext = (next_pc + 2) & 0xFFFF |
| |
|
| | write_result = True |
| | result = a |
| | carry = 0 |
| | overflow = 0 |
| |
|
| | if opcode == 0x0: |
| | result, carry, overflow = self.alu.add(a, b) |
| | elif opcode == 0x1: |
| | result, carry, overflow = self.alu.sub(a, b) |
| | elif opcode == 0x2: |
| | result = self.alu.bitwise_and(a, b) |
| | elif opcode == 0x3: |
| | result = self.alu.bitwise_or(a, b) |
| | elif opcode == 0x4: |
| | result = self.alu.bitwise_xor(a, b) |
| | elif opcode == 0x5: |
| | carry = 1 if (a & 0x80) else 0 |
| | result = (a << 1) & 0xFF |
| | elif opcode == 0x6: |
| | carry = 1 if (a & 0x01) else 0 |
| | result = (a >> 1) & 0xFF |
| | elif opcode == 0x7: |
| | full = a * b |
| | result = full & 0xFF |
| | carry = 1 if full > 0xFF else 0 |
| | elif opcode == 0x8: |
| | if b == 0: |
| | result = 0 |
| | carry = 1 |
| | overflow = 1 |
| | else: |
| | result = (a // b) & 0xFF |
| | elif opcode == 0x9: |
| | result, carry, overflow = self.alu.sub(a, b) |
| | write_result = False |
| | elif opcode == 0xA: |
| | result = self._memory_read(s.mem, addr16) |
| | elif opcode == 0xB: |
| | s.mem = self._memory_write(s.mem, addr16, b & 0xFF) |
| | write_result = False |
| | elif opcode == 0xC: |
| | s.pc = addr16 & 0xFFFF |
| | write_result = False |
| | elif opcode == 0xD: |
| | hi_pc = self._conditional_jump_byte( |
| | "control.jz", |
| | (next_pc_ext >> 8) & 0xFF, |
| | (addr16 >> 8) & 0xFF, |
| | s.flags[0], |
| | ) |
| | lo_pc = self._conditional_jump_byte( |
| | "control.jz", |
| | next_pc_ext & 0xFF, |
| | addr16 & 0xFF, |
| | s.flags[0], |
| | ) |
| | s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF) |
| | write_result = False |
| | elif opcode == 0xE: |
| | ret_addr = next_pc_ext & 0xFFFF |
| | s.sp = (s.sp - 1) & 0xFFFF |
| | s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF) |
| | s.sp = (s.sp - 1) & 0xFFFF |
| | s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF) |
| | s.pc = addr16 & 0xFFFF |
| | write_result = False |
| | elif opcode == 0xF: |
| | s.ctrl[0] = 1 |
| | write_result = False |
| |
|
| | if opcode <= 0x9 or opcode == 0xA: |
| | s.flags = self.flags_from_result(result, carry, overflow) |
| |
|
| | if write_result: |
| | s.regs[rd] = result & 0xFF |
| |
|
| | if opcode not in (0xC, 0xD, 0xE): |
| | s.pc = next_pc_ext |
| |
|
| | return s |
| |
|
| | def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: |
| | s = state.copy() |
| | for i in range(max_cycles): |
| | if s.ctrl[0] == 1: |
| | return s, i |
| | s = self.step(s) |
| | return s, max_cycles |
| |
|
| | def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor: |
| | bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()] |
| | state = unpack_state(bits_list) |
| | final, _ = self.run_until_halt(state, max_cycles=max_cycles) |
| | return torch.tensor(pack_state(final), dtype=torch.float32) |
| |
|