CharlesCNorton
commited on
Commit
·
438de6a
1
Parent(s):
3a9b35a
Parameterize CPU circuits for N-bit data and address widths
Browse filesBuild changes:
- add_fetch_load_store_buffers: N-bit data bus support
- add_stack_ops: N-bit data, M-bit stack pointer
- add_conditional_jumps: 8 jump types with N-bit addresses
- add_status_flags: Z/N/C/V flags for N-bit ALU results
- add_rol_ror_nbits: Parameterized rotation circuits
- update_manifest: Version 4.0 with data_bits field
- Stack ops moved from cmd_alu to cmd_memory
Eval changes:
- get_manifest: Extract data_bits/addr_bits from model
- BatchedFitnessEvaluator reads manifest for N-bit support
- _test_conditional_jump uses self.addr_bits
- _test_stack_ops uses self.addr_bits for SP width
Tested configurations:
- 8-bit data / 10-bit addr (small memory)
- 32-bit data / 10-bit addr (small memory)
- build.py +197 -75
- eval.py +60 -26
- neural_alu32.safetensors +2 -2
- neural_computer.safetensors +2 -2
build.py
CHANGED
|
@@ -255,10 +255,19 @@ def add_memory_write_cells(tensors: Dict[str, torch.Tensor], mem_bytes: int) ->
|
|
| 255 |
tensors["memory.write.or.bias"] = or_bias
|
| 256 |
|
| 257 |
|
| 258 |
-
def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], addr_bits: int) -> None:
|
| 259 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0])
|
| 261 |
-
for bit in range(
|
| 262 |
add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0])
|
| 263 |
add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0])
|
| 264 |
for bit in range(addr_bits):
|
|
@@ -555,116 +564,197 @@ def add_neg(tensors: Dict[str, torch.Tensor]) -> None:
|
|
| 555 |
|
| 556 |
|
| 557 |
def add_rol_ror(tensors: Dict[str, torch.Tensor]) -> None:
|
| 558 |
-
"""Add ROL and ROR circuits (
|
|
|
|
|
|
|
| 559 |
|
| 560 |
-
|
| 561 |
-
|
| 562 |
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
"""
|
| 565 |
# ROL: rotate left (toward MSB)
|
| 566 |
-
for bit in range(
|
| 567 |
-
|
| 568 |
-
add_gate(tensors, f"alu.alu8bit.rol.bit{bit}", [2.0], [-1.0])
|
| 569 |
|
| 570 |
# ROR: rotate right (toward LSB)
|
| 571 |
-
for bit in range(
|
| 572 |
-
|
| 573 |
-
add_gate(tensors, f"alu.alu8bit.ror.bit{bit}", [2.0], [-1.0])
|
| 574 |
|
| 575 |
|
| 576 |
-
def add_stack_ops(tensors: Dict[str, torch.Tensor]) -> None:
|
| 577 |
"""Add RET, PUSH, POP circuit components.
|
| 578 |
|
| 579 |
These are higher-level operations that use memory read/write.
|
| 580 |
We create the control logic gates.
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
RET: Pop return address from stack, jump to it
|
| 583 |
PUSH: Decrement SP, write value to [SP]
|
| 584 |
POP: Read value from [SP], increment SP
|
| 585 |
"""
|
| 586 |
-
# SP decrement for PUSH (
|
| 587 |
-
for bit in range(
|
| 588 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
|
| 589 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
|
| 590 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
|
| 591 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.borrow", [1.0, 1.0], [-2.0])
|
| 592 |
|
| 593 |
-
# SP increment for POP (
|
| 594 |
-
for bit in range(
|
| 595 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
|
| 596 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
|
| 597 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
|
| 598 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
|
| 599 |
|
| 600 |
-
#
|
| 601 |
-
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
add_gate(tensors, f"control.ret.addr.bit{bit}", [2.0], [-1.0])
|
| 604 |
|
| 605 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
def add_barrel_shifter(tensors: Dict[str, torch.Tensor]) -> None:
|
| 607 |
-
"""Add barrel shifter circuit.
|
|
|
|
| 608 |
|
| 609 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
Uses layers of 2:1 muxes controlled by shift amount bits.
|
| 611 |
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
Layer 2: shift by 0 or 4 (controlled by shift[0], MSB)
|
| 615 |
"""
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
|
|
|
|
|
|
|
|
|
| 620 |
# 2:1 mux: if sel then shifted else original
|
| 621 |
-
|
| 622 |
-
add_gate(tensors, f"
|
| 623 |
-
|
| 624 |
-
add_gate(tensors, f"
|
| 625 |
-
add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.and_b", [1.0, 1.0], [-2.0])
|
| 626 |
-
# OR gate
|
| 627 |
-
add_gate(tensors, f"combinational.barrelshifter.layer{layer}.bit{bit}.or", [1.0, 1.0], [-1.0])
|
| 628 |
|
| 629 |
|
| 630 |
def add_priority_encoder(tensors: Dict[str, torch.Tensor]) -> None:
|
| 631 |
-
"""Add priority encoder circuit.
|
|
|
|
|
|
|
| 632 |
|
| 633 |
-
|
| 634 |
-
|
| 635 |
|
| 636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
"""
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
|
|
|
|
|
|
|
|
|
| 642 |
weights = [1.0] * num_inputs
|
| 643 |
-
add_gate(tensors, f"
|
| 644 |
-
weights, [-1.0])
|
| 645 |
|
| 646 |
# Priority logic: pos N is highest if bit N is set AND no higher bit is set
|
| 647 |
-
for pos in range(
|
| 648 |
-
|
| 649 |
-
add_gate(tensors, f"
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
# out[0] (LSB): positions 1,3,5,7
|
| 654 |
-
# out[1]: positions 2,3,6,7
|
| 655 |
-
# out[2] (MSB): positions 4,5,6,7
|
| 656 |
-
for out_bit in range(3):
|
| 657 |
weights = []
|
| 658 |
-
for pos in range(
|
| 659 |
if (pos >> out_bit) & 1:
|
| 660 |
weights.append(1.0)
|
| 661 |
if weights:
|
| 662 |
-
add_gate(tensors, f"
|
| 663 |
-
weights, [-1.0])
|
| 664 |
|
| 665 |
# Valid flag: any bit set
|
| 666 |
-
add_gate(tensors, f"
|
| 667 |
-
[1.0] * 8, [-1.0])
|
| 668 |
|
| 669 |
|
| 670 |
def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
|
|
@@ -908,10 +998,19 @@ def add_neg_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
|
| 908 |
add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
|
| 909 |
|
| 910 |
|
| 911 |
-
def update_manifest(tensors: Dict[str, torch.Tensor], addr_bits: int, mem_bytes: int) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32)
|
| 913 |
tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
|
| 914 |
-
tensors["manifest.version"] = torch.tensor([
|
| 915 |
|
| 916 |
|
| 917 |
def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None:
|
|
@@ -2028,6 +2127,10 @@ def cmd_memory(args) -> None:
|
|
| 2028 |
drop_prefixes(tensors, [
|
| 2029 |
"memory.addr_decode.", "memory.read.", "memory.write.",
|
| 2030 |
"control.fetch.ir.", "control.load.", "control.store.", "control.mem_addr.",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2031 |
])
|
| 2032 |
print(f" Now {len(tensors)} tensors")
|
| 2033 |
|
|
@@ -2040,16 +2143,42 @@ def cmd_memory(args) -> None:
|
|
| 2040 |
|
| 2041 |
print("\nGenerating buffer gates...")
|
| 2042 |
try:
|
| 2043 |
-
add_fetch_load_store_buffers(tensors, addr_bits)
|
| 2044 |
-
print(" Added fetch/load/store/mem_addr buffers")
|
| 2045 |
except ValueError as e:
|
| 2046 |
print(f" Buffers already exist: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2047 |
else:
|
| 2048 |
print("\nSkipping memory circuits (addr_bits=0, pure ALU mode)")
|
| 2049 |
|
| 2050 |
print("\nUpdating manifest...")
|
| 2051 |
-
update_manifest(tensors, addr_bits, mem_bytes)
|
| 2052 |
-
print(f"
|
| 2053 |
|
| 2054 |
if args.apply:
|
| 2055 |
print(f"\nSaving: {args.model}")
|
|
@@ -2113,7 +2242,6 @@ def cmd_alu(args) -> None:
|
|
| 2113 |
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
|
| 2114 |
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
|
| 2115 |
"arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
|
| 2116 |
-
"control.push.", "control.pop.", "control.ret.",
|
| 2117 |
"combinational.barrelshifter.", "combinational.priorityencoder.",
|
| 2118 |
]
|
| 2119 |
|
|
@@ -2164,12 +2292,6 @@ def cmd_alu(args) -> None:
|
|
| 2164 |
print(" Added ROL (8 gates), ROR (8 gates)")
|
| 2165 |
except ValueError as e:
|
| 2166 |
print(f" ROL/ROR already exist: {e}")
|
| 2167 |
-
print("\nGenerating stack operation circuits...")
|
| 2168 |
-
try:
|
| 2169 |
-
add_stack_ops(tensors)
|
| 2170 |
-
print(" Added PUSH/POP/RET (144 gates)")
|
| 2171 |
-
except ValueError as e:
|
| 2172 |
-
print(f" Stack ops already exist: {e}")
|
| 2173 |
print("\nGenerating barrel shifter...")
|
| 2174 |
try:
|
| 2175 |
add_barrel_shifter(tensors)
|
|
|
|
| 255 |
tensors["memory.write.or.bias"] = or_bias
|
| 256 |
|
| 257 |
|
| 258 |
+
def add_fetch_load_store_buffers(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int) -> None:
|
| 259 |
+
"""Add control buffers for fetch, load, store operations.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
data_bits: Width of data bus (8/16/32)
|
| 263 |
+
addr_bits: Width of address bus (determines instruction register width)
|
| 264 |
+
"""
|
| 265 |
+
# Instruction register width = opcode (8) + operands (depends on arch)
|
| 266 |
+
# For simplicity, IR width = max(16, addr_bits) to hold jump targets
|
| 267 |
+
ir_bits = max(16, addr_bits)
|
| 268 |
+
for bit in range(ir_bits):
|
| 269 |
add_gate(tensors, f"control.fetch.ir.bit{bit}", [1.0], [-1.0])
|
| 270 |
+
for bit in range(data_bits):
|
| 271 |
add_gate(tensors, f"control.load.bit{bit}", [1.0], [-1.0])
|
| 272 |
add_gate(tensors, f"control.store.bit{bit}", [1.0], [-1.0])
|
| 273 |
for bit in range(addr_bits):
|
|
|
|
| 564 |
|
| 565 |
|
| 566 |
def add_rol_ror(tensors: Dict[str, torch.Tensor]) -> None:
|
| 567 |
+
"""Add 8-bit ROL and ROR circuits (legacy wrapper)."""
|
| 568 |
+
add_rol_ror_nbits(tensors, 8)
|
| 569 |
+
|
| 570 |
|
| 571 |
+
def add_rol_ror_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
| 572 |
+
"""Add N-bit ROL and ROR circuits (rotate left/right).
|
| 573 |
|
| 574 |
+
ROL: out[i] = in[i+1] for i<N-1, out[N-1] = in[0] (MSB wraps to LSB)
|
| 575 |
+
ROR: out[0] = in[N-1], out[i] = in[i-1] for i>0 (LSB wraps to MSB)
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
bits: Data width (8, 16, 32, etc.)
|
| 579 |
"""
|
| 580 |
# ROL: rotate left (toward MSB)
|
| 581 |
+
for bit in range(bits):
|
| 582 |
+
add_gate(tensors, f"alu.alu{bits}bit.rol.bit{bit}", [2.0], [-1.0])
|
|
|
|
| 583 |
|
| 584 |
# ROR: rotate right (toward LSB)
|
| 585 |
+
for bit in range(bits):
|
| 586 |
+
add_gate(tensors, f"alu.alu{bits}bit.ror.bit{bit}", [2.0], [-1.0])
|
|
|
|
| 587 |
|
| 588 |
|
| 589 |
+
def add_stack_ops(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int) -> None:
|
| 590 |
"""Add RET, PUSH, POP circuit components.
|
| 591 |
|
| 592 |
These are higher-level operations that use memory read/write.
|
| 593 |
We create the control logic gates.
|
| 594 |
|
| 595 |
+
Args:
|
| 596 |
+
data_bits: Width of data to push/pop (8/16/32)
|
| 597 |
+
addr_bits: Width of stack pointer and return addresses
|
| 598 |
+
|
| 599 |
RET: Pop return address from stack, jump to it
|
| 600 |
PUSH: Decrement SP, write value to [SP]
|
| 601 |
POP: Read value from [SP], increment SP
|
| 602 |
"""
|
| 603 |
+
# SP decrement for PUSH (addr_bits wide)
|
| 604 |
+
for bit in range(addr_bits):
|
| 605 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
|
| 606 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
|
| 607 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
|
| 608 |
add_gate(tensors, f"control.push.sp_dec.bit{bit}.borrow", [1.0, 1.0], [-2.0])
|
| 609 |
|
| 610 |
+
# SP increment for POP (addr_bits wide)
|
| 611 |
+
for bit in range(addr_bits):
|
| 612 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
|
| 613 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
|
| 614 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
|
| 615 |
add_gate(tensors, f"control.pop.sp_inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
|
| 616 |
|
| 617 |
+
# Data buffers for PUSH (data_bits wide)
|
| 618 |
+
for bit in range(data_bits):
|
| 619 |
+
add_gate(tensors, f"control.push.data.bit{bit}", [2.0], [-1.0])
|
| 620 |
+
|
| 621 |
+
# Data buffers for POP (data_bits wide)
|
| 622 |
+
for bit in range(data_bits):
|
| 623 |
+
add_gate(tensors, f"control.pop.data.bit{bit}", [2.0], [-1.0])
|
| 624 |
+
|
| 625 |
+
# RET: Buffer gates for return address (addr_bits wide)
|
| 626 |
+
for bit in range(addr_bits):
|
| 627 |
add_gate(tensors, f"control.ret.addr.bit{bit}", [2.0], [-1.0])
|
| 628 |
|
| 629 |
|
| 630 |
+
def add_conditional_jumps(tensors: Dict[str, torch.Tensor], addr_bits: int) -> None:
|
| 631 |
+
"""Add conditional jump circuits (JZ, JNZ, JC, JNC, JP, JN, JV, JNV).
|
| 632 |
+
|
| 633 |
+
Each conditional jump is a 2:1 MUX per address bit:
|
| 634 |
+
- If flag is set: output = target_bit
|
| 635 |
+
- If flag is clear: output = pc_bit
|
| 636 |
+
|
| 637 |
+
Structure per bit:
|
| 638 |
+
- not_sel: NOT(flag)
|
| 639 |
+
- and_a: pc_bit AND NOT(flag)
|
| 640 |
+
- and_b: target_bit AND flag
|
| 641 |
+
- or: and_a OR and_b
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
addr_bits: Width of program counter / jump target
|
| 645 |
+
"""
|
| 646 |
+
jump_types = ['jz', 'jnz', 'jc', 'jnc', 'jp', 'jn', 'jv', 'jnv']
|
| 647 |
+
|
| 648 |
+
for jmp in jump_types:
|
| 649 |
+
for bit in range(addr_bits):
|
| 650 |
+
prefix = f"control.{jmp}.bit{bit}"
|
| 651 |
+
# NOT sel (invert flag)
|
| 652 |
+
add_gate(tensors, f"{prefix}.not_sel", [-1.0], [0.0])
|
| 653 |
+
# AND a: pc_bit AND NOT(flag)
|
| 654 |
+
add_gate(tensors, f"{prefix}.and_a", [1.0, 1.0], [-2.0])
|
| 655 |
+
# AND b: target_bit AND flag
|
| 656 |
+
add_gate(tensors, f"{prefix}.and_b", [1.0, 1.0], [-2.0])
|
| 657 |
+
# OR: combine
|
| 658 |
+
add_gate(tensors, f"{prefix}.or", [1.0, 1.0], [-1.0])
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def add_status_flags(tensors: Dict[str, torch.Tensor], data_bits: int) -> None:
|
| 662 |
+
"""Add status flag computation circuits (Z, N, C, V).
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
data_bits: Width of ALU data (8/16/32)
|
| 666 |
+
|
| 667 |
+
Flags:
|
| 668 |
+
- Z (Zero): NOR of all result bits (1 if result == 0)
|
| 669 |
+
- N (Negative): Copy of MSB (sign bit)
|
| 670 |
+
- C (Carry): Carry out from adder (external input)
|
| 671 |
+
- V (Overflow): XOR of carry into and out of MSB (signed overflow)
|
| 672 |
+
"""
|
| 673 |
+
# Z flag: NOR of all bits (result == 0)
|
| 674 |
+
# Single threshold gate: fires if sum of all bits < 1
|
| 675 |
+
add_gate(tensors, "flags.zero", [-1.0] * data_bits, [0.0])
|
| 676 |
+
|
| 677 |
+
# N flag: Buffer for MSB (sign bit)
|
| 678 |
+
add_gate(tensors, "flags.negative", [2.0], [-1.0])
|
| 679 |
+
|
| 680 |
+
# C flag: Buffer for carry out (input from adder)
|
| 681 |
+
add_gate(tensors, "flags.carry", [2.0], [-1.0])
|
| 682 |
+
|
| 683 |
+
# V flag: XOR of carry_in_msb and carry_out_msb
|
| 684 |
+
# Two-layer XOR: (A OR B) AND (A NAND B)
|
| 685 |
+
add_gate(tensors, "flags.overflow.layer1.or", [1.0, 1.0], [-1.0])
|
| 686 |
+
add_gate(tensors, "flags.overflow.layer1.nand", [-1.0, -1.0], [1.0])
|
| 687 |
+
add_gate(tensors, "flags.overflow.layer2", [1.0, 1.0], [-2.0])
|
| 688 |
+
|
| 689 |
+
|
| 690 |
def add_barrel_shifter(tensors: Dict[str, torch.Tensor]) -> None:
|
| 691 |
+
"""Add 8-bit barrel shifter circuit (legacy wrapper)."""
|
| 692 |
+
add_barrel_shifter_nbits(tensors, 8)
|
| 693 |
|
| 694 |
+
|
| 695 |
+
def add_barrel_shifter_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
| 696 |
+
"""Add N-bit barrel shifter circuit.
|
| 697 |
+
|
| 698 |
+
Shifts input by 0 to (bits-1) positions based on ceil(log2(bits))-bit shift amount.
|
| 699 |
Uses layers of 2:1 muxes controlled by shift amount bits.
|
| 700 |
|
| 701 |
+
Args:
|
| 702 |
+
bits: Data width (8, 16, 32, etc.)
|
|
|
|
| 703 |
"""
|
| 704 |
+
import math
|
| 705 |
+
num_layers = max(1, math.ceil(math.log2(bits)))
|
| 706 |
+
|
| 707 |
+
for layer in range(num_layers):
|
| 708 |
+
shift_amount = 1 << (num_layers - 1 - layer)
|
| 709 |
+
for bit in range(bits):
|
| 710 |
+
prefix = f"combinational.barrelshifter{bits}.layer{layer}.bit{bit}"
|
| 711 |
# 2:1 mux: if sel then shifted else original
|
| 712 |
+
add_gate(tensors, f"{prefix}.not_sel", [-1.0], [0.0])
|
| 713 |
+
add_gate(tensors, f"{prefix}.and_a", [1.0, 1.0], [-2.0])
|
| 714 |
+
add_gate(tensors, f"{prefix}.and_b", [1.0, 1.0], [-2.0])
|
| 715 |
+
add_gate(tensors, f"{prefix}.or", [1.0, 1.0], [-1.0])
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
|
| 718 |
def add_priority_encoder(tensors: Dict[str, torch.Tensor]) -> None:
|
| 719 |
+
"""Add 8-bit priority encoder circuit (legacy wrapper)."""
|
| 720 |
+
add_priority_encoder_nbits(tensors, 8)
|
| 721 |
+
|
| 722 |
|
| 723 |
+
def add_priority_encoder_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
|
| 724 |
+
"""Add N-bit priority encoder circuit.
|
| 725 |
|
| 726 |
+
Finds the position of the highest set bit (0 to bits-1).
|
| 727 |
+
Output is ceil(log2(bits))-bit index + valid flag.
|
| 728 |
+
|
| 729 |
+
Args:
|
| 730 |
+
bits: Input width (8, 16, 32, etc.)
|
| 731 |
"""
|
| 732 |
+
import math
|
| 733 |
+
out_bits = max(1, math.ceil(math.log2(bits)))
|
| 734 |
+
prefix = f"combinational.priorityencoder{bits}"
|
| 735 |
+
|
| 736 |
+
# Check each bit position (OR gates to detect any bit set at or above position)
|
| 737 |
+
for pos in range(bits):
|
| 738 |
+
num_inputs = bits - pos
|
| 739 |
weights = [1.0] * num_inputs
|
| 740 |
+
add_gate(tensors, f"{prefix}.any_ge{pos}", weights, [-1.0])
|
|
|
|
| 741 |
|
| 742 |
# Priority logic: pos N is highest if bit N is set AND no higher bit is set
|
| 743 |
+
for pos in range(bits):
|
| 744 |
+
add_gate(tensors, f"{prefix}.is_highest{pos}.not_higher", [-1.0], [0.0])
|
| 745 |
+
add_gate(tensors, f"{prefix}.is_highest{pos}.and", [1.0, 1.0], [-2.0])
|
| 746 |
+
|
| 747 |
+
# Encode position to output bits
|
| 748 |
+
for out_bit in range(out_bits):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
weights = []
|
| 750 |
+
for pos in range(bits):
|
| 751 |
if (pos >> out_bit) & 1:
|
| 752 |
weights.append(1.0)
|
| 753 |
if weights:
|
| 754 |
+
add_gate(tensors, f"{prefix}.out{out_bit}", weights, [-1.0])
|
|
|
|
| 755 |
|
| 756 |
# Valid flag: any bit set
|
| 757 |
+
add_gate(tensors, f"{prefix}.valid", [1.0] * bits, [-1.0])
|
|
|
|
| 758 |
|
| 759 |
|
| 760 |
def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
|
|
|
|
| 998 |
add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
|
| 999 |
|
| 1000 |
|
| 1001 |
+
def update_manifest(tensors: Dict[str, torch.Tensor], data_bits: int, addr_bits: int, mem_bytes: int) -> None:
|
| 1002 |
+
"""Update manifest metadata tensors.
|
| 1003 |
+
|
| 1004 |
+
Args:
|
| 1005 |
+
data_bits: ALU/register width (8/16/32)
|
| 1006 |
+
addr_bits: Address bus width (determines memory size)
|
| 1007 |
+
mem_bytes: Memory size in bytes (2^addr_bits)
|
| 1008 |
+
"""
|
| 1009 |
+
tensors["manifest.data_bits"] = torch.tensor([float(data_bits)], dtype=torch.float32)
|
| 1010 |
+
tensors["manifest.addr_bits"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
|
| 1011 |
tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32)
|
| 1012 |
tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
|
| 1013 |
+
tensors["manifest.version"] = torch.tensor([4.0], dtype=torch.float32) # Bump version for N-bit support
|
| 1014 |
|
| 1015 |
|
| 1016 |
def write_manifest(path: Path, tensors: Dict[str, torch.Tensor]) -> None:
|
|
|
|
| 2127 |
drop_prefixes(tensors, [
|
| 2128 |
"memory.addr_decode.", "memory.read.", "memory.write.",
|
| 2129 |
"control.fetch.ir.", "control.load.", "control.store.", "control.mem_addr.",
|
| 2130 |
+
"control.push.", "control.pop.", "control.ret.",
|
| 2131 |
+
"control.jz.", "control.jnz.", "control.jc.", "control.jnc.",
|
| 2132 |
+
"control.jp.", "control.jn.", "control.jv.", "control.jnv.",
|
| 2133 |
+
"flags.",
|
| 2134 |
])
|
| 2135 |
print(f" Now {len(tensors)} tensors")
|
| 2136 |
|
|
|
|
| 2143 |
|
| 2144 |
print("\nGenerating buffer gates...")
|
| 2145 |
try:
|
| 2146 |
+
add_fetch_load_store_buffers(tensors, args.bits, addr_bits)
|
| 2147 |
+
print(f" Added fetch/load/store/mem_addr buffers ({args.bits}-bit data, {addr_bits}-bit addr)")
|
| 2148 |
except ValueError as e:
|
| 2149 |
print(f" Buffers already exist: {e}")
|
| 2150 |
+
|
| 2151 |
+
print("\nGenerating stack operation circuits...")
|
| 2152 |
+
try:
|
| 2153 |
+
add_stack_ops(tensors, args.bits, addr_bits)
|
| 2154 |
+
sp_gates = addr_bits * 4 * 2 # SP inc/dec gates
|
| 2155 |
+
data_gates = args.bits * 2 # PUSH/POP data buffers
|
| 2156 |
+
ret_gates = addr_bits # RET address buffers
|
| 2157 |
+
total_gates = sp_gates + data_gates + ret_gates
|
| 2158 |
+
print(f" Added PUSH/POP/RET ({total_gates} gates: {args.bits}-bit data, {addr_bits}-bit SP)")
|
| 2159 |
+
except ValueError as e:
|
| 2160 |
+
print(f" Stack ops already exist: {e}")
|
| 2161 |
+
|
| 2162 |
+
print("\nGenerating conditional jump circuits...")
|
| 2163 |
+
try:
|
| 2164 |
+
add_conditional_jumps(tensors, addr_bits)
|
| 2165 |
+
jump_gates = 8 * addr_bits * 4 # 8 jump types × addr_bits × 4 gates each
|
| 2166 |
+
print(f" Added JZ/JNZ/JC/JNC/JP/JN/JV/JNV ({jump_gates} gates: {addr_bits}-bit addresses)")
|
| 2167 |
+
except ValueError as e:
|
| 2168 |
+
print(f" Conditional jumps already exist: {e}")
|
| 2169 |
+
|
| 2170 |
+
print("\nGenerating status flag circuits...")
|
| 2171 |
+
try:
|
| 2172 |
+
add_status_flags(tensors, args.bits)
|
| 2173 |
+
print(f" Added Z/N/C/V flags ({args.bits}-bit aware)")
|
| 2174 |
+
except ValueError as e:
|
| 2175 |
+
print(f" Status flags already exist: {e}")
|
| 2176 |
else:
|
| 2177 |
print("\nSkipping memory circuits (addr_bits=0, pure ALU mode)")
|
| 2178 |
|
| 2179 |
print("\nUpdating manifest...")
|
| 2180 |
+
update_manifest(tensors, args.bits, addr_bits, mem_bytes)
|
| 2181 |
+
print(f" data_bits={args.bits}, addr_bits={addr_bits}, memory_bytes={mem_bytes:,}")
|
| 2182 |
|
| 2183 |
if args.apply:
|
| 2184 |
print(f"\nSaving: {args.model}")
|
|
|
|
| 2242 |
"arithmetic.greaterthan8bit.", "arithmetic.lessthan8bit.",
|
| 2243 |
"arithmetic.greaterorequal8bit.", "arithmetic.lessorequal8bit.",
|
| 2244 |
"arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
|
|
|
|
| 2245 |
"combinational.barrelshifter.", "combinational.priorityencoder.",
|
| 2246 |
]
|
| 2247 |
|
|
|
|
| 2292 |
print(" Added ROL (8 gates), ROR (8 gates)")
|
| 2293 |
except ValueError as e:
|
| 2294 |
print(f" ROL/ROR already exist: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2295 |
print("\nGenerating barrel shifter...")
|
| 2296 |
try:
|
| 2297 |
add_barrel_shifter(tensors)
|
eval.py
CHANGED
|
@@ -67,6 +67,21 @@ def load_metadata(path: str = MODEL_PATH) -> Dict:
|
|
| 67 |
return {'signal_registry': {}}
|
| 68 |
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def create_population(
|
| 71 |
base_tensors: Dict[str, torch.Tensor],
|
| 72 |
pop_size: int,
|
|
@@ -889,7 +904,7 @@ class BatchedFitnessEvaluator:
|
|
| 889 |
Tests all circuits comprehensively.
|
| 890 |
"""
|
| 891 |
|
| 892 |
-
def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH):
|
| 893 |
self.device = device
|
| 894 |
self.model_path = model_path
|
| 895 |
self.metadata = load_metadata(model_path)
|
|
@@ -897,6 +912,16 @@ class BatchedFitnessEvaluator:
|
|
| 897 |
self.results: List[CircuitResult] = []
|
| 898 |
self.category_scores: Dict[str, Tuple[float, int]] = {}
|
| 899 |
self.total_tests = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
self._setup_tests()
|
| 901 |
|
| 902 |
def _setup_tests(self):
|
|
@@ -2897,7 +2922,7 @@ class BatchedFitnessEvaluator:
|
|
| 2897 |
# =========================================================================
|
| 2898 |
|
| 2899 |
def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 2900 |
-
"""Test conditional jump circuit."""
|
| 2901 |
pop_size = next(iter(pop.values())).shape[0]
|
| 2902 |
prefix = f'control.{name}'
|
| 2903 |
|
|
@@ -2911,7 +2936,7 @@ class BatchedFitnessEvaluator:
|
|
| 2911 |
scores = torch.zeros(pop_size, device=self.device)
|
| 2912 |
total = 0
|
| 2913 |
|
| 2914 |
-
for bit in range(
|
| 2915 |
bit_prefix = f'{prefix}.bit{bit}'
|
| 2916 |
try:
|
| 2917 |
# NOT sel
|
|
@@ -2979,27 +3004,34 @@ class BatchedFitnessEvaluator:
|
|
| 2979 |
return scores, total
|
| 2980 |
|
| 2981 |
def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 2982 |
-
"""Test PUSH/POP/RET stack operation circuits."""
|
| 2983 |
pop_size = next(iter(pop.values())).shape[0]
|
| 2984 |
scores = torch.zeros(pop_size, device=self.device)
|
| 2985 |
total = 0
|
|
|
|
|
|
|
| 2986 |
|
| 2987 |
if debug:
|
| 2988 |
-
print("\n=== STACK OPERATIONS ===")
|
| 2989 |
|
| 2990 |
-
# Test PUSH SP decrement (
|
| 2991 |
try:
|
| 2992 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2993 |
op_scores = torch.zeros(pop_size, device=self.device)
|
| 2994 |
op_total = 0
|
| 2995 |
|
| 2996 |
for sp_val in sp_tests:
|
| 2997 |
-
expected_val = (sp_val - 1) &
|
| 2998 |
-
sp_bits = [float((sp_val >> (
|
| 2999 |
|
| 3000 |
borrow = 1.0
|
| 3001 |
out_bits = []
|
| 3002 |
-
for bit in range(
|
| 3003 |
prefix = f'control.push.sp_dec.bit{bit}'
|
| 3004 |
|
| 3005 |
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
|
|
@@ -3024,11 +3056,11 @@ class BatchedFitnessEvaluator:
|
|
| 3024 |
borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item()
|
| 3025 |
|
| 3026 |
out = torch.stack(out_bits, dim=-1)
|
| 3027 |
-
expected = torch.tensor([((expected_val >> (
|
| 3028 |
device=self.device, dtype=torch.float32)
|
| 3029 |
correct = (out == expected.unsqueeze(0)).float().sum(1)
|
| 3030 |
op_scores += correct
|
| 3031 |
-
op_total +=
|
| 3032 |
|
| 3033 |
scores += op_scores
|
| 3034 |
total += op_total
|
|
@@ -3040,18 +3072,18 @@ class BatchedFitnessEvaluator:
|
|
| 3040 |
if debug:
|
| 3041 |
print(f" control.push.sp_dec: SKIP ({e})")
|
| 3042 |
|
| 3043 |
-
# Test POP SP increment (
|
| 3044 |
try:
|
| 3045 |
op_scores = torch.zeros(pop_size, device=self.device)
|
| 3046 |
op_total = 0
|
| 3047 |
|
| 3048 |
for sp_val in sp_tests:
|
| 3049 |
-
expected_val = (sp_val + 1) &
|
| 3050 |
-
sp_bits = [float((sp_val >> (
|
| 3051 |
|
| 3052 |
carry = 1.0
|
| 3053 |
out_bits = []
|
| 3054 |
-
for bit in range(
|
| 3055 |
prefix = f'control.pop.sp_inc.bit{bit}'
|
| 3056 |
|
| 3057 |
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
|
|
@@ -3074,11 +3106,11 @@ class BatchedFitnessEvaluator:
|
|
| 3074 |
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item()
|
| 3075 |
|
| 3076 |
out = torch.stack(out_bits, dim=-1)
|
| 3077 |
-
expected = torch.tensor([((expected_val >> (
|
| 3078 |
device=self.device, dtype=torch.float32)
|
| 3079 |
correct = (out == expected.unsqueeze(0)).float().sum(1)
|
| 3080 |
op_scores += correct
|
| 3081 |
-
op_total +=
|
| 3082 |
|
| 3083 |
scores += op_scores
|
| 3084 |
total += op_total
|
|
@@ -3090,27 +3122,29 @@ class BatchedFitnessEvaluator:
|
|
| 3090 |
if debug:
|
| 3091 |
print(f" control.pop.sp_inc: SKIP ({e})")
|
| 3092 |
|
| 3093 |
-
# Test RET address buffer (
|
| 3094 |
try:
|
| 3095 |
op_scores = torch.zeros(pop_size, device=self.device)
|
| 3096 |
op_total = 0
|
| 3097 |
|
| 3098 |
-
|
| 3099 |
-
|
| 3100 |
-
|
|
|
|
|
|
|
| 3101 |
device=self.device, dtype=torch.float32)
|
| 3102 |
|
| 3103 |
out_bits = []
|
| 3104 |
-
for bit in range(
|
| 3105 |
w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size)
|
| 3106 |
b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size)
|
| 3107 |
-
out = heaviside(
|
| 3108 |
out_bits.append(out)
|
| 3109 |
|
| 3110 |
out = torch.stack(out_bits, dim=-1)
|
| 3111 |
-
correct = (out ==
|
| 3112 |
op_scores += correct
|
| 3113 |
-
op_total +=
|
| 3114 |
|
| 3115 |
scores += op_scores
|
| 3116 |
total += op_total
|
|
|
|
| 67 |
return {'signal_registry': {}}
|
| 68 |
|
| 69 |
|
| 70 |
+
def get_manifest(tensors: Dict[str, torch.Tensor]) -> Dict[str, int]:
|
| 71 |
+
"""Extract manifest values from tensors.
|
| 72 |
+
|
| 73 |
+
Returns dict with data_bits, addr_bits, memory_bytes, version.
|
| 74 |
+
Defaults to 8-bit data, 16-bit addr for legacy models.
|
| 75 |
+
"""
|
| 76 |
+
return {
|
| 77 |
+
'data_bits': int(tensors.get('manifest.data_bits', torch.tensor([8.0])).item()),
|
| 78 |
+
'addr_bits': int(tensors.get('manifest.addr_bits',
|
| 79 |
+
tensors.get('manifest.pc_width', torch.tensor([16.0]))).item()),
|
| 80 |
+
'memory_bytes': int(tensors.get('manifest.memory_bytes', torch.tensor([65536.0])).item()),
|
| 81 |
+
'version': float(tensors.get('manifest.version', torch.tensor([1.0])).item()),
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
def create_population(
|
| 86 |
base_tensors: Dict[str, torch.Tensor],
|
| 87 |
pop_size: int,
|
|
|
|
| 904 |
Tests all circuits comprehensively.
|
| 905 |
"""
|
| 906 |
|
| 907 |
+
def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH, tensors: Dict[str, torch.Tensor] = None):
|
| 908 |
self.device = device
|
| 909 |
self.model_path = model_path
|
| 910 |
self.metadata = load_metadata(model_path)
|
|
|
|
| 912 |
self.results: List[CircuitResult] = []
|
| 913 |
self.category_scores: Dict[str, Tuple[float, int]] = {}
|
| 914 |
self.total_tests = 0
|
| 915 |
+
|
| 916 |
+
# Get manifest for N-bit support
|
| 917 |
+
if tensors is not None:
|
| 918 |
+
self.manifest = get_manifest(tensors)
|
| 919 |
+
else:
|
| 920 |
+
base_tensors = load_model(model_path)
|
| 921 |
+
self.manifest = get_manifest(base_tensors)
|
| 922 |
+
self.data_bits = self.manifest['data_bits']
|
| 923 |
+
self.addr_bits = self.manifest['addr_bits']
|
| 924 |
+
|
| 925 |
self._setup_tests()
|
| 926 |
|
| 927 |
def _setup_tests(self):
|
|
|
|
| 2922 |
# =========================================================================
|
| 2923 |
|
| 2924 |
def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 2925 |
+
"""Test conditional jump circuit (N-bit address aware)."""
|
| 2926 |
pop_size = next(iter(pop.values())).shape[0]
|
| 2927 |
prefix = f'control.{name}'
|
| 2928 |
|
|
|
|
| 2936 |
scores = torch.zeros(pop_size, device=self.device)
|
| 2937 |
total = 0
|
| 2938 |
|
| 2939 |
+
for bit in range(self.addr_bits):
|
| 2940 |
bit_prefix = f'{prefix}.bit{bit}'
|
| 2941 |
try:
|
| 2942 |
# NOT sel
|
|
|
|
| 3004 |
return scores, total
|
| 3005 |
|
| 3006 |
def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]:
|
| 3007 |
+
"""Test PUSH/POP/RET stack operation circuits (N-bit address aware)."""
|
| 3008 |
pop_size = next(iter(pop.values())).shape[0]
|
| 3009 |
scores = torch.zeros(pop_size, device=self.device)
|
| 3010 |
total = 0
|
| 3011 |
+
addr_bits = self.addr_bits
|
| 3012 |
+
addr_mask = (1 << addr_bits) - 1
|
| 3013 |
|
| 3014 |
if debug:
|
| 3015 |
+
print(f"\n=== STACK OPERATIONS ({addr_bits}-bit SP) ===")
|
| 3016 |
|
| 3017 |
+
# Test PUSH SP decrement (addr_bits wide, borrow chain)
|
| 3018 |
try:
|
| 3019 |
+
# Generate test values appropriate for addr_bits
|
| 3020 |
+
sp_tests = [0, 1, addr_mask // 2, addr_mask]
|
| 3021 |
+
if addr_bits >= 8:
|
| 3022 |
+
sp_tests.append(0x100 & addr_mask)
|
| 3023 |
+
if addr_bits >= 12:
|
| 3024 |
+
sp_tests.append(0x1234 & addr_mask)
|
| 3025 |
op_scores = torch.zeros(pop_size, device=self.device)
|
| 3026 |
op_total = 0
|
| 3027 |
|
| 3028 |
for sp_val in sp_tests:
|
| 3029 |
+
expected_val = (sp_val - 1) & addr_mask
|
| 3030 |
+
sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)]
|
| 3031 |
|
| 3032 |
borrow = 1.0
|
| 3033 |
out_bits = []
|
| 3034 |
+
for bit in range(addr_bits - 1, -1, -1): # LSB to MSB
|
| 3035 |
prefix = f'control.push.sp_dec.bit{bit}'
|
| 3036 |
|
| 3037 |
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
|
|
|
|
| 3056 |
borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item()
|
| 3057 |
|
| 3058 |
out = torch.stack(out_bits, dim=-1)
|
| 3059 |
+
expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
|
| 3060 |
device=self.device, dtype=torch.float32)
|
| 3061 |
correct = (out == expected.unsqueeze(0)).float().sum(1)
|
| 3062 |
op_scores += correct
|
| 3063 |
+
op_total += addr_bits
|
| 3064 |
|
| 3065 |
scores += op_scores
|
| 3066 |
total += op_total
|
|
|
|
| 3072 |
if debug:
|
| 3073 |
print(f" control.push.sp_dec: SKIP ({e})")
|
| 3074 |
|
| 3075 |
+
# Test POP SP increment (addr_bits wide, carry chain)
|
| 3076 |
try:
|
| 3077 |
op_scores = torch.zeros(pop_size, device=self.device)
|
| 3078 |
op_total = 0
|
| 3079 |
|
| 3080 |
for sp_val in sp_tests:
|
| 3081 |
+
expected_val = (sp_val + 1) & addr_mask
|
| 3082 |
+
sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)]
|
| 3083 |
|
| 3084 |
carry = 1.0
|
| 3085 |
out_bits = []
|
| 3086 |
+
for bit in range(addr_bits - 1, -1, -1): # LSB to MSB
|
| 3087 |
prefix = f'control.pop.sp_inc.bit{bit}'
|
| 3088 |
|
| 3089 |
w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2)
|
|
|
|
| 3106 |
carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item()
|
| 3107 |
|
| 3108 |
out = torch.stack(out_bits, dim=-1)
|
| 3109 |
+
expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
|
| 3110 |
device=self.device, dtype=torch.float32)
|
| 3111 |
correct = (out == expected.unsqueeze(0)).float().sum(1)
|
| 3112 |
op_scores += correct
|
| 3113 |
+
op_total += addr_bits
|
| 3114 |
|
| 3115 |
scores += op_scores
|
| 3116 |
total += op_total
|
|
|
|
| 3122 |
if debug:
|
| 3123 |
print(f" control.pop.sp_inc: SKIP ({e})")
|
| 3124 |
|
| 3125 |
+
# Test RET address buffer (addr_bits identity gates)
|
| 3126 |
try:
|
| 3127 |
op_scores = torch.zeros(pop_size, device=self.device)
|
| 3128 |
op_total = 0
|
| 3129 |
|
| 3130 |
+
ret_tests = [0, addr_mask, addr_mask // 2, 1]
|
| 3131 |
+
if addr_bits >= 12:
|
| 3132 |
+
ret_tests.append(0x1234 & addr_mask)
|
| 3133 |
+
for addr_val in ret_tests:
|
| 3134 |
+
ret_bits_tensor = torch.tensor([float((addr_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)],
|
| 3135 |
device=self.device, dtype=torch.float32)
|
| 3136 |
|
| 3137 |
out_bits = []
|
| 3138 |
+
for bit in range(addr_bits):
|
| 3139 |
w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size)
|
| 3140 |
b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size)
|
| 3141 |
+
out = heaviside(ret_bits_tensor[bit] * w + b)
|
| 3142 |
out_bits.append(out)
|
| 3143 |
|
| 3144 |
out = torch.stack(out_bits, dim=-1)
|
| 3145 |
+
correct = (out == ret_bits_tensor.unsqueeze(0)).float().sum(1)
|
| 3146 |
op_scores += correct
|
| 3147 |
+
op_total += addr_bits
|
| 3148 |
|
| 3149 |
scores += op_scores
|
| 3150 |
total += op_total
|
neural_alu32.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5a0f6cdfb4ba0ebdfc863f43e5f8fd4f41626c0fd4e7258a0a581a117a79d97
|
| 3 |
+
size 5031612
|
neural_computer.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08a39c4758f6e5236f84d231be7f2d54364099309a89cf484d607a6544194d20
|
| 3 |
+
size 2591660
|