8bit-threshold-computer / cpu /threshold_cpu.py
PortfolioAI
Add packed memory routing and 16-bit addressing
ea46629
raw
history blame
16.9 kB
"""
Threshold-weight runtime for the 8-bit CPU.
Implements a reference cycle using the frozen circuit weights for core ALU ops.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
import torch
from safetensors.torch import load_file
from .state import CPUState, pack_state, unpack_state, REG_BITS, PC_BITS, MEM_BYTES
def heaviside(x: torch.Tensor) -> torch.Tensor:
return (x >= 0).float()
def int_to_bits_msb(value: int, width: int) -> List[int]:
return [(value >> (width - 1 - i)) & 1 for i in range(width)]
def bits_to_int_msb(bits: List[int]) -> int:
value = 0
for bit in bits:
value = (value << 1) | int(bit)
return value
def bits_msb_to_lsb(bits: List[int]) -> List[int]:
return list(reversed(bits))
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors"
class ThresholdALU:
def __init__(self, model_path: str, device: str = "cpu") -> None:
self.device = device
self.tensors = {k: v.float().to(device) for k, v in load_file(model_path).items()}
def _get(self, name: str) -> torch.Tensor:
return self.tensors[name]
def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float:
w = self._get(weight_key)
b = self._get(bias_key)
inp = torch.tensor(inputs, device=self.device)
return heaviside((inp * w).sum() + b).item()
def _eval_xor(self, prefix: str, inputs: List[float]) -> float:
inp = torch.tensor(inputs, device=self.device)
w_or = self._get(f"{prefix}.layer1.or.weight")
b_or = self._get(f"{prefix}.layer1.or.bias")
w_nand = self._get(f"{prefix}.layer1.nand.weight")
b_nand = self._get(f"{prefix}.layer1.nand.bias")
w2 = self._get(f"{prefix}.layer2.weight")
b2 = self._get(f"{prefix}.layer2.bias")
h_or = heaviside((inp * w_or).sum() + b_or).item()
h_nand = heaviside((inp * w_nand).sum() + b_nand).item()
hidden = torch.tensor([h_or, h_nand], device=self.device)
return heaviside((hidden * w2).sum() + b2).item()
def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]:
ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b])
ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b])
ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin])
ha2_carry = self._eval_gate(
f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin]
)
cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry])
return ha2_sum, cout
def add(self, a: int, b: int) -> Tuple[int, int, int]:
a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS))
b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS))
carry = 0.0
sum_bits: List[int] = []
for bit in range(REG_BITS):
sum_bit, carry = self._eval_full_adder(
f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry
)
sum_bits.append(int(sum_bit))
result = bits_to_int_msb(list(reversed(sum_bits)))
carry_out = int(carry)
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
return result, carry_out, overflow
def sub(self, a: int, b: int) -> Tuple[int, int, int]:
a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS))
b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS))
carry = 1.0 # two's complement carry-in
sum_bits: List[int] = []
for bit in range(REG_BITS):
notb = self._eval_gate(
f"arithmetic.sub8bit.notb{bit}.weight",
f"arithmetic.sub8bit.notb{bit}.bias",
[float(b_bits[bit])],
)
xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb])
xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry])
and1 = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.and1.weight",
f"arithmetic.sub8bit.fa{bit}.and1.bias",
[float(a_bits[bit]), notb],
)
and2 = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.and2.weight",
f"arithmetic.sub8bit.fa{bit}.and2.bias",
[xor1, carry],
)
carry = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.or_carry.weight",
f"arithmetic.sub8bit.fa{bit}.or_carry.bias",
[and1, and2],
)
sum_bits.append(int(xor2))
result = bits_to_int_msb(list(reversed(sum_bits)))
carry_out = int(carry)
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
return result, carry_out, overflow
def bitwise_and(self, a: int, b: int) -> int:
a_bits = int_to_bits_msb(a, REG_BITS)
b_bits = int_to_bits_msb(b, REG_BITS)
w = self._get("alu.alu8bit.and.weight")
bias = self._get("alu.alu8bit.and.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int_msb(out_bits)
def bitwise_or(self, a: int, b: int) -> int:
a_bits = int_to_bits_msb(a, REG_BITS)
b_bits = int_to_bits_msb(b, REG_BITS)
w = self._get("alu.alu8bit.or.weight")
bias = self._get("alu.alu8bit.or.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int_msb(out_bits)
def bitwise_not(self, a: int) -> int:
a_bits = int_to_bits_msb(a, REG_BITS)
w = self._get("alu.alu8bit.not.weight")
bias = self._get("alu.alu8bit.not.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit])], device=self.device)
out = heaviside((inp * w[bit]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int_msb(out_bits)
def bitwise_xor(self, a: int, b: int) -> int:
a_bits = int_to_bits_msb(a, REG_BITS)
b_bits = int_to_bits_msb(b, REG_BITS)
w_or = self._get("alu.alu8bit.xor.layer1.or.weight")
b_or = self._get("alu.alu8bit.xor.layer1.or.bias")
w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight")
b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias")
w2 = self._get("alu.alu8bit.xor.layer2.weight")
b2 = self._get("alu.alu8bit.xor.layer2.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit])
h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit])
hidden = torch.stack([h_or, h_nand])
out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item()
out_bits.append(int(out))
return bits_to_int_msb(out_bits)
class ThresholdCPU:
def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None:
self.device = device
self.alu = ThresholdALU(str(model_path), device=device)
@staticmethod
def decode_ir(ir: int) -> Tuple[int, int, int, int]:
opcode = (ir >> 12) & 0xF
rd = (ir >> 10) & 0x3
rs = (ir >> 8) & 0x3
imm8 = ir & 0xFF
return opcode, rd, rs, imm8
@staticmethod
def flags_from_result(result: int, carry: int, overflow: int) -> List[int]:
z = 1 if result == 0 else 0
n = 1 if (result & 0x80) else 0
c = 1 if carry else 0
v = 1 if overflow else 0
return [z, n, c, v]
def _addr_decode(self, addr: int) -> torch.Tensor:
bits = torch.tensor(int_to_bits_msb(addr, PC_BITS), device=self.device, dtype=torch.float32)
w = self.alu._get("memory.addr_decode.weight")
b = self.alu._get("memory.addr_decode.bias")
return heaviside((w * bits).sum(dim=1) + b)
def _memory_read(self, mem: List[int], addr: int) -> int:
sel = self._addr_decode(addr)
mem_bits = torch.tensor(
[int_to_bits_msb(byte, REG_BITS) for byte in mem],
device=self.device,
dtype=torch.float32,
)
and_w = self.alu._get("memory.read.and.weight")
and_b = self.alu._get("memory.read.and.bias")
or_w = self.alu._get("memory.read.or.weight")
or_b = self.alu._get("memory.read.or.bias")
out_bits: List[int] = []
for bit in range(REG_BITS):
inp = torch.stack([mem_bits[:, bit], sel], dim=1)
and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()
out_bits.append(int(out_bit))
return bits_to_int_msb(out_bits)
def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]:
sel = self._addr_decode(addr)
data_bits = torch.tensor(int_to_bits_msb(value, REG_BITS), device=self.device, dtype=torch.float32)
mem_bits = torch.tensor(
[int_to_bits_msb(byte, REG_BITS) for byte in mem],
device=self.device,
dtype=torch.float32,
)
sel_w = self.alu._get("memory.write.sel.weight")
sel_b = self.alu._get("memory.write.sel.bias")
nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1)
nsel_b = self.alu._get("memory.write.nsel.bias")
and_old_w = self.alu._get("memory.write.and_old.weight")
and_old_b = self.alu._get("memory.write.and_old.bias")
and_new_w = self.alu._get("memory.write.and_new.weight")
and_new_b = self.alu._get("memory.write.and_new.bias")
or_w = self.alu._get("memory.write.or.weight")
or_b = self.alu._get("memory.write.or.bias")
we = torch.ones_like(sel)
sel_inp = torch.stack([sel, we], dim=1)
write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
nsel = heaviside((write_sel * nsel_w) + nsel_b)
new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device)
for bit in range(REG_BITS):
old_bit = mem_bits[:, bit]
data_bit = data_bits[bit].expand(MEM_BYTES)
inp_old = torch.stack([old_bit, nsel], dim=1)
inp_new = torch.stack([data_bit, write_sel], dim=1)
and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
or_inp = torch.stack([and_old, and_new], dim=1)
out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
new_mem_bits[:, bit] = out_bit
return [bits_to_int_msb([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)]
def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int:
pc_bits = int_to_bits_msb(pc_byte, REG_BITS)
target_bits = int_to_bits_msb(target_byte, REG_BITS)
out_bits: List[int] = []
for bit in range(REG_BITS):
not_sel = self.alu._eval_gate(
f"{prefix}.bit{bit}.not_sel.weight",
f"{prefix}.bit{bit}.not_sel.bias",
[float(flag)],
)
and_a = self.alu._eval_gate(
f"{prefix}.bit{bit}.and_a.weight",
f"{prefix}.bit{bit}.and_a.bias",
[float(pc_bits[bit]), not_sel],
)
and_b = self.alu._eval_gate(
f"{prefix}.bit{bit}.and_b.weight",
f"{prefix}.bit{bit}.and_b.bias",
[float(target_bits[bit]), float(flag)],
)
out_bit = self.alu._eval_gate(
f"{prefix}.bit{bit}.or.weight",
f"{prefix}.bit{bit}.or.bias",
[and_a, and_b],
)
out_bits.append(int(out_bit))
return bits_to_int_msb(out_bits)
def step(self, state: CPUState) -> CPUState:
if state.ctrl[0] == 1: # HALT
return state.copy()
s = state.copy()
# Fetch: two bytes, big-endian
hi = self._memory_read(s.mem, s.pc)
lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF)
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
next_pc = (s.pc + 2) & 0xFFFF
opcode, rd, rs, imm8 = self.decode_ir(s.ir)
a = s.regs[rd]
b = s.regs[rs]
addr16 = None
next_pc_ext = next_pc
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
addr_hi = self._memory_read(s.mem, next_pc)
addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF)
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
next_pc_ext = (next_pc + 2) & 0xFFFF
write_result = True
result = a
carry = 0
overflow = 0
if opcode == 0x0: # ADD
result, carry, overflow = self.alu.add(a, b)
elif opcode == 0x1: # SUB
result, carry, overflow = self.alu.sub(a, b)
elif opcode == 0x2: # AND
result = self.alu.bitwise_and(a, b)
elif opcode == 0x3: # OR
result = self.alu.bitwise_or(a, b)
elif opcode == 0x4: # XOR
result = self.alu.bitwise_xor(a, b)
elif opcode == 0x5: # SHL
carry = 1 if (a & 0x80) else 0
result = (a << 1) & 0xFF
elif opcode == 0x6: # SHR
carry = 1 if (a & 0x01) else 0
result = (a >> 1) & 0xFF
elif opcode == 0x7: # MUL
full = a * b
result = full & 0xFF
carry = 1 if full > 0xFF else 0
elif opcode == 0x8: # DIV
if b == 0:
result = 0
carry = 1
overflow = 1
else:
result = (a // b) & 0xFF
elif opcode == 0x9: # CMP
result, carry, overflow = self.alu.sub(a, b)
write_result = False
elif opcode == 0xA: # LOAD
result = self._memory_read(s.mem, addr16)
elif opcode == 0xB: # STORE
s.mem = self._memory_write(s.mem, addr16, b & 0xFF)
write_result = False
elif opcode == 0xC: # JMP
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xD: # JZ
hi_pc = self._conditional_jump_byte(
"control.jz",
(next_pc_ext >> 8) & 0xFF,
(addr16 >> 8) & 0xFF,
s.flags[0],
)
lo_pc = self._conditional_jump_byte(
"control.jz",
next_pc_ext & 0xFF,
addr16 & 0xFF,
s.flags[0],
)
s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF)
write_result = False
elif opcode == 0xE: # CALL
ret_addr = next_pc_ext & 0xFFFF
s.sp = (s.sp - 1) & 0xFFFF
s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF)
s.sp = (s.sp - 1) & 0xFFFF
s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF)
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xF: # HALT
s.ctrl[0] = 1
write_result = False
if opcode <= 0x9 or opcode == 0xA:
s.flags = self.flags_from_result(result, carry, overflow)
if write_result:
s.regs[rd] = result & 0xFF
if opcode not in (0xC, 0xD, 0xE):
s.pc = next_pc_ext
return s
def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
s = state.copy()
for i in range(max_cycles):
if s.ctrl[0] == 1:
return s, i
s = self.step(s)
return s, max_cycles
def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor:
bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()]
state = unpack_state(bits_list)
final, _ = self.run_until_halt(state, max_cycles=max_cycles)
return torch.tensor(pack_state(final), dtype=torch.float32)