| """ |
| Basic CPU cycle smoke test. |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
|
|
| from cpu.cycle import run_until_halt |
| from cpu.state import CPUState, pack_state, unpack_state |
| from cpu.threshold_cpu import ThresholdCPU |
|
|
|
|
| def encode(opcode: int, rd: int, rs: int, imm8: int) -> int: |
| return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF) |
|
|
|
|
| def write_instr(mem, addr, instr): |
| mem[addr & 0xFFFF] = (instr >> 8) & 0xFF |
| mem[(addr + 1) & 0xFFFF] = instr & 0xFF |
|
|
|
|
| def write_addr(mem, addr, value): |
| mem[addr & 0xFFFF] = (value >> 8) & 0xFF |
| mem[(addr + 1) & 0xFFFF] = value & 0xFF |
|
|
|
|
| def main() -> None: |
| mem = [0] * 65536 |
|
|
| write_instr(mem, 0x0000, encode(0xA, 0, 0, 0x00)) |
| write_addr(mem, 0x0002, 0x0100) |
| write_instr(mem, 0x0004, encode(0xA, 1, 0, 0x00)) |
| write_addr(mem, 0x0006, 0x0101) |
| write_instr(mem, 0x0008, encode(0x0, 0, 1, 0x00)) |
| write_instr(mem, 0x000A, encode(0xB, 0, 0, 0x00)) |
| write_addr(mem, 0x000C, 0x0102) |
| write_instr(mem, 0x000E, encode(0xF, 0, 0, 0x00)) |
|
|
| mem[0x0100] = 5 |
| mem[0x0101] = 7 |
|
|
| state = CPUState( |
| pc=0, |
| ir=0, |
| regs=[0, 0, 0, 0], |
| flags=[0, 0, 0, 0], |
| sp=0xFFFE, |
| ctrl=[0, 0, 0, 0], |
| mem=mem, |
| ) |
|
|
| final, cycles = run_until_halt(state, max_cycles=20) |
|
|
| assert final.ctrl[0] == 1, "HALT flag not set" |
| assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}" |
| assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}" |
| assert cycles <= 10, f"Unexpected cycle count: {cycles}" |
|
|
| |
| threshold_cpu = ThresholdCPU() |
| t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20) |
|
|
| assert t_final.ctrl[0] == 1, "Threshold HALT flag not set" |
| assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}" |
| assert t_final.mem[0x0102] == final.mem[0x0102], ( |
| f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}" |
| ) |
| assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}" |
|
|
| |
| bits = torch.tensor(pack_state(state), dtype=torch.float32) |
| out_bits = threshold_cpu.forward(bits, max_cycles=20) |
| out_state = unpack_state([int(b) for b in out_bits.tolist()]) |
| assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}" |
| assert out_state.mem[0x0102] == final.mem[0x0102], ( |
| f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}" |
| ) |
|
|
| print("cpu_cycle_test: ok") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|