CharlesCNorton commited on
Commit
52beb94
·
1 Parent(s): a854d3d

Consolidate CPU into threshold_cpu.py at root

Browse files

- Merge cpu/core.py and eval/cpu_cycle_test.py
- Test now runs as __main__ block
- Remove cpu/ folder

eval/cpu_cycle_test.py DELETED
@@ -1,87 +0,0 @@
1
- """
2
- Basic CPU cycle smoke test.
3
- """
4
-
5
- import sys
6
- from pathlib import Path
7
-
8
- sys.path.append(str(Path(__file__).resolve().parent.parent))
9
-
10
- import torch
11
-
12
- from cpu.cycle import run_until_halt
13
- from cpu.state import CPUState, pack_state, unpack_state
14
- from cpu.threshold_cpu import ThresholdCPU
15
-
16
-
17
- def encode(opcode: int, rd: int, rs: int, imm8: int) -> int:
18
- return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF)
19
-
20
-
21
- def write_instr(mem, addr, instr):
22
- mem[addr & 0xFFFF] = (instr >> 8) & 0xFF
23
- mem[(addr + 1) & 0xFFFF] = instr & 0xFF
24
-
25
-
26
- def write_addr(mem, addr, value):
27
- mem[addr & 0xFFFF] = (value >> 8) & 0xFF
28
- mem[(addr + 1) & 0xFFFF] = value & 0xFF
29
-
30
-
31
- def main() -> None:
32
- mem = [0] * 65536
33
-
34
- write_instr(mem, 0x0000, encode(0xA, 0, 0, 0x00)) # LOAD R0, [addr]
35
- write_addr(mem, 0x0002, 0x0100)
36
- write_instr(mem, 0x0004, encode(0xA, 1, 0, 0x00)) # LOAD R1, [addr]
37
- write_addr(mem, 0x0006, 0x0101)
38
- write_instr(mem, 0x0008, encode(0x0, 0, 1, 0x00)) # ADD R0, R1
39
- write_instr(mem, 0x000A, encode(0xB, 0, 0, 0x00)) # STORE R0 -> [addr]
40
- write_addr(mem, 0x000C, 0x0102)
41
- write_instr(mem, 0x000E, encode(0xF, 0, 0, 0x00)) # HALT
42
-
43
- mem[0x0100] = 5
44
- mem[0x0101] = 7
45
-
46
- state = CPUState(
47
- pc=0,
48
- ir=0,
49
- regs=[0, 0, 0, 0],
50
- flags=[0, 0, 0, 0],
51
- sp=0xFFFE,
52
- ctrl=[0, 0, 0, 0],
53
- mem=mem,
54
- )
55
-
56
- final, cycles = run_until_halt(state, max_cycles=20)
57
-
58
- assert final.ctrl[0] == 1, "HALT flag not set"
59
- assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}"
60
- assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}"
61
- assert cycles <= 10, f"Unexpected cycle count: {cycles}"
62
-
63
- # Threshold-weight runtime should match reference behavior.
64
- threshold_cpu = ThresholdCPU()
65
- t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20)
66
-
67
- assert t_final.ctrl[0] == 1, "Threshold HALT flag not set"
68
- assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}"
69
- assert t_final.mem[0x0102] == final.mem[0x0102], (
70
- f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}"
71
- )
72
- assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}"
73
-
74
- # Validate forward() state I/O.
75
- bits = torch.tensor(pack_state(state), dtype=torch.float32)
76
- out_bits = threshold_cpu.forward(bits, max_cycles=20)
77
- out_state = unpack_state([int(b) for b in out_bits.tolist()])
78
- assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}"
79
- assert out_state.mem[0x0102] == final.mem[0x0102], (
80
- f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}"
81
- )
82
-
83
- print("cpu_cycle_test: ok")
84
-
85
-
86
- if __name__ == "__main__":
87
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cpu/core.py → threshold_cpu.py RENAMED
@@ -1,12 +1,17 @@
1
  """
2
- 8-bit Threshold Computer - Combined CPU Module
3
 
4
- State layout, reference cycle, and threshold-weight runtime in one file.
5
  All multi-bit fields are MSB-first.
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
 
10
  from dataclasses import dataclass
11
  from pathlib import Path
12
  from typing import List, Tuple
@@ -30,6 +35,8 @@ MEM_BITS = MEM_BYTES * 8
30
 
31
  STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
32
 
 
 
33
 
34
  def int_to_bits(value: int, width: int) -> List[int]:
35
  return [(value >> (width - 1 - i)) & 1 for i in range(width)]
@@ -257,9 +264,6 @@ def heaviside(x: torch.Tensor) -> torch.Tensor:
257
  return (x >= 0).float()
258
 
259
 
260
- DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors"
261
-
262
-
263
  class ThresholdALU:
264
  def __init__(self, model_path: str, device: str = "cpu") -> None:
265
  self.device = device
@@ -642,3 +646,88 @@ class ThresholdCPU:
642
  state = unpack_state(bits_list)
643
  final, _ = self.run_until_halt(state, max_cycles=max_cycles)
644
  return torch.tensor(pack_state(final), dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ 8-bit Threshold Computer - CPU Runtime
3
 
4
+ State layout, reference cycle, and threshold-weight execution.
5
  All multi-bit fields are MSB-first.
6
+
7
+ Usage:
8
+ python threshold_cpu.py # Run smoke test
9
+ python threshold_cpu.py --help # Show options
10
  """
11
 
12
  from __future__ import annotations
13
 
14
+ import argparse
15
  from dataclasses import dataclass
16
  from pathlib import Path
17
  from typing import List, Tuple
 
35
 
36
  STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
37
 
38
+ DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "neural_computer.safetensors"
39
+
40
 
41
  def int_to_bits(value: int, width: int) -> List[int]:
42
  return [(value >> (width - 1 - i)) & 1 for i in range(width)]
 
264
  return (x >= 0).float()
265
 
266
 
 
 
 
267
  class ThresholdALU:
268
  def __init__(self, model_path: str, device: str = "cpu") -> None:
269
  self.device = device
 
646
  state = unpack_state(bits_list)
647
  final, _ = self.run_until_halt(state, max_cycles=max_cycles)
648
  return torch.tensor(pack_state(final), dtype=torch.float32)
649
+
650
+
651
+ def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int:
652
+ return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF)
653
+
654
+
655
+ def write_instr(mem: List[int], addr: int, instr: int) -> None:
656
+ mem[addr & 0xFFFF] = (instr >> 8) & 0xFF
657
+ mem[(addr + 1) & 0xFFFF] = instr & 0xFF
658
+
659
+
660
+ def write_addr(mem: List[int], addr: int, value: int) -> None:
661
+ mem[addr & 0xFFFF] = (value >> 8) & 0xFF
662
+ mem[(addr + 1) & 0xFFFF] = value & 0xFF
663
+
664
+
665
+ def run_smoke_test() -> None:
666
+ """Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12."""
667
+ mem = [0] * 65536
668
+
669
+ write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00))
670
+ write_addr(mem, 0x0002, 0x0100)
671
+ write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00))
672
+ write_addr(mem, 0x0006, 0x0101)
673
+ write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00))
674
+ write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00))
675
+ write_addr(mem, 0x000C, 0x0102)
676
+ write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00))
677
+
678
+ mem[0x0100] = 5
679
+ mem[0x0101] = 7
680
+
681
+ state = CPUState(
682
+ pc=0,
683
+ ir=0,
684
+ regs=[0, 0, 0, 0],
685
+ flags=[0, 0, 0, 0],
686
+ sp=0xFFFE,
687
+ ctrl=[0, 0, 0, 0],
688
+ mem=mem,
689
+ )
690
+
691
+ print("Running reference implementation...")
692
+ final, cycles = ref_run_until_halt(state, max_cycles=20)
693
+
694
+ assert final.ctrl[0] == 1, "HALT flag not set"
695
+ assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}"
696
+ assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}"
697
+ assert cycles <= 10, f"Unexpected cycle count: {cycles}"
698
+ print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}")
699
+
700
+ print("Running threshold-weight implementation...")
701
+ threshold_cpu = ThresholdCPU()
702
+ t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20)
703
+
704
+ assert t_final.ctrl[0] == 1, "Threshold HALT flag not set"
705
+ assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}"
706
+ assert t_final.mem[0x0102] == final.mem[0x0102], (
707
+ f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}"
708
+ )
709
+ assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}"
710
+ print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}")
711
+
712
+ print("Validating forward() tensor I/O...")
713
+ bits = torch.tensor(pack_state(state), dtype=torch.float32)
714
+ out_bits = threshold_cpu.forward(bits, max_cycles=20)
715
+ out_state = unpack_state([int(b) for b in out_bits.tolist()])
716
+ assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}"
717
+ assert out_state.mem[0x0102] == final.mem[0x0102], (
718
+ f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}"
719
+ )
720
+ print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}")
721
+
722
+ print("\nSmoke test: PASSED")
723
+
724
+
725
+ if __name__ == "__main__":
726
+ parser = argparse.ArgumentParser(description="8-bit Threshold CPU")
727
+ parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH, help="Path to safetensors model")
728
+ args = parser.parse_args()
729
+
730
+ if args.model != DEFAULT_MODEL_PATH:
731
+ DEFAULT_MODEL_PATH = args.model
732
+
733
+ run_smoke_test()