CharlesCNorton commited on
Commit
3a45f0c
·
1 Parent(s): d15242d

Fold threshold_cpu.py into eval.py

Browse files

- Move CPU runtime (ThresholdALU, ThresholdCPU, CPUState) into eval.py
- Add --cpu-test CLI flag for smoke test
- Delete threshold_cpu.py
- eval.py now serves as single entry point for both circuit evaluation and CPU execution

Files changed (2) hide show
  1. eval.py +812 -1
  2. threshold_cpu.py +0 -842
eval.py CHANGED
@@ -2,14 +2,17 @@
2
  Unified Evaluation Suite for 8-bit Threshold Computer
3
  ======================================================
4
  GPU-batched evaluation with per-circuit reporting.
 
5
 
6
  Usage:
7
- python eval.py # Run evaluation
8
  python eval.py --device cpu # CPU mode
9
  python eval.py --pop_size 1000 # Population mode for evolution
 
10
 
11
  API (for prune_weights.py):
12
  from eval import load_model, create_population, BatchedFitnessEvaluator
 
13
  """
14
 
15
  import argparse
@@ -76,6 +79,810 @@ def create_population(
76
  }
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  class BatchedFitnessEvaluator:
80
  """
81
  GPU-batched fitness evaluator with per-circuit reporting.
@@ -2698,8 +3505,12 @@ def main():
2698
  parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu')
2699
  parser.add_argument('--pop_size', type=int, default=1, help='Population size for batched evaluation')
2700
  parser.add_argument('--quiet', action='store_true', help='Suppress detailed output')
 
2701
  args = parser.parse_args()
2702
 
 
 
 
2703
  print("=" * 70)
2704
  print(" UNIFIED EVALUATION SUITE")
2705
  print("=" * 70)
 
2
  Unified Evaluation Suite for 8-bit Threshold Computer
3
  ======================================================
4
  GPU-batched evaluation with per-circuit reporting.
5
+ Includes CPU runtime for threshold-weight execution.
6
 
7
  Usage:
8
+ python eval.py # Run circuit evaluation
9
  python eval.py --device cpu # CPU mode
10
  python eval.py --pop_size 1000 # Population mode for evolution
11
+ python eval.py --cpu-test # Run CPU smoke test
12
 
13
  API (for prune_weights.py):
14
  from eval import load_model, create_population, BatchedFitnessEvaluator
15
+ from eval import ThresholdCPU, ThresholdALU, CPUState
16
  """
17
 
18
  import argparse
 
79
  }
80
 
81
 
82
+ # =============================================================================
83
+ # CPU RUNTIME
84
+ # =============================================================================
85
+
86
+ FLAG_NAMES = ["Z", "N", "C", "V"]
87
+ CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"]
88
+
89
+ PC_BITS = 16
90
+ IR_BITS = 16
91
+ REG_BITS = 8
92
+ REG_COUNT = 4
93
+ FLAG_BITS = 4
94
+ SP_BITS = 16
95
+ CTRL_BITS = 4
96
+ MEM_BYTES = 65536
97
+ MEM_BITS = MEM_BYTES * 8
98
+
99
+ STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
100
+
101
+
102
+ def int_to_bits(value: int, width: int) -> List[int]:
103
+ return [(value >> (width - 1 - i)) & 1 for i in range(width)]
104
+
105
+
106
+ def bits_to_int(bits: List[int]) -> int:
107
+ value = 0
108
+ for bit in bits:
109
+ value = (value << 1) | int(bit)
110
+ return value
111
+
112
+
113
+ def bits_msb_to_lsb(bits: List[int]) -> List[int]:
114
+ return list(reversed(bits))
115
+
116
+
117
+ @dataclass
118
+ class CPUState:
119
+ pc: int
120
+ ir: int
121
+ regs: List[int]
122
+ flags: List[int]
123
+ sp: int
124
+ ctrl: List[int]
125
+ mem: List[int]
126
+
127
+ def copy(self) -> 'CPUState':
128
+ return CPUState(
129
+ pc=int(self.pc),
130
+ ir=int(self.ir),
131
+ regs=[int(r) for r in self.regs],
132
+ flags=[int(f) for f in self.flags],
133
+ sp=int(self.sp),
134
+ ctrl=[int(c) for c in self.ctrl],
135
+ mem=[int(m) for m in self.mem],
136
+ )
137
+
138
+
139
+ def pack_state(state: CPUState) -> List[int]:
140
+ bits: List[int] = []
141
+ bits.extend(int_to_bits(state.pc, PC_BITS))
142
+ bits.extend(int_to_bits(state.ir, IR_BITS))
143
+ for reg in state.regs:
144
+ bits.extend(int_to_bits(reg, REG_BITS))
145
+ bits.extend([int(f) for f in state.flags])
146
+ bits.extend(int_to_bits(state.sp, SP_BITS))
147
+ bits.extend([int(c) for c in state.ctrl])
148
+ for byte in state.mem:
149
+ bits.extend(int_to_bits(byte, REG_BITS))
150
+ return bits
151
+
152
+
153
+ def unpack_state(bits: List[int]) -> CPUState:
154
+ if len(bits) != STATE_BITS:
155
+ raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}")
156
+
157
+ idx = 0
158
+ pc = bits_to_int(bits[idx:idx + PC_BITS])
159
+ idx += PC_BITS
160
+ ir = bits_to_int(bits[idx:idx + IR_BITS])
161
+ idx += IR_BITS
162
+
163
+ regs = []
164
+ for _ in range(REG_COUNT):
165
+ regs.append(bits_to_int(bits[idx:idx + REG_BITS]))
166
+ idx += REG_BITS
167
+
168
+ flags = [int(b) for b in bits[idx:idx + FLAG_BITS]]
169
+ idx += FLAG_BITS
170
+
171
+ sp = bits_to_int(bits[idx:idx + SP_BITS])
172
+ idx += SP_BITS
173
+
174
+ ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]]
175
+ idx += CTRL_BITS
176
+
177
+ mem = []
178
+ for _ in range(MEM_BYTES):
179
+ mem.append(bits_to_int(bits[idx:idx + REG_BITS]))
180
+ idx += REG_BITS
181
+
182
+ return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem)
183
+
184
+
185
+ def decode_ir(ir: int) -> Tuple[int, int, int, int]:
186
+ opcode = (ir >> 12) & 0xF
187
+ rd = (ir >> 10) & 0x3
188
+ rs = (ir >> 8) & 0x3
189
+ imm8 = ir & 0xFF
190
+ return opcode, rd, rs, imm8
191
+
192
+
193
+ def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]:
194
+ z = 1 if result == 0 else 0
195
+ n = 1 if (result & 0x80) else 0
196
+ c = 1 if carry else 0
197
+ v = 1 if overflow else 0
198
+ return z, n, c, v
199
+
200
+
201
+ def alu_add(a: int, b: int) -> Tuple[int, int, int]:
202
+ full = a + b
203
+ result = full & 0xFF
204
+ carry = 1 if full > 0xFF else 0
205
+ overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
206
+ return result, carry, overflow
207
+
208
+
209
+ def alu_sub(a: int, b: int) -> Tuple[int, int, int]:
210
+ full = (a - b) & 0x1FF
211
+ result = full & 0xFF
212
+ carry = 1 if a >= b else 0
213
+ overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
214
+ return result, carry, overflow
215
+
216
+
217
+ def ref_step(state: CPUState) -> CPUState:
218
+ """Reference CPU cycle (pure Python arithmetic)."""
219
+ if state.ctrl[0] == 1:
220
+ return state.copy()
221
+
222
+ s = state.copy()
223
+
224
+ hi = s.mem[s.pc]
225
+ lo = s.mem[(s.pc + 1) & 0xFFFF]
226
+ s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
227
+ next_pc = (s.pc + 2) & 0xFFFF
228
+
229
+ opcode, rd, rs, imm8 = decode_ir(s.ir)
230
+ a = s.regs[rd]
231
+ b = s.regs[rs]
232
+
233
+ addr16 = None
234
+ next_pc_ext = next_pc
235
+ if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
236
+ addr_hi = s.mem[next_pc]
237
+ addr_lo = s.mem[(next_pc + 1) & 0xFFFF]
238
+ addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
239
+ next_pc_ext = (next_pc + 2) & 0xFFFF
240
+
241
+ write_result = True
242
+ result = a
243
+ carry = 0
244
+ overflow = 0
245
+
246
+ if opcode == 0x0:
247
+ result, carry, overflow = alu_add(a, b)
248
+ elif opcode == 0x1:
249
+ result, carry, overflow = alu_sub(a, b)
250
+ elif opcode == 0x2:
251
+ result = a & b
252
+ elif opcode == 0x3:
253
+ result = a | b
254
+ elif opcode == 0x4:
255
+ result = a ^ b
256
+ elif opcode == 0x5:
257
+ result = (a << 1) & 0xFF
258
+ elif opcode == 0x6:
259
+ result = (a >> 1) & 0xFF
260
+ elif opcode == 0x7:
261
+ result = (a * b) & 0xFF
262
+ elif opcode == 0x8:
263
+ if b == 0:
264
+ result = 0xFF
265
+ else:
266
+ result = a // b
267
+ elif opcode == 0x9:
268
+ result, carry, overflow = alu_sub(a, b)
269
+ write_result = False
270
+ elif opcode == 0xA:
271
+ result = s.mem[addr16]
272
+ elif opcode == 0xB:
273
+ s.mem[addr16] = b & 0xFF
274
+ write_result = False
275
+ elif opcode == 0xC:
276
+ s.pc = addr16 & 0xFFFF
277
+ write_result = False
278
+ elif opcode == 0xD:
279
+ cond_type = imm8 & 0x7
280
+ if cond_type == 0:
281
+ take_branch = s.flags[0] == 1
282
+ elif cond_type == 1:
283
+ take_branch = s.flags[0] == 0
284
+ elif cond_type == 2:
285
+ take_branch = s.flags[2] == 1
286
+ elif cond_type == 3:
287
+ take_branch = s.flags[2] == 0
288
+ elif cond_type == 4:
289
+ take_branch = s.flags[1] == 1
290
+ elif cond_type == 5:
291
+ take_branch = s.flags[1] == 0
292
+ elif cond_type == 6:
293
+ take_branch = s.flags[3] == 1
294
+ else:
295
+ take_branch = s.flags[3] == 0
296
+ if take_branch:
297
+ s.pc = addr16 & 0xFFFF
298
+ else:
299
+ s.pc = next_pc_ext
300
+ write_result = False
301
+ elif opcode == 0xE:
302
+ ret_addr = next_pc_ext & 0xFFFF
303
+ s.sp = (s.sp - 1) & 0xFFFF
304
+ s.mem[s.sp] = (ret_addr >> 8) & 0xFF
305
+ s.sp = (s.sp - 1) & 0xFFFF
306
+ s.mem[s.sp] = ret_addr & 0xFF
307
+ s.pc = addr16 & 0xFFFF
308
+ write_result = False
309
+ elif opcode == 0xF:
310
+ s.ctrl[0] = 1
311
+ write_result = False
312
+
313
+ if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8):
314
+ s.flags = list(flags_from_result(result, carry, overflow))
315
+
316
+ if write_result:
317
+ s.regs[rd] = result & 0xFF
318
+
319
+ if opcode not in (0xC, 0xD, 0xE):
320
+ s.pc = next_pc_ext
321
+
322
+ return s
323
+
324
+
325
+ def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
326
+ """Reference execution loop."""
327
+ s = state.copy()
328
+ for i in range(max_cycles):
329
+ if s.ctrl[0] == 1:
330
+ return s, i
331
+ s = ref_step(s)
332
+ return s, max_cycles
333
+
334
+
335
+ class ThresholdALU:
336
+ def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None:
337
+ self.device = device
338
+ self.tensors = {k: v.float().to(device) for k, v in load_model(model_path).items()}
339
+
340
+ def _get(self, name: str) -> torch.Tensor:
341
+ return self.tensors[name]
342
+
343
+ def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float:
344
+ w = self._get(weight_key)
345
+ b = self._get(bias_key)
346
+ inp = torch.tensor(inputs, device=self.device)
347
+ return heaviside((inp * w).sum() + b).item()
348
+
349
+ def _eval_xor(self, prefix: str, inputs: List[float]) -> float:
350
+ inp = torch.tensor(inputs, device=self.device)
351
+ w_or = self._get(f"{prefix}.layer1.or.weight")
352
+ b_or = self._get(f"{prefix}.layer1.or.bias")
353
+ w_nand = self._get(f"{prefix}.layer1.nand.weight")
354
+ b_nand = self._get(f"{prefix}.layer1.nand.bias")
355
+ w2 = self._get(f"{prefix}.layer2.weight")
356
+ b2 = self._get(f"{prefix}.layer2.bias")
357
+
358
+ h_or = heaviside((inp * w_or).sum() + b_or).item()
359
+ h_nand = heaviside((inp * w_nand).sum() + b_nand).item()
360
+ hidden = torch.tensor([h_or, h_nand], device=self.device)
361
+ return heaviside((hidden * w2).sum() + b2).item()
362
+
363
+ def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]:
364
+ ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b])
365
+ ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b])
366
+
367
+ ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin])
368
+ ha2_carry = self._eval_gate(
369
+ f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin]
370
+ )
371
+
372
+ cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry])
373
+ return ha2_sum, cout
374
+
375
+ def add(self, a: int, b: int) -> Tuple[int, int, int]:
376
+ a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
377
+ b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
378
+
379
+ carry = 0.0
380
+ sum_bits: List[int] = []
381
+ for bit in range(REG_BITS):
382
+ sum_bit, carry = self._eval_full_adder(
383
+ f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry
384
+ )
385
+ sum_bits.append(int(sum_bit))
386
+
387
+ result = bits_to_int(list(reversed(sum_bits)))
388
+ carry_out = int(carry)
389
+ overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
390
+ return result, carry_out, overflow
391
+
392
+ def sub(self, a: int, b: int) -> Tuple[int, int, int]:
393
+ a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
394
+ b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
395
+
396
+ carry = 1.0
397
+ sum_bits: List[int] = []
398
+ for bit in range(REG_BITS):
399
+ notb = self._eval_gate(
400
+ f"arithmetic.sub8bit.notb{bit}.weight",
401
+ f"arithmetic.sub8bit.notb{bit}.bias",
402
+ [float(b_bits[bit])],
403
+ )
404
+
405
+ xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb])
406
+ xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry])
407
+
408
+ and1 = self._eval_gate(
409
+ f"arithmetic.sub8bit.fa{bit}.and1.weight",
410
+ f"arithmetic.sub8bit.fa{bit}.and1.bias",
411
+ [float(a_bits[bit]), notb],
412
+ )
413
+ and2 = self._eval_gate(
414
+ f"arithmetic.sub8bit.fa{bit}.and2.weight",
415
+ f"arithmetic.sub8bit.fa{bit}.and2.bias",
416
+ [xor1, carry],
417
+ )
418
+ carry = self._eval_gate(
419
+ f"arithmetic.sub8bit.fa{bit}.or_carry.weight",
420
+ f"arithmetic.sub8bit.fa{bit}.or_carry.bias",
421
+ [and1, and2],
422
+ )
423
+
424
+ sum_bits.append(int(xor2))
425
+
426
+ result = bits_to_int(list(reversed(sum_bits)))
427
+ carry_out = int(carry)
428
+ overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
429
+ return result, carry_out, overflow
430
+
431
+ def bitwise_and(self, a: int, b: int) -> int:
432
+ a_bits = int_to_bits(a, REG_BITS)
433
+ b_bits = int_to_bits(b, REG_BITS)
434
+ w = self._get("alu.alu8bit.and.weight")
435
+ bias = self._get("alu.alu8bit.and.bias")
436
+
437
+ out_bits = []
438
+ for bit in range(REG_BITS):
439
+ inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
440
+ out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
441
+ out_bits.append(int(out))
442
+
443
+ return bits_to_int(out_bits)
444
+
445
+ def bitwise_or(self, a: int, b: int) -> int:
446
+ a_bits = int_to_bits(a, REG_BITS)
447
+ b_bits = int_to_bits(b, REG_BITS)
448
+ w = self._get("alu.alu8bit.or.weight")
449
+ bias = self._get("alu.alu8bit.or.bias")
450
+
451
+ out_bits = []
452
+ for bit in range(REG_BITS):
453
+ inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
454
+ out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
455
+ out_bits.append(int(out))
456
+
457
+ return bits_to_int(out_bits)
458
+
459
+ def bitwise_not(self, a: int) -> int:
460
+ a_bits = int_to_bits(a, REG_BITS)
461
+ w = self._get("alu.alu8bit.not.weight")
462
+ bias = self._get("alu.alu8bit.not.bias")
463
+
464
+ out_bits = []
465
+ for bit in range(REG_BITS):
466
+ inp = torch.tensor([float(a_bits[bit])], device=self.device)
467
+ out = heaviside((inp * w[bit]).sum() + bias[bit]).item()
468
+ out_bits.append(int(out))
469
+
470
+ return bits_to_int(out_bits)
471
+
472
+ def bitwise_xor(self, a: int, b: int) -> int:
473
+ a_bits = int_to_bits(a, REG_BITS)
474
+ b_bits = int_to_bits(b, REG_BITS)
475
+
476
+ w_or = self._get("alu.alu8bit.xor.layer1.or.weight")
477
+ b_or = self._get("alu.alu8bit.xor.layer1.or.bias")
478
+ w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight")
479
+ b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias")
480
+ w2 = self._get("alu.alu8bit.xor.layer2.weight")
481
+ b2 = self._get("alu.alu8bit.xor.layer2.bias")
482
+
483
+ out_bits = []
484
+ for bit in range(REG_BITS):
485
+ inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
486
+ h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit])
487
+ h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit])
488
+ hidden = torch.stack([h_or, h_nand])
489
+ out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item()
490
+ out_bits.append(int(out))
491
+
492
+ return bits_to_int(out_bits)
493
+
494
+ def shift_left(self, a: int) -> int:
495
+ a_bits = int_to_bits(a, REG_BITS)
496
+ out_bits = []
497
+ for bit in range(REG_BITS):
498
+ w = self._get(f"alu.alu8bit.shl.bit{bit}.weight")
499
+ bias = self._get(f"alu.alu8bit.shl.bit{bit}.bias")
500
+ if bit < 7:
501
+ inp = torch.tensor([float(a_bits[bit + 1])], device=self.device)
502
+ else:
503
+ inp = torch.tensor([0.0], device=self.device)
504
+ out = heaviside((inp * w).sum() + bias).item()
505
+ out_bits.append(int(out))
506
+ return bits_to_int(out_bits)
507
+
508
+ def shift_right(self, a: int) -> int:
509
+ a_bits = int_to_bits(a, REG_BITS)
510
+ out_bits = []
511
+ for bit in range(REG_BITS):
512
+ w = self._get(f"alu.alu8bit.shr.bit{bit}.weight")
513
+ bias = self._get(f"alu.alu8bit.shr.bit{bit}.bias")
514
+ if bit > 0:
515
+ inp = torch.tensor([float(a_bits[bit - 1])], device=self.device)
516
+ else:
517
+ inp = torch.tensor([0.0], device=self.device)
518
+ out = heaviside((inp * w).sum() + bias).item()
519
+ out_bits.append(int(out))
520
+ return bits_to_int(out_bits)
521
+
522
+ def multiply(self, a: int, b: int) -> int:
523
+ """8-bit multiply using partial product AND gates + shift-add."""
524
+ a_bits = int_to_bits(a, REG_BITS)
525
+ b_bits = int_to_bits(b, REG_BITS)
526
+
527
+ pp = [[0] * 8 for _ in range(8)]
528
+ for i in range(8):
529
+ for j in range(8):
530
+ w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight")
531
+ bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias")
532
+ inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device)
533
+ pp[i][j] = int(heaviside((inp * w).sum() + bias).item())
534
+
535
+ result = 0
536
+ for j in range(8):
537
+ if b_bits[j] == 0:
538
+ continue
539
+ row = 0
540
+ for i in range(8):
541
+ row |= (pp[i][j] << (7 - i))
542
+ shifted = row << (7 - j)
543
+ result, _, _ = self.add(result & 0xFF, shifted & 0xFF)
544
+ if shifted > 255 or result > 255:
545
+ result = (result + (shifted >> 8)) & 0xFF
546
+
547
+ return result & 0xFF
548
+
549
+ def divide(self, a: int, b: int) -> Tuple[int, int]:
550
+ """8-bit divide using restoring division with threshold gates."""
551
+ if b == 0:
552
+ return 0xFF, a
553
+
554
+ a_bits = int_to_bits(a, REG_BITS)
555
+
556
+ quotient = 0
557
+ remainder = 0
558
+
559
+ for stage in range(8):
560
+ remainder = ((remainder << 1) | a_bits[stage]) & 0xFF
561
+
562
+ rem_bits = int_to_bits(remainder, REG_BITS)
563
+ div_bits = int_to_bits(b, REG_BITS)
564
+
565
+ w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight")
566
+ bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias")
567
+ inp = torch.tensor([float(rem_bits[i]) for i in range(8)] +
568
+ [float(div_bits[i]) for i in range(8)], device=self.device)
569
+ cmp_result = int(heaviside((inp * w).sum() + bias).item())
570
+
571
+ if cmp_result:
572
+ remainder, _, _ = self.sub(remainder, b)
573
+ quotient = (quotient << 1) | 1
574
+ else:
575
+ quotient = quotient << 1
576
+
577
+ return quotient & 0xFF, remainder & 0xFF
578
+
579
+
580
+ class ThresholdCPU:
581
+ def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None:
582
+ self.device = device
583
+ self.alu = ThresholdALU(model_path, device=device)
584
+
585
+ def _addr_decode(self, addr: int) -> torch.Tensor:
586
+ bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32)
587
+ w = self.alu._get("memory.addr_decode.weight")
588
+ b = self.alu._get("memory.addr_decode.bias")
589
+ return heaviside((w * bits).sum(dim=1) + b)
590
+
591
+ def _memory_read(self, mem: List[int], addr: int) -> int:
592
+ sel = self._addr_decode(addr)
593
+ mem_bits = torch.tensor(
594
+ [int_to_bits(byte, REG_BITS) for byte in mem],
595
+ device=self.device,
596
+ dtype=torch.float32,
597
+ )
598
+ and_w = self.alu._get("memory.read.and.weight")
599
+ and_b = self.alu._get("memory.read.and.bias")
600
+ or_w = self.alu._get("memory.read.or.weight")
601
+ or_b = self.alu._get("memory.read.or.bias")
602
+
603
+ out_bits: List[int] = []
604
+ for bit in range(REG_BITS):
605
+ inp = torch.stack([mem_bits[:, bit], sel], dim=1)
606
+ and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
607
+ out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()
608
+ out_bits.append(int(out_bit))
609
+
610
+ return bits_to_int(out_bits)
611
+
612
+ def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]:
613
+ sel = self._addr_decode(addr)
614
+ data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32)
615
+ mem_bits = torch.tensor(
616
+ [int_to_bits(byte, REG_BITS) for byte in mem],
617
+ device=self.device,
618
+ dtype=torch.float32,
619
+ )
620
+
621
+ sel_w = self.alu._get("memory.write.sel.weight")
622
+ sel_b = self.alu._get("memory.write.sel.bias")
623
+ nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1)
624
+ nsel_b = self.alu._get("memory.write.nsel.bias")
625
+ and_old_w = self.alu._get("memory.write.and_old.weight")
626
+ and_old_b = self.alu._get("memory.write.and_old.bias")
627
+ and_new_w = self.alu._get("memory.write.and_new.weight")
628
+ and_new_b = self.alu._get("memory.write.and_new.bias")
629
+ or_w = self.alu._get("memory.write.or.weight")
630
+ or_b = self.alu._get("memory.write.or.bias")
631
+
632
+ we = torch.ones_like(sel)
633
+ sel_inp = torch.stack([sel, we], dim=1)
634
+ write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
635
+ nsel = heaviside((write_sel * nsel_w) + nsel_b)
636
+
637
+ new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device)
638
+ for bit in range(REG_BITS):
639
+ old_bit = mem_bits[:, bit]
640
+ data_bit = data_bits[bit].expand(MEM_BYTES)
641
+ inp_old = torch.stack([old_bit, nsel], dim=1)
642
+ inp_new = torch.stack([data_bit, write_sel], dim=1)
643
+
644
+ and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
645
+ and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
646
+ or_inp = torch.stack([and_old, and_new], dim=1)
647
+ out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
648
+ new_mem_bits[:, bit] = out_bit
649
+
650
+ return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)]
651
+
652
+ def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int:
653
+ pc_bits = int_to_bits(pc_byte, REG_BITS)
654
+ target_bits = int_to_bits(target_byte, REG_BITS)
655
+
656
+ out_bits: List[int] = []
657
+ for bit in range(REG_BITS):
658
+ not_sel = self.alu._eval_gate(
659
+ f"{prefix}.bit{bit}.not_sel.weight",
660
+ f"{prefix}.bit{bit}.not_sel.bias",
661
+ [float(flag)],
662
+ )
663
+ and_a = self.alu._eval_gate(
664
+ f"{prefix}.bit{bit}.and_a.weight",
665
+ f"{prefix}.bit{bit}.and_a.bias",
666
+ [float(pc_bits[bit]), not_sel],
667
+ )
668
+ and_b = self.alu._eval_gate(
669
+ f"{prefix}.bit{bit}.and_b.weight",
670
+ f"{prefix}.bit{bit}.and_b.bias",
671
+ [float(target_bits[bit]), float(flag)],
672
+ )
673
+ out_bit = self.alu._eval_gate(
674
+ f"{prefix}.bit{bit}.or.weight",
675
+ f"{prefix}.bit{bit}.or.bias",
676
+ [and_a, and_b],
677
+ )
678
+ out_bits.append(int(out_bit))
679
+
680
+ return bits_to_int(out_bits)
681
+
682
+ def step(self, state: CPUState) -> CPUState:
683
+ """Single CPU cycle using threshold neurons."""
684
+ if state.ctrl[0] == 1:
685
+ return state.copy()
686
+
687
+ s = state.copy()
688
+
689
+ hi = self._memory_read(s.mem, s.pc)
690
+ lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF)
691
+ s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
692
+ next_pc = (s.pc + 2) & 0xFFFF
693
+
694
+ opcode, rd, rs, imm8 = decode_ir(s.ir)
695
+ a = s.regs[rd]
696
+ b = s.regs[rs]
697
+
698
+ addr16 = None
699
+ next_pc_ext = next_pc
700
+ if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
701
+ addr_hi = self._memory_read(s.mem, next_pc)
702
+ addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF)
703
+ addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
704
+ next_pc_ext = (next_pc + 2) & 0xFFFF
705
+
706
+ write_result = True
707
+ result = a
708
+ carry = 0
709
+ overflow = 0
710
+
711
+ if opcode == 0x0:
712
+ result, carry, overflow = self.alu.add(a, b)
713
+ elif opcode == 0x1:
714
+ result, carry, overflow = self.alu.sub(a, b)
715
+ elif opcode == 0x2:
716
+ result = self.alu.bitwise_and(a, b)
717
+ elif opcode == 0x3:
718
+ result = self.alu.bitwise_or(a, b)
719
+ elif opcode == 0x4:
720
+ result = self.alu.bitwise_xor(a, b)
721
+ elif opcode == 0x5:
722
+ result = self.alu.shift_left(a)
723
+ elif opcode == 0x6:
724
+ result = self.alu.shift_right(a)
725
+ elif opcode == 0x7:
726
+ result = self.alu.multiply(a, b)
727
+ elif opcode == 0x8:
728
+ result, _ = self.alu.divide(a, b)
729
+ elif opcode == 0x9:
730
+ result, carry, overflow = self.alu.sub(a, b)
731
+ write_result = False
732
+ elif opcode == 0xA:
733
+ result = self._memory_read(s.mem, addr16)
734
+ elif opcode == 0xB:
735
+ s.mem = self._memory_write(s.mem, addr16, b & 0xFF)
736
+ write_result = False
737
+ elif opcode == 0xC:
738
+ s.pc = addr16 & 0xFFFF
739
+ write_result = False
740
+ elif opcode == 0xD:
741
+ cond_type = imm8 & 0x7
742
+ cond_circuits = [
743
+ ("control.jz", 0),
744
+ ("control.jnz", 0),
745
+ ("control.jc", 2),
746
+ ("control.jnc", 2),
747
+ ("control.jn", 1),
748
+ ("control.jp", 1),
749
+ ("control.jv", 3),
750
+ ("control.jnv", 3),
751
+ ]
752
+ circuit_prefix, flag_idx = cond_circuits[cond_type]
753
+ hi_pc = self._conditional_jump_byte(
754
+ circuit_prefix,
755
+ (next_pc_ext >> 8) & 0xFF,
756
+ (addr16 >> 8) & 0xFF,
757
+ s.flags[flag_idx],
758
+ )
759
+ lo_pc = self._conditional_jump_byte(
760
+ circuit_prefix,
761
+ next_pc_ext & 0xFF,
762
+ addr16 & 0xFF,
763
+ s.flags[flag_idx],
764
+ )
765
+ s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF)
766
+ write_result = False
767
+ elif opcode == 0xE:
768
+ ret_addr = next_pc_ext & 0xFFFF
769
+ s.sp = (s.sp - 1) & 0xFFFF
770
+ s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF)
771
+ s.sp = (s.sp - 1) & 0xFFFF
772
+ s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF)
773
+ s.pc = addr16 & 0xFFFF
774
+ write_result = False
775
+ elif opcode == 0xF:
776
+ s.ctrl[0] = 1
777
+ write_result = False
778
+
779
+ if opcode <= 0x9 or opcode == 0xA:
780
+ s.flags = list(flags_from_result(result, carry, overflow))
781
+
782
+ if write_result:
783
+ s.regs[rd] = result & 0xFF
784
+
785
+ if opcode not in (0xC, 0xD, 0xE):
786
+ s.pc = next_pc_ext
787
+
788
+ return s
789
+
790
+ def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
791
+ """Execute until HALT or max_cycles reached."""
792
+ s = state.copy()
793
+ for i in range(max_cycles):
794
+ if s.ctrl[0] == 1:
795
+ return s, i
796
+ s = self.step(s)
797
+ return s, max_cycles
798
+
799
+ def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor:
800
+ """Tensor-in, tensor-out interface for neural integration."""
801
+ bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()]
802
+ state = unpack_state(bits_list)
803
+ final, _ = self.run_until_halt(state, max_cycles=max_cycles)
804
+ return torch.tensor(pack_state(final), dtype=torch.float32)
805
+
806
+
807
+ def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int:
808
+ return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF)
809
+
810
+
811
+ def write_instr(mem: List[int], addr: int, instr: int) -> None:
812
+ mem[addr & 0xFFFF] = (instr >> 8) & 0xFF
813
+ mem[(addr + 1) & 0xFFFF] = instr & 0xFF
814
+
815
+
816
+ def write_addr(mem: List[int], addr: int, value: int) -> None:
817
+ mem[addr & 0xFFFF] = (value >> 8) & 0xFF
818
+ mem[(addr + 1) & 0xFFFF] = value & 0xFF
819
+
820
+
821
+ def run_smoke_test() -> int:
822
+ """Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12."""
823
+ mem = [0] * 65536
824
+
825
+ write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00))
826
+ write_addr(mem, 0x0002, 0x0100)
827
+ write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00))
828
+ write_addr(mem, 0x0006, 0x0101)
829
+ write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00))
830
+ write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00))
831
+ write_addr(mem, 0x000C, 0x0102)
832
+ write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00))
833
+
834
+ mem[0x0100] = 5
835
+ mem[0x0101] = 7
836
+
837
+ state = CPUState(
838
+ pc=0,
839
+ ir=0,
840
+ regs=[0, 0, 0, 0],
841
+ flags=[0, 0, 0, 0],
842
+ sp=0xFFFE,
843
+ ctrl=[0, 0, 0, 0],
844
+ mem=mem,
845
+ )
846
+
847
+ print("Running reference implementation...")
848
+ final, cycles = ref_run_until_halt(state, max_cycles=20)
849
+
850
+ assert final.ctrl[0] == 1, "HALT flag not set"
851
+ assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}"
852
+ assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}"
853
+ assert cycles <= 10, f"Unexpected cycle count: {cycles}"
854
+ print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}")
855
+
856
+ print("Running threshold-weight implementation...")
857
+ threshold_cpu = ThresholdCPU()
858
+ t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20)
859
+
860
+ assert t_final.ctrl[0] == 1, "Threshold HALT flag not set"
861
+ assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}"
862
+ assert t_final.mem[0x0102] == final.mem[0x0102], (
863
+ f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}"
864
+ )
865
+ assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}"
866
+ print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}")
867
+
868
+ print("Validating forward() tensor I/O...")
869
+ bits = torch.tensor(pack_state(state), dtype=torch.float32)
870
+ out_bits = threshold_cpu.forward(bits, max_cycles=20)
871
+ out_state = unpack_state([int(b) for b in out_bits.tolist()])
872
+ assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}"
873
+ assert out_state.mem[0x0102] == final.mem[0x0102], (
874
+ f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}"
875
+ )
876
+ print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}")
877
+
878
+ print("\nSmoke test: PASSED")
879
+ return 0
880
+
881
+
882
+ # =============================================================================
883
+ # CIRCUIT EVALUATION
884
+ # =============================================================================
885
+
886
  class BatchedFitnessEvaluator:
887
  """
888
  GPU-batched fitness evaluator with per-circuit reporting.
 
3505
  parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu')
3506
  parser.add_argument('--pop_size', type=int, default=1, help='Population size for batched evaluation')
3507
  parser.add_argument('--quiet', action='store_true', help='Suppress detailed output')
3508
+ parser.add_argument('--cpu-test', action='store_true', help='Run CPU smoke test (LOAD, ADD, STORE, HALT)')
3509
  args = parser.parse_args()
3510
 
3511
+ if args.cpu_test:
3512
+ return run_smoke_test()
3513
+
3514
  print("=" * 70)
3515
  print(" UNIFIED EVALUATION SUITE")
3516
  print("=" * 70)
threshold_cpu.py DELETED
@@ -1,842 +0,0 @@
1
- """
2
- 8-bit Threshold Computer - CPU Runtime
3
-
4
- State layout, reference cycle, and threshold-weight execution.
5
- All multi-bit fields are MSB-first.
6
-
7
- Usage:
8
- python threshold_cpu.py # Run smoke test
9
- python threshold_cpu.py --help # Show options
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- import argparse
15
- from dataclasses import dataclass
16
- from pathlib import Path
17
- from typing import List, Tuple
18
-
19
- import torch
20
- from safetensors.torch import load_file
21
-
22
-
23
- FLAG_NAMES = ["Z", "N", "C", "V"]
24
- CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"]
25
-
26
- PC_BITS = 16
27
- IR_BITS = 16
28
- REG_BITS = 8
29
- REG_COUNT = 4
30
- FLAG_BITS = 4
31
- SP_BITS = 16
32
- CTRL_BITS = 4
33
- MEM_BYTES = 65536
34
- MEM_BITS = MEM_BYTES * 8
35
-
36
- STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS
37
-
38
- DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "neural_computer.safetensors"
39
-
40
-
41
- def int_to_bits(value: int, width: int) -> List[int]:
42
- return [(value >> (width - 1 - i)) & 1 for i in range(width)]
43
-
44
-
45
- def bits_to_int(bits: List[int]) -> int:
46
- value = 0
47
- for bit in bits:
48
- value = (value << 1) | int(bit)
49
- return value
50
-
51
-
52
- def bits_msb_to_lsb(bits: List[int]) -> List[int]:
53
- return list(reversed(bits))
54
-
55
-
56
- @dataclass
57
- class CPUState:
58
- pc: int
59
- ir: int
60
- regs: List[int]
61
- flags: List[int]
62
- sp: int
63
- ctrl: List[int]
64
- mem: List[int]
65
-
66
- def copy(self) -> CPUState:
67
- return CPUState(
68
- pc=int(self.pc),
69
- ir=int(self.ir),
70
- regs=[int(r) for r in self.regs],
71
- flags=[int(f) for f in self.flags],
72
- sp=int(self.sp),
73
- ctrl=[int(c) for c in self.ctrl],
74
- mem=[int(m) for m in self.mem],
75
- )
76
-
77
-
78
- def pack_state(state: CPUState) -> List[int]:
79
- bits: List[int] = []
80
- bits.extend(int_to_bits(state.pc, PC_BITS))
81
- bits.extend(int_to_bits(state.ir, IR_BITS))
82
- for reg in state.regs:
83
- bits.extend(int_to_bits(reg, REG_BITS))
84
- bits.extend([int(f) for f in state.flags])
85
- bits.extend(int_to_bits(state.sp, SP_BITS))
86
- bits.extend([int(c) for c in state.ctrl])
87
- for byte in state.mem:
88
- bits.extend(int_to_bits(byte, REG_BITS))
89
- return bits
90
-
91
-
92
- def unpack_state(bits: List[int]) -> CPUState:
93
- if len(bits) != STATE_BITS:
94
- raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}")
95
-
96
- idx = 0
97
- pc = bits_to_int(bits[idx:idx + PC_BITS])
98
- idx += PC_BITS
99
- ir = bits_to_int(bits[idx:idx + IR_BITS])
100
- idx += IR_BITS
101
-
102
- regs = []
103
- for _ in range(REG_COUNT):
104
- regs.append(bits_to_int(bits[idx:idx + REG_BITS]))
105
- idx += REG_BITS
106
-
107
- flags = [int(b) for b in bits[idx:idx + FLAG_BITS]]
108
- idx += FLAG_BITS
109
-
110
- sp = bits_to_int(bits[idx:idx + SP_BITS])
111
- idx += SP_BITS
112
-
113
- ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]]
114
- idx += CTRL_BITS
115
-
116
- mem = []
117
- for _ in range(MEM_BYTES):
118
- mem.append(bits_to_int(bits[idx:idx + REG_BITS]))
119
- idx += REG_BITS
120
-
121
- return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem)
122
-
123
-
124
- def decode_ir(ir: int) -> Tuple[int, int, int, int]:
125
- opcode = (ir >> 12) & 0xF
126
- rd = (ir >> 10) & 0x3
127
- rs = (ir >> 8) & 0x3
128
- imm8 = ir & 0xFF
129
- return opcode, rd, rs, imm8
130
-
131
-
132
- def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]:
133
- z = 1 if result == 0 else 0
134
- n = 1 if (result & 0x80) else 0
135
- c = 1 if carry else 0
136
- v = 1 if overflow else 0
137
- return z, n, c, v
138
-
139
-
140
- def alu_add(a: int, b: int) -> Tuple[int, int, int]:
141
- full = a + b
142
- result = full & 0xFF
143
- carry = 1 if full > 0xFF else 0
144
- overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
145
- return result, carry, overflow
146
-
147
-
148
- def alu_sub(a: int, b: int) -> Tuple[int, int, int]:
149
- full = (a - b) & 0x1FF
150
- result = full & 0xFF
151
- carry = 1 if a >= b else 0
152
- overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
153
- return result, carry, overflow
154
-
155
-
156
- def ref_step(state: CPUState) -> CPUState:
157
- """Reference CPU cycle (pure Python arithmetic)."""
158
- if state.ctrl[0] == 1:
159
- return state.copy()
160
-
161
- s = state.copy()
162
-
163
- hi = s.mem[s.pc]
164
- lo = s.mem[(s.pc + 1) & 0xFFFF]
165
- s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
166
- next_pc = (s.pc + 2) & 0xFFFF
167
-
168
- opcode, rd, rs, imm8 = decode_ir(s.ir)
169
- a = s.regs[rd]
170
- b = s.regs[rs]
171
-
172
- addr16 = None
173
- next_pc_ext = next_pc
174
- if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
175
- addr_hi = s.mem[next_pc]
176
- addr_lo = s.mem[(next_pc + 1) & 0xFFFF]
177
- addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
178
- next_pc_ext = (next_pc + 2) & 0xFFFF
179
-
180
- write_result = True
181
- result = a
182
- carry = 0
183
- overflow = 0
184
-
185
- if opcode == 0x0:
186
- result, carry, overflow = alu_add(a, b)
187
- elif opcode == 0x1:
188
- result, carry, overflow = alu_sub(a, b)
189
- elif opcode == 0x2:
190
- result = a & b
191
- elif opcode == 0x3:
192
- result = a | b
193
- elif opcode == 0x4:
194
- result = a ^ b
195
- elif opcode == 0x5:
196
- result = (a << 1) & 0xFF
197
- elif opcode == 0x6:
198
- result = (a >> 1) & 0xFF
199
- elif opcode == 0x7:
200
- result = (a * b) & 0xFF
201
- elif opcode == 0x8:
202
- if b == 0:
203
- result = 0xFF
204
- else:
205
- result = a // b
206
- elif opcode == 0x9:
207
- result, carry, overflow = alu_sub(a, b)
208
- write_result = False
209
- elif opcode == 0xA:
210
- result = s.mem[addr16]
211
- elif opcode == 0xB:
212
- s.mem[addr16] = b & 0xFF
213
- write_result = False
214
- elif opcode == 0xC:
215
- s.pc = addr16 & 0xFFFF
216
- write_result = False
217
- elif opcode == 0xD:
218
- cond_type = imm8 & 0x7
219
- if cond_type == 0:
220
- take_branch = s.flags[0] == 1
221
- elif cond_type == 1:
222
- take_branch = s.flags[0] == 0
223
- elif cond_type == 2:
224
- take_branch = s.flags[2] == 1
225
- elif cond_type == 3:
226
- take_branch = s.flags[2] == 0
227
- elif cond_type == 4:
228
- take_branch = s.flags[1] == 1
229
- elif cond_type == 5:
230
- take_branch = s.flags[1] == 0
231
- elif cond_type == 6:
232
- take_branch = s.flags[3] == 1
233
- else:
234
- take_branch = s.flags[3] == 0
235
- if take_branch:
236
- s.pc = addr16 & 0xFFFF
237
- else:
238
- s.pc = next_pc_ext
239
- write_result = False
240
- elif opcode == 0xE:
241
- ret_addr = next_pc_ext & 0xFFFF
242
- s.sp = (s.sp - 1) & 0xFFFF
243
- s.mem[s.sp] = (ret_addr >> 8) & 0xFF
244
- s.sp = (s.sp - 1) & 0xFFFF
245
- s.mem[s.sp] = ret_addr & 0xFF
246
- s.pc = addr16 & 0xFFFF
247
- write_result = False
248
- elif opcode == 0xF:
249
- s.ctrl[0] = 1
250
- write_result = False
251
-
252
- if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8):
253
- s.flags = list(flags_from_result(result, carry, overflow))
254
-
255
- if write_result:
256
- s.regs[rd] = result & 0xFF
257
-
258
- if opcode not in (0xC, 0xD, 0xE):
259
- s.pc = next_pc_ext
260
-
261
- return s
262
-
263
-
264
- def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
265
- """Reference execution loop."""
266
- s = state.copy()
267
- for i in range(max_cycles):
268
- if s.ctrl[0] == 1:
269
- return s, i
270
- s = ref_step(s)
271
- return s, max_cycles
272
-
273
-
274
- def heaviside(x: torch.Tensor) -> torch.Tensor:
275
- return (x >= 0).float()
276
-
277
-
278
- class ThresholdALU:
279
- def __init__(self, model_path: str, device: str = "cpu") -> None:
280
- self.device = device
281
- self.tensors = {k: v.float().to(device) for k, v in load_file(model_path).items()}
282
-
283
- def _get(self, name: str) -> torch.Tensor:
284
- return self.tensors[name]
285
-
286
- def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float:
287
- w = self._get(weight_key)
288
- b = self._get(bias_key)
289
- inp = torch.tensor(inputs, device=self.device)
290
- return heaviside((inp * w).sum() + b).item()
291
-
292
- def _eval_xor(self, prefix: str, inputs: List[float]) -> float:
293
- inp = torch.tensor(inputs, device=self.device)
294
- w_or = self._get(f"{prefix}.layer1.or.weight")
295
- b_or = self._get(f"{prefix}.layer1.or.bias")
296
- w_nand = self._get(f"{prefix}.layer1.nand.weight")
297
- b_nand = self._get(f"{prefix}.layer1.nand.bias")
298
- w2 = self._get(f"{prefix}.layer2.weight")
299
- b2 = self._get(f"{prefix}.layer2.bias")
300
-
301
- h_or = heaviside((inp * w_or).sum() + b_or).item()
302
- h_nand = heaviside((inp * w_nand).sum() + b_nand).item()
303
- hidden = torch.tensor([h_or, h_nand], device=self.device)
304
- return heaviside((hidden * w2).sum() + b2).item()
305
-
306
- def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]:
307
- ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b])
308
- ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b])
309
-
310
- ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin])
311
- ha2_carry = self._eval_gate(
312
- f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin]
313
- )
314
-
315
- cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry])
316
- return ha2_sum, cout
317
-
318
- def add(self, a: int, b: int) -> Tuple[int, int, int]:
319
- a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
320
- b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
321
-
322
- carry = 0.0
323
- sum_bits: List[int] = []
324
- for bit in range(REG_BITS):
325
- sum_bit, carry = self._eval_full_adder(
326
- f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry
327
- )
328
- sum_bits.append(int(sum_bit))
329
-
330
- result = bits_to_int(list(reversed(sum_bits)))
331
- carry_out = int(carry)
332
- overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0
333
- return result, carry_out, overflow
334
-
335
- def sub(self, a: int, b: int) -> Tuple[int, int, int]:
336
- a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS))
337
- b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS))
338
-
339
- carry = 1.0
340
- sum_bits: List[int] = []
341
- for bit in range(REG_BITS):
342
- notb = self._eval_gate(
343
- f"arithmetic.sub8bit.notb{bit}.weight",
344
- f"arithmetic.sub8bit.notb{bit}.bias",
345
- [float(b_bits[bit])],
346
- )
347
-
348
- xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb])
349
- xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry])
350
-
351
- and1 = self._eval_gate(
352
- f"arithmetic.sub8bit.fa{bit}.and1.weight",
353
- f"arithmetic.sub8bit.fa{bit}.and1.bias",
354
- [float(a_bits[bit]), notb],
355
- )
356
- and2 = self._eval_gate(
357
- f"arithmetic.sub8bit.fa{bit}.and2.weight",
358
- f"arithmetic.sub8bit.fa{bit}.and2.bias",
359
- [xor1, carry],
360
- )
361
- carry = self._eval_gate(
362
- f"arithmetic.sub8bit.fa{bit}.or_carry.weight",
363
- f"arithmetic.sub8bit.fa{bit}.or_carry.bias",
364
- [and1, and2],
365
- )
366
-
367
- sum_bits.append(int(xor2))
368
-
369
- result = bits_to_int(list(reversed(sum_bits)))
370
- carry_out = int(carry)
371
- overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0
372
- return result, carry_out, overflow
373
-
374
- def bitwise_and(self, a: int, b: int) -> int:
375
- a_bits = int_to_bits(a, REG_BITS)
376
- b_bits = int_to_bits(b, REG_BITS)
377
- w = self._get("alu.alu8bit.and.weight")
378
- bias = self._get("alu.alu8bit.and.bias")
379
-
380
- out_bits = []
381
- for bit in range(REG_BITS):
382
- inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
383
- out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
384
- out_bits.append(int(out))
385
-
386
- return bits_to_int(out_bits)
387
-
388
- def bitwise_or(self, a: int, b: int) -> int:
389
- a_bits = int_to_bits(a, REG_BITS)
390
- b_bits = int_to_bits(b, REG_BITS)
391
- w = self._get("alu.alu8bit.or.weight")
392
- bias = self._get("alu.alu8bit.or.bias")
393
-
394
- out_bits = []
395
- for bit in range(REG_BITS):
396
- inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
397
- out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item()
398
- out_bits.append(int(out))
399
-
400
- return bits_to_int(out_bits)
401
-
402
- def bitwise_not(self, a: int) -> int:
403
- a_bits = int_to_bits(a, REG_BITS)
404
- w = self._get("alu.alu8bit.not.weight")
405
- bias = self._get("alu.alu8bit.not.bias")
406
-
407
- out_bits = []
408
- for bit in range(REG_BITS):
409
- inp = torch.tensor([float(a_bits[bit])], device=self.device)
410
- out = heaviside((inp * w[bit]).sum() + bias[bit]).item()
411
- out_bits.append(int(out))
412
-
413
- return bits_to_int(out_bits)
414
-
415
- def bitwise_xor(self, a: int, b: int) -> int:
416
- a_bits = int_to_bits(a, REG_BITS)
417
- b_bits = int_to_bits(b, REG_BITS)
418
-
419
- w_or = self._get("alu.alu8bit.xor.layer1.or.weight")
420
- b_or = self._get("alu.alu8bit.xor.layer1.or.bias")
421
- w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight")
422
- b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias")
423
- w2 = self._get("alu.alu8bit.xor.layer2.weight")
424
- b2 = self._get("alu.alu8bit.xor.layer2.bias")
425
-
426
- out_bits = []
427
- for bit in range(REG_BITS):
428
- inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device)
429
- h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit])
430
- h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit])
431
- hidden = torch.stack([h_or, h_nand])
432
- out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item()
433
- out_bits.append(int(out))
434
-
435
- return bits_to_int(out_bits)
436
-
437
- def shift_left(self, a: int) -> int:
438
- a_bits = int_to_bits(a, REG_BITS)
439
- out_bits = []
440
- for bit in range(REG_BITS):
441
- w = self.alu._get(f"alu.alu8bit.shl.bit{bit}.weight")
442
- bias = self.alu._get(f"alu.alu8bit.shl.bit{bit}.bias")
443
- if bit < 7:
444
- inp = torch.tensor([float(a_bits[bit + 1])], device=self.device)
445
- else:
446
- inp = torch.tensor([0.0], device=self.device)
447
- out = heaviside((inp * w).sum() + bias).item()
448
- out_bits.append(int(out))
449
- return bits_to_int(out_bits)
450
-
451
- def shift_right(self, a: int) -> int:
452
- a_bits = int_to_bits(a, REG_BITS)
453
- out_bits = []
454
- for bit in range(REG_BITS):
455
- w = self.alu._get(f"alu.alu8bit.shr.bit{bit}.weight")
456
- bias = self.alu._get(f"alu.alu8bit.shr.bit{bit}.bias")
457
- if bit > 0:
458
- inp = torch.tensor([float(a_bits[bit - 1])], device=self.device)
459
- else:
460
- inp = torch.tensor([0.0], device=self.device)
461
- out = heaviside((inp * w).sum() + bias).item()
462
- out_bits.append(int(out))
463
- return bits_to_int(out_bits)
464
-
465
- def multiply(self, a: int, b: int) -> int:
466
- """8-bit multiply using partial product AND gates + shift-add."""
467
- a_bits = int_to_bits(a, REG_BITS)
468
- b_bits = int_to_bits(b, REG_BITS)
469
-
470
- # Compute all 64 partial products using AND gates
471
- pp = [[0] * 8 for _ in range(8)]
472
- for i in range(8):
473
- for j in range(8):
474
- w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight")
475
- bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias")
476
- inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device)
477
- pp[i][j] = int(heaviside((inp * w).sum() + bias).item())
478
-
479
- # Shift-add accumulation using existing 8-bit adder
480
- # Row j contributes A*B[j] shifted left by (7-j) positions
481
- result = 0
482
- for j in range(8):
483
- if b_bits[j] == 0:
484
- continue
485
- # Construct the partial product row (A masked by B[j])
486
- row = 0
487
- for i in range(8):
488
- row |= (pp[i][j] << (7 - i))
489
- # Shift by position (7-j means B[7] is LSB, B[0] is MSB)
490
- shifted = row << (7 - j)
491
- # Add to result using threshold adder
492
- result, _, _ = self.add(result & 0xFF, shifted & 0xFF)
493
- # Handle overflow into high byte
494
- if shifted > 255 or result > 255:
495
- result = (result + (shifted >> 8)) & 0xFF
496
-
497
- return result & 0xFF
498
-
499
- def divide(self, a: int, b: int) -> Tuple[int, int]:
500
- """8-bit divide using restoring division with threshold gates."""
501
- if b == 0:
502
- return 0xFF, a # Division by zero: return max quotient, original dividend
503
-
504
- a_bits = int_to_bits(a, REG_BITS)
505
-
506
- quotient = 0
507
- remainder = 0
508
-
509
- for stage in range(8):
510
- # Shift remainder left and bring in next dividend bit
511
- remainder = ((remainder << 1) | a_bits[stage]) & 0xFF
512
-
513
- # Compare remainder >= divisor using threshold gate
514
- rem_bits = int_to_bits(remainder, REG_BITS)
515
- div_bits = int_to_bits(b, REG_BITS)
516
-
517
- w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight")
518
- bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias")
519
- inp = torch.tensor([float(rem_bits[i]) for i in range(8)] +
520
- [float(div_bits[i]) for i in range(8)], device=self.device)
521
- cmp_result = int(heaviside((inp * w).sum() + bias).item())
522
-
523
- # If remainder >= divisor, subtract and set quotient bit
524
- if cmp_result:
525
- remainder, _, _ = self.sub(remainder, b)
526
- quotient = (quotient << 1) | 1
527
- else:
528
- quotient = quotient << 1
529
-
530
- return quotient & 0xFF, remainder & 0xFF
531
-
532
-
533
- class ThresholdCPU:
534
- def __init__(self, model_path: str | Path = DEFAULT_MODEL_PATH, device: str = "cpu") -> None:
535
- self.device = device
536
- self.alu = ThresholdALU(str(model_path), device=device)
537
-
538
- def _addr_decode(self, addr: int) -> torch.Tensor:
539
- bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32)
540
- w = self.alu._get("memory.addr_decode.weight")
541
- b = self.alu._get("memory.addr_decode.bias")
542
- return heaviside((w * bits).sum(dim=1) + b)
543
-
544
- def _memory_read(self, mem: List[int], addr: int) -> int:
545
- sel = self._addr_decode(addr)
546
- mem_bits = torch.tensor(
547
- [int_to_bits(byte, REG_BITS) for byte in mem],
548
- device=self.device,
549
- dtype=torch.float32,
550
- )
551
- and_w = self.alu._get("memory.read.and.weight")
552
- and_b = self.alu._get("memory.read.and.bias")
553
- or_w = self.alu._get("memory.read.or.weight")
554
- or_b = self.alu._get("memory.read.or.bias")
555
-
556
- out_bits: List[int] = []
557
- for bit in range(REG_BITS):
558
- inp = torch.stack([mem_bits[:, bit], sel], dim=1)
559
- and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit])
560
- out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item()
561
- out_bits.append(int(out_bit))
562
-
563
- return bits_to_int(out_bits)
564
-
565
- def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]:
566
- sel = self._addr_decode(addr)
567
- data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32)
568
- mem_bits = torch.tensor(
569
- [int_to_bits(byte, REG_BITS) for byte in mem],
570
- device=self.device,
571
- dtype=torch.float32,
572
- )
573
-
574
- sel_w = self.alu._get("memory.write.sel.weight")
575
- sel_b = self.alu._get("memory.write.sel.bias")
576
- nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1)
577
- nsel_b = self.alu._get("memory.write.nsel.bias")
578
- and_old_w = self.alu._get("memory.write.and_old.weight")
579
- and_old_b = self.alu._get("memory.write.and_old.bias")
580
- and_new_w = self.alu._get("memory.write.and_new.weight")
581
- and_new_b = self.alu._get("memory.write.and_new.bias")
582
- or_w = self.alu._get("memory.write.or.weight")
583
- or_b = self.alu._get("memory.write.or.bias")
584
-
585
- we = torch.ones_like(sel)
586
- sel_inp = torch.stack([sel, we], dim=1)
587
- write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b)
588
- nsel = heaviside((write_sel * nsel_w) + nsel_b)
589
-
590
- new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device)
591
- for bit in range(REG_BITS):
592
- old_bit = mem_bits[:, bit]
593
- data_bit = data_bits[bit].expand(MEM_BYTES)
594
- inp_old = torch.stack([old_bit, nsel], dim=1)
595
- inp_new = torch.stack([data_bit, write_sel], dim=1)
596
-
597
- and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit])
598
- and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit])
599
- or_inp = torch.stack([and_old, and_new], dim=1)
600
- out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit])
601
- new_mem_bits[:, bit] = out_bit
602
-
603
- return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)]
604
-
605
- def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int:
606
- pc_bits = int_to_bits(pc_byte, REG_BITS)
607
- target_bits = int_to_bits(target_byte, REG_BITS)
608
-
609
- out_bits: List[int] = []
610
- for bit in range(REG_BITS):
611
- not_sel = self.alu._eval_gate(
612
- f"{prefix}.bit{bit}.not_sel.weight",
613
- f"{prefix}.bit{bit}.not_sel.bias",
614
- [float(flag)],
615
- )
616
- and_a = self.alu._eval_gate(
617
- f"{prefix}.bit{bit}.and_a.weight",
618
- f"{prefix}.bit{bit}.and_a.bias",
619
- [float(pc_bits[bit]), not_sel],
620
- )
621
- and_b = self.alu._eval_gate(
622
- f"{prefix}.bit{bit}.and_b.weight",
623
- f"{prefix}.bit{bit}.and_b.bias",
624
- [float(target_bits[bit]), float(flag)],
625
- )
626
- out_bit = self.alu._eval_gate(
627
- f"{prefix}.bit{bit}.or.weight",
628
- f"{prefix}.bit{bit}.or.bias",
629
- [and_a, and_b],
630
- )
631
- out_bits.append(int(out_bit))
632
-
633
- return bits_to_int(out_bits)
634
-
635
- def step(self, state: CPUState) -> CPUState:
636
- """Single CPU cycle using threshold neurons."""
637
- if state.ctrl[0] == 1:
638
- return state.copy()
639
-
640
- s = state.copy()
641
-
642
- hi = self._memory_read(s.mem, s.pc)
643
- lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF)
644
- s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF)
645
- next_pc = (s.pc + 2) & 0xFFFF
646
-
647
- opcode, rd, rs, imm8 = decode_ir(s.ir)
648
- a = s.regs[rd]
649
- b = s.regs[rs]
650
-
651
- addr16 = None
652
- next_pc_ext = next_pc
653
- if opcode in (0xA, 0xB, 0xC, 0xD, 0xE):
654
- addr_hi = self._memory_read(s.mem, next_pc)
655
- addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF)
656
- addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF)
657
- next_pc_ext = (next_pc + 2) & 0xFFFF
658
-
659
- write_result = True
660
- result = a
661
- carry = 0
662
- overflow = 0
663
-
664
- if opcode == 0x0:
665
- result, carry, overflow = self.alu.add(a, b)
666
- elif opcode == 0x1:
667
- result, carry, overflow = self.alu.sub(a, b)
668
- elif opcode == 0x2:
669
- result = self.alu.bitwise_and(a, b)
670
- elif opcode == 0x3:
671
- result = self.alu.bitwise_or(a, b)
672
- elif opcode == 0x4:
673
- result = self.alu.bitwise_xor(a, b)
674
- elif opcode == 0x5:
675
- result = self.alu.shift_left(a)
676
- elif opcode == 0x6:
677
- result = self.alu.shift_right(a)
678
- elif opcode == 0x7:
679
- result = self.alu.multiply(a, b)
680
- elif opcode == 0x8:
681
- result, _ = self.alu.divide(a, b)
682
- elif opcode == 0x9:
683
- result, carry, overflow = self.alu.sub(a, b)
684
- write_result = False
685
- elif opcode == 0xA:
686
- result = self._memory_read(s.mem, addr16)
687
- elif opcode == 0xB:
688
- s.mem = self._memory_write(s.mem, addr16, b & 0xFF)
689
- write_result = False
690
- elif opcode == 0xC:
691
- s.pc = addr16 & 0xFFFF
692
- write_result = False
693
- elif opcode == 0xD:
694
- cond_type = imm8 & 0x7
695
- cond_circuits = [
696
- ("control.jz", 0),
697
- ("control.jnz", 0),
698
- ("control.jc", 2),
699
- ("control.jnc", 2),
700
- ("control.jn", 1),
701
- ("control.jp", 1),
702
- ("control.jv", 3),
703
- ("control.jnv", 3),
704
- ]
705
- circuit_prefix, flag_idx = cond_circuits[cond_type]
706
- hi_pc = self._conditional_jump_byte(
707
- circuit_prefix,
708
- (next_pc_ext >> 8) & 0xFF,
709
- (addr16 >> 8) & 0xFF,
710
- s.flags[flag_idx],
711
- )
712
- lo_pc = self._conditional_jump_byte(
713
- circuit_prefix,
714
- next_pc_ext & 0xFF,
715
- addr16 & 0xFF,
716
- s.flags[flag_idx],
717
- )
718
- s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF)
719
- write_result = False
720
- elif opcode == 0xE:
721
- ret_addr = next_pc_ext & 0xFFFF
722
- s.sp = (s.sp - 1) & 0xFFFF
723
- s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF)
724
- s.sp = (s.sp - 1) & 0xFFFF
725
- s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF)
726
- s.pc = addr16 & 0xFFFF
727
- write_result = False
728
- elif opcode == 0xF:
729
- s.ctrl[0] = 1
730
- write_result = False
731
-
732
- if opcode <= 0x9 or opcode == 0xA:
733
- s.flags = list(flags_from_result(result, carry, overflow))
734
-
735
- if write_result:
736
- s.regs[rd] = result & 0xFF
737
-
738
- if opcode not in (0xC, 0xD, 0xE):
739
- s.pc = next_pc_ext
740
-
741
- return s
742
-
743
- def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]:
744
- """Execute until HALT or max_cycles reached."""
745
- s = state.copy()
746
- for i in range(max_cycles):
747
- if s.ctrl[0] == 1:
748
- return s, i
749
- s = self.step(s)
750
- return s, max_cycles
751
-
752
- def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor:
753
- """Tensor-in, tensor-out interface for neural integration."""
754
- bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()]
755
- state = unpack_state(bits_list)
756
- final, _ = self.run_until_halt(state, max_cycles=max_cycles)
757
- return torch.tensor(pack_state(final), dtype=torch.float32)
758
-
759
-
760
- def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int:
761
- return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF)
762
-
763
-
764
- def write_instr(mem: List[int], addr: int, instr: int) -> None:
765
- mem[addr & 0xFFFF] = (instr >> 8) & 0xFF
766
- mem[(addr + 1) & 0xFFFF] = instr & 0xFF
767
-
768
-
769
- def write_addr(mem: List[int], addr: int, value: int) -> None:
770
- mem[addr & 0xFFFF] = (value >> 8) & 0xFF
771
- mem[(addr + 1) & 0xFFFF] = value & 0xFF
772
-
773
-
774
- def run_smoke_test() -> None:
775
- """Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12."""
776
- mem = [0] * 65536
777
-
778
- write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00))
779
- write_addr(mem, 0x0002, 0x0100)
780
- write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00))
781
- write_addr(mem, 0x0006, 0x0101)
782
- write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00))
783
- write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00))
784
- write_addr(mem, 0x000C, 0x0102)
785
- write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00))
786
-
787
- mem[0x0100] = 5
788
- mem[0x0101] = 7
789
-
790
- state = CPUState(
791
- pc=0,
792
- ir=0,
793
- regs=[0, 0, 0, 0],
794
- flags=[0, 0, 0, 0],
795
- sp=0xFFFE,
796
- ctrl=[0, 0, 0, 0],
797
- mem=mem,
798
- )
799
-
800
- print("Running reference implementation...")
801
- final, cycles = ref_run_until_halt(state, max_cycles=20)
802
-
803
- assert final.ctrl[0] == 1, "HALT flag not set"
804
- assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}"
805
- assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}"
806
- assert cycles <= 10, f"Unexpected cycle count: {cycles}"
807
- print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}")
808
-
809
- print("Running threshold-weight implementation...")
810
- threshold_cpu = ThresholdCPU()
811
- t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20)
812
-
813
- assert t_final.ctrl[0] == 1, "Threshold HALT flag not set"
814
- assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}"
815
- assert t_final.mem[0x0102] == final.mem[0x0102], (
816
- f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}"
817
- )
818
- assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}"
819
- print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}")
820
-
821
- print("Validating forward() tensor I/O...")
822
- bits = torch.tensor(pack_state(state), dtype=torch.float32)
823
- out_bits = threshold_cpu.forward(bits, max_cycles=20)
824
- out_state = unpack_state([int(b) for b in out_bits.tolist()])
825
- assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}"
826
- assert out_state.mem[0x0102] == final.mem[0x0102], (
827
- f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}"
828
- )
829
- print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}")
830
-
831
- print("\nSmoke test: PASSED")
832
-
833
-
834
- if __name__ == "__main__":
835
- parser = argparse.ArgumentParser(description="8-bit Threshold CPU")
836
- parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH, help="Path to safetensors model")
837
- args = parser.parse_args()
838
-
839
- if args.model != DEFAULT_MODEL_PATH:
840
- DEFAULT_MODEL_PATH = args.model
841
-
842
- run_smoke_test()