|
|
""" |
|
|
8-bit Threshold Computer - CPU Runtime |
|
|
|
|
|
State layout, reference cycle, and threshold-weight execution. |
|
|
All multi-bit fields are MSB-first. |
|
|
|
|
|
Usage: |
|
|
python threshold_cpu.py # Run smoke test |
|
|
python threshold_cpu.py --help # Show options |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
FLAG_NAMES = ["Z", "N", "C", "V"] |
|
|
CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"] |
|
|
|
|
|
PC_BITS = 16 |
|
|
IR_BITS = 16 |
|
|
REG_BITS = 8 |
|
|
REG_COUNT = 4 |
|
|
FLAG_BITS = 4 |
|
|
SP_BITS = 16 |
|
|
CTRL_BITS = 4 |
|
|
MEM_BYTES = 65536 |
|
|
MEM_BITS = MEM_BYTES * 8 |
|
|
|
|
|
STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS |
|
|
|
|
|
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "neural_computer.safetensors" |
|
|
|
|
|
|
|
|
def int_to_bits(value: int, width: int) -> List[int]: |
|
|
return [(value >> (width - 1 - i)) & 1 for i in range(width)] |
|
|
|
|
|
|
|
|
def bits_to_int(bits: List[int]) -> int: |
|
|
value = 0 |
|
|
for bit in bits: |
|
|
value = (value << 1) | int(bit) |
|
|
return value |
|
|
|
|
|
|
|
|
def bits_msb_to_lsb(bits: List[int]) -> List[int]: |
|
|
return list(reversed(bits)) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CPUState: |
|
|
pc: int |
|
|
ir: int |
|
|
regs: List[int] |
|
|
flags: List[int] |
|
|
sp: int |
|
|
ctrl: List[int] |
|
|
mem: List[int] |
|
|
|
|
|
def copy(self) -> CPUState: |
|
|
return CPUState( |
|
|
pc=int(self.pc), |
|
|
ir=int(self.ir), |
|
|
regs=[int(r) for r in self.regs], |
|
|
flags=[int(f) for f in self.flags], |
|
|
sp=int(self.sp), |
|
|
ctrl=[int(c) for c in self.ctrl], |
|
|
mem=[int(m) for m in self.mem], |
|
|
) |
|
|
|
|
|
|
|
|
def pack_state(state: CPUState) -> List[int]: |
|
|
bits: List[int] = [] |
|
|
bits.extend(int_to_bits(state.pc, PC_BITS)) |
|
|
bits.extend(int_to_bits(state.ir, IR_BITS)) |
|
|
for reg in state.regs: |
|
|
bits.extend(int_to_bits(reg, REG_BITS)) |
|
|
bits.extend([int(f) for f in state.flags]) |
|
|
bits.extend(int_to_bits(state.sp, SP_BITS)) |
|
|
bits.extend([int(c) for c in state.ctrl]) |
|
|
for byte in state.mem: |
|
|
bits.extend(int_to_bits(byte, REG_BITS)) |
|
|
return bits |
|
|
|
|
|
|
|
|
def unpack_state(bits: List[int]) -> CPUState: |
|
|
if len(bits) != STATE_BITS: |
|
|
raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}") |
|
|
|
|
|
idx = 0 |
|
|
pc = bits_to_int(bits[idx:idx + PC_BITS]) |
|
|
idx += PC_BITS |
|
|
ir = bits_to_int(bits[idx:idx + IR_BITS]) |
|
|
idx += IR_BITS |
|
|
|
|
|
regs = [] |
|
|
for _ in range(REG_COUNT): |
|
|
regs.append(bits_to_int(bits[idx:idx + REG_BITS])) |
|
|
idx += REG_BITS |
|
|
|
|
|
flags = [int(b) for b in bits[idx:idx + FLAG_BITS]] |
|
|
idx += FLAG_BITS |
|
|
|
|
|
sp = bits_to_int(bits[idx:idx + SP_BITS]) |
|
|
idx += SP_BITS |
|
|
|
|
|
ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]] |
|
|
idx += CTRL_BITS |
|
|
|
|
|
mem = [] |
|
|
for _ in range(MEM_BYTES): |
|
|
mem.append(bits_to_int(bits[idx:idx + REG_BITS])) |
|
|
idx += REG_BITS |
|
|
|
|
|
return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem) |
|
|
|
|
|
|
|
|
def decode_ir(ir: int) -> Tuple[int, int, int, int]: |
|
|
opcode = (ir >> 12) & 0xF |
|
|
rd = (ir >> 10) & 0x3 |
|
|
rs = (ir >> 8) & 0x3 |
|
|
imm8 = ir & 0xFF |
|
|
return opcode, rd, rs, imm8 |
|
|
|
|
|
|
|
|
def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]: |
|
|
z = 1 if result == 0 else 0 |
|
|
n = 1 if (result & 0x80) else 0 |
|
|
c = 1 if carry else 0 |
|
|
v = 1 if overflow else 0 |
|
|
return z, n, c, v |
|
|
|
|
|
|
|
|
def alu_add(a: int, b: int) -> Tuple[int, int, int]: |
|
|
full = a + b |
|
|
result = full & 0xFF |
|
|
carry = 1 if full > 0xFF else 0 |
|
|
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 |
|
|
return result, carry, overflow |
|
|
|
|
|
|
|
|
def alu_sub(a: int, b: int) -> Tuple[int, int, int]: |
|
|
full = (a - b) & 0x1FF |
|
|
result = full & 0xFF |
|
|
carry = 1 if a >= b else 0 |
|
|
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 |
|
|
return result, carry, overflow |
|
|
|
|
|
|
|
|
def ref_step(state: CPUState) -> CPUState: |
|
|
"""Reference CPU cycle (pure Python arithmetic).""" |
|
|
if state.ctrl[0] == 1: |
|
|
return state.copy() |
|
|
|
|
|
s = state.copy() |
|
|
|
|
|
hi = s.mem[s.pc] |
|
|
lo = s.mem[(s.pc + 1) & 0xFFFF] |
|
|
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) |
|
|
next_pc = (s.pc + 2) & 0xFFFF |
|
|
|
|
|
opcode, rd, rs, imm8 = decode_ir(s.ir) |
|
|
a = s.regs[rd] |
|
|
b = s.regs[rs] |
|
|
|
|
|
addr16 = None |
|
|
next_pc_ext = next_pc |
|
|
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): |
|
|
addr_hi = s.mem[next_pc] |
|
|
addr_lo = s.mem[(next_pc + 1) & 0xFFFF] |
|
|
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) |
|
|
next_pc_ext = (next_pc + 2) & 0xFFFF |
|
|
|
|
|
write_result = True |
|
|
result = a |
|
|
carry = 0 |
|
|
overflow = 0 |
|
|
|
|
|
if opcode == 0x0: |
|
|
result, carry, overflow = alu_add(a, b) |
|
|
elif opcode == 0x1: |
|
|
result, carry, overflow = alu_sub(a, b) |
|
|
elif opcode == 0x2: |
|
|
result = a & b |
|
|
elif opcode == 0x3: |
|
|
result = a | b |
|
|
elif opcode == 0x4: |
|
|
result = a ^ b |
|
|
elif opcode == 0x5: |
|
|
result = (a << 1) & 0xFF |
|
|
elif opcode == 0x6: |
|
|
result = (a >> 1) & 0xFF |
|
|
elif opcode == 0x7: |
|
|
result = (a * b) & 0xFF |
|
|
elif opcode == 0x8: |
|
|
if b == 0: |
|
|
result = 0xFF |
|
|
else: |
|
|
result = a // b |
|
|
elif opcode == 0x9: |
|
|
result, carry, overflow = alu_sub(a, b) |
|
|
write_result = False |
|
|
elif opcode == 0xA: |
|
|
result = s.mem[addr16] |
|
|
elif opcode == 0xB: |
|
|
s.mem[addr16] = b & 0xFF |
|
|
write_result = False |
|
|
elif opcode == 0xC: |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xD: |
|
|
cond_type = imm8 & 0x7 |
|
|
if cond_type == 0: |
|
|
take_branch = s.flags[0] == 1 |
|
|
elif cond_type == 1: |
|
|
take_branch = s.flags[0] == 0 |
|
|
elif cond_type == 2: |
|
|
take_branch = s.flags[2] == 1 |
|
|
elif cond_type == 3: |
|
|
take_branch = s.flags[2] == 0 |
|
|
elif cond_type == 4: |
|
|
take_branch = s.flags[1] == 1 |
|
|
elif cond_type == 5: |
|
|
take_branch = s.flags[1] == 0 |
|
|
elif cond_type == 6: |
|
|
take_branch = s.flags[3] == 1 |
|
|
else: |
|
|
take_branch = s.flags[3] == 0 |
|
|
if take_branch: |
|
|
s.pc = addr16 & 0xFFFF |
|
|
else: |
|
|
s.pc = next_pc_ext |
|
|
write_result = False |
|
|
elif opcode == 0xE: |
|
|
ret_addr = next_pc_ext & 0xFFFF |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem[s.sp] = (ret_addr >> 8) & 0xFF |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem[s.sp] = ret_addr & 0xFF |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xF: |
|
|
s.ctrl[0] = 1 |
|
|
write_result = False |
|
|
|
|
|
if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8): |
|
|
s.flags = list(flags_from_result(result, carry, overflow)) |
|
|
|
|
|
if write_result: |
|
|
s.regs[rd] = result & 0xFF |
|
|
|
|
|
if opcode not in (0xC, 0xD, 0xE): |
|
|
s.pc = next_pc_ext |
|
|
|
|
|
return s |
|
|
|
|
|
|
|
|
def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: |
|
|
"""Reference execution loop.""" |
|
|
s = state.copy() |
|
|
for i in range(max_cycles): |
|
|
if s.ctrl[0] == 1: |
|
|
return s, i |
|
|
s = ref_step(s) |
|
|
return s, max_cycles |
|
|
|
|
|
|
|
|
def heaviside(x: torch.Tensor) -> torch.Tensor: |
|
|
return (x >= 0).float() |
|
|
|
|
|
|
|
|
class ThresholdALU: |
|
|
def __init__(self, model_path: str, device: str = "cpu") -> None: |
|
|
self.device = device |
|
|
self.tensors = {k: v.float().to(device) for k, v in load_file(model_path).items()} |
|
|
|
|
|
def _get(self, name: str) -> torch.Tensor: |
|
|
return self.tensors[name] |
|
|
|
|
|
def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float: |
|
|
w = self._get(weight_key) |
|
|
b = self._get(bias_key) |
|
|
inp = torch.tensor(inputs, device=self.device) |
|
|
return heaviside((inp * w).sum() + b).item() |
|
|
|
|
|
def _eval_xor(self, prefix: str, inputs: List[float]) -> float: |
|
|
inp = torch.tensor(inputs, device=self.device) |
|
|
w_or = self._get(f"{prefix}.layer1.or.weight") |
|
|
b_or = self._get(f"{prefix}.layer1.or.bias") |
|
|
w_nand = self._get(f"{prefix}.layer1.nand.weight") |
|
|
b_nand = self._get(f"{prefix}.layer1.nand.bias") |
|
|
w2 = self._get(f"{prefix}.layer2.weight") |
|
|
b2 = self._get(f"{prefix}.layer2.bias") |
|
|
|
|
|
h_or = heaviside((inp * w_or).sum() + b_or).item() |
|
|
h_nand = heaviside((inp * w_nand).sum() + b_nand).item() |
|
|
hidden = torch.tensor([h_or, h_nand], device=self.device) |
|
|
return heaviside((hidden * w2).sum() + b2).item() |
|
|
|
|
|
def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]: |
|
|
ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b]) |
|
|
ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b]) |
|
|
|
|
|
ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin]) |
|
|
ha2_carry = self._eval_gate( |
|
|
f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin] |
|
|
) |
|
|
|
|
|
cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry]) |
|
|
return ha2_sum, cout |
|
|
|
|
|
def add(self, a: int, b: int) -> Tuple[int, int, int]: |
|
|
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) |
|
|
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) |
|
|
|
|
|
carry = 0.0 |
|
|
sum_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
sum_bit, carry = self._eval_full_adder( |
|
|
f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry |
|
|
) |
|
|
sum_bits.append(int(sum_bit)) |
|
|
|
|
|
result = bits_to_int(list(reversed(sum_bits))) |
|
|
carry_out = int(carry) |
|
|
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 |
|
|
return result, carry_out, overflow |
|
|
|
|
|
def sub(self, a: int, b: int) -> Tuple[int, int, int]: |
|
|
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) |
|
|
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) |
|
|
|
|
|
carry = 1.0 |
|
|
sum_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
notb = self._eval_gate( |
|
|
f"arithmetic.sub8bit.notb{bit}.weight", |
|
|
f"arithmetic.sub8bit.notb{bit}.bias", |
|
|
[float(b_bits[bit])], |
|
|
) |
|
|
|
|
|
xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb]) |
|
|
xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry]) |
|
|
|
|
|
and1 = self._eval_gate( |
|
|
f"arithmetic.sub8bit.fa{bit}.and1.weight", |
|
|
f"arithmetic.sub8bit.fa{bit}.and1.bias", |
|
|
[float(a_bits[bit]), notb], |
|
|
) |
|
|
and2 = self._eval_gate( |
|
|
f"arithmetic.sub8bit.fa{bit}.and2.weight", |
|
|
f"arithmetic.sub8bit.fa{bit}.and2.bias", |
|
|
[xor1, carry], |
|
|
) |
|
|
carry = self._eval_gate( |
|
|
f"arithmetic.sub8bit.fa{bit}.or_carry.weight", |
|
|
f"arithmetic.sub8bit.fa{bit}.or_carry.bias", |
|
|
[and1, and2], |
|
|
) |
|
|
|
|
|
sum_bits.append(int(xor2)) |
|
|
|
|
|
result = bits_to_int(list(reversed(sum_bits))) |
|
|
carry_out = int(carry) |
|
|
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 |
|
|
return result, carry_out, overflow |
|
|
|
|
|
def bitwise_and(self, a: int, b: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
w = self._get("alu.alu8bit.and.weight") |
|
|
bias = self._get("alu.alu8bit.and.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
|
|
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def bitwise_or(self, a: int, b: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
w = self._get("alu.alu8bit.or.weight") |
|
|
bias = self._get("alu.alu8bit.or.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
|
|
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def bitwise_not(self, a: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
w = self._get("alu.alu8bit.not.weight") |
|
|
bias = self._get("alu.alu8bit.not.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit])], device=self.device) |
|
|
out = heaviside((inp * w[bit]).sum() + bias[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def bitwise_xor(self, a: int, b: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
|
|
|
w_or = self._get("alu.alu8bit.xor.layer1.or.weight") |
|
|
b_or = self._get("alu.alu8bit.xor.layer1.or.bias") |
|
|
w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight") |
|
|
b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias") |
|
|
w2 = self._get("alu.alu8bit.xor.layer2.weight") |
|
|
b2 = self._get("alu.alu8bit.xor.layer2.bias") |
|
|
|
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
|
|
h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit]) |
|
|
h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit]) |
|
|
hidden = torch.stack([h_or, h_nand]) |
|
|
out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item() |
|
|
out_bits.append(int(out)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def shift_left(self, a: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
w = self.alu._get(f"alu.alu8bit.shl.bit{bit}.weight") |
|
|
bias = self.alu._get(f"alu.alu8bit.shl.bit{bit}.bias") |
|
|
if bit < 7: |
|
|
inp = torch.tensor([float(a_bits[bit + 1])], device=self.device) |
|
|
else: |
|
|
inp = torch.tensor([0.0], device=self.device) |
|
|
out = heaviside((inp * w).sum() + bias).item() |
|
|
out_bits.append(int(out)) |
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def shift_right(self, a: int) -> int: |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
out_bits = [] |
|
|
for bit in range(REG_BITS): |
|
|
w = self.alu._get(f"alu.alu8bit.shr.bit{bit}.weight") |
|
|
bias = self.alu._get(f"alu.alu8bit.shr.bit{bit}.bias") |
|
|
if bit > 0: |
|
|
inp = torch.tensor([float(a_bits[bit - 1])], device=self.device) |
|
|
else: |
|
|
inp = torch.tensor([0.0], device=self.device) |
|
|
out = heaviside((inp * w).sum() + bias).item() |
|
|
out_bits.append(int(out)) |
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def multiply(self, a: int, b: int) -> int: |
|
|
"""8-bit multiply using partial product AND gates + shift-add.""" |
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
b_bits = int_to_bits(b, REG_BITS) |
|
|
|
|
|
|
|
|
pp = [[0] * 8 for _ in range(8)] |
|
|
for i in range(8): |
|
|
for j in range(8): |
|
|
w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight") |
|
|
bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias") |
|
|
inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device) |
|
|
pp[i][j] = int(heaviside((inp * w).sum() + bias).item()) |
|
|
|
|
|
|
|
|
|
|
|
result = 0 |
|
|
for j in range(8): |
|
|
if b_bits[j] == 0: |
|
|
continue |
|
|
|
|
|
row = 0 |
|
|
for i in range(8): |
|
|
row |= (pp[i][j] << (7 - i)) |
|
|
|
|
|
shifted = row << (7 - j) |
|
|
|
|
|
result, _, _ = self.add(result & 0xFF, shifted & 0xFF) |
|
|
|
|
|
if shifted > 255 or result > 255: |
|
|
result = (result + (shifted >> 8)) & 0xFF |
|
|
|
|
|
return result & 0xFF |
|
|
|
|
|
def divide(self, a: int, b: int) -> Tuple[int, int]: |
|
|
"""8-bit divide using restoring division with threshold gates.""" |
|
|
if b == 0: |
|
|
return 0xFF, a |
|
|
|
|
|
a_bits = int_to_bits(a, REG_BITS) |
|
|
|
|
|
quotient = 0 |
|
|
remainder = 0 |
|
|
|
|
|
for stage in range(8): |
|
|
|
|
|
remainder = ((remainder << 1) | a_bits[stage]) & 0xFF |
|
|
|
|
|
|
|
|
rem_bits = int_to_bits(remainder, REG_BITS) |
|
|
div_bits = int_to_bits(b, REG_BITS) |
|
|
|
|
|
w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight") |
|
|
bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias") |
|
|
inp = torch.tensor([float(rem_bits[i]) for i in range(8)] + |
|
|
[float(div_bits[i]) for i in range(8)], device=self.device) |
|
|
cmp_result = int(heaviside((inp * w).sum() + bias).item()) |
|
|
|
|
|
|
|
|
if cmp_result: |
|
|
remainder, _, _ = self.sub(remainder, b) |
|
|
quotient = (quotient << 1) | 1 |
|
|
else: |
|
|
quotient = quotient << 1 |
|
|
|
|
|
return quotient & 0xFF, remainder & 0xFF |
|
|
|
|
|
|
|
|
class ThresholdCPU: |
|
|
def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None: |
|
|
self.device = device |
|
|
self.alu = ThresholdALU(str(model_path), device=device) |
|
|
|
|
|
def _addr_decode(self, addr: int) -> torch.Tensor: |
|
|
bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32) |
|
|
w = self.alu._get("memory.addr_decode.weight") |
|
|
b = self.alu._get("memory.addr_decode.bias") |
|
|
return heaviside((w * bits).sum(dim=1) + b) |
|
|
|
|
|
def _memory_read(self, mem: List[int], addr: int) -> int: |
|
|
sel = self._addr_decode(addr) |
|
|
mem_bits = torch.tensor( |
|
|
[int_to_bits(byte, REG_BITS) for byte in mem], |
|
|
device=self.device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
and_w = self.alu._get("memory.read.and.weight") |
|
|
and_b = self.alu._get("memory.read.and.bias") |
|
|
or_w = self.alu._get("memory.read.or.weight") |
|
|
or_b = self.alu._get("memory.read.or.bias") |
|
|
|
|
|
out_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
inp = torch.stack([mem_bits[:, bit], sel], dim=1) |
|
|
and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) |
|
|
out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item() |
|
|
out_bits.append(int(out_bit)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]: |
|
|
sel = self._addr_decode(addr) |
|
|
data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32) |
|
|
mem_bits = torch.tensor( |
|
|
[int_to_bits(byte, REG_BITS) for byte in mem], |
|
|
device=self.device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
sel_w = self.alu._get("memory.write.sel.weight") |
|
|
sel_b = self.alu._get("memory.write.sel.bias") |
|
|
nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1) |
|
|
nsel_b = self.alu._get("memory.write.nsel.bias") |
|
|
and_old_w = self.alu._get("memory.write.and_old.weight") |
|
|
and_old_b = self.alu._get("memory.write.and_old.bias") |
|
|
and_new_w = self.alu._get("memory.write.and_new.weight") |
|
|
and_new_b = self.alu._get("memory.write.and_new.bias") |
|
|
or_w = self.alu._get("memory.write.or.weight") |
|
|
or_b = self.alu._get("memory.write.or.bias") |
|
|
|
|
|
we = torch.ones_like(sel) |
|
|
sel_inp = torch.stack([sel, we], dim=1) |
|
|
write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) |
|
|
nsel = heaviside((write_sel * nsel_w) + nsel_b) |
|
|
|
|
|
new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device) |
|
|
for bit in range(REG_BITS): |
|
|
old_bit = mem_bits[:, bit] |
|
|
data_bit = data_bits[bit].expand(MEM_BYTES) |
|
|
inp_old = torch.stack([old_bit, nsel], dim=1) |
|
|
inp_new = torch.stack([data_bit, write_sel], dim=1) |
|
|
|
|
|
and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) |
|
|
and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) |
|
|
or_inp = torch.stack([and_old, and_new], dim=1) |
|
|
out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) |
|
|
new_mem_bits[:, bit] = out_bit |
|
|
|
|
|
return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)] |
|
|
|
|
|
def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int: |
|
|
pc_bits = int_to_bits(pc_byte, REG_BITS) |
|
|
target_bits = int_to_bits(target_byte, REG_BITS) |
|
|
|
|
|
out_bits: List[int] = [] |
|
|
for bit in range(REG_BITS): |
|
|
not_sel = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.not_sel.weight", |
|
|
f"{prefix}.bit{bit}.not_sel.bias", |
|
|
[float(flag)], |
|
|
) |
|
|
and_a = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.and_a.weight", |
|
|
f"{prefix}.bit{bit}.and_a.bias", |
|
|
[float(pc_bits[bit]), not_sel], |
|
|
) |
|
|
and_b = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.and_b.weight", |
|
|
f"{prefix}.bit{bit}.and_b.bias", |
|
|
[float(target_bits[bit]), float(flag)], |
|
|
) |
|
|
out_bit = self.alu._eval_gate( |
|
|
f"{prefix}.bit{bit}.or.weight", |
|
|
f"{prefix}.bit{bit}.or.bias", |
|
|
[and_a, and_b], |
|
|
) |
|
|
out_bits.append(int(out_bit)) |
|
|
|
|
|
return bits_to_int(out_bits) |
|
|
|
|
|
def step(self, state: CPUState) -> CPUState: |
|
|
"""Single CPU cycle using threshold neurons.""" |
|
|
if state.ctrl[0] == 1: |
|
|
return state.copy() |
|
|
|
|
|
s = state.copy() |
|
|
|
|
|
hi = self._memory_read(s.mem, s.pc) |
|
|
lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF) |
|
|
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) |
|
|
next_pc = (s.pc + 2) & 0xFFFF |
|
|
|
|
|
opcode, rd, rs, imm8 = decode_ir(s.ir) |
|
|
a = s.regs[rd] |
|
|
b = s.regs[rs] |
|
|
|
|
|
addr16 = None |
|
|
next_pc_ext = next_pc |
|
|
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): |
|
|
addr_hi = self._memory_read(s.mem, next_pc) |
|
|
addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF) |
|
|
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) |
|
|
next_pc_ext = (next_pc + 2) & 0xFFFF |
|
|
|
|
|
write_result = True |
|
|
result = a |
|
|
carry = 0 |
|
|
overflow = 0 |
|
|
|
|
|
if opcode == 0x0: |
|
|
result, carry, overflow = self.alu.add(a, b) |
|
|
elif opcode == 0x1: |
|
|
result, carry, overflow = self.alu.sub(a, b) |
|
|
elif opcode == 0x2: |
|
|
result = self.alu.bitwise_and(a, b) |
|
|
elif opcode == 0x3: |
|
|
result = self.alu.bitwise_or(a, b) |
|
|
elif opcode == 0x4: |
|
|
result = self.alu.bitwise_xor(a, b) |
|
|
elif opcode == 0x5: |
|
|
result = self.alu.shift_left(a) |
|
|
elif opcode == 0x6: |
|
|
result = self.alu.shift_right(a) |
|
|
elif opcode == 0x7: |
|
|
result = self.alu.multiply(a, b) |
|
|
elif opcode == 0x8: |
|
|
result, _ = self.alu.divide(a, b) |
|
|
elif opcode == 0x9: |
|
|
result, carry, overflow = self.alu.sub(a, b) |
|
|
write_result = False |
|
|
elif opcode == 0xA: |
|
|
result = self._memory_read(s.mem, addr16) |
|
|
elif opcode == 0xB: |
|
|
s.mem = self._memory_write(s.mem, addr16, b & 0xFF) |
|
|
write_result = False |
|
|
elif opcode == 0xC: |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xD: |
|
|
cond_type = imm8 & 0x7 |
|
|
cond_circuits = [ |
|
|
("control.jz", 0), |
|
|
("control.jnz", 0), |
|
|
("control.jc", 2), |
|
|
("control.jnc", 2), |
|
|
("control.jn", 1), |
|
|
("control.jp", 1), |
|
|
("control.jv", 3), |
|
|
("control.jnv", 3), |
|
|
] |
|
|
circuit_prefix, flag_idx = cond_circuits[cond_type] |
|
|
hi_pc = self._conditional_jump_byte( |
|
|
circuit_prefix, |
|
|
(next_pc_ext >> 8) & 0xFF, |
|
|
(addr16 >> 8) & 0xFF, |
|
|
s.flags[flag_idx], |
|
|
) |
|
|
lo_pc = self._conditional_jump_byte( |
|
|
circuit_prefix, |
|
|
next_pc_ext & 0xFF, |
|
|
addr16 & 0xFF, |
|
|
s.flags[flag_idx], |
|
|
) |
|
|
s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF) |
|
|
write_result = False |
|
|
elif opcode == 0xE: |
|
|
ret_addr = next_pc_ext & 0xFFFF |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF) |
|
|
s.sp = (s.sp - 1) & 0xFFFF |
|
|
s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF) |
|
|
s.pc = addr16 & 0xFFFF |
|
|
write_result = False |
|
|
elif opcode == 0xF: |
|
|
s.ctrl[0] = 1 |
|
|
write_result = False |
|
|
|
|
|
if opcode <= 0x9 or opcode == 0xA: |
|
|
s.flags = list(flags_from_result(result, carry, overflow)) |
|
|
|
|
|
if write_result: |
|
|
s.regs[rd] = result & 0xFF |
|
|
|
|
|
if opcode not in (0xC, 0xD, 0xE): |
|
|
s.pc = next_pc_ext |
|
|
|
|
|
return s |
|
|
|
|
|
def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: |
|
|
"""Execute until HALT or max_cycles reached.""" |
|
|
s = state.copy() |
|
|
for i in range(max_cycles): |
|
|
if s.ctrl[0] == 1: |
|
|
return s, i |
|
|
s = self.step(s) |
|
|
return s, max_cycles |
|
|
|
|
|
def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor: |
|
|
"""Tensor-in, tensor-out interface for neural integration.""" |
|
|
bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()] |
|
|
state = unpack_state(bits_list) |
|
|
final, _ = self.run_until_halt(state, max_cycles=max_cycles) |
|
|
return torch.tensor(pack_state(final), dtype=torch.float32) |
|
|
|
|
|
|
|
|
def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int: |
|
|
return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF) |
|
|
|
|
|
|
|
|
def write_instr(mem: List[int], addr: int, instr: int) -> None: |
|
|
mem[addr & 0xFFFF] = (instr >> 8) & 0xFF |
|
|
mem[(addr + 1) & 0xFFFF] = instr & 0xFF |
|
|
|
|
|
|
|
|
def write_addr(mem: List[int], addr: int, value: int) -> None: |
|
|
mem[addr & 0xFFFF] = (value >> 8) & 0xFF |
|
|
mem[(addr + 1) & 0xFFFF] = value & 0xFF |
|
|
|
|
|
|
|
|
def run_smoke_test() -> None: |
|
|
"""Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12.""" |
|
|
mem = [0] * 65536 |
|
|
|
|
|
write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00)) |
|
|
write_addr(mem, 0x0002, 0x0100) |
|
|
write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00)) |
|
|
write_addr(mem, 0x0006, 0x0101) |
|
|
write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00)) |
|
|
write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00)) |
|
|
write_addr(mem, 0x000C, 0x0102) |
|
|
write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00)) |
|
|
|
|
|
mem[0x0100] = 5 |
|
|
mem[0x0101] = 7 |
|
|
|
|
|
state = CPUState( |
|
|
pc=0, |
|
|
ir=0, |
|
|
regs=[0, 0, 0, 0], |
|
|
flags=[0, 0, 0, 0], |
|
|
sp=0xFFFE, |
|
|
ctrl=[0, 0, 0, 0], |
|
|
mem=mem, |
|
|
) |
|
|
|
|
|
print("Running reference implementation...") |
|
|
final, cycles = ref_run_until_halt(state, max_cycles=20) |
|
|
|
|
|
assert final.ctrl[0] == 1, "HALT flag not set" |
|
|
assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}" |
|
|
assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}" |
|
|
assert cycles <= 10, f"Unexpected cycle count: {cycles}" |
|
|
print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}") |
|
|
|
|
|
print("Running threshold-weight implementation...") |
|
|
threshold_cpu = ThresholdCPU() |
|
|
t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20) |
|
|
|
|
|
assert t_final.ctrl[0] == 1, "Threshold HALT flag not set" |
|
|
assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}" |
|
|
assert t_final.mem[0x0102] == final.mem[0x0102], ( |
|
|
f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}" |
|
|
) |
|
|
assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}" |
|
|
print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}") |
|
|
|
|
|
print("Validating forward() tensor I/O...") |
|
|
bits = torch.tensor(pack_state(state), dtype=torch.float32) |
|
|
out_bits = threshold_cpu.forward(bits, max_cycles=20) |
|
|
out_state = unpack_state([int(b) for b in out_bits.tolist()]) |
|
|
assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}" |
|
|
assert out_state.mem[0x0102] == final.mem[0x0102], ( |
|
|
f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}" |
|
|
) |
|
|
print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}") |
|
|
|
|
|
print("\nSmoke test: PASSED") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="8-bit Threshold CPU") |
|
|
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH, help="Path to safetensors model") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.model != DEFAULT_MODEL_PATH: |
|
|
DEFAULT_MODEL_PATH = args.model |
|
|
|
|
|
run_smoke_test() |
|
|
|