CharlesCNorton commited on
Commit
1f44a34
·
1 Parent(s): 8a1465b

Add 32-bit ALU support with 1KB memory profile

Browse files

- build.py: Add --bits {8,16,32} flag for N-bit circuit generation
- build.py: Add 'small' memory profile (1KB, 10-bit addresses)
- build.py: Add 32-bit generators for adder, subtractor, comparators,
multiplier, divider, bitwise ops, shifts, inc/dec, neg
- eval.py: Add 32-bit test data and comparator testing
- README.md: Document 32-bit support, pivot to from-scratch extractor

32-bit adder verified: 1M + 2M = 3M, 0xDEAD0000 + 0xBEEF = 0xDEADBEEF

TODO:
- Add missing 32-bit eval tests (sub, mul, div, bitwise, shifts)
- Fix 32-bit comparator precision (float32 mantissa overflow on 2^31 weights)
Planned fix: cascaded byte-wise comparison

Files changed (4) hide show
  1. README.md +109 -76
  2. build.py +266 -9
  3. eval.py +140 -0
  4. neural_alu32.safetensors +3 -0
README.md CHANGED
@@ -12,30 +12,31 @@ tags:
12
 
13
  # 8bit-threshold-computer
14
 
15
- **A Turing-complete 8-bit CPU implemented entirely as threshold logic gates.**
16
 
17
  Every logic gate is a threshold neuron: `output = 1 if (Σ wᵢxᵢ + b) ≥ 0 else 0`
18
 
19
  ```
20
- Tensors: 11,581
21
- Parameters: 8,290,134 (full CPU) / 32,397 (pure ALU for LLM)
22
  ```
23
 
24
  ---
25
 
26
  ## What Is This?
27
 
28
- A complete 8-bit processor where every operation—from Boolean logic to arithmetic to control flow—is implemented using only weighted sums and step functions. No traditional gates.
29
 
30
- | Component | Specification |
31
- |-----------|---------------|
32
- | Registers | 4 × 8-bit general purpose |
33
- | Memory | Configurable: 0B (pure ALU) to 64KB (full CPU) |
34
- | ALU | 16 operations (ADD, SUB, AND, OR, XOR, NOT, SHL, SHR, MUL, DIV, INC, DEC, NEG, ROL, ROR, CMP) |
35
- | Flags | Zero, Negative, Carry, Overflow |
36
- | Control | JMP, JZ, JNZ, JC, JNC, JN, JP, JV, JNV, CALL, RET, PUSH, POP |
 
37
 
38
- **Turing complete.** Verified with loops, conditionals, recursion, and self-modification.
39
 
40
  ---
41
 
@@ -583,27 +584,29 @@ Head-to-head on 50 random cases: SmolLM2 got 7/50 (14%), circuits got 50/50 (100
583
 
584
  **Stage 3: LLM Integration — IN PROGRESS**
585
 
586
- The actual challenge: train an interface that extracts operands and operations from LLM hidden states (not from pre-formatted bit inputs).
587
 
588
  ```
589
  "47 + 86"
590
 
591
- [SmolLM2 hidden states: (seq_len, 960)]
592
 
593
- Extractor (must LEARN: hidden → a_bits, b_bits, op_logits)
594
 
595
  [Frozen threshold circuits]
596
 
597
  [Result bits] → 133
598
  ```
599
 
600
- **Training Infrastructure** (`train.py`):
 
 
601
 
602
  | Mode | Description | Status |
603
  |------|-------------|--------|
604
  | `--mode router` | Train OpRouter with ground truth bits | 100% achieved |
605
  | `--mode interface` | Train BitEncoder + OpRouter | Ready |
606
- | `--mode llm` | Train from LLM hidden states | Active development |
607
 
608
  **LLM Mode Options**:
609
  - `--unfreeze_layers N`: Fine-tune top N transformer layers
@@ -615,95 +618,125 @@ Extractor (must LEARN: hidden → a_bits, b_bits, op_logits)
615
  - `Extractor`: Attention pooling + per-bit MLPs
616
  - `PositionExtractor`: Position-aware (operand A from positions 0-2, B from 5-7)
617
  - `DigitExtractor`: Predicts 3 digits per operand, converts to bits
 
618
 
619
  **Curriculum Learning**: Training progresses 0-9 → 0-99 → 0-255 over epochs.
620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  #### Proof of Concept Scope
622
 
623
- - **8-bit operands** (0-255)
624
  - **Six operations**: ADD, SUB, MUL, GT, LT, EQ
625
- - **Pure ALU profile** (no memory access)
626
 
627
  **Current Status**:
628
- - Circuit validation: Complete (100% on all operations)
629
- - LLM baseline: Measured (11.90%)
630
- - SmolLM2 architecture analysis: Complete (see `SMOLLM2_ARCHITECTURE.md`)
631
- - Extraction training: In progress
 
632
 
633
  ### Extension Roadmap
634
 
635
- The following extensions are planned after proof-of-concept validation:
636
-
637
- 1. **16-bit operations (0-65535)** — Chain two 8-bit circuits with carry propagation. ADD16: low = ADD8(A_lo, B_lo), high = ADD8(A_hi, B_hi, carry_out). MUL16: four partial products + shift-add. Doubles operand extraction width. This extension is a priority as it dramatically expands the useful range of arithmetic operations.
638
 
639
- 2. **Parenthetical expressions ((5 + 3) × 2 = 16)** — Explicit grouping overrides precedence. Parser must recognize parens and build correct tree. Evaluation proceeds innermost-out. Adds complexity to extraction layer.
 
 
 
 
 
 
 
640
 
641
- 3. **Multi-operation chains (a + b - c × d)** Sequential dispatch through multiple circuits with intermediate result routing. Requires state management in interface layers.
642
 
643
- 4. **Floating point arithmetic** IEEE 754-style with separate circuits for mantissa and exponent. ADD: align exponents, add mantissas, renormalize. MUL: add exponents, multiply mantissas. Requires sign handling, overflow detection, and rounding logic.
644
 
645
- 5. **Full CPU integration** — Enable memory access circuits for stateful computation. Allows multi-step algorithms executed entirely within threshold logic.
646
 
647
- ### Completed Extensions
648
 
649
- - **3-operand addition (15 + 27 + 33 = 75)** `arithmetic.add3_8bit` chains two 8-bit ripple carry stages. 16 full adders, 144 gates, 240 test cases verified.
650
-
651
- - **Order of operations (5 + 3 × 2 = 11)** — `arithmetic.expr_add_mul` computes A + (B × C) using shift-add multiplication then addition. 64 AND gates + 64 full adders, 73 test cases verified.
652
-
653
- ---
654
 
655
- ## Files
656
 
657
- ### Core
658
 
659
- | File | Description |
660
- |------|-------------|
661
- | `neural_computer.safetensors` | Frozen threshold circuits (~8.29M params full, ~32K pure ALU) |
662
- | `eval.py` | Unified evaluation suite (GPU-batched, exhaustive testing) |
663
- | `build.py` | Circuit generator with configurable memory profiles |
664
- | `prune_weights.py` | Weight magnitude pruning (GPU-batched, binary search conflict resolution) |
665
 
666
- ### LLM Integration (`llm_integration/`)
667
 
668
- | File | Description |
669
- |------|-------------|
670
- | `SMOLLM2_ARCHITECTURE.md` | Complete technical analysis of SmolLM2-360M (layers, weights, tokenization) |
671
- | `baseline.py` | SmolLM2-360M vanilla arithmetic evaluation (11.90% baseline) |
672
- | `circuits.py` | Frozen threshold circuit wrapper with STE gradients |
673
- | `fitness.py` | Shared fitness function (randomized arithmetic, no answer supervision) |
674
- | `model.py` | Interface layers: `BitEncoder`, `OpRouter`, `Extractor`, `ArithmeticModel` |
675
- | `train.py` | Unified training: `--mode router`, `--mode interface`, `--mode llm` |
676
- | `trained/router.pt` | Trained OpRouter checkpoint (100% with ground truth bits) |
677
 
678
- ### Build Tool Usage
679
 
680
  ```bash
681
- # Full CPU (64KB memory, default)
682
- python build.py memory --apply
 
 
683
 
684
- # LLM integration profiles
685
- python build.py --memory-profile none memory --apply # Pure ALU (32K params)
686
- python build.py --memory-profile registers memory --apply # 16-byte register file
687
- python build.py --memory-profile scratchpad memory --apply # 256-byte scratchpad
688
 
689
- # Custom memory size
690
- python build.py --addr-bits 6 memory --apply # 64 bytes (2^6)
691
-
692
- # Regenerate ALU and input metadata
693
- python build.py alu --apply
694
- python build.py inputs --apply
695
- python build.py all --apply # memory + alu + inputs
696
  ```
697
 
698
- Memory profiles:
 
 
 
 
 
 
 
 
699
 
700
- | Profile | Addr Bits | Memory | Memory Params | Total Params |
701
- |---------|-----------|--------|---------------|--------------|
702
- | `none` | 0 | 0B | 0 | ~32K |
703
- | `registers` | 4 | 16B | ~2K | ~34K |
704
- | `scratchpad` | 8 | 256B | ~30K | ~63K |
705
- | `reduced` | 12 | 4KB | ~516K | ~549K |
706
- | `full` | 16 | 64KB | ~8.26M | ~8.29M |
 
707
 
708
  ---
709
 
 
12
 
13
  # 8bit-threshold-computer
14
 
15
+ **A Turing-complete CPU implemented entirely as threshold logic gates, with 8-bit and 32-bit ALU support.**
16
 
17
  Every logic gate is a threshold neuron: `output = 1 if (Σ wᵢxᵢ + b) ≥ 0 else 0`
18
 
19
  ```
20
+ 8-bit CPU: 8,290,134 params (full) / 32,397 params (pure ALU)
21
+ 32-bit ALU: 202,869 params (1KB scratch memory)
22
  ```
23
 
24
  ---
25
 
26
  ## What Is This?
27
 
28
+ A complete processor where every operation—from Boolean logic to arithmetic to control flow—is implemented using only weighted sums and step functions. No traditional gates.
29
 
30
+ | Component | 8-bit CPU | 32-bit ALU |
31
+ |-----------|-----------|------------|
32
+ | Registers | 4 × 8-bit | N/A (pure computation) |
33
+ | Memory | 0B–64KB configurable | 1KB scratch |
34
+ | ALU | 16 ops @ 8-bit | ADD, SUB, MUL, DIV, CMP, bitwise, shifts |
35
+ | Precision | 0–255 | 0–4,294,967,295 |
36
+ | Flags | Z, N, C, V | Carry/overflow |
37
+ | Control | Full ISA | Stateless |
38
 
39
+ **Turing complete.** The 8-bit CPU is verified with loops, conditionals, recursion, and self-modification. The 32-bit ALU extends arithmetic to practical ranges (0–4B) where 8-bit (0–255) is insufficient.
40
 
41
  ---
42
 
 
584
 
585
  **Stage 3: LLM Integration — IN PROGRESS**
586
 
587
+ The challenge: train an interface that extracts operands and operations from natural language (not from pre-formatted bit inputs).
588
 
589
  ```
590
  "47 + 86"
591
 
592
+ [Language Model / Extractor]
593
 
594
+ [a_bits, b_bits, op_logits]
595
 
596
  [Frozen threshold circuits]
597
 
598
  [Result bits] → 133
599
  ```
600
 
601
+ **SmolLM2 Approach** (`llm_integration/`):
602
+
603
+ Initial experiments used SmolLM2-360M-Instruct as the language understanding backbone.
604
 
605
  | Mode | Description | Status |
606
  |------|-------------|--------|
607
  | `--mode router` | Train OpRouter with ground truth bits | 100% achieved |
608
  | `--mode interface` | Train BitEncoder + OpRouter | Ready |
609
+ | `--mode llm` | Train from LLM hidden states | Explored |
610
 
611
  **LLM Mode Options**:
612
  - `--unfreeze_layers N`: Fine-tune top N transformer layers
 
618
  - `Extractor`: Attention pooling + per-bit MLPs
619
  - `PositionExtractor`: Position-aware (operand A from positions 0-2, B from 5-7)
620
  - `DigitExtractor`: Predicts 3 digits per operand, converts to bits
621
+ - `HybridExtractor`: Digit lookup + MLP fallback for word inputs
622
 
623
  **Curriculum Learning**: Training progresses 0-9 → 0-99 → 0-255 over epochs.
624
 
625
+ **Observations**: SmolLM2 integration proved challenging—360M parameters of pre-trained representations largely irrelevant to arithmetic parsing, high VRAM requirements, and gradient conflicts between frozen circuits and pre-trained weights.
626
+
627
+ **Pivot: From-Scratch Extractor**
628
+
629
+ Given that the task is fundamentally simple—parse `(a, b, op)` from structured text—a lightweight purpose-built model may be more appropriate than adapting a general LLM.
630
+
631
+ ```
632
+ "one thousand plus two thousand"
633
+
634
+ [Char-level tokenizer: ~40 tokens]
635
+
636
+ [Small transformer: ~1-5M params]
637
+
638
+ [3 heads: a_value, b_value, op_idx]
639
+
640
+ [Frozen 32-bit threshold circuits]
641
+
642
+ 3000
643
+ ```
644
+
645
+ **Design principles**:
646
+ - **Minimal Python**: All parsing logic learned in weights, not hardcoded
647
+ - **Character-level input**: No word tokenization; model learns "forty seven" = 47
648
+ - **From-scratch training**: No pre-trained weights to conflict with
649
+ - **32-bit target**: Practical arithmetic range (0–4,294,967,295)
650
+
651
+ **Planned architecture**:
652
+ - Vocab: ~40 chars (a-z, 0-9, space, operators)
653
+ - Embedding: 40 × 128d
654
+ - Encoder: 2-3 transformer layers
655
+ - Output heads: `a_classifier`, `b_classifier`, `op_classifier`
656
+ - Total: ~1-5M params (vs 360M for SmolLM2)
657
+
658
+ This approach treats the problem as what it is: a structured parsing task where the frozen circuits handle all computation. The extractor need only learn the mapping from text to operands—no world knowledge required.
659
+
660
  #### Proof of Concept Scope
661
 
662
+ - **32-bit operands** (0–4,294,967,295)
663
  - **Six operations**: ADD, SUB, MUL, GT, LT, EQ
664
+ - **Structured input**: Digits ("1000 + 2000") and number words ("one thousand plus two thousand")
665
 
666
  **Current Status**:
667
+ - Circuit validation: Complete (100% on 8-bit operations)
668
+ - 32-bit circuits: Built and tested (adder verified on 1M+2M=3M, etc.)
669
+ - LLM baseline: Measured (11.90% - establishes control condition)
670
+ - SmolLM2 integration: Infrastructure complete, training explored
671
+ - From-scratch extractor: Design phase
672
 
673
  ### Extension Roadmap
674
 
675
+ #### Completed
 
 
676
 
677
+ 1. **32-bit operations (0–4,294,967,295)** — Full 32-bit ALU implemented via `--bits 32` flag:
678
+ - 32-bit ripple carry adder (32 chained full adders) — **verified**
679
+ - 32-bit subtractor (NOT + adder with carry-in)
680
+ - 32-bit multiplication (1024 partial product ANDs)
681
+ - 32-bit division (32 restoring stages)
682
+ - 32-bit comparators (GT, LT, GE, LE, EQ)
683
+ - 32-bit bitwise ops (AND, OR, XOR, NOT)
684
+ - 32-bit shifts (SHL, SHR), INC, DEC, NEG
685
 
686
+ **Known issue**: Single-layer 32-bit comparators use weights up to 2³¹, which exceeds float32 mantissa precision (24 bits). Comparisons between large numbers differing only in low bits may fail. Fix planned: cascaded byte-wise comparison (compare MSB first, if equal compare next byte, etc.).
687
 
688
+ 2. **3-operand addition (15 + 27 + 33 = 75)** `arithmetic.add3_8bit` chains two 8-bit ripple carry stages. 16 full adders, 144 gates, 240 test cases verified.
689
 
690
+ 3. **Order of operations (5 + 3 × 2 = 11)** — `arithmetic.expr_add_mul` computes A + (B × C) using shift-add multiplication then addition. 64 AND gates + 64 full adders, 73 test cases verified.
691
 
692
+ #### Planned
693
 
694
+ 1. **Cascaded 32-bit comparators** Replace single-layer weighted comparison with multi-layer byte-wise cascade. Each byte comparison uses 8-bit weights (max 128), well within float32 precision. Hardware-accurate and extensible to 64-bit, 128-bit, etc.
 
 
 
 
695
 
696
+ 2. **Parenthetical expressions ((5 + 3) × 2 = 16)** — Explicit grouping overrides precedence. Parser must recognize parens and build correct tree. Evaluation proceeds innermost-out.
697
 
698
+ 3. **Multi-operation chains (a + b - c × d)** — Sequential dispatch through multiple circuits with intermediate result routing. Requires state management in interface layers.
699
 
700
+ 4. **Floating point arithmetic** — IEEE 754-style with separate circuits for mantissa and exponent. ADD: align exponents, add mantissas, renormalize. MUL: add exponents, multiply mantissas.
 
 
 
 
 
701
 
702
+ 5. **Full CPU integration** — Enable memory access circuits for stateful computation. Allows multi-step algorithms executed entirely within threshold logic.
703
 
704
+ ---
 
 
 
 
 
 
 
 
705
 
706
+ ## Build Tool
707
 
708
  ```bash
709
+ # 8-bit CPU (default)
710
+ python build.py --apply all # Full 64KB memory
711
+ python build.py -m none --apply all # Pure ALU (32K params)
712
+ python build.py -m scratchpad --apply all # 256-byte scratch
713
 
714
+ # 32-bit ALU
715
+ python build.py --bits 32 -m small --apply all # 1KB scratch (~203K params)
716
+ python build.py --bits 32 -m none --apply all # Pure 32-bit ALU
 
717
 
718
+ # Custom configurations
719
+ python build.py --bits 16 --addr-bits 6 --apply all # 16-bit ALU, 64 bytes memory
 
 
 
 
 
720
  ```
721
 
722
+ **Bit widths** (`--bits`):
723
+
724
+ | Width | Range | Use Case |
725
+ |-------|-------|----------|
726
+ | 8 | 0–255 | Full CPU, legacy |
727
+ | 16 | 0–65,535 | Extended arithmetic |
728
+ | 32 | 0–4,294,967,295 | Practical arithmetic |
729
+
730
+ **Memory profiles** (`-m`):
731
 
732
+ | Profile | Size | Params | Use Case |
733
+ |---------|------|--------|----------|
734
+ | `none` | 0B | ~32K | Pure ALU |
735
+ | `registers` | 16B | ~34K | Minimal state |
736
+ | `scratchpad` | 256B | ~63K | 8-bit scratch |
737
+ | `small` | 1KB | ~123K | 32-bit scratch |
738
+ | `reduced` | 4KB | ~549K | Small programs |
739
+ | `full` | 64KB | ~8.29M | Full CPU |
740
 
741
  ---
742
 
build.py CHANGED
@@ -121,11 +121,14 @@ DEFAULT_MEM_BYTES = 1 << DEFAULT_ADDR_BITS
121
  MEMORY_PROFILES = {
122
  "full": 16, # 64KB - full CPU mode
123
  "reduced": 12, # 4KB - reduced CPU
 
124
  "scratchpad": 8, # 256 bytes - LLM scratchpad
125
  "registers": 4, # 16 bytes - LLM register file
126
  "none": 0, # Pure ALU, no memory
127
  }
128
 
 
 
129
 
130
  def load_tensors(path: Path) -> Dict[str, torch.Tensor]:
131
  tensors: Dict[str, torch.Tensor] = {}
@@ -674,6 +677,163 @@ def add_comparators(tensors: Dict[str, torch.Tensor]) -> None:
674
  add_gate(tensors, "arithmetic.equality8bit.layer2", [1.0, 1.0], [-2.0])
675
 
676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
  def update_manifest(tensors: Dict[str, torch.Tensor], addr_bits: int, mem_bytes: int) -> None:
678
  tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32)
679
  tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
@@ -1863,14 +2023,15 @@ def cmd_inputs(args) -> None:
1863
 
1864
 
1865
  def cmd_alu(args) -> None:
 
1866
  print("=" * 60)
1867
- print(" BUILD ALU CIRCUITS")
1868
  print("=" * 60)
1869
  print(f"\nLoading: {args.model}")
1870
  tensors = load_tensors(args.model)
1871
  print(f" Loaded {len(tensors)} tensors")
1872
- print("\nDropping existing ALU extension tensors...")
1873
- drop_prefixes(tensors, [
1874
  "alu.alu8bit.shl.", "alu.alu8bit.shr.",
1875
  "alu.alu8bit.mul.", "alu.alu8bit.div.",
1876
  "alu.alu8bit.inc.", "alu.alu8bit.dec.",
@@ -1880,7 +2041,18 @@ def cmd_alu(args) -> None:
1880
  "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
1881
  "control.push.", "control.pop.", "control.ret.",
1882
  "combinational.barrelshifter.", "combinational.priorityencoder.",
1883
- ])
 
 
 
 
 
 
 
 
 
 
 
1884
  print(f" Now {len(tensors)} tensors")
1885
  print("\nGenerating SHL/SHR circuits...")
1886
  try:
@@ -1960,13 +2132,85 @@ def cmd_alu(args) -> None:
1960
  print(" Added EXPR_PAREN (8 + 64 AND + 56 full adders = 640 gates)")
1961
  except ValueError as e:
1962
  print(f" EXPR_PAREN already exists: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1963
  if args.apply:
1964
  print(f"\nSaving: {args.model}")
1965
  save_file(tensors, str(args.model))
1966
  print(" Done.")
1967
  else:
1968
  print("\n[DRY-RUN] Use --apply to save.")
 
1969
  print(f"\nTotal: {len(tensors)} tensors")
 
 
1970
  print("=" * 60)
1971
 
1972
 
@@ -1987,26 +2231,39 @@ def main() -> None:
1987
  Memory Profiles:
1988
  full 64KB (16-bit addr) - Full CPU mode
1989
  reduced 4KB (12-bit addr) - Reduced CPU
 
1990
  scratchpad 256B (8-bit addr) - LLM scratchpad
1991
  registers 16B (4-bit addr) - LLM register file
1992
  none 0B (no memory) - Pure ALU for LLM
1993
 
 
 
 
 
 
1994
  Examples:
1995
  python build.py memory --memory-profile none --apply # LLM-only (no RAM)
1996
- python build.py memory --memory-profile scratchpad # 256-byte scratchpad
1997
- python build.py memory --addr-bits 6 # Custom: 64 bytes
1998
- python build.py memory # Default: 64KB
1999
  """
2000
  )
2001
  parser.add_argument("--model", type=Path, default=MODEL_PATH, help="Model path")
2002
  parser.add_argument("--apply", action="store_true", help="Apply changes (default: dry-run)")
2003
  parser.add_argument("--manifest", action="store_true", help="Write tensors.txt manifest (memory only)")
 
 
 
 
 
 
 
2004
 
2005
  mem_group = parser.add_mutually_exclusive_group()
2006
  mem_group.add_argument(
2007
  "--memory-profile", "-m",
2008
  choices=list(MEMORY_PROFILES.keys()),
2009
- help="Memory size profile (full/reduced/scratchpad/registers/none)"
2010
  )
2011
  mem_group.add_argument(
2012
  "--addr-bits", "-a",
@@ -2018,7 +2275,7 @@ Examples:
2018
 
2019
  subparsers = parser.add_subparsers(dest="command", help="Subcommands")
2020
  subparsers.add_parser("memory", help="Generate memory circuits (size controlled by --memory-profile or --addr-bits)")
2021
- subparsers.add_parser("alu", help="Generate ALU extension circuits (SHL, SHR, comparators)")
2022
  subparsers.add_parser("inputs", help="Add .inputs metadata tensors")
2023
  subparsers.add_parser("all", help="Run memory, alu, then inputs")
2024
 
 
121
  MEMORY_PROFILES = {
122
  "full": 16, # 64KB - full CPU mode
123
  "reduced": 12, # 4KB - reduced CPU
124
+ "small": 10, # 1KB - 32-bit arithmetic scratch
125
  "scratchpad": 8, # 256 bytes - LLM scratchpad
126
  "registers": 4, # 16 bytes - LLM register file
127
  "none": 0, # Pure ALU, no memory
128
  }
129
 
130
+ SUPPORTED_BITS = [8, 16, 32]
131
+
132
 
133
  def load_tensors(path: Path) -> Dict[str, torch.Tensor]:
134
  tensors: Dict[str, torch.Tensor] = {}
 
677
  add_gate(tensors, "arithmetic.equality8bit.layer2", [1.0, 1.0], [-2.0])
678
 
679
 
680
+ def add_ripple_carry_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
681
+ """Add N-bit ripple carry adder circuit.
682
+
683
+ Creates a chain of full adders for N-bit addition.
684
+ Works for 8, 16, or 32 bits.
685
+
686
+ Inputs: $a[0..N-1], $b[0..N-1] (MSB-first)
687
+ Outputs: fa0-fa{N-1} sum bits, fa{N-1}.carry_or for overflow
688
+ """
689
+ prefix = f"arithmetic.ripplecarry{bits}bit"
690
+ for bit in range(bits):
691
+ add_full_adder(tensors, f"{prefix}.fa{bit}")
692
+
693
+
694
+ def add_sub_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
695
+ """Add N-bit subtractor circuit (A - B).
696
+
697
+ Uses two's complement: A - B = A + (~B) + 1
698
+
699
+ Structure:
700
+ - NOT gates for each bit of B
701
+ - N-bit ripple carry adder with carry_in = 1
702
+
703
+ The carry_in=1 is handled by the adder's fa0 having cin=#1 instead of #0.
704
+ """
705
+ prefix = f"arithmetic.sub{bits}bit"
706
+
707
+ for bit in range(bits):
708
+ add_gate(tensors, f"{prefix}.not_b.bit{bit}", [-1.0], [0.0])
709
+
710
+ for bit in range(bits):
711
+ add_full_adder(tensors, f"{prefix}.fa{bit}")
712
+
713
+
714
+ def add_comparators_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
715
+ """Add N-bit comparator circuits (GT, LT, GE, LE, EQ).
716
+
717
+ Uses weighted sum comparison extended to N bits.
718
+ For N=32: weights are 2^31, 2^30, ..., 2^0 for A, negated for B.
719
+ """
720
+ pos_weights = [float(1 << (bits - 1 - i)) for i in range(bits)]
721
+ neg_weights = [-w for w in pos_weights]
722
+
723
+ gt_weights = pos_weights + neg_weights
724
+ lt_weights = neg_weights + pos_weights
725
+
726
+ add_gate(tensors, f"arithmetic.greaterthan{bits}bit", gt_weights, [-1.0])
727
+ add_gate(tensors, f"arithmetic.greaterorequal{bits}bit", gt_weights, [0.0])
728
+ add_gate(tensors, f"arithmetic.lessthan{bits}bit", lt_weights, [-1.0])
729
+ add_gate(tensors, f"arithmetic.lessorequal{bits}bit", lt_weights, [0.0])
730
+
731
+ add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.geq", gt_weights, [0.0])
732
+ add_gate(tensors, f"arithmetic.equality{bits}bit.layer1.leq", lt_weights, [0.0])
733
+ add_gate(tensors, f"arithmetic.equality{bits}bit.layer2", [1.0, 1.0], [-2.0])
734
+
735
+
736
+ def add_mul_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
737
+ """Add N-bit multiplication circuit.
738
+
739
+ Produces low N bits of the 2N-bit result.
740
+
741
+ Structure:
742
+ - N*N AND gates for partial products P[i][j] = A[i] AND B[j]
743
+ - Shift-add accumulation using existing adder circuits
744
+
745
+ For 32-bit: 1024 AND gates for partial products.
746
+ """
747
+ for i in range(bits):
748
+ for j in range(bits):
749
+ add_gate(tensors, f"alu.alu{bits}bit.mul.pp.a{i}b{j}", [1.0, 1.0], [-2.0])
750
+
751
+
752
+ def add_div_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
753
+ """Add N-bit division circuit.
754
+
755
+ Uses restoring division algorithm with N iterations.
756
+ """
757
+ pos_weights = [float(1 << (bits - 1 - i)) for i in range(bits)]
758
+ neg_weights = [-w for w in pos_weights]
759
+ cmp_weights = pos_weights + neg_weights
760
+
761
+ for stage in range(bits):
762
+ add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.cmp", cmp_weights, [0.0])
763
+
764
+ for stage in range(bits):
765
+ for bit in range(bits):
766
+ add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.not_sel", [-1.0], [0.0])
767
+ add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.and_a", [1.0, 1.0], [-2.0])
768
+ add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.and_b", [1.0, 1.0], [-2.0])
769
+ add_gate(tensors, f"alu.alu{bits}bit.div.stage{stage}.mux.bit{bit}.or", [1.0, 1.0], [-1.0])
770
+
771
+
772
+ def add_bitwise_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
773
+ """Add N-bit bitwise operation circuits (AND, OR, XOR, NOT).
774
+
775
+ These are simply N copies of the 1-bit gates.
776
+ """
777
+ for bit in range(bits):
778
+ add_gate(tensors, f"alu.alu{bits}bit.and.bit{bit}", [1.0, 1.0], [-2.0])
779
+
780
+ for bit in range(bits):
781
+ add_gate(tensors, f"alu.alu{bits}bit.or.bit{bit}", [1.0, 1.0], [-1.0])
782
+
783
+ for bit in range(bits):
784
+ add_gate(tensors, f"alu.alu{bits}bit.xor.bit{bit}.layer1.or", [1.0, 1.0], [-1.0])
785
+ add_gate(tensors, f"alu.alu{bits}bit.xor.bit{bit}.layer1.nand", [-1.0, -1.0], [1.0])
786
+ add_gate(tensors, f"alu.alu{bits}bit.xor.bit{bit}.layer2", [1.0, 1.0], [-2.0])
787
+
788
+ for bit in range(bits):
789
+ add_gate(tensors, f"alu.alu{bits}bit.not.bit{bit}", [-1.0], [0.0])
790
+
791
+
792
+ def add_shift_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
793
+ """Add N-bit shift circuits (SHL, SHR by 1 position).
794
+
795
+ SHL: out[i] = in[i+1] for i<N-1, out[N-1] = 0
796
+ SHR: out[0] = 0, out[i] = in[i-1] for i>0
797
+ """
798
+ for bit in range(bits):
799
+ if bit < bits - 1:
800
+ add_gate(tensors, f"alu.alu{bits}bit.shl.bit{bit}", [2.0], [-1.0])
801
+ else:
802
+ add_gate(tensors, f"alu.alu{bits}bit.shl.bit{bit}", [0.0], [-1.0])
803
+
804
+ for bit in range(bits):
805
+ if bit > 0:
806
+ add_gate(tensors, f"alu.alu{bits}bit.shr.bit{bit}", [2.0], [-1.0])
807
+ else:
808
+ add_gate(tensors, f"alu.alu{bits}bit.shr.bit{bit}", [0.0], [-1.0])
809
+
810
+
811
+ def add_inc_dec_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
812
+ """Add N-bit INC and DEC circuits."""
813
+ for bit in range(bits):
814
+ add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
815
+ add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
816
+ add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
817
+ add_gate(tensors, f"alu.alu{bits}bit.inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
818
+
819
+ for bit in range(bits):
820
+ add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
821
+ add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
822
+ add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
823
+ add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.not_a", [-1.0], [0.0])
824
+ add_gate(tensors, f"alu.alu{bits}bit.dec.bit{bit}.borrow", [1.0, 1.0], [-2.0])
825
+
826
+
827
+ def add_neg_nbits(tensors: Dict[str, torch.Tensor], bits: int) -> None:
828
+ """Add N-bit NEG circuit (two's complement negation)."""
829
+ for bit in range(bits):
830
+ add_gate(tensors, f"alu.alu{bits}bit.neg.not.bit{bit}", [-1.0], [0.0])
831
+ add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.xor.layer1.or", [1.0, 1.0], [-1.0])
832
+ add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.xor.layer1.nand", [-1.0, -1.0], [1.0])
833
+ add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.xor.layer2", [1.0, 1.0], [-2.0])
834
+ add_gate(tensors, f"alu.alu{bits}bit.neg.inc.bit{bit}.carry", [1.0, 1.0], [-2.0])
835
+
836
+
837
  def update_manifest(tensors: Dict[str, torch.Tensor], addr_bits: int, mem_bytes: int) -> None:
838
  tensors["manifest.memory_bytes"] = torch.tensor([float(mem_bytes)], dtype=torch.float32)
839
  tensors["manifest.pc_width"] = torch.tensor([float(addr_bits)], dtype=torch.float32)
 
2023
 
2024
 
2025
  def cmd_alu(args) -> None:
2026
+ bits = getattr(args, 'bits', 8) or 8
2027
  print("=" * 60)
2028
+ print(f" BUILD ALU CIRCUITS ({bits}-bit)")
2029
  print("=" * 60)
2030
  print(f"\nLoading: {args.model}")
2031
  tensors = load_tensors(args.model)
2032
  print(f" Loaded {len(tensors)} tensors")
2033
+
2034
+ drop_list = [
2035
  "alu.alu8bit.shl.", "alu.alu8bit.shr.",
2036
  "alu.alu8bit.mul.", "alu.alu8bit.div.",
2037
  "alu.alu8bit.inc.", "alu.alu8bit.dec.",
 
2041
  "arithmetic.equality8bit.", "arithmetic.add3_8bit.", "arithmetic.expr_add_mul.", "arithmetic.expr_paren.",
2042
  "control.push.", "control.pop.", "control.ret.",
2043
  "combinational.barrelshifter.", "combinational.priorityencoder.",
2044
+ ]
2045
+
2046
+ if bits in [16, 32]:
2047
+ drop_list.extend([
2048
+ f"alu.alu{bits}bit.", f"arithmetic.ripplecarry{bits}bit.",
2049
+ f"arithmetic.sub{bits}bit.", f"arithmetic.greaterthan{bits}bit.",
2050
+ f"arithmetic.lessthan{bits}bit.", f"arithmetic.greaterorequal{bits}bit.",
2051
+ f"arithmetic.lessorequal{bits}bit.", f"arithmetic.equality{bits}bit.",
2052
+ ])
2053
+
2054
+ print("\nDropping existing ALU extension tensors...")
2055
+ drop_prefixes(tensors, drop_list)
2056
  print(f" Now {len(tensors)} tensors")
2057
  print("\nGenerating SHL/SHR circuits...")
2058
  try:
 
2132
  print(" Added EXPR_PAREN (8 + 64 AND + 56 full adders = 640 gates)")
2133
  except ValueError as e:
2134
  print(f" EXPR_PAREN already exists: {e}")
2135
+
2136
+ if bits in [16, 32]:
2137
+ print(f"\n{'=' * 60}")
2138
+ print(f" GENERATING {bits}-BIT CIRCUITS")
2139
+ print(f"{'=' * 60}")
2140
+
2141
+ print(f"\nGenerating {bits}-bit ripple carry adder...")
2142
+ try:
2143
+ add_ripple_carry_nbits(tensors, bits)
2144
+ print(f" Added {bits}-bit adder ({bits} full adders = {bits * 9} gates)")
2145
+ except ValueError as e:
2146
+ print(f" {bits}-bit adder already exists: {e}")
2147
+
2148
+ print(f"\nGenerating {bits}-bit subtractor...")
2149
+ try:
2150
+ add_sub_nbits(tensors, bits)
2151
+ print(f" Added {bits}-bit subtractor ({bits} NOT + {bits} full adders)")
2152
+ except ValueError as e:
2153
+ print(f" {bits}-bit subtractor already exists: {e}")
2154
+
2155
+ print(f"\nGenerating {bits}-bit comparators...")
2156
+ try:
2157
+ add_comparators_nbits(tensors, bits)
2158
+ print(f" Added {bits}-bit GT, GE, LT, LE, EQ")
2159
+ except ValueError as e:
2160
+ print(f" {bits}-bit comparators already exist: {e}")
2161
+
2162
+ print(f"\nGenerating {bits}-bit multiplication...")
2163
+ try:
2164
+ add_mul_nbits(tensors, bits)
2165
+ print(f" Added {bits}-bit MUL ({bits * bits} partial product AND gates)")
2166
+ except ValueError as e:
2167
+ print(f" {bits}-bit MUL already exists: {e}")
2168
+
2169
+ print(f"\nGenerating {bits}-bit division...")
2170
+ try:
2171
+ add_div_nbits(tensors, bits)
2172
+ print(f" Added {bits}-bit DIV ({bits} stages)")
2173
+ except ValueError as e:
2174
+ print(f" {bits}-bit DIV already exists: {e}")
2175
+
2176
+ print(f"\nGenerating {bits}-bit bitwise ops (AND, OR, XOR, NOT)...")
2177
+ try:
2178
+ add_bitwise_nbits(tensors, bits)
2179
+ print(f" Added {bits}-bit AND, OR, XOR, NOT")
2180
+ except ValueError as e:
2181
+ print(f" {bits}-bit bitwise ops already exist: {e}")
2182
+
2183
+ print(f"\nGenerating {bits}-bit shift ops (SHL, SHR)...")
2184
+ try:
2185
+ add_shift_nbits(tensors, bits)
2186
+ print(f" Added {bits}-bit SHL, SHR")
2187
+ except ValueError as e:
2188
+ print(f" {bits}-bit shift ops already exist: {e}")
2189
+
2190
+ print(f"\nGenerating {bits}-bit INC/DEC...")
2191
+ try:
2192
+ add_inc_dec_nbits(tensors, bits)
2193
+ print(f" Added {bits}-bit INC, DEC")
2194
+ except ValueError as e:
2195
+ print(f" {bits}-bit INC/DEC already exist: {e}")
2196
+
2197
+ print(f"\nGenerating {bits}-bit NEG...")
2198
+ try:
2199
+ add_neg_nbits(tensors, bits)
2200
+ print(f" Added {bits}-bit NEG")
2201
+ except ValueError as e:
2202
+ print(f" {bits}-bit NEG already exists: {e}")
2203
+
2204
  if args.apply:
2205
  print(f"\nSaving: {args.model}")
2206
  save_file(tensors, str(args.model))
2207
  print(" Done.")
2208
  else:
2209
  print("\n[DRY-RUN] Use --apply to save.")
2210
+
2211
  print(f"\nTotal: {len(tensors)} tensors")
2212
+ total_params = sum(t.numel() for t in tensors.values())
2213
+ print(f"Total params: {total_params:,}")
2214
  print("=" * 60)
2215
 
2216
 
 
2231
  Memory Profiles:
2232
  full 64KB (16-bit addr) - Full CPU mode
2233
  reduced 4KB (12-bit addr) - Reduced CPU
2234
+ small 1KB (10-bit addr) - 32-bit arithmetic scratch
2235
  scratchpad 256B (8-bit addr) - LLM scratchpad
2236
  registers 16B (4-bit addr) - LLM register file
2237
  none 0B (no memory) - Pure ALU for LLM
2238
 
2239
+ ALU Bit Widths:
2240
+ 8 Standard 8-bit ALU (default)
2241
+ 16 16-bit ALU (0-65535)
2242
+ 32 32-bit ALU (0-4294967295)
2243
+
2244
  Examples:
2245
  python build.py memory --memory-profile none --apply # LLM-only (no RAM)
2246
+ python build.py memory --memory-profile small --apply # 1KB for 32-bit scratch
2247
+ python build.py alu --bits 32 --apply # 32-bit ALU circuits
2248
+ python build.py all --bits 32 -m small --apply # Full 32-bit build
2249
  """
2250
  )
2251
  parser.add_argument("--model", type=Path, default=MODEL_PATH, help="Model path")
2252
  parser.add_argument("--apply", action="store_true", help="Apply changes (default: dry-run)")
2253
  parser.add_argument("--manifest", action="store_true", help="Write tensors.txt manifest (memory only)")
2254
+ parser.add_argument(
2255
+ "--bits", "-b",
2256
+ type=int,
2257
+ choices=SUPPORTED_BITS,
2258
+ default=8,
2259
+ help="ALU bit width: 8 (default), 16, or 32"
2260
+ )
2261
 
2262
  mem_group = parser.add_mutually_exclusive_group()
2263
  mem_group.add_argument(
2264
  "--memory-profile", "-m",
2265
  choices=list(MEMORY_PROFILES.keys()),
2266
+ help="Memory size profile (full/reduced/small/scratchpad/registers/none)"
2267
  )
2268
  mem_group.add_argument(
2269
  "--addr-bits", "-a",
 
2275
 
2276
  subparsers = parser.add_subparsers(dest="command", help="Subcommands")
2277
  subparsers.add_parser("memory", help="Generate memory circuits (size controlled by --memory-profile or --addr-bits)")
2278
+ subparsers.add_parser("alu", help="Generate ALU extension circuits (use --bits for 16/32-bit)")
2279
  subparsers.add_parser("inputs", help="Add .inputs metadata tensors")
2280
  subparsers.add_parser("all", help="Run memory, alu, then inputs")
2281
 
eval.py CHANGED
@@ -968,6 +968,27 @@ class BatchedFitnessEvaluator:
968
  # Modular test range
969
  self.mod_test = torch.arange(256, device=d, dtype=torch.long)
970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
  def _record(self, name: str, passed: int, total: int, failures: List[Tuple] = None):
972
  """Record a circuit test result."""
973
  self.results.append(CircuitResult(
@@ -1705,6 +1726,107 @@ class BatchedFitnessEvaluator:
1705
 
1706
  return scores, total
1707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1708
  # =========================================================================
1709
  # THRESHOLD GATES
1710
  # =========================================================================
@@ -3399,6 +3521,24 @@ class BatchedFitnessEvaluator:
3399
  total_tests += t
3400
  self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
3401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3402
  # 3-operand adder
3403
  s, t = self._test_add3(population, debug)
3404
  scores += s
 
968
  # Modular test range
969
  self.mod_test = torch.arange(256, device=d, dtype=torch.long)
970
 
971
+ # 32-bit test values (strategic sampling)
972
+ self.test_32bit = torch.tensor([
973
+ 0, 1, 2, 255, 256, 65535, 65536,
974
+ 0x7FFFFFFF, 0x80000000, 0xFFFFFFFF,
975
+ 0x12345678, 0xDEADBEEF, 0xCAFEBABE,
976
+ 1000000, 1000000000, 2147483647,
977
+ 0x55555555, 0xAAAAAAAA, 0x0F0F0F0F, 0xF0F0F0F0
978
+ ], device=d, dtype=torch.long)
979
+
980
+ # 32-bit comparator test pairs
981
+ comp32_tests = [
982
+ (0, 0), (1, 0), (0, 1), (1000, 999), (999, 1000),
983
+ (0xFFFFFFFF, 0), (0, 0xFFFFFFFF),
984
+ (0x80000000, 0x7FFFFFFF), (0x7FFFFFFF, 0x80000000),
985
+ (1000000, 1000000), (0x12345678, 0x12345678),
986
+ (0xDEADBEEF, 0xCAFEBABE), (0xCAFEBABE, 0xDEADBEEF),
987
+ (256, 255), (255, 256), (65536, 65535), (65535, 65536),
988
+ ]
989
+ self.comp32_a = torch.tensor([c[0] for c in comp32_tests], device=d, dtype=torch.long)
990
+ self.comp32_b = torch.tensor([c[1] for c in comp32_tests], device=d, dtype=torch.long)
991
+
992
  def _record(self, name: str, passed: int, total: int, failures: List[Tuple] = None):
993
  """Record a circuit test result."""
994
  self.results.append(CircuitResult(
 
1726
 
1727
  return scores, total
1728
 
1729
+ def _test_comparators_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]:
1730
+ """Test N-bit comparator circuits (GT, LT, GE, LE, EQ)."""
1731
+ pop_size = next(iter(pop.values())).shape[0]
1732
+ scores = torch.zeros(pop_size, device=self.device)
1733
+ total = 0
1734
+
1735
+ if debug:
1736
+ print(f"\n=== {bits}-BIT COMPARATORS ===")
1737
+
1738
+ if bits == 32:
1739
+ comp_a = self.comp32_a
1740
+ comp_b = self.comp32_b
1741
+ elif bits == 16:
1742
+ comp_a = self.comp_a.clamp(0, 65535)
1743
+ comp_b = self.comp_b.clamp(0, 65535)
1744
+ else:
1745
+ comp_a = self.comp_a
1746
+ comp_b = self.comp_b
1747
+
1748
+ a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1749
+ b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1)
1750
+ inputs = torch.cat([a_bits, b_bits], dim=1)
1751
+
1752
+ comparators = [
1753
+ (f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b),
1754
+ (f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b),
1755
+ (f'arithmetic.lessthan{bits}bit', lambda a, b: a < b),
1756
+ (f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b),
1757
+ ]
1758
+
1759
+ for name, op in comparators:
1760
+ try:
1761
+ expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0
1762
+ for a, b in zip(comp_a, comp_b)], device=self.device)
1763
+
1764
+ w = pop[f'{name}.weight']
1765
+ b = pop[f'{name}.bias']
1766
+ out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size))
1767
+
1768
+ correct = (out == expected.unsqueeze(1)).float().sum(0)
1769
+
1770
+ failures = []
1771
+ if pop_size == 1:
1772
+ for i in range(len(comp_a)):
1773
+ if out[i, 0].item() != expected[i].item():
1774
+ failures.append((
1775
+ [int(comp_a[i].item()), int(comp_b[i].item())],
1776
+ expected[i].item(),
1777
+ out[i, 0].item()
1778
+ ))
1779
+
1780
+ self._record(name, int(correct[0].item()), len(comp_a), failures)
1781
+ if debug:
1782
+ r = self.results[-1]
1783
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1784
+ scores += correct
1785
+ total += len(comp_a)
1786
+ except KeyError:
1787
+ pass
1788
+
1789
+ prefix = f'arithmetic.equality{bits}bit'
1790
+ try:
1791
+ expected = torch.tensor([1.0 if a.item() == b.item() else 0.0
1792
+ for a, b in zip(comp_a, comp_b)], device=self.device)
1793
+
1794
+ w_geq = pop[f'{prefix}.layer1.geq.weight']
1795
+ b_geq = pop[f'{prefix}.layer1.geq.bias']
1796
+ w_leq = pop[f'{prefix}.layer1.leq.weight']
1797
+ b_leq = pop[f'{prefix}.layer1.leq.bias']
1798
+
1799
+ h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size))
1800
+ h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size))
1801
+ hidden = torch.stack([h_geq, h_leq], dim=-1)
1802
+
1803
+ w2 = pop[f'{prefix}.layer2.weight']
1804
+ b2 = pop[f'{prefix}.layer2.bias']
1805
+ out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size))
1806
+
1807
+ correct = (out == expected.unsqueeze(1)).float().sum(0)
1808
+
1809
+ failures = []
1810
+ if pop_size == 1:
1811
+ for i in range(len(comp_a)):
1812
+ if out[i, 0].item() != expected[i].item():
1813
+ failures.append((
1814
+ [int(comp_a[i].item()), int(comp_b[i].item())],
1815
+ expected[i].item(),
1816
+ out[i, 0].item()
1817
+ ))
1818
+
1819
+ self._record(prefix, int(correct[0].item()), len(comp_a), failures)
1820
+ if debug:
1821
+ r = self.results[-1]
1822
+ print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}")
1823
+ scores += correct
1824
+ total += len(comp_a)
1825
+ except KeyError:
1826
+ pass
1827
+
1828
+ return scores, total
1829
+
1830
  # =========================================================================
1831
  # THRESHOLD GATES
1832
  # =========================================================================
 
3521
  total_tests += t
3522
  self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
3523
 
3524
+ # 16/32-bit circuits (if present)
3525
+ for bits in [16, 32]:
3526
+ if f'arithmetic.ripplecarry{bits}bit.fa0.ha1.sum.layer1.or.weight' in population:
3527
+ if debug:
3528
+ print(f"\n{'=' * 60}")
3529
+ print(f" {bits}-BIT CIRCUITS")
3530
+ print(f"{'=' * 60}")
3531
+
3532
+ s, t = self._test_ripplecarry(population, bits, debug)
3533
+ scores += s
3534
+ total_tests += t
3535
+ self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
3536
+
3537
+ s, t = self._test_comparators_nbits(population, bits, debug)
3538
+ scores += s
3539
+ total_tests += t
3540
+ self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t)
3541
+
3542
  # 3-operand adder
3543
  s, t = self._test_add3(population, debug)
3544
  scores += s
neural_alu32.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:788a277fbff9e44eb9006f5f76839ced42d90c1ff31513b36b34c9ee604e3d97
3
+ size 4972488