""" CPU validation programs for the threshold computer. A small assembler and a suite of programs that exercise the ISA end-to-end: arithmetic, control flow, memory access, self-modifying code, all eight conditional jumps, the call mechanism, and a sort. Each program returns (mem, expected, max_cycles, description) where: mem : list[int] -- complete memory image expected : dict[int, int] -- {address: expected_value} verified at HALT max_cycles : int -- cycle budget (an infinite loop will hit this) description: str -- short human-readable summary Programs target the 1 KB profile (addr_bits=10) by default but use only the low 256 bytes so they also run on scratchpad and larger profiles. All programs assume the CPU starts with PC=0, SP defaulting to addr_mask (highest address; CALL pre-decrements before writing). """ from __future__ import annotations from typing import Dict, List, Tuple # ---------------------------------------------------------------------------- # Mini assembler # ---------------------------------------------------------------------------- _OPCODE_NAMES = { "add": 0x0, "sub": 0x1, "and": 0x2, "or": 0x3, "xor": 0x4, "shl": 0x5, "shr": 0x6, "mul": 0x7, "div": 0x8, "cmp": 0x9, "load": 0xA, "store": 0xB, "jmp": 0xC, "jcc": 0xD, "call": 0xE, "halt": 0xF, } _COND = {"jz": 0, "jnz": 1, "jc": 2, "jnc": 3, "jn": 4, "jp": 5, "jv": 6, "jnv": 7} def _enc(opcode: int, rd: int = 0, rs: int = 0, imm: int = 0) -> int: return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm & 0xFF) class Asm: """Tiny assembler for the threshold-computer ISA. Usage: a = Asm(size=256) a.org(0) a.label("start") a.load(0, "data") a.halt() a.org(0x80); a.label("data"); a.db(42) mem = a.assemble() """ def __init__(self, size: int): self.mem: List[int] = [0] * size self.pc: int = 0 self.labels: Dict[str, int] = {} self.fixups: List[Tuple[int, str]] = [] def org(self, addr: int) -> None: self.pc = addr def label(self, name: str) -> None: if name in self.labels: raise ValueError(f"duplicate label: {name}") self.labels[name] = self.pc def db(self, *values: int) -> None: for v in values: self.mem[self.pc] = v & 0xFF self.pc += 1 def dw(self, value: int) -> None: self.mem[self.pc] = (value >> 8) & 0xFF self.mem[self.pc + 1] = value & 0xFF self.pc += 2 def daddr(self, label: str) -> None: self.fixups.append((self.pc, label)) self.dw(0) # --- ALU ops (no immediate) --- def _alu(self, op: int, rd: int, rs: int) -> None: self.dw(_enc(op, rd, rs)) def add(self, rd: int, rs: int) -> None: self._alu(0x0, rd, rs) def sub(self, rd: int, rs: int) -> None: self._alu(0x1, rd, rs) def and_(self, rd: int, rs: int) -> None: self._alu(0x2, rd, rs) def or_(self, rd: int, rs: int) -> None: self._alu(0x3, rd, rs) def xor(self, rd: int, rs: int) -> None: self._alu(0x4, rd, rs) def shl(self, rd: int) -> None: self._alu(0x5, rd, 0) def shr(self, rd: int) -> None: self._alu(0x6, rd, 0) def mul(self, rd: int, rs: int) -> None: self._alu(0x7, rd, rs) def cmp(self, rd: int, rs: int) -> None: self._alu(0x9, rd, rs) # --- Memory + control (address-extended) --- def load(self, rd: int, label: str) -> None: self.dw(_enc(0xA, rd, 0)); self.daddr(label) def store(self, rs: int, label: str) -> None: self.dw(_enc(0xB, 0, rs)); self.daddr(label) def jmp(self, label: str) -> None: self.dw(_enc(0xC)); self.daddr(label) def jcc(self, cond: str, label: str) -> None: self.dw(_enc(0xD, 0, 0, _COND[cond])); self.daddr(label) def jz(self, label: str) -> None: self.jcc("jz", label) def jnz(self, label: str) -> None: self.jcc("jnz", label) def jc(self, label: str) -> None: self.jcc("jc", label) def jnc(self, label: str) -> None: self.jcc("jnc", label) def jn(self, label: str) -> None: self.jcc("jn", label) def jp(self, label: str) -> None: self.jcc("jp", label) def jv(self, label: str) -> None: self.jcc("jv", label) def jnv(self, label: str) -> None: self.jcc("jnv", label) def call(self, label: str) -> None: self.dw(_enc(0xE)); self.daddr(label) def halt(self) -> None: self.dw(_enc(0xF)) def assemble(self) -> List[int]: for offset, label in self.fixups: if label not in self.labels: raise ValueError(f"undefined label: {label}") target = self.labels[label] self.mem[offset] = (target >> 8) & 0xFF self.mem[offset + 1] = target & 0xFF return self.mem # ---------------------------------------------------------------------------- # Programs # ---------------------------------------------------------------------------- ProgramResult = Tuple[List[int], Dict[int, int], int, str] def fib(n: int = 11, mem_size: int = 256) -> ProgramResult: """Iterative Fibonacci F(N), 8-bit wrap. F(11) = 89. F(13) = 233 still fits in 8 bits; F(14) = 377 overflows. Algorithm: maintain (a, b) = (F(k), F(k+1)); after N steps a = F(N). Per iteration: temp=b, b=a+b, a=temp; n--. """ a = Asm(mem_size) a.org(0) a.load(2, "n_addr") # R2 = n a.load(0, "zero_addr") # R0 = 0 = F(0) a.load(1, "one_addr") # R1 = 1 = F(1) a.label("loop") a.load(3, "zero_addr") # R3 = 0 a.cmp(2, 3) # n == 0? a.jz("done") a.load(3, "zero_addr") # R3 = 0 a.add(3, 1) # R3 = b (saved old b) a.add(1, 0) # R1 = a + b (new b) a.load(0, "zero_addr") # R0 = 0 a.add(0, 3) # R0 = old b (new a) a.load(3, "one_addr") # R3 = 1 a.sub(2, 3) # n-- a.jmp("loop") a.label("done") a.store(0, "out") # OUT = a a.halt() a.org(0x80) a.label("zero_addr"); a.db(0) a.label("one_addr"); a.db(1) a.label("n_addr"); a.db(n) a.label("out"); a.db(0) mem = a.assemble() expected_a = 0 aa, bb = 0, 1 for _ in range(n): aa, bb = bb, (aa + bb) & 0xFF expected_a = aa return mem, {a.labels["out"]: expected_a}, 16 * (n + 2), f"Fibonacci F({n}) = {expected_a}" def sum_n(n: int = 10, mem_size: int = 256) -> ProgramResult: """Compute 1 + 2 + ... + N using the Z flag from SUB to terminate. No explicit zero register required; SUB sets Z when its result is zero. R0 = accumulator, R1 = counter (n down to 1), R2 = 1. """ a = Asm(mem_size) a.org(0) a.load(0, "zero_addr") # acc = 0 a.load(1, "n_addr") # i = n a.load(2, "one_addr") # 1 a.label("loop") a.add(0, 1) # acc += i a.sub(1, 2) # i-- (Z set when i becomes 0) a.jnz("loop") a.store(0, "out") a.halt() a.org(0x80) a.label("zero_addr"); a.db(0) a.label("one_addr"); a.db(1) a.label("n_addr"); a.db(n) a.label("out"); a.db(0) mem = a.assemble() expected = (n * (n + 1) // 2) & 0xFF return mem, {a.labels["out"]: expected}, 4 + 4 * n, f"sum 1..{n} = {expected}" def self_mod_jmp(mem_size: int = 256) -> ProgramResult: """Self-modifying code: writes the JMP target's low byte at runtime. Initial JMP target is path_a (writes 0xAA to OUT). The code first overwrites the JMP's address-word LSB so it points to path_b (writes 0xBB). Successful execution lands at path_b, so OUT = 0xBB. """ a = Asm(mem_size) a.org(0) a.label("start") # Forward-declare the_jmp's LSB address. The first two instructions are # each 4 bytes; the_jmp follows them, so the_jmp = pc + 8 from start. # The JMP's address word is at the_jmp + 2; the LSB byte is at the_jmp + 3. a.labels["jmp_target_lsb"] = a.pc + 8 + 3 a.load(0, "new_lsb") # R0 = LSB of path_b address (4 bytes) a.store(0, "jmp_target_lsb") # patch the JMP's LSB (4 bytes) a.label("the_jmp") a.jmp("path_a") # initially -> path_a; after patch -> path_b a.label("path_a") a.load(1, "val_a"); a.store(1, "out"); a.halt() a.label("path_b") a.load(1, "val_b"); a.store(1, "out"); a.halt() a.org(0x80) a.label("val_a"); a.db(0xAA) a.label("val_b"); a.db(0xBB) # new_lsb_word stores path_b's full 16-bit address; new_lsb labels the LSB byte. a.label("new_lsb_word"); a.daddr("path_b") a.labels["new_lsb"] = a.labels["new_lsb_word"] + 1 a.label("out"); a.db(0) mem = a.assemble() return mem, {a.labels["out"]: 0xBB}, 30, "self-modifying JMP target" def all_branches(mem_size: int = 256) -> ProgramResult: """Drive all eight conditional jumps; each path writes a unique marker. Test plan (each step sets flags, then a Jcc; the branch takes when expected and the corresponding marker is written): JZ: CMP equal -> Z=1 -> taken -> M[OUT0] = 0xA0 JNZ: CMP unequal -> Z=0 -> taken -> M[OUT1] = 0xA1 JC: ADD overflow (255+1) -> C=1 -> taken -> M[OUT2] = 0xA2 JNC: ADD no overflow -> C=0 -> taken -> M[OUT3] = 0xA3 JN: SUB result 0xFF (n=1) -> N=1 -> taken -> M[OUT4] = 0xA4 JP: ADD result 1 -> N=0 -> taken -> M[OUT5] = 0xA5 JV: ADD signed overflow (127+1=128) -> V=1 -> taken -> M[OUT6] = 0xA6 JNV: ADD no signed overflow (1+1=2) -> V=0 -> taken -> M[OUT7] = 0xA7 A failure on any branch causes the wrong (or no) marker to be written. """ a = Asm(mem_size) a.org(0) # ----- JZ: equal compare -> Z=1 ----- a.load(0, "v5"); a.load(1, "v5"); a.cmp(0, 1); a.jz("ok_jz"); a.jmp("fail") a.label("ok_jz"); a.load(2, "m_a0"); a.store(2, "out0") # ----- JNZ: unequal compare -> Z=0 ----- a.load(0, "v5"); a.load(1, "v3"); a.cmp(0, 1); a.jnz("ok_jnz"); a.jmp("fail") a.label("ok_jnz"); a.load(2, "m_a1"); a.store(2, "out1") # ----- JC: 255+1 = 0 with carry ----- a.load(0, "v255"); a.load(1, "v1"); a.add(0, 1); a.jc("ok_jc"); a.jmp("fail") a.label("ok_jc"); a.load(2, "m_a2"); a.store(2, "out2") # ----- JNC: 1+1 = 2, no carry ----- a.load(0, "v1"); a.load(1, "v1"); a.add(0, 1); a.jnc("ok_jnc"); a.jmp("fail") a.label("ok_jnc"); a.load(2, "m_a3"); a.store(2, "out3") # ----- JN: 0 - 1 = 0xFF, MSB set ----- a.load(0, "v0"); a.load(1, "v1"); a.sub(0, 1); a.jn("ok_jn"); a.jmp("fail") a.label("ok_jn"); a.load(2, "m_a4"); a.store(2, "out4") # ----- JP: 0 + 1 = 1, MSB clear ----- a.load(0, "v0"); a.load(1, "v1"); a.add(0, 1); a.jp("ok_jp"); a.jmp("fail") a.label("ok_jp"); a.load(2, "m_a5"); a.store(2, "out5") # ----- JV: 127 + 1 = 128, signed overflow ----- a.load(0, "v127"); a.load(1, "v1"); a.add(0, 1); a.jv("ok_jv"); a.jmp("fail") a.label("ok_jv"); a.load(2, "m_a6"); a.store(2, "out6") # ----- JNV: 1 + 1 = 2, no signed overflow ----- a.load(0, "v1"); a.load(1, "v1"); a.add(0, 1); a.jnv("ok_jnv"); a.jmp("fail") a.label("ok_jnv"); a.load(2, "m_a7"); a.store(2, "out7") a.jmp("end") a.label("fail") a.load(2, "v_fail"); a.store(2, "fail_addr"); a.halt() a.label("end") a.halt() # Code runs to ~0xDF; data starts safely after that. a.org(0xE0) a.label("v0"); a.db(0) a.label("v1"); a.db(1) a.label("v3"); a.db(3) a.label("v5"); a.db(5) a.label("v127"); a.db(127) a.label("v255"); a.db(255) a.label("m_a0"); a.db(0xA0) a.label("m_a1"); a.db(0xA1) a.label("m_a2"); a.db(0xA2) a.label("m_a3"); a.db(0xA3) a.label("m_a4"); a.db(0xA4) a.label("m_a5"); a.db(0xA5) a.label("m_a6"); a.db(0xA6) a.label("m_a7"); a.db(0xA7) a.label("v_fail"); a.db(0xEE) a.label("out0"); a.db(0) a.label("out1"); a.db(0) a.label("out2"); a.db(0) a.label("out3"); a.db(0) a.label("out4"); a.db(0) a.label("out5"); a.db(0) a.label("out6"); a.db(0) a.label("out7"); a.db(0) a.label("fail_addr"); a.db(0) mem = a.assemble() expected = { a.labels["out0"]: 0xA0, a.labels["out1"]: 0xA1, a.labels["out2"]: 0xA2, a.labels["out3"]: 0xA3, a.labels["out4"]: 0xA4, a.labels["out5"]: 0xA5, a.labels["out6"]: 0xA6, a.labels["out7"]: 0xA7, a.labels["fail_addr"]: 0, # must remain zero } return mem, expected, 200, "all 8 conditional jumps (JZ/JNZ/JC/JNC/JN/JP/JV/JNV)" def call_pushes_pc(mem_size: int = 256) -> ProgramResult: """Verify CALL pushes the return address (next-instruction PC) onto the stack. SP starts at addr_mask (mem_size - 1). CALL decrements SP and writes the return-address high byte, decrements again and writes the low byte. After HALT we expect: - mem[addr_mask - 2] = low byte of the return address - mem[addr_mask - 1] = high byte of the return address - the no-return code path was NOT taken - the callee was reached """ a = Asm(mem_size) a.org(0) a.label("caller") a.load(0, "marker_val") a.store(0, "marker_addr") # write before CALL a.label("call_site") a.call("callee") # If CALL did not transfer control, this fallthrough store would write 0xDD: a.load(0, "noret_val") a.store(0, "noret_addr") a.halt() a.label("callee") a.load(0, "callee_val") a.store(0, "callee_addr") a.halt() a.org(0x40) a.label("marker_val"); a.db(0x11) a.label("marker_addr"); a.db(0) a.label("noret_val"); a.db(0xDD) a.label("noret_addr"); a.db(0) a.label("callee_val"); a.db(0x22) a.label("callee_addr"); a.db(0) mem = a.assemble() addr_mask = mem_size - 1 return_addr = a.labels["call_site"] + 4 # 4-byte CALL instruction expected = { a.labels["marker_addr"]: 0x11, # pre-CALL store ran a.labels["callee_addr"]: 0x22, # callee reached a.labels["noret_addr"]: 0, # fallthrough did not run (addr_mask - 2) & addr_mask: return_addr & 0xFF, # ret LSB on stack (addr_mask - 1) & addr_mask: (return_addr >> 8) & 0xFF, # ret MSB } return mem, expected, 30, "CALL pushes return PC onto stack" def bubble_sort_4(mem_size: int = 256) -> ProgramResult: """Sort a 4-byte array using unrolled compare-swap (3 passes of 3 compares). Algorithm: bubble sort, fully unrolled (no inner loops). Each compare-swap is 8 instructions; 3 outer passes x 3 inner positions = 9 swaps -> ~72 instrs. For each position i in (0,1,2): if A[i] > A[i+1]: tmp = A[i]; A[i] = A[i+1]; A[i+1] = tmp Repeat 3 times -> sorted ascending. """ a = Asm(mem_size) addrs = ["a0", "a1", "a2", "a3"] a.org(0) for _outer in range(3): for i in range(3): x, y = addrs[i], addrs[i + 1] # Load pair a.load(0, x) # R0 = A[i] a.load(1, y) # R1 = A[i+1] a.cmp(0, 1) # compare A[i] - A[i+1] # If A[i] <= A[i+1] (Z=1 or C=0 from sub-style cmp), skip swap # SUB sets carry when no borrow (a >= b). So: # a > b iff Z=0 and a >= b -> Z=0 and C=1 (the sub didn't borrow) # We want to swap when a > b. JNC (no carry / borrow) means a < b -> skip swap. # If a == b (Z=1) we also skip. So: jump-skip when JZ OR JNC. # Easier: compute (a > b) by checking C=1 AND Z=0. Use JZ to skip on equal, # then JNC to skip on a < b. Otherwise fall through to swap. skip_lbl = f"skip_{_outer}_{i}" a.jz(skip_lbl) # equal -> skip a.jnc(skip_lbl) # a < b (sub borrowed) -> skip # swap: store R0 -> y, R1 -> x a.store(1, x) a.store(0, y) a.label(skip_lbl) a.halt() # Initial unsorted array; code runs to ~0xEC, so data starts at 0xF0. initial = [42, 7, 200, 19] a.org(0xF0) for name, val in zip(addrs, initial): a.label(name); a.db(val) mem = a.assemble() sorted_vals = sorted(initial) expected = {a.labels[name]: v for name, v in zip(addrs, sorted_vals)} return mem, expected, 800, f"bubble sort {initial} -> {sorted_vals}" def cross_check_mul(mem_size: int = 256) -> ProgramResult: """Cross-check the threshold MUL circuit against repeated ADD. Multiplies A * B two ways: 1. R0 = A; ADD R0, B repeatedly B times. Stored at OUT_ADD. Wait that gives A*(B+1) actually... let me rewrite. Use: acc = 0; for i in 0..B-1: acc += A; -> acc = A*B direct: R0 = A; MUL R0, R1 (R1 = B); -> R0 = A*B Compare both at OUT. """ a = Asm(mem_size) A_VAL = 17 B_VAL = 9 expected_product = (A_VAL * B_VAL) & 0xFF a.org(0) # --- direct multiply --- a.load(0, "A") a.load(1, "B") a.mul(0, 1) a.store(0, "out_mul") # --- repeated-add multiply --- a.load(0, "zero") # acc = 0 a.load(1, "A") # addend a.load(2, "B") # counter a.load(3, "one") # 1 a.label("rep_loop") a.add(0, 1) # acc += A a.sub(2, 3) # B-- a.jnz("rep_loop") a.store(0, "out_add") a.halt() a.org(0x80) a.label("A"); a.db(A_VAL) a.label("B"); a.db(B_VAL) a.label("zero"); a.db(0) a.label("one"); a.db(1) a.label("out_mul"); a.db(0) a.label("out_add"); a.db(0) mem = a.assemble() expected = { a.labels["out_mul"]: expected_product, a.labels["out_add"]: expected_product, } return mem, expected, 80, f"MUL vs repeated ADD: {A_VAL} * {B_VAL} = {expected_product}" def div_via_repeated_sub(mem_size: int = 256) -> ProgramResult: """Compute floor(A/B) and (A mod B) by repeated subtraction. Loop: while A >= B { A -= B; quotient += 1 } Uses CMP + JC (carry-set on no-borrow), SUB, ADD, JMP, STORE, HALT. Cross-checked against the on-chip 8-bit DIV opcode (0x8) via a second pass that uses DIV directly. Both quotients written to OUT locations; the test verifies they match. """ A_VAL = 100 B_VAL = 7 expected_q = A_VAL // B_VAL # 14 expected_r = A_VAL % B_VAL # 2 a = Asm(mem_size) a.org(0) # ---- Repeated-subtraction division ---- a.load(0, "A") # R0 = A (will become remainder) a.load(1, "B") # R1 = B (divisor) a.load(2, "ZERO") # R2 = 0 (will become quotient) a.load(3, "ONE") # R3 = 1 (increment) a.label("loop") a.cmp(0, 1) # CMP R0, R1; carry=1 (no-borrow) iff R0 >= R1 a.jnc("done") # if R0 < R1 (carry=0), exit loop a.sub(0, 1) # R0 -= B a.add(2, 3) # quotient += 1 a.jmp("loop") a.label("done") a.store(2, "OUT_Q_RPT") # quotient via repeated sub a.store(0, "OUT_R_RPT") # remainder via repeated sub # ---- Direct DIV opcode for cross-check ---- a.load(0, "A") a.load(1, "B") a.dw(_enc(0x8, 0, 1, 0)) # DIV R0, R1 -> R0 = R0 / R1 (8-bit DIV) a.store(0, "OUT_Q_DIV") a.halt() a.org(0x80) a.label("A"); a.db(A_VAL) a.label("B"); a.db(B_VAL) a.label("ZERO"); a.db(0) a.label("ONE"); a.db(1) a.label("OUT_Q_RPT"); a.db(0) a.label("OUT_R_RPT"); a.db(0) a.label("OUT_Q_DIV"); a.db(0) mem = a.assemble() expected = { a.labels["OUT_Q_RPT"]: expected_q, a.labels["OUT_R_RPT"]: expected_r, a.labels["OUT_Q_DIV"]: expected_q, } return mem, expected, 4 * (A_VAL // B_VAL + 4) + 12, ( f"{A_VAL} / {B_VAL}: quotient {expected_q} (repeated SUB) " f"matches DIV opcode result; remainder {expected_r}" ) def bitwise_chain(mem_size: int = 256) -> ProgramResult: """Run a chain of bitwise ops and verify each intermediate value. Sequence: R0 = A & B (AND) R0 = R0 | C (OR) R0 = R0 ^ D (XOR) R0 = R0 << 1 (SHL) R0 = R0 >> 1 (SHR) Stores R0 after each step. Verifies all intermediate values to catch any single-op regression. """ A = 0xCC # 11001100 B = 0xF0 # 11110000 C = 0x0F # 00001111 D = 0xAA # 10101010 s1 = A & B # 0xC0 s2 = s1 | C # 0xCF s3 = s2 ^ D # 0x65 s4 = (s3 << 1) & 0xFF # 0xCA s5 = s4 >> 1 # 0x65 a = Asm(mem_size) a.org(0) a.load(0, "A"); a.load(1, "B"); a.and_(0, 1); a.store(0, "S1") a.load(1, "C"); a.or_(0, 1); a.store(0, "S2") a.load(1, "D"); a.xor(0, 1); a.store(0, "S3") a.shl(0); a.store(0, "S4") a.shr(0); a.store(0, "S5") a.halt() a.org(0x80) a.label("A"); a.db(A) a.label("B"); a.db(B) a.label("C"); a.db(C) a.label("D"); a.db(D) a.label("S1"); a.db(0) a.label("S2"); a.db(0) a.label("S3"); a.db(0) a.label("S4"); a.db(0) a.label("S5"); a.db(0) mem = a.assemble() expected = { a.labels["S1"]: s1, a.labels["S2"]: s2, a.labels["S3"]: s3, a.labels["S4"]: s4, a.labels["S5"]: s5, } return mem, expected, 30, ( f"bitwise chain AND/OR/XOR/SHL/SHR -> {s1:#x},{s2:#x},{s3:#x},{s4:#x},{s5:#x}" ) SUITE = [ ("fib", lambda mem_size: fib(11, mem_size)), ("sum_n", lambda mem_size: sum_n(10, mem_size)), ("self_mod_jmp", lambda mem_size: self_mod_jmp(mem_size)), ("all_branches", lambda mem_size: all_branches(mem_size)), ("call_pushes_pc", lambda mem_size: call_pushes_pc(mem_size)), ("bubble_sort_4", lambda mem_size: bubble_sort_4(mem_size)), ("cross_check_mul", lambda mem_size: cross_check_mul(mem_size)), ("div_via_repeated_sub", lambda mem_size: div_via_repeated_sub(mem_size)), ("bitwise_chain", lambda mem_size: bitwise_chain(mem_size)), ]