CharlesCNorton commited on
Commit
d09568e
·
1 Parent(s): 32eb1de

Consolidate CPU modules into single core.py

Browse files

Merge state.py, cycle.py, and threshold_cpu.py into cpu/core.py.

Files changed (3) hide show
  1. cpu/{threshold_cpu.py → core.py} +279 -70
  2. cpu/cycle.py +0 -148
  3. cpu/state.py +0 -103
cpu/{threshold_cpu.py → core.py} RENAMED
@@ -1,29 +1,41 @@
1
  """
2
- Threshold-weight runtime for the 8-bit CPU.
3
 
4
- Implements a reference cycle using the frozen circuit weights for core ALU ops.
 
5
  """
6
 
7
  from __future__ import annotations
8
 
 
9
  from pathlib import Path
10
  from typing import List, Tuple
11
 
12
  import torch
13
  from safetensors.torch import load_file
14
 
15
- from .state import CPUState, pack_state, unpack_state, REG_BITS, PC_BITS, MEM_BYTES
16
 
 
 
17
 
18
- def heaviside(x: torch.Tensor) -> torch.Tensor:
19
- return (x >= 0).float()
 
 
 
 
 
 
 
 
 
20
 
21
 
22
- def int_to_bits_msb(value: int, width: int) -> List[int]:
23
  return [(value >> (width - 1 - i)) & 1 for i in range(width)]
24
 
25
 
26
- def bits_to_int_msb(bits: List[int]) -> int:
27
  value = 0
28
  for bit in bits:
29
  value = (value << 1) | int(bit)
@@ -34,6 +46,217 @@ def bits_msb_to_lsb(bits: List[int]) -> List[int]:
34
  return list(reversed(bits))
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors"
38
 
39
 
@@ -78,8 +301,8 @@ class ThresholdALU:
78
  return ha2_sum, cout
79
 
80
  def add(self, a: int, b: int) -> Tuple[int, int, int]:
81
- a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS))
82
- b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS))
83
 
84
  carry = 0.0
85
  sum_bits: List[int] = []
@@ -89,16 +312,16 @@ class ThresholdALU:
89
  )
90
  sum_bits.append(int(sum_bit))
91
 
92
- result = bits_to_int_msb(list(reversed(sum_bits)))
93
  carry_out = int(carry)
94
  overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
95
  return result, carry_out, overflow
96
 
97
  def sub(self, a: int, b: int) -> Tuple[int, int, int]:
98
- a_bits = bits_msb_to_lsb(int_to_bits_msb(a, REG_BITS))
99
- b_bits = bits_msb_to_lsb(int_to_bits_msb(b, REG_BITS))
100
 
101
- carry = 1.0 # two's complement carry-in
102
  sum_bits: List[int] = []
103
  for bit in range(REG_BITS):
104
  notb = self._eval_gate(
@@ -128,14 +351,14 @@ class ThresholdALU:
128
 
129
  sum_bits.append(int(xor2))
130
 
131
- result = bits_to_int_msb(list(reversed(sum_bits)))
132
  carry_out = int(carry)
133
  overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
134
  return result, carry_out, overflow
135
 
136
  def bitwise_and(self, a: int, b: int) -> int:
137
- a_bits = int_to_bits_msb(a, REG_BITS)
138
- b_bits = int_to_bits_msb(b, REG_BITS)
139
  w = self._get("alu.alu8bit.and.weight")
140
  bias = self._get("alu.alu8bit.and.bias")
141
 
@@ -145,11 +368,11 @@ class ThresholdALU:
145
  out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
146
  out_bits.append(int(out))
147
 
148
- return bits_to_int_msb(out_bits)
149
 
150
  def bitwise_or(self, a: int, b: int) -> int:
151
- a_bits = int_to_bits_msb(a, REG_BITS)
152
- b_bits = int_to_bits_msb(b, REG_BITS)
153
  w = self._get("alu.alu8bit.or.weight")
154
  bias = self._get("alu.alu8bit.or.bias")
155
 
@@ -159,10 +382,10 @@ class ThresholdALU:
159
  out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
160
  out_bits.append(int(out))
161
 
162
- return bits_to_int_msb(out_bits)
163
 
164
  def bitwise_not(self, a: int) -> int:
165
- a_bits = int_to_bits_msb(a, REG_BITS)
166
  w = self._get("alu.alu8bit.not.weight")
167
  bias = self._get("alu.alu8bit.not.bias")
168
 
@@ -172,11 +395,11 @@ class ThresholdALU:
172
  out = heaviside((inp * w[bit]).sum() + bias[bit]).item()
173
  out_bits.append(int(out))
174
 
175
- return bits_to_int_msb(out_bits)
176
 
177
  def bitwise_xor(self, a: int, b: int) -> int:
178
- a_bits = int_to_bits_msb(a, REG_BITS)
179
- b_bits = int_to_bits_msb(b, REG_BITS)
180
 
181
  w_or = self._get("alu.alu8bit.xor.layer1.or.weight")
182
  b_or = self._get("alu.alu8bit.xor.layer1.or.bias")
@@ -194,7 +417,7 @@ class ThresholdALU:
194
  out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item()
195
  out_bits.append(int(out))
196
 
197
- return bits_to_int_msb(out_bits)
198
 
199
 
200
  class ThresholdCPU:
@@ -202,24 +425,8 @@ class ThresholdCPU:
202
  self.device = device
203
  self.alu = ThresholdALU(str(model_path), device=device)
204
 
205
- @staticmethod
206
- def decode_ir(ir: int) -> Tuple[int, int, int, int]:
207
- opcode = (ir >> 12) & 0xF
208
- rd = (ir >> 10) & 0x3
209
- rs = (ir >> 8) & 0x3
210
- imm8 = ir & 0xFF
211
- return opcode, rd, rs, imm8
212
-
213
- @staticmethod
214
- def flags_from_result(result: int, carry: int, overflow: int) -> List[int]:
215
- z = 1 if result == 0 else 0
216
- n = 1 if (result & 0x80) else 0
217
- c = 1 if carry else 0
218
- v = 1 if overflow else 0
219
- return [z, n, c, v]
220
-
221
  def _addr_decode(self, addr: int) -> torch.Tensor:
222
- bits = torch.tensor(int_to_bits_msb(addr, PC_BITS), device=self.device, dtype=torch.float32)
223
  w = self.alu._get("memory.addr_decode.weight")
224
  b = self.alu._get("memory.addr_decode.bias")
225
  return heaviside((w * bits).sum(dim=1) + b)
@@ -227,7 +434,7 @@ class ThresholdCPU:
227
  def _memory_read(self, mem: List[int], addr: int) -> int:
228
  sel = self._addr_decode(addr)
229
  mem_bits = torch.tensor(
230
- [int_to_bits_msb(byte, REG_BITS) for byte in mem],
231
  device=self.device,
232
  dtype=torch.float32,
233
  )
@@ -243,13 +450,13 @@ class ThresholdCPU:
243
  out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()
244
  out_bits.append(int(out_bit))
245
 
246
- return bits_to_int_msb(out_bits)
247
 
248
  def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]:
249
  sel = self._addr_decode(addr)
250
- data_bits = torch.tensor(int_to_bits_msb(value, REG_BITS), device=self.device, dtype=torch.float32)
251
  mem_bits = torch.tensor(
252
- [int_to_bits_msb(byte, REG_BITS) for byte in mem],
253
  device=self.device,
254
  dtype=torch.float32,
255
  )
@@ -283,11 +490,11 @@ class ThresholdCPU:
283
  out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
284
  new_mem_bits[:, bit] = out_bit
285
 
286
- return [bits_to_int_msb([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)]
287
 
288
  def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int:
289
- pc_bits = int_to_bits_msb(pc_byte, REG_BITS)
290
- target_bits = int_to_bits_msb(target_byte, REG_BITS)
291
 
292
  out_bits: List[int] = []
293
  for bit in range(REG_BITS):
@@ -313,21 +520,21 @@ class ThresholdCPU:
313
  )
314
  out_bits.append(int(out_bit))
315
 
316
- return bits_to_int_msb(out_bits)
317
 
318
  def step(self, state: CPUState) -> CPUState:
319
- if state.ctrl[0] == 1: # HALT
 
320
  return state.copy()
321
 
322
  s = state.copy()
323
 
324
- # Fetch: two bytes, big-endian
325
  hi = self._memory_read(s.mem, s.pc)
326
  lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF)
327
  s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
328
  next_pc = (s.pc + 2) & 0xFFFF
329
 
330
- opcode, rd, rs, imm8 = self.decode_ir(s.ir)
331
  a = s.regs[rd]
332
  b = s.regs[rs]
333
 
@@ -344,45 +551,45 @@ class ThresholdCPU:
344
  carry = 0
345
  overflow = 0
346
 
347
- if opcode == 0x0: # ADD
348
  result, carry, overflow = self.alu.add(a, b)
349
- elif opcode == 0x1: # SUB
350
  result, carry, overflow = self.alu.sub(a, b)
351
- elif opcode == 0x2: # AND
352
  result = self.alu.bitwise_and(a, b)
353
- elif opcode == 0x3: # OR
354
  result = self.alu.bitwise_or(a, b)
355
- elif opcode == 0x4: # XOR
356
  result = self.alu.bitwise_xor(a, b)
357
- elif opcode == 0x5: # SHL
358
  carry = 1 if (a & 0x80) else 0
359
  result = (a << 1) & 0xFF
360
- elif opcode == 0x6: # SHR
361
  carry = 1 if (a & 0x01) else 0
362
  result = (a >> 1) & 0xFF
363
- elif opcode == 0x7: # MUL
364
  full = a * b
365
  result = full & 0xFF
366
  carry = 1 if full > 0xFF else 0
367
- elif opcode == 0x8: # DIV
368
  if b == 0:
369
  result = 0
370
  carry = 1
371
  overflow = 1
372
  else:
373
  result = (a // b) & 0xFF
374
- elif opcode == 0x9: # CMP
375
  result, carry, overflow = self.alu.sub(a, b)
376
  write_result = False
377
- elif opcode == 0xA: # LOAD
378
  result = self._memory_read(s.mem, addr16)
379
- elif opcode == 0xB: # STORE
380
  s.mem = self._memory_write(s.mem, addr16, b & 0xFF)
381
  write_result = False
382
- elif opcode == 0xC: # JMP
383
  s.pc = addr16 & 0xFFFF
384
  write_result = False
385
- elif opcode == 0xD: # JZ
386
  hi_pc = self._conditional_jump_byte(
387
  "control.jz",
388
  (next_pc_ext >> 8) & 0xFF,
@@ -397,7 +604,7 @@ class ThresholdCPU:
397
  )
398
  s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF)
399
  write_result = False
400
- elif opcode == 0xE: # CALL
401
  ret_addr = next_pc_ext & 0xFFFF
402
  s.sp = (s.sp - 1) & 0xFFFF
403
  s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF)
@@ -405,12 +612,12 @@ class ThresholdCPU:
405
  s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF)
406
  s.pc = addr16 & 0xFFFF
407
  write_result = False
408
- elif opcode == 0xF: # HALT
409
  s.ctrl[0] = 1
410
  write_result = False
411
 
412
  if opcode <= 0x9 or opcode == 0xA:
413
- s.flags = self.flags_from_result(result, carry, overflow)
414
 
415
  if write_result:
416
  s.regs[rd] = result & 0xFF
@@ -421,6 +628,7 @@ class ThresholdCPU:
421
  return s
422
 
423
  def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
 
424
  s = state.copy()
425
  for i in range(max_cycles):
426
  if s.ctrl[0] == 1:
@@ -429,6 +637,7 @@ class ThresholdCPU:
429
  return s, max_cycles
430
 
431
  def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor:
 
432
  bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()]
433
  state = unpack_state(bits_list)
434
  final, _ = self.run_until_halt(state, max_cycles=max_cycles)
 
1
  """
2
+ 8-bit Threshold Computer - Combined CPU Module
3
 
4
+ State layout, reference cycle, and threshold-weight runtime in one file.
5
+ All multi-bit fields are MSB-first.
6
  """
7
 
8
  from __future__ import annotations
9
 
10
+ from dataclasses import dataclass
11
  from pathlib import Path
12
  from typing import List, Tuple
13
 
14
  import torch
15
  from safetensors.torch import load_file
16
 
 
17
 
18
+ FLAG_NAMES = ["Z", "N", "C", "V"]
19
+ CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"]
20
 
21
+ PC_BITS = 16
22
+ IR_BITS = 16
23
+ REG_BITS = 8
24
+ REG_COUNT = 4
25
+ FLAG_BITS = 4
26
+ SP_BITS = 16
27
+ CTRL_BITS = 4
28
+ MEM_BYTES = 65536
29
+ MEM_BITS = MEM_BYTES * 8
30
+
31
+ STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
32
 
33
 
34
+ def int_to_bits(value: int, width: int) -> List[int]:
35
  return [(value >> (width - 1 - i)) & 1 for i in range(width)]
36
 
37
 
38
+ def bits_to_int(bits: List[int]) -> int:
39
  value = 0
40
  for bit in bits:
41
  value = (value << 1) | int(bit)
 
46
  return list(reversed(bits))
47
 
48
 
49
+ @dataclass
50
+ class CPUState:
51
+ pc: int
52
+ ir: int
53
+ regs: List[int]
54
+ flags: List[int]
55
+ sp: int
56
+ ctrl: List[int]
57
+ mem: List[int]
58
+
59
+ def copy(self) -> CPUState:
60
+ return CPUState(
61
+ pc=int(self.pc),
62
+ ir=int(self.ir),
63
+ regs=[int(r) for r in self.regs],
64
+ flags=[int(f) for f in self.flags],
65
+ sp=int(self.sp),
66
+ ctrl=[int(c) for c in self.ctrl],
67
+ mem=[int(m) for m in self.mem],
68
+ )
69
+
70
+
71
+ def pack_state(state: CPUState) -> List[int]:
72
+ bits: List[int] = []
73
+ bits.extend(int_to_bits(state.pc, PC_BITS))
74
+ bits.extend(int_to_bits(state.ir, IR_BITS))
75
+ for reg in state.regs:
76
+ bits.extend(int_to_bits(reg, REG_BITS))
77
+ bits.extend([int(f) for f in state.flags])
78
+ bits.extend(int_to_bits(state.sp, SP_BITS))
79
+ bits.extend([int(c) for c in state.ctrl])
80
+ for byte in state.mem:
81
+ bits.extend(int_to_bits(byte, REG_BITS))
82
+ return bits
83
+
84
+
85
+ def unpack_state(bits: List[int]) -> CPUState:
86
+ if len(bits) != STATE_BITS:
87
+ raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}")
88
+
89
+ idx = 0
90
+ pc = bits_to_int(bits[idx:idx + PC_BITS])
91
+ idx += PC_BITS
92
+ ir = bits_to_int(bits[idx:idx + IR_BITS])
93
+ idx += IR_BITS
94
+
95
+ regs = []
96
+ for _ in range(REG_COUNT):
97
+ regs.append(bits_to_int(bits[idx:idx + REG_BITS]))
98
+ idx += REG_BITS
99
+
100
+ flags = [int(b) for b in bits[idx:idx + FLAG_BITS]]
101
+ idx += FLAG_BITS
102
+
103
+ sp = bits_to_int(bits[idx:idx + SP_BITS])
104
+ idx += SP_BITS
105
+
106
+ ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]]
107
+ idx += CTRL_BITS
108
+
109
+ mem = []
110
+ for _ in range(MEM_BYTES):
111
+ mem.append(bits_to_int(bits[idx:idx + REG_BITS]))
112
+ idx += REG_BITS
113
+
114
+ return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem)
115
+
116
+
117
+ def decode_ir(ir: int) -> Tuple[int, int, int, int]:
118
+ opcode = (ir >> 12) & 0xF
119
+ rd = (ir >> 10) & 0x3
120
+ rs = (ir >> 8) & 0x3
121
+ imm8 = ir & 0xFF
122
+ return opcode, rd, rs, imm8
123
+
124
+
125
+ def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]:
126
+ z = 1 if result == 0 else 0
127
+ n = 1 if (result & 0x80) else 0
128
+ c = 1 if carry else 0
129
+ v = 1 if overflow else 0
130
+ return z, n, c, v
131
+
132
+
133
+ def alu_add(a: int, b: int) -> Tuple[int, int, int]:
134
+ full = a + b
135
+ result = full & 0xFF
136
+ carry = 1 if full > 0xFF else 0
137
+ overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
138
+ return result, carry, overflow
139
+
140
+
141
+ def alu_sub(a: int, b: int) -> Tuple[int, int, int]:
142
+ full = (a - b) & 0x1FF
143
+ result = full & 0xFF
144
+ carry = 1 if a >= b else 0
145
+ overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
146
+ return result, carry, overflow
147
+
148
+
149
+ def ref_step(state: CPUState) -> CPUState:
150
+ """Reference CPU cycle (pure Python arithmetic)."""
151
+ if state.ctrl[0] == 1:
152
+ return state.copy()
153
+
154
+ s = state.copy()
155
+
156
+ hi = s.mem[s.pc]
157
+ lo = s.mem[(s.pc + 1) & 0xFFFF]
158
+ s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
159
+ next_pc = (s.pc + 2) & 0xFFFF
160
+
161
+ opcode, rd, rs, imm8 = decode_ir(s.ir)
162
+ a = s.regs[rd]
163
+ b = s.regs[rs]
164
+
165
+ addr16 = None
166
+ next_pc_ext = next_pc
167
+ if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
168
+ addr_hi = s.mem[next_pc]
169
+ addr_lo = s.mem[(next_pc + 1) & 0xFFFF]
170
+ addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
171
+ next_pc_ext = (next_pc + 2) & 0xFFFF
172
+
173
+ write_result = True
174
+ result = a
175
+ carry = 0
176
+ overflow = 0
177
+
178
+ if opcode == 0x0:
179
+ result, carry, overflow = alu_add(a, b)
180
+ elif opcode == 0x1:
181
+ result, carry, overflow = alu_sub(a, b)
182
+ elif opcode == 0x2:
183
+ result = a & b
184
+ elif opcode == 0x3:
185
+ result = a | b
186
+ elif opcode == 0x4:
187
+ result = a ^ b
188
+ elif opcode == 0x5:
189
+ carry = 1 if (a & 0x80) else 0
190
+ result = (a << 1) & 0xFF
191
+ elif opcode == 0x6:
192
+ carry = 1 if (a & 0x01) else 0
193
+ result = (a >> 1) & 0xFF
194
+ elif opcode == 0x7:
195
+ full = a * b
196
+ result = full & 0xFF
197
+ carry = 1 if full > 0xFF else 0
198
+ elif opcode == 0x8:
199
+ if b == 0:
200
+ result = 0
201
+ carry = 1
202
+ overflow = 1
203
+ else:
204
+ result = (a // b) & 0xFF
205
+ elif opcode == 0x9:
206
+ result, carry, overflow = alu_sub(a, b)
207
+ write_result = False
208
+ elif opcode == 0xA:
209
+ result = s.mem[addr16]
210
+ elif opcode == 0xB:
211
+ s.mem[addr16] = b & 0xFF
212
+ write_result = False
213
+ elif opcode == 0xC:
214
+ s.pc = addr16 & 0xFFFF
215
+ write_result = False
216
+ elif opcode == 0xD:
217
+ if s.flags[0] == 1:
218
+ s.pc = addr16 & 0xFFFF
219
+ else:
220
+ s.pc = next_pc_ext
221
+ write_result = False
222
+ elif opcode == 0xE:
223
+ ret_addr = next_pc_ext & 0xFFFF
224
+ s.sp = (s.sp - 1) & 0xFFFF
225
+ s.mem[s.sp] = (ret_addr >> 8) & 0xFF
226
+ s.sp = (s.sp - 1) & 0xFFFF
227
+ s.mem[s.sp] = ret_addr & 0xFF
228
+ s.pc = addr16 & 0xFFFF
229
+ write_result = False
230
+ elif opcode == 0xF:
231
+ s.ctrl[0] = 1
232
+ write_result = False
233
+
234
+ if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8):
235
+ s.flags = list(flags_from_result(result, carry, overflow))
236
+
237
+ if write_result:
238
+ s.regs[rd] = result & 0xFF
239
+
240
+ if opcode not in (0xC, 0xD, 0xE):
241
+ s.pc = next_pc_ext
242
+
243
+ return s
244
+
245
+
246
+ def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
247
+ """Reference execution loop."""
248
+ s = state.copy()
249
+ for i in range(max_cycles):
250
+ if s.ctrl[0] == 1:
251
+ return s, i
252
+ s = ref_step(s)
253
+ return s, max_cycles
254
+
255
+
256
+ def heaviside(x: torch.Tensor) -> torch.Tensor:
257
+ return (x >= 0).float()
258
+
259
+
260
  DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent / "neural_computer.safetensors"
261
 
262
 
 
301
  return ha2_sum, cout
302
 
303
  def add(self, a: int, b: int) -> Tuple[int, int, int]:
304
+ a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
305
+ b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
306
 
307
  carry = 0.0
308
  sum_bits: List[int] = []
 
312
  )
313
  sum_bits.append(int(sum_bit))
314
 
315
+ result = bits_to_int(list(reversed(sum_bits)))
316
  carry_out = int(carry)
317
  overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
318
  return result, carry_out, overflow
319
 
320
  def sub(self, a: int, b: int) -> Tuple[int, int, int]:
321
+ a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
322
+ b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
323
 
324
+ carry = 1.0
325
  sum_bits: List[int] = []
326
  for bit in range(REG_BITS):
327
  notb = self._eval_gate(
 
351
 
352
  sum_bits.append(int(xor2))
353
 
354
+ result = bits_to_int(list(reversed(sum_bits)))
355
  carry_out = int(carry)
356
  overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
357
  return result, carry_out, overflow
358
 
359
  def bitwise_and(self, a: int, b: int) -> int:
360
+ a_bits = int_to_bits(a, REG_BITS)
361
+ b_bits = int_to_bits(b, REG_BITS)
362
  w = self._get("alu.alu8bit.and.weight")
363
  bias = self._get("alu.alu8bit.and.bias")
364
 
 
368
  out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
369
  out_bits.append(int(out))
370
 
371
+ return bits_to_int(out_bits)
372
 
373
  def bitwise_or(self, a: int, b: int) -> int:
374
+ a_bits = int_to_bits(a, REG_BITS)
375
+ b_bits = int_to_bits(b, REG_BITS)
376
  w = self._get("alu.alu8bit.or.weight")
377
  bias = self._get("alu.alu8bit.or.bias")
378
 
 
382
  out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
383
  out_bits.append(int(out))
384
 
385
+ return bits_to_int(out_bits)
386
 
387
  def bitwise_not(self, a: int) -> int:
388
+ a_bits = int_to_bits(a, REG_BITS)
389
  w = self._get("alu.alu8bit.not.weight")
390
  bias = self._get("alu.alu8bit.not.bias")
391
 
 
395
  out = heaviside((inp * w[bit]).sum() + bias[bit]).item()
396
  out_bits.append(int(out))
397
 
398
+ return bits_to_int(out_bits)
399
 
400
  def bitwise_xor(self, a: int, b: int) -> int:
401
+ a_bits = int_to_bits(a, REG_BITS)
402
+ b_bits = int_to_bits(b, REG_BITS)
403
 
404
  w_or = self._get("alu.alu8bit.xor.layer1.or.weight")
405
  b_or = self._get("alu.alu8bit.xor.layer1.or.bias")
 
417
  out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item()
418
  out_bits.append(int(out))
419
 
420
+ return bits_to_int(out_bits)
421
 
422
 
423
  class ThresholdCPU:
 
425
  self.device = device
426
  self.alu = ThresholdALU(str(model_path), device=device)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  def _addr_decode(self, addr: int) -> torch.Tensor:
429
+ bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32)
430
  w = self.alu._get("memory.addr_decode.weight")
431
  b = self.alu._get("memory.addr_decode.bias")
432
  return heaviside((w * bits).sum(dim=1) + b)
 
434
  def _memory_read(self, mem: List[int], addr: int) -> int:
435
  sel = self._addr_decode(addr)
436
  mem_bits = torch.tensor(
437
+ [int_to_bits(byte, REG_BITS) for byte in mem],
438
  device=self.device,
439
  dtype=torch.float32,
440
  )
 
450
  out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()
451
  out_bits.append(int(out_bit))
452
 
453
+ return bits_to_int(out_bits)
454
 
455
  def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]:
456
  sel = self._addr_decode(addr)
457
+ data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32)
458
  mem_bits = torch.tensor(
459
+ [int_to_bits(byte, REG_BITS) for byte in mem],
460
  device=self.device,
461
  dtype=torch.float32,
462
  )
 
490
  out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
491
  new_mem_bits[:, bit] = out_bit
492
 
493
+ return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)]
494
 
495
  def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int:
496
+ pc_bits = int_to_bits(pc_byte, REG_BITS)
497
+ target_bits = int_to_bits(target_byte, REG_BITS)
498
 
499
  out_bits: List[int] = []
500
  for bit in range(REG_BITS):
 
520
  )
521
  out_bits.append(int(out_bit))
522
 
523
+ return bits_to_int(out_bits)
524
 
525
  def step(self, state: CPUState) -> CPUState:
526
+ """Single CPU cycle using threshold neurons."""
527
+ if state.ctrl[0] == 1:
528
  return state.copy()
529
 
530
  s = state.copy()
531
 
 
532
  hi = self._memory_read(s.mem, s.pc)
533
  lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF)
534
  s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
535
  next_pc = (s.pc + 2) & 0xFFFF
536
 
537
+ opcode, rd, rs, imm8 = decode_ir(s.ir)
538
  a = s.regs[rd]
539
  b = s.regs[rs]
540
 
 
551
  carry = 0
552
  overflow = 0
553
 
554
+ if opcode == 0x0:
555
  result, carry, overflow = self.alu.add(a, b)
556
+ elif opcode == 0x1:
557
  result, carry, overflow = self.alu.sub(a, b)
558
+ elif opcode == 0x2:
559
  result = self.alu.bitwise_and(a, b)
560
+ elif opcode == 0x3:
561
  result = self.alu.bitwise_or(a, b)
562
+ elif opcode == 0x4:
563
  result = self.alu.bitwise_xor(a, b)
564
+ elif opcode == 0x5:
565
  carry = 1 if (a & 0x80) else 0
566
  result = (a << 1) & 0xFF
567
+ elif opcode == 0x6:
568
  carry = 1 if (a & 0x01) else 0
569
  result = (a >> 1) & 0xFF
570
+ elif opcode == 0x7:
571
  full = a * b
572
  result = full & 0xFF
573
  carry = 1 if full > 0xFF else 0
574
+ elif opcode == 0x8:
575
  if b == 0:
576
  result = 0
577
  carry = 1
578
  overflow = 1
579
  else:
580
  result = (a // b) & 0xFF
581
+ elif opcode == 0x9:
582
  result, carry, overflow = self.alu.sub(a, b)
583
  write_result = False
584
+ elif opcode == 0xA:
585
  result = self._memory_read(s.mem, addr16)
586
+ elif opcode == 0xB:
587
  s.mem = self._memory_write(s.mem, addr16, b & 0xFF)
588
  write_result = False
589
+ elif opcode == 0xC:
590
  s.pc = addr16 & 0xFFFF
591
  write_result = False
592
+ elif opcode == 0xD:
593
  hi_pc = self._conditional_jump_byte(
594
  "control.jz",
595
  (next_pc_ext >> 8) & 0xFF,
 
604
  )
605
  s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF)
606
  write_result = False
607
+ elif opcode == 0xE:
608
  ret_addr = next_pc_ext & 0xFFFF
609
  s.sp = (s.sp - 1) & 0xFFFF
610
  s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF)
 
612
  s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF)
613
  s.pc = addr16 & 0xFFFF
614
  write_result = False
615
+ elif opcode == 0xF:
616
  s.ctrl[0] = 1
617
  write_result = False
618
 
619
  if opcode <= 0x9 or opcode == 0xA:
620
+ s.flags = list(flags_from_result(result, carry, overflow))
621
 
622
  if write_result:
623
  s.regs[rd] = result & 0xFF
 
628
  return s
629
 
630
  def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
631
+ """Execute until HALT or max_cycles reached."""
632
  s = state.copy()
633
  for i in range(max_cycles):
634
  if s.ctrl[0] == 1:
 
637
  return s, max_cycles
638
 
639
  def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor:
640
+ """Tensor-in, tensor-out interface for neural integration."""
641
  bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()]
642
  state = unpack_state(bits_list)
643
  final, _ = self.run_until_halt(state, max_cycles=max_cycles)
cpu/cycle.py DELETED
@@ -1,148 +0,0 @@
1
- """
2
- Reference CPU cycle (software) for the threshold computer.
3
- Implements fetch/decode/execute over the state layout.
4
- """
5
-
6
- from __future__ import annotations
7
-
8
- from typing import Tuple
9
-
10
- from .state import CPUState
11
-
12
-
13
- def decode_ir(ir: int) -> Tuple[int, int, int, int]:
14
- opcode = (ir >> 12) & 0xF
15
- rd = (ir >> 10) & 0x3
16
- rs = (ir >> 8) & 0x3
17
- imm8 = ir & 0xFF
18
- return opcode, rd, rs, imm8
19
-
20
-
21
- def _flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]:
22
- z = 1 if result == 0 else 0
23
- n = 1 if (result & 0x80) else 0
24
- c = 1 if carry else 0
25
- v = 1 if overflow else 0
26
- return z, n, c, v
27
-
28
-
29
- def _alu_add(a: int, b: int) -> Tuple[int, int, int]:
30
- full = a + b
31
- result = full & 0xFF
32
- carry = 1 if full > 0xFF else 0
33
- overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
34
- return result, carry, overflow
35
-
36
-
37
- def _alu_sub(a: int, b: int) -> Tuple[int, int, int]:
38
- full = (a - b) & 0x1FF
39
- result = full & 0xFF
40
- carry = 1 if a >= b else 0
41
- overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
42
- return result, carry, overflow
43
-
44
-
45
- def step(state: CPUState) -> CPUState:
46
- if state.ctrl[0] == 1: # HALT
47
- return state.copy()
48
-
49
- s = state.copy()
50
-
51
- # Fetch: two bytes, big-endian
52
- hi = s.mem[s.pc]
53
- lo = s.mem[(s.pc + 1) & 0xFFFF]
54
- s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
55
- next_pc = (s.pc + 2) & 0xFFFF
56
-
57
- opcode, rd, rs, imm8 = decode_ir(s.ir)
58
- a = s.regs[rd]
59
- b = s.regs[rs]
60
-
61
- addr16 = None
62
- next_pc_ext = next_pc
63
- if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
64
- addr_hi = s.mem[next_pc]
65
- addr_lo = s.mem[(next_pc + 1) & 0xFFFF]
66
- addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
67
- next_pc_ext = (next_pc + 2) & 0xFFFF
68
-
69
- write_result = True
70
- result = a
71
- carry = 0
72
- overflow = 0
73
-
74
- if opcode == 0x0: # ADD
75
- result, carry, overflow = _alu_add(a, b)
76
- elif opcode == 0x1: # SUB
77
- result, carry, overflow = _alu_sub(a, b)
78
- elif opcode == 0x2: # AND
79
- result = a & b
80
- elif opcode == 0x3: # OR
81
- result = a | b
82
- elif opcode == 0x4: # XOR
83
- result = a ^ b
84
- elif opcode == 0x5: # SHL
85
- carry = 1 if (a & 0x80) else 0
86
- result = (a << 1) & 0xFF
87
- elif opcode == 0x6: # SHR
88
- carry = 1 if (a & 0x01) else 0
89
- result = (a >> 1) & 0xFF
90
- elif opcode == 0x7: # MUL
91
- full = a * b
92
- result = full & 0xFF
93
- carry = 1 if full > 0xFF else 0
94
- elif opcode == 0x8: # DIV
95
- if b == 0:
96
- result = 0
97
- carry = 1
98
- overflow = 1
99
- else:
100
- result = (a // b) & 0xFF
101
- elif opcode == 0x9: # CMP
102
- result, carry, overflow = _alu_sub(a, b)
103
- write_result = False
104
- elif opcode == 0xA: # LOAD
105
- result = s.mem[addr16]
106
- elif opcode == 0xB: # STORE
107
- s.mem[addr16] = b & 0xFF
108
- write_result = False
109
- elif opcode == 0xC: # JMP
110
- s.pc = addr16 & 0xFFFF
111
- write_result = False
112
- elif opcode == 0xD: # JZ
113
- if s.flags[0] == 1:
114
- s.pc = addr16 & 0xFFFF
115
- else:
116
- s.pc = next_pc_ext
117
- write_result = False
118
- elif opcode == 0xE: # CALL
119
- ret_addr = next_pc_ext & 0xFFFF
120
- s.sp = (s.sp - 1) & 0xFFFF
121
- s.mem[s.sp] = (ret_addr >> 8) & 0xFF
122
- s.sp = (s.sp - 1) & 0xFFFF
123
- s.mem[s.sp] = ret_addr & 0xFF
124
- s.pc = addr16 & 0xFFFF
125
- write_result = False
126
- elif opcode == 0xF: # HALT
127
- s.ctrl[0] = 1
128
- write_result = False
129
-
130
- if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8):
131
- s.flags = list(_flags_from_result(result, carry, overflow))
132
-
133
- if write_result:
134
- s.regs[rd] = result & 0xFF
135
-
136
- if opcode not in (0xC, 0xD, 0xE):
137
- s.pc = next_pc_ext
138
-
139
- return s
140
-
141
-
142
- def run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
143
- s = state.copy()
144
- for i in range(max_cycles):
145
- if s.ctrl[0] == 1:
146
- return s, i
147
- s = step(s)
148
- return s, max_cycles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cpu/state.py DELETED
@@ -1,103 +0,0 @@
1
- """
2
- State layout helpers for the 8-bit threshold computer.
3
- All multi-bit fields are MSB-first.
4
- """
5
-
6
- from __future__ import annotations
7
-
8
- from dataclasses import dataclass
9
- from typing import List
10
-
11
- FLAG_NAMES = ["Z", "N", "C", "V"]
12
- CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"]
13
-
14
- PC_BITS = 16
15
- IR_BITS = 16
16
- REG_BITS = 8
17
- REG_COUNT = 4
18
- FLAG_BITS = 4
19
- SP_BITS = 16
20
- CTRL_BITS = 4
21
- MEM_BYTES = 65536
22
- MEM_BITS = MEM_BYTES * 8
23
-
24
- STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
25
-
26
-
27
- def int_to_bits(value: int, width: int) -> List[int]:
28
- return [(value >> (width - 1 - i)) & 1 for i in range(width)]
29
-
30
-
31
- def bits_to_int(bits: List[int]) -> int:
32
- value = 0
33
- for bit in bits:
34
- value = (value << 1) | int(bit)
35
- return value
36
-
37
-
38
- @dataclass
39
- class CPUState:
40
- pc: int
41
- ir: int
42
- regs: List[int]
43
- flags: List[int]
44
- sp: int
45
- ctrl: List[int]
46
- mem: List[int]
47
-
48
- def copy(self) -> "CPUState":
49
- return CPUState(
50
- pc=int(self.pc),
51
- ir=int(self.ir),
52
- regs=[int(r) for r in self.regs],
53
- flags=[int(f) for f in self.flags],
54
- sp=int(self.sp),
55
- ctrl=[int(c) for c in self.ctrl],
56
- mem=[int(m) for m in self.mem],
57
- )
58
-
59
-
60
- def pack_state(state: CPUState) -> List[int]:
61
- bits: List[int] = []
62
- bits.extend(int_to_bits(state.pc, PC_BITS))
63
- bits.extend(int_to_bits(state.ir, IR_BITS))
64
- for reg in state.regs:
65
- bits.extend(int_to_bits(reg, REG_BITS))
66
- bits.extend([int(f) for f in state.flags])
67
- bits.extend(int_to_bits(state.sp, SP_BITS))
68
- bits.extend([int(c) for c in state.ctrl])
69
- for byte in state.mem:
70
- bits.extend(int_to_bits(byte, REG_BITS))
71
- return bits
72
-
73
-
74
- def unpack_state(bits: List[int]) -> CPUState:
75
- if len(bits) != STATE_BITS:
76
- raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}")
77
-
78
- idx = 0
79
- pc = bits_to_int(bits[idx:idx + PC_BITS])
80
- idx += PC_BITS
81
- ir = bits_to_int(bits[idx:idx + IR_BITS])
82
- idx += IR_BITS
83
-
84
- regs = []
85
- for _ in range(REG_COUNT):
86
- regs.append(bits_to_int(bits[idx:idx + REG_BITS]))
87
- idx += REG_BITS
88
-
89
- flags = [int(b) for b in bits[idx:idx + FLAG_BITS]]
90
- idx += FLAG_BITS
91
-
92
- sp = bits_to_int(bits[idx:idx + SP_BITS])
93
- idx += SP_BITS
94
-
95
- ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]]
96
- idx += CTRL_BITS
97
-
98
- mem = []
99
- for _ in range(MEM_BYTES):
100
- mem.append(bits_to_int(bits[idx:idx + REG_BITS]))
101
- idx += REG_BITS
102
-
103
- return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem)