phanerozoic commited on
Commit
aa8a1bb
·
verified ·
1 Parent(s): 7b6c28e

Clean up: remove duplicates, pycache, and misplaced files

Browse files
__pycache__/circuit_llm.cpython-311.pyc DELETED
Binary file (30.9 kB)
 
__pycache__/evolve_weights.cpython-312.pyc DELETED
Binary file (16.7 kB)
 
__pycache__/iron_eval.cpython-311.pyc DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:889b91e75c83db93f3bb79eca5185dcc75309927c4b558944425b30365110603
3
- size 274934
 
 
 
 
__pycache__/iron_eval.cpython-312.pyc DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e35bdab47ec64d0490d088431d6d3034160c3af4a308569c0a6cc6294ef45e98
3
- size 213694
 
 
 
 
circuit_llm.py DELETED
@@ -1,606 +0,0 @@
1
- """
2
- Circuit-Augmented LLM: Embedding threshold logic circuits into SmolLM2
3
- ======================================================================
4
-
5
- Replaces/augments MLP layers with frozen threshold circuits for exact arithmetic.
6
- """
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from typing import Dict, Optional, Tuple
12
- from safetensors.torch import load_file
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
14
- import warnings
15
- warnings.filterwarnings('ignore')
16
-
17
-
18
- # =============================================================================
19
- # HEAVISIDE WITH STRAIGHT-THROUGH ESTIMATOR
20
- # =============================================================================
21
-
22
- class HeavisideSTE(torch.autograd.Function):
23
- """Heaviside step function with straight-through estimator for backprop."""
24
-
25
- @staticmethod
26
- def forward(ctx, x):
27
- return (x >= 0).float()
28
-
29
- @staticmethod
30
- def backward(ctx, grad_output):
31
- # STE: pass gradient through unchanged
32
- return grad_output
33
-
34
-
35
- def heaviside(x: torch.Tensor) -> torch.Tensor:
36
- """Heaviside step: 1 if x >= 0, else 0. Uses STE for training."""
37
- return HeavisideSTE.apply(x)
38
-
39
-
40
- # =============================================================================
41
- # CIRCUIT EXECUTOR - Runs the frozen threshold circuits
42
- # =============================================================================
43
-
44
- class CircuitExecutor(nn.Module):
45
- """
46
- Executes threshold logic circuits from the safetensors file.
47
- All circuit weights are frozen - only interface layers train.
48
- """
49
-
50
- def __init__(self, circuit_path: str, device: str = 'cpu'):
51
- super().__init__()
52
- self.device = device
53
-
54
- # Load all circuit tensors
55
- raw_circuits = load_file(circuit_path)
56
-
57
- # Store as frozen parameters (use underscores for valid param names)
58
- self.circuits = {}
59
- for k, v in raw_circuits.items():
60
- safe_name = k.replace('.', '__')
61
- self.register_buffer(safe_name, v.float().to(device))
62
- self.circuits[k] = safe_name
63
-
64
- def _get(self, name: str) -> torch.Tensor:
65
- """Get circuit tensor by original dotted name."""
66
- return getattr(self, self.circuits[name])
67
-
68
- # -------------------------------------------------------------------------
69
- # Boolean Gates
70
- # -------------------------------------------------------------------------
71
-
72
- def eval_and(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
73
- """AND gate: output 1 iff both inputs are 1."""
74
- inp = torch.stack([a, b], dim=-1)
75
- w = self._get('boolean.and.weight')
76
- bias = self._get('boolean.and.bias')
77
- return heaviside(inp @ w + bias)
78
-
79
- def eval_or(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
80
- """OR gate: output 1 if either input is 1."""
81
- inp = torch.stack([a, b], dim=-1)
82
- w = self._get('boolean.or.weight')
83
- bias = self._get('boolean.or.bias')
84
- return heaviside(inp @ w + bias)
85
-
86
- def eval_xor(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
87
- """XOR gate: two-layer network (not linearly separable)."""
88
- inp = torch.stack([a, b], dim=-1)
89
-
90
- # Layer 1: OR and NAND neurons
91
- w1_n1 = self._get('boolean.xor.layer1.neuron1.weight')
92
- b1_n1 = self._get('boolean.xor.layer1.neuron1.bias')
93
- w1_n2 = self._get('boolean.xor.layer1.neuron2.weight')
94
- b1_n2 = self._get('boolean.xor.layer1.neuron2.bias')
95
-
96
- h1 = heaviside(inp @ w1_n1 + b1_n1)
97
- h2 = heaviside(inp @ w1_n2 + b1_n2)
98
- hidden = torch.stack([h1, h2], dim=-1)
99
-
100
- # Layer 2: AND of hidden
101
- w2 = self._get('boolean.xor.layer2.weight')
102
- b2 = self._get('boolean.xor.layer2.bias')
103
-
104
- return heaviside(hidden @ w2 + b2)
105
-
106
- # -------------------------------------------------------------------------
107
- # Arithmetic: Full Adder
108
- # -------------------------------------------------------------------------
109
-
110
- def eval_full_adder(self, a: torch.Tensor, b: torch.Tensor,
111
- cin: torch.Tensor, prefix: str) -> Tuple[torch.Tensor, torch.Tensor]:
112
- """
113
- Full adder: sum = a XOR b XOR cin, cout = (a AND b) OR (cin AND (a XOR b))
114
- Returns (sum_bit, carry_out)
115
- """
116
- inp_ab = torch.stack([a, b], dim=-1)
117
-
118
- # HA1: a XOR b
119
- w1_or = self._get(f'{prefix}.ha1.sum.layer1.or.weight')
120
- b1_or = self._get(f'{prefix}.ha1.sum.layer1.or.bias')
121
- w1_nand = self._get(f'{prefix}.ha1.sum.layer1.nand.weight')
122
- b1_nand = self._get(f'{prefix}.ha1.sum.layer1.nand.bias')
123
- w2 = self._get(f'{prefix}.ha1.sum.layer2.weight')
124
- b2 = self._get(f'{prefix}.ha1.sum.layer2.bias')
125
-
126
- h_or = heaviside(inp_ab @ w1_or + b1_or)
127
- h_nand = heaviside(inp_ab @ w1_nand + b1_nand)
128
- hidden = torch.stack([h_or, h_nand], dim=-1)
129
- ha1_sum = heaviside(hidden @ w2 + b2)
130
-
131
- # HA1 carry
132
- w_c1 = self._get(f'{prefix}.ha1.carry.weight')
133
- b_c1 = self._get(f'{prefix}.ha1.carry.bias')
134
- ha1_carry = heaviside(inp_ab @ w_c1 + b_c1)
135
-
136
- # HA2: ha1_sum XOR cin
137
- inp_ha2 = torch.stack([ha1_sum, cin], dim=-1)
138
- w1_or = self._get(f'{prefix}.ha2.sum.layer1.or.weight')
139
- b1_or = self._get(f'{prefix}.ha2.sum.layer1.or.bias')
140
- w1_nand = self._get(f'{prefix}.ha2.sum.layer1.nand.weight')
141
- b1_nand = self._get(f'{prefix}.ha2.sum.layer1.nand.bias')
142
- w2 = self._get(f'{prefix}.ha2.sum.layer2.weight')
143
- b2 = self._get(f'{prefix}.ha2.sum.layer2.bias')
144
-
145
- h_or = heaviside(inp_ha2 @ w1_or + b1_or)
146
- h_nand = heaviside(inp_ha2 @ w1_nand + b1_nand)
147
- hidden = torch.stack([h_or, h_nand], dim=-1)
148
- ha2_sum = heaviside(hidden @ w2 + b2)
149
-
150
- # HA2 carry
151
- w_c2 = self._get(f'{prefix}.ha2.carry.weight')
152
- b_c2 = self._get(f'{prefix}.ha2.carry.bias')
153
- ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2)
154
-
155
- # Carry out = ha1_carry OR ha2_carry
156
- inp_cout = torch.stack([ha1_carry, ha2_carry], dim=-1)
157
- w_or = self._get(f'{prefix}.carry_or.weight')
158
- b_or = self._get(f'{prefix}.carry_or.bias')
159
- cout = heaviside(inp_cout @ w_or + b_or)
160
-
161
- return ha2_sum, cout
162
-
163
- # -------------------------------------------------------------------------
164
- # Arithmetic: 8-bit Ripple Carry Adder
165
- # -------------------------------------------------------------------------
166
-
167
- def add_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
168
- """
169
- 8-bit ripple carry addition.
170
- a_bits, b_bits: [..., 8] tensors of bits (LSB first)
171
- Returns: (result_bits [..., 8], carry_out [...])
172
- """
173
- batch_shape = a_bits.shape[:-1]
174
- carry = torch.zeros(batch_shape, device=a_bits.device)
175
- result_bits = []
176
-
177
- for i in range(8):
178
- a_i = a_bits[..., i]
179
- b_i = b_bits[..., i]
180
- sum_bit, carry = self.eval_full_adder(
181
- a_i, b_i, carry,
182
- f'arithmetic.ripplecarry8bit.fa{i}'
183
- )
184
- result_bits.append(sum_bit)
185
-
186
- return torch.stack(result_bits, dim=-1), carry
187
-
188
- # -------------------------------------------------------------------------
189
- # Arithmetic: 8-bit Comparators
190
- # -------------------------------------------------------------------------
191
-
192
- def greater_than_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
193
- """Returns 1 if a > b, else 0. Bits are MSB first."""
194
- diff = a_bits - b_bits # [..., 8]
195
- w = self._get('arithmetic.greaterthan8bit.comparator')
196
- score = (diff * w).sum(dim=-1)
197
- return (score > 0).float()
198
-
199
- def less_than_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
200
- """Returns 1 if a < b, else 0. Bits are MSB first."""
201
- diff = b_bits - a_bits # [..., 8]
202
- w = self._get('arithmetic.lessthan8bit.comparator')
203
- score = (diff * w).sum(dim=-1)
204
- return (score > 0).float()
205
-
206
- def equal_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
207
- """Returns 1 if a == b, else 0."""
208
- gt = self.greater_than_8bit(a_bits, b_bits)
209
- lt = self.less_than_8bit(a_bits, b_bits)
210
- return (1 - gt) * (1 - lt)
211
-
212
-
213
- # =============================================================================
214
- # BIT EXTRACTION / INJECTION INTERFACES
215
- # =============================================================================
216
-
217
- class BitExtractor(nn.Module):
218
- """
219
- Learns to extract 8-bit operands from token embeddings.
220
- Maps embedding -> 16 bits (two 8-bit operands).
221
- """
222
-
223
- def __init__(self, d_model: int):
224
- super().__init__()
225
- self.d_model = d_model
226
-
227
- # Project to logits, then binarize
228
- self.proj = nn.Linear(d_model, 16)
229
-
230
- # Learnable temperature for sigmoid approximation during training
231
- self.temperature = nn.Parameter(torch.tensor(1.0))
232
-
233
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
- """
235
- x: [..., d_model]
236
- Returns: a_bits [..., 8], b_bits [..., 8] (LSB first for arithmetic)
237
- """
238
- logits = self.proj(x) # [..., 16]
239
-
240
- # Binarize with STE
241
- bits = heaviside(logits)
242
-
243
- # Split into two operands
244
- a_bits = bits[..., :8]
245
- b_bits = bits[..., 8:]
246
-
247
- return a_bits, b_bits
248
-
249
-
250
- class BitInjector(nn.Module):
251
- """
252
- Learns to inject circuit results back into embedding space.
253
- Maps 16 bits (result + flags) -> embedding delta.
254
- """
255
-
256
- def __init__(self, d_model: int):
257
- super().__init__()
258
- self.d_model = d_model
259
-
260
- # Project bits to embedding
261
- self.proj = nn.Linear(16, d_model)
262
-
263
- # Learnable scale
264
- self.scale = nn.Parameter(torch.tensor(0.1))
265
-
266
- def forward(self, result_bits: torch.Tensor, flags: torch.Tensor) -> torch.Tensor:
267
- """
268
- result_bits: [..., 8]
269
- flags: [..., 8] (carry, overflow, zero, negative, etc.)
270
- Returns: [..., d_model]
271
- """
272
- combined = torch.cat([result_bits, flags], dim=-1) # [..., 16]
273
- return self.proj(combined) * self.scale
274
-
275
-
276
- # =============================================================================
277
- # CIRCUIT-AUGMENTED MLP BLOCK
278
- # =============================================================================
279
-
280
- class CircuitAugmentedMLP(nn.Module):
281
- """
282
- MLP block augmented with frozen threshold circuits.
283
-
284
- The original MLP path runs in parallel with the circuit path.
285
- A learned router decides how much to use each.
286
- """
287
-
288
- def __init__(
289
- self,
290
- d_model: int,
291
- intermediate_size: int,
292
- circuit_path: str,
293
- device: str = 'cpu'
294
- ):
295
- super().__init__()
296
- self.d_model = d_model
297
-
298
- # Original MLP components (will be loaded from pretrained)
299
- self.gate_proj = nn.Linear(d_model, intermediate_size, bias=False)
300
- self.up_proj = nn.Linear(d_model, intermediate_size, bias=False)
301
- self.down_proj = nn.Linear(intermediate_size, d_model, bias=False)
302
- self.act_fn = nn.SiLU()
303
-
304
- # Circuit components
305
- self.circuits = CircuitExecutor(circuit_path, device)
306
- self.bit_extractor = BitExtractor(d_model)
307
- self.bit_injector = BitInjector(d_model)
308
-
309
- # Router: decides circuit vs MLP contribution
310
- self.router = nn.Sequential(
311
- nn.Linear(d_model, 64),
312
- nn.ReLU(),
313
- nn.Linear(64, 2),
314
- nn.Softmax(dim=-1)
315
- )
316
-
317
- # Operation selector (which arithmetic op to perform)
318
- self.op_selector = nn.Sequential(
319
- nn.Linear(d_model, 32),
320
- nn.ReLU(),
321
- nn.Linear(32, 4), # add, sub, compare, passthrough
322
- nn.Softmax(dim=-1)
323
- )
324
-
325
- def _compute_flags(self, result_bits: torch.Tensor, carry: torch.Tensor) -> torch.Tensor:
326
- """Compute status flags from result."""
327
- batch_shape = result_bits.shape[:-1]
328
-
329
- # Zero flag: all bits are 0
330
- zero = (result_bits.sum(dim=-1) == 0).float()
331
-
332
- # Negative flag: MSB is 1 (two's complement)
333
- negative = result_bits[..., 7]
334
-
335
- # Carry flag
336
- carry_flag = carry
337
-
338
- # Pad to 8 flags
339
- flags = torch.zeros(*batch_shape, 8, device=result_bits.device)
340
- flags[..., 0] = zero
341
- flags[..., 1] = negative
342
- flags[..., 2] = carry_flag
343
-
344
- return flags
345
-
346
- def _circuit_forward(self, x: torch.Tensor) -> torch.Tensor:
347
- """Run input through threshold circuits."""
348
- # Extract operands
349
- a_bits, b_bits = self.bit_extractor(x)
350
-
351
- # Get operation weights
352
- op_weights = self.op_selector(x) # [..., 4]
353
-
354
- # Compute addition
355
- add_result, add_carry = self.circuits.add_8bit(a_bits, b_bits)
356
- add_flags = self._compute_flags(add_result, add_carry)
357
-
358
- # Compute subtraction (a + (~b) + 1, simplified: just use add for now)
359
- # For MVP, we'll focus on addition
360
-
361
- # Inject result back
362
- circuit_delta = self.bit_injector(add_result, add_flags)
363
-
364
- return circuit_delta
365
-
366
- def forward(self, x: torch.Tensor) -> torch.Tensor:
367
- """
368
- x: [batch, seq_len, d_model]
369
- Returns: [batch, seq_len, d_model]
370
- """
371
- # Original MLP path
372
- mlp_out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
373
-
374
- # Circuit path
375
- circuit_out = self._circuit_forward(x)
376
-
377
- # Route between paths
378
- route_weights = self.router(x) # [..., 2]
379
- mlp_weight = route_weights[..., 0:1]
380
- circuit_weight = route_weights[..., 1:2]
381
-
382
- # Combine: MLP output + weighted circuit contribution
383
- output = mlp_out + circuit_weight * circuit_out
384
-
385
- return output
386
-
387
-
388
- # =============================================================================
389
- # MODEL SURGERY: Insert circuits into SmolLM2
390
- # =============================================================================
391
-
392
- def augment_smollm2_with_circuits(
393
- model: AutoModelForCausalLM,
394
- circuit_path: str,
395
- layer_indices: list = None,
396
- device: str = 'cpu'
397
- ) -> AutoModelForCausalLM:
398
- """
399
- Surgically insert circuit blocks into SmolLM2's MLP layers.
400
-
401
- Args:
402
- model: Pretrained SmolLM2 model
403
- circuit_path: Path to neural_computer.safetensors
404
- layer_indices: Which layers to augment (default: middle layers)
405
- device: Device for circuit tensors
406
-
407
- Returns:
408
- Modified model with circuit-augmented MLPs
409
- """
410
- config = model.config
411
- num_layers = config.num_hidden_layers
412
-
413
- # Default: augment middle third of layers
414
- if layer_indices is None:
415
- start = num_layers // 3
416
- end = 2 * num_layers // 3
417
- layer_indices = list(range(start, end))
418
-
419
- print(f"Augmenting layers {layer_indices} with threshold circuits...")
420
-
421
- for idx in layer_indices:
422
- layer = model.model.layers[idx]
423
- old_mlp = layer.mlp
424
-
425
- # Create augmented MLP
426
- new_mlp = CircuitAugmentedMLP(
427
- d_model=config.hidden_size,
428
- intermediate_size=config.intermediate_size,
429
- circuit_path=circuit_path,
430
- device=device
431
- )
432
-
433
- # Copy pretrained weights
434
- new_mlp.gate_proj.weight.data = old_mlp.gate_proj.weight.data.clone()
435
- new_mlp.up_proj.weight.data = old_mlp.up_proj.weight.data.clone()
436
- new_mlp.down_proj.weight.data = old_mlp.down_proj.weight.data.clone()
437
-
438
- # Replace
439
- layer.mlp = new_mlp
440
-
441
- # Freeze circuit weights, keep interfaces trainable
442
- for name, param in model.named_parameters():
443
- if 'circuits' in name:
444
- param.requires_grad = False
445
-
446
- print(f"Done. Circuit weights frozen, interfaces trainable.")
447
-
448
- return model
449
-
450
-
451
- # =============================================================================
452
- # TRAINING UTILITIES
453
- # =============================================================================
454
-
455
- def generate_arithmetic_batch(batch_size: int, max_val: int = 255) -> Tuple[list, list]:
456
- """Generate batch of arithmetic problems and solutions."""
457
- prompts = []
458
- targets = []
459
-
460
- for _ in range(batch_size):
461
- a = torch.randint(0, max_val + 1, (1,)).item()
462
- b = torch.randint(0, max_val + 1, (1,)).item()
463
- result = (a + b) % 256
464
-
465
- prompts.append(f"{a} + {b} =")
466
- targets.append(f" {result}")
467
-
468
- return prompts, targets
469
-
470
-
471
- def evaluate_arithmetic(
472
- model: AutoModelForCausalLM,
473
- tokenizer: AutoTokenizer,
474
- n_problems: int = 100,
475
- device: str = 'cpu'
476
- ) -> dict:
477
- """Evaluate model on random arithmetic problems."""
478
- correct = 0
479
- total = 0
480
- errors = []
481
-
482
- model.eval()
483
-
484
- for _ in range(n_problems):
485
- a = torch.randint(0, 256, (1,)).item()
486
- b = torch.randint(0, 256, (1,)).item()
487
- expected = (a + b) % 256
488
-
489
- prompt = f"{a} + {b} ="
490
- inputs = tokenizer(prompt, return_tensors='pt').to(device)
491
-
492
- with torch.no_grad():
493
- outputs = model.generate(
494
- **inputs,
495
- max_new_tokens=10,
496
- do_sample=False,
497
- pad_token_id=tokenizer.eos_token_id
498
- )
499
-
500
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
501
-
502
- # Extract number from response
503
- try:
504
- # Find the part after "="
505
- answer_part = response.split('=')[-1].strip()
506
- # Extract first number
507
- predicted = int(''.join(c for c in answer_part.split()[0] if c.isdigit()))
508
-
509
- if predicted == expected:
510
- correct += 1
511
- else:
512
- errors.append((a, b, expected, predicted))
513
- except:
514
- errors.append((a, b, expected, "parse_error"))
515
-
516
- total += 1
517
-
518
- return {
519
- 'accuracy': correct / total,
520
- 'correct': correct,
521
- 'total': total,
522
- 'errors': errors[:10] # First 10 errors
523
- }
524
-
525
-
526
- # =============================================================================
527
- # MAIN: Demo
528
- # =============================================================================
529
-
530
- if __name__ == "__main__":
531
- import argparse
532
-
533
- parser = argparse.ArgumentParser(description='Circuit-Augmented LLM Demo')
534
- parser.add_argument('--circuit-path', type=str,
535
- default='./neural_computer.safetensors',
536
- help='Path to circuit weights')
537
- parser.add_argument('--device', type=str, default='cpu',
538
- help='Device (cpu or cuda)')
539
- parser.add_argument('--eval-only', action='store_true',
540
- help='Only evaluate, do not augment')
541
- args = parser.parse_args()
542
-
543
- print("=" * 70)
544
- print(" CIRCUIT-AUGMENTED LLM")
545
- print("=" * 70)
546
-
547
- # Load tokenizer and model
548
- print("\n[1] Loading SmolLM2-360M...")
549
- model_id = "HuggingFaceTB/SmolLM2-360M"
550
- tokenizer = AutoTokenizer.from_pretrained(model_id)
551
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
552
-
553
- print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
554
-
555
- # Baseline evaluation
556
- print("\n[2] Baseline arithmetic evaluation...")
557
- baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
558
- print(f" Accuracy: {baseline['accuracy']*100:.1f}% ({baseline['correct']}/{baseline['total']})")
559
- if baseline['errors']:
560
- print(f" Sample errors:")
561
- for a, b, exp, got in baseline['errors'][:5]:
562
- print(f" {a} + {b} = {exp}, model said {got}")
563
-
564
- if args.eval_only:
565
- print("\nDone (eval only mode).")
566
- exit(0)
567
-
568
- # Augment with circuits
569
- print(f"\n[3] Augmenting with threshold circuits...")
570
- print(f" Circuit path: {args.circuit_path}")
571
- model = augment_smollm2_with_circuits(
572
- model,
573
- args.circuit_path,
574
- device=args.device
575
- )
576
-
577
- new_params = sum(p.numel() for p in model.parameters())
578
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
579
- print(f" Total parameters: {new_params:,}")
580
- print(f" Trainable parameters: {trainable:,}")
581
-
582
- # Test circuit execution directly
583
- print("\n[4] Testing circuit execution...")
584
- circuit_exec = CircuitExecutor(args.circuit_path, args.device)
585
-
586
- test_cases = [(127, 128), (255, 1), (0, 0), (100, 55)]
587
- for a, b in test_cases:
588
- # Convert to bits (LSB first)
589
- a_bits = torch.tensor([(a >> i) & 1 for i in range(8)], dtype=torch.float32)
590
- b_bits = torch.tensor([(b >> i) & 1 for i in range(8)], dtype=torch.float32)
591
-
592
- result_bits, carry = circuit_exec.add_8bit(
593
- a_bits.unsqueeze(0),
594
- b_bits.unsqueeze(0)
595
- )
596
-
597
- # Convert result bits back to int
598
- result = sum(int(result_bits[0, i].item()) * (2**i) for i in range(8))
599
- expected = (a + b) % 256
600
-
601
- status = "OK" if result == expected else "FAIL"
602
- print(f" {a} + {b} = {result} (expected {expected}) [{status}]")
603
-
604
- print("\n[5] Model ready for fine-tuning.")
605
- print(" Next: Train interface layers on arithmetic examples.")
606
- print("=" * 70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
guide.md DELETED
@@ -1,615 +0,0 @@
1
- # Embedding Threshold Logic Circuits into Transformer MLPs
2
-
3
- ## Technical Implementation Guide
4
-
5
- ---
6
-
7
- ## 1. Core Thesis
8
-
9
- Standard LLMs fail at arithmetic because they're interpolators—they approximate functions over training distributions rather than compute exact results. A 360M parameter model trained on internet text has seen "127 + 128 = 255" zero or few times, so it guesses "140" based on pattern matching.
10
-
11
- We solve this by embedding **frozen, proven-correct arithmetic circuits** directly into the transformer's MLP layers. The circuits use threshold logic (weighted sums + step activation), which is structurally compatible with neural network layers. We train only the **interface layers** that learn to:
12
-
13
- 1. Extract operands from token embeddings
14
- 2. Route computation through the circuits
15
- 3. Inject results back into the residual stream
16
-
17
- The model learns **call dispatch**, not arithmetic. The arithmetic is already solved.
18
-
19
- ---
20
-
21
- ## 2. Threshold Logic Fundamentals
22
-
23
- ### 2.1 Single Threshold Gate
24
-
25
- A threshold gate computes:
26
-
27
- ```
28
- output = 1 if (Σ wᵢxᵢ + b) ≥ 0
29
- 0 otherwise
30
- ```
31
-
32
- This is a neuron with Heaviside step activation. With integer weights `w` and bias `b`, it computes a Boolean function of binary inputs.
33
-
34
- **Example: AND gate**
35
- ```
36
- w = [1, 1], b = -2
37
- AND(0,0) = H(0 + 0 - 2) = H(-2) = 0
38
- AND(0,1) = H(0 + 1 - 2) = H(-1) = 0
39
- AND(1,0) = H(1 + 0 - 2) = H(-1) = 0
40
- AND(1,1) = H(1 + 1 - 2) = H(0) = 1
41
- ```
42
-
43
- **Example: OR gate**
44
- ```
45
- w = [1, 1], b = -1
46
- OR(0,0) = H(0 + 0 - 1) = H(-1) = 0
47
- OR(0,1) = H(0 + 1 - 1) = H(0) = 1
48
- OR(1,0) = H(1 + 0 - 1) = H(0) = 1
49
- OR(1,1) = H(1 + 1 - 1) = H(1) = 1
50
- ```
51
-
52
- ### 2.2 Multi-Layer Circuits
53
-
54
- XOR is not linearly separable—it requires two layers:
55
-
56
- ```
57
- Layer 1:
58
- neuron1 (OR): w=[1,1], b=-1 → fires if a OR b
59
- neuron2 (NAND): w=[-1,-1], b=1 → fires if NOT(a AND b)
60
-
61
- Layer 2:
62
- neuron3 (AND): w=[1,1], b=-2 → fires if both layer1 outputs are 1
63
-
64
- XOR(a,b) = AND(OR(a,b), NAND(a,b))
65
- ```
66
-
67
- ### 2.3 Full Adder
68
-
69
- A full adder computes `sum` and `carry_out` from inputs `a`, `b`, `carry_in`:
70
-
71
- ```
72
- sum = a XOR b XOR cin
73
- cout = (a AND b) OR (cin AND (a XOR b))
74
- ```
75
-
76
- Implementation uses two half-adders chained:
77
-
78
- ```
79
- HA1: (a, b) → (sum1 = a XOR b, carry1 = a AND b)
80
- HA2: (sum1, cin) → (sum2 = sum1 XOR cin, carry2 = sum1 AND cin)
81
- cout = carry1 OR carry2
82
- final_sum = sum2
83
- ```
84
-
85
- Each XOR is 2 layers, each AND/OR is 1 layer. Total depth: ~4 layers per full adder.
86
-
87
- ### 2.4 8-bit Ripple Carry Adder
88
-
89
- Chain 8 full adders, propagating carry:
90
-
91
- ```
92
- FA0: (a[0], b[0], 0) → (sum[0], c0)
93
- FA1: (a[1], b[1], c0) → (sum[1], c1)
94
- FA2: (a[2], b[2], c1) → (sum[2], c2)
95
- ...
96
- FA7: (a[7], b[7], c6) → (sum[7], c7)
97
- ```
98
-
99
- Total circuit depth: ~32 threshold layers (8 FAs × 4 layers each).
100
-
101
- ---
102
-
103
- ## 3. Circuit Inventory
104
-
105
- The `neural_computer.safetensors` contains 3,122 tensors / 5,648 parameters implementing:
106
-
107
- | Category | Circuits | Tensors |
108
- |----------|----------|---------|
109
- | Boolean | AND, OR, NOT, NAND, NOR, XOR, XNOR, IMPLIES, BIIMPLIES | ~30 |
110
- | Arithmetic | Half adder, Full adder, Ripple carry 2/4/8-bit, 8×8 multiplier | ~800 |
111
- | Comparators | GT, LT, GEQ, LEQ, EQ (8-bit) | ~50 |
112
- | ALU | 16-operation ALU, opcode decoder, flag computation | ~400 |
113
- | Control | JMP, JZ, JNZ, JC, JNC, JN, JP, CALL, RET, PUSH, POP | ~200 |
114
- | Modular | Divisibility by 2-12 | ~600 |
115
- | Error Detection | Parity, CRC, Hamming, checksum | ~200 |
116
- | Pattern | Popcount, leading zeros, symmetry | ~150 |
117
- | Threshold | k-of-n gates, majority, minority | ~100 |
118
-
119
- All weights are integers. All activations are Heaviside. Verified with 6,590 exhaustive tests.
120
-
121
- ---
122
-
123
- ## 4. Transformer Integration Architecture
124
-
125
- ### 4.1 Target: SmolLM2-360M
126
-
127
- ```
128
- Architecture: LlamaForCausalLM
129
- Hidden dim: 960
130
- Layers: 32
131
- Heads: 15
132
- MLP expansion: 4x (intermediate = 3840)
133
- Vocab: 49152
134
- Parameters: 361,821,120
135
- ```
136
-
137
- Standard MLP block:
138
- ```python
139
- def forward(x): # x: [batch, seq, 960]
140
- gate = self.gate_proj(x) # [batch, seq, 3840]
141
- up = self.up_proj(x) # [batch, seq, 3840]
142
- hidden = silu(gate) * up # SwiGLU activation
143
- return self.down_proj(hidden) # [batch, seq, 960]
144
- ```
145
-
146
- ### 4.2 Augmented MLP Block
147
-
148
- ```python
149
- def forward(x): # x: [batch, seq, 960]
150
- # Original MLP path (unchanged)
151
- mlp_out = self.down_proj(silu(self.gate_proj(x)) * self.up_proj(x))
152
-
153
- # Circuit path (new)
154
- a_bits, b_bits = self.bit_extractor(x) # [batch, seq, 8] each
155
- result_bits, carry = self.circuits.add_8bit(a_bits, b_bits)
156
- flags = self.compute_flags(result_bits, carry)
157
- circuit_delta = self.bit_injector(result_bits, flags)
158
-
159
- # Routing
160
- route_weights = self.router(x) # [batch, seq, 2] softmax
161
-
162
- # Combine
163
- return mlp_out + route_weights[..., 1:2] * circuit_delta
164
- ```
165
-
166
- ### 4.3 Layer Selection
167
-
168
- We augment the **middle third** of layers (10-20 of 32):
169
-
170
- - Early layers (0-9): Token/position encoding, not arithmetic-relevant
171
- - Middle layers (10-20): Abstract reasoning, computation
172
- - Late layers (21-31): Output formatting, vocabulary projection
173
-
174
- Rationale: Arithmetic computation happens in middle layers where the model processes relationships between tokens. Early layers haven't built sufficient representations; late layers are committed to output tokens.
175
-
176
- ---
177
-
178
- ## 5. Interface Layers (Trainable)
179
-
180
- ### 5.1 BitExtractor
181
-
182
- Maps token embedding → two 8-bit operands.
183
-
184
- ```python
185
- class BitExtractor(nn.Module):
186
- def __init__(self, d_model=960):
187
- self.proj = nn.Linear(d_model, 16) # 960 → 16
188
-
189
- def forward(self, x):
190
- logits = self.proj(x) # [batch, seq, 16]
191
- bits = heaviside(logits) # binarize with STE
192
- a_bits = bits[..., :8] # first operand
193
- b_bits = bits[..., 8:] # second operand
194
- return a_bits, b_bits # both [batch, seq, 8], LSB first
195
- ```
196
-
197
- **What it learns**: Which embedding dimensions encode numeric magnitude. For token "127", it must learn that certain activation patterns correspond to bits `[1,1,1,1,1,1,1,0]`.
198
-
199
- **Parameters**: 960 × 16 + 16 = 15,376
200
-
201
- ### 5.2 BitInjector
202
-
203
- Maps circuit outputs → embedding delta.
204
-
205
- ```python
206
- class BitInjector(nn.Module):
207
- def __init__(self, d_model=960):
208
- self.proj = nn.Linear(16, d_model) # 16 → 960
209
- self.scale = nn.Parameter(torch.tensor(0.1))
210
-
211
- def forward(self, result_bits, flags):
212
- combined = torch.cat([result_bits, flags], dim=-1) # [batch, seq, 16]
213
- return self.proj(combined) * self.scale # [batch, seq, 960]
214
- ```
215
-
216
- **What it learns**: How to inject the result bits back into embedding space such that subsequent layers (and the final vocabulary projection) produce the correct output tokens.
217
-
218
- **Parameters**: 16 × 960 + 960 + 1 = 16,321
219
-
220
- ### 5.3 Router
221
-
222
- Decides when to use circuit path.
223
-
224
- ```python
225
- class Router(nn.Module):
226
- def __init__(self, d_model=960):
227
- self.net = nn.Sequential(
228
- nn.Linear(d_model, 64),
229
- nn.ReLU(),
230
- nn.Linear(64, 2),
231
- nn.Softmax(dim=-1)
232
- )
233
-
234
- def forward(self, x):
235
- return self.net(x) # [batch, seq, 2]: [mlp_weight, circuit_weight]
236
- ```
237
-
238
- **What it learns**: "This position contains arithmetic" → route through circuits. "This is prose" → use normal MLP.
239
-
240
- **Parameters**: 960 × 64 + 64 + 64 × 2 + 2 = 61,698
241
-
242
- ### 5.4 Total Trainable Parameters
243
-
244
- Per augmented layer:
245
- ```
246
- BitExtractor: 15,376
247
- BitInjector: 16,321
248
- Router: 61,698
249
- OpSelector: ~31,000
250
- ───────────────────────
251
- Total: ~124,395 per layer
252
- ```
253
-
254
- For 11 augmented layers: **~1.37M trainable parameters**
255
-
256
- This is 0.38% of the model. The other 99.62% (including all circuit weights) is frozen.
257
-
258
- ---
259
-
260
- ## 6. Gradient Flow Through Heaviside
261
-
262
- ### 6.1 The Problem
263
-
264
- Heaviside has zero gradient almost everywhere:
265
-
266
- ```
267
- H(x) = 1 if x ≥ 0 else 0
268
- dH/dx = 0 for x ≠ 0, undefined at x = 0
269
- ```
270
-
271
- Standard backprop would give zero gradients to BitExtractor.
272
-
273
- ### 6.2 Straight-Through Estimator (STE)
274
-
275
- We use STE: forward pass uses true Heaviside, backward pass pretends it's identity.
276
-
277
- ```python
278
- class HeavisideSTE(torch.autograd.Function):
279
- @staticmethod
280
- def forward(ctx, x):
281
- return (x >= 0).float() # true step function
282
-
283
- @staticmethod
284
- def backward(ctx, grad_output):
285
- return grad_output # pass gradient through unchanged
286
- ```
287
-
288
- **Intuition**: "If making the input larger would have helped the output, increase the input." The gradient tells us the direction even though the function is flat.
289
-
290
- ### 6.3 Alternative: Sigmoid Annealing
291
-
292
- During training, use sigmoid with increasing temperature:
293
-
294
- ```python
295
- def soft_heaviside(x, temperature):
296
- return torch.sigmoid(x * temperature)
297
-
298
- # temperature: 1 → 10 → 100 over training
299
- # At high temperature, sigmoid ≈ step function
300
- ```
301
-
302
- This provides smoother gradients early in training, then sharpens to true binary at inference.
303
-
304
- ---
305
-
306
- ## 7. Training Strategy
307
-
308
- ### 7.1 Data Generation
309
-
310
- Generate arithmetic problems exhaustively:
311
-
312
- ```python
313
- def generate_batch(batch_size):
314
- a = torch.randint(0, 256, (batch_size,))
315
- b = torch.randint(0, 256, (batch_size,))
316
- result = (a + b) % 256
317
-
318
- prompts = [f"{a[i]} + {b[i]} =" for i in range(batch_size)]
319
- targets = [f" {result[i]}" for i in range(batch_size)]
320
-
321
- return prompts, targets
322
- ```
323
-
324
- For 8-bit addition, there are 256 × 256 = 65,536 unique problems. We can cover the entire space.
325
-
326
- ### 7.2 Loss Function
327
-
328
- Standard cross-entropy on next-token prediction:
329
-
330
- ```python
331
- outputs = model(input_ids, attention_mask=mask, labels=labels)
332
- loss = outputs.loss # CE loss, only on target tokens
333
- ```
334
-
335
- Labels are masked for prompt tokens (`-100`), so loss only backprops through the answer.
336
-
337
- ### 7.3 Optimizer Configuration
338
-
339
- ```python
340
- # Only train interface layers
341
- interface_params = [p for n, p in model.named_parameters()
342
- if any(x in n for x in ['bit_extractor', 'bit_injector', 'router'])]
343
-
344
- optimizer = AdamW(interface_params, lr=1e-4, weight_decay=0.01)
345
- scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
346
- ```
347
-
348
- ### 7.4 Curriculum Learning
349
-
350
- Start simple, increase difficulty:
351
-
352
- ```
353
- Phase 1 (epochs 1-2): Single-digit addition (0-9 + 0-9)
354
- Phase 2 (epochs 3-4): Two-digit addition (0-99 + 0-99)
355
- Phase 3 (epochs 5-7): Full 8-bit addition (0-255 + 0-255)
356
- Phase 4 (epochs 8-10): Adversarial cases (carry chains: 127+128, 255+1)
357
- ```
358
-
359
- This helps the interface layers learn the basic extraction pattern before tackling hard cases.
360
-
361
- ### 7.5 Training Hyperparameters
362
-
363
- ```
364
- Model: SmolLM2-360M
365
- Augmented: Layers 10-20 (11 layers)
366
- Trainable: 1.37M parameters
367
- Frozen: 362M parameters (including 5.6K circuit params)
368
-
369
- Batch size: 32
370
- Learning rate: 1e-4
371
- Epochs: 10
372
- Samples: 10,000 per epoch
373
- Warmup: 500 steps
374
- Device: RTX 6000 Ada (48GB)
375
-
376
- Expected time: ~30 minutes total
377
- ```
378
-
379
- ---
380
-
381
- ## 8. Forward Pass Walkthrough
382
-
383
- Input: `"127 + 128 ="`
384
-
385
- ### 8.1 Tokenization
386
-
387
- ```
388
- Tokens: ["127", " +", " 128", " ="]
389
- IDs: [12700, 489, 13824, 284] # hypothetical
390
- ```
391
-
392
- ### 8.2 Embedding
393
-
394
- ```
395
- embeddings = embed(input_ids) # [1, 4, 960]
396
- ```
397
-
398
- ### 8.3 Layers 0-9 (Unchanged)
399
-
400
- Standard attention + MLP, building representations.
401
-
402
- ### 8.4 Layer 10 (Augmented)
403
-
404
- ```python
405
- # After attention
406
- x = layer_norm(attn_output + residual) # [1, 4, 960]
407
-
408
- # MLP path
409
- mlp_out = down_proj(silu(gate_proj(x)) * up_proj(x))
410
-
411
- # Circuit path
412
- a_bits, b_bits = bit_extractor(x)
413
- # Position 0 ("127"): a_bits ≈ [1,1,1,1,1,1,1,0] if well-trained
414
- # Position 2 ("128"): b_bits ≈ [0,0,0,0,0,0,0,1]
415
- # (In practice, extraction happens per-position; aggregation is learned)
416
-
417
- result_bits, carry = circuits.add_8bit(a_bits, b_bits)
418
- # result_bits = [1,1,1,1,1,1,1,1] = 255
419
-
420
- flags = compute_flags(result_bits, carry)
421
- # zero=0, negative=1, carry=1
422
-
423
- circuit_delta = bit_injector(result_bits, flags) # [1, 4, 960]
424
-
425
- # Routing
426
- route = router(x) # [1, 4, 2]
427
- # Position 3 ("="): route ≈ [0.1, 0.9] → use circuits
428
- # Position 1 ("+"): route ≈ [0.8, 0.2] → mostly MLP
429
-
430
- # Combine
431
- output = mlp_out + route[..., 1:2] * circuit_delta
432
- ```
433
-
434
- ### 8.5 Layers 11-31
435
-
436
- Continue processing, eventually projecting to vocabulary.
437
-
438
- ### 8.6 Output
439
-
440
- ```
441
- logits = lm_head(final_hidden) # [1, 4, 49152]
442
- next_token = argmax(logits[0, 3, :]) # token after "="
443
- # Should decode to "255" (possibly as " 255" or "255")
444
- ```
445
-
446
- ---
447
-
448
- ## 9. Inference Characteristics
449
-
450
- ### 9.1 Exactness
451
-
452
- At inference, Heaviside is true step function—no approximation. If BitExtractor correctly maps "127" → bits and "128" → bits, the circuit **will** output 255. The only failure mode is incorrect extraction.
453
-
454
- ### 9.2 Latency
455
-
456
- Circuit computation adds ~5-10% overhead:
457
- - BitExtractor: 1 linear layer (960→16)
458
- - Circuits: ~32 threshold layers, but sparse and tiny
459
- - BitInjector: 1 linear layer (16→960)
460
- - Router: 2 linear layers
461
-
462
- The circuits have only 5,648 parameters total—negligible versus the 361M in the base model.
463
-
464
- ### 9.3 Generalization
465
-
466
- Once the interface learns the mapping, it generalizes to **all** 65,536 8-bit additions. There's no memorization—the circuits compute.
467
-
468
- ---
469
-
470
- ## 10. Evaluation Metrics
471
-
472
- ### 10.1 Arithmetic Accuracy
473
-
474
- ```python
475
- def eval_accuracy(model, n_problems=1000):
476
- correct = 0
477
- for _ in range(n_problems):
478
- a, b = random 8-bit values
479
- expected = (a + b) % 256
480
- predicted = model.generate(f"{a} + {b} =")
481
- if parse_int(predicted) == expected:
482
- correct += 1
483
- return correct / n_problems
484
- ```
485
-
486
- **Baseline SmolLM2**: ~5-10% (guessing based on patterns)
487
- **Target**: >95% (circuit-accurate)
488
-
489
- ### 10.2 Edge Case Performance
490
-
491
- Specifically test:
492
- - Carry propagation: 127+128, 255+1, 128+128
493
- - Zeros: 0+0, 0+255
494
- - Identity: x+0 for various x
495
- - Commutativity: verify a+b == b+a
496
-
497
- ### 10.3 Non-Arithmetic Preservation
498
-
499
- Verify general capability isn't degraded:
500
- - Perplexity on held-out text
501
- - Common benchmarks (HellaSwag, etc.)
502
-
503
- The augmentation should be **additive**—circuits help arithmetic, MLP handles everything else via routing.
504
-
505
- ---
506
-
507
- ## 11. Extension Roadmap
508
-
509
- ### 11.1 Additional Operations
510
-
511
- The circuit inventory includes:
512
- - Subtraction (via two's complement)
513
- - Multiplication (8×8 → 16-bit)
514
- - Division (iterative subtraction)
515
- - Bitwise ops (AND, OR, XOR, shifts)
516
- - Comparisons (GT, LT, EQ)
517
-
518
- Each needs its own extraction/injection interface, or a unified interface with operation selection.
519
-
520
- ### 11.2 Multi-Operand Expressions
521
-
522
- For "15 + 27 + 33 =", need:
523
- - Operand count detection
524
- - Sequential circuit invocation
525
- - Accumulator pattern
526
-
527
- ### 11.3 Larger Bit Widths
528
-
529
- 16-bit and 32-bit arithmetic require:
530
- - Larger circuits (or chained 8-bit)
531
- - Wider BitExtractor (32 or 64 output dims)
532
- - More training data
533
-
534
- ### 11.4 Symbolic Integration
535
-
536
- Ultimate goal: the model recognizes when it needs to compute, invokes circuits, and integrates results into coherent natural language output.
537
-
538
- ```
539
- User: "If I have 127 apples and buy 128 more, how many do I have?"
540
- Model: [extracts 127, 128] [routes to circuit] [gets 255]
541
- "You would have 255 apples."
542
- ```
543
-
544
- ---
545
-
546
- ## 12. File Structure
547
-
548
- ```
549
- 8bit-threshold-computer/
550
- ├── neural_computer.safetensors # Frozen circuits (3,122 tensors)
551
- ├── circuit_llm.py # Integration architecture
552
- ├── train_circuit_interface.py # Training loop
553
- ├── iron_eval.py # Circuit verification (6,590 tests)
554
- ├── skeptic_test.py # Algebraic identity tests (127 tests)
555
- ├── prune_weights.py # Weight optimization
556
- ├── tensors.txt # Tensor manifest
557
- ├── guide.md # This document
558
- └── README.md # Project overview
559
- ```
560
-
561
- ---
562
-
563
- ## 13. Key Equations
564
-
565
- ### Heaviside Step
566
- ```
567
- H(x) = 1 if x ≥ 0 else 0
568
- ```
569
-
570
- ### Threshold Gate
571
- ```
572
- f(x₁,...,xₙ) = H(Σᵢ wᵢxᵢ + b)
573
- ```
574
-
575
- ### Full Adder
576
- ```
577
- sum = a ⊕ b ⊕ cᵢₙ
578
- cₒᵤₜ = (a ∧ b) ∨ (cᵢₙ ∧ (a ⊕ b))
579
- ```
580
-
581
- ### STE Gradient
582
- ```
583
- Forward: y = H(x)
584
- Backward: ∂L/∂x = ∂L/∂y
585
- ```
586
-
587
- ### Router Combination
588
- ```
589
- output = mlp_out + softmax(router(x))[1] × circuit_delta
590
- ```
591
-
592
- ---
593
-
594
- ## 14. References
595
-
596
- 1. McCulloch & Pitts (1943). "A Logical Calculus of Ideas Immanent in Nervous Activity"
597
- 2. Muroga (1971). "Threshold Logic and Its Applications"
598
- 3. Siegelmann & Sontag (1995). "On the Computational Power of Neural Nets"
599
- 4. Bengio et al. (2013). "Estimating or Propagating Gradients Through Stochastic Neurons"
600
- 5. Ma et al. (2024). "The Era of 1-bit LLMs" (BitNet b1.58)
601
- 6. HuggingFace (2024). "SmolLM2: Small Language Models"
602
-
603
- ---
604
-
605
- ## 15. Summary
606
-
607
- We embed a proven-correct 8-bit threshold logic computer into SmolLM2's MLP layers. The circuits are frozen; we train only the interface layers that learn call dispatch. This gives the LLM exact arithmetic capability without training it to "do math"—the math is already done.
608
-
609
- The approach is:
610
- - **Sound**: Circuits verified with 6,590 tests
611
- - **Efficient**: 1.37M trainable params, 5.6K circuit params
612
- - **Exact**: Heaviside at inference means no approximation error
613
- - **Composable**: Add more circuits (multiply, compare, etc.) with same pattern
614
-
615
- The model learns when to call the calculator, not how to calculate.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm/todo.md DELETED
@@ -1,46 +0,0 @@
1
- Step 1: Install dependencies
2
-
3
- pip install torch transformers safetensors tqdm
4
-
5
- Step 2: Test circuit execution standalone
6
-
7
- cd C:\Users\cnort\8bit-threshold-computer\llm
8
- python circuit_llm.py --circuit-path ../neural_computer.safetensors --device cpu
9
-
10
- Expected output:
11
- - Load SmolLM2-360M
12
- - Baseline arithmetic evaluation (~1-10% accuracy)
13
- - Augment layers 10-20 with threshold circuits
14
- - Test circuit execution (127+128=255, etc.)
15
-
16
- Step 3: Train interface layers (CPU)
17
-
18
- python train_circuit_interface.py \
19
- --circuit-path ../neural_computer.safetensors \
20
- --device cpu \
21
- --epochs 3 \
22
- --batch-size 4 \
23
- --n-samples 2000
24
-
25
- Expected:
26
- - Freeze 362M params
27
- - Train ~1.37M interface params (BitExtractor, BitInjector, Router)
28
- - 3 epochs on arithmetic examples (reduced samples for CPU)
29
- - Accuracy should improve from ~10% to >90%
30
- - Training time: ~1-2 hours on CPU
31
-
32
- Step 4: Save trained model
33
-
34
- Output: circuit_augmented_smollm2.pt
35
-
36
- Hardware Requirements
37
-
38
- - GPU recommended (RTX 6000 Ada per guide.md)
39
- - CPU possible but slow (~10x longer)
40
- - ~4GB VRAM for SmolLM2-360M
41
-
42
- Success Criteria
43
-
44
- - Baseline accuracy: ~1-10% (SmolLM2 guessing)
45
- - Post-training accuracy: >98% (circuits computing)
46
- - Circuit tests pass: 127+128=255, 255+1=0, etc.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prune_weights.py DELETED
@@ -1,481 +0,0 @@
1
- """
2
- BATCHED WEIGHT PRUNING (GPU-optimized)
3
- ======================================
4
- Phase 1: Batch eval all candidates in parallel
5
- Phase 2: Apply all successes at once, binary search if conflicts
6
- """
7
-
8
- import torch
9
- import time
10
- import argparse
11
- from safetensors.torch import save_file
12
- from iron_eval import BatchedFitnessEvaluator, create_population, load_model
13
-
14
- torch.manual_seed(0)
15
-
16
-
17
- def format_time(seconds):
18
- if seconds < 60:
19
- return f"{seconds:.1f}s"
20
- elif seconds < 3600:
21
- return f"{seconds/60:.1f}m"
22
- else:
23
- return f"{seconds/3600:.1f}h"
24
-
25
-
26
- def format_eta(elapsed, done, total):
27
- if done == 0:
28
- return "calculating..."
29
- rate = done / elapsed
30
- remaining = (total - done) / rate
31
- return format_time(remaining)
32
-
33
-
34
- def apply_reductions(model, reductions):
35
- """Apply a list of (name, flat_idx, shape, old_val) reductions."""
36
- for name, flat_idx, shape, old_val in reductions:
37
- new_val = old_val - 1 if old_val > 0 else old_val + 1
38
- flat = model[name].flatten()
39
- if flat[flat_idx].item() == old_val:
40
- flat[flat_idx] = new_val
41
- model[name] = flat.view(shape)
42
-
43
-
44
- def revert_reductions(model, reductions):
45
- """Revert a list of reductions."""
46
- for name, flat_idx, shape, old_val in reductions:
47
- flat = model[name].flatten()
48
- new_val = old_val - 1 if old_val > 0 else old_val + 1
49
- if flat[flat_idx].item() == new_val:
50
- flat[flat_idx] = old_val
51
- model[name] = flat.view(shape)
52
-
53
-
54
- def check_fitness(model, evaluator, device):
55
- """Check model fitness."""
56
- torch.manual_seed(0)
57
- pop = create_population(model, 1, device)
58
- return evaluator.evaluate(pop, debug=False)[0].item()
59
-
60
-
61
- def sequential_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
62
- """
63
- Sequential fallback - tests and applies reductions one at a time.
64
- Slower but guarantees no interaction bugs.
65
- """
66
- accepted = []
67
- for i, (name, flat_idx, shape, old_val) in enumerate(candidates):
68
- apply_reductions(model, [(name, flat_idx, shape, old_val)])
69
- fitness = check_fitness(model, evaluator, device)
70
- if fitness >= 0.9999:
71
- accepted.append((name, flat_idx, shape, old_val))
72
- if (i + 1) % 50 == 0:
73
- current_mag = sum(t.abs().sum().item() for t in model.values())
74
- reduction_pct = 100 * (1 - current_mag / base_magnitude)
75
- print(f" Sequential: {len(accepted)}/{i+1} accepted | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
76
- else:
77
- revert_reductions(model, [(name, flat_idx, shape, old_val)])
78
- return accepted
79
-
80
-
81
- def batched_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
82
- """
83
- Batched binary search - evaluates multiple branches in parallel.
84
- Uses BFS instead of DFS to maximize batching opportunities.
85
- Verifies cumulative effect after each batch to prevent interaction bugs.
86
- """
87
- if len(candidates) == 0:
88
- return []
89
-
90
- # First try all at once
91
- print(f" Trying {len(candidates)} reductions at once...")
92
- apply_reductions(model, candidates)
93
- fitness = check_fitness(model, evaluator, device)
94
-
95
- if fitness >= 0.9999:
96
- current_mag = sum(t.abs().sum().item() for t in model.values())
97
- reduction_pct = 100 * (1 - current_mag / base_magnitude)
98
- print(f" ALL {len(candidates)} OK | fitness={fitness:.6f} | "
99
- f"mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
100
- return candidates
101
-
102
- # Conflict - revert and use batched BFS
103
- revert_reductions(model, candidates)
104
- print(f" CONFLICT (fitness={fitness:.6f}), starting batched resolution...")
105
-
106
- accepted = []
107
- # Queue of (candidate_list, depth) to process
108
- pending = [(candidates, 0)]
109
-
110
- while pending:
111
- # Collect all pending groups for batch evaluation
112
- to_eval = []
113
- for group, depth in pending:
114
- if len(group) == 0:
115
- continue
116
- elif len(group) == 1:
117
- to_eval.append((group, depth, 'single'))
118
- else:
119
- to_eval.append((group, depth, 'group'))
120
-
121
- pending = []
122
-
123
- if not to_eval:
124
- break
125
-
126
- # Build batch: create model variants for each group
127
- batch_size = len(to_eval)
128
- print(f" Batch evaluating {batch_size} groups...")
129
-
130
- # Create population for batch eval
131
- pop = {}
132
- for name, tensor in model.items():
133
- pop[name] = tensor.unsqueeze(0).expand(batch_size, *tensor.shape).clone().to(device)
134
-
135
- # Apply each group's reductions to its population slot
136
- for idx, (group, depth, gtype) in enumerate(to_eval):
137
- for name, flat_idx, shape, old_val in group:
138
- new_val = old_val - 1 if old_val > 0 else old_val + 1
139
- flat_view = pop[name][idx].flatten()
140
- # Check if not already modified in base model
141
- base_val = model[name].flatten()[flat_idx].item()
142
- if base_val == old_val:
143
- flat_view[flat_idx] = new_val
144
-
145
- # Batch evaluate
146
- torch.manual_seed(0)
147
- fitnesses = evaluator.evaluate(pop, debug=False)
148
-
149
- # Process results - collect accepted groups first, then verify
150
- batch_accepted = []
151
- ok_count = 0
152
- conflict_count = 0
153
- fail_count = 0
154
-
155
- for idx, (group, depth, gtype) in enumerate(to_eval):
156
- fit = fitnesses[idx].item()
157
- indent = " " + " " * depth
158
-
159
- if fit >= 0.9999:
160
- batch_accepted.append((group, depth, indent))
161
- ok_count += len(group)
162
- else:
163
- if len(group) == 1:
164
- name, flat_idx, shape, old_val = group[0]
165
- print(f"{indent}[1/1] FAIL {name}[{flat_idx}] | fitness={fit:.6f}")
166
- fail_count += 1
167
- else:
168
- mid = len(group) // 2
169
- left = group[:mid]
170
- right = group[mid:]
171
- print(f"{indent}CONFLICT ({len(group)}) fitness={fit:.6f} -> split {len(left)}+{len(right)}")
172
- pending.append((left, depth + 1))
173
- pending.append((right, depth + 1))
174
- conflict_count += 1
175
-
176
- # Apply all batch-accepted reductions
177
- all_batch_reductions = []
178
- for group, depth, indent in batch_accepted:
179
- apply_reductions(model, group)
180
- all_batch_reductions.extend(group)
181
-
182
- # Verify cumulative effect
183
- if all_batch_reductions:
184
- verify_fitness = check_fitness(model, evaluator, device)
185
- if verify_fitness >= 0.9999:
186
- # All good - commit these reductions
187
- for group, depth, indent in batch_accepted:
188
- current_mag = sum(t.abs().sum().item() for t in model.values())
189
- reduction_pct = 100 * (1 - current_mag / base_magnitude)
190
- if len(group) == 1:
191
- name, flat_idx, shape, old_val = group[0]
192
- print(f"{indent}[1/1] OK {name}[{flat_idx}] | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
193
- else:
194
- print(f"{indent}ALL {len(group)} OK | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
195
- accepted.extend(all_batch_reductions)
196
- print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
197
- else:
198
- # Interaction bug detected - revert and use sequential fallback
199
- print(f" INTERACTION BUG detected (batch fitness={verify_fitness:.6f})")
200
- print(f" Reverting {len(all_batch_reductions)} reductions, falling back to sequential...")
201
- revert_reductions(model, all_batch_reductions)
202
-
203
- # Process each group sequentially
204
- seq_accepted = sequential_conflict_resolution(
205
- model, evaluator, device, all_batch_reductions, base_magnitude
206
- )
207
- accepted.extend(seq_accepted)
208
- print(f" Sequential fallback: {len(seq_accepted)}/{len(all_batch_reductions)} accepted")
209
- else:
210
- print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
211
-
212
- return accepted
213
-
214
-
215
- def prune_weights(
216
- passes: int = 10,
217
- batch_size: int = 5000,
218
- device: str = 'cuda',
219
- checkpoint_path: str = "D:/8bit-threshold-computer/pruned.safetensors"
220
- ):
221
- print("=" * 80)
222
- print(" BATCHED WEIGHT PRUNING (GPU-optimized)")
223
- print("=" * 80)
224
- print(f" Device: {device}")
225
- print(f" Batch size: {batch_size}")
226
- print(f" Max passes: {passes}")
227
- print("=" * 80)
228
-
229
- # Load model
230
- print("\n[1/4] LOADING MODEL...")
231
- load_start = time.perf_counter()
232
- model = load_model()
233
- load_time = time.perf_counter() - load_start
234
-
235
- n_params = sum(t.numel() for t in model.values())
236
- n_tensors = len(model)
237
- base_magnitude = sum(t.abs().sum().item() for t in model.values())
238
- base_max = max(t.abs().max().item() for t in model.values())
239
- nonzero_params = sum((t != 0).sum().item() for t in model.values())
240
-
241
- print(f" Loaded in {load_time:.2f}s")
242
- print(f" Tensors: {n_tensors}")
243
- print(f" Parameters: {n_params}")
244
- print(f" Non-zero parameters: {nonzero_params}")
245
- print(f" Total magnitude: {base_magnitude:.0f}")
246
- print(f" Max weight: {base_max:.0f}")
247
-
248
- # Initialize evaluator
249
- print("\n[2/4] INITIALIZING EVALUATOR...")
250
- eval_start = time.perf_counter()
251
- evaluator = BatchedFitnessEvaluator(device=device)
252
- eval_time = time.perf_counter() - eval_start
253
- print(f" Initialized in {eval_time:.2f}s")
254
-
255
- # Verify initial fitness
256
- print("\n[3/4] VERIFYING BASE MODEL...")
257
- initial_fitness = check_fitness(model, evaluator, device)
258
- print(f" Fitness: {initial_fitness:.6f}")
259
-
260
- if initial_fitness < 0.9999:
261
- print(f" ERROR: Base model fitness {initial_fitness:.6f} < 0.9999")
262
- return None
263
-
264
- print(f" STATUS: PASS")
265
-
266
- # Build parameter list
267
- print("\n[4/4] BUILDING PARAMETER INDEX...")
268
- param_list = []
269
- for name, tensor in model.items():
270
- flat = tensor.flatten()
271
- for i in range(len(flat)):
272
- param_list.append((name, i, tensor.shape))
273
- print(f" Indexed {len(param_list)} parameters")
274
-
275
- # Main pruning loop
276
- print("\n" + "=" * 80)
277
- print(" PRUNING STARTED")
278
- print("=" * 80)
279
-
280
- total_reductions = 0
281
- pruning_start = time.perf_counter()
282
-
283
- for pass_num in range(passes):
284
- torch.manual_seed(0)
285
- pass_start = time.perf_counter()
286
-
287
- print(f"\n{'='*80}")
288
- print(f" PASS {pass_num + 1}/{passes}")
289
- print(f"{'='*80}")
290
-
291
- # Count candidates
292
- candidates = []
293
- for name, idx, shape in param_list:
294
- flat = model[name].flatten()
295
- val = flat[idx].item()
296
- if val != 0:
297
- candidates.append((name, idx, shape, val))
298
-
299
- n_candidates = len(candidates)
300
- print(f"\n Candidates: {n_candidates} non-zero weights")
301
-
302
- if n_candidates == 0:
303
- print(f" No candidates remaining. Stopping.")
304
- break
305
-
306
- # Phase 1: Batch evaluation
307
- print(f"\n PHASE 1: Batch evaluation (testing each reduction independently)")
308
- print(f" " + "-" * 60)
309
- phase1_start = time.perf_counter()
310
- successful_candidates = []
311
- n_batches = (n_candidates + batch_size - 1) // batch_size
312
-
313
- for batch_idx, batch_start_idx in enumerate(range(0, n_candidates, batch_size)):
314
- batch = candidates[batch_start_idx:batch_start_idx + batch_size]
315
- batch_len = len(batch)
316
- batch_start_time = time.perf_counter()
317
-
318
- # Build population
319
- pop = {}
320
- for name, tensor in model.items():
321
- pop[name] = tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device)
322
-
323
- # Apply reductions
324
- for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
325
- new_val = old_val - 1 if old_val > 0 else old_val + 1
326
- flat_view = pop[name][pop_idx].flatten()
327
- flat_view[flat_idx] = new_val
328
-
329
- # Evaluate
330
- torch.manual_seed(0)
331
- if device == 'cuda':
332
- torch.cuda.synchronize()
333
- fitness = evaluator.evaluate(pop, debug=False)
334
- if device == 'cuda':
335
- torch.cuda.synchronize()
336
-
337
- # Collect successes
338
- batch_successes = 0
339
- for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
340
- if fitness[pop_idx].item() >= 0.9999:
341
- successful_candidates.append((name, flat_idx, shape, old_val))
342
- batch_successes += 1
343
-
344
- batch_time = time.perf_counter() - batch_start_time
345
- elapsed = time.perf_counter() - phase1_start
346
- done = batch_start_idx + batch_len
347
- eta = format_eta(elapsed, done, n_candidates)
348
- throughput = batch_len / batch_time
349
-
350
- print(f" Batch {batch_idx + 1}/{n_batches}: "
351
- f"{batch_successes}/{batch_len} passed ({100*batch_successes/batch_len:.1f}%) | "
352
- f"Total OK: {len(successful_candidates)} | "
353
- f"Progress: {done}/{n_candidates} ({100*done/n_candidates:.1f}%) | "
354
- f"Speed: {throughput:.0f}/s | "
355
- f"ETA: {eta}")
356
-
357
- phase1_time = time.perf_counter() - phase1_start
358
- print(f"\n Phase 1 complete: {len(successful_candidates)}/{n_candidates} candidates "
359
- f"({100*len(successful_candidates)/n_candidates:.1f}%) in {format_time(phase1_time)}")
360
-
361
- # Phase 2: Apply with conflict resolution
362
- if len(successful_candidates) == 0:
363
- print(f"\n No reductions possible. Stopping.")
364
- break
365
-
366
- print(f"\n PHASE 2: Apply reductions with conflict resolution")
367
- print(f" " + "-" * 60)
368
- phase2_start = time.perf_counter()
369
-
370
- accepted = batched_conflict_resolution(model, evaluator, device, successful_candidates, base_magnitude)
371
- pass_reductions = len(accepted)
372
-
373
- phase2_time = time.perf_counter() - phase2_start
374
- print(f"\n Phase 2 complete: {pass_reductions} reductions applied in {format_time(phase2_time)}")
375
-
376
- # Pass summary
377
- total_reductions += pass_reductions
378
- current_magnitude = sum(t.abs().sum().item() for t in model.values())
379
- current_nonzero = sum((t != 0).sum().item() for t in model.values())
380
- pass_time = time.perf_counter() - pass_start
381
- reduction_pct = 100 * (1 - current_magnitude / base_magnitude)
382
-
383
- print(f"\n PASS {pass_num + 1} SUMMARY:")
384
- print(f" Reductions this pass: {pass_reductions}")
385
- print(f" Total reductions: {total_reductions}")
386
- print(f" Current magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
387
- print(f" Current non-zero: {current_nonzero}")
388
- print(f" Pass time: {format_time(pass_time)}")
389
-
390
- # Verify after pass
391
- print(f"\n Verifying model integrity...")
392
- fitness = check_fitness(model, evaluator, device)
393
- print(f" Fitness: {fitness:.6f} {'PASS' if fitness >= 0.9999 else 'FAIL'}")
394
-
395
- # Save checkpoint after each pass
396
- checkpoint_name = checkpoint_path.replace('.safetensors', f'_pass{pass_num + 1}.safetensors')
397
- print(f"\n Saving checkpoint: {checkpoint_name}")
398
- save_file(model, checkpoint_name)
399
- print(f" Saved. Magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
400
-
401
- # Also save as "latest" for easy access
402
- latest_path = checkpoint_path.replace('.safetensors', '_latest.safetensors')
403
- save_file(model, latest_path)
404
- print(f" Also saved as: {latest_path}")
405
-
406
- if pass_reductions == 0:
407
- print(f"\n No reductions achieved. Stopping early.")
408
- break
409
-
410
- # Final summary
411
- pruning_time = time.perf_counter() - pruning_start
412
- final_magnitude = sum(t.abs().sum().item() for t in model.values())
413
- final_max = max(t.abs().max().item() for t in model.values())
414
- final_nonzero = sum((t != 0).sum().item() for t in model.values())
415
- reduction_pct = 100 * (1 - final_magnitude / base_magnitude)
416
-
417
- print("\n" + "=" * 80)
418
- print(" PRUNING COMPLETE")
419
- print("=" * 80)
420
- print(f"\n RESULTS:")
421
- print(f" Original magnitude: {base_magnitude:.0f}")
422
- print(f" Final magnitude: {final_magnitude:.0f}")
423
- print(f" Reduction: {reduction_pct:.2f}%")
424
- print(f" Total reductions: {total_reductions}")
425
- print(f" Original non-zero: {nonzero_params}")
426
- print(f" Final non-zero: {final_nonzero}")
427
- print(f" Zeros created: {nonzero_params - final_nonzero}")
428
- print(f" Max weight: {final_max:.0f}")
429
- print(f" Total time: {format_time(pruning_time)}")
430
-
431
- # Save
432
- print(f"\n SAVING to {checkpoint_path}...")
433
- save_file(model, checkpoint_path)
434
- print(f" Saved.")
435
-
436
- # Final verification
437
- print(f"\n FINAL VERIFICATION...")
438
- from safetensors import safe_open
439
- f = safe_open(checkpoint_path, framework='numpy')
440
- verify_model = {name: torch.tensor(f.get_tensor(name)).float() for name in f.keys()}
441
- verify_fitness = check_fitness(verify_model, evaluator, device)
442
- print(f" Fitness: {verify_fitness:.6f}")
443
-
444
- if verify_fitness >= 0.9999:
445
- print(f" STATUS: PASS")
446
- else:
447
- print(f" STATUS: FAIL - Model corrupted!")
448
-
449
- print("\n" + "=" * 80)
450
- return model
451
-
452
-
453
- MAX_BATCH_SIZE = 80000
454
-
455
- if __name__ == "__main__":
456
- parser = argparse.ArgumentParser(description='Batched Weight Pruning')
457
- parser.add_argument('--passes', type=int, default=10,
458
- help='Maximum pruning passes (default: 10)')
459
- parser.add_argument('--batch_size', type=int, default=80000,
460
- help=f'Batch size for parallel evaluation (default: 80000, max: {MAX_BATCH_SIZE})')
461
- parser.add_argument('--device', type=str, default='cuda',
462
- help='Device: cuda or cpu (default: cuda)')
463
- parser.add_argument('--output', type=str,
464
- default='D:/8bit-threshold-computer/pruned.safetensors',
465
- help='Output path')
466
- args = parser.parse_args()
467
-
468
- if args.batch_size > MAX_BATCH_SIZE:
469
- print(f"WARNING: batch_size {args.batch_size} exceeds maximum {MAX_BATCH_SIZE}. Clamping.")
470
- args.batch_size = MAX_BATCH_SIZE
471
-
472
- print(f"\nStarting at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
473
-
474
- prune_weights(
475
- passes=args.passes,
476
- batch_size=args.batch_size,
477
- device=args.device,
478
- checkpoint_path=args.output
479
- )
480
-
481
- print(f"\nFinished at {time.strftime('%Y-%m-%d %H:%M:%S')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stress_test.py DELETED
@@ -1,367 +0,0 @@
1
- """
2
- WILD STRESS TESTS - Push the threshold CPU to its limits
3
- """
4
- import torch
5
- from safetensors.torch import load_file
6
-
7
- model = load_file('./neural_computer.safetensors')
8
- model = {k: v.float() for k, v in model.items()}
9
-
10
- def heaviside(x):
11
- return (x >= 0).float()
12
-
13
- def int_to_bits(val, width=8):
14
- return torch.tensor([(val >> (width-1-i)) & 1 for i in range(width)], dtype=torch.float32)
15
-
16
- def bits_to_int(bits):
17
- val = 0
18
- for i, b in enumerate(bits):
19
- val |= (int(b.item()) << (len(bits)-1-i))
20
- return val
21
-
22
- # === BASIC PRIMITIVES ===
23
-
24
- def eval_xor(a, b):
25
- inp = torch.tensor([float(a), float(b)], dtype=torch.float32)
26
- w1_n1 = model['boolean.xor.layer1.neuron1.weight']
27
- b1_n1 = model['boolean.xor.layer1.neuron1.bias']
28
- w1_n2 = model['boolean.xor.layer1.neuron2.weight']
29
- b1_n2 = model['boolean.xor.layer1.neuron2.bias']
30
- w2 = model['boolean.xor.layer2.weight']
31
- b2 = model['boolean.xor.layer2.bias']
32
- h1 = heaviside(inp @ w1_n1 + b1_n1)
33
- h2 = heaviside(inp @ w1_n2 + b1_n2)
34
- hidden = torch.tensor([h1.item(), h2.item()])
35
- return int(heaviside(hidden @ w2 + b2).item())
36
-
37
- def eval_and(a, b):
38
- inp = torch.tensor([float(a), float(b)], dtype=torch.float32)
39
- return int(heaviside(inp @ model['boolean.and.weight'] + model['boolean.and.bias']).item())
40
-
41
- def eval_or(a, b):
42
- inp = torch.tensor([float(a), float(b)], dtype=torch.float32)
43
- return int(heaviside(inp @ model['boolean.or.weight'] + model['boolean.or.bias']).item())
44
-
45
- def eval_not(a):
46
- inp = torch.tensor([float(a)], dtype=torch.float32)
47
- return int(heaviside(inp @ model['boolean.not.weight'] + model['boolean.not.bias']).item())
48
-
49
- def eval_xor_arith(inp, prefix):
50
- w1_or = model[f'{prefix}.layer1.or.weight']
51
- b1_or = model[f'{prefix}.layer1.or.bias']
52
- w1_nand = model[f'{prefix}.layer1.nand.weight']
53
- b1_nand = model[f'{prefix}.layer1.nand.bias']
54
- w2 = model[f'{prefix}.layer2.weight']
55
- b2 = model[f'{prefix}.layer2.bias']
56
- h_or = heaviside(inp @ w1_or + b1_or)
57
- h_nand = heaviside(inp @ w1_nand + b1_nand)
58
- hidden = torch.tensor([h_or.item(), h_nand.item()])
59
- return heaviside(hidden @ w2 + b2).item()
60
-
61
- def eval_full_adder(a, b, cin, prefix):
62
- inp_ab = torch.tensor([a, b], dtype=torch.float32)
63
- ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
64
- ha1_carry = heaviside(inp_ab @ model[f'{prefix}.ha1.carry.weight'] + model[f'{prefix}.ha1.carry.bias']).item()
65
- inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
66
- ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
67
- ha2_carry = heaviside(inp_ha2 @ model[f'{prefix}.ha2.carry.weight'] + model[f'{prefix}.ha2.carry.bias']).item()
68
- inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
69
- cout = heaviside(inp_cout @ model[f'{prefix}.carry_or.weight'] + model[f'{prefix}.carry_or.bias']).item()
70
- return int(ha2_sum), int(cout)
71
-
72
- def add_8bit(a, b):
73
- carry = 0.0
74
- result = 0
75
- for i in range(8):
76
- s, carry = eval_full_adder(float((a >> i) & 1), float((b >> i) & 1), carry, f'arithmetic.ripplecarry8bit.fa{i}')
77
- result |= (s << i)
78
- return result, int(carry)
79
-
80
- def sub_8bit(a, b):
81
- # a - b = a + (~b + 1)
82
- not_b = 0
83
- for i in range(8):
84
- not_b |= (eval_not((b >> i) & 1) << i)
85
- temp, _ = add_8bit(a, not_b)
86
- result, _ = add_8bit(temp, 1)
87
- return result
88
-
89
- def gt(a, b):
90
- a_bits, b_bits = int_to_bits(a), int_to_bits(b)
91
- w = model['arithmetic.greaterthan8bit.comparator']
92
- return 1 if ((a_bits - b_bits) @ w).item() > 0 else 0
93
-
94
- def lt(a, b):
95
- a_bits, b_bits = int_to_bits(a), int_to_bits(b)
96
- w = model['arithmetic.lessthan8bit.comparator']
97
- return 1 if ((b_bits - a_bits) @ w).item() > 0 else 0
98
-
99
- def eq(a, b):
100
- return 1 if (gt(a,b) == 0 and lt(a,b) == 0) else 0
101
-
102
- def popcount(val):
103
- bits = int_to_bits(val)
104
- w = model['pattern_recognition.popcount.weight']
105
- b = model['pattern_recognition.popcount.bias']
106
- return int((bits @ w + b).item())
107
-
108
- print('='*70)
109
- print('WILD STRESS TESTS')
110
- print('='*70)
111
-
112
- # === TEST 1: FACTORIAL ===
113
- print('\n[1] FACTORIAL via chained multiply-add')
114
- def factorial(n):
115
- result = 1
116
- for i in range(2, n+1):
117
- new_result = 0
118
- for _ in range(i):
119
- new_result, _ = add_8bit(new_result, result)
120
- new_result &= 0xFF
121
- result = new_result
122
- return result
123
-
124
- for n in [1, 2, 3, 4, 5]:
125
- got = factorial(n)
126
- expected = [1, 1, 2, 6, 24, 120][n]
127
- status = 'OK' if got == expected else 'FAIL'
128
- print(f' {n}! = {got} (expected {expected}) [{status}]')
129
-
130
- # === TEST 2: GCD ===
131
- print('\n[2] GCD via Euclidean algorithm')
132
- def gcd(a, b):
133
- iterations = 0
134
- while not eq(b, 0) and iterations < 100:
135
- temp = a
136
- while not lt(temp, b) and not eq(temp, 0) and iterations < 100:
137
- temp = sub_8bit(temp, b)
138
- iterations += 1
139
- a, b = b, temp
140
- iterations += 1
141
- return a
142
-
143
- test_gcds = [(48, 18, 6), (100, 35, 5), (252, 105, 21), (17, 13, 1), (128, 64, 64)]
144
- for a, b, expected in test_gcds:
145
- got = gcd(a, b)
146
- status = 'OK' if got == expected else 'FAIL'
147
- print(f' gcd({a}, {b}) = {got} (expected {expected}) [{status}]')
148
-
149
- # === TEST 3: FIBONACCI ===
150
- print('\n[3] FIBONACCI until overflow')
151
- def fib_sequence():
152
- a, b = 0, 1
153
- seq = [a, b]
154
- for _ in range(20):
155
- next_val, carry = add_8bit(a, b)
156
- if carry:
157
- break
158
- seq.append(next_val)
159
- a, b = b, next_val
160
- return seq
161
-
162
- fib = fib_sequence()
163
- expected_fib = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233]
164
- print(f' Computed: {fib[:len(expected_fib)]}')
165
- print(f' Expected: {expected_fib}')
166
- print(f' Match: {fib[:len(expected_fib)] == expected_fib}')
167
-
168
- # === TEST 4: PRIME CHECK ===
169
- print('\n[4] PRIME CHECK via trial division')
170
- def is_prime(n):
171
- if n < 2: return False
172
- if n == 2: return True
173
- if (n & 1) == 0: return False
174
-
175
- i = 3
176
- while i * i <= n and i < n:
177
- temp = n
178
- while temp >= i:
179
- temp = sub_8bit(temp, i)
180
- if eq(temp, 0):
181
- return False
182
- i += 2
183
- return True
184
-
185
- primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
186
- non_primes = [4, 6, 8, 9, 10, 12, 14, 15, 16, 18]
187
-
188
- prime_pass = sum(1 for p in primes if is_prime(p))
189
- non_prime_pass = sum(1 for n in non_primes if not is_prime(n))
190
- print(f' Primes correctly identified: {prime_pass}/10')
191
- print(f' Non-primes correctly rejected: {non_prime_pass}/10')
192
-
193
- # === TEST 5: INTEGER SQRT ===
194
- print('\n[5] INTEGER SQUARE ROOT via binary search')
195
- def isqrt(n):
196
- if n == 0: return 0
197
- lo, hi = 1, min(n, 15) # limit for 8-bit
198
- result = 0
199
- iterations = 0
200
- while lo <= hi and iterations < 50:
201
- mid = (lo + hi) >> 1
202
- sq = 0
203
- for _ in range(mid):
204
- sq, _ = add_8bit(sq, mid)
205
- sq &= 0xFF
206
-
207
- if sq <= n:
208
- result = mid
209
- lo = mid + 1
210
- else:
211
- hi = mid - 1
212
- iterations += 1
213
- return result
214
-
215
- sqrt_tests = [(0, 0), (1, 1), (4, 2), (9, 3), (16, 4), (25, 5), (36, 6), (49, 7), (64, 8), (81, 9), (100, 10), (144, 12)]
216
- sqrt_pass = 0
217
- for n, expected in sqrt_tests:
218
- got = isqrt(n)
219
- if got == expected:
220
- sqrt_pass += 1
221
- print(f' Passed: {sqrt_pass}/{len(sqrt_tests)}')
222
-
223
- # === TEST 6: COLLATZ ===
224
- print('\n[6] COLLATZ CONJECTURE iterations')
225
- def collatz_steps(n):
226
- steps = 0
227
- while n != 1 and steps < 200:
228
- if (n & 1) == 0:
229
- n = n >> 1
230
- else:
231
- temp, _ = add_8bit(n, n)
232
- temp, _ = add_8bit(temp, n)
233
- n, _ = add_8bit(temp, 1)
234
- n &= 0xFF
235
- steps += 1
236
- if n == 0: break
237
- return steps
238
-
239
- collatz_tests = [(1, 0), (2, 1), (3, 7), (6, 8)]
240
- for start, expected in collatz_tests:
241
- got = collatz_steps(start)
242
- status = 'OK' if got == expected else f'got {got}'
243
- print(f' collatz({start}) = {got} steps [{status}]')
244
-
245
- # === TEST 7: SORT BY POPCOUNT ===
246
- print('\n[7] SORT BY HAMMING WEIGHT (popcount)')
247
- values = [0b11111111, 0b00000001, 0b10101010, 0b00001111, 0b11110000, 0b00000000]
248
- weighted = [(v, popcount(v)) for v in values]
249
- for i in range(len(weighted)):
250
- for j in range(len(weighted) - 1):
251
- if gt(weighted[j][1], weighted[j+1][1]):
252
- weighted[j], weighted[j+1] = weighted[j+1], weighted[j]
253
-
254
- print(f' Sorted by popcount:')
255
- for v, p in weighted:
256
- print(f' {bin(v):>12} -> popcount = {p}')
257
-
258
- # === TEST 8: XOR CHECKSUM ===
259
- print('\n[8] XOR CHECKSUM of message')
260
- message = [0x48, 0x65, 0x6C, 0x6C, 0x6F] # "Hello"
261
- checksum = 0
262
- for byte in message:
263
- for i in range(8):
264
- bit_a = (checksum >> i) & 1
265
- bit_b = (byte >> i) & 1
266
- xor_bit = eval_xor(bit_a, bit_b)
267
- checksum = (checksum & ~(1 << i)) | (xor_bit << i)
268
-
269
- expected_checksum = 0x48 ^ 0x65 ^ 0x6C ^ 0x6C ^ 0x6F
270
- status = 'OK' if checksum == expected_checksum else 'FAIL'
271
- print(f' Message: {[hex(b) for b in message]}')
272
- print(f' XOR checksum: {hex(checksum)} (expected {hex(expected_checksum)}) [{status}]')
273
-
274
- # === TEST 9: PARITY TREE ===
275
- print('\n[9] 8-BIT PARITY (full XOR tree)')
276
- def parity_8bit(val):
277
- bits = [(val >> i) & 1 for i in range(8)]
278
- s1 = [eval_xor(bits[0], bits[1]), eval_xor(bits[2], bits[3]),
279
- eval_xor(bits[4], bits[5]), eval_xor(bits[6], bits[7])]
280
- s2 = [eval_xor(s1[0], s1[1]), eval_xor(s1[2], s1[3])]
281
- return eval_xor(s2[0], s2[1])
282
-
283
- parity_tests = [(0x00, 0), (0xFF, 0), (0x01, 1), (0x03, 0), (0x07, 1), (0xAA, 0), (0x55, 0), (0x81, 0), (0x80, 1)]
284
- parity_pass = sum(1 for v, exp in parity_tests if parity_8bit(v) == exp)
285
- print(f' Passed: {parity_pass}/{len(parity_tests)}')
286
-
287
- # === TEST 10: OVERFLOW CASCADE ===
288
- print('\n[10] OVERFLOW CASCADE (255 + 1 chain)')
289
- val = 255
290
- carries = []
291
- for i in range(5):
292
- val, carry = add_8bit(val, 1)
293
- carries.append(carry)
294
- print(f' 255 -> +1 -> +1 -> +1 -> +1 -> +1')
295
- print(f' Carries: {carries}')
296
- print(f' Final value: {val} (expected 4) [{"OK" if val == 4 else "FAIL"}]')
297
-
298
- # === TEST 11: POWER OF 2 CHECK ===
299
- print('\n[11] POWER OF 2 detection (popcount == 1)')
300
- def is_power_of_2(n):
301
- if n == 0: return False
302
- return popcount(n) == 1
303
-
304
- pow2_tests = [(1, True), (2, True), (4, True), (8, True), (16, True), (32, True), (64, True), (128, True),
305
- (3, False), (5, False), (6, False), (7, False), (9, False), (15, False), (255, False)]
306
- pow2_pass = sum(1 for n, exp in pow2_tests if is_power_of_2(n) == exp)
307
- print(f' Passed: {pow2_pass}/{len(pow2_tests)}')
308
-
309
- # === TEST 12: BYTE REVERSE ===
310
- print('\n[12] BYTE REVERSE via bit manipulation')
311
- def reverse_bits(val):
312
- result = 0
313
- for i in range(8):
314
- bit = (val >> i) & 1
315
- result |= (bit << (7 - i))
316
- return result
317
-
318
- reverse_tests = [(0b10000000, 0b00000001), (0b11110000, 0b00001111), (0b10101010, 0b01010101), (0b00000000, 0b00000000), (0b11111111, 0b11111111)]
319
- reverse_pass = sum(1 for inp, exp in reverse_tests if reverse_bits(inp) == exp)
320
- print(f' Passed: {reverse_pass}/{len(reverse_tests)}')
321
-
322
- # === TEST 13: MAX/MIN via comparator ===
323
- print('\n[13] MAX and MIN of array')
324
- def find_max(arr):
325
- m = arr[0]
326
- for x in arr[1:]:
327
- if gt(x, m):
328
- m = x
329
- return m
330
-
331
- def find_min(arr):
332
- m = arr[0]
333
- for x in arr[1:]:
334
- if lt(x, m):
335
- m = x
336
- return m
337
-
338
- test_arr = [42, 17, 255, 0, 128, 64, 33]
339
- got_max = find_max(test_arr)
340
- got_min = find_min(test_arr)
341
- print(f' Array: {test_arr}')
342
- print(f' Max: {got_max} (expected 255) [{"OK" if got_max == 255 else "FAIL"}]')
343
- print(f' Min: {got_min} (expected 0) [{"OK" if got_min == 0 else "FAIL"}]')
344
-
345
- # === TEST 14: LFSR (pseudo-random) ===
346
- print('\n[14] 8-BIT LFSR (taps at 8,6,5,4)')
347
- def lfsr_step(state):
348
- # Taps: 8, 6, 5, 4 (for maximal length)
349
- bit = eval_xor((state >> 0) & 1, (state >> 2) & 1)
350
- bit = eval_xor(bit, (state >> 3) & 1)
351
- bit = eval_xor(bit, (state >> 4) & 1)
352
- return ((state >> 1) | (bit << 7)) & 0xFF
353
-
354
- state = 1
355
- seen = set()
356
- for i in range(300):
357
- if state in seen:
358
- break
359
- seen.add(state)
360
- state = lfsr_step(state)
361
-
362
- print(f' Period: {len(seen)} (max possible: 255)')
363
- print(f' Full period: {"OK" if len(seen) == 255 else "FAIL"}')
364
-
365
- print('\n' + '='*70)
366
- print('STRESS TESTS COMPLETE')
367
- print('='*70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/iron_eval.py DELETED
The diff for this file is too large to render. See raw diff
 
train_circuit_interface.py DELETED
@@ -1,306 +0,0 @@
1
- """
2
- Train the circuit interface layers on arithmetic examples.
3
- ============================================================
4
-
5
- The threshold circuits are frozen - we only train:
6
- - BitExtractor: embedding -> operand bits
7
- - BitInjector: result bits -> embedding
8
- - Router: when to use circuits vs MLP
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- from torch.utils.data import Dataset, DataLoader
14
- from transformers import AutoModelForCausalLM, AutoTokenizer
15
- from tqdm import tqdm
16
- import argparse
17
- import warnings
18
- warnings.filterwarnings('ignore')
19
-
20
- from circuit_llm import (
21
- augment_smollm2_with_circuits,
22
- evaluate_arithmetic,
23
- CircuitExecutor
24
- )
25
-
26
-
27
- # =============================================================================
28
- # ARITHMETIC DATASET
29
- # =============================================================================
30
-
31
- class ArithmeticDataset(Dataset):
32
- """Dataset of 8-bit addition problems."""
33
-
34
- def __init__(self, tokenizer, n_samples: int = 10000, max_val: int = 255):
35
- self.tokenizer = tokenizer
36
- self.n_samples = n_samples
37
- self.max_val = max_val
38
-
39
- # Pre-generate all examples
40
- self.examples = []
41
- for _ in range(n_samples):
42
- a = torch.randint(0, max_val + 1, (1,)).item()
43
- b = torch.randint(0, max_val + 1, (1,)).item()
44
- result = (a + b) % 256
45
-
46
- prompt = f"{a} + {b} ="
47
- target = f" {result}"
48
-
49
- self.examples.append((prompt, target, a, b, result))
50
-
51
- def __len__(self):
52
- return len(self.examples)
53
-
54
- def __getitem__(self, idx):
55
- prompt, target, a, b, result = self.examples[idx]
56
-
57
- # Tokenize
58
- prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
59
- target_ids = self.tokenizer.encode(target, add_special_tokens=False)
60
-
61
- input_ids = prompt_ids + target_ids
62
- labels = [-100] * len(prompt_ids) + target_ids # Only predict target
63
-
64
- return {
65
- 'input_ids': torch.tensor(input_ids),
66
- 'labels': torch.tensor(labels),
67
- 'a': a,
68
- 'b': b,
69
- 'result': result
70
- }
71
-
72
-
73
- def collate_fn(batch):
74
- """Collate with padding."""
75
- max_len = max(len(item['input_ids']) for item in batch)
76
-
77
- input_ids = []
78
- labels = []
79
- attention_mask = []
80
-
81
- for item in batch:
82
- pad_len = max_len - len(item['input_ids'])
83
-
84
- input_ids.append(
85
- torch.cat([item['input_ids'], torch.zeros(pad_len, dtype=torch.long)])
86
- )
87
- labels.append(
88
- torch.cat([item['labels'], torch.full((pad_len,), -100, dtype=torch.long)])
89
- )
90
- attention_mask.append(
91
- torch.cat([torch.ones(len(item['input_ids'])), torch.zeros(pad_len)])
92
- )
93
-
94
- return {
95
- 'input_ids': torch.stack(input_ids),
96
- 'labels': torch.stack(labels),
97
- 'attention_mask': torch.stack(attention_mask),
98
- }
99
-
100
-
101
- # =============================================================================
102
- # TRAINING LOOP
103
- # =============================================================================
104
-
105
- def train_interface(
106
- model: AutoModelForCausalLM,
107
- tokenizer: AutoTokenizer,
108
- n_epochs: int = 3,
109
- batch_size: int = 16,
110
- lr: float = 1e-4,
111
- n_train_samples: int = 10000,
112
- device: str = 'cpu',
113
- eval_every: int = 500
114
- ):
115
- """
116
- Train the circuit interface layers.
117
-
118
- Only trains:
119
- - bit_extractor (embedding -> bits)
120
- - bit_injector (bits -> embedding)
121
- - router (circuit vs MLP weighting)
122
- - op_selector (which operation)
123
- """
124
- print("\n" + "=" * 70)
125
- print(" TRAINING CIRCUIT INTERFACE")
126
- print("=" * 70)
127
-
128
- # Freeze everything except interface layers
129
- interface_params = []
130
- frozen_count = 0
131
- trainable_count = 0
132
-
133
- for name, param in model.named_parameters():
134
- if any(x in name for x in ['bit_extractor', 'bit_injector', 'router', 'op_selector']):
135
- param.requires_grad = True
136
- interface_params.append(param)
137
- trainable_count += param.numel()
138
- else:
139
- param.requires_grad = False
140
- frozen_count += param.numel()
141
-
142
- print(f"\n Frozen parameters: {frozen_count:,}")
143
- print(f" Trainable parameters: {trainable_count:,}")
144
- print(f" Training {len(interface_params)} parameter groups")
145
-
146
- # Create dataset
147
- print(f"\n Creating dataset ({n_train_samples} examples)...")
148
- dataset = ArithmeticDataset(tokenizer, n_samples=n_train_samples)
149
- dataloader = DataLoader(
150
- dataset,
151
- batch_size=batch_size,
152
- shuffle=True,
153
- collate_fn=collate_fn
154
- )
155
-
156
- # Optimizer
157
- optimizer = torch.optim.AdamW(interface_params, lr=lr)
158
-
159
- # Training
160
- model.to(device)
161
- model.train()
162
-
163
- global_step = 0
164
- total_loss = 0
165
-
166
- for epoch in range(n_epochs):
167
- print(f"\n Epoch {epoch + 1}/{n_epochs}")
168
- print(" " + "-" * 60)
169
-
170
- epoch_loss = 0
171
- epoch_steps = 0
172
-
173
- pbar = tqdm(dataloader, desc=f" Training", leave=False)
174
-
175
- for batch in pbar:
176
- input_ids = batch['input_ids'].to(device)
177
- labels = batch['labels'].to(device)
178
- attention_mask = batch['attention_mask'].to(device)
179
-
180
- # Forward
181
- outputs = model(
182
- input_ids=input_ids,
183
- attention_mask=attention_mask,
184
- labels=labels
185
- )
186
-
187
- loss = outputs.loss
188
-
189
- # Backward
190
- optimizer.zero_grad()
191
- loss.backward()
192
- optimizer.step()
193
-
194
- # Logging
195
- epoch_loss += loss.item()
196
- epoch_steps += 1
197
- global_step += 1
198
- total_loss += loss.item()
199
-
200
- pbar.set_postfix({'loss': f'{loss.item():.4f}'})
201
-
202
- # Periodic evaluation
203
- if global_step % eval_every == 0:
204
- model.eval()
205
- eval_results = evaluate_arithmetic(model, tokenizer, n_problems=50, device=device)
206
- print(f"\n Step {global_step}: Loss={total_loss/eval_every:.4f}, "
207
- f"Accuracy={eval_results['accuracy']*100:.1f}%")
208
- total_loss = 0
209
- model.train()
210
-
211
- avg_loss = epoch_loss / epoch_steps
212
- print(f"\n Epoch {epoch + 1} complete. Avg loss: {avg_loss:.4f}")
213
-
214
- # End of epoch evaluation
215
- model.eval()
216
- eval_results = evaluate_arithmetic(model, tokenizer, n_problems=100, device=device)
217
- print(f" Evaluation: {eval_results['accuracy']*100:.1f}% "
218
- f"({eval_results['correct']}/{eval_results['total']})")
219
-
220
- if eval_results['errors']:
221
- print(f" Sample errors:")
222
- for a, b, exp, got in eval_results['errors'][:3]:
223
- print(f" {a} + {b} = {exp}, model said {got}")
224
-
225
- model.train()
226
-
227
- print("\n" + "=" * 70)
228
- print(" TRAINING COMPLETE")
229
- print("=" * 70)
230
-
231
- return model
232
-
233
-
234
- # =============================================================================
235
- # MAIN
236
- # =============================================================================
237
-
238
- if __name__ == "__main__":
239
- parser = argparse.ArgumentParser(description='Train Circuit Interface')
240
- parser.add_argument('--circuit-path', type=str,
241
- default='./neural_computer.safetensors',
242
- help='Path to circuit weights')
243
- parser.add_argument('--device', type=str, default='cpu',
244
- help='Device (cpu or cuda)')
245
- parser.add_argument('--epochs', type=int, default=3,
246
- help='Number of epochs')
247
- parser.add_argument('--batch-size', type=int, default=8,
248
- help='Batch size')
249
- parser.add_argument('--lr', type=float, default=1e-4,
250
- help='Learning rate')
251
- parser.add_argument('--n-samples', type=int, default=5000,
252
- help='Number of training samples')
253
- args = parser.parse_args()
254
-
255
- print("=" * 70)
256
- print(" CIRCUIT-AUGMENTED LLM TRAINING")
257
- print("=" * 70)
258
-
259
- # Load model
260
- print("\n[1] Loading SmolLM2-360M...")
261
- model_id = "HuggingFaceTB/SmolLM2-360M"
262
- tokenizer = AutoTokenizer.from_pretrained(model_id)
263
- tokenizer.pad_token = tokenizer.eos_token
264
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
265
-
266
- # Baseline
267
- print("\n[2] Baseline evaluation...")
268
- baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
269
- print(f" Baseline accuracy: {baseline['accuracy']*100:.1f}%")
270
-
271
- # Augment
272
- print("\n[3] Augmenting with circuits...")
273
- model = augment_smollm2_with_circuits(
274
- model,
275
- args.circuit_path,
276
- device=args.device
277
- )
278
-
279
- # Train
280
- print("\n[4] Training interface layers...")
281
- model = train_interface(
282
- model,
283
- tokenizer,
284
- n_epochs=args.epochs,
285
- batch_size=args.batch_size,
286
- lr=args.lr,
287
- n_train_samples=args.n_samples,
288
- device=args.device
289
- )
290
-
291
- # Final evaluation
292
- print("\n[5] Final evaluation...")
293
- final = evaluate_arithmetic(model, tokenizer, n_problems=100, device=args.device)
294
- print(f" Final accuracy: {final['accuracy']*100:.1f}%")
295
- print(f" Improvement: {baseline['accuracy']*100:.1f}% -> {final['accuracy']*100:.1f}%")
296
-
297
- # Save
298
- save_path = './circuit_augmented_smollm2.pt'
299
- print(f"\n[6] Saving to {save_path}...")
300
- torch.save({
301
- 'model_state_dict': model.state_dict(),
302
- 'baseline_accuracy': baseline['accuracy'],
303
- 'final_accuracy': final['accuracy']
304
- }, save_path)
305
-
306
- print("\nDone!")