CharlesCNorton commited on
Commit
438de6a
·
1 Parent(s): 3a9b35a

Parameterize CPU circuits for N-bit data and address widths

Browse files

Build changes:
- add_fetch_load_store_buffers: N-bit data bus support
- add_stack_ops: N-bit data, M-bit stack pointer
- add_conditional_jumps: 8 jump types with N-bit addresses
- add_status_flags: Z/N/C/V flags for N-bit ALU results
- add_rol_ror_nbits: Parameterized rotation circuits
- update_manifest: Version 4.0 with data_bits field
- Stack ops moved from cmd_alu to cmd_memory

Eval changes:
- get_manifest: Extract data_bits/addr_bits from model
- BatchedFitnessEvaluator reads manifest for N-bit support
- _test_conditional_jump uses self.addr_bits
- _test_stack_ops uses self.addr_bits for SP width

Tested configurations:
- 8-bit data / 10-bit addr (small memory)
- 32-bit data / 10-bit addr (small memory)

Files changed (4) hide show
  1. build.py +197 -75
  2. eval.py +60 -26
  3. neural_alu32.safetensors +2 -2
  4. neural_computer.safetensors +2 -2
build.py CHANGED
@@ -255,10 +255,19 @@ def add_memory_write_cells(tensors: Dict[str, torch.Tensor], mem_bytes: int) ->
255
  tensors["memory.write.or.bias"] = or_bias
256
 
257
 
258
- def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], addr_bits: int) -> None:
259
- for bit in range(16):
 
 
 
 
 
 
 
 
 
260
  add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0])
261
- for bit in range(8):
262
  add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0])
263
  add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0])
264
  for bit in range(addr_bits):
@@ -555,116 +564,197 @@ def add_neg(tensors: Dict[str, torch.Tensor]) -> None:
555
 
556
 
557
  def add_rol_ror(tensors: Dict[str, torch.Tensor]) -> None:
558
- """Add ROL and ROR circuits (rotate left/right).
 
 
559
 
560
- ROL: out[i] = in[i+1] for i<7, out[7] = in[0] (MSB wraps to LSB)
561
- ROR: out[0] = in[7], out[i] = in[i-1] for i>0 (LSB wraps to MSB)
562
 
563
- Identity gates with circular wiring.
 
 
 
 
564
  """
565
  # ROL: rotate left (toward MSB)
566
- for bit in range(8):
567
- src = (bit + 1) % 8 # Circular: bit 7 gets bit 0
568
- add_gate(tensors, f"alu.alu8bit.rol.bit{bit}", [2.0], [-1.0])
569
 
570
  # ROR: rotate right (toward LSB)
571
- for bit in range(8):
572
- src = (bit - 1) % 8 # Circular: bit 0 gets bit 7
573
- add_gate(tensors, f"alu.alu8bit.ror.bit{bit}", [2.0], [-1.0])
574
 
575
 
576
- def add_stack_ops(tensors: Dict[str, torch.Tensor]) -> None:
577
  """Add RET, PUSH, POP circuit components.
578
 
579
  These are higher-level operations that use memory read/write.
580
  We create the control logic gates.
581
 
 
 
 
 
582
  RET: Pop return address from stack, jump to it
583
  PUSH: Decrement SP, write value to [SP]
584
  POP: Read value from [SP], increment SP
585
  """
586
- # SP decrement for PUSH (16-bit)
587
- for bit in range(16):
588
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
589
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
590
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
591
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.borrow", [1.0, 1.0], [-2.0])
592
 
593
- # SP increment for POP (16-bit)
594
- for bit in range(16):
595
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
596
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
597
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
598
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
599
 
600
- # RET uses POP twice (for 16-bit address) then jumps
601
- # Buffer gates for return address
602
- for bit in range(16):
 
 
 
 
 
 
 
603
  add_gate(tensors, f"control.ret.addr.bit{bit}", [2.0], [-1.0])
604
 
605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
  def add_barrel_shifter(tensors: Dict[str, torch.Tensor]) -> None:
607
- """Add barrel shifter circuit.
 
608
 
609
- Shifts input by 0-7 positions based on 3-bit shift amount.
 
 
 
 
610
  Uses layers of 2:1 muxes controlled by shift amount bits.
611
 
612
- Layer 0: shift by 0 or 1 (controlled by shift[2], LSB)
613
- Layer 1: shift by 0 or 2 (controlled by shift[1])
614
- Layer 2: shift by 0 or 4 (controlled by shift[0], MSB)
615
  """
616
- # 3 layers of muxes, 8 bits each
617
- for layer in range(3):
618
- shift_amount = 1 << (2 - layer) # 4, 2, 1 for layers 0, 1, 2
619
- for bit in range(8):
 
 
 
620
  # 2:1 mux: if sel then shifted else original
621
- # NOT for inverting select
622
- add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.not_sel", [-1.0], [0.0])
623
- # AND gates
624
- add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.and_a", [1.0, 1.0], [-2.0])
625
- add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.and_b", [1.0, 1.0], [-2.0])
626
- # OR gate
627
- add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.or", [1.0, 1.0], [-1.0])
628
 
629
 
630
  def add_priority_encoder(tensors: Dict[str, torch.Tensor]) -> None:
631
- """Add priority encoder circuit.
 
 
632
 
633
- Finds the position of the highest set bit (0-7).
634
- Output is 3-bit index + valid flag.
635
 
636
- Uses cascaded comparisons: check bit 7 first, then 6, etc.
 
 
 
 
637
  """
638
- # Check each bit position (8 OR gates to detect any bit set at or above position)
639
- for pos in range(8):
640
- # OR of bits pos through 7
641
- num_inputs = 8 - pos
 
 
 
642
  weights = [1.0] * num_inputs
643
- add_gate(tensors, f"combinational.priorityencoder.any_ge{pos}",
644
- weights, [-1.0])
645
 
646
  # Priority logic: pos N is highest if bit N is set AND no higher bit is set
647
- for pos in range(8):
648
- # bit[pos] AND NOT(any bit > pos)
649
- add_gate(tensors, f"combinational.priorityencoder.is_highest{pos}.not_higher", [-1.0], [0.0])
650
- add_gate(tensors, f"combinational.priorityencoder.is_highest{pos}.and", [1.0, 1.0], [-2.0])
651
-
652
- # Encode position to 3-bit output
653
- # out[0] (LSB): positions 1,3,5,7
654
- # out[1]: positions 2,3,6,7
655
- # out[2] (MSB): positions 4,5,6,7
656
- for out_bit in range(3):
657
  weights = []
658
- for pos in range(8):
659
  if (pos >> out_bit) & 1:
660
  weights.append(1.0)
661
  if weights:
662
- add_gate(tensors, f"combinational.priorityencoder.out{out_bit}",
663
- weights, [-1.0])
664
 
665
  # Valid flag: any bit set
666
- add_gate(tensors, f"combinational.priorityencoder.valid",
667
- [1.0] * 8, [-1.0])
668
 
669
 
670
  def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
@@ -908,10 +998,19 @@ def add_neg_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
908
  add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
909
 
910
 
911
- def update_manifest(tensors: Dict[str, torch.Tensor], addr_bits: int, mem_bytes: int) -> None:
 
 
 
 
 
 
 
 
 
912
  tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32)
913
  tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
914
- tensors["manifest.version"] = torch.tensor([3.0], dtype=torch.float32)
915
 
916
 
917
  def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None:
@@ -2028,6 +2127,10 @@ def cmd_memory(args) -> None:
2028
  drop_prefixes(tensors, [
2029
  "memory.addr_decode.", "memory.read.", "memory.write.",
2030
  "control.fetch.ir.", "control.load.", "control.store.", "control.mem_addr.",
 
 
 
 
2031
  ])
2032
  print(f" Now {len(tensors)} tensors")
2033
 
@@ -2040,16 +2143,42 @@ def cmd_memory(args) -> None:
2040
 
2041
  print("\nGenerating buffer gates...")
2042
  try:
2043
- add_fetch_load_store_buffers(tensors, addr_bits)
2044
- print(" Added fetch/load/store/mem_addr buffers")
2045
  except ValueError as e:
2046
  print(f" Buffers already exist: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2047
  else:
2048
  print("\nSkipping memory circuits (addr_bits=0, pure ALU mode)")
2049
 
2050
  print("\nUpdating manifest...")
2051
- update_manifest(tensors, addr_bits, mem_bytes)
2052
- print(f" memory_bytes={mem_bytes:,}, pc_width={addr_bits}")
2053
 
2054
  if args.apply:
2055
  print(f"\nSaving: {args.model}")
@@ -2113,7 +2242,6 @@ def cmd_alu(args) -> None:
2113
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
2114
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
2115
  "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
2116
- "control.push.", "control.pop.", "control.ret.",
2117
  "combinational.barrelshifter.", "combinational.priorityencoder.",
2118
  ]
2119
 
@@ -2164,12 +2292,6 @@ def cmd_alu(args) -> None:
2164
  print(" Added ROL (8 gates), ROR (8 gates)")
2165
  except ValueError as e:
2166
  print(f" ROL/ROR already exist: {e}")
2167
- print("\nGenerating stack operation circuits...")
2168
- try:
2169
- add_stack_ops(tensors)
2170
- print(" Added PUSH/POP/RET (144 gates)")
2171
- except ValueError as e:
2172
- print(f" Stack ops already exist: {e}")
2173
  print("\nGenerating barrel shifter...")
2174
  try:
2175
  add_barrel_shifter(tensors)
 
255
  tensors["memory.write.or.bias"] = or_bias
256
 
257
 
258
+ def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int) -> None:
259
+ """Add control buffers for fetch, load, store operations.
260
+
261
+ Args:
262
+ data_bits: Width of data bus (8/16/32)
263
+ addr_bits: Width of address bus (determines instruction register width)
264
+ """
265
+ # Instruction register width = opcode (8) + operands (depends on arch)
266
+ # For simplicity, IR width = max(16, addr_bits) to hold jump targets
267
+ ir_bits = max(16, addr_bits)
268
+ for bit in range(ir_bits):
269
  add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0])
270
+ for bit in range(data_bits):
271
  add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0])
272
  add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0])
273
  for bit in range(addr_bits):
 
564
 
565
 
566
  def add_rol_ror(tensors: Dict[str, torch.Tensor]) -> None:
567
+ """Add 8-bit ROL and ROR circuits (legacy wrapper)."""
568
+ add_rol_ror_nbits(tensors, 8)
569
+
570
 
571
+ def add_rol_ror_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
572
+ """Add N-bit ROL and ROR circuits (rotate left/right).
573
 
574
+ ROL: out[i] = in[i+1] for i<N-1, out[N-1] = in[0] (MSB wraps to LSB)
575
+ ROR: out[0] = in[N-1], out[i] = in[i-1] for i>0 (LSB wraps to MSB)
576
+
577
+ Args:
578
+ bits: Data width (8, 16, 32, etc.)
579
  """
580
  # ROL: rotate left (toward MSB)
581
+ for bit in range(bits):
582
+ add_gate(tensors, f"alu.alu{bits}bit.rol.bit{bit}", [2.0], [-1.0])
 
583
 
584
  # ROR: rotate right (toward LSB)
585
+ for bit in range(bits):
586
+ add_gate(tensors, f"alu.alu{bits}bit.ror.bit{bit}", [2.0], [-1.0])
 
587
 
588
 
589
+ def add_stack_ops(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int) -> None:
590
  """Add RET, PUSH, POP circuit components.
591
 
592
  These are higher-level operations that use memory read/write.
593
  We create the control logic gates.
594
 
595
+ Args:
596
+ data_bits: Width of data to push/pop (8/16/32)
597
+ addr_bits: Width of stack pointer and return addresses
598
+
599
  RET: Pop return address from stack, jump to it
600
  PUSH: Decrement SP, write value to [SP]
601
  POP: Read value from [SP], increment SP
602
  """
603
+ # SP decrement for PUSH (addr_bits wide)
604
+ for bit in range(addr_bits):
605
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
606
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
607
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
608
  add_gate(tensors, f"control.push.sp_dec.bit{bit}.borrow", [1.0, 1.0], [-2.0])
609
 
610
+ # SP increment for POP (addr_bits wide)
611
+ for bit in range(addr_bits):
612
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
613
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
614
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
615
  add_gate(tensors, f"control.pop.sp_inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
616
 
617
+ # Data buffers for PUSH (data_bits wide)
618
+ for bit in range(data_bits):
619
+ add_gate(tensors, f"control.push.data.bit{bit}", [2.0], [-1.0])
620
+
621
+ # Data buffers for POP (data_bits wide)
622
+ for bit in range(data_bits):
623
+ add_gate(tensors, f"control.pop.data.bit{bit}", [2.0], [-1.0])
624
+
625
+ # RET: Buffer gates for return address (addr_bits wide)
626
+ for bit in range(addr_bits):
627
  add_gate(tensors, f"control.ret.addr.bit{bit}", [2.0], [-1.0])
628
 
629
 
630
+ def add_conditional_jumps(tensors: Dict[str, torch.Tensor], addr_bits: int) -> None:
631
+ """Add conditional jump circuits (JZ, JNZ, JC, JNC, JP, JN, JV, JNV).
632
+
633
+ Each conditional jump is a 2:1 MUX per address bit:
634
+ - If flag is set: output = target_bit
635
+ - If flag is clear: output = pc_bit
636
+
637
+ Structure per bit:
638
+ - not_sel: NOT(flag)
639
+ - and_a: pc_bit AND NOT(flag)
640
+ - and_b: target_bit AND flag
641
+ - or: and_a OR and_b
642
+
643
+ Args:
644
+ addr_bits: Width of program counter / jump target
645
+ """
646
+ jump_types = ['jz', 'jnz', 'jc', 'jnc', 'jp', 'jn', 'jv', 'jnv']
647
+
648
+ for jmp in jump_types:
649
+ for bit in range(addr_bits):
650
+ prefix = f"control.{jmp}.bit{bit}"
651
+ # NOT sel (invert flag)
652
+ add_gate(tensors, f"{prefix}.not_sel", [-1.0], [0.0])
653
+ # AND a: pc_bit AND NOT(flag)
654
+ add_gate(tensors, f"{prefix}.and_a", [1.0, 1.0], [-2.0])
655
+ # AND b: target_bit AND flag
656
+ add_gate(tensors, f"{prefix}.and_b", [1.0, 1.0], [-2.0])
657
+ # OR: combine
658
+ add_gate(tensors, f"{prefix}.or", [1.0, 1.0], [-1.0])
659
+
660
+
661
+ def add_status_flags(tensors: Dict[str, torch.Tensor], data_bits: int) -> None:
662
+ """Add status flag computation circuits (Z, N, C, V).
663
+
664
+ Args:
665
+ data_bits: Width of ALU data (8/16/32)
666
+
667
+ Flags:
668
+ - Z (Zero): NOR of all result bits (1 if result == 0)
669
+ - N (Negative): Copy of MSB (sign bit)
670
+ - C (Carry): Carry out from adder (external input)
671
+ - V (Overflow): XOR of carry into and out of MSB (signed overflow)
672
+ """
673
+ # Z flag: NOR of all bits (result == 0)
674
+ # Single threshold gate: fires if sum of all bits < 1
675
+ add_gate(tensors, "flags.zero", [-1.0] * data_bits, [0.0])
676
+
677
+ # N flag: Buffer for MSB (sign bit)
678
+ add_gate(tensors, "flags.negative", [2.0], [-1.0])
679
+
680
+ # C flag: Buffer for carry out (input from adder)
681
+ add_gate(tensors, "flags.carry", [2.0], [-1.0])
682
+
683
+ # V flag: XOR of carry_in_msb and carry_out_msb
684
+ # Two-layer XOR: (A OR B) AND (A NAND B)
685
+ add_gate(tensors, "flags.overflow.layer1.or", [1.0, 1.0], [-1.0])
686
+ add_gate(tensors, "flags.overflow.layer1.nand", [-1.0, -1.0], [1.0])
687
+ add_gate(tensors, "flags.overflow.layer2", [1.0, 1.0], [-2.0])
688
+
689
+
690
  def add_barrel_shifter(tensors: Dict[str, torch.Tensor]) -> None:
691
+ """Add 8-bit barrel shifter circuit (legacy wrapper)."""
692
+ add_barrel_shifter_nbits(tensors, 8)
693
 
694
+
695
+ def add_barrel_shifter_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
696
+ """Add N-bit barrel shifter circuit.
697
+
698
+ Shifts input by 0 to (bits-1) positions based on ceil(log2(bits))-bit shift amount.
699
  Uses layers of 2:1 muxes controlled by shift amount bits.
700
 
701
+ Args:
702
+ bits: Data width (8, 16, 32, etc.)
 
703
  """
704
+ import math
705
+ num_layers = max(1, math.ceil(math.log2(bits)))
706
+
707
+ for layer in range(num_layers):
708
+ shift_amount = 1 << (num_layers - 1 - layer)
709
+ for bit in range(bits):
710
+ prefix = f"combinational.barrelshifter{bits}.layer{layer}.bit{bit}"
711
  # 2:1 mux: if sel then shifted else original
712
+ add_gate(tensors, f"{prefix}.not_sel", [-1.0], [0.0])
713
+ add_gate(tensors, f"{prefix}.and_a", [1.0, 1.0], [-2.0])
714
+ add_gate(tensors, f"{prefix}.and_b", [1.0, 1.0], [-2.0])
715
+ add_gate(tensors, f"{prefix}.or", [1.0, 1.0], [-1.0])
 
 
 
716
 
717
 
718
  def add_priority_encoder(tensors: Dict[str, torch.Tensor]) -> None:
719
+ """Add 8-bit priority encoder circuit (legacy wrapper)."""
720
+ add_priority_encoder_nbits(tensors, 8)
721
+
722
 
723
+ def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
724
+ """Add N-bit priority encoder circuit.
725
 
726
+ Finds the position of the highest set bit (0 to bits-1).
727
+ Output is ceil(log2(bits))-bit index + valid flag.
728
+
729
+ Args:
730
+ bits: Input width (8, 16, 32, etc.)
731
  """
732
+ import math
733
+ out_bits = max(1, math.ceil(math.log2(bits)))
734
+ prefix = f"combinational.priorityencoder{bits}"
735
+
736
+ # Check each bit position (OR gates to detect any bit set at or above position)
737
+ for pos in range(bits):
738
+ num_inputs = bits - pos
739
  weights = [1.0] * num_inputs
740
+ add_gate(tensors, f"{prefix}.any_ge{pos}", weights, [-1.0])
 
741
 
742
  # Priority logic: pos N is highest if bit N is set AND no higher bit is set
743
+ for pos in range(bits):
744
+ add_gate(tensors, f"{prefix}.is_highest{pos}.not_higher", [-1.0], [0.0])
745
+ add_gate(tensors, f"{prefix}.is_highest{pos}.and", [1.0, 1.0], [-2.0])
746
+
747
+ # Encode position to output bits
748
+ for out_bit in range(out_bits):
 
 
 
 
749
  weights = []
750
+ for pos in range(bits):
751
  if (pos >> out_bit) & 1:
752
  weights.append(1.0)
753
  if weights:
754
+ add_gate(tensors, f"{prefix}.out{out_bit}", weights, [-1.0])
 
755
 
756
  # Valid flag: any bit set
757
+ add_gate(tensors, f"{prefix}.valid", [1.0] * bits, [-1.0])
 
758
 
759
 
760
  def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
 
998
  add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
999
 
1000
 
1001
+ def update_manifest(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int, mem_bytes: int) -> None:
1002
+ """Update manifest metadata tensors.
1003
+
1004
+ Args:
1005
+ data_bits: ALU/register width (8/16/32)
1006
+ addr_bits: Address bus width (determines memory size)
1007
+ mem_bytes: Memory size in bytes (2^addr_bits)
1008
+ """
1009
+ tensors["manifest.data_bits"] = torch.tensor([float(data_bits)], dtype=torch.float32)
1010
+ tensors["manifest.addr_bits"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
1011
  tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32)
1012
  tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
1013
+ tensors["manifest.version"] = torch.tensor([4.0], dtype=torch.float32) # Bump version for N-bit support
1014
 
1015
 
1016
  def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None:
 
2127
  drop_prefixes(tensors, [
2128
  "memory.addr_decode.", "memory.read.", "memory.write.",
2129
  "control.fetch.ir.", "control.load.", "control.store.", "control.mem_addr.",
2130
+ "control.push.", "control.pop.", "control.ret.",
2131
+ "control.jz.", "control.jnz.", "control.jc.", "control.jnc.",
2132
+ "control.jp.", "control.jn.", "control.jv.", "control.jnv.",
2133
+ "flags.",
2134
  ])
2135
  print(f" Now {len(tensors)} tensors")
2136
 
 
2143
 
2144
  print("\nGenerating buffer gates...")
2145
  try:
2146
+ add_fetch_load_store_buffers(tensors, args.bits, addr_bits)
2147
+ print(f" Added fetch/load/store/mem_addr buffers ({args.bits}-bit data, {addr_bits}-bit addr)")
2148
  except ValueError as e:
2149
  print(f" Buffers already exist: {e}")
2150
+
2151
+ print("\nGenerating stack operation circuits...")
2152
+ try:
2153
+ add_stack_ops(tensors, args.bits, addr_bits)
2154
+ sp_gates = addr_bits * 4 * 2 # SP inc/dec gates
2155
+ data_gates = args.bits * 2 # PUSH/POP data buffers
2156
+ ret_gates = addr_bits # RET address buffers
2157
+ total_gates = sp_gates + data_gates + ret_gates
2158
+ print(f" Added PUSH/POP/RET ({total_gates} gates: {args.bits}-bit data, {addr_bits}-bit SP)")
2159
+ except ValueError as e:
2160
+ print(f" Stack ops already exist: {e}")
2161
+
2162
+ print("\nGenerating conditional jump circuits...")
2163
+ try:
2164
+ add_conditional_jumps(tensors, addr_bits)
2165
+ jump_gates = 8 * addr_bits * 4 # 8 jump types × addr_bits × 4 gates each
2166
+ print(f" Added JZ/JNZ/JC/JNC/JP/JN/JV/JNV ({jump_gates} gates: {addr_bits}-bit addresses)")
2167
+ except ValueError as e:
2168
+ print(f" Conditional jumps already exist: {e}")
2169
+
2170
+ print("\nGenerating status flag circuits...")
2171
+ try:
2172
+ add_status_flags(tensors, args.bits)
2173
+ print(f" Added Z/N/C/V flags ({args.bits}-bit aware)")
2174
+ except ValueError as e:
2175
+ print(f" Status flags already exist: {e}")
2176
  else:
2177
  print("\nSkipping memory circuits (addr_bits=0, pure ALU mode)")
2178
 
2179
  print("\nUpdating manifest...")
2180
+ update_manifest(tensors, args.bits, addr_bits, mem_bytes)
2181
+ print(f" data_bits={args.bits}, addr_bits={addr_bits}, memory_bytes={mem_bytes:,}")
2182
 
2183
  if args.apply:
2184
  print(f"\nSaving: {args.model}")
 
2242
  "arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
2243
  "arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
2244
  "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
 
2245
  "combinational.barrelshifter.", "combinational.priorityencoder.",
2246
  ]
2247
 
 
2292
  print(" Added ROL (8 gates), ROR (8 gates)")
2293
  except ValueError as e:
2294
  print(f" ROL/ROR already exist: {e}")
 
 
 
 
 
 
2295
  print("\nGenerating barrel shifter...")
2296
  try:
2297
  add_barrel_shifter(tensors)
eval.py CHANGED
@@ -67,6 +67,21 @@ def load_metadata(path: str = MODEL_PATH) -> Dict:
67
  return {'signal_registry': {}}
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def create_population(
71
  base_tensors: Dict[str, torch.Tensor],
72
  pop_size: int,
@@ -889,7 +904,7 @@ class BatchedFitnessEvaluator:
889
  Tests all circuits comprehensively.
890
  """
891
 
892
- def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH):
893
  self.device = device
894
  self.model_path = model_path
895
  self.metadata = load_metadata(model_path)
@@ -897,6 +912,16 @@ class BatchedFitnessEvaluator:
897
  self.results: List[CircuitResult] = []
898
  self.category_scores: Dict[str, Tuple[float, int]] = {}
899
  self.total_tests = 0
 
 
 
 
 
 
 
 
 
 
900
  self._setup_tests()
901
 
902
  def _setup_tests(self):
@@ -2897,7 +2922,7 @@ class BatchedFitnessEvaluator:
2897
  # =========================================================================
2898
 
2899
  def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
2900
- """Test conditional jump circuit."""
2901
  pop_size = next(iter(pop.values())).shape[0]
2902
  prefix = f'control.{name}'
2903
 
@@ -2911,7 +2936,7 @@ class BatchedFitnessEvaluator:
2911
  scores = torch.zeros(pop_size, device=self.device)
2912
  total = 0
2913
 
2914
- for bit in range(8):
2915
  bit_prefix = f'{prefix}.bit{bit}'
2916
  try:
2917
  # NOT sel
@@ -2979,27 +3004,34 @@ class BatchedFitnessEvaluator:
2979
  return scores, total
2980
 
2981
  def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
2982
- """Test PUSH/POP/RET stack operation circuits."""
2983
  pop_size = next(iter(pop.values())).shape[0]
2984
  scores = torch.zeros(pop_size, device=self.device)
2985
  total = 0
 
 
2986
 
2987
  if debug:
2988
- print("\n=== STACK OPERATIONS ===")
2989
 
2990
- # Test PUSH SP decrement (16-bit, borrow chain)
2991
  try:
2992
- sp_tests = [0x0000, 0x0001, 0x0100, 0x8000, 0xFFFF, 0x1234]
 
 
 
 
 
2993
  op_scores = torch.zeros(pop_size, device=self.device)
2994
  op_total = 0
2995
 
2996
  for sp_val in sp_tests:
2997
- expected_val = (sp_val - 1) & 0xFFFF
2998
- sp_bits = [float((sp_val >> (15 - i)) & 1) for i in range(16)]
2999
 
3000
  borrow = 1.0
3001
  out_bits = []
3002
- for bit in range(15, -1, -1): # LSB to MSB
3003
  prefix = f'control.push.sp_dec.bit{bit}'
3004
 
3005
  w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
@@ -3024,11 +3056,11 @@ class BatchedFitnessEvaluator:
3024
  borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item()
3025
 
3026
  out = torch.stack(out_bits, dim=-1)
3027
- expected = torch.tensor([((expected_val >> (15 - i)) & 1) for i in range(16)],
3028
  device=self.device, dtype=torch.float32)
3029
  correct = (out == expected.unsqueeze(0)).float().sum(1)
3030
  op_scores += correct
3031
- op_total += 16
3032
 
3033
  scores += op_scores
3034
  total += op_total
@@ -3040,18 +3072,18 @@ class BatchedFitnessEvaluator:
3040
  if debug:
3041
  print(f" control.push.sp_dec: SKIP ({e})")
3042
 
3043
- # Test POP SP increment (16-bit, carry chain)
3044
  try:
3045
  op_scores = torch.zeros(pop_size, device=self.device)
3046
  op_total = 0
3047
 
3048
  for sp_val in sp_tests:
3049
- expected_val = (sp_val + 1) & 0xFFFF
3050
- sp_bits = [float((sp_val >> (15 - i)) & 1) for i in range(16)]
3051
 
3052
  carry = 1.0
3053
  out_bits = []
3054
- for bit in range(15, -1, -1): # LSB to MSB
3055
  prefix = f'control.pop.sp_inc.bit{bit}'
3056
 
3057
  w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
@@ -3074,11 +3106,11 @@ class BatchedFitnessEvaluator:
3074
  carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item()
3075
 
3076
  out = torch.stack(out_bits, dim=-1)
3077
- expected = torch.tensor([((expected_val >> (15 - i)) & 1) for i in range(16)],
3078
  device=self.device, dtype=torch.float32)
3079
  correct = (out == expected.unsqueeze(0)).float().sum(1)
3080
  op_scores += correct
3081
- op_total += 16
3082
 
3083
  scores += op_scores
3084
  total += op_total
@@ -3090,27 +3122,29 @@ class BatchedFitnessEvaluator:
3090
  if debug:
3091
  print(f" control.pop.sp_inc: SKIP ({e})")
3092
 
3093
- # Test RET address buffer (16 identity gates)
3094
  try:
3095
  op_scores = torch.zeros(pop_size, device=self.device)
3096
  op_total = 0
3097
 
3098
- addr_tests = [0x0000, 0xFFFF, 0x1234, 0x8000, 0x00FF]
3099
- for addr_val in addr_tests:
3100
- addr_bits = torch.tensor([float((addr_val >> (15 - i)) & 1) for i in range(16)],
 
 
3101
  device=self.device, dtype=torch.float32)
3102
 
3103
  out_bits = []
3104
- for bit in range(16):
3105
  w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size)
3106
  b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size)
3107
- out = heaviside(addr_bits[bit] * w + b)
3108
  out_bits.append(out)
3109
 
3110
  out = torch.stack(out_bits, dim=-1)
3111
- correct = (out == addr_bits.unsqueeze(0)).float().sum(1)
3112
  op_scores += correct
3113
- op_total += 16
3114
 
3115
  scores += op_scores
3116
  total += op_total
 
67
  return {'signal_registry': {}}
68
 
69
 
70
+ def get_manifest(tensors: Dict[str, torch.Tensor]) -> Dict[str, int]:
71
+ """Extract manifest values from tensors.
72
+
73
+ Returns dict with data_bits, addr_bits, memory_bytes, version.
74
+ Defaults to 8-bit data, 16-bit addr for legacy models.
75
+ """
76
+ return {
77
+ 'data_bits': int(tensors.get('manifest.data_bits', torch.tensor([8.0])).item()),
78
+ 'addr_bits': int(tensors.get('manifest.addr_bits',
79
+ tensors.get('manifest.pc_width', torch.tensor([16.0]))).item()),
80
+ 'memory_bytes': int(tensors.get('manifest.memory_bytes', torch.tensor([65536.0])).item()),
81
+ 'version': float(tensors.get('manifest.version', torch.tensor([1.0])).item()),
82
+ }
83
+
84
+
85
  def create_population(
86
  base_tensors: Dict[str, torch.Tensor],
87
  pop_size: int,
 
904
  Tests all circuits comprehensively.
905
  """
906
 
907
+ def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH, tensors: Dict[str, torch.Tensor] = None):
908
  self.device = device
909
  self.model_path = model_path
910
  self.metadata = load_metadata(model_path)
 
912
  self.results: List[CircuitResult] = []
913
  self.category_scores: Dict[str, Tuple[float, int]] = {}
914
  self.total_tests = 0
915
+
916
+ # Get manifest for N-bit support
917
+ if tensors is not None:
918
+ self.manifest = get_manifest(tensors)
919
+ else:
920
+ base_tensors = load_model(model_path)
921
+ self.manifest = get_manifest(base_tensors)
922
+ self.data_bits = self.manifest['data_bits']
923
+ self.addr_bits = self.manifest['addr_bits']
924
+
925
  self._setup_tests()
926
 
927
  def _setup_tests(self):
 
2922
  # =========================================================================
2923
 
2924
  def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
2925
+ """Test conditional jump circuit (N-bit address aware)."""
2926
  pop_size = next(iter(pop.values())).shape[0]
2927
  prefix = f'control.{name}'
2928
 
 
2936
  scores = torch.zeros(pop_size, device=self.device)
2937
  total = 0
2938
 
2939
+ for bit in range(self.addr_bits):
2940
  bit_prefix = f'{prefix}.bit{bit}'
2941
  try:
2942
  # NOT sel
 
3004
  return scores, total
3005
 
3006
  def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
3007
+ """Test PUSH/POP/RET stack operation circuits (N-bit address aware)."""
3008
  pop_size = next(iter(pop.values())).shape[0]
3009
  scores = torch.zeros(pop_size, device=self.device)
3010
  total = 0
3011
+ addr_bits = self.addr_bits
3012
+ addr_mask = (1 << addr_bits) - 1
3013
 
3014
  if debug:
3015
+ print(f"\n=== STACK OPERATIONS ({addr_bits}-bit SP) ===")
3016
 
3017
+ # Test PUSH SP decrement (addr_bits wide, borrow chain)
3018
  try:
3019
+ # Generate test values appropriate for addr_bits
3020
+ sp_tests = [0, 1, addr_mask // 2, addr_mask]
3021
+ if addr_bits >= 8:
3022
+ sp_tests.append(0x100 & addr_mask)
3023
+ if addr_bits >= 12:
3024
+ sp_tests.append(0x1234 & addr_mask)
3025
  op_scores = torch.zeros(pop_size, device=self.device)
3026
  op_total = 0
3027
 
3028
  for sp_val in sp_tests:
3029
+ expected_val = (sp_val - 1) & addr_mask
3030
+ sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)]
3031
 
3032
  borrow = 1.0
3033
  out_bits = []
3034
+ for bit in range(addr_bits - 1, -1, -1): # LSB to MSB
3035
  prefix = f'control.push.sp_dec.bit{bit}'
3036
 
3037
  w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
 
3056
  borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item()
3057
 
3058
  out = torch.stack(out_bits, dim=-1)
3059
+ expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
3060
  device=self.device, dtype=torch.float32)
3061
  correct = (out == expected.unsqueeze(0)).float().sum(1)
3062
  op_scores += correct
3063
+ op_total += addr_bits
3064
 
3065
  scores += op_scores
3066
  total += op_total
 
3072
  if debug:
3073
  print(f" control.push.sp_dec: SKIP ({e})")
3074
 
3075
+ # Test POP SP increment (addr_bits wide, carry chain)
3076
  try:
3077
  op_scores = torch.zeros(pop_size, device=self.device)
3078
  op_total = 0
3079
 
3080
  for sp_val in sp_tests:
3081
+ expected_val = (sp_val + 1) & addr_mask
3082
+ sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)]
3083
 
3084
  carry = 1.0
3085
  out_bits = []
3086
+ for bit in range(addr_bits - 1, -1, -1): # LSB to MSB
3087
  prefix = f'control.pop.sp_inc.bit{bit}'
3088
 
3089
  w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
 
3106
  carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item()
3107
 
3108
  out = torch.stack(out_bits, dim=-1)
3109
+ expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
3110
  device=self.device, dtype=torch.float32)
3111
  correct = (out == expected.unsqueeze(0)).float().sum(1)
3112
  op_scores += correct
3113
+ op_total += addr_bits
3114
 
3115
  scores += op_scores
3116
  total += op_total
 
3122
  if debug:
3123
  print(f" control.pop.sp_inc: SKIP ({e})")
3124
 
3125
+ # Test RET address buffer (addr_bits identity gates)
3126
  try:
3127
  op_scores = torch.zeros(pop_size, device=self.device)
3128
  op_total = 0
3129
 
3130
+ ret_tests = [0, addr_mask, addr_mask // 2, 1]
3131
+ if addr_bits >= 12:
3132
+ ret_tests.append(0x1234 & addr_mask)
3133
+ for addr_val in ret_tests:
3134
+ ret_bits_tensor = torch.tensor([float((addr_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
3135
  device=self.device, dtype=torch.float32)
3136
 
3137
  out_bits = []
3138
+ for bit in range(addr_bits):
3139
  w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size)
3140
  b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size)
3141
+ out = heaviside(ret_bits_tensor[bit] * w + b)
3142
  out_bits.append(out)
3143
 
3144
  out = torch.stack(out_bits, dim=-1)
3145
+ correct = (out == ret_bits_tensor.unsqueeze(0)).float().sum(1)
3146
  op_scores += correct
3147
+ op_total += addr_bits
3148
 
3149
  scores += op_scores
3150
  total += op_total
neural_alu32.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8a292e8d1dc5b29fd84d25d0333599a9946849e456aeb30b7519156dc150a623
3
- size 4985016
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5a0f6cdfb4ba0ebdfc863f43e5f8fd4f41626c0fd4e7258a0a581a117a79d97
3
+ size 5031612
neural_computer.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:eaabeed4fa50c13129fe4f83f6a8f31b6ccd41de12e83c62448460881373fc3e
3
- size 34838348
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08a39c4758f6e5236f84d231be7f2d54364099309a89cf484d607a6544194d20
3
+ size 2591660