8bit-threshold-computer / threshold_cpu.py
CharlesCNorton
Add SHL, SHR, MUL, DIV, and comparator circuits
6087b2e
raw
history blame
30.2 kB
"""
8-bit Threshold Computer - CPU Runtime
State layout, reference cycle, and threshold-weight execution.
All multi-bit fields are MSB-first.
Usage:
python threshold_cpu.py # Run smoke test
python threshold_cpu.py --help # Show options
"""
from __future__ import annotations
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple
import torch
from safetensors.torch import load_file
FLAG_NAMES = ["Z", "N", "C", "V"]
CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"]
PC_BITS = 16
IR_BITS = 16
REG_BITS = 8
REG_COUNT = 4
FLAG_BITS = 4
SP_BITS = 16
CTRL_BITS = 4
MEM_BYTES = 65536
MEM_BITS = MEM_BYTES * 8
STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "neural_computer.safetensors"
def int_to_bits(value: int, width: int) -> List[int]:
return [(value >> (width - 1 - i)) & 1 for i in range(width)]
def bits_to_int(bits: List[int]) -> int:
value = 0
for bit in bits:
value = (value << 1) | int(bit)
return value
def bits_msb_to_lsb(bits: List[int]) -> List[int]:
return list(reversed(bits))
@dataclass
class CPUState:
pc: int
ir: int
regs: List[int]
flags: List[int]
sp: int
ctrl: List[int]
mem: List[int]
def copy(self) -> CPUState:
return CPUState(
pc=int(self.pc),
ir=int(self.ir),
regs=[int(r) for r in self.regs],
flags=[int(f) for f in self.flags],
sp=int(self.sp),
ctrl=[int(c) for c in self.ctrl],
mem=[int(m) for m in self.mem],
)
def pack_state(state: CPUState) -> List[int]:
bits: List[int] = []
bits.extend(int_to_bits(state.pc, PC_BITS))
bits.extend(int_to_bits(state.ir, IR_BITS))
for reg in state.regs:
bits.extend(int_to_bits(reg, REG_BITS))
bits.extend([int(f) for f in state.flags])
bits.extend(int_to_bits(state.sp, SP_BITS))
bits.extend([int(c) for c in state.ctrl])
for byte in state.mem:
bits.extend(int_to_bits(byte, REG_BITS))
return bits
def unpack_state(bits: List[int]) -> CPUState:
if len(bits) != STATE_BITS:
raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}")
idx = 0
pc = bits_to_int(bits[idx:idx + PC_BITS])
idx += PC_BITS
ir = bits_to_int(bits[idx:idx + IR_BITS])
idx += IR_BITS
regs = []
for _ in range(REG_COUNT):
regs.append(bits_to_int(bits[idx:idx + REG_BITS]))
idx += REG_BITS
flags = [int(b) for b in bits[idx:idx + FLAG_BITS]]
idx += FLAG_BITS
sp = bits_to_int(bits[idx:idx + SP_BITS])
idx += SP_BITS
ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]]
idx += CTRL_BITS
mem = []
for _ in range(MEM_BYTES):
mem.append(bits_to_int(bits[idx:idx + REG_BITS]))
idx += REG_BITS
return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem)
def decode_ir(ir: int) -> Tuple[int, int, int, int]:
opcode = (ir >> 12) & 0xF
rd = (ir >> 10) & 0x3
rs = (ir >> 8) & 0x3
imm8 = ir & 0xFF
return opcode, rd, rs, imm8
def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]:
z = 1 if result == 0 else 0
n = 1 if (result & 0x80) else 0
c = 1 if carry else 0
v = 1 if overflow else 0
return z, n, c, v
def alu_add(a: int, b: int) -> Tuple[int, int, int]:
full = a + b
result = full & 0xFF
carry = 1 if full > 0xFF else 0
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
return result, carry, overflow
def alu_sub(a: int, b: int) -> Tuple[int, int, int]:
full = (a - b) & 0x1FF
result = full & 0xFF
carry = 1 if a >= b else 0
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
return result, carry, overflow
def ref_step(state: CPUState) -> CPUState:
"""Reference CPU cycle (pure Python arithmetic)."""
if state.ctrl[0] == 1:
return state.copy()
s = state.copy()
hi = s.mem[s.pc]
lo = s.mem[(s.pc + 1) & 0xFFFF]
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
next_pc = (s.pc + 2) & 0xFFFF
opcode, rd, rs, imm8 = decode_ir(s.ir)
a = s.regs[rd]
b = s.regs[rs]
addr16 = None
next_pc_ext = next_pc
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
addr_hi = s.mem[next_pc]
addr_lo = s.mem[(next_pc + 1) & 0xFFFF]
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
next_pc_ext = (next_pc + 2) & 0xFFFF
write_result = True
result = a
carry = 0
overflow = 0
if opcode == 0x0:
result, carry, overflow = alu_add(a, b)
elif opcode == 0x1:
result, carry, overflow = alu_sub(a, b)
elif opcode == 0x2:
result = a & b
elif opcode == 0x3:
result = a | b
elif opcode == 0x4:
result = a ^ b
elif opcode == 0x5:
result = (a << 1) & 0xFF
elif opcode == 0x6:
result = (a >> 1) & 0xFF
elif opcode == 0x7:
result = (a * b) & 0xFF
elif opcode == 0x8:
if b == 0:
result = 0xFF
else:
result = a // b
elif opcode == 0x9:
result, carry, overflow = alu_sub(a, b)
write_result = False
elif opcode == 0xA:
result = s.mem[addr16]
elif opcode == 0xB:
s.mem[addr16] = b & 0xFF
write_result = False
elif opcode == 0xC:
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xD:
cond_type = imm8 & 0x7
if cond_type == 0:
take_branch = s.flags[0] == 1
elif cond_type == 1:
take_branch = s.flags[0] == 0
elif cond_type == 2:
take_branch = s.flags[2] == 1
elif cond_type == 3:
take_branch = s.flags[2] == 0
elif cond_type == 4:
take_branch = s.flags[1] == 1
elif cond_type == 5:
take_branch = s.flags[1] == 0
elif cond_type == 6:
take_branch = s.flags[3] == 1
else:
take_branch = s.flags[3] == 0
if take_branch:
s.pc = addr16 & 0xFFFF
else:
s.pc = next_pc_ext
write_result = False
elif opcode == 0xE:
ret_addr = next_pc_ext & 0xFFFF
s.sp = (s.sp - 1) & 0xFFFF
s.mem[s.sp] = (ret_addr >> 8) & 0xFF
s.sp = (s.sp - 1) & 0xFFFF
s.mem[s.sp] = ret_addr & 0xFF
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xF:
s.ctrl[0] = 1
write_result = False
if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8):
s.flags = list(flags_from_result(result, carry, overflow))
if write_result:
s.regs[rd] = result & 0xFF
if opcode not in (0xC, 0xD, 0xE):
s.pc = next_pc_ext
return s
def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
"""Reference execution loop."""
s = state.copy()
for i in range(max_cycles):
if s.ctrl[0] == 1:
return s, i
s = ref_step(s)
return s, max_cycles
def heaviside(x: torch.Tensor) -> torch.Tensor:
return (x >= 0).float()
class ThresholdALU:
def __init__(self, model_path: str, device: str = "cpu") -> None:
self.device = device
self.tensors = {k: v.float().to(device) for k, v in load_file(model_path).items()}
def _get(self, name: str) -> torch.Tensor:
return self.tensors[name]
def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float:
w = self._get(weight_key)
b = self._get(bias_key)
inp = torch.tensor(inputs, device=self.device)
return heaviside((inp * w).sum() + b).item()
def _eval_xor(self, prefix: str, inputs: List[float]) -> float:
inp = torch.tensor(inputs, device=self.device)
w_or = self._get(f"{prefix}.layer1.or.weight")
b_or = self._get(f"{prefix}.layer1.or.bias")
w_nand = self._get(f"{prefix}.layer1.nand.weight")
b_nand = self._get(f"{prefix}.layer1.nand.bias")
w2 = self._get(f"{prefix}.layer2.weight")
b2 = self._get(f"{prefix}.layer2.bias")
h_or = heaviside((inp * w_or).sum() + b_or).item()
h_nand = heaviside((inp * w_nand).sum() + b_nand).item()
hidden = torch.tensor([h_or, h_nand], device=self.device)
return heaviside((hidden * w2).sum() + b2).item()
def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]:
ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b])
ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b])
ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin])
ha2_carry = self._eval_gate(
f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin]
)
cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry])
return ha2_sum, cout
def add(self, a: int, b: int) -> Tuple[int, int, int]:
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
carry = 0.0
sum_bits: List[int] = []
for bit in range(REG_BITS):
sum_bit, carry = self._eval_full_adder(
f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry
)
sum_bits.append(int(sum_bit))
result = bits_to_int(list(reversed(sum_bits)))
carry_out = int(carry)
overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
return result, carry_out, overflow
def sub(self, a: int, b: int) -> Tuple[int, int, int]:
a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
carry = 1.0
sum_bits: List[int] = []
for bit in range(REG_BITS):
notb = self._eval_gate(
f"arithmetic.sub8bit.notb{bit}.weight",
f"arithmetic.sub8bit.notb{bit}.bias",
[float(b_bits[bit])],
)
xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb])
xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry])
and1 = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.and1.weight",
f"arithmetic.sub8bit.fa{bit}.and1.bias",
[float(a_bits[bit]), notb],
)
and2 = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.and2.weight",
f"arithmetic.sub8bit.fa{bit}.and2.bias",
[xor1, carry],
)
carry = self._eval_gate(
f"arithmetic.sub8bit.fa{bit}.or_carry.weight",
f"arithmetic.sub8bit.fa{bit}.or_carry.bias",
[and1, and2],
)
sum_bits.append(int(xor2))
result = bits_to_int(list(reversed(sum_bits)))
carry_out = int(carry)
overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
return result, carry_out, overflow
def bitwise_and(self, a: int, b: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
w = self._get("alu.alu8bit.and.weight")
bias = self._get("alu.alu8bit.and.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def bitwise_or(self, a: int, b: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
w = self._get("alu.alu8bit.or.weight")
bias = self._get("alu.alu8bit.or.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def bitwise_not(self, a: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
w = self._get("alu.alu8bit.not.weight")
bias = self._get("alu.alu8bit.not.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit])], device=self.device)
out = heaviside((inp * w[bit]).sum() + bias[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def bitwise_xor(self, a: int, b: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
w_or = self._get("alu.alu8bit.xor.layer1.or.weight")
b_or = self._get("alu.alu8bit.xor.layer1.or.bias")
w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight")
b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias")
w2 = self._get("alu.alu8bit.xor.layer2.weight")
b2 = self._get("alu.alu8bit.xor.layer2.bias")
out_bits = []
for bit in range(REG_BITS):
inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit])
h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit])
hidden = torch.stack([h_or, h_nand])
out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def shift_left(self, a: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
out_bits = []
for bit in range(REG_BITS):
w = self.alu._get(f"alu.alu8bit.shl.bit{bit}.weight")
bias = self.alu._get(f"alu.alu8bit.shl.bit{bit}.bias")
if bit < 7:
inp = torch.tensor([float(a_bits[bit + 1])], device=self.device)
else:
inp = torch.tensor([0.0], device=self.device)
out = heaviside((inp * w).sum() + bias).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def shift_right(self, a: int) -> int:
a_bits = int_to_bits(a, REG_BITS)
out_bits = []
for bit in range(REG_BITS):
w = self.alu._get(f"alu.alu8bit.shr.bit{bit}.weight")
bias = self.alu._get(f"alu.alu8bit.shr.bit{bit}.bias")
if bit > 0:
inp = torch.tensor([float(a_bits[bit - 1])], device=self.device)
else:
inp = torch.tensor([0.0], device=self.device)
out = heaviside((inp * w).sum() + bias).item()
out_bits.append(int(out))
return bits_to_int(out_bits)
def multiply(self, a: int, b: int) -> int:
"""8-bit multiply using partial product AND gates + shift-add."""
a_bits = int_to_bits(a, REG_BITS)
b_bits = int_to_bits(b, REG_BITS)
# Compute all 64 partial products using AND gates
pp = [[0] * 8 for _ in range(8)]
for i in range(8):
for j in range(8):
w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight")
bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias")
inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device)
pp[i][j] = int(heaviside((inp * w).sum() + bias).item())
# Shift-add accumulation using existing 8-bit adder
# Row j contributes A*B[j] shifted left by (7-j) positions
result = 0
for j in range(8):
if b_bits[j] == 0:
continue
# Construct the partial product row (A masked by B[j])
row = 0
for i in range(8):
row |= (pp[i][j] << (7 - i))
# Shift by position (7-j means B[7] is LSB, B[0] is MSB)
shifted = row << (7 - j)
# Add to result using threshold adder
result, _, _ = self.add(result & 0xFF, shifted & 0xFF)
# Handle overflow into high byte
if shifted > 255 or result > 255:
result = (result + (shifted >> 8)) & 0xFF
return result & 0xFF
def divide(self, a: int, b: int) -> Tuple[int, int]:
"""8-bit divide using restoring division with threshold gates."""
if b == 0:
return 0xFF, a # Division by zero: return max quotient, original dividend
a_bits = int_to_bits(a, REG_BITS)
quotient = 0
remainder = 0
for stage in range(8):
# Shift remainder left and bring in next dividend bit
remainder = ((remainder << 1) | a_bits[stage]) & 0xFF
# Compare remainder >= divisor using threshold gate
rem_bits = int_to_bits(remainder, REG_BITS)
div_bits = int_to_bits(b, REG_BITS)
w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight")
bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias")
inp = torch.tensor([float(rem_bits[i]) for i in range(8)] +
[float(div_bits[i]) for i in range(8)], device=self.device)
cmp_result = int(heaviside((inp * w).sum() + bias).item())
# If remainder >= divisor, subtract and set quotient bit
if cmp_result:
remainder, _, _ = self.sub(remainder, b)
quotient = (quotient << 1) | 1
else:
quotient = quotient << 1
return quotient & 0xFF, remainder & 0xFF
class ThresholdCPU:
def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None:
self.device = device
self.alu = ThresholdALU(str(model_path), device=device)
def _addr_decode(self, addr: int) -> torch.Tensor:
bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32)
w = self.alu._get("memory.addr_decode.weight")
b = self.alu._get("memory.addr_decode.bias")
return heaviside((w * bits).sum(dim=1) + b)
def _memory_read(self, mem: List[int], addr: int) -> int:
sel = self._addr_decode(addr)
mem_bits = torch.tensor(
[int_to_bits(byte, REG_BITS) for byte in mem],
device=self.device,
dtype=torch.float32,
)
and_w = self.alu._get("memory.read.and.weight")
and_b = self.alu._get("memory.read.and.bias")
or_w = self.alu._get("memory.read.or.weight")
or_b = self.alu._get("memory.read.or.bias")
out_bits: List[int] = []
for bit in range(REG_BITS):
inp = torch.stack([mem_bits[:, bit], sel], dim=1)
and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()
out_bits.append(int(out_bit))
return bits_to_int(out_bits)
def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]:
sel = self._addr_decode(addr)
data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32)
mem_bits = torch.tensor(
[int_to_bits(byte, REG_BITS) for byte in mem],
device=self.device,
dtype=torch.float32,
)
sel_w = self.alu._get("memory.write.sel.weight")
sel_b = self.alu._get("memory.write.sel.bias")
nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1)
nsel_b = self.alu._get("memory.write.nsel.bias")
and_old_w = self.alu._get("memory.write.and_old.weight")
and_old_b = self.alu._get("memory.write.and_old.bias")
and_new_w = self.alu._get("memory.write.and_new.weight")
and_new_b = self.alu._get("memory.write.and_new.bias")
or_w = self.alu._get("memory.write.or.weight")
or_b = self.alu._get("memory.write.or.bias")
we = torch.ones_like(sel)
sel_inp = torch.stack([sel, we], dim=1)
write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
nsel = heaviside((write_sel * nsel_w) + nsel_b)
new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device)
for bit in range(REG_BITS):
old_bit = mem_bits[:, bit]
data_bit = data_bits[bit].expand(MEM_BYTES)
inp_old = torch.stack([old_bit, nsel], dim=1)
inp_new = torch.stack([data_bit, write_sel], dim=1)
and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
or_inp = torch.stack([and_old, and_new], dim=1)
out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
new_mem_bits[:, bit] = out_bit
return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)]
def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int:
pc_bits = int_to_bits(pc_byte, REG_BITS)
target_bits = int_to_bits(target_byte, REG_BITS)
out_bits: List[int] = []
for bit in range(REG_BITS):
not_sel = self.alu._eval_gate(
f"{prefix}.bit{bit}.not_sel.weight",
f"{prefix}.bit{bit}.not_sel.bias",
[float(flag)],
)
and_a = self.alu._eval_gate(
f"{prefix}.bit{bit}.and_a.weight",
f"{prefix}.bit{bit}.and_a.bias",
[float(pc_bits[bit]), not_sel],
)
and_b = self.alu._eval_gate(
f"{prefix}.bit{bit}.and_b.weight",
f"{prefix}.bit{bit}.and_b.bias",
[float(target_bits[bit]), float(flag)],
)
out_bit = self.alu._eval_gate(
f"{prefix}.bit{bit}.or.weight",
f"{prefix}.bit{bit}.or.bias",
[and_a, and_b],
)
out_bits.append(int(out_bit))
return bits_to_int(out_bits)
def step(self, state: CPUState) -> CPUState:
"""Single CPU cycle using threshold neurons."""
if state.ctrl[0] == 1:
return state.copy()
s = state.copy()
hi = self._memory_read(s.mem, s.pc)
lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF)
s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
next_pc = (s.pc + 2) & 0xFFFF
opcode, rd, rs, imm8 = decode_ir(s.ir)
a = s.regs[rd]
b = s.regs[rs]
addr16 = None
next_pc_ext = next_pc
if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
addr_hi = self._memory_read(s.mem, next_pc)
addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF)
addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
next_pc_ext = (next_pc + 2) & 0xFFFF
write_result = True
result = a
carry = 0
overflow = 0
if opcode == 0x0:
result, carry, overflow = self.alu.add(a, b)
elif opcode == 0x1:
result, carry, overflow = self.alu.sub(a, b)
elif opcode == 0x2:
result = self.alu.bitwise_and(a, b)
elif opcode == 0x3:
result = self.alu.bitwise_or(a, b)
elif opcode == 0x4:
result = self.alu.bitwise_xor(a, b)
elif opcode == 0x5:
result = self.alu.shift_left(a)
elif opcode == 0x6:
result = self.alu.shift_right(a)
elif opcode == 0x7:
result = self.alu.multiply(a, b)
elif opcode == 0x8:
result, _ = self.alu.divide(a, b)
elif opcode == 0x9:
result, carry, overflow = self.alu.sub(a, b)
write_result = False
elif opcode == 0xA:
result = self._memory_read(s.mem, addr16)
elif opcode == 0xB:
s.mem = self._memory_write(s.mem, addr16, b & 0xFF)
write_result = False
elif opcode == 0xC:
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xD:
cond_type = imm8 & 0x7
cond_circuits = [
("control.jz", 0),
("control.jnz", 0),
("control.jc", 2),
("control.jnc", 2),
("control.jn", 1),
("control.jp", 1),
("control.jv", 3),
("control.jnv", 3),
]
circuit_prefix, flag_idx = cond_circuits[cond_type]
hi_pc = self._conditional_jump_byte(
circuit_prefix,
(next_pc_ext >> 8) & 0xFF,
(addr16 >> 8) & 0xFF,
s.flags[flag_idx],
)
lo_pc = self._conditional_jump_byte(
circuit_prefix,
next_pc_ext & 0xFF,
addr16 & 0xFF,
s.flags[flag_idx],
)
s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF)
write_result = False
elif opcode == 0xE:
ret_addr = next_pc_ext & 0xFFFF
s.sp = (s.sp - 1) & 0xFFFF
s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF)
s.sp = (s.sp - 1) & 0xFFFF
s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF)
s.pc = addr16 & 0xFFFF
write_result = False
elif opcode == 0xF:
s.ctrl[0] = 1
write_result = False
if opcode <= 0x9 or opcode == 0xA:
s.flags = list(flags_from_result(result, carry, overflow))
if write_result:
s.regs[rd] = result & 0xFF
if opcode not in (0xC, 0xD, 0xE):
s.pc = next_pc_ext
return s
def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
"""Execute until HALT or max_cycles reached."""
s = state.copy()
for i in range(max_cycles):
if s.ctrl[0] == 1:
return s, i
s = self.step(s)
return s, max_cycles
def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor:
"""Tensor-in, tensor-out interface for neural integration."""
bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()]
state = unpack_state(bits_list)
final, _ = self.run_until_halt(state, max_cycles=max_cycles)
return torch.tensor(pack_state(final), dtype=torch.float32)
def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int:
return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF)
def write_instr(mem: List[int], addr: int, instr: int) -> None:
mem[addr & 0xFFFF] = (instr >> 8) & 0xFF
mem[(addr + 1) & 0xFFFF] = instr & 0xFF
def write_addr(mem: List[int], addr: int, value: int) -> None:
mem[addr & 0xFFFF] = (value >> 8) & 0xFF
mem[(addr + 1) & 0xFFFF] = value & 0xFF
def run_smoke_test() -> None:
"""Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12."""
mem = [0] * 65536
write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00))
write_addr(mem, 0x0002, 0x0100)
write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00))
write_addr(mem, 0x0006, 0x0101)
write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00))
write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00))
write_addr(mem, 0x000C, 0x0102)
write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00))
mem[0x0100] = 5
mem[0x0101] = 7
state = CPUState(
pc=0,
ir=0,
regs=[0, 0, 0, 0],
flags=[0, 0, 0, 0],
sp=0xFFFE,
ctrl=[0, 0, 0, 0],
mem=mem,
)
print("Running reference implementation...")
final, cycles = ref_run_until_halt(state, max_cycles=20)
assert final.ctrl[0] == 1, "HALT flag not set"
assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}"
assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}"
assert cycles <= 10, f"Unexpected cycle count: {cycles}"
print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}")
print("Running threshold-weight implementation...")
threshold_cpu = ThresholdCPU()
t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20)
assert t_final.ctrl[0] == 1, "Threshold HALT flag not set"
assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}"
assert t_final.mem[0x0102] == final.mem[0x0102], (
f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}"
)
assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}"
print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}")
print("Validating forward() tensor I/O...")
bits = torch.tensor(pack_state(state), dtype=torch.float32)
out_bits = threshold_cpu.forward(bits, max_cycles=20)
out_state = unpack_state([int(b) for b in out_bits.tolist()])
assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}"
assert out_state.mem[0x0102] == final.mem[0x0102], (
f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}"
)
print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}")
print("\nSmoke test: PASSED")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="8-bit Threshold CPU")
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH, help="Path to safetensors model")
args = parser.parse_args()
if args.model != DEFAULT_MODEL_PATH:
DEFAULT_MODEL_PATH = args.model
run_smoke_test()