CharlesCNorton commited on
Commit
56aed4a
·
1 Parent(s): 1b10db6

Integrate LLM guide into README, remove llm folder

Browse files

- Add LLM Integration section with full architecture docs
- Remove formal verification/Coq references
- Delete llm/ folder (content now in README)

Files changed (2) hide show
  1. README.md +243 -67
  2. llm/core.py +0 -766
README.md CHANGED
@@ -17,8 +17,8 @@ tags:
17
  Every logic gate is a threshold neuron: `output = 1 if (Σ wᵢxᵢ + b) ≥ 0 else 0`
18
 
19
  ```
20
- Tensors: 6,296
21
- Parameters: 8,267,667
22
  ```
23
 
24
  ---
@@ -30,7 +30,7 @@ A complete 8-bit processor where every operation—from Boolean logic to arithme
30
  | Component | Specification |
31
  |-----------|---------------|
32
  | Registers | 4 × 8-bit general purpose |
33
- | Memory | 64KB addressable |
34
  | ALU | 16 operations (ADD, SUB, AND, OR, XOR, NOT, SHL, SHR, INC, DEC, CMP, NEG, PASS, ZERO, ONES, NOP) |
35
  | Flags | Zero, Negative, Carry, Overflow |
36
  | Control | JMP, JZ, JNZ, JC, JNC, JN, JP, JV, JNV, CALL, RET, PUSH, POP |
@@ -85,16 +85,16 @@ The weights in this repository implement a complete 8-bit computer: registers, A
85
  | Arithmetic | 18 | Half/full adder, 2/4/8-bit ripple carry, comparators |
86
  | ALU | 3 | 8-bit ALU, control decoder, flag computation |
87
  | Combinational | 10 | MUX (2:1, 4:1, 8:1), DEMUX, encoders, decoders |
88
- | Control Flow | 16 | JMP, conditional jumps, CALL, RET, PUSH, POP |
89
- | Error Detection | 11 | Parity (XOR tree), checksum, CRC, Hamming |
90
- | Modular | 11 | Divisibility by 2-12 (multi-layer for non-powers-of-2) |
91
- | Threshold | 13 | k-of-n gates, majority, minority, exactly-k |
92
- | Pattern | 10 | Popcount, leading/trailing ones, symmetry |
93
- | Memory | 3 | 16-bit addr decoder, 65536x8 read mux, write cell update (packed) |
94
 
95
  ---
96
 
97
- ## Usage
98
 
99
  ```python
100
  import torch
@@ -113,43 +113,43 @@ for a, b_in in [(0,0), (0,1), (1,0), (1,1)]:
113
  inp = torch.tensor([a, b_in], dtype=torch.float32)
114
  out = heaviside(inp @ w + b)
115
  print(f"AND({a}, {b_in}) = {int(out.item())}")
116
- ```
117
-
118
- ---
119
-
120
- ## State Tensor Layout
121
-
122
- All multi-bit fields are **MSB-first** (index 0 is the most-significant bit).
123
-
124
- ```
125
- [ PC[16] | IR[16] | R0[8] R1[8] R2[8] R3[8] | FLAGS[4] | SP[16] | CTRL[4] | MEM[65536][8] ]
126
- ```
127
-
128
- Flags are ordered as: `Z, N, C, V`.
129
-
130
- Control bits are ordered as: `HALT, MEM_WE, MEM_RE, RESERVED`.
131
-
132
- Total state size: `524376` bits.
133
-
134
- ---
135
-
136
- ## Instruction Encoding (16-bit)
137
-
138
- All instruction bits are **MSB-first**.
139
-
140
- ```
141
- 15..12 11..10 9..8 7..0
142
- opcode rd rs imm8
143
- ```
144
-
145
- Interpretation:
146
- - **R-type**: `rd = rd op rs` (imm8 ignored).
147
- - **I-type**: `rd = op rd, imm8` (rs ignored).
148
- - **Address-extended**: `LOAD`, `STORE`, `JMP`, `JZ`, `CALL` consume the next word as a 16-bit address (big-endian). `imm8` is reserved, and the PC skips 4 bytes when the jump is not taken.
149
-
150
- ---
151
-
152
- ## Verification
153
 
154
  The model includes `iron_eval.py` which exhaustively tests all circuits:
155
 
@@ -162,11 +162,11 @@ python iron_eval.py
162
 
163
  | Category | Status | Notes |
164
  |----------|--------|-------|
165
- | Boolean gates | Exhaustively tested | Coq proofs available |
166
- | Arithmetic | Exhaustively tested | Coq proofs available |
167
- | ALU | Exhaustively tested | Coq proofs available |
168
- | Control flow | Exhaustively tested | Coq proofs available |
169
- | Threshold | Exhaustively tested | Coq proofs available |
170
  | Modular (mod 3,5,6,7,9,10,11,12) | Exhaustively tested | Multi-layer, hand-constructed |
171
  | Parity | Exhaustively tested | XOR tree, hand-constructed |
172
  | Modular (mod 2,4,8) | Exhaustively tested | Single-layer, trivial |
@@ -184,15 +184,15 @@ All circuits pass exhaustive testing over their full input domains.
184
  ```
185
  {category}.{circuit}[.{layer}][.{component}].{weight|bias}
186
 
187
- Examples:
188
- boolean.and.weight
189
- boolean.xor.layer1.neuron1.weight
190
- arithmetic.ripplecarry8bit.fa7.ha2.sum.layer1.or.weight
191
- modular.mod5.layer2.eq3.weight
192
- error_detection.paritychecker8bit.stage2.xor1.layer1.nand.bias
193
-
194
- Memory circuits are stored as packed tensors to keep the safetensors header size manageable
195
- (e.g., `memory.addr_decode.weight`, `memory.read.and.weight`, `memory.write.and_old.weight`).
196
  ```
197
 
198
  ---
@@ -207,13 +207,185 @@ All weights are integers. All activations are Heaviside step. Designed for:
207
 
208
  ---
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  ## Files
211
 
212
  | File | Description |
213
  |------|-------------|
214
- | `neural_computer.safetensors` | 6,296 tensors, 8,267,667 parameters |
215
- | `iron_eval.py` | Comprehensive test suite |
216
- | `prune_weights.py` | Weight optimization tool |
217
 
218
  ---
219
 
@@ -237,7 +409,11 @@ MIT
237
 
238
  ---
239
 
240
- ## Links
241
 
242
- - [Coq Proofs](https://github.com/CharlesCNorton/coq-circuits) Formal verification for core circuits
243
- - [HuggingFace](https://huggingface.co/phanerozoic) Other models
 
 
 
 
 
17
  Every logic gate is a threshold neuron: `output = 1 if (Σ wᵢxᵢ + b) ≥ 0 else 0`
18
 
19
  ```
20
+ Tensors: 6,296
21
+ Parameters: 8,267,667
22
  ```
23
 
24
  ---
 
30
  | Component | Specification |
31
  |-----------|---------------|
32
  | Registers | 4 × 8-bit general purpose |
33
+ | Memory | 64KB addressable |
34
  | ALU | 16 operations (ADD, SUB, AND, OR, XOR, NOT, SHL, SHR, INC, DEC, CMP, NEG, PASS, ZERO, ONES, NOP) |
35
  | Flags | Zero, Negative, Carry, Overflow |
36
  | Control | JMP, JZ, JNZ, JC, JNC, JN, JP, JV, JNV, CALL, RET, PUSH, POP |
 
85
  | Arithmetic | 18 | Half/full adder, 2/4/8-bit ripple carry, comparators |
86
  | ALU | 3 | 8-bit ALU, control decoder, flag computation |
87
  | Combinational | 10 | MUX (2:1, 4:1, 8:1), DEMUX, encoders, decoders |
88
+ | Control Flow | 16 | JMP, conditional jumps, CALL, RET, PUSH, POP |
89
+ | Error Detection | 11 | Parity (XOR tree), checksum, CRC, Hamming |
90
+ | Modular | 11 | Divisibility by 2-12 (multi-layer for non-powers-of-2) |
91
+ | Threshold | 13 | k-of-n gates, majority, minority, exactly-k |
92
+ | Pattern | 10 | Popcount, leading/trailing ones, symmetry |
93
+ | Memory | 3 | 16-bit addr decoder, 65536x8 read mux, write cell update (packed) |
94
 
95
  ---
96
 
97
+ ## Usage
98
 
99
  ```python
100
  import torch
 
113
  inp = torch.tensor([a, b_in], dtype=torch.float32)
114
  out = heaviside(inp @ w + b)
115
  print(f"AND({a}, {b_in}) = {int(out.item())}")
116
+ ```
117
+
118
+ ---
119
+
120
+ ## State Tensor Layout
121
+
122
+ All multi-bit fields are **MSB-first** (index 0 is the most-significant bit).
123
+
124
+ ```
125
+ [ PC[16] | IR[16] | R0[8] R1[8] R2[8] R3[8] | FLAGS[4] | SP[16] | CTRL[4] | MEM[65536][8] ]
126
+ ```
127
+
128
+ Flags are ordered as: `Z, N, C, V`.
129
+
130
+ Control bits are ordered as: `HALT, MEM_WE, MEM_RE, RESERVED`.
131
+
132
+ Total state size: `524376` bits.
133
+
134
+ ---
135
+
136
+ ## Instruction Encoding (16-bit)
137
+
138
+ All instruction bits are **MSB-first**.
139
+
140
+ ```
141
+ 15..12 11..10 9..8 7..0
142
+ opcode rd rs imm8
143
+ ```
144
+
145
+ Interpretation:
146
+ - **R-type**: `rd = rd op rs` (imm8 ignored).
147
+ - **I-type**: `rd = op rd, imm8` (rs ignored).
148
+ - **Address-extended**: `LOAD`, `STORE`, `JMP`, `JZ`, `CALL` consume the next word as a 16-bit address (big-endian). `imm8` is reserved, and the PC skips 4 bytes when the jump is not taken.
149
+
150
+ ---
151
+
152
+ ## Verification
153
 
154
  The model includes `iron_eval.py` which exhaustively tests all circuits:
155
 
 
162
 
163
  | Category | Status | Notes |
164
  |----------|--------|-------|
165
+ | Boolean gates | Exhaustively tested | All 2^n input combinations |
166
+ | Arithmetic | Exhaustively tested | Full 8-bit range |
167
+ | ALU | Exhaustively tested | All operations, all inputs |
168
+ | Control flow | Exhaustively tested | Branch/jump conditions |
169
+ | Threshold | Exhaustively tested | k-of-n, majority, etc. |
170
  | Modular (mod 3,5,6,7,9,10,11,12) | Exhaustively tested | Multi-layer, hand-constructed |
171
  | Parity | Exhaustively tested | XOR tree, hand-constructed |
172
  | Modular (mod 2,4,8) | Exhaustively tested | Single-layer, trivial |
 
184
  ```
185
  {category}.{circuit}[.{layer}][.{component}].{weight|bias}
186
 
187
+ Examples:
188
+ boolean.and.weight
189
+ boolean.xor.layer1.neuron1.weight
190
+ arithmetic.ripplecarry8bit.fa7.ha2.sum.layer1.or.weight
191
+ modular.mod5.layer2.eq3.weight
192
+ error_detection.paritychecker8bit.stage2.xor1.layer1.nand.bias
193
+
194
+ Memory circuits are stored as packed tensors to keep the safetensors header size manageable
195
+ (e.g., `memory.addr_decode.weight`, `memory.read.and.weight`, `memory.write.and_old.weight`).
196
  ```
197
 
198
  ---
 
207
 
208
  ---
209
 
210
+ ## LLM Integration
211
+
212
+ The threshold circuits can be embedded into transformer MLP layers to give LLMs exact arithmetic capability.
213
+
214
+ ### Core Thesis
215
+
216
+ Standard LLMs fail at arithmetic because they're interpolators—they approximate functions over training distributions rather than compute exact results. A 360M parameter model trained on internet text has seen "127 + 128 = 255" zero or few times, so it guesses based on pattern matching.
217
+
218
+ We solve this by embedding **frozen, proven-correct arithmetic circuits** directly into the transformer's MLP layers. The circuits use threshold logic (weighted sums + step activation), which is structurally compatible with neural network layers. We train only the **interface layers** that learn to:
219
+
220
+ 1. Extract operands from token embeddings
221
+ 2. Route computation through the circuits
222
+ 3. Inject results back into the residual stream
223
+
224
+ The model learns **call dispatch**, not arithmetic. The arithmetic is already solved.
225
+
226
+ ### Architecture
227
+
228
+ Standard MLP block with parallel circuit path:
229
+
230
+ ```
231
+ x ──┬── MLP path ────────────────┬── + ── output
232
+ │ │
233
+ └── BitExtractor ── Circuit ─┴── BitInjector
234
+
235
+ Router (learned weighting)
236
+ ```
237
+
238
+ Augmented MLP forward pass:
239
+
240
+ ```python
241
+ def forward(x): # x: [batch, seq, d_model]
242
+ # Original MLP path (unchanged)
243
+ mlp_out = self.down_proj(silu(self.gate_proj(x)) * self.up_proj(x))
244
+
245
+ # Circuit path (new)
246
+ a_bits, b_bits = self.bit_extractor(x) # [batch, seq, 8] each
247
+ result_bits, carry = self.circuits.add_8bit(a_bits, b_bits)
248
+ flags = self.compute_flags(result_bits, carry)
249
+ circuit_delta = self.bit_injector(result_bits, flags)
250
+
251
+ # Routing
252
+ route_weights = self.router(x) # [batch, seq, 2] softmax
253
+
254
+ # Combine
255
+ return mlp_out + route_weights[..., 1:2] * circuit_delta
256
+ ```
257
+
258
+ ### Threshold Logic Fundamentals
259
+
260
+ A threshold gate computes:
261
+
262
+ ```
263
+ output = 1 if (Σ wᵢxᵢ + b) ≥ 0
264
+ 0 otherwise
265
+ ```
266
+
267
+ Example gates:
268
+
269
+ ```
270
+ AND: w=[1,1], b=-2
271
+ AND(0,0) = H(-2) = 0
272
+ AND(1,1) = H(0) = 1
273
+
274
+ OR: w=[1,1], b=-1
275
+ OR(0,1) = H(0) = 1
276
+ OR(1,1) = H(1) = 1
277
+
278
+ XOR: requires 2 layers (not linearly separable)
279
+ Layer 1: OR + NAND
280
+ Layer 2: AND
281
+ ```
282
+
283
+ Full adder = 2 half-adders + carry OR, ~4 threshold layers.
284
+ 8-bit ripple carry = 8 chained full adders, ~32 threshold layers.
285
+
286
+ ### Interface Layers (Trainable)
287
+
288
+ **BitExtractor** — Maps embedding → two 8-bit operands:
289
+
290
+ ```python
291
+ class BitExtractor(nn.Module):
292
+ def __init__(self, d_model):
293
+ self.proj = nn.Linear(d_model, 16)
294
+
295
+ def forward(self, x):
296
+ logits = self.proj(x)
297
+ bits = heaviside(logits) # STE for training
298
+ return bits[..., :8], bits[..., 8:]
299
+ ```
300
+
301
+ **BitInjector** — Maps result bits → embedding delta:
302
+
303
+ ```python
304
+ class BitInjector(nn.Module):
305
+ def __init__(self, d_model):
306
+ self.proj = nn.Linear(16, d_model)
307
+ self.scale = nn.Parameter(torch.tensor(0.1))
308
+
309
+ def forward(self, result_bits, flags):
310
+ combined = torch.cat([result_bits, flags], dim=-1)
311
+ return self.proj(combined) * self.scale
312
+ ```
313
+
314
+ **Router** — Decides when to use circuits:
315
+
316
+ ```python
317
+ class Router(nn.Module):
318
+ def __init__(self, d_model):
319
+ self.net = nn.Sequential(
320
+ nn.Linear(d_model, 64), nn.ReLU(),
321
+ nn.Linear(64, 2), nn.Softmax(dim=-1)
322
+ )
323
+ ```
324
+
325
+ ### Trainable Parameters
326
+
327
+ For SmolLM2-360M (d_model=960), augmenting 11 layers:
328
+
329
+ | Component | Params/Layer |
330
+ |-----------|-------------|
331
+ | BitExtractor | 15,376 |
332
+ | BitInjector | 16,321 |
333
+ | Router | 61,698 |
334
+ | OpSelector | ~31,000 |
335
+ | **Total** | ~124,395 |
336
+
337
+ **11 layers × 124,395 = ~1.37M trainable parameters** (0.38% of model)
338
+
339
+ ### Gradient Flow
340
+
341
+ Heaviside has zero gradient almost everywhere. We use **Straight-Through Estimator (STE)**:
342
+
343
+ ```python
344
+ class HeavisideSTE(torch.autograd.Function):
345
+ @staticmethod
346
+ def forward(ctx, x):
347
+ return (x >= 0).float()
348
+
349
+ @staticmethod
350
+ def backward(ctx, grad_output):
351
+ return grad_output # pass through unchanged
352
+ ```
353
+
354
+ ### Training Strategy
355
+
356
+ 1. **Data**: Generate 8-bit arithmetic problems exhaustively (256×256 = 65,536 unique)
357
+ 2. **Loss**: Cross-entropy on answer tokens only (prompt masked with -100)
358
+ 3. **Optimizer**: AdamW on interface params only, lr=1e-4
359
+ 4. **Curriculum**: Single-digit → two-digit → full 8-bit → adversarial (127+128, 255+1)
360
+
361
+ ### Inference
362
+
363
+ At inference, Heaviside is true step function—no approximation. If BitExtractor correctly extracts operands, the circuit **will** output the correct result. Circuit computation adds ~5-10% latency overhead.
364
+
365
+ ### Target Performance
366
+
367
+ | Model | Baseline | Target |
368
+ |-------|----------|--------|
369
+ | SmolLM2-360M | ~5-10% | >95% |
370
+
371
+ The interface generalizes to **all** 65,536 8-bit additions once trained—no memorization, the circuits compute.
372
+
373
+ ### Extension Roadmap
374
+
375
+ - **Additional operations**: Subtraction, multiplication, division, bitwise ops, comparisons
376
+ - **Multi-operand**: "15 + 27 + 33 =" via accumulator pattern
377
+ - **Larger widths**: 16-bit, 32-bit via chained circuits
378
+ - **Symbolic integration**: Natural language problems → extract operands → compute → generate answer
379
+
380
+ ---
381
+
382
  ## Files
383
 
384
  | File | Description |
385
  |------|-------------|
386
+ | `neural_computer.safetensors` | 6,296 tensors, 8,267,667 parameters |
387
+ | `cpu/core.py` | CPU state, reference cycle, threshold runtime |
388
+ | `eval/iron_eval.py` | Comprehensive test suite |
389
 
390
  ---
391
 
 
409
 
410
  ---
411
 
412
+ ## References
413
 
414
+ 1. McCulloch & Pitts (1943). "A Logical Calculus of Ideas Immanent in Nervous Activity"
415
+ 2. Muroga (1971). "Threshold Logic and Its Applications"
416
+ 3. Siegelmann & Sontag (1995). "On the Computational Power of Neural Nets"
417
+ 4. Bengio et al. (2013). "Estimating or Propagating Gradients Through Stochastic Neurons"
418
+ 5. Ma et al. (2024). "The Era of 1-bit LLMs" (BitNet b1.58)
419
+ 6. HuggingFace (2024). "SmolLM2: Small Language Models"
llm/core.py DELETED
@@ -1,766 +0,0 @@
1
- """
2
- Circuit-Augmented LLM: Embedding Threshold Logic Circuits into Transformers
3
- ============================================================================
4
-
5
- Embeds frozen, proven-correct arithmetic circuits into transformer MLP layers.
6
- The model learns call dispatch (when to use circuits), not arithmetic.
7
-
8
- ARCHITECTURE
9
- ------------
10
- Standard LLM MLPs are augmented with a parallel circuit path:
11
-
12
- x ──┬── MLP path ────────────────┬── + ── output
13
- │ │
14
- └── BitExtractor ── Circuit ─┴── BitInjector
15
-
16
- Router (learned weighting)
17
-
18
- THRESHOLD LOGIC
19
- ---------------
20
- Each gate: output = 1 if (Σ wᵢxᵢ + b) ≥ 0 else 0
21
-
22
- Examples:
23
- AND: w=[1,1], b=-2 → fires only when both inputs are 1
24
- OR: w=[1,1], b=-1 → fires when either input is 1
25
- XOR: 2-layer network (not linearly separable)
26
-
27
- Full adder = 2 half-adders + carry OR, ~4 threshold layers.
28
- 8-bit ripple carry = 8 chained full adders, ~32 threshold layers.
29
-
30
- TRAINING
31
- --------
32
- Only interface layers train (~1.37M params):
33
- - BitExtractor: embedding → operand bits
34
- - BitInjector: result bits → embedding delta
35
- - Router: when to use circuits vs MLP
36
-
37
- Circuits are frozen (proven correct via 6,590 exhaustive tests).
38
- Uses Straight-Through Estimator for Heaviside gradient flow.
39
-
40
- TARGET: SmolLM2-360M
41
- - 960 hidden dim, 32 layers, 361M params
42
- - Augment middle third (layers 10-20)
43
- - Baseline arithmetic: ~5-10%
44
- - Target: >95% (circuit-accurate)
45
-
46
- USAGE
47
- -----
48
- # Augment model
49
- model = augment_smollm2_with_circuits(model, "neural_computer.safetensors")
50
-
51
- # Train interface
52
- model = train_interface(model, tokenizer, n_epochs=3)
53
-
54
- # Evaluate
55
- results = evaluate_arithmetic(model, tokenizer, n_problems=100)
56
-
57
- REFERENCES
58
- ----------
59
- 1. McCulloch & Pitts (1943). Logical Calculus of Ideas in Nervous Activity
60
- 2. Muroga (1971). Threshold Logic and Its Applications
61
- 3. Bengio et al. (2013). Estimating Gradients Through Stochastic Neurons (STE)
62
- 4. Ma et al. (2024). The Era of 1-bit LLMs (BitNet)
63
- """
64
-
65
- from __future__ import annotations
66
-
67
- import argparse
68
- import warnings
69
- from typing import Dict, List, Optional, Tuple
70
-
71
- import torch
72
- import torch.nn as nn
73
- import torch.nn.functional as F
74
- from safetensors.torch import load_file
75
- from torch.utils.data import DataLoader, Dataset
76
- from tqdm import tqdm
77
- from transformers import AutoModelForCausalLM, AutoTokenizer
78
-
79
- warnings.filterwarnings("ignore")
80
-
81
-
82
- class HeavisideSTE(torch.autograd.Function):
83
- """Heaviside step function with straight-through estimator for backprop."""
84
-
85
- @staticmethod
86
- def forward(ctx, x):
87
- return (x >= 0).float()
88
-
89
- @staticmethod
90
- def backward(ctx, grad_output):
91
- return grad_output
92
-
93
-
94
- def heaviside(x: torch.Tensor) -> torch.Tensor:
95
- """Heaviside step: 1 if x >= 0, else 0. Uses STE for training."""
96
- return HeavisideSTE.apply(x)
97
-
98
-
99
- class CircuitExecutor(nn.Module):
100
- """
101
- Executes threshold logic circuits from safetensors.
102
- All circuit weights are frozen.
103
- """
104
-
105
- def __init__(self, circuit_path: str, device: str = "cpu"):
106
- super().__init__()
107
- self.device = device
108
-
109
- raw_circuits = load_file(circuit_path)
110
-
111
- self.circuits = {}
112
- for k, v in raw_circuits.items():
113
- safe_name = k.replace(".", "__")
114
- self.register_buffer(safe_name, v.float().to(device))
115
- self.circuits[k] = safe_name
116
-
117
- def _get(self, name: str) -> torch.Tensor:
118
- return getattr(self, self.circuits[name])
119
-
120
- def eval_and(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
121
- inp = torch.stack([a, b], dim=-1)
122
- w = self._get("boolean.and.weight")
123
- bias = self._get("boolean.and.bias")
124
- return heaviside(inp @ w + bias)
125
-
126
- def eval_or(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
127
- inp = torch.stack([a, b], dim=-1)
128
- w = self._get("boolean.or.weight")
129
- bias = self._get("boolean.or.bias")
130
- return heaviside(inp @ w + bias)
131
-
132
- def eval_xor(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
133
- inp = torch.stack([a, b], dim=-1)
134
-
135
- w1_n1 = self._get("boolean.xor.layer1.neuron1.weight")
136
- b1_n1 = self._get("boolean.xor.layer1.neuron1.bias")
137
- w1_n2 = self._get("boolean.xor.layer1.neuron2.weight")
138
- b1_n2 = self._get("boolean.xor.layer1.neuron2.bias")
139
-
140
- h1 = heaviside(inp @ w1_n1 + b1_n1)
141
- h2 = heaviside(inp @ w1_n2 + b1_n2)
142
- hidden = torch.stack([h1, h2], dim=-1)
143
-
144
- w2 = self._get("boolean.xor.layer2.weight")
145
- b2 = self._get("boolean.xor.layer2.bias")
146
-
147
- return heaviside(hidden @ w2 + b2)
148
-
149
- def eval_full_adder(
150
- self, a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor, prefix: str
151
- ) -> Tuple[torch.Tensor, torch.Tensor]:
152
- inp_ab = torch.stack([a, b], dim=-1)
153
-
154
- w1_or = self._get(f"{prefix}.ha1.sum.layer1.or.weight")
155
- b1_or = self._get(f"{prefix}.ha1.sum.layer1.or.bias")
156
- w1_nand = self._get(f"{prefix}.ha1.sum.layer1.nand.weight")
157
- b1_nand = self._get(f"{prefix}.ha1.sum.layer1.nand.bias")
158
- w2 = self._get(f"{prefix}.ha1.sum.layer2.weight")
159
- b2 = self._get(f"{prefix}.ha1.sum.layer2.bias")
160
-
161
- h_or = heaviside(inp_ab @ w1_or + b1_or)
162
- h_nand = heaviside(inp_ab @ w1_nand + b1_nand)
163
- hidden = torch.stack([h_or, h_nand], dim=-1)
164
- ha1_sum = heaviside(hidden @ w2 + b2)
165
-
166
- w_c1 = self._get(f"{prefix}.ha1.carry.weight")
167
- b_c1 = self._get(f"{prefix}.ha1.carry.bias")
168
- ha1_carry = heaviside(inp_ab @ w_c1 + b_c1)
169
-
170
- inp_ha2 = torch.stack([ha1_sum, cin], dim=-1)
171
- w1_or = self._get(f"{prefix}.ha2.sum.layer1.or.weight")
172
- b1_or = self._get(f"{prefix}.ha2.sum.layer1.or.bias")
173
- w1_nand = self._get(f"{prefix}.ha2.sum.layer1.nand.weight")
174
- b1_nand = self._get(f"{prefix}.ha2.sum.layer1.nand.bias")
175
- w2 = self._get(f"{prefix}.ha2.sum.layer2.weight")
176
- b2 = self._get(f"{prefix}.ha2.sum.layer2.bias")
177
-
178
- h_or = heaviside(inp_ha2 @ w1_or + b1_or)
179
- h_nand = heaviside(inp_ha2 @ w1_nand + b1_nand)
180
- hidden = torch.stack([h_or, h_nand], dim=-1)
181
- ha2_sum = heaviside(hidden @ w2 + b2)
182
-
183
- w_c2 = self._get(f"{prefix}.ha2.carry.weight")
184
- b_c2 = self._get(f"{prefix}.ha2.carry.bias")
185
- ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2)
186
-
187
- inp_cout = torch.stack([ha1_carry, ha2_carry], dim=-1)
188
- w_or = self._get(f"{prefix}.carry_or.weight")
189
- b_or = self._get(f"{prefix}.carry_or.bias")
190
- cout = heaviside(inp_cout @ w_or + b_or)
191
-
192
- return ha2_sum, cout
193
-
194
- def add_8bit(
195
- self, a_bits: torch.Tensor, b_bits: torch.Tensor
196
- ) -> Tuple[torch.Tensor, torch.Tensor]:
197
- """
198
- 8-bit ripple carry addition.
199
- a_bits, b_bits: [..., 8] tensors (LSB first)
200
- Returns: (result_bits [..., 8], carry_out [...])
201
- """
202
- batch_shape = a_bits.shape[:-1]
203
- carry = torch.zeros(batch_shape, device=a_bits.device)
204
- result_bits = []
205
-
206
- for i in range(8):
207
- a_i = a_bits[..., i]
208
- b_i = b_bits[..., i]
209
- sum_bit, carry = self.eval_full_adder(
210
- a_i, b_i, carry, f"arithmetic.ripplecarry8bit.fa{i}"
211
- )
212
- result_bits.append(sum_bit)
213
-
214
- return torch.stack(result_bits, dim=-1), carry
215
-
216
- def greater_than_8bit(
217
- self, a_bits: torch.Tensor, b_bits: torch.Tensor
218
- ) -> torch.Tensor:
219
- diff = a_bits - b_bits
220
- w = self._get("arithmetic.greaterthan8bit.comparator")
221
- score = (diff * w).sum(dim=-1)
222
- return (score > 0).float()
223
-
224
- def less_than_8bit(
225
- self, a_bits: torch.Tensor, b_bits: torch.Tensor
226
- ) -> torch.Tensor:
227
- diff = b_bits - a_bits
228
- w = self._get("arithmetic.lessthan8bit.comparator")
229
- score = (diff * w).sum(dim=-1)
230
- return (score > 0).float()
231
-
232
- def equal_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
233
- gt = self.greater_than_8bit(a_bits, b_bits)
234
- lt = self.less_than_8bit(a_bits, b_bits)
235
- return (1 - gt) * (1 - lt)
236
-
237
-
238
- class BitExtractor(nn.Module):
239
- """Maps embedding -> two 8-bit operands."""
240
-
241
- def __init__(self, d_model: int):
242
- super().__init__()
243
- self.d_model = d_model
244
- self.proj = nn.Linear(d_model, 16)
245
- self.temperature = nn.Parameter(torch.tensor(1.0))
246
-
247
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
248
- logits = self.proj(x)
249
- bits = heaviside(logits)
250
- a_bits = bits[..., :8]
251
- b_bits = bits[..., 8:]
252
- return a_bits, b_bits
253
-
254
-
255
- class BitInjector(nn.Module):
256
- """Maps result bits -> embedding delta."""
257
-
258
- def __init__(self, d_model: int):
259
- super().__init__()
260
- self.d_model = d_model
261
- self.proj = nn.Linear(16, d_model)
262
- self.scale = nn.Parameter(torch.tensor(0.1))
263
-
264
- def forward(self, result_bits: torch.Tensor, flags: torch.Tensor) -> torch.Tensor:
265
- combined = torch.cat([result_bits, flags], dim=-1)
266
- return self.proj(combined) * self.scale
267
-
268
-
269
- class CircuitAugmentedMLP(nn.Module):
270
- """
271
- MLP block augmented with frozen threshold circuits.
272
- Original MLP runs in parallel with circuit path; router decides weighting.
273
- """
274
-
275
- def __init__(
276
- self,
277
- d_model: int,
278
- intermediate_size: int,
279
- circuit_path: str,
280
- device: str = "cpu",
281
- ):
282
- super().__init__()
283
- self.d_model = d_model
284
-
285
- self.gate_proj = nn.Linear(d_model, intermediate_size, bias=False)
286
- self.up_proj = nn.Linear(d_model, intermediate_size, bias=False)
287
- self.down_proj = nn.Linear(intermediate_size, d_model, bias=False)
288
- self.act_fn = nn.SiLU()
289
-
290
- self.circuits = CircuitExecutor(circuit_path, device)
291
- self.bit_extractor = BitExtractor(d_model)
292
- self.bit_injector = BitInjector(d_model)
293
-
294
- self.router = nn.Sequential(
295
- nn.Linear(d_model, 64),
296
- nn.ReLU(),
297
- nn.Linear(64, 2),
298
- nn.Softmax(dim=-1),
299
- )
300
-
301
- self.op_selector = nn.Sequential(
302
- nn.Linear(d_model, 32),
303
- nn.ReLU(),
304
- nn.Linear(32, 4),
305
- nn.Softmax(dim=-1),
306
- )
307
-
308
- def _compute_flags(
309
- self, result_bits: torch.Tensor, carry: torch.Tensor
310
- ) -> torch.Tensor:
311
- batch_shape = result_bits.shape[:-1]
312
-
313
- zero = (result_bits.sum(dim=-1) == 0).float()
314
- negative = result_bits[..., 7]
315
- carry_flag = carry
316
-
317
- flags = torch.zeros(*batch_shape, 8, device=result_bits.device)
318
- flags[..., 0] = zero
319
- flags[..., 1] = negative
320
- flags[..., 2] = carry_flag
321
-
322
- return flags
323
-
324
- def _circuit_forward(self, x: torch.Tensor) -> torch.Tensor:
325
- a_bits, b_bits = self.bit_extractor(x)
326
- add_result, add_carry = self.circuits.add_8bit(a_bits, b_bits)
327
- add_flags = self._compute_flags(add_result, add_carry)
328
- circuit_delta = self.bit_injector(add_result, add_flags)
329
- return circuit_delta
330
-
331
- def forward(self, x: torch.Tensor) -> torch.Tensor:
332
- mlp_out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
333
-
334
- circuit_out = self._circuit_forward(x)
335
-
336
- route_weights = self.router(x)
337
- circuit_weight = route_weights[..., 1:2]
338
-
339
- output = mlp_out + circuit_weight * circuit_out
340
-
341
- return output
342
-
343
-
344
- def augment_smollm2_with_circuits(
345
- model: AutoModelForCausalLM,
346
- circuit_path: str,
347
- layer_indices: list = None,
348
- device: str = "cpu",
349
- ) -> AutoModelForCausalLM:
350
- """
351
- Insert circuit blocks into SmolLM2's MLP layers.
352
-
353
- Args:
354
- model: Pretrained SmolLM2
355
- circuit_path: Path to neural_computer.safetensors
356
- layer_indices: Which layers to augment (default: middle third)
357
- device: Device for circuit tensors
358
-
359
- Returns:
360
- Model with circuit-augmented MLPs
361
- """
362
- config = model.config
363
- num_layers = config.num_hidden_layers
364
-
365
- if layer_indices is None:
366
- start = num_layers // 3
367
- end = 2 * num_layers // 3
368
- layer_indices = list(range(start, end))
369
-
370
- print(f"Augmenting layers {layer_indices} with threshold circuits...")
371
-
372
- for idx in layer_indices:
373
- layer = model.model.layers[idx]
374
- old_mlp = layer.mlp
375
-
376
- new_mlp = CircuitAugmentedMLP(
377
- d_model=config.hidden_size,
378
- intermediate_size=config.intermediate_size,
379
- circuit_path=circuit_path,
380
- device=device,
381
- )
382
-
383
- new_mlp.gate_proj.weight.data = old_mlp.gate_proj.weight.data.clone()
384
- new_mlp.up_proj.weight.data = old_mlp.up_proj.weight.data.clone()
385
- new_mlp.down_proj.weight.data = old_mlp.down_proj.weight.data.clone()
386
-
387
- layer.mlp = new_mlp
388
-
389
- for name, param in model.named_parameters():
390
- if "circuits" in name:
391
- param.requires_grad = False
392
-
393
- print("Done. Circuit weights frozen, interfaces trainable.")
394
-
395
- return model
396
-
397
-
398
- def generate_arithmetic_batch(
399
- batch_size: int, max_val: int = 255
400
- ) -> Tuple[list, list]:
401
- """Generate batch of arithmetic problems and solutions."""
402
- prompts = []
403
- targets = []
404
-
405
- for _ in range(batch_size):
406
- a = torch.randint(0, max_val + 1, (1,)).item()
407
- b = torch.randint(0, max_val + 1, (1,)).item()
408
- result = (a + b) % 256
409
-
410
- prompts.append(f"{a} + {b} =")
411
- targets.append(f" {result}")
412
-
413
- return prompts, targets
414
-
415
-
416
- def evaluate_arithmetic(
417
- model: AutoModelForCausalLM,
418
- tokenizer: AutoTokenizer,
419
- n_problems: int = 100,
420
- device: str = "cpu",
421
- ) -> dict:
422
- """Evaluate model on random arithmetic problems."""
423
- correct = 0
424
- total = 0
425
- errors = []
426
-
427
- model.eval()
428
-
429
- for _ in range(n_problems):
430
- a = torch.randint(0, 256, (1,)).item()
431
- b = torch.randint(0, 256, (1,)).item()
432
- expected = (a + b) % 256
433
-
434
- prompt = f"{a} + {b} ="
435
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
436
-
437
- with torch.no_grad():
438
- outputs = model.generate(
439
- **inputs,
440
- max_new_tokens=10,
441
- do_sample=False,
442
- pad_token_id=tokenizer.eos_token_id,
443
- )
444
-
445
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
446
-
447
- try:
448
- answer_part = response.split("=")[-1].strip()
449
- predicted = int("".join(c for c in answer_part.split()[0] if c.isdigit()))
450
-
451
- if predicted == expected:
452
- correct += 1
453
- else:
454
- errors.append((a, b, expected, predicted))
455
- except:
456
- errors.append((a, b, expected, "parse_error"))
457
-
458
- total += 1
459
-
460
- return {
461
- "accuracy": correct / total,
462
- "correct": correct,
463
- "total": total,
464
- "errors": errors[:10],
465
- }
466
-
467
-
468
- class ArithmeticDataset(Dataset):
469
- """Dataset of 8-bit addition problems."""
470
-
471
- def __init__(self, tokenizer, n_samples: int = 10000, max_val: int = 255):
472
- self.tokenizer = tokenizer
473
- self.n_samples = n_samples
474
- self.max_val = max_val
475
-
476
- self.examples = []
477
- for _ in range(n_samples):
478
- a = torch.randint(0, max_val + 1, (1,)).item()
479
- b = torch.randint(0, max_val + 1, (1,)).item()
480
- result = (a + b) % 256
481
-
482
- prompt = f"{a} + {b} ="
483
- target = f" {result}"
484
-
485
- self.examples.append((prompt, target, a, b, result))
486
-
487
- def __len__(self):
488
- return len(self.examples)
489
-
490
- def __getitem__(self, idx):
491
- prompt, target, a, b, result = self.examples[idx]
492
-
493
- prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
494
- target_ids = self.tokenizer.encode(target, add_special_tokens=False)
495
-
496
- input_ids = prompt_ids + target_ids
497
- labels = [-100] * len(prompt_ids) + target_ids
498
-
499
- return {
500
- "input_ids": torch.tensor(input_ids),
501
- "labels": torch.tensor(labels),
502
- "a": a,
503
- "b": b,
504
- "result": result,
505
- }
506
-
507
-
508
- def collate_fn(batch):
509
- """Collate with padding."""
510
- max_len = max(len(item["input_ids"]) for item in batch)
511
-
512
- input_ids = []
513
- labels = []
514
- attention_mask = []
515
-
516
- for item in batch:
517
- pad_len = max_len - len(item["input_ids"])
518
-
519
- input_ids.append(
520
- torch.cat([item["input_ids"], torch.zeros(pad_len, dtype=torch.long)])
521
- )
522
- labels.append(
523
- torch.cat(
524
- [item["labels"], torch.full((pad_len,), -100, dtype=torch.long)]
525
- )
526
- )
527
- attention_mask.append(
528
- torch.cat([torch.ones(len(item["input_ids"])), torch.zeros(pad_len)])
529
- )
530
-
531
- return {
532
- "input_ids": torch.stack(input_ids),
533
- "labels": torch.stack(labels),
534
- "attention_mask": torch.stack(attention_mask),
535
- }
536
-
537
-
538
- def train_interface(
539
- model: AutoModelForCausalLM,
540
- tokenizer: AutoTokenizer,
541
- n_epochs: int = 3,
542
- batch_size: int = 16,
543
- lr: float = 1e-4,
544
- n_train_samples: int = 10000,
545
- device: str = "cpu",
546
- eval_every: int = 500,
547
- ):
548
- """
549
- Train the circuit interface layers.
550
-
551
- Only trains:
552
- - bit_extractor (embedding -> bits)
553
- - bit_injector (bits -> embedding)
554
- - router (circuit vs MLP weighting)
555
- - op_selector (which operation)
556
- """
557
- print("\n" + "=" * 70)
558
- print(" TRAINING CIRCUIT INTERFACE")
559
- print("=" * 70)
560
-
561
- interface_params = []
562
- frozen_count = 0
563
- trainable_count = 0
564
-
565
- for name, param in model.named_parameters():
566
- if any(
567
- x in name for x in ["bit_extractor", "bit_injector", "router", "op_selector"]
568
- ):
569
- param.requires_grad = True
570
- interface_params.append(param)
571
- trainable_count += param.numel()
572
- else:
573
- param.requires_grad = False
574
- frozen_count += param.numel()
575
-
576
- print(f"\n Frozen parameters: {frozen_count:,}")
577
- print(f" Trainable parameters: {trainable_count:,}")
578
- print(f" Training {len(interface_params)} parameter groups")
579
-
580
- print(f"\n Creating dataset ({n_train_samples} examples)...")
581
- dataset = ArithmeticDataset(tokenizer, n_samples=n_train_samples)
582
- dataloader = DataLoader(
583
- dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
584
- )
585
-
586
- optimizer = torch.optim.AdamW(interface_params, lr=lr)
587
-
588
- model.to(device)
589
- model.train()
590
-
591
- global_step = 0
592
- total_loss = 0
593
-
594
- for epoch in range(n_epochs):
595
- print(f"\n Epoch {epoch + 1}/{n_epochs}")
596
- print(" " + "-" * 60)
597
-
598
- epoch_loss = 0
599
- epoch_steps = 0
600
-
601
- pbar = tqdm(dataloader, desc=" Training", leave=False)
602
-
603
- for batch in pbar:
604
- input_ids = batch["input_ids"].to(device)
605
- labels = batch["labels"].to(device)
606
- attention_mask = batch["attention_mask"].to(device)
607
-
608
- outputs = model(
609
- input_ids=input_ids, attention_mask=attention_mask, labels=labels
610
- )
611
-
612
- loss = outputs.loss
613
-
614
- optimizer.zero_grad()
615
- loss.backward()
616
- optimizer.step()
617
-
618
- epoch_loss += loss.item()
619
- epoch_steps += 1
620
- global_step += 1
621
- total_loss += loss.item()
622
-
623
- pbar.set_postfix({"loss": f"{loss.item():.4f}"})
624
-
625
- if global_step % eval_every == 0:
626
- model.eval()
627
- eval_results = evaluate_arithmetic(
628
- model, tokenizer, n_problems=50, device=device
629
- )
630
- print(
631
- f"\n Step {global_step}: Loss={total_loss/eval_every:.4f}, "
632
- f"Accuracy={eval_results['accuracy']*100:.1f}%"
633
- )
634
- total_loss = 0
635
- model.train()
636
-
637
- avg_loss = epoch_loss / epoch_steps
638
- print(f"\n Epoch {epoch + 1} complete. Avg loss: {avg_loss:.4f}")
639
-
640
- model.eval()
641
- eval_results = evaluate_arithmetic(
642
- model, tokenizer, n_problems=100, device=device
643
- )
644
- print(
645
- f" Evaluation: {eval_results['accuracy']*100:.1f}% "
646
- f"({eval_results['correct']}/{eval_results['total']})"
647
- )
648
-
649
- if eval_results["errors"]:
650
- print(" Sample errors:")
651
- for a, b, exp, got in eval_results["errors"][:3]:
652
- print(f" {a} + {b} = {exp}, model said {got}")
653
-
654
- model.train()
655
-
656
- print("\n" + "=" * 70)
657
- print(" TRAINING COMPLETE")
658
- print("=" * 70)
659
-
660
- return model
661
-
662
-
663
- if __name__ == "__main__":
664
- parser = argparse.ArgumentParser(description="Circuit-Augmented LLM")
665
- parser.add_argument(
666
- "--circuit-path",
667
- type=str,
668
- default="./neural_computer.safetensors",
669
- help="Path to circuit weights",
670
- )
671
- parser.add_argument("--device", type=str, default="cpu", help="Device")
672
- parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
673
- parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
674
- parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
675
- parser.add_argument(
676
- "--n-samples", type=int, default=5000, help="Training samples"
677
- )
678
- parser.add_argument(
679
- "--eval-only", action="store_true", help="Only evaluate baseline"
680
- )
681
- args = parser.parse_args()
682
-
683
- print("=" * 70)
684
- print(" CIRCUIT-AUGMENTED LLM")
685
- print("=" * 70)
686
-
687
- print("\n[1] Loading SmolLM2-360M...")
688
- model_id = "HuggingFaceTB/SmolLM2-360M"
689
- tokenizer = AutoTokenizer.from_pretrained(model_id)
690
- tokenizer.pad_token = tokenizer.eos_token
691
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
692
-
693
- print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
694
-
695
- print("\n[2] Baseline arithmetic evaluation...")
696
- baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
697
- print(
698
- f" Accuracy: {baseline['accuracy']*100:.1f}% "
699
- f"({baseline['correct']}/{baseline['total']})"
700
- )
701
- if baseline["errors"]:
702
- print(" Sample errors:")
703
- for a, b, exp, got in baseline["errors"][:5]:
704
- print(f" {a} + {b} = {exp}, model said {got}")
705
-
706
- if args.eval_only:
707
- print("\nDone (eval only mode).")
708
- exit(0)
709
-
710
- print(f"\n[3] Augmenting with threshold circuits...")
711
- print(f" Circuit path: {args.circuit_path}")
712
- model = augment_smollm2_with_circuits(model, args.circuit_path, device=args.device)
713
-
714
- new_params = sum(p.numel() for p in model.parameters())
715
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
716
- print(f" Total parameters: {new_params:,}")
717
- print(f" Trainable parameters: {trainable:,}")
718
-
719
- print("\n[4] Testing circuit execution...")
720
- circuit_exec = CircuitExecutor(args.circuit_path, args.device)
721
-
722
- test_cases = [(127, 128), (255, 1), (0, 0), (100, 55)]
723
- for a, b in test_cases:
724
- a_bits = torch.tensor([(a >> i) & 1 for i in range(8)], dtype=torch.float32)
725
- b_bits = torch.tensor([(b >> i) & 1 for i in range(8)], dtype=torch.float32)
726
-
727
- result_bits, carry = circuit_exec.add_8bit(
728
- a_bits.unsqueeze(0), b_bits.unsqueeze(0)
729
- )
730
-
731
- result = sum(int(result_bits[0, i].item()) * (2**i) for i in range(8))
732
- expected = (a + b) % 256
733
-
734
- status = "OK" if result == expected else "FAIL"
735
- print(f" {a} + {b} = {result} (expected {expected}) [{status}]")
736
-
737
- print("\n[5] Training interface layers...")
738
- model = train_interface(
739
- model,
740
- tokenizer,
741
- n_epochs=args.epochs,
742
- batch_size=args.batch_size,
743
- lr=args.lr,
744
- n_train_samples=args.n_samples,
745
- device=args.device,
746
- )
747
-
748
- print("\n[6] Final evaluation...")
749
- final = evaluate_arithmetic(model, tokenizer, n_problems=100, device=args.device)
750
- print(f" Final accuracy: {final['accuracy']*100:.1f}%")
751
- print(
752
- f" Improvement: {baseline['accuracy']*100:.1f}% -> {final['accuracy']*100:.1f}%"
753
- )
754
-
755
- save_path = "./circuit_augmented_smollm2.pt"
756
- print(f"\n[7] Saving to {save_path}...")
757
- torch.save(
758
- {
759
- "model_state_dict": model.state_dict(),
760
- "baseline_accuracy": baseline["accuracy"],
761
- "final_accuracy": final["accuracy"],
762
- },
763
- save_path,
764
- )
765
-
766
- print("\nDone!")