8bit-threshold-computer / eval /cpu_cycle_test.py
PortfolioAI
Add packed memory routing and 16-bit addressing
ea46629
raw
history blame
2.89 kB
"""
Basic CPU cycle smoke test.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import torch
from cpu.cycle import run_until_halt
from cpu.state import CPUState, pack_state, unpack_state
from cpu.threshold_cpu import ThresholdCPU
def encode(opcode: int, rd: int, rs: int, imm8: int) -> int:
return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF)
def write_instr(mem, addr, instr):
mem[addr & 0xFFFF] = (instr >> 8) & 0xFF
mem[(addr + 1) & 0xFFFF] = instr & 0xFF
def write_addr(mem, addr, value):
mem[addr & 0xFFFF] = (value >> 8) & 0xFF
mem[(addr + 1) & 0xFFFF] = value & 0xFF
def main() -> None:
mem = [0] * 65536
write_instr(mem, 0x0000, encode(0xA, 0, 0, 0x00)) # LOAD R0, [addr]
write_addr(mem, 0x0002, 0x0100)
write_instr(mem, 0x0004, encode(0xA, 1, 0, 0x00)) # LOAD R1, [addr]
write_addr(mem, 0x0006, 0x0101)
write_instr(mem, 0x0008, encode(0x0, 0, 1, 0x00)) # ADD R0, R1
write_instr(mem, 0x000A, encode(0xB, 0, 0, 0x00)) # STORE R0 -> [addr]
write_addr(mem, 0x000C, 0x0102)
write_instr(mem, 0x000E, encode(0xF, 0, 0, 0x00)) # HALT
mem[0x0100] = 5
mem[0x0101] = 7
state = CPUState(
pc=0,
ir=0,
regs=[0, 0, 0, 0],
flags=[0, 0, 0, 0],
sp=0xFFFE,
ctrl=[0, 0, 0, 0],
mem=mem,
)
final, cycles = run_until_halt(state, max_cycles=20)
assert final.ctrl[0] == 1, "HALT flag not set"
assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}"
assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}"
assert cycles <= 10, f"Unexpected cycle count: {cycles}"
# Threshold-weight runtime should match reference behavior.
threshold_cpu = ThresholdCPU()
t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20)
assert t_final.ctrl[0] == 1, "Threshold HALT flag not set"
assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}"
assert t_final.mem[0x0102] == final.mem[0x0102], (
f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}"
)
assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}"
# Validate forward() state I/O.
bits = torch.tensor(pack_state(state), dtype=torch.float32)
out_bits = threshold_cpu.forward(bits, max_cycles=20)
out_state = unpack_state([int(b) for b in out_bits.tolist()])
assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}"
assert out_state.mem[0x0102] == final.mem[0x0102], (
f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}"
)
print("cpu_cycle_test: ok")
if __name__ == "__main__":
main()