CharlesCNorton commited on
Commit
1b10db6
·
1 Parent(s): d09568e

Consolidate LLM modules into single core.py

Browse files

Merge circuit_llm.py, guide.md, and train_circuit_interface.py into llm/core.py.
Documentation integrated as module docstring.

llm/{circuit_llm.py → core.py} RENAMED
@@ -1,606 +1,766 @@
1
- """
2
- Circuit-Augmented LLM: Embedding threshold logic circuits into SmolLM2
3
- ======================================================================
4
-
5
- Replaces/augments MLP layers with frozen threshold circuits for exact arithmetic.
6
- """
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from typing import Dict, Optional, Tuple
12
- from safetensors.torch import load_file
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
14
- import warnings
15
- warnings.filterwarnings('ignore')
16
-
17
-
18
- # =============================================================================
19
- # HEAVISIDE WITH STRAIGHT-THROUGH ESTIMATOR
20
- # =============================================================================
21
-
22
- class HeavisideSTE(torch.autograd.Function):
23
- """Heaviside step function with straight-through estimator for backprop."""
24
-
25
- @staticmethod
26
- def forward(ctx, x):
27
- return (x >= 0).float()
28
-
29
- @staticmethod
30
- def backward(ctx, grad_output):
31
- # STE: pass gradient through unchanged
32
- return grad_output
33
-
34
-
35
- def heaviside(x: torch.Tensor) -> torch.Tensor:
36
- """Heaviside step: 1 if x >= 0, else 0. Uses STE for training."""
37
- return HeavisideSTE.apply(x)
38
-
39
-
40
- # =============================================================================
41
- # CIRCUIT EXECUTOR - Runs the frozen threshold circuits
42
- # =============================================================================
43
-
44
- class CircuitExecutor(nn.Module):
45
- """
46
- Executes threshold logic circuits from the safetensors file.
47
- All circuit weights are frozen - only interface layers train.
48
- """
49
-
50
- def __init__(self, circuit_path: str, device: str = 'cpu'):
51
- super().__init__()
52
- self.device = device
53
-
54
- # Load all circuit tensors
55
- raw_circuits = load_file(circuit_path)
56
-
57
- # Store as frozen parameters (use underscores for valid param names)
58
- self.circuits = {}
59
- for k, v in raw_circuits.items():
60
- safe_name = k.replace('.', '__')
61
- self.register_buffer(safe_name, v.float().to(device))
62
- self.circuits[k] = safe_name
63
-
64
- def _get(self, name: str) -> torch.Tensor:
65
- """Get circuit tensor by original dotted name."""
66
- return getattr(self, self.circuits[name])
67
-
68
- # -------------------------------------------------------------------------
69
- # Boolean Gates
70
- # -------------------------------------------------------------------------
71
-
72
- def eval_and(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
73
- """AND gate: output 1 iff both inputs are 1."""
74
- inp = torch.stack([a, b], dim=-1)
75
- w = self._get('boolean.and.weight')
76
- bias = self._get('boolean.and.bias')
77
- return heaviside(inp @ w + bias)
78
-
79
- def eval_or(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
80
- """OR gate: output 1 if either input is 1."""
81
- inp = torch.stack([a, b], dim=-1)
82
- w = self._get('boolean.or.weight')
83
- bias = self._get('boolean.or.bias')
84
- return heaviside(inp @ w + bias)
85
-
86
- def eval_xor(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
87
- """XOR gate: two-layer network (not linearly separable)."""
88
- inp = torch.stack([a, b], dim=-1)
89
-
90
- # Layer 1: OR and NAND neurons
91
- w1_n1 = self._get('boolean.xor.layer1.neuron1.weight')
92
- b1_n1 = self._get('boolean.xor.layer1.neuron1.bias')
93
- w1_n2 = self._get('boolean.xor.layer1.neuron2.weight')
94
- b1_n2 = self._get('boolean.xor.layer1.neuron2.bias')
95
-
96
- h1 = heaviside(inp @ w1_n1 + b1_n1)
97
- h2 = heaviside(inp @ w1_n2 + b1_n2)
98
- hidden = torch.stack([h1, h2], dim=-1)
99
-
100
- # Layer 2: AND of hidden
101
- w2 = self._get('boolean.xor.layer2.weight')
102
- b2 = self._get('boolean.xor.layer2.bias')
103
-
104
- return heaviside(hidden @ w2 + b2)
105
-
106
- # -------------------------------------------------------------------------
107
- # Arithmetic: Full Adder
108
- # -------------------------------------------------------------------------
109
-
110
- def eval_full_adder(self, a: torch.Tensor, b: torch.Tensor,
111
- cin: torch.Tensor, prefix: str) -> Tuple[torch.Tensor, torch.Tensor]:
112
- """
113
- Full adder: sum = a XOR b XOR cin, cout = (a AND b) OR (cin AND (a XOR b))
114
- Returns (sum_bit, carry_out)
115
- """
116
- inp_ab = torch.stack([a, b], dim=-1)
117
-
118
- # HA1: a XOR b
119
- w1_or = self._get(f'{prefix}.ha1.sum.layer1.or.weight')
120
- b1_or = self._get(f'{prefix}.ha1.sum.layer1.or.bias')
121
- w1_nand = self._get(f'{prefix}.ha1.sum.layer1.nand.weight')
122
- b1_nand = self._get(f'{prefix}.ha1.sum.layer1.nand.bias')
123
- w2 = self._get(f'{prefix}.ha1.sum.layer2.weight')
124
- b2 = self._get(f'{prefix}.ha1.sum.layer2.bias')
125
-
126
- h_or = heaviside(inp_ab @ w1_or + b1_or)
127
- h_nand = heaviside(inp_ab @ w1_nand + b1_nand)
128
- hidden = torch.stack([h_or, h_nand], dim=-1)
129
- ha1_sum = heaviside(hidden @ w2 + b2)
130
-
131
- # HA1 carry
132
- w_c1 = self._get(f'{prefix}.ha1.carry.weight')
133
- b_c1 = self._get(f'{prefix}.ha1.carry.bias')
134
- ha1_carry = heaviside(inp_ab @ w_c1 + b_c1)
135
-
136
- # HA2: ha1_sum XOR cin
137
- inp_ha2 = torch.stack([ha1_sum, cin], dim=-1)
138
- w1_or = self._get(f'{prefix}.ha2.sum.layer1.or.weight')
139
- b1_or = self._get(f'{prefix}.ha2.sum.layer1.or.bias')
140
- w1_nand = self._get(f'{prefix}.ha2.sum.layer1.nand.weight')
141
- b1_nand = self._get(f'{prefix}.ha2.sum.layer1.nand.bias')
142
- w2 = self._get(f'{prefix}.ha2.sum.layer2.weight')
143
- b2 = self._get(f'{prefix}.ha2.sum.layer2.bias')
144
-
145
- h_or = heaviside(inp_ha2 @ w1_or + b1_or)
146
- h_nand = heaviside(inp_ha2 @ w1_nand + b1_nand)
147
- hidden = torch.stack([h_or, h_nand], dim=-1)
148
- ha2_sum = heaviside(hidden @ w2 + b2)
149
-
150
- # HA2 carry
151
- w_c2 = self._get(f'{prefix}.ha2.carry.weight')
152
- b_c2 = self._get(f'{prefix}.ha2.carry.bias')
153
- ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2)
154
-
155
- # Carry out = ha1_carry OR ha2_carry
156
- inp_cout = torch.stack([ha1_carry, ha2_carry], dim=-1)
157
- w_or = self._get(f'{prefix}.carry_or.weight')
158
- b_or = self._get(f'{prefix}.carry_or.bias')
159
- cout = heaviside(inp_cout @ w_or + b_or)
160
-
161
- return ha2_sum, cout
162
-
163
- # -------------------------------------------------------------------------
164
- # Arithmetic: 8-bit Ripple Carry Adder
165
- # -------------------------------------------------------------------------
166
-
167
- def add_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
168
- """
169
- 8-bit ripple carry addition.
170
- a_bits, b_bits: [..., 8] tensors of bits (LSB first)
171
- Returns: (result_bits [..., 8], carry_out [...])
172
- """
173
- batch_shape = a_bits.shape[:-1]
174
- carry = torch.zeros(batch_shape, device=a_bits.device)
175
- result_bits = []
176
-
177
- for i in range(8):
178
- a_i = a_bits[..., i]
179
- b_i = b_bits[..., i]
180
- sum_bit, carry = self.eval_full_adder(
181
- a_i, b_i, carry,
182
- f'arithmetic.ripplecarry8bit.fa{i}'
183
- )
184
- result_bits.append(sum_bit)
185
-
186
- return torch.stack(result_bits, dim=-1), carry
187
-
188
- # -------------------------------------------------------------------------
189
- # Arithmetic: 8-bit Comparators
190
- # -------------------------------------------------------------------------
191
-
192
- def greater_than_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
193
- """Returns 1 if a > b, else 0. Bits are MSB first."""
194
- diff = a_bits - b_bits # [..., 8]
195
- w = self._get('arithmetic.greaterthan8bit.comparator')
196
- score = (diff * w).sum(dim=-1)
197
- return (score > 0).float()
198
-
199
- def less_than_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
200
- """Returns 1 if a < b, else 0. Bits are MSB first."""
201
- diff = b_bits - a_bits # [..., 8]
202
- w = self._get('arithmetic.lessthan8bit.comparator')
203
- score = (diff * w).sum(dim=-1)
204
- return (score > 0).float()
205
-
206
- def equal_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
207
- """Returns 1 if a == b, else 0."""
208
- gt = self.greater_than_8bit(a_bits, b_bits)
209
- lt = self.less_than_8bit(a_bits, b_bits)
210
- return (1 - gt) * (1 - lt)
211
-
212
-
213
- # =============================================================================
214
- # BIT EXTRACTION / INJECTION INTERFACES
215
- # =============================================================================
216
-
217
- class BitExtractor(nn.Module):
218
- """
219
- Learns to extract 8-bit operands from token embeddings.
220
- Maps embedding -> 16 bits (two 8-bit operands).
221
- """
222
-
223
- def __init__(self, d_model: int):
224
- super().__init__()
225
- self.d_model = d_model
226
-
227
- # Project to logits, then binarize
228
- self.proj = nn.Linear(d_model, 16)
229
-
230
- # Learnable temperature for sigmoid approximation during training
231
- self.temperature = nn.Parameter(torch.tensor(1.0))
232
-
233
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
- """
235
- x: [..., d_model]
236
- Returns: a_bits [..., 8], b_bits [..., 8] (LSB first for arithmetic)
237
- """
238
- logits = self.proj(x) # [..., 16]
239
-
240
- # Binarize with STE
241
- bits = heaviside(logits)
242
-
243
- # Split into two operands
244
- a_bits = bits[..., :8]
245
- b_bits = bits[..., 8:]
246
-
247
- return a_bits, b_bits
248
-
249
-
250
- class BitInjector(nn.Module):
251
- """
252
- Learns to inject circuit results back into embedding space.
253
- Maps 16 bits (result + flags) -> embedding delta.
254
- """
255
-
256
- def __init__(self, d_model: int):
257
- super().__init__()
258
- self.d_model = d_model
259
-
260
- # Project bits to embedding
261
- self.proj = nn.Linear(16, d_model)
262
-
263
- # Learnable scale
264
- self.scale = nn.Parameter(torch.tensor(0.1))
265
-
266
- def forward(self, result_bits: torch.Tensor, flags: torch.Tensor) -> torch.Tensor:
267
- """
268
- result_bits: [..., 8]
269
- flags: [..., 8] (carry, overflow, zero, negative, etc.)
270
- Returns: [..., d_model]
271
- """
272
- combined = torch.cat([result_bits, flags], dim=-1) # [..., 16]
273
- return self.proj(combined) * self.scale
274
-
275
-
276
- # =============================================================================
277
- # CIRCUIT-AUGMENTED MLP BLOCK
278
- # =============================================================================
279
-
280
- class CircuitAugmentedMLP(nn.Module):
281
- """
282
- MLP block augmented with frozen threshold circuits.
283
-
284
- The original MLP path runs in parallel with the circuit path.
285
- A learned router decides how much to use each.
286
- """
287
-
288
- def __init__(
289
- self,
290
- d_model: int,
291
- intermediate_size: int,
292
- circuit_path: str,
293
- device: str = 'cpu'
294
- ):
295
- super().__init__()
296
- self.d_model = d_model
297
-
298
- # Original MLP components (will be loaded from pretrained)
299
- self.gate_proj = nn.Linear(d_model, intermediate_size, bias=False)
300
- self.up_proj = nn.Linear(d_model, intermediate_size, bias=False)
301
- self.down_proj = nn.Linear(intermediate_size, d_model, bias=False)
302
- self.act_fn = nn.SiLU()
303
-
304
- # Circuit components
305
- self.circuits = CircuitExecutor(circuit_path, device)
306
- self.bit_extractor = BitExtractor(d_model)
307
- self.bit_injector = BitInjector(d_model)
308
-
309
- # Router: decides circuit vs MLP contribution
310
- self.router = nn.Sequential(
311
- nn.Linear(d_model, 64),
312
- nn.ReLU(),
313
- nn.Linear(64, 2),
314
- nn.Softmax(dim=-1)
315
- )
316
-
317
- # Operation selector (which arithmetic op to perform)
318
- self.op_selector = nn.Sequential(
319
- nn.Linear(d_model, 32),
320
- nn.ReLU(),
321
- nn.Linear(32, 4), # add, sub, compare, passthrough
322
- nn.Softmax(dim=-1)
323
- )
324
-
325
- def _compute_flags(self, result_bits: torch.Tensor, carry: torch.Tensor) -> torch.Tensor:
326
- """Compute status flags from result."""
327
- batch_shape = result_bits.shape[:-1]
328
-
329
- # Zero flag: all bits are 0
330
- zero = (result_bits.sum(dim=-1) == 0).float()
331
-
332
- # Negative flag: MSB is 1 (two's complement)
333
- negative = result_bits[..., 7]
334
-
335
- # Carry flag
336
- carry_flag = carry
337
-
338
- # Pad to 8 flags
339
- flags = torch.zeros(*batch_shape, 8, device=result_bits.device)
340
- flags[..., 0] = zero
341
- flags[..., 1] = negative
342
- flags[..., 2] = carry_flag
343
-
344
- return flags
345
-
346
- def _circuit_forward(self, x: torch.Tensor) -> torch.Tensor:
347
- """Run input through threshold circuits."""
348
- # Extract operands
349
- a_bits, b_bits = self.bit_extractor(x)
350
-
351
- # Get operation weights
352
- op_weights = self.op_selector(x) # [..., 4]
353
-
354
- # Compute addition
355
- add_result, add_carry = self.circuits.add_8bit(a_bits, b_bits)
356
- add_flags = self._compute_flags(add_result, add_carry)
357
-
358
- # Compute subtraction (a + (~b) + 1, simplified: just use add for now)
359
- # For MVP, we'll focus on addition
360
-
361
- # Inject result back
362
- circuit_delta = self.bit_injector(add_result, add_flags)
363
-
364
- return circuit_delta
365
-
366
- def forward(self, x: torch.Tensor) -> torch.Tensor:
367
- """
368
- x: [batch, seq_len, d_model]
369
- Returns: [batch, seq_len, d_model]
370
- """
371
- # Original MLP path
372
- mlp_out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
373
-
374
- # Circuit path
375
- circuit_out = self._circuit_forward(x)
376
-
377
- # Route between paths
378
- route_weights = self.router(x) # [..., 2]
379
- mlp_weight = route_weights[..., 0:1]
380
- circuit_weight = route_weights[..., 1:2]
381
-
382
- # Combine: MLP output + weighted circuit contribution
383
- output = mlp_out + circuit_weight * circuit_out
384
-
385
- return output
386
-
387
-
388
- # =============================================================================
389
- # MODEL SURGERY: Insert circuits into SmolLM2
390
- # =============================================================================
391
-
392
- def augment_smollm2_with_circuits(
393
- model: AutoModelForCausalLM,
394
- circuit_path: str,
395
- layer_indices: list = None,
396
- device: str = 'cpu'
397
- ) -> AutoModelForCausalLM:
398
- """
399
- Surgically insert circuit blocks into SmolLM2's MLP layers.
400
-
401
- Args:
402
- model: Pretrained SmolLM2 model
403
- circuit_path: Path to neural_computer.safetensors
404
- layer_indices: Which layers to augment (default: middle layers)
405
- device: Device for circuit tensors
406
-
407
- Returns:
408
- Modified model with circuit-augmented MLPs
409
- """
410
- config = model.config
411
- num_layers = config.num_hidden_layers
412
-
413
- # Default: augment middle third of layers
414
- if layer_indices is None:
415
- start = num_layers // 3
416
- end = 2 * num_layers // 3
417
- layer_indices = list(range(start, end))
418
-
419
- print(f"Augmenting layers {layer_indices} with threshold circuits...")
420
-
421
- for idx in layer_indices:
422
- layer = model.model.layers[idx]
423
- old_mlp = layer.mlp
424
-
425
- # Create augmented MLP
426
- new_mlp = CircuitAugmentedMLP(
427
- d_model=config.hidden_size,
428
- intermediate_size=config.intermediate_size,
429
- circuit_path=circuit_path,
430
- device=device
431
- )
432
-
433
- # Copy pretrained weights
434
- new_mlp.gate_proj.weight.data = old_mlp.gate_proj.weight.data.clone()
435
- new_mlp.up_proj.weight.data = old_mlp.up_proj.weight.data.clone()
436
- new_mlp.down_proj.weight.data = old_mlp.down_proj.weight.data.clone()
437
-
438
- # Replace
439
- layer.mlp = new_mlp
440
-
441
- # Freeze circuit weights, keep interfaces trainable
442
- for name, param in model.named_parameters():
443
- if 'circuits' in name:
444
- param.requires_grad = False
445
-
446
- print(f"Done. Circuit weights frozen, interfaces trainable.")
447
-
448
- return model
449
-
450
-
451
- # =============================================================================
452
- # TRAINING UTILITIES
453
- # =============================================================================
454
-
455
- def generate_arithmetic_batch(batch_size: int, max_val: int = 255) -> Tuple[list, list]:
456
- """Generate batch of arithmetic problems and solutions."""
457
- prompts = []
458
- targets = []
459
-
460
- for _ in range(batch_size):
461
- a = torch.randint(0, max_val + 1, (1,)).item()
462
- b = torch.randint(0, max_val + 1, (1,)).item()
463
- result = (a + b) % 256
464
-
465
- prompts.append(f"{a} + {b} =")
466
- targets.append(f" {result}")
467
-
468
- return prompts, targets
469
-
470
-
471
- def evaluate_arithmetic(
472
- model: AutoModelForCausalLM,
473
- tokenizer: AutoTokenizer,
474
- n_problems: int = 100,
475
- device: str = 'cpu'
476
- ) -> dict:
477
- """Evaluate model on random arithmetic problems."""
478
- correct = 0
479
- total = 0
480
- errors = []
481
-
482
- model.eval()
483
-
484
- for _ in range(n_problems):
485
- a = torch.randint(0, 256, (1,)).item()
486
- b = torch.randint(0, 256, (1,)).item()
487
- expected = (a + b) % 256
488
-
489
- prompt = f"{a} + {b} ="
490
- inputs = tokenizer(prompt, return_tensors='pt').to(device)
491
-
492
- with torch.no_grad():
493
- outputs = model.generate(
494
- **inputs,
495
- max_new_tokens=10,
496
- do_sample=False,
497
- pad_token_id=tokenizer.eos_token_id
498
- )
499
-
500
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
501
-
502
- # Extract number from response
503
- try:
504
- # Find the part after "="
505
- answer_part = response.split('=')[-1].strip()
506
- # Extract first number
507
- predicted = int(''.join(c for c in answer_part.split()[0] if c.isdigit()))
508
-
509
- if predicted == expected:
510
- correct += 1
511
- else:
512
- errors.append((a, b, expected, predicted))
513
- except:
514
- errors.append((a, b, expected, "parse_error"))
515
-
516
- total += 1
517
-
518
- return {
519
- 'accuracy': correct / total,
520
- 'correct': correct,
521
- 'total': total,
522
- 'errors': errors[:10] # First 10 errors
523
- }
524
-
525
-
526
- # =============================================================================
527
- # MAIN: Demo
528
- # =============================================================================
529
-
530
- if __name__ == "__main__":
531
- import argparse
532
-
533
- parser = argparse.ArgumentParser(description='Circuit-Augmented LLM Demo')
534
- parser.add_argument('--circuit-path', type=str,
535
- default='./neural_computer.safetensors',
536
- help='Path to circuit weights')
537
- parser.add_argument('--device', type=str, default='cpu',
538
- help='Device (cpu or cuda)')
539
- parser.add_argument('--eval-only', action='store_true',
540
- help='Only evaluate, do not augment')
541
- args = parser.parse_args()
542
-
543
- print("=" * 70)
544
- print(" CIRCUIT-AUGMENTED LLM")
545
- print("=" * 70)
546
-
547
- # Load tokenizer and model
548
- print("\n[1] Loading SmolLM2-360M...")
549
- model_id = "HuggingFaceTB/SmolLM2-360M"
550
- tokenizer = AutoTokenizer.from_pretrained(model_id)
551
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
552
-
553
- print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
554
-
555
- # Baseline evaluation
556
- print("\n[2] Baseline arithmetic evaluation...")
557
- baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
558
- print(f" Accuracy: {baseline['accuracy']*100:.1f}% ({baseline['correct']}/{baseline['total']})")
559
- if baseline['errors']:
560
- print(f" Sample errors:")
561
- for a, b, exp, got in baseline['errors'][:5]:
562
- print(f" {a} + {b} = {exp}, model said {got}")
563
-
564
- if args.eval_only:
565
- print("\nDone (eval only mode).")
566
- exit(0)
567
-
568
- # Augment with circuits
569
- print(f"\n[3] Augmenting with threshold circuits...")
570
- print(f" Circuit path: {args.circuit_path}")
571
- model = augment_smollm2_with_circuits(
572
- model,
573
- args.circuit_path,
574
- device=args.device
575
- )
576
-
577
- new_params = sum(p.numel() for p in model.parameters())
578
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
579
- print(f" Total parameters: {new_params:,}")
580
- print(f" Trainable parameters: {trainable:,}")
581
-
582
- # Test circuit execution directly
583
- print("\n[4] Testing circuit execution...")
584
- circuit_exec = CircuitExecutor(args.circuit_path, args.device)
585
-
586
- test_cases = [(127, 128), (255, 1), (0, 0), (100, 55)]
587
- for a, b in test_cases:
588
- # Convert to bits (LSB first)
589
- a_bits = torch.tensor([(a >> i) & 1 for i in range(8)], dtype=torch.float32)
590
- b_bits = torch.tensor([(b >> i) & 1 for i in range(8)], dtype=torch.float32)
591
-
592
- result_bits, carry = circuit_exec.add_8bit(
593
- a_bits.unsqueeze(0),
594
- b_bits.unsqueeze(0)
595
- )
596
-
597
- # Convert result bits back to int
598
- result = sum(int(result_bits[0, i].item()) * (2**i) for i in range(8))
599
- expected = (a + b) % 256
600
-
601
- status = "OK" if result == expected else "FAIL"
602
- print(f" {a} + {b} = {result} (expected {expected}) [{status}]")
603
-
604
- print("\n[5] Model ready for fine-tuning.")
605
- print(" Next: Train interface layers on arithmetic examples.")
606
- print("=" * 70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Circuit-Augmented LLM: Embedding Threshold Logic Circuits into Transformers
3
+ ============================================================================
4
+
5
+ Embeds frozen, proven-correct arithmetic circuits into transformer MLP layers.
6
+ The model learns call dispatch (when to use circuits), not arithmetic.
7
+
8
+ ARCHITECTURE
9
+ ------------
10
+ Standard LLM MLPs are augmented with a parallel circuit path:
11
+
12
+ x ──┬── MLP path ────────────────┬── + ── output
13
+ │ │
14
+ └── BitExtractor ── Circuit ─┴── BitInjector
15
+
16
+ Router (learned weighting)
17
+
18
+ THRESHOLD LOGIC
19
+ ---------------
20
+ Each gate: output = 1 if (Σ wᵢxᵢ + b) ≥ 0 else 0
21
+
22
+ Examples:
23
+ AND: w=[1,1], b=-2 → fires only when both inputs are 1
24
+ OR: w=[1,1], b=-1 → fires when either input is 1
25
+ XOR: 2-layer network (not linearly separable)
26
+
27
+ Full adder = 2 half-adders + carry OR, ~4 threshold layers.
28
+ 8-bit ripple carry = 8 chained full adders, ~32 threshold layers.
29
+
30
+ TRAINING
31
+ --------
32
+ Only interface layers train (~1.37M params):
33
+ - BitExtractor: embedding → operand bits
34
+ - BitInjector: result bits → embedding delta
35
+ - Router: when to use circuits vs MLP
36
+
37
+ Circuits are frozen (proven correct via 6,590 exhaustive tests).
38
+ Uses Straight-Through Estimator for Heaviside gradient flow.
39
+
40
+ TARGET: SmolLM2-360M
41
+ - 960 hidden dim, 32 layers, 361M params
42
+ - Augment middle third (layers 10-20)
43
+ - Baseline arithmetic: ~5-10%
44
+ - Target: >95% (circuit-accurate)
45
+
46
+ USAGE
47
+ -----
48
+ # Augment model
49
+ model = augment_smollm2_with_circuits(model, "neural_computer.safetensors")
50
+
51
+ # Train interface
52
+ model = train_interface(model, tokenizer, n_epochs=3)
53
+
54
+ # Evaluate
55
+ results = evaluate_arithmetic(model, tokenizer, n_problems=100)
56
+
57
+ REFERENCES
58
+ ----------
59
+ 1. McCulloch & Pitts (1943). Logical Calculus of Ideas in Nervous Activity
60
+ 2. Muroga (1971). Threshold Logic and Its Applications
61
+ 3. Bengio et al. (2013). Estimating Gradients Through Stochastic Neurons (STE)
62
+ 4. Ma et al. (2024). The Era of 1-bit LLMs (BitNet)
63
+ """
64
+
65
+ from __future__ import annotations
66
+
67
+ import argparse
68
+ import warnings
69
+ from typing import Dict, List, Optional, Tuple
70
+
71
+ import torch
72
+ import torch.nn as nn
73
+ import torch.nn.functional as F
74
+ from safetensors.torch import load_file
75
+ from torch.utils.data import DataLoader, Dataset
76
+ from tqdm import tqdm
77
+ from transformers import AutoModelForCausalLM, AutoTokenizer
78
+
79
+ warnings.filterwarnings("ignore")
80
+
81
+
82
+ class HeavisideSTE(torch.autograd.Function):
83
+ """Heaviside step function with straight-through estimator for backprop."""
84
+
85
+ @staticmethod
86
+ def forward(ctx, x):
87
+ return (x >= 0).float()
88
+
89
+ @staticmethod
90
+ def backward(ctx, grad_output):
91
+ return grad_output
92
+
93
+
94
+ def heaviside(x: torch.Tensor) -> torch.Tensor:
95
+ """Heaviside step: 1 if x >= 0, else 0. Uses STE for training."""
96
+ return HeavisideSTE.apply(x)
97
+
98
+
99
+ class CircuitExecutor(nn.Module):
100
+ """
101
+ Executes threshold logic circuits from safetensors.
102
+ All circuit weights are frozen.
103
+ """
104
+
105
+ def __init__(self, circuit_path: str, device: str = "cpu"):
106
+ super().__init__()
107
+ self.device = device
108
+
109
+ raw_circuits = load_file(circuit_path)
110
+
111
+ self.circuits = {}
112
+ for k, v in raw_circuits.items():
113
+ safe_name = k.replace(".", "__")
114
+ self.register_buffer(safe_name, v.float().to(device))
115
+ self.circuits[k] = safe_name
116
+
117
+ def _get(self, name: str) -> torch.Tensor:
118
+ return getattr(self, self.circuits[name])
119
+
120
+ def eval_and(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
121
+ inp = torch.stack([a, b], dim=-1)
122
+ w = self._get("boolean.and.weight")
123
+ bias = self._get("boolean.and.bias")
124
+ return heaviside(inp @ w + bias)
125
+
126
+ def eval_or(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
127
+ inp = torch.stack([a, b], dim=-1)
128
+ w = self._get("boolean.or.weight")
129
+ bias = self._get("boolean.or.bias")
130
+ return heaviside(inp @ w + bias)
131
+
132
+ def eval_xor(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
133
+ inp = torch.stack([a, b], dim=-1)
134
+
135
+ w1_n1 = self._get("boolean.xor.layer1.neuron1.weight")
136
+ b1_n1 = self._get("boolean.xor.layer1.neuron1.bias")
137
+ w1_n2 = self._get("boolean.xor.layer1.neuron2.weight")
138
+ b1_n2 = self._get("boolean.xor.layer1.neuron2.bias")
139
+
140
+ h1 = heaviside(inp @ w1_n1 + b1_n1)
141
+ h2 = heaviside(inp @ w1_n2 + b1_n2)
142
+ hidden = torch.stack([h1, h2], dim=-1)
143
+
144
+ w2 = self._get("boolean.xor.layer2.weight")
145
+ b2 = self._get("boolean.xor.layer2.bias")
146
+
147
+ return heaviside(hidden @ w2 + b2)
148
+
149
+ def eval_full_adder(
150
+ self, a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor, prefix: str
151
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
152
+ inp_ab = torch.stack([a, b], dim=-1)
153
+
154
+ w1_or = self._get(f"{prefix}.ha1.sum.layer1.or.weight")
155
+ b1_or = self._get(f"{prefix}.ha1.sum.layer1.or.bias")
156
+ w1_nand = self._get(f"{prefix}.ha1.sum.layer1.nand.weight")
157
+ b1_nand = self._get(f"{prefix}.ha1.sum.layer1.nand.bias")
158
+ w2 = self._get(f"{prefix}.ha1.sum.layer2.weight")
159
+ b2 = self._get(f"{prefix}.ha1.sum.layer2.bias")
160
+
161
+ h_or = heaviside(inp_ab @ w1_or + b1_or)
162
+ h_nand = heaviside(inp_ab @ w1_nand + b1_nand)
163
+ hidden = torch.stack([h_or, h_nand], dim=-1)
164
+ ha1_sum = heaviside(hidden @ w2 + b2)
165
+
166
+ w_c1 = self._get(f"{prefix}.ha1.carry.weight")
167
+ b_c1 = self._get(f"{prefix}.ha1.carry.bias")
168
+ ha1_carry = heaviside(inp_ab @ w_c1 + b_c1)
169
+
170
+ inp_ha2 = torch.stack([ha1_sum, cin], dim=-1)
171
+ w1_or = self._get(f"{prefix}.ha2.sum.layer1.or.weight")
172
+ b1_or = self._get(f"{prefix}.ha2.sum.layer1.or.bias")
173
+ w1_nand = self._get(f"{prefix}.ha2.sum.layer1.nand.weight")
174
+ b1_nand = self._get(f"{prefix}.ha2.sum.layer1.nand.bias")
175
+ w2 = self._get(f"{prefix}.ha2.sum.layer2.weight")
176
+ b2 = self._get(f"{prefix}.ha2.sum.layer2.bias")
177
+
178
+ h_or = heaviside(inp_ha2 @ w1_or + b1_or)
179
+ h_nand = heaviside(inp_ha2 @ w1_nand + b1_nand)
180
+ hidden = torch.stack([h_or, h_nand], dim=-1)
181
+ ha2_sum = heaviside(hidden @ w2 + b2)
182
+
183
+ w_c2 = self._get(f"{prefix}.ha2.carry.weight")
184
+ b_c2 = self._get(f"{prefix}.ha2.carry.bias")
185
+ ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2)
186
+
187
+ inp_cout = torch.stack([ha1_carry, ha2_carry], dim=-1)
188
+ w_or = self._get(f"{prefix}.carry_or.weight")
189
+ b_or = self._get(f"{prefix}.carry_or.bias")
190
+ cout = heaviside(inp_cout @ w_or + b_or)
191
+
192
+ return ha2_sum, cout
193
+
194
+ def add_8bit(
195
+ self, a_bits: torch.Tensor, b_bits: torch.Tensor
196
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
197
+ """
198
+ 8-bit ripple carry addition.
199
+ a_bits, b_bits: [..., 8] tensors (LSB first)
200
+ Returns: (result_bits [..., 8], carry_out [...])
201
+ """
202
+ batch_shape = a_bits.shape[:-1]
203
+ carry = torch.zeros(batch_shape, device=a_bits.device)
204
+ result_bits = []
205
+
206
+ for i in range(8):
207
+ a_i = a_bits[..., i]
208
+ b_i = b_bits[..., i]
209
+ sum_bit, carry = self.eval_full_adder(
210
+ a_i, b_i, carry, f"arithmetic.ripplecarry8bit.fa{i}"
211
+ )
212
+ result_bits.append(sum_bit)
213
+
214
+ return torch.stack(result_bits, dim=-1), carry
215
+
216
+ def greater_than_8bit(
217
+ self, a_bits: torch.Tensor, b_bits: torch.Tensor
218
+ ) -> torch.Tensor:
219
+ diff = a_bits - b_bits
220
+ w = self._get("arithmetic.greaterthan8bit.comparator")
221
+ score = (diff * w).sum(dim=-1)
222
+ return (score > 0).float()
223
+
224
+ def less_than_8bit(
225
+ self, a_bits: torch.Tensor, b_bits: torch.Tensor
226
+ ) -> torch.Tensor:
227
+ diff = b_bits - a_bits
228
+ w = self._get("arithmetic.lessthan8bit.comparator")
229
+ score = (diff * w).sum(dim=-1)
230
+ return (score > 0).float()
231
+
232
+ def equal_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
233
+ gt = self.greater_than_8bit(a_bits, b_bits)
234
+ lt = self.less_than_8bit(a_bits, b_bits)
235
+ return (1 - gt) * (1 - lt)
236
+
237
+
238
+ class BitExtractor(nn.Module):
239
+ """Maps embedding -> two 8-bit operands."""
240
+
241
+ def __init__(self, d_model: int):
242
+ super().__init__()
243
+ self.d_model = d_model
244
+ self.proj = nn.Linear(d_model, 16)
245
+ self.temperature = nn.Parameter(torch.tensor(1.0))
246
+
247
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ logits = self.proj(x)
249
+ bits = heaviside(logits)
250
+ a_bits = bits[..., :8]
251
+ b_bits = bits[..., 8:]
252
+ return a_bits, b_bits
253
+
254
+
255
+ class BitInjector(nn.Module):
256
+ """Maps result bits -> embedding delta."""
257
+
258
+ def __init__(self, d_model: int):
259
+ super().__init__()
260
+ self.d_model = d_model
261
+ self.proj = nn.Linear(16, d_model)
262
+ self.scale = nn.Parameter(torch.tensor(0.1))
263
+
264
+ def forward(self, result_bits: torch.Tensor, flags: torch.Tensor) -> torch.Tensor:
265
+ combined = torch.cat([result_bits, flags], dim=-1)
266
+ return self.proj(combined) * self.scale
267
+
268
+
269
+ class CircuitAugmentedMLP(nn.Module):
270
+ """
271
+ MLP block augmented with frozen threshold circuits.
272
+ Original MLP runs in parallel with circuit path; router decides weighting.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ d_model: int,
278
+ intermediate_size: int,
279
+ circuit_path: str,
280
+ device: str = "cpu",
281
+ ):
282
+ super().__init__()
283
+ self.d_model = d_model
284
+
285
+ self.gate_proj = nn.Linear(d_model, intermediate_size, bias=False)
286
+ self.up_proj = nn.Linear(d_model, intermediate_size, bias=False)
287
+ self.down_proj = nn.Linear(intermediate_size, d_model, bias=False)
288
+ self.act_fn = nn.SiLU()
289
+
290
+ self.circuits = CircuitExecutor(circuit_path, device)
291
+ self.bit_extractor = BitExtractor(d_model)
292
+ self.bit_injector = BitInjector(d_model)
293
+
294
+ self.router = nn.Sequential(
295
+ nn.Linear(d_model, 64),
296
+ nn.ReLU(),
297
+ nn.Linear(64, 2),
298
+ nn.Softmax(dim=-1),
299
+ )
300
+
301
+ self.op_selector = nn.Sequential(
302
+ nn.Linear(d_model, 32),
303
+ nn.ReLU(),
304
+ nn.Linear(32, 4),
305
+ nn.Softmax(dim=-1),
306
+ )
307
+
308
+ def _compute_flags(
309
+ self, result_bits: torch.Tensor, carry: torch.Tensor
310
+ ) -> torch.Tensor:
311
+ batch_shape = result_bits.shape[:-1]
312
+
313
+ zero = (result_bits.sum(dim=-1) == 0).float()
314
+ negative = result_bits[..., 7]
315
+ carry_flag = carry
316
+
317
+ flags = torch.zeros(*batch_shape, 8, device=result_bits.device)
318
+ flags[..., 0] = zero
319
+ flags[..., 1] = negative
320
+ flags[..., 2] = carry_flag
321
+
322
+ return flags
323
+
324
+ def _circuit_forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ a_bits, b_bits = self.bit_extractor(x)
326
+ add_result, add_carry = self.circuits.add_8bit(a_bits, b_bits)
327
+ add_flags = self._compute_flags(add_result, add_carry)
328
+ circuit_delta = self.bit_injector(add_result, add_flags)
329
+ return circuit_delta
330
+
331
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
332
+ mlp_out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
333
+
334
+ circuit_out = self._circuit_forward(x)
335
+
336
+ route_weights = self.router(x)
337
+ circuit_weight = route_weights[..., 1:2]
338
+
339
+ output = mlp_out + circuit_weight * circuit_out
340
+
341
+ return output
342
+
343
+
344
+ def augment_smollm2_with_circuits(
345
+ model: AutoModelForCausalLM,
346
+ circuit_path: str,
347
+ layer_indices: list = None,
348
+ device: str = "cpu",
349
+ ) -> AutoModelForCausalLM:
350
+ """
351
+ Insert circuit blocks into SmolLM2's MLP layers.
352
+
353
+ Args:
354
+ model: Pretrained SmolLM2
355
+ circuit_path: Path to neural_computer.safetensors
356
+ layer_indices: Which layers to augment (default: middle third)
357
+ device: Device for circuit tensors
358
+
359
+ Returns:
360
+ Model with circuit-augmented MLPs
361
+ """
362
+ config = model.config
363
+ num_layers = config.num_hidden_layers
364
+
365
+ if layer_indices is None:
366
+ start = num_layers // 3
367
+ end = 2 * num_layers // 3
368
+ layer_indices = list(range(start, end))
369
+
370
+ print(f"Augmenting layers {layer_indices} with threshold circuits...")
371
+
372
+ for idx in layer_indices:
373
+ layer = model.model.layers[idx]
374
+ old_mlp = layer.mlp
375
+
376
+ new_mlp = CircuitAugmentedMLP(
377
+ d_model=config.hidden_size,
378
+ intermediate_size=config.intermediate_size,
379
+ circuit_path=circuit_path,
380
+ device=device,
381
+ )
382
+
383
+ new_mlp.gate_proj.weight.data = old_mlp.gate_proj.weight.data.clone()
384
+ new_mlp.up_proj.weight.data = old_mlp.up_proj.weight.data.clone()
385
+ new_mlp.down_proj.weight.data = old_mlp.down_proj.weight.data.clone()
386
+
387
+ layer.mlp = new_mlp
388
+
389
+ for name, param in model.named_parameters():
390
+ if "circuits" in name:
391
+ param.requires_grad = False
392
+
393
+ print("Done. Circuit weights frozen, interfaces trainable.")
394
+
395
+ return model
396
+
397
+
398
+ def generate_arithmetic_batch(
399
+ batch_size: int, max_val: int = 255
400
+ ) -> Tuple[list, list]:
401
+ """Generate batch of arithmetic problems and solutions."""
402
+ prompts = []
403
+ targets = []
404
+
405
+ for _ in range(batch_size):
406
+ a = torch.randint(0, max_val + 1, (1,)).item()
407
+ b = torch.randint(0, max_val + 1, (1,)).item()
408
+ result = (a + b) % 256
409
+
410
+ prompts.append(f"{a} + {b} =")
411
+ targets.append(f" {result}")
412
+
413
+ return prompts, targets
414
+
415
+
416
+ def evaluate_arithmetic(
417
+ model: AutoModelForCausalLM,
418
+ tokenizer: AutoTokenizer,
419
+ n_problems: int = 100,
420
+ device: str = "cpu",
421
+ ) -> dict:
422
+ """Evaluate model on random arithmetic problems."""
423
+ correct = 0
424
+ total = 0
425
+ errors = []
426
+
427
+ model.eval()
428
+
429
+ for _ in range(n_problems):
430
+ a = torch.randint(0, 256, (1,)).item()
431
+ b = torch.randint(0, 256, (1,)).item()
432
+ expected = (a + b) % 256
433
+
434
+ prompt = f"{a} + {b} ="
435
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
436
+
437
+ with torch.no_grad():
438
+ outputs = model.generate(
439
+ **inputs,
440
+ max_new_tokens=10,
441
+ do_sample=False,
442
+ pad_token_id=tokenizer.eos_token_id,
443
+ )
444
+
445
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
446
+
447
+ try:
448
+ answer_part = response.split("=")[-1].strip()
449
+ predicted = int("".join(c for c in answer_part.split()[0] if c.isdigit()))
450
+
451
+ if predicted == expected:
452
+ correct += 1
453
+ else:
454
+ errors.append((a, b, expected, predicted))
455
+ except:
456
+ errors.append((a, b, expected, "parse_error"))
457
+
458
+ total += 1
459
+
460
+ return {
461
+ "accuracy": correct / total,
462
+ "correct": correct,
463
+ "total": total,
464
+ "errors": errors[:10],
465
+ }
466
+
467
+
468
+ class ArithmeticDataset(Dataset):
469
+ """Dataset of 8-bit addition problems."""
470
+
471
+ def __init__(self, tokenizer, n_samples: int = 10000, max_val: int = 255):
472
+ self.tokenizer = tokenizer
473
+ self.n_samples = n_samples
474
+ self.max_val = max_val
475
+
476
+ self.examples = []
477
+ for _ in range(n_samples):
478
+ a = torch.randint(0, max_val + 1, (1,)).item()
479
+ b = torch.randint(0, max_val + 1, (1,)).item()
480
+ result = (a + b) % 256
481
+
482
+ prompt = f"{a} + {b} ="
483
+ target = f" {result}"
484
+
485
+ self.examples.append((prompt, target, a, b, result))
486
+
487
+ def __len__(self):
488
+ return len(self.examples)
489
+
490
+ def __getitem__(self, idx):
491
+ prompt, target, a, b, result = self.examples[idx]
492
+
493
+ prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
494
+ target_ids = self.tokenizer.encode(target, add_special_tokens=False)
495
+
496
+ input_ids = prompt_ids + target_ids
497
+ labels = [-100] * len(prompt_ids) + target_ids
498
+
499
+ return {
500
+ "input_ids": torch.tensor(input_ids),
501
+ "labels": torch.tensor(labels),
502
+ "a": a,
503
+ "b": b,
504
+ "result": result,
505
+ }
506
+
507
+
508
+ def collate_fn(batch):
509
+ """Collate with padding."""
510
+ max_len = max(len(item["input_ids"]) for item in batch)
511
+
512
+ input_ids = []
513
+ labels = []
514
+ attention_mask = []
515
+
516
+ for item in batch:
517
+ pad_len = max_len - len(item["input_ids"])
518
+
519
+ input_ids.append(
520
+ torch.cat([item["input_ids"], torch.zeros(pad_len, dtype=torch.long)])
521
+ )
522
+ labels.append(
523
+ torch.cat(
524
+ [item["labels"], torch.full((pad_len,), -100, dtype=torch.long)]
525
+ )
526
+ )
527
+ attention_mask.append(
528
+ torch.cat([torch.ones(len(item["input_ids"])), torch.zeros(pad_len)])
529
+ )
530
+
531
+ return {
532
+ "input_ids": torch.stack(input_ids),
533
+ "labels": torch.stack(labels),
534
+ "attention_mask": torch.stack(attention_mask),
535
+ }
536
+
537
+
538
+ def train_interface(
539
+ model: AutoModelForCausalLM,
540
+ tokenizer: AutoTokenizer,
541
+ n_epochs: int = 3,
542
+ batch_size: int = 16,
543
+ lr: float = 1e-4,
544
+ n_train_samples: int = 10000,
545
+ device: str = "cpu",
546
+ eval_every: int = 500,
547
+ ):
548
+ """
549
+ Train the circuit interface layers.
550
+
551
+ Only trains:
552
+ - bit_extractor (embedding -> bits)
553
+ - bit_injector (bits -> embedding)
554
+ - router (circuit vs MLP weighting)
555
+ - op_selector (which operation)
556
+ """
557
+ print("\n" + "=" * 70)
558
+ print(" TRAINING CIRCUIT INTERFACE")
559
+ print("=" * 70)
560
+
561
+ interface_params = []
562
+ frozen_count = 0
563
+ trainable_count = 0
564
+
565
+ for name, param in model.named_parameters():
566
+ if any(
567
+ x in name for x in ["bit_extractor", "bit_injector", "router", "op_selector"]
568
+ ):
569
+ param.requires_grad = True
570
+ interface_params.append(param)
571
+ trainable_count += param.numel()
572
+ else:
573
+ param.requires_grad = False
574
+ frozen_count += param.numel()
575
+
576
+ print(f"\n Frozen parameters: {frozen_count:,}")
577
+ print(f" Trainable parameters: {trainable_count:,}")
578
+ print(f" Training {len(interface_params)} parameter groups")
579
+
580
+ print(f"\n Creating dataset ({n_train_samples} examples)...")
581
+ dataset = ArithmeticDataset(tokenizer, n_samples=n_train_samples)
582
+ dataloader = DataLoader(
583
+ dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
584
+ )
585
+
586
+ optimizer = torch.optim.AdamW(interface_params, lr=lr)
587
+
588
+ model.to(device)
589
+ model.train()
590
+
591
+ global_step = 0
592
+ total_loss = 0
593
+
594
+ for epoch in range(n_epochs):
595
+ print(f"\n Epoch {epoch + 1}/{n_epochs}")
596
+ print(" " + "-" * 60)
597
+
598
+ epoch_loss = 0
599
+ epoch_steps = 0
600
+
601
+ pbar = tqdm(dataloader, desc=" Training", leave=False)
602
+
603
+ for batch in pbar:
604
+ input_ids = batch["input_ids"].to(device)
605
+ labels = batch["labels"].to(device)
606
+ attention_mask = batch["attention_mask"].to(device)
607
+
608
+ outputs = model(
609
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels
610
+ )
611
+
612
+ loss = outputs.loss
613
+
614
+ optimizer.zero_grad()
615
+ loss.backward()
616
+ optimizer.step()
617
+
618
+ epoch_loss += loss.item()
619
+ epoch_steps += 1
620
+ global_step += 1
621
+ total_loss += loss.item()
622
+
623
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
624
+
625
+ if global_step % eval_every == 0:
626
+ model.eval()
627
+ eval_results = evaluate_arithmetic(
628
+ model, tokenizer, n_problems=50, device=device
629
+ )
630
+ print(
631
+ f"\n Step {global_step}: Loss={total_loss/eval_every:.4f}, "
632
+ f"Accuracy={eval_results['accuracy']*100:.1f}%"
633
+ )
634
+ total_loss = 0
635
+ model.train()
636
+
637
+ avg_loss = epoch_loss / epoch_steps
638
+ print(f"\n Epoch {epoch + 1} complete. Avg loss: {avg_loss:.4f}")
639
+
640
+ model.eval()
641
+ eval_results = evaluate_arithmetic(
642
+ model, tokenizer, n_problems=100, device=device
643
+ )
644
+ print(
645
+ f" Evaluation: {eval_results['accuracy']*100:.1f}% "
646
+ f"({eval_results['correct']}/{eval_results['total']})"
647
+ )
648
+
649
+ if eval_results["errors"]:
650
+ print(" Sample errors:")
651
+ for a, b, exp, got in eval_results["errors"][:3]:
652
+ print(f" {a} + {b} = {exp}, model said {got}")
653
+
654
+ model.train()
655
+
656
+ print("\n" + "=" * 70)
657
+ print(" TRAINING COMPLETE")
658
+ print("=" * 70)
659
+
660
+ return model
661
+
662
+
663
+ if __name__ == "__main__":
664
+ parser = argparse.ArgumentParser(description="Circuit-Augmented LLM")
665
+ parser.add_argument(
666
+ "--circuit-path",
667
+ type=str,
668
+ default="./neural_computer.safetensors",
669
+ help="Path to circuit weights",
670
+ )
671
+ parser.add_argument("--device", type=str, default="cpu", help="Device")
672
+ parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
673
+ parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
674
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
675
+ parser.add_argument(
676
+ "--n-samples", type=int, default=5000, help="Training samples"
677
+ )
678
+ parser.add_argument(
679
+ "--eval-only", action="store_true", help="Only evaluate baseline"
680
+ )
681
+ args = parser.parse_args()
682
+
683
+ print("=" * 70)
684
+ print(" CIRCUIT-AUGMENTED LLM")
685
+ print("=" * 70)
686
+
687
+ print("\n[1] Loading SmolLM2-360M...")
688
+ model_id = "HuggingFaceTB/SmolLM2-360M"
689
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
690
+ tokenizer.pad_token = tokenizer.eos_token
691
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
692
+
693
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
694
+
695
+ print("\n[2] Baseline arithmetic evaluation...")
696
+ baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
697
+ print(
698
+ f" Accuracy: {baseline['accuracy']*100:.1f}% "
699
+ f"({baseline['correct']}/{baseline['total']})"
700
+ )
701
+ if baseline["errors"]:
702
+ print(" Sample errors:")
703
+ for a, b, exp, got in baseline["errors"][:5]:
704
+ print(f" {a} + {b} = {exp}, model said {got}")
705
+
706
+ if args.eval_only:
707
+ print("\nDone (eval only mode).")
708
+ exit(0)
709
+
710
+ print(f"\n[3] Augmenting with threshold circuits...")
711
+ print(f" Circuit path: {args.circuit_path}")
712
+ model = augment_smollm2_with_circuits(model, args.circuit_path, device=args.device)
713
+
714
+ new_params = sum(p.numel() for p in model.parameters())
715
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
716
+ print(f" Total parameters: {new_params:,}")
717
+ print(f" Trainable parameters: {trainable:,}")
718
+
719
+ print("\n[4] Testing circuit execution...")
720
+ circuit_exec = CircuitExecutor(args.circuit_path, args.device)
721
+
722
+ test_cases = [(127, 128), (255, 1), (0, 0), (100, 55)]
723
+ for a, b in test_cases:
724
+ a_bits = torch.tensor([(a >> i) & 1 for i in range(8)], dtype=torch.float32)
725
+ b_bits = torch.tensor([(b >> i) & 1 for i in range(8)], dtype=torch.float32)
726
+
727
+ result_bits, carry = circuit_exec.add_8bit(
728
+ a_bits.unsqueeze(0), b_bits.unsqueeze(0)
729
+ )
730
+
731
+ result = sum(int(result_bits[0, i].item()) * (2**i) for i in range(8))
732
+ expected = (a + b) % 256
733
+
734
+ status = "OK" if result == expected else "FAIL"
735
+ print(f" {a} + {b} = {result} (expected {expected}) [{status}]")
736
+
737
+ print("\n[5] Training interface layers...")
738
+ model = train_interface(
739
+ model,
740
+ tokenizer,
741
+ n_epochs=args.epochs,
742
+ batch_size=args.batch_size,
743
+ lr=args.lr,
744
+ n_train_samples=args.n_samples,
745
+ device=args.device,
746
+ )
747
+
748
+ print("\n[6] Final evaluation...")
749
+ final = evaluate_arithmetic(model, tokenizer, n_problems=100, device=args.device)
750
+ print(f" Final accuracy: {final['accuracy']*100:.1f}%")
751
+ print(
752
+ f" Improvement: {baseline['accuracy']*100:.1f}% -> {final['accuracy']*100:.1f}%"
753
+ )
754
+
755
+ save_path = "./circuit_augmented_smollm2.pt"
756
+ print(f"\n[7] Saving to {save_path}...")
757
+ torch.save(
758
+ {
759
+ "model_state_dict": model.state_dict(),
760
+ "baseline_accuracy": baseline["accuracy"],
761
+ "final_accuracy": final["accuracy"],
762
+ },
763
+ save_path,
764
+ )
765
+
766
+ print("\nDone!")
llm/guide.md DELETED
@@ -1,615 +0,0 @@
1
- # Embedding Threshold Logic Circuits into Transformer MLPs
2
-
3
- ## Technical Implementation Guide
4
-
5
- ---
6
-
7
- ## 1. Core Thesis
8
-
9
- Standard LLMs fail at arithmetic because they're interpolators—they approximate functions over training distributions rather than compute exact results. A 360M parameter model trained on internet text has seen "127 + 128 = 255" zero or few times, so it guesses "140" based on pattern matching.
10
-
11
- We solve this by embedding **frozen, proven-correct arithmetic circuits** directly into the transformer's MLP layers. The circuits use threshold logic (weighted sums + step activation), which is structurally compatible with neural network layers. We train only the **interface layers** that learn to:
12
-
13
- 1. Extract operands from token embeddings
14
- 2. Route computation through the circuits
15
- 3. Inject results back into the residual stream
16
-
17
- The model learns **call dispatch**, not arithmetic. The arithmetic is already solved.
18
-
19
- ---
20
-
21
- ## 2. Threshold Logic Fundamentals
22
-
23
- ### 2.1 Single Threshold Gate
24
-
25
- A threshold gate computes:
26
-
27
- ```
28
- output = 1 if (Σ wᵢxᵢ + b) ≥ 0
29
- 0 otherwise
30
- ```
31
-
32
- This is a neuron with Heaviside step activation. With integer weights `w` and bias `b`, it computes a Boolean function of binary inputs.
33
-
34
- **Example: AND gate**
35
- ```
36
- w = [1, 1], b = -2
37
- AND(0,0) = H(0 + 0 - 2) = H(-2) = 0
38
- AND(0,1) = H(0 + 1 - 2) = H(-1) = 0
39
- AND(1,0) = H(1 + 0 - 2) = H(-1) = 0
40
- AND(1,1) = H(1 + 1 - 2) = H(0) = 1
41
- ```
42
-
43
- **Example: OR gate**
44
- ```
45
- w = [1, 1], b = -1
46
- OR(0,0) = H(0 + 0 - 1) = H(-1) = 0
47
- OR(0,1) = H(0 + 1 - 1) = H(0) = 1
48
- OR(1,0) = H(1 + 0 - 1) = H(0) = 1
49
- OR(1,1) = H(1 + 1 - 1) = H(1) = 1
50
- ```
51
-
52
- ### 2.2 Multi-Layer Circuits
53
-
54
- XOR is not linearly separable—it requires two layers:
55
-
56
- ```
57
- Layer 1:
58
- neuron1 (OR): w=[1,1], b=-1 → fires if a OR b
59
- neuron2 (NAND): w=[-1,-1], b=1 → fires if NOT(a AND b)
60
-
61
- Layer 2:
62
- neuron3 (AND): w=[1,1], b=-2 → fires if both layer1 outputs are 1
63
-
64
- XOR(a,b) = AND(OR(a,b), NAND(a,b))
65
- ```
66
-
67
- ### 2.3 Full Adder
68
-
69
- A full adder computes `sum` and `carry_out` from inputs `a`, `b`, `carry_in`:
70
-
71
- ```
72
- sum = a XOR b XOR cin
73
- cout = (a AND b) OR (cin AND (a XOR b))
74
- ```
75
-
76
- Implementation uses two half-adders chained:
77
-
78
- ```
79
- HA1: (a, b) → (sum1 = a XOR b, carry1 = a AND b)
80
- HA2: (sum1, cin) → (sum2 = sum1 XOR cin, carry2 = sum1 AND cin)
81
- cout = carry1 OR carry2
82
- final_sum = sum2
83
- ```
84
-
85
- Each XOR is 2 layers, each AND/OR is 1 layer. Total depth: ~4 layers per full adder.
86
-
87
- ### 2.4 8-bit Ripple Carry Adder
88
-
89
- Chain 8 full adders, propagating carry:
90
-
91
- ```
92
- FA0: (a[0], b[0], 0) → (sum[0], c0)
93
- FA1: (a[1], b[1], c0) → (sum[1], c1)
94
- FA2: (a[2], b[2], c1) → (sum[2], c2)
95
- ...
96
- FA7: (a[7], b[7], c6) → (sum[7], c7)
97
- ```
98
-
99
- Total circuit depth: ~32 threshold layers (8 FAs × 4 layers each).
100
-
101
- ---
102
-
103
- ## 3. Circuit Inventory
104
-
105
- The `neural_computer.safetensors` contains 24,200 tensors / 40,323 parameters implementing:
106
-
107
- | Category | Circuits | Tensors |
108
- |----------|----------|---------|
109
- | Boolean | AND, OR, NOT, NAND, NOR, XOR, XNOR, IMPLIES, BIIMPLIES | ~30 |
110
- | Arithmetic | Half adder, Full adder, Ripple carry 2/4/8-bit, 8×8 multiplier | ~800 |
111
- | Comparators | GT, LT, GEQ, LEQ, EQ (8-bit) | ~50 |
112
- | ALU | 16-operation ALU, opcode decoder, flag computation | ~400 |
113
- | Control | JMP, JZ, JNZ, JC, JNC, JN, JP, CALL, RET, PUSH, POP | ~200 |
114
- | Modular | Divisibility by 2-12 | ~600 |
115
- | Error Detection | Parity, CRC, Hamming, checksum | ~200 |
116
- | Pattern | Popcount, leading zeros, symmetry | ~150 |
117
- | Threshold | k-of-n gates, majority, minority | ~100 |
118
-
119
- All weights are integers. All activations are Heaviside. Verified with 6,590 exhaustive tests.
120
-
121
- ---
122
-
123
- ## 4. Transformer Integration Architecture
124
-
125
- ### 4.1 Target: SmolLM2-360M
126
-
127
- ```
128
- Architecture: LlamaForCausalLM
129
- Hidden dim: 960
130
- Layers: 32
131
- Heads: 15
132
- MLP expansion: 4x (intermediate = 3840)
133
- Vocab: 49152
134
- Parameters: 361,821,120
135
- ```
136
-
137
- Standard MLP block:
138
- ```python
139
- def forward(x): # x: [batch, seq, 960]
140
- gate = self.gate_proj(x) # [batch, seq, 3840]
141
- up = self.up_proj(x) # [batch, seq, 3840]
142
- hidden = silu(gate) * up # SwiGLU activation
143
- return self.down_proj(hidden) # [batch, seq, 960]
144
- ```
145
-
146
- ### 4.2 Augmented MLP Block
147
-
148
- ```python
149
- def forward(x): # x: [batch, seq, 960]
150
- # Original MLP path (unchanged)
151
- mlp_out = self.down_proj(silu(self.gate_proj(x)) * self.up_proj(x))
152
-
153
- # Circuit path (new)
154
- a_bits, b_bits = self.bit_extractor(x) # [batch, seq, 8] each
155
- result_bits, carry = self.circuits.add_8bit(a_bits, b_bits)
156
- flags = self.compute_flags(result_bits, carry)
157
- circuit_delta = self.bit_injector(result_bits, flags)
158
-
159
- # Routing
160
- route_weights = self.router(x) # [batch, seq, 2] softmax
161
-
162
- # Combine
163
- return mlp_out + route_weights[..., 1:2] * circuit_delta
164
- ```
165
-
166
- ### 4.3 Layer Selection
167
-
168
- We augment the **middle third** of layers (10-20 of 32):
169
-
170
- - Early layers (0-9): Token/position encoding, not arithmetic-relevant
171
- - Middle layers (10-20): Abstract reasoning, computation
172
- - Late layers (21-31): Output formatting, vocabulary projection
173
-
174
- Rationale: Arithmetic computation happens in middle layers where the model processes relationships between tokens. Early layers haven't built sufficient representations; late layers are committed to output tokens.
175
-
176
- ---
177
-
178
- ## 5. Interface Layers (Trainable)
179
-
180
- ### 5.1 BitExtractor
181
-
182
- Maps token embedding → two 8-bit operands.
183
-
184
- ```python
185
- class BitExtractor(nn.Module):
186
- def __init__(self, d_model=960):
187
- self.proj = nn.Linear(d_model, 16) # 960 → 16
188
-
189
- def forward(self, x):
190
- logits = self.proj(x) # [batch, seq, 16]
191
- bits = heaviside(logits) # binarize with STE
192
- a_bits = bits[..., :8] # first operand
193
- b_bits = bits[..., 8:] # second operand
194
- return a_bits, b_bits # both [batch, seq, 8], LSB first
195
- ```
196
-
197
- **What it learns**: Which embedding dimensions encode numeric magnitude. For token "127", it must learn that certain activation patterns correspond to bits `[1,1,1,1,1,1,1,0]`.
198
-
199
- **Parameters**: 960 × 16 + 16 = 15,376
200
-
201
- ### 5.2 BitInjector
202
-
203
- Maps circuit outputs → embedding delta.
204
-
205
- ```python
206
- class BitInjector(nn.Module):
207
- def __init__(self, d_model=960):
208
- self.proj = nn.Linear(16, d_model) # 16 → 960
209
- self.scale = nn.Parameter(torch.tensor(0.1))
210
-
211
- def forward(self, result_bits, flags):
212
- combined = torch.cat([result_bits, flags], dim=-1) # [batch, seq, 16]
213
- return self.proj(combined) * self.scale # [batch, seq, 960]
214
- ```
215
-
216
- **What it learns**: How to inject the result bits back into embedding space such that subsequent layers (and the final vocabulary projection) produce the correct output tokens.
217
-
218
- **Parameters**: 16 × 960 + 960 + 1 = 16,321
219
-
220
- ### 5.3 Router
221
-
222
- Decides when to use circuit path.
223
-
224
- ```python
225
- class Router(nn.Module):
226
- def __init__(self, d_model=960):
227
- self.net = nn.Sequential(
228
- nn.Linear(d_model, 64),
229
- nn.ReLU(),
230
- nn.Linear(64, 2),
231
- nn.Softmax(dim=-1)
232
- )
233
-
234
- def forward(self, x):
235
- return self.net(x) # [batch, seq, 2]: [mlp_weight, circuit_weight]
236
- ```
237
-
238
- **What it learns**: "This position contains arithmetic" → route through circuits. "This is prose" → use normal MLP.
239
-
240
- **Parameters**: 960 × 64 + 64 + 64 × 2 + 2 = 61,698
241
-
242
- ### 5.4 Total Trainable Parameters
243
-
244
- Per augmented layer:
245
- ```
246
- BitExtractor: 15,376
247
- BitInjector: 16,321
248
- Router: 61,698
249
- OpSelector: ~31,000
250
- ───────────────────────
251
- Total: ~124,395 per layer
252
- ```
253
-
254
- For 11 augmented layers: **~1.37M trainable parameters**
255
-
256
- This is 0.38% of the model. The other 99.62% (including all circuit weights) is frozen.
257
-
258
- ---
259
-
260
- ## 6. Gradient Flow Through Heaviside
261
-
262
- ### 6.1 The Problem
263
-
264
- Heaviside has zero gradient almost everywhere:
265
-
266
- ```
267
- H(x) = 1 if x ≥ 0 else 0
268
- dH/dx = 0 for x ≠ 0, undefined at x = 0
269
- ```
270
-
271
- Standard backprop would give zero gradients to BitExtractor.
272
-
273
- ### 6.2 Straight-Through Estimator (STE)
274
-
275
- We use STE: forward pass uses true Heaviside, backward pass pretends it's identity.
276
-
277
- ```python
278
- class HeavisideSTE(torch.autograd.Function):
279
- @staticmethod
280
- def forward(ctx, x):
281
- return (x >= 0).float() # true step function
282
-
283
- @staticmethod
284
- def backward(ctx, grad_output):
285
- return grad_output # pass gradient through unchanged
286
- ```
287
-
288
- **Intuition**: "If making the input larger would have helped the output, increase the input." The gradient tells us the direction even though the function is flat.
289
-
290
- ### 6.3 Alternative: Sigmoid Annealing
291
-
292
- During training, use sigmoid with increasing temperature:
293
-
294
- ```python
295
- def soft_heaviside(x, temperature):
296
- return torch.sigmoid(x * temperature)
297
-
298
- # temperature: 1 → 10 → 100 over training
299
- # At high temperature, sigmoid ≈ step function
300
- ```
301
-
302
- This provides smoother gradients early in training, then sharpens to true binary at inference.
303
-
304
- ---
305
-
306
- ## 7. Training Strategy
307
-
308
- ### 7.1 Data Generation
309
-
310
- Generate arithmetic problems exhaustively:
311
-
312
- ```python
313
- def generate_batch(batch_size):
314
- a = torch.randint(0, 256, (batch_size,))
315
- b = torch.randint(0, 256, (batch_size,))
316
- result = (a + b) % 256
317
-
318
- prompts = [f"{a[i]} + {b[i]} =" for i in range(batch_size)]
319
- targets = [f" {result[i]}" for i in range(batch_size)]
320
-
321
- return prompts, targets
322
- ```
323
-
324
- For 8-bit addition, there are 256 × 256 = 65,536 unique problems. We can cover the entire space.
325
-
326
- ### 7.2 Loss Function
327
-
328
- Standard cross-entropy on next-token prediction:
329
-
330
- ```python
331
- outputs = model(input_ids, attention_mask=mask, labels=labels)
332
- loss = outputs.loss # CE loss, only on target tokens
333
- ```
334
-
335
- Labels are masked for prompt tokens (`-100`), so loss only backprops through the answer.
336
-
337
- ### 7.3 Optimizer Configuration
338
-
339
- ```python
340
- # Only train interface layers
341
- interface_params = [p for n, p in model.named_parameters()
342
- if any(x in n for x in ['bit_extractor', 'bit_injector', 'router'])]
343
-
344
- optimizer = AdamW(interface_params, lr=1e-4, weight_decay=0.01)
345
- scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
346
- ```
347
-
348
- ### 7.4 Curriculum Learning
349
-
350
- Start simple, increase difficulty:
351
-
352
- ```
353
- Phase 1 (epochs 1-2): Single-digit addition (0-9 + 0-9)
354
- Phase 2 (epochs 3-4): Two-digit addition (0-99 + 0-99)
355
- Phase 3 (epochs 5-7): Full 8-bit addition (0-255 + 0-255)
356
- Phase 4 (epochs 8-10): Adversarial cases (carry chains: 127+128, 255+1)
357
- ```
358
-
359
- This helps the interface layers learn the basic extraction pattern before tackling hard cases.
360
-
361
- ### 7.5 Training Hyperparameters
362
-
363
- ```
364
- Model: SmolLM2-360M
365
- Augmented: Layers 10-20 (11 layers)
366
- Trainable: 1.37M parameters
367
- Frozen: 362M parameters (including 5.6K circuit params)
368
-
369
- Batch size: 32
370
- Learning rate: 1e-4
371
- Epochs: 10
372
- Samples: 10,000 per epoch
373
- Warmup: 500 steps
374
- Device: RTX 6000 Ada (48GB)
375
-
376
- Expected time: ~30 minutes total
377
- ```
378
-
379
- ---
380
-
381
- ## 8. Forward Pass Walkthrough
382
-
383
- Input: `"127 + 128 ="`
384
-
385
- ### 8.1 Tokenization
386
-
387
- ```
388
- Tokens: ["127", " +", " 128", " ="]
389
- IDs: [12700, 489, 13824, 284] # hypothetical
390
- ```
391
-
392
- ### 8.2 Embedding
393
-
394
- ```
395
- embeddings = embed(input_ids) # [1, 4, 960]
396
- ```
397
-
398
- ### 8.3 Layers 0-9 (Unchanged)
399
-
400
- Standard attention + MLP, building representations.
401
-
402
- ### 8.4 Layer 10 (Augmented)
403
-
404
- ```python
405
- # After attention
406
- x = layer_norm(attn_output + residual) # [1, 4, 960]
407
-
408
- # MLP path
409
- mlp_out = down_proj(silu(gate_proj(x)) * up_proj(x))
410
-
411
- # Circuit path
412
- a_bits, b_bits = bit_extractor(x)
413
- # Position 0 ("127"): a_bits ≈ [1,1,1,1,1,1,1,0] if well-trained
414
- # Position 2 ("128"): b_bits ≈ [0,0,0,0,0,0,0,1]
415
- # (In practice, extraction happens per-position; aggregation is learned)
416
-
417
- result_bits, carry = circuits.add_8bit(a_bits, b_bits)
418
- # result_bits = [1,1,1,1,1,1,1,1] = 255
419
-
420
- flags = compute_flags(result_bits, carry)
421
- # zero=0, negative=1, carry=1
422
-
423
- circuit_delta = bit_injector(result_bits, flags) # [1, 4, 960]
424
-
425
- # Routing
426
- route = router(x) # [1, 4, 2]
427
- # Position 3 ("="): route ≈ [0.1, 0.9] → use circuits
428
- # Position 1 ("+"): route ≈ [0.8, 0.2] → mostly MLP
429
-
430
- # Combine
431
- output = mlp_out + route[..., 1:2] * circuit_delta
432
- ```
433
-
434
- ### 8.5 Layers 11-31
435
-
436
- Continue processing, eventually projecting to vocabulary.
437
-
438
- ### 8.6 Output
439
-
440
- ```
441
- logits = lm_head(final_hidden) # [1, 4, 49152]
442
- next_token = argmax(logits[0, 3, :]) # token after "="
443
- # Should decode to "255" (possibly as " 255" or "255")
444
- ```
445
-
446
- ---
447
-
448
- ## 9. Inference Characteristics
449
-
450
- ### 9.1 Exactness
451
-
452
- At inference, Heaviside is true step function—no approximation. If BitExtractor correctly maps "127" → bits and "128" → bits, the circuit **will** output 255. The only failure mode is incorrect extraction.
453
-
454
- ### 9.2 Latency
455
-
456
- Circuit computation adds ~5-10% overhead:
457
- - BitExtractor: 1 linear layer (960→16)
458
- - Circuits: ~32 threshold layers, but sparse and tiny
459
- - BitInjector: 1 linear layer (16→960)
460
- - Router: 2 linear layers
461
-
462
- The circuits have only 40,323 parameters total—negligible versus the 361M in the base model.
463
-
464
- ### 9.3 Generalization
465
-
466
- Once the interface learns the mapping, it generalizes to **all** 65,536 8-bit additions. There's no memorization—the circuits compute.
467
-
468
- ---
469
-
470
- ## 10. Evaluation Metrics
471
-
472
- ### 10.1 Arithmetic Accuracy
473
-
474
- ```python
475
- def eval_accuracy(model, n_problems=1000):
476
- correct = 0
477
- for _ in range(n_problems):
478
- a, b = random 8-bit values
479
- expected = (a + b) % 256
480
- predicted = model.generate(f"{a} + {b} =")
481
- if parse_int(predicted) == expected:
482
- correct += 1
483
- return correct / n_problems
484
- ```
485
-
486
- **Baseline SmolLM2**: ~5-10% (guessing based on patterns)
487
- **Target**: >95% (circuit-accurate)
488
-
489
- ### 10.2 Edge Case Performance
490
-
491
- Specifically test:
492
- - Carry propagation: 127+128, 255+1, 128+128
493
- - Zeros: 0+0, 0+255
494
- - Identity: x+0 for various x
495
- - Commutativity: verify a+b == b+a
496
-
497
- ### 10.3 Non-Arithmetic Preservation
498
-
499
- Verify general capability isn't degraded:
500
- - Perplexity on held-out text
501
- - Common benchmarks (HellaSwag, etc.)
502
-
503
- The augmentation should be **additive**—circuits help arithmetic, MLP handles everything else via routing.
504
-
505
- ---
506
-
507
- ## 11. Extension Roadmap
508
-
509
- ### 11.1 Additional Operations
510
-
511
- The circuit inventory includes:
512
- - Subtraction (via two's complement)
513
- - Multiplication (8×8 → 16-bit)
514
- - Division (iterative subtraction)
515
- - Bitwise ops (AND, OR, XOR, shifts)
516
- - Comparisons (GT, LT, EQ)
517
-
518
- Each needs its own extraction/injection interface, or a unified interface with operation selection.
519
-
520
- ### 11.2 Multi-Operand Expressions
521
-
522
- For "15 + 27 + 33 =", need:
523
- - Operand count detection
524
- - Sequential circuit invocation
525
- - Accumulator pattern
526
-
527
- ### 11.3 Larger Bit Widths
528
-
529
- 16-bit and 32-bit arithmetic require:
530
- - Larger circuits (or chained 8-bit)
531
- - Wider BitExtractor (32 or 64 output dims)
532
- - More training data
533
-
534
- ### 11.4 Symbolic Integration
535
-
536
- Ultimate goal: the model recognizes when it needs to compute, invokes circuits, and integrates results into coherent natural language output.
537
-
538
- ```
539
- User: "If I have 127 apples and buy 128 more, how many do I have?"
540
- Model: [extracts 127, 128] [routes to circuit] [gets 255]
541
- "You would have 255 apples."
542
- ```
543
-
544
- ---
545
-
546
- ## 12. File Structure
547
-
548
- ```
549
- 8bit-threshold-computer/
550
- ├── neural_computer.safetensors # Frozen circuits (24,200 tensors)
551
- ├── circuit_llm.py # Integration architecture
552
- ├── train_circuit_interface.py # Training loop
553
- ├── iron_eval.py # Circuit verification (6,590 tests)
554
- ├── skeptic_test.py # Algebraic identity tests (127 tests)
555
- ├── prune_weights.py # Weight optimization
556
- ├── tensors.txt # Tensor manifest
557
- ├── guide.md # This document
558
- └── README.md # Project overview
559
- ```
560
-
561
- ---
562
-
563
- ## 13. Key Equations
564
-
565
- ### Heaviside Step
566
- ```
567
- H(x) = 1 if x ≥ 0 else 0
568
- ```
569
-
570
- ### Threshold Gate
571
- ```
572
- f(x₁,...,xₙ) = H(Σᵢ wᵢxᵢ + b)
573
- ```
574
-
575
- ### Full Adder
576
- ```
577
- sum = a ⊕ b ⊕ cᵢₙ
578
- cₒᵤₜ = (a ∧ b) ∨ (cᵢₙ ∧ (a ⊕ b))
579
- ```
580
-
581
- ### STE Gradient
582
- ```
583
- Forward: y = H(x)
584
- Backward: ∂L/∂x = ∂L/∂y
585
- ```
586
-
587
- ### Router Combination
588
- ```
589
- output = mlp_out + softmax(router(x))[1] × circuit_delta
590
- ```
591
-
592
- ---
593
-
594
- ## 14. References
595
-
596
- 1. McCulloch & Pitts (1943). "A Logical Calculus of Ideas Immanent in Nervous Activity"
597
- 2. Muroga (1971). "Threshold Logic and Its Applications"
598
- 3. Siegelmann & Sontag (1995). "On the Computational Power of Neural Nets"
599
- 4. Bengio et al. (2013). "Estimating or Propagating Gradients Through Stochastic Neurons"
600
- 5. Ma et al. (2024). "The Era of 1-bit LLMs" (BitNet b1.58)
601
- 6. HuggingFace (2024). "SmolLM2: Small Language Models"
602
-
603
- ---
604
-
605
- ## 15. Summary
606
-
607
- We embed a proven-correct 8-bit threshold logic computer into SmolLM2's MLP layers. The circuits are frozen; we train only the interface layers that learn call dispatch. This gives the LLM exact arithmetic capability without training it to "do math"—the math is already done.
608
-
609
- The approach is:
610
- - **Sound**: Circuits verified with 6,590 tests
611
- - **Efficient**: 1.37M trainable params, 5.6K circuit params
612
- - **Exact**: Heaviside at inference means no approximation error
613
- - **Composable**: Add more circuits (multiply, compare, etc.) with same pattern
614
-
615
- The model learns when to call the calculator, not how to calculate.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm/train_circuit_interface.py DELETED
@@ -1,306 +0,0 @@
1
- """
2
- Train the circuit interface layers on arithmetic examples.
3
- ============================================================
4
-
5
- The threshold circuits are frozen - we only train:
6
- - BitExtractor: embedding -> operand bits
7
- - BitInjector: result bits -> embedding
8
- - Router: when to use circuits vs MLP
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- from torch.utils.data import Dataset, DataLoader
14
- from transformers import AutoModelForCausalLM, AutoTokenizer
15
- from tqdm import tqdm
16
- import argparse
17
- import warnings
18
- warnings.filterwarnings('ignore')
19
-
20
- from circuit_llm import (
21
- augment_smollm2_with_circuits,
22
- evaluate_arithmetic,
23
- CircuitExecutor
24
- )
25
-
26
-
27
- # =============================================================================
28
- # ARITHMETIC DATASET
29
- # =============================================================================
30
-
31
- class ArithmeticDataset(Dataset):
32
- """Dataset of 8-bit addition problems."""
33
-
34
- def __init__(self, tokenizer, n_samples: int = 10000, max_val: int = 255):
35
- self.tokenizer = tokenizer
36
- self.n_samples = n_samples
37
- self.max_val = max_val
38
-
39
- # Pre-generate all examples
40
- self.examples = []
41
- for _ in range(n_samples):
42
- a = torch.randint(0, max_val + 1, (1,)).item()
43
- b = torch.randint(0, max_val + 1, (1,)).item()
44
- result = (a + b) % 256
45
-
46
- prompt = f"{a} + {b} ="
47
- target = f" {result}"
48
-
49
- self.examples.append((prompt, target, a, b, result))
50
-
51
- def __len__(self):
52
- return len(self.examples)
53
-
54
- def __getitem__(self, idx):
55
- prompt, target, a, b, result = self.examples[idx]
56
-
57
- # Tokenize
58
- prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
59
- target_ids = self.tokenizer.encode(target, add_special_tokens=False)
60
-
61
- input_ids = prompt_ids + target_ids
62
- labels = [-100] * len(prompt_ids) + target_ids # Only predict target
63
-
64
- return {
65
- 'input_ids': torch.tensor(input_ids),
66
- 'labels': torch.tensor(labels),
67
- 'a': a,
68
- 'b': b,
69
- 'result': result
70
- }
71
-
72
-
73
- def collate_fn(batch):
74
- """Collate with padding."""
75
- max_len = max(len(item['input_ids']) for item in batch)
76
-
77
- input_ids = []
78
- labels = []
79
- attention_mask = []
80
-
81
- for item in batch:
82
- pad_len = max_len - len(item['input_ids'])
83
-
84
- input_ids.append(
85
- torch.cat([item['input_ids'], torch.zeros(pad_len, dtype=torch.long)])
86
- )
87
- labels.append(
88
- torch.cat([item['labels'], torch.full((pad_len,), -100, dtype=torch.long)])
89
- )
90
- attention_mask.append(
91
- torch.cat([torch.ones(len(item['input_ids'])), torch.zeros(pad_len)])
92
- )
93
-
94
- return {
95
- 'input_ids': torch.stack(input_ids),
96
- 'labels': torch.stack(labels),
97
- 'attention_mask': torch.stack(attention_mask),
98
- }
99
-
100
-
101
- # =============================================================================
102
- # TRAINING LOOP
103
- # =============================================================================
104
-
105
- def train_interface(
106
- model: AutoModelForCausalLM,
107
- tokenizer: AutoTokenizer,
108
- n_epochs: int = 3,
109
- batch_size: int = 16,
110
- lr: float = 1e-4,
111
- n_train_samples: int = 10000,
112
- device: str = 'cpu',
113
- eval_every: int = 500
114
- ):
115
- """
116
- Train the circuit interface layers.
117
-
118
- Only trains:
119
- - bit_extractor (embedding -> bits)
120
- - bit_injector (bits -> embedding)
121
- - router (circuit vs MLP weighting)
122
- - op_selector (which operation)
123
- """
124
- print("\n" + "=" * 70)
125
- print(" TRAINING CIRCUIT INTERFACE")
126
- print("=" * 70)
127
-
128
- # Freeze everything except interface layers
129
- interface_params = []
130
- frozen_count = 0
131
- trainable_count = 0
132
-
133
- for name, param in model.named_parameters():
134
- if any(x in name for x in ['bit_extractor', 'bit_injector', 'router', 'op_selector']):
135
- param.requires_grad = True
136
- interface_params.append(param)
137
- trainable_count += param.numel()
138
- else:
139
- param.requires_grad = False
140
- frozen_count += param.numel()
141
-
142
- print(f"\n Frozen parameters: {frozen_count:,}")
143
- print(f" Trainable parameters: {trainable_count:,}")
144
- print(f" Training {len(interface_params)} parameter groups")
145
-
146
- # Create dataset
147
- print(f"\n Creating dataset ({n_train_samples} examples)...")
148
- dataset = ArithmeticDataset(tokenizer, n_samples=n_train_samples)
149
- dataloader = DataLoader(
150
- dataset,
151
- batch_size=batch_size,
152
- shuffle=True,
153
- collate_fn=collate_fn
154
- )
155
-
156
- # Optimizer
157
- optimizer = torch.optim.AdamW(interface_params, lr=lr)
158
-
159
- # Training
160
- model.to(device)
161
- model.train()
162
-
163
- global_step = 0
164
- total_loss = 0
165
-
166
- for epoch in range(n_epochs):
167
- print(f"\n Epoch {epoch + 1}/{n_epochs}")
168
- print(" " + "-" * 60)
169
-
170
- epoch_loss = 0
171
- epoch_steps = 0
172
-
173
- pbar = tqdm(dataloader, desc=f" Training", leave=False)
174
-
175
- for batch in pbar:
176
- input_ids = batch['input_ids'].to(device)
177
- labels = batch['labels'].to(device)
178
- attention_mask = batch['attention_mask'].to(device)
179
-
180
- # Forward
181
- outputs = model(
182
- input_ids=input_ids,
183
- attention_mask=attention_mask,
184
- labels=labels
185
- )
186
-
187
- loss = outputs.loss
188
-
189
- # Backward
190
- optimizer.zero_grad()
191
- loss.backward()
192
- optimizer.step()
193
-
194
- # Logging
195
- epoch_loss += loss.item()
196
- epoch_steps += 1
197
- global_step += 1
198
- total_loss += loss.item()
199
-
200
- pbar.set_postfix({'loss': f'{loss.item():.4f}'})
201
-
202
- # Periodic evaluation
203
- if global_step % eval_every == 0:
204
- model.eval()
205
- eval_results = evaluate_arithmetic(model, tokenizer, n_problems=50, device=device)
206
- print(f"\n Step {global_step}: Loss={total_loss/eval_every:.4f}, "
207
- f"Accuracy={eval_results['accuracy']*100:.1f}%")
208
- total_loss = 0
209
- model.train()
210
-
211
- avg_loss = epoch_loss / epoch_steps
212
- print(f"\n Epoch {epoch + 1} complete. Avg loss: {avg_loss:.4f}")
213
-
214
- # End of epoch evaluation
215
- model.eval()
216
- eval_results = evaluate_arithmetic(model, tokenizer, n_problems=100, device=device)
217
- print(f" Evaluation: {eval_results['accuracy']*100:.1f}% "
218
- f"({eval_results['correct']}/{eval_results['total']})")
219
-
220
- if eval_results['errors']:
221
- print(f" Sample errors:")
222
- for a, b, exp, got in eval_results['errors'][:3]:
223
- print(f" {a} + {b} = {exp}, model said {got}")
224
-
225
- model.train()
226
-
227
- print("\n" + "=" * 70)
228
- print(" TRAINING COMPLETE")
229
- print("=" * 70)
230
-
231
- return model
232
-
233
-
234
- # =============================================================================
235
- # MAIN
236
- # =============================================================================
237
-
238
- if __name__ == "__main__":
239
- parser = argparse.ArgumentParser(description='Train Circuit Interface')
240
- parser.add_argument('--circuit-path', type=str,
241
- default='./neural_computer.safetensors',
242
- help='Path to circuit weights')
243
- parser.add_argument('--device', type=str, default='cpu',
244
- help='Device (cpu or cuda)')
245
- parser.add_argument('--epochs', type=int, default=3,
246
- help='Number of epochs')
247
- parser.add_argument('--batch-size', type=int, default=8,
248
- help='Batch size')
249
- parser.add_argument('--lr', type=float, default=1e-4,
250
- help='Learning rate')
251
- parser.add_argument('--n-samples', type=int, default=5000,
252
- help='Number of training samples')
253
- args = parser.parse_args()
254
-
255
- print("=" * 70)
256
- print(" CIRCUIT-AUGMENTED LLM TRAINING")
257
- print("=" * 70)
258
-
259
- # Load model
260
- print("\n[1] Loading SmolLM2-360M...")
261
- model_id = "HuggingFaceTB/SmolLM2-360M"
262
- tokenizer = AutoTokenizer.from_pretrained(model_id)
263
- tokenizer.pad_token = tokenizer.eos_token
264
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
265
-
266
- # Baseline
267
- print("\n[2] Baseline evaluation...")
268
- baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
269
- print(f" Baseline accuracy: {baseline['accuracy']*100:.1f}%")
270
-
271
- # Augment
272
- print("\n[3] Augmenting with circuits...")
273
- model = augment_smollm2_with_circuits(
274
- model,
275
- args.circuit_path,
276
- device=args.device
277
- )
278
-
279
- # Train
280
- print("\n[4] Training interface layers...")
281
- model = train_interface(
282
- model,
283
- tokenizer,
284
- n_epochs=args.epochs,
285
- batch_size=args.batch_size,
286
- lr=args.lr,
287
- n_train_samples=args.n_samples,
288
- device=args.device
289
- )
290
-
291
- # Final evaluation
292
- print("\n[5] Final evaluation...")
293
- final = evaluate_arithmetic(model, tokenizer, n_problems=100, device=args.device)
294
- print(f" Final accuracy: {final['accuracy']*100:.1f}%")
295
- print(f" Improvement: {baseline['accuracy']*100:.1f}% -> {final['accuracy']*100:.1f}%")
296
-
297
- # Save
298
- save_path = './circuit_augmented_smollm2.pt'
299
- print(f"\n[6] Saving to {save_path}...")
300
- torch.save({
301
- 'model_state_dict': model.state_dict(),
302
- 'baseline_accuracy': baseline['accuracy'],
303
- 'final_accuracy': final['accuracy']
304
- }, save_path)
305
-
306
- print("\nDone!")