""" 8-bit Threshold Computer - CPU Runtime State layout, reference cycle, and threshold-weight execution. All multi-bit fields are MSB-first. Usage: python threshold_cpu.py # Run smoke test python threshold_cpu.py --help # Show options """ from __future__ import annotations import argparse from dataclasses import dataclass from pathlib import Path from typing import List, Tuple import torch from safetensors.torch import load_file FLAG_NAMES = ["Z", "N", "C", "V"] CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"] PC_BITS = 16 IR_BITS = 16 REG_BITS = 8 REG_COUNT = 4 FLAG_BITS = 4 SP_BITS = 16 CTRL_BITS = 4 MEM_BYTES = 65536 MEM_BITS = MEM_BYTES * 8 STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "neural_computer.safetensors" def int_to_bits(value: int, width: int) -> List[int]: return [(value >> (width - 1 - i)) & 1 for i in range(width)] def bits_to_int(bits: List[int]) -> int: value = 0 for bit in bits: value = (value << 1) | int(bit) return value def bits_msb_to_lsb(bits: List[int]) -> List[int]: return list(reversed(bits)) @dataclass class CPUState: pc: int ir: int regs: List[int] flags: List[int] sp: int ctrl: List[int] mem: List[int] def copy(self) -> CPUState: return CPUState( pc=int(self.pc), ir=int(self.ir), regs=[int(r) for r in self.regs], flags=[int(f) for f in self.flags], sp=int(self.sp), ctrl=[int(c) for c in self.ctrl], mem=[int(m) for m in self.mem], ) def pack_state(state: CPUState) -> List[int]: bits: List[int] = [] bits.extend(int_to_bits(state.pc, PC_BITS)) bits.extend(int_to_bits(state.ir, IR_BITS)) for reg in state.regs: bits.extend(int_to_bits(reg, REG_BITS)) bits.extend([int(f) for f in state.flags]) bits.extend(int_to_bits(state.sp, SP_BITS)) bits.extend([int(c) for c in state.ctrl]) for byte in state.mem: bits.extend(int_to_bits(byte, REG_BITS)) return bits def unpack_state(bits: List[int]) -> CPUState: if len(bits) != STATE_BITS: raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}") idx = 0 pc = bits_to_int(bits[idx:idx + PC_BITS]) idx += PC_BITS ir = bits_to_int(bits[idx:idx + IR_BITS]) idx += IR_BITS regs = [] for _ in range(REG_COUNT): regs.append(bits_to_int(bits[idx:idx + REG_BITS])) idx += REG_BITS flags = [int(b) for b in bits[idx:idx + FLAG_BITS]] idx += FLAG_BITS sp = bits_to_int(bits[idx:idx + SP_BITS]) idx += SP_BITS ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]] idx += CTRL_BITS mem = [] for _ in range(MEM_BYTES): mem.append(bits_to_int(bits[idx:idx + REG_BITS])) idx += REG_BITS return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem) def decode_ir(ir: int) -> Tuple[int, int, int, int]: opcode = (ir >> 12) & 0xF rd = (ir >> 10) & 0x3 rs = (ir >> 8) & 0x3 imm8 = ir & 0xFF return opcode, rd, rs, imm8 def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]: z = 1 if result == 0 else 0 n = 1 if (result & 0x80) else 0 c = 1 if carry else 0 v = 1 if overflow else 0 return z, n, c, v def alu_add(a: int, b: int) -> Tuple[int, int, int]: full = a + b result = full & 0xFF carry = 1 if full > 0xFF else 0 overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 return result, carry, overflow def alu_sub(a: int, b: int) -> Tuple[int, int, int]: full = (a - b) & 0x1FF result = full & 0xFF carry = 1 if a >= b else 0 overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 return result, carry, overflow def ref_step(state: CPUState) -> CPUState: """Reference CPU cycle (pure Python arithmetic).""" if state.ctrl[0] == 1: return state.copy() s = state.copy() hi = s.mem[s.pc] lo = s.mem[(s.pc + 1) & 0xFFFF] s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) next_pc = (s.pc + 2) & 0xFFFF opcode, rd, rs, imm8 = decode_ir(s.ir) a = s.regs[rd] b = s.regs[rs] addr16 = None next_pc_ext = next_pc if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): addr_hi = s.mem[next_pc] addr_lo = s.mem[(next_pc + 1) & 0xFFFF] addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) next_pc_ext = (next_pc + 2) & 0xFFFF write_result = True result = a carry = 0 overflow = 0 if opcode == 0x0: result, carry, overflow = alu_add(a, b) elif opcode == 0x1: result, carry, overflow = alu_sub(a, b) elif opcode == 0x2: result = a & b elif opcode == 0x3: result = a | b elif opcode == 0x4: result = a ^ b elif opcode == 0x5: result = (a << 1) & 0xFF elif opcode == 0x6: result = (a >> 1) & 0xFF elif opcode == 0x7: result = (a * b) & 0xFF elif opcode == 0x8: if b == 0: result = 0xFF else: result = a // b elif opcode == 0x9: result, carry, overflow = alu_sub(a, b) write_result = False elif opcode == 0xA: result = s.mem[addr16] elif opcode == 0xB: s.mem[addr16] = b & 0xFF write_result = False elif opcode == 0xC: s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xD: cond_type = imm8 & 0x7 if cond_type == 0: take_branch = s.flags[0] == 1 elif cond_type == 1: take_branch = s.flags[0] == 0 elif cond_type == 2: take_branch = s.flags[2] == 1 elif cond_type == 3: take_branch = s.flags[2] == 0 elif cond_type == 4: take_branch = s.flags[1] == 1 elif cond_type == 5: take_branch = s.flags[1] == 0 elif cond_type == 6: take_branch = s.flags[3] == 1 else: take_branch = s.flags[3] == 0 if take_branch: s.pc = addr16 & 0xFFFF else: s.pc = next_pc_ext write_result = False elif opcode == 0xE: ret_addr = next_pc_ext & 0xFFFF s.sp = (s.sp - 1) & 0xFFFF s.mem[s.sp] = (ret_addr >> 8) & 0xFF s.sp = (s.sp - 1) & 0xFFFF s.mem[s.sp] = ret_addr & 0xFF s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xF: s.ctrl[0] = 1 write_result = False if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8): s.flags = list(flags_from_result(result, carry, overflow)) if write_result: s.regs[rd] = result & 0xFF if opcode not in (0xC, 0xD, 0xE): s.pc = next_pc_ext return s def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: """Reference execution loop.""" s = state.copy() for i in range(max_cycles): if s.ctrl[0] == 1: return s, i s = ref_step(s) return s, max_cycles def heaviside(x: torch.Tensor) -> torch.Tensor: return (x >= 0).float() class ThresholdALU: def __init__(self, model_path: str, device: str = "cpu") -> None: self.device = device self.tensors = {k: v.float().to(device) for k, v in load_file(model_path).items()} def _get(self, name: str) -> torch.Tensor: return self.tensors[name] def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float: w = self._get(weight_key) b = self._get(bias_key) inp = torch.tensor(inputs, device=self.device) return heaviside((inp * w).sum() + b).item() def _eval_xor(self, prefix: str, inputs: List[float]) -> float: inp = torch.tensor(inputs, device=self.device) w_or = self._get(f"{prefix}.layer1.or.weight") b_or = self._get(f"{prefix}.layer1.or.bias") w_nand = self._get(f"{prefix}.layer1.nand.weight") b_nand = self._get(f"{prefix}.layer1.nand.bias") w2 = self._get(f"{prefix}.layer2.weight") b2 = self._get(f"{prefix}.layer2.bias") h_or = heaviside((inp * w_or).sum() + b_or).item() h_nand = heaviside((inp * w_nand).sum() + b_nand).item() hidden = torch.tensor([h_or, h_nand], device=self.device) return heaviside((hidden * w2).sum() + b2).item() def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]: ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b]) ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b]) ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin]) ha2_carry = self._eval_gate( f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin] ) cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry]) return ha2_sum, cout def add(self, a: int, b: int) -> Tuple[int, int, int]: a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) carry = 0.0 sum_bits: List[int] = [] for bit in range(REG_BITS): sum_bit, carry = self._eval_full_adder( f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry ) sum_bits.append(int(sum_bit)) result = bits_to_int(list(reversed(sum_bits))) carry_out = int(carry) overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 return result, carry_out, overflow def sub(self, a: int, b: int) -> Tuple[int, int, int]: a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) carry = 1.0 sum_bits: List[int] = [] for bit in range(REG_BITS): notb = self._eval_gate( f"arithmetic.sub8bit.notb{bit}.weight", f"arithmetic.sub8bit.notb{bit}.bias", [float(b_bits[bit])], ) xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb]) xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry]) and1 = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.and1.weight", f"arithmetic.sub8bit.fa{bit}.and1.bias", [float(a_bits[bit]), notb], ) and2 = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.and2.weight", f"arithmetic.sub8bit.fa{bit}.and2.bias", [xor1, carry], ) carry = self._eval_gate( f"arithmetic.sub8bit.fa{bit}.or_carry.weight", f"arithmetic.sub8bit.fa{bit}.or_carry.bias", [and1, and2], ) sum_bits.append(int(xor2)) result = bits_to_int(list(reversed(sum_bits))) carry_out = int(carry) overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 return result, carry_out, overflow def bitwise_and(self, a: int, b: int) -> int: a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) w = self._get("alu.alu8bit.and.weight") bias = self._get("alu.alu8bit.and.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def bitwise_or(self, a: int, b: int) -> int: a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) w = self._get("alu.alu8bit.or.weight") bias = self._get("alu.alu8bit.or.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def bitwise_not(self, a: int) -> int: a_bits = int_to_bits(a, REG_BITS) w = self._get("alu.alu8bit.not.weight") bias = self._get("alu.alu8bit.not.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit])], device=self.device) out = heaviside((inp * w[bit]).sum() + bias[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def bitwise_xor(self, a: int, b: int) -> int: a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) w_or = self._get("alu.alu8bit.xor.layer1.or.weight") b_or = self._get("alu.alu8bit.xor.layer1.or.bias") w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight") b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias") w2 = self._get("alu.alu8bit.xor.layer2.weight") b2 = self._get("alu.alu8bit.xor.layer2.bias") out_bits = [] for bit in range(REG_BITS): inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit]) h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit]) hidden = torch.stack([h_or, h_nand]) out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item() out_bits.append(int(out)) return bits_to_int(out_bits) def shift_left(self, a: int) -> int: a_bits = int_to_bits(a, REG_BITS) out_bits = [] for bit in range(REG_BITS): w = self.alu._get(f"alu.alu8bit.shl.bit{bit}.weight") bias = self.alu._get(f"alu.alu8bit.shl.bit{bit}.bias") if bit < 7: inp = torch.tensor([float(a_bits[bit + 1])], device=self.device) else: inp = torch.tensor([0.0], device=self.device) out = heaviside((inp * w).sum() + bias).item() out_bits.append(int(out)) return bits_to_int(out_bits) def shift_right(self, a: int) -> int: a_bits = int_to_bits(a, REG_BITS) out_bits = [] for bit in range(REG_BITS): w = self.alu._get(f"alu.alu8bit.shr.bit{bit}.weight") bias = self.alu._get(f"alu.alu8bit.shr.bit{bit}.bias") if bit > 0: inp = torch.tensor([float(a_bits[bit - 1])], device=self.device) else: inp = torch.tensor([0.0], device=self.device) out = heaviside((inp * w).sum() + bias).item() out_bits.append(int(out)) return bits_to_int(out_bits) def multiply(self, a: int, b: int) -> int: """8-bit multiply using partial product AND gates + shift-add.""" a_bits = int_to_bits(a, REG_BITS) b_bits = int_to_bits(b, REG_BITS) # Compute all 64 partial products using AND gates pp = [[0] * 8 for _ in range(8)] for i in range(8): for j in range(8): w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight") bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias") inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device) pp[i][j] = int(heaviside((inp * w).sum() + bias).item()) # Shift-add accumulation using existing 8-bit adder # Row j contributes A*B[j] shifted left by (7-j) positions result = 0 for j in range(8): if b_bits[j] == 0: continue # Construct the partial product row (A masked by B[j]) row = 0 for i in range(8): row |= (pp[i][j] << (7 - i)) # Shift by position (7-j means B[7] is LSB, B[0] is MSB) shifted = row << (7 - j) # Add to result using threshold adder result, _, _ = self.add(result & 0xFF, shifted & 0xFF) # Handle overflow into high byte if shifted > 255 or result > 255: result = (result + (shifted >> 8)) & 0xFF return result & 0xFF def divide(self, a: int, b: int) -> Tuple[int, int]: """8-bit divide using restoring division with threshold gates.""" if b == 0: return 0xFF, a # Division by zero: return max quotient, original dividend a_bits = int_to_bits(a, REG_BITS) quotient = 0 remainder = 0 for stage in range(8): # Shift remainder left and bring in next dividend bit remainder = ((remainder << 1) | a_bits[stage]) & 0xFF # Compare remainder >= divisor using threshold gate rem_bits = int_to_bits(remainder, REG_BITS) div_bits = int_to_bits(b, REG_BITS) w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight") bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias") inp = torch.tensor([float(rem_bits[i]) for i in range(8)] + [float(div_bits[i]) for i in range(8)], device=self.device) cmp_result = int(heaviside((inp * w).sum() + bias).item()) # If remainder >= divisor, subtract and set quotient bit if cmp_result: remainder, _, _ = self.sub(remainder, b) quotient = (quotient << 1) | 1 else: quotient = quotient << 1 return quotient & 0xFF, remainder & 0xFF class ThresholdCPU: def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None: self.device = device self.alu = ThresholdALU(str(model_path), device=device) def _addr_decode(self, addr: int) -> torch.Tensor: bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32) w = self.alu._get("memory.addr_decode.weight") b = self.alu._get("memory.addr_decode.bias") return heaviside((w * bits).sum(dim=1) + b) def _memory_read(self, mem: List[int], addr: int) -> int: sel = self._addr_decode(addr) mem_bits = torch.tensor( [int_to_bits(byte, REG_BITS) for byte in mem], device=self.device, dtype=torch.float32, ) and_w = self.alu._get("memory.read.and.weight") and_b = self.alu._get("memory.read.and.bias") or_w = self.alu._get("memory.read.or.weight") or_b = self.alu._get("memory.read.or.bias") out_bits: List[int] = [] for bit in range(REG_BITS): inp = torch.stack([mem_bits[:, bit], sel], dim=1) and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item() out_bits.append(int(out_bit)) return bits_to_int(out_bits) def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]: sel = self._addr_decode(addr) data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32) mem_bits = torch.tensor( [int_to_bits(byte, REG_BITS) for byte in mem], device=self.device, dtype=torch.float32, ) sel_w = self.alu._get("memory.write.sel.weight") sel_b = self.alu._get("memory.write.sel.bias") nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1) nsel_b = self.alu._get("memory.write.nsel.bias") and_old_w = self.alu._get("memory.write.and_old.weight") and_old_b = self.alu._get("memory.write.and_old.bias") and_new_w = self.alu._get("memory.write.and_new.weight") and_new_b = self.alu._get("memory.write.and_new.bias") or_w = self.alu._get("memory.write.or.weight") or_b = self.alu._get("memory.write.or.bias") we = torch.ones_like(sel) sel_inp = torch.stack([sel, we], dim=1) write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) nsel = heaviside((write_sel * nsel_w) + nsel_b) new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device) for bit in range(REG_BITS): old_bit = mem_bits[:, bit] data_bit = data_bits[bit].expand(MEM_BYTES) inp_old = torch.stack([old_bit, nsel], dim=1) inp_new = torch.stack([data_bit, write_sel], dim=1) and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) or_inp = torch.stack([and_old, and_new], dim=1) out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) new_mem_bits[:, bit] = out_bit return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)] def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int: pc_bits = int_to_bits(pc_byte, REG_BITS) target_bits = int_to_bits(target_byte, REG_BITS) out_bits: List[int] = [] for bit in range(REG_BITS): not_sel = self.alu._eval_gate( f"{prefix}.bit{bit}.not_sel.weight", f"{prefix}.bit{bit}.not_sel.bias", [float(flag)], ) and_a = self.alu._eval_gate( f"{prefix}.bit{bit}.and_a.weight", f"{prefix}.bit{bit}.and_a.bias", [float(pc_bits[bit]), not_sel], ) and_b = self.alu._eval_gate( f"{prefix}.bit{bit}.and_b.weight", f"{prefix}.bit{bit}.and_b.bias", [float(target_bits[bit]), float(flag)], ) out_bit = self.alu._eval_gate( f"{prefix}.bit{bit}.or.weight", f"{prefix}.bit{bit}.or.bias", [and_a, and_b], ) out_bits.append(int(out_bit)) return bits_to_int(out_bits) def step(self, state: CPUState) -> CPUState: """Single CPU cycle using threshold neurons.""" if state.ctrl[0] == 1: return state.copy() s = state.copy() hi = self._memory_read(s.mem, s.pc) lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF) s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) next_pc = (s.pc + 2) & 0xFFFF opcode, rd, rs, imm8 = decode_ir(s.ir) a = s.regs[rd] b = s.regs[rs] addr16 = None next_pc_ext = next_pc if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): addr_hi = self._memory_read(s.mem, next_pc) addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF) addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) next_pc_ext = (next_pc + 2) & 0xFFFF write_result = True result = a carry = 0 overflow = 0 if opcode == 0x0: result, carry, overflow = self.alu.add(a, b) elif opcode == 0x1: result, carry, overflow = self.alu.sub(a, b) elif opcode == 0x2: result = self.alu.bitwise_and(a, b) elif opcode == 0x3: result = self.alu.bitwise_or(a, b) elif opcode == 0x4: result = self.alu.bitwise_xor(a, b) elif opcode == 0x5: result = self.alu.shift_left(a) elif opcode == 0x6: result = self.alu.shift_right(a) elif opcode == 0x7: result = self.alu.multiply(a, b) elif opcode == 0x8: result, _ = self.alu.divide(a, b) elif opcode == 0x9: result, carry, overflow = self.alu.sub(a, b) write_result = False elif opcode == 0xA: result = self._memory_read(s.mem, addr16) elif opcode == 0xB: s.mem = self._memory_write(s.mem, addr16, b & 0xFF) write_result = False elif opcode == 0xC: s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xD: cond_type = imm8 & 0x7 cond_circuits = [ ("control.jz", 0), ("control.jnz", 0), ("control.jc", 2), ("control.jnc", 2), ("control.jn", 1), ("control.jp", 1), ("control.jv", 3), ("control.jnv", 3), ] circuit_prefix, flag_idx = cond_circuits[cond_type] hi_pc = self._conditional_jump_byte( circuit_prefix, (next_pc_ext >> 8) & 0xFF, (addr16 >> 8) & 0xFF, s.flags[flag_idx], ) lo_pc = self._conditional_jump_byte( circuit_prefix, next_pc_ext & 0xFF, addr16 & 0xFF, s.flags[flag_idx], ) s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF) write_result = False elif opcode == 0xE: ret_addr = next_pc_ext & 0xFFFF s.sp = (s.sp - 1) & 0xFFFF s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF) s.sp = (s.sp - 1) & 0xFFFF s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF) s.pc = addr16 & 0xFFFF write_result = False elif opcode == 0xF: s.ctrl[0] = 1 write_result = False if opcode <= 0x9 or opcode == 0xA: s.flags = list(flags_from_result(result, carry, overflow)) if write_result: s.regs[rd] = result & 0xFF if opcode not in (0xC, 0xD, 0xE): s.pc = next_pc_ext return s def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: """Execute until HALT or max_cycles reached.""" s = state.copy() for i in range(max_cycles): if s.ctrl[0] == 1: return s, i s = self.step(s) return s, max_cycles def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor: """Tensor-in, tensor-out interface for neural integration.""" bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()] state = unpack_state(bits_list) final, _ = self.run_until_halt(state, max_cycles=max_cycles) return torch.tensor(pack_state(final), dtype=torch.float32) def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int: return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF) def write_instr(mem: List[int], addr: int, instr: int) -> None: mem[addr & 0xFFFF] = (instr >> 8) & 0xFF mem[(addr + 1) & 0xFFFF] = instr & 0xFF def write_addr(mem: List[int], addr: int, value: int) -> None: mem[addr & 0xFFFF] = (value >> 8) & 0xFF mem[(addr + 1) & 0xFFFF] = value & 0xFF def run_smoke_test() -> None: """Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12.""" mem = [0] * 65536 write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00)) write_addr(mem, 0x0002, 0x0100) write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00)) write_addr(mem, 0x0006, 0x0101) write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00)) write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00)) write_addr(mem, 0x000C, 0x0102) write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00)) mem[0x0100] = 5 mem[0x0101] = 7 state = CPUState( pc=0, ir=0, regs=[0, 0, 0, 0], flags=[0, 0, 0, 0], sp=0xFFFE, ctrl=[0, 0, 0, 0], mem=mem, ) print("Running reference implementation...") final, cycles = ref_run_until_halt(state, max_cycles=20) assert final.ctrl[0] == 1, "HALT flag not set" assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}" assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}" assert cycles <= 10, f"Unexpected cycle count: {cycles}" print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}") print("Running threshold-weight implementation...") threshold_cpu = ThresholdCPU() t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20) assert t_final.ctrl[0] == 1, "Threshold HALT flag not set" assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}" assert t_final.mem[0x0102] == final.mem[0x0102], ( f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}" ) assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}" print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}") print("Validating forward() tensor I/O...") bits = torch.tensor(pack_state(state), dtype=torch.float32) out_bits = threshold_cpu.forward(bits, max_cycles=20) out_state = unpack_state([int(b) for b in out_bits.tolist()]) assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}" assert out_state.mem[0x0102] == final.mem[0x0102], ( f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}" ) print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}") print("\nSmoke test: PASSED") if __name__ == "__main__": parser = argparse.ArgumentParser(description="8-bit Threshold CPU") parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH, help="Path to safetensors model") args = parser.parse_args() if args.model != DEFAULT_MODEL_PATH: DEFAULT_MODEL_PATH = args.model run_smoke_test()