phanerozoic commited on
Commit
96eb2fc
·
verified ·
1 Parent(s): 58d0763

Delete circuit_llm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. circuit_llm.py +0 -606
circuit_llm.py DELETED
@@ -1,606 +0,0 @@
1
- """
2
- Circuit-Augmented LLM: Embedding threshold logic circuits into SmolLM2
3
- ======================================================================
4
-
5
- Replaces/augments MLP layers with frozen threshold circuits for exact arithmetic.
6
- """
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from typing import Dict, Optional, Tuple
12
- from safetensors.torch import load_file
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
14
- import warnings
15
- warnings.filterwarnings('ignore')
16
-
17
-
18
- # =============================================================================
19
- # HEAVISIDE WITH STRAIGHT-THROUGH ESTIMATOR
20
- # =============================================================================
21
-
22
- class HeavisideSTE(torch.autograd.Function):
23
- """Heaviside step function with straight-through estimator for backprop."""
24
-
25
- @staticmethod
26
- def forward(ctx, x):
27
- return (x >= 0).float()
28
-
29
- @staticmethod
30
- def backward(ctx, grad_output):
31
- # STE: pass gradient through unchanged
32
- return grad_output
33
-
34
-
35
- def heaviside(x: torch.Tensor) -> torch.Tensor:
36
- """Heaviside step: 1 if x >= 0, else 0. Uses STE for training."""
37
- return HeavisideSTE.apply(x)
38
-
39
-
40
- # =============================================================================
41
- # CIRCUIT EXECUTOR - Runs the frozen threshold circuits
42
- # =============================================================================
43
-
44
- class CircuitExecutor(nn.Module):
45
- """
46
- Executes threshold logic circuits from the safetensors file.
47
- All circuit weights are frozen - only interface layers train.
48
- """
49
-
50
- def __init__(self, circuit_path: str, device: str = 'cpu'):
51
- super().__init__()
52
- self.device = device
53
-
54
- # Load all circuit tensors
55
- raw_circuits = load_file(circuit_path)
56
-
57
- # Store as frozen parameters (use underscores for valid param names)
58
- self.circuits = {}
59
- for k, v in raw_circuits.items():
60
- safe_name = k.replace('.', '__')
61
- self.register_buffer(safe_name, v.float().to(device))
62
- self.circuits[k] = safe_name
63
-
64
- def _get(self, name: str) -> torch.Tensor:
65
- """Get circuit tensor by original dotted name."""
66
- return getattr(self, self.circuits[name])
67
-
68
- # -------------------------------------------------------------------------
69
- # Boolean Gates
70
- # -------------------------------------------------------------------------
71
-
72
- def eval_and(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
73
- """AND gate: output 1 iff both inputs are 1."""
74
- inp = torch.stack([a, b], dim=-1)
75
- w = self._get('boolean.and.weight')
76
- bias = self._get('boolean.and.bias')
77
- return heaviside(inp @ w + bias)
78
-
79
- def eval_or(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
80
- """OR gate: output 1 if either input is 1."""
81
- inp = torch.stack([a, b], dim=-1)
82
- w = self._get('boolean.or.weight')
83
- bias = self._get('boolean.or.bias')
84
- return heaviside(inp @ w + bias)
85
-
86
- def eval_xor(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
87
- """XOR gate: two-layer network (not linearly separable)."""
88
- inp = torch.stack([a, b], dim=-1)
89
-
90
- # Layer 1: OR and NAND neurons
91
- w1_n1 = self._get('boolean.xor.layer1.neuron1.weight')
92
- b1_n1 = self._get('boolean.xor.layer1.neuron1.bias')
93
- w1_n2 = self._get('boolean.xor.layer1.neuron2.weight')
94
- b1_n2 = self._get('boolean.xor.layer1.neuron2.bias')
95
-
96
- h1 = heaviside(inp @ w1_n1 + b1_n1)
97
- h2 = heaviside(inp @ w1_n2 + b1_n2)
98
- hidden = torch.stack([h1, h2], dim=-1)
99
-
100
- # Layer 2: AND of hidden
101
- w2 = self._get('boolean.xor.layer2.weight')
102
- b2 = self._get('boolean.xor.layer2.bias')
103
-
104
- return heaviside(hidden @ w2 + b2)
105
-
106
- # -------------------------------------------------------------------------
107
- # Arithmetic: Full Adder
108
- # -------------------------------------------------------------------------
109
-
110
- def eval_full_adder(self, a: torch.Tensor, b: torch.Tensor,
111
- cin: torch.Tensor, prefix: str) -> Tuple[torch.Tensor, torch.Tensor]:
112
- """
113
- Full adder: sum = a XOR b XOR cin, cout = (a AND b) OR (cin AND (a XOR b))
114
- Returns (sum_bit, carry_out)
115
- """
116
- inp_ab = torch.stack([a, b], dim=-1)
117
-
118
- # HA1: a XOR b
119
- w1_or = self._get(f'{prefix}.ha1.sum.layer1.or.weight')
120
- b1_or = self._get(f'{prefix}.ha1.sum.layer1.or.bias')
121
- w1_nand = self._get(f'{prefix}.ha1.sum.layer1.nand.weight')
122
- b1_nand = self._get(f'{prefix}.ha1.sum.layer1.nand.bias')
123
- w2 = self._get(f'{prefix}.ha1.sum.layer2.weight')
124
- b2 = self._get(f'{prefix}.ha1.sum.layer2.bias')
125
-
126
- h_or = heaviside(inp_ab @ w1_or + b1_or)
127
- h_nand = heaviside(inp_ab @ w1_nand + b1_nand)
128
- hidden = torch.stack([h_or, h_nand], dim=-1)
129
- ha1_sum = heaviside(hidden @ w2 + b2)
130
-
131
- # HA1 carry
132
- w_c1 = self._get(f'{prefix}.ha1.carry.weight')
133
- b_c1 = self._get(f'{prefix}.ha1.carry.bias')
134
- ha1_carry = heaviside(inp_ab @ w_c1 + b_c1)
135
-
136
- # HA2: ha1_sum XOR cin
137
- inp_ha2 = torch.stack([ha1_sum, cin], dim=-1)
138
- w1_or = self._get(f'{prefix}.ha2.sum.layer1.or.weight')
139
- b1_or = self._get(f'{prefix}.ha2.sum.layer1.or.bias')
140
- w1_nand = self._get(f'{prefix}.ha2.sum.layer1.nand.weight')
141
- b1_nand = self._get(f'{prefix}.ha2.sum.layer1.nand.bias')
142
- w2 = self._get(f'{prefix}.ha2.sum.layer2.weight')
143
- b2 = self._get(f'{prefix}.ha2.sum.layer2.bias')
144
-
145
- h_or = heaviside(inp_ha2 @ w1_or + b1_or)
146
- h_nand = heaviside(inp_ha2 @ w1_nand + b1_nand)
147
- hidden = torch.stack([h_or, h_nand], dim=-1)
148
- ha2_sum = heaviside(hidden @ w2 + b2)
149
-
150
- # HA2 carry
151
- w_c2 = self._get(f'{prefix}.ha2.carry.weight')
152
- b_c2 = self._get(f'{prefix}.ha2.carry.bias')
153
- ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2)
154
-
155
- # Carry out = ha1_carry OR ha2_carry
156
- inp_cout = torch.stack([ha1_carry, ha2_carry], dim=-1)
157
- w_or = self._get(f'{prefix}.carry_or.weight')
158
- b_or = self._get(f'{prefix}.carry_or.bias')
159
- cout = heaviside(inp_cout @ w_or + b_or)
160
-
161
- return ha2_sum, cout
162
-
163
- # -------------------------------------------------------------------------
164
- # Arithmetic: 8-bit Ripple Carry Adder
165
- # -------------------------------------------------------------------------
166
-
167
- def add_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
168
- """
169
- 8-bit ripple carry addition.
170
- a_bits, b_bits: [..., 8] tensors of bits (LSB first)
171
- Returns: (result_bits [..., 8], carry_out [...])
172
- """
173
- batch_shape = a_bits.shape[:-1]
174
- carry = torch.zeros(batch_shape, device=a_bits.device)
175
- result_bits = []
176
-
177
- for i in range(8):
178
- a_i = a_bits[..., i]
179
- b_i = b_bits[..., i]
180
- sum_bit, carry = self.eval_full_adder(
181
- a_i, b_i, carry,
182
- f'arithmetic.ripplecarry8bit.fa{i}'
183
- )
184
- result_bits.append(sum_bit)
185
-
186
- return torch.stack(result_bits, dim=-1), carry
187
-
188
- # -------------------------------------------------------------------------
189
- # Arithmetic: 8-bit Comparators
190
- # -------------------------------------------------------------------------
191
-
192
- def greater_than_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
193
- """Returns 1 if a > b, else 0. Bits are MSB first."""
194
- diff = a_bits - b_bits # [..., 8]
195
- w = self._get('arithmetic.greaterthan8bit.comparator')
196
- score = (diff * w).sum(dim=-1)
197
- return (score > 0).float()
198
-
199
- def less_than_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
200
- """Returns 1 if a < b, else 0. Bits are MSB first."""
201
- diff = b_bits - a_bits # [..., 8]
202
- w = self._get('arithmetic.lessthan8bit.comparator')
203
- score = (diff * w).sum(dim=-1)
204
- return (score > 0).float()
205
-
206
- def equal_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
207
- """Returns 1 if a == b, else 0."""
208
- gt = self.greater_than_8bit(a_bits, b_bits)
209
- lt = self.less_than_8bit(a_bits, b_bits)
210
- return (1 - gt) * (1 - lt)
211
-
212
-
213
- # =============================================================================
214
- # BIT EXTRACTION / INJECTION INTERFACES
215
- # =============================================================================
216
-
217
- class BitExtractor(nn.Module):
218
- """
219
- Learns to extract 8-bit operands from token embeddings.
220
- Maps embedding -> 16 bits (two 8-bit operands).
221
- """
222
-
223
- def __init__(self, d_model: int):
224
- super().__init__()
225
- self.d_model = d_model
226
-
227
- # Project to logits, then binarize
228
- self.proj = nn.Linear(d_model, 16)
229
-
230
- # Learnable temperature for sigmoid approximation during training
231
- self.temperature = nn.Parameter(torch.tensor(1.0))
232
-
233
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
- """
235
- x: [..., d_model]
236
- Returns: a_bits [..., 8], b_bits [..., 8] (LSB first for arithmetic)
237
- """
238
- logits = self.proj(x) # [..., 16]
239
-
240
- # Binarize with STE
241
- bits = heaviside(logits)
242
-
243
- # Split into two operands
244
- a_bits = bits[..., :8]
245
- b_bits = bits[..., 8:]
246
-
247
- return a_bits, b_bits
248
-
249
-
250
- class BitInjector(nn.Module):
251
- """
252
- Learns to inject circuit results back into embedding space.
253
- Maps 16 bits (result + flags) -> embedding delta.
254
- """
255
-
256
- def __init__(self, d_model: int):
257
- super().__init__()
258
- self.d_model = d_model
259
-
260
- # Project bits to embedding
261
- self.proj = nn.Linear(16, d_model)
262
-
263
- # Learnable scale
264
- self.scale = nn.Parameter(torch.tensor(0.1))
265
-
266
- def forward(self, result_bits: torch.Tensor, flags: torch.Tensor) -> torch.Tensor:
267
- """
268
- result_bits: [..., 8]
269
- flags: [..., 8] (carry, overflow, zero, negative, etc.)
270
- Returns: [..., d_model]
271
- """
272
- combined = torch.cat([result_bits, flags], dim=-1) # [..., 16]
273
- return self.proj(combined) * self.scale
274
-
275
-
276
- # =============================================================================
277
- # CIRCUIT-AUGMENTED MLP BLOCK
278
- # =============================================================================
279
-
280
- class CircuitAugmentedMLP(nn.Module):
281
- """
282
- MLP block augmented with frozen threshold circuits.
283
-
284
- The original MLP path runs in parallel with the circuit path.
285
- A learned router decides how much to use each.
286
- """
287
-
288
- def __init__(
289
- self,
290
- d_model: int,
291
- intermediate_size: int,
292
- circuit_path: str,
293
- device: str = 'cpu'
294
- ):
295
- super().__init__()
296
- self.d_model = d_model
297
-
298
- # Original MLP components (will be loaded from pretrained)
299
- self.gate_proj = nn.Linear(d_model, intermediate_size, bias=False)
300
- self.up_proj = nn.Linear(d_model, intermediate_size, bias=False)
301
- self.down_proj = nn.Linear(intermediate_size, d_model, bias=False)
302
- self.act_fn = nn.SiLU()
303
-
304
- # Circuit components
305
- self.circuits = CircuitExecutor(circuit_path, device)
306
- self.bit_extractor = BitExtractor(d_model)
307
- self.bit_injector = BitInjector(d_model)
308
-
309
- # Router: decides circuit vs MLP contribution
310
- self.router = nn.Sequential(
311
- nn.Linear(d_model, 64),
312
- nn.ReLU(),
313
- nn.Linear(64, 2),
314
- nn.Softmax(dim=-1)
315
- )
316
-
317
- # Operation selector (which arithmetic op to perform)
318
- self.op_selector = nn.Sequential(
319
- nn.Linear(d_model, 32),
320
- nn.ReLU(),
321
- nn.Linear(32, 4), # add, sub, compare, passthrough
322
- nn.Softmax(dim=-1)
323
- )
324
-
325
- def _compute_flags(self, result_bits: torch.Tensor, carry: torch.Tensor) -> torch.Tensor:
326
- """Compute status flags from result."""
327
- batch_shape = result_bits.shape[:-1]
328
-
329
- # Zero flag: all bits are 0
330
- zero = (result_bits.sum(dim=-1) == 0).float()
331
-
332
- # Negative flag: MSB is 1 (two's complement)
333
- negative = result_bits[..., 7]
334
-
335
- # Carry flag
336
- carry_flag = carry
337
-
338
- # Pad to 8 flags
339
- flags = torch.zeros(*batch_shape, 8, device=result_bits.device)
340
- flags[..., 0] = zero
341
- flags[..., 1] = negative
342
- flags[..., 2] = carry_flag
343
-
344
- return flags
345
-
346
- def _circuit_forward(self, x: torch.Tensor) -> torch.Tensor:
347
- """Run input through threshold circuits."""
348
- # Extract operands
349
- a_bits, b_bits = self.bit_extractor(x)
350
-
351
- # Get operation weights
352
- op_weights = self.op_selector(x) # [..., 4]
353
-
354
- # Compute addition
355
- add_result, add_carry = self.circuits.add_8bit(a_bits, b_bits)
356
- add_flags = self._compute_flags(add_result, add_carry)
357
-
358
- # Compute subtraction (a + (~b) + 1, simplified: just use add for now)
359
- # For MVP, we'll focus on addition
360
-
361
- # Inject result back
362
- circuit_delta = self.bit_injector(add_result, add_flags)
363
-
364
- return circuit_delta
365
-
366
- def forward(self, x: torch.Tensor) -> torch.Tensor:
367
- """
368
- x: [batch, seq_len, d_model]
369
- Returns: [batch, seq_len, d_model]
370
- """
371
- # Original MLP path
372
- mlp_out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
373
-
374
- # Circuit path
375
- circuit_out = self._circuit_forward(x)
376
-
377
- # Route between paths
378
- route_weights = self.router(x) # [..., 2]
379
- mlp_weight = route_weights[..., 0:1]
380
- circuit_weight = route_weights[..., 1:2]
381
-
382
- # Combine: MLP output + weighted circuit contribution
383
- output = mlp_out + circuit_weight * circuit_out
384
-
385
- return output
386
-
387
-
388
- # =============================================================================
389
- # MODEL SURGERY: Insert circuits into SmolLM2
390
- # =============================================================================
391
-
392
- def augment_smollm2_with_circuits(
393
- model: AutoModelForCausalLM,
394
- circuit_path: str,
395
- layer_indices: list = None,
396
- device: str = 'cpu'
397
- ) -> AutoModelForCausalLM:
398
- """
399
- Surgically insert circuit blocks into SmolLM2's MLP layers.
400
-
401
- Args:
402
- model: Pretrained SmolLM2 model
403
- circuit_path: Path to neural_computer.safetensors
404
- layer_indices: Which layers to augment (default: middle layers)
405
- device: Device for circuit tensors
406
-
407
- Returns:
408
- Modified model with circuit-augmented MLPs
409
- """
410
- config = model.config
411
- num_layers = config.num_hidden_layers
412
-
413
- # Default: augment middle third of layers
414
- if layer_indices is None:
415
- start = num_layers // 3
416
- end = 2 * num_layers // 3
417
- layer_indices = list(range(start, end))
418
-
419
- print(f"Augmenting layers {layer_indices} with threshold circuits...")
420
-
421
- for idx in layer_indices:
422
- layer = model.model.layers[idx]
423
- old_mlp = layer.mlp
424
-
425
- # Create augmented MLP
426
- new_mlp = CircuitAugmentedMLP(
427
- d_model=config.hidden_size,
428
- intermediate_size=config.intermediate_size,
429
- circuit_path=circuit_path,
430
- device=device
431
- )
432
-
433
- # Copy pretrained weights
434
- new_mlp.gate_proj.weight.data = old_mlp.gate_proj.weight.data.clone()
435
- new_mlp.up_proj.weight.data = old_mlp.up_proj.weight.data.clone()
436
- new_mlp.down_proj.weight.data = old_mlp.down_proj.weight.data.clone()
437
-
438
- # Replace
439
- layer.mlp = new_mlp
440
-
441
- # Freeze circuit weights, keep interfaces trainable
442
- for name, param in model.named_parameters():
443
- if 'circuits' in name:
444
- param.requires_grad = False
445
-
446
- print(f"Done. Circuit weights frozen, interfaces trainable.")
447
-
448
- return model
449
-
450
-
451
- # =============================================================================
452
- # TRAINING UTILITIES
453
- # =============================================================================
454
-
455
- def generate_arithmetic_batch(batch_size: int, max_val: int = 255) -> Tuple[list, list]:
456
- """Generate batch of arithmetic problems and solutions."""
457
- prompts = []
458
- targets = []
459
-
460
- for _ in range(batch_size):
461
- a = torch.randint(0, max_val + 1, (1,)).item()
462
- b = torch.randint(0, max_val + 1, (1,)).item()
463
- result = (a + b) % 256
464
-
465
- prompts.append(f"{a} + {b} =")
466
- targets.append(f" {result}")
467
-
468
- return prompts, targets
469
-
470
-
471
- def evaluate_arithmetic(
472
- model: AutoModelForCausalLM,
473
- tokenizer: AutoTokenizer,
474
- n_problems: int = 100,
475
- device: str = 'cpu'
476
- ) -> dict:
477
- """Evaluate model on random arithmetic problems."""
478
- correct = 0
479
- total = 0
480
- errors = []
481
-
482
- model.eval()
483
-
484
- for _ in range(n_problems):
485
- a = torch.randint(0, 256, (1,)).item()
486
- b = torch.randint(0, 256, (1,)).item()
487
- expected = (a + b) % 256
488
-
489
- prompt = f"{a} + {b} ="
490
- inputs = tokenizer(prompt, return_tensors='pt').to(device)
491
-
492
- with torch.no_grad():
493
- outputs = model.generate(
494
- **inputs,
495
- max_new_tokens=10,
496
- do_sample=False,
497
- pad_token_id=tokenizer.eos_token_id
498
- )
499
-
500
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
501
-
502
- # Extract number from response
503
- try:
504
- # Find the part after "="
505
- answer_part = response.split('=')[-1].strip()
506
- # Extract first number
507
- predicted = int(''.join(c for c in answer_part.split()[0] if c.isdigit()))
508
-
509
- if predicted == expected:
510
- correct += 1
511
- else:
512
- errors.append((a, b, expected, predicted))
513
- except:
514
- errors.append((a, b, expected, "parse_error"))
515
-
516
- total += 1
517
-
518
- return {
519
- 'accuracy': correct / total,
520
- 'correct': correct,
521
- 'total': total,
522
- 'errors': errors[:10] # First 10 errors
523
- }
524
-
525
-
526
- # =============================================================================
527
- # MAIN: Demo
528
- # =============================================================================
529
-
530
- if __name__ == "__main__":
531
- import argparse
532
-
533
- parser = argparse.ArgumentParser(description='Circuit-Augmented LLM Demo')
534
- parser.add_argument('--circuit-path', type=str,
535
- default='./neural_computer.safetensors',
536
- help='Path to circuit weights')
537
- parser.add_argument('--device', type=str, default='cpu',
538
- help='Device (cpu or cuda)')
539
- parser.add_argument('--eval-only', action='store_true',
540
- help='Only evaluate, do not augment')
541
- args = parser.parse_args()
542
-
543
- print("=" * 70)
544
- print(" CIRCUIT-AUGMENTED LLM")
545
- print("=" * 70)
546
-
547
- # Load tokenizer and model
548
- print("\n[1] Loading SmolLM2-360M...")
549
- model_id = "HuggingFaceTB/SmolLM2-360M"
550
- tokenizer = AutoTokenizer.from_pretrained(model_id)
551
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
552
-
553
- print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
554
-
555
- # Baseline evaluation
556
- print("\n[2] Baseline arithmetic evaluation...")
557
- baseline = evaluate_arithmetic(model, tokenizer, n_problems=50, device=args.device)
558
- print(f" Accuracy: {baseline['accuracy']*100:.1f}% ({baseline['correct']}/{baseline['total']})")
559
- if baseline['errors']:
560
- print(f" Sample errors:")
561
- for a, b, exp, got in baseline['errors'][:5]:
562
- print(f" {a} + {b} = {exp}, model said {got}")
563
-
564
- if args.eval_only:
565
- print("\nDone (eval only mode).")
566
- exit(0)
567
-
568
- # Augment with circuits
569
- print(f"\n[3] Augmenting with threshold circuits...")
570
- print(f" Circuit path: {args.circuit_path}")
571
- model = augment_smollm2_with_circuits(
572
- model,
573
- args.circuit_path,
574
- device=args.device
575
- )
576
-
577
- new_params = sum(p.numel() for p in model.parameters())
578
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
579
- print(f" Total parameters: {new_params:,}")
580
- print(f" Trainable parameters: {trainable:,}")
581
-
582
- # Test circuit execution directly
583
- print("\n[4] Testing circuit execution...")
584
- circuit_exec = CircuitExecutor(args.circuit_path, args.device)
585
-
586
- test_cases = [(127, 128), (255, 1), (0, 0), (100, 55)]
587
- for a, b in test_cases:
588
- # Convert to bits (LSB first)
589
- a_bits = torch.tensor([(a >> i) & 1 for i in range(8)], dtype=torch.float32)
590
- b_bits = torch.tensor([(b >> i) & 1 for i in range(8)], dtype=torch.float32)
591
-
592
- result_bits, carry = circuit_exec.add_8bit(
593
- a_bits.unsqueeze(0),
594
- b_bits.unsqueeze(0)
595
- )
596
-
597
- # Convert result bits back to int
598
- result = sum(int(result_bits[0, i].item()) * (2**i) for i in range(8))
599
- expected = (a + b) % 256
600
-
601
- status = "OK" if result == expected else "FAIL"
602
- print(f" {a} + {b} = {result} (expected {expected}) [{status}]")
603
-
604
- print("\n[5] Model ready for fine-tuning.")
605
- print(" Next: Train interface layers on arithmetic examples.")
606
- print("=" * 70)