| """ |
| Run the CPU program suite against a threshold-computer variant. |
| |
| Loads weights, instantiates GenericThresholdCPU, runs each program from |
| cpu_programs.SUITE, and verifies expected memory contents at HALT. |
| |
| Usage: |
| python test_cpu.py # default: 1KB variant, fast |
| python test_cpu.py --model neural_computer.safetensors # 64KB canonical, slow |
| python test_cpu.py --only fib,sum_n # subset of suite |
| """ |
|
|
| from __future__ import annotations |
| import argparse |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| from safetensors import safe_open |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from eval_all import GenericThresholdCPU, get_manifest |
| from cpu_programs import SUITE |
|
|
|
|
| def load_tensors(path: Path): |
| out = {} |
| with safe_open(str(path), framework="pt") as f: |
| for name in f.keys(): |
| out[name] = f.get_tensor(name).float() |
| return out |
|
|
|
|
| def run_program(cpu: GenericThresholdCPU, mem, max_cycles: int): |
| addr_mask = (1 << cpu.addr_bits) - 1 |
| state = { |
| "pc": 0, |
| "regs": [0] * 4, |
| "flags": [0] * 4, |
| "mem": list(mem), |
| "halted": False, |
| "sp": addr_mask, |
| } |
| t0 = time.perf_counter() |
| final, cycles = cpu.run(state, max_cycles=max_cycles) |
| return final, cycles, time.perf_counter() - t0 |
|
|
|
|
| def check_expected(final, expected: dict) -> tuple[bool, list[str]]: |
| failures = [] |
| for addr, want in expected.items(): |
| got = final["mem"][addr] |
| if got != want: |
| failures.append(f"M[0x{addr:04X}] = {got} (expected {want})") |
| return (len(failures) == 0), failures |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description="Run the CPU program suite") |
| parser.add_argument( |
| "--model", type=str, |
| default=os.path.join(os.path.dirname(__file__), |
| "variants", "neural_computer8_small.safetensors"), |
| help="Path to .safetensors variant", |
| ) |
| parser.add_argument( |
| "--only", type=str, default="", |
| help="Comma-separated subset of program names to run", |
| ) |
| args = parser.parse_args() |
|
|
| print(f"Loading {args.model}") |
| tensors = load_tensors(Path(args.model)) |
| manifest = get_manifest(tensors) |
| print(f"Manifest: data={manifest['data_bits']}-bit, addr={manifest['addr_bits']}-bit, " |
| f"mem={manifest['memory_bytes']}B") |
|
|
| if manifest["memory_bytes"] < 256: |
| print(f"ERROR: variant has {manifest['memory_bytes']}B memory; " |
| f"the suite needs at least 256B (scratchpad).") |
| return 2 |
|
|
| cpu = GenericThresholdCPU(tensors) |
| only = set(s.strip() for s in args.only.split(",") if s.strip()) |
|
|
| print() |
| print("=" * 80) |
| print(f" CPU PROGRAM SUITE ({manifest['memory_bytes']}B mem, {manifest['data_bits']}-bit ALU)") |
| print("=" * 80) |
|
|
| pass_count = 0 |
| fail_count = 0 |
| skip_count = 0 |
|
|
| for name, builder in SUITE: |
| if only and name not in only: |
| skip_count += 1 |
| continue |
| try: |
| mem, expected, max_cycles, desc = builder(manifest["memory_bytes"]) |
| except Exception as e: |
| print(f" {name:18} BUILD ERROR: {e}") |
| fail_count += 1 |
| continue |
|
|
| if len(mem) != manifest["memory_bytes"]: |
| print(f" {name:18} SKIP (program built {len(mem)}B, " |
| f"variant has {manifest['memory_bytes']}B)") |
| skip_count += 1 |
| continue |
|
|
| final, cycles, elapsed = run_program(cpu, mem, max_cycles) |
| ok, failures = check_expected(final, expected) |
| if ok and not final["halted"]: |
| ok = False |
| failures.append(f"did not HALT within {max_cycles} cycles") |
|
|
| status = "PASS" if ok else "FAIL" |
| if ok: |
| pass_count += 1 |
| else: |
| fail_count += 1 |
| print(f" {name:18} {status} ({cycles:>3} cyc, {elapsed:>5.2f}s) {desc}") |
| if not ok: |
| for f in failures[:6]: |
| print(f" - {f}") |
|
|
| print() |
| print("=" * 80) |
| total = pass_count + fail_count |
| print(f" PASS: {pass_count}/{total} FAIL: {fail_count}/{total} SKIP: {skip_count}") |
| return 0 if fail_count == 0 else 1 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|