File size: 4,346 Bytes
6e3b69a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Run the CPU program suite against a threshold-computer variant.

Loads weights, instantiates GenericThresholdCPU, runs each program from
cpu_programs.SUITE, and verifies expected memory contents at HALT.

Usage:
    python test_cpu.py                                       # default: 1KB variant, fast
    python test_cpu.py --model neural_computer.safetensors   # 64KB canonical, slow
    python test_cpu.py --only fib,sum_n                      # subset of suite
"""

from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path

import torch
from safetensors import safe_open

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from eval_all import GenericThresholdCPU, get_manifest
from cpu_programs import SUITE


def load_tensors(path: Path):
    out = {}
    with safe_open(str(path), framework="pt") as f:
        for name in f.keys():
            out[name] = f.get_tensor(name).float()
    return out


def run_program(cpu: GenericThresholdCPU, mem, max_cycles: int):
    addr_mask = (1 << cpu.addr_bits) - 1
    state = {
        "pc": 0,
        "regs": [0] * 4,
        "flags": [0] * 4,
        "mem": list(mem),
        "halted": False,
        "sp": addr_mask,
    }
    t0 = time.perf_counter()
    final, cycles = cpu.run(state, max_cycles=max_cycles)
    return final, cycles, time.perf_counter() - t0


def check_expected(final, expected: dict) -> tuple[bool, list[str]]:
    failures = []
    for addr, want in expected.items():
        got = final["mem"][addr]
        if got != want:
            failures.append(f"M[0x{addr:04X}] = {got} (expected {want})")
    return (len(failures) == 0), failures


def main() -> int:
    parser = argparse.ArgumentParser(description="Run the CPU program suite")
    parser.add_argument(
        "--model", type=str,
        default=os.path.join(os.path.dirname(__file__),
                             "variants", "neural_computer8_small.safetensors"),
        help="Path to .safetensors variant",
    )
    parser.add_argument(
        "--only", type=str, default="",
        help="Comma-separated subset of program names to run",
    )
    args = parser.parse_args()

    print(f"Loading {args.model}")
    tensors = load_tensors(Path(args.model))
    manifest = get_manifest(tensors)
    print(f"Manifest: data={manifest['data_bits']}-bit, addr={manifest['addr_bits']}-bit, "
          f"mem={manifest['memory_bytes']}B")

    if manifest["memory_bytes"] < 256:
        print(f"ERROR: variant has {manifest['memory_bytes']}B memory; "
              f"the suite needs at least 256B (scratchpad).")
        return 2

    cpu = GenericThresholdCPU(tensors)
    only = set(s.strip() for s in args.only.split(",") if s.strip())

    print()
    print("=" * 80)
    print(f" CPU PROGRAM SUITE  ({manifest['memory_bytes']}B mem, {manifest['data_bits']}-bit ALU)")
    print("=" * 80)

    pass_count = 0
    fail_count = 0
    skip_count = 0

    for name, builder in SUITE:
        if only and name not in only:
            skip_count += 1
            continue
        try:
            mem, expected, max_cycles, desc = builder(manifest["memory_bytes"])
        except Exception as e:
            print(f"  {name:18}  BUILD ERROR: {e}")
            fail_count += 1
            continue

        if len(mem) != manifest["memory_bytes"]:
            print(f"  {name:18}  SKIP (program built {len(mem)}B, "
                  f"variant has {manifest['memory_bytes']}B)")
            skip_count += 1
            continue

        final, cycles, elapsed = run_program(cpu, mem, max_cycles)
        ok, failures = check_expected(final, expected)
        if ok and not final["halted"]:
            ok = False
            failures.append(f"did not HALT within {max_cycles} cycles")

        status = "PASS" if ok else "FAIL"
        if ok:
            pass_count += 1
        else:
            fail_count += 1
        print(f"  {name:18}  {status}  ({cycles:>3} cyc, {elapsed:>5.2f}s)  {desc}")
        if not ok:
            for f in failures[:6]:
                print(f"    - {f}")

    print()
    print("=" * 80)
    total = pass_count + fail_count
    print(f"  PASS: {pass_count}/{total}   FAIL: {fail_count}/{total}   SKIP: {skip_count}")
    return 0 if fail_count == 0 else 1


if __name__ == "__main__":
    sys.exit(main())