CharlesCNorton commited on
Commit
a6b5f0c
·
1 Parent(s): a47c121

Refactor LLM integration: consolidate training scripts, add extraction strategies

Browse files

TRAINING SCRIPT CONSOLIDATION:
- Merge train_passthrough.py, train_passthrough_router.py, train_llm.py into unified train.py
- Remove redundant 5-mode system (router, interface, embeddings, llm, enhanced)
- New 3-mode system: router (sanity), interface (sanity), llm (real training)
- Delete orphaned trained_passthrough_router.pt checkpoint
- Move outputs to trained/ subfolder (router.pt, interface.pt, llm.pt)

MODEL ARCHITECTURE UPDATES (model.py):
- Add ArithmeticModel: unified LLM + extractor + frozen circuits
- Add Extractor: attention pooling + per-bit extraction networks
- Add PositionExtractor: position-specific extraction from token positions
- Add DigitExtractor: predict digits (0-9) then convert to bits
- Add AttentionPooling: learnable CLS token attention over sequence
- Add MultiHeadBitExtractor: 8 separate networks for 8 bits
- Add HiddenStateExtractor: simple MLP-based bit extraction
- Remove EmbeddingArithmeticModel (mean pooling failed, ~33% accuracy plateau)
- Remove AugmentedArithmeticModel (merged into ArithmeticModel)

NEW TRAINING FEATURES (baked-in):
- Curriculum learning: 0-9 (epochs 0-20%) -> 0-99 (20-50%) -> 0-255 (50-100%)
- Loss reweighting: 2x multiplier for a/b bit losses (extraction is bottleneck)
- Per-batch progress reporting every 5 batches
- Per-epoch VRAM and timing stats

NEW CLI ARGUMENTS (--mode llm):
- --unfreeze_layers N: fine-tune top N transformer layers (default 0 = frozen)
- --extract_layer N: extract from layer N (-1 = last, try 12 for middle)
- --position_extract: use position-specific extraction instead of pooling
- --digit_pred: predict digits instead of bits (aligns with tokenization)

RATIONALE:
- Embeddings mode removed: mean pooling loses positional info, can't distinguish "47" from "74"
- Operation classification works (97-100%), bit extraction is the bottleneck (~33% accuracy)
- Position-specific and digit-level extraction may better align with LLM representations
- Curriculum learning helps model learn simpler cases before harder ones

USAGE:
python train.py --mode llm --epochs 100 # baseline
python train.py --mode llm --position_extract # position-specific
python train.py --mode llm --digit_pred # digit prediction
python train.py --mode llm --extract_layer 12 # middle layer
python train.py --mode llm --unfreeze_layers 4 # fine-tune LLM

llm_integration/model.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Trainable interface layers for frozen threshold circuits.
3
  BitEncoder, OpRouter, BitDecoder wrap the frozen circuits.
 
4
  """
5
 
6
  import torch
@@ -8,6 +9,10 @@ import torch.nn as nn
8
  import torch.nn.functional as F
9
  from circuits import FrozenThresholdCircuits, heaviside_ste
10
 
 
 
 
 
11
 
12
  class BitEncoder(nn.Module):
13
  """
@@ -50,12 +55,6 @@ class OpRouter(nn.Module):
50
  """
51
 
52
  def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32, n_ops: int = 6):
53
- """
54
- Args:
55
- input_dim: Input dimension
56
- hidden_dim: Hidden layer dimension
57
- n_ops: Number of operations to route between
58
- """
59
  super().__init__()
60
  self.net = nn.Sequential(
61
  nn.Linear(input_dim, hidden_dim),
@@ -83,21 +82,10 @@ class BitDecoder(nn.Module):
83
  """
84
 
85
  def __init__(self, output_dim: int = 8):
86
- """
87
- Args:
88
- output_dim: Output dimension (8 bits for result)
89
- """
90
  super().__init__()
91
  self.output_dim = output_dim
92
 
93
  def forward(self, result_bits: torch.Tensor) -> torch.Tensor:
94
- """
95
- Args:
96
- result_bits: [batch, 8] result bits from circuits
97
-
98
- Returns:
99
- output: [batch, 8] processed output
100
- """
101
  return result_bits
102
 
103
 
@@ -149,15 +137,6 @@ class ThresholdALU(nn.Module):
149
  op_onehot: torch.Tensor) -> torch.Tensor:
150
  """
151
  Direct forward through circuits (bypass encoder/router for testing).
152
- Uses ground truth bits and operation directly.
153
-
154
- Args:
155
- a_bits: [batch, 8] operand A bits
156
- b_bits: [batch, 8] operand B bits
157
- op_onehot: [batch, 6] one-hot operation
158
-
159
- Returns:
160
- result_bits: [batch, 8] output bits
161
  """
162
  return self.circuits(a_bits, b_bits, op_onehot)
163
 
@@ -175,10 +154,496 @@ class DirectCircuitModel(nn.Module):
175
 
176
  def forward(self, a_bits: torch.Tensor, b_bits: torch.Tensor,
177
  op_onehot: torch.Tensor) -> torch.Tensor:
178
- """Direct circuit execution."""
179
  return self.circuits(a_bits, b_bits, op_onehot)
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if __name__ == "__main__":
183
  import sys
184
  sys.path.insert(0, '.')
 
1
  """
2
  Trainable interface layers for frozen threshold circuits.
3
  BitEncoder, OpRouter, BitDecoder wrap the frozen circuits.
4
+ HiddenStateExtractor and AugmentedArithmeticModel for LLM integration.
5
  """
6
 
7
  import torch
 
9
  import torch.nn.functional as F
10
  from circuits import FrozenThresholdCircuits, heaviside_ste
11
 
12
+ MODEL_ID = 'HuggingFaceTB/SmolLM2-360M-Instruct'
13
+ OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq']
14
+ OP_SYMBOLS = {'add': '+', 'sub': '-', 'mul': '*', 'gt': '>', 'lt': '<', 'eq': '=='}
15
+
16
 
17
  class BitEncoder(nn.Module):
18
  """
 
55
  """
56
 
57
  def __init__(self, input_dim: int = 16 + 6, hidden_dim: int = 32, n_ops: int = 6):
 
 
 
 
 
 
58
  super().__init__()
59
  self.net = nn.Sequential(
60
  nn.Linear(input_dim, hidden_dim),
 
82
  """
83
 
84
  def __init__(self, output_dim: int = 8):
 
 
 
 
85
  super().__init__()
86
  self.output_dim = output_dim
87
 
88
  def forward(self, result_bits: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
89
  return result_bits
90
 
91
 
 
137
  op_onehot: torch.Tensor) -> torch.Tensor:
138
  """
139
  Direct forward through circuits (bypass encoder/router for testing).
 
 
 
 
 
 
 
 
 
140
  """
141
  return self.circuits(a_bits, b_bits, op_onehot)
142
 
 
154
 
155
  def forward(self, a_bits: torch.Tensor, b_bits: torch.Tensor,
156
  op_onehot: torch.Tensor) -> torch.Tensor:
 
157
  return self.circuits(a_bits, b_bits, op_onehot)
158
 
159
 
160
+ class HiddenStateExtractor(nn.Module):
161
+ """
162
+ Extracts operands and operation from LLM hidden states.
163
+ This is the hard part - must learn to parse numbers from embeddings.
164
+ """
165
+
166
+ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256):
167
+ super().__init__()
168
+
169
+ self.a_extractor = nn.Sequential(
170
+ nn.Linear(hidden_dim, intermediate_dim),
171
+ nn.GELU(),
172
+ nn.Linear(intermediate_dim, 8),
173
+ )
174
+
175
+ self.b_extractor = nn.Sequential(
176
+ nn.Linear(hidden_dim, intermediate_dim),
177
+ nn.GELU(),
178
+ nn.Linear(intermediate_dim, 8),
179
+ )
180
+
181
+ self.op_router = nn.Sequential(
182
+ nn.Linear(hidden_dim, intermediate_dim),
183
+ nn.GELU(),
184
+ nn.Linear(intermediate_dim, len(OPERATIONS)),
185
+ )
186
+
187
+ def forward(self, hidden_states: torch.Tensor):
188
+ """
189
+ Args:
190
+ hidden_states: [batch, hidden_dim] from LLM
191
+
192
+ Returns:
193
+ a_bits: [batch, 8]
194
+ b_bits: [batch, 8]
195
+ op_logits: [batch, 6]
196
+ """
197
+ a_logits = self.a_extractor(hidden_states)
198
+ b_logits = self.b_extractor(hidden_states)
199
+ op_logits = self.op_router(hidden_states)
200
+
201
+ a_soft = torch.sigmoid(a_logits)
202
+ b_soft = torch.sigmoid(b_logits)
203
+
204
+ a_hard = heaviside_ste(a_logits)
205
+ b_hard = heaviside_ste(b_logits)
206
+
207
+ a_bits = a_hard - a_soft.detach() + a_soft
208
+ b_bits = b_hard - b_soft.detach() + b_soft
209
+
210
+ return a_bits, b_bits, op_logits
211
+
212
+
213
+ class AttentionPooling(nn.Module):
214
+ """
215
+ Learnable attention pooling over sequence positions.
216
+ Replaces mean pooling - learns which tokens matter for extraction.
217
+ """
218
+
219
+ def __init__(self, hidden_dim: int = 960, num_heads: int = 4):
220
+ super().__init__()
221
+ self.num_heads = num_heads
222
+ self.head_dim = hidden_dim // num_heads
223
+
224
+ self.query = nn.Linear(hidden_dim, hidden_dim)
225
+ self.key = nn.Linear(hidden_dim, hidden_dim)
226
+ self.value = nn.Linear(hidden_dim, hidden_dim)
227
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim)
228
+
229
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
230
+
231
+ def forward(self, embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
232
+ """
233
+ Args:
234
+ embeddings: [batch, seq_len, hidden_dim]
235
+ mask: [batch, seq_len] attention mask (1 = attend, 0 = ignore)
236
+
237
+ Returns:
238
+ pooled: [batch, hidden_dim]
239
+ """
240
+ batch_size, seq_len, hidden_dim = embeddings.shape
241
+
242
+ cls_expanded = self.cls_token.expand(batch_size, -1, -1)
243
+ embeddings = torch.cat([cls_expanded, embeddings], dim=1)
244
+
245
+ cls_mask = torch.ones(batch_size, 1, device=mask.device)
246
+ mask = torch.cat([cls_mask, mask], dim=1)
247
+
248
+ Q = self.query(embeddings[:, :1, :])
249
+ K = self.key(embeddings)
250
+ V = self.value(embeddings)
251
+
252
+ Q = Q.view(batch_size, 1, self.num_heads, self.head_dim).transpose(1, 2)
253
+ K = K.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2)
254
+ V = V.view(batch_size, seq_len + 1, self.num_heads, self.head_dim).transpose(1, 2)
255
+
256
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
257
+
258
+ mask_expanded = mask.unsqueeze(1).unsqueeze(2)
259
+ scores = scores.masked_fill(mask_expanded == 0, -1e9)
260
+
261
+ attn_weights = torch.softmax(scores, dim=-1)
262
+ attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
263
+
264
+ context = torch.matmul(attn_weights, V)
265
+ context = context.transpose(1, 2).contiguous().view(batch_size, 1, hidden_dim)
266
+
267
+ pooled = self.out_proj(context).squeeze(1)
268
+ pooled = torch.nan_to_num(pooled, nan=0.0)
269
+
270
+ return pooled
271
+
272
+
273
+ class MultiHeadBitExtractor(nn.Module):
274
+ """
275
+ 8 separate extractors for 8 bits - each bit gets its own specialized network.
276
+ More expressive than single MLP predicting all 8 bits at once.
277
+ """
278
+
279
+ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 128):
280
+ super().__init__()
281
+
282
+ self.bit_extractors = nn.ModuleList([
283
+ nn.Sequential(
284
+ nn.Linear(hidden_dim, intermediate_dim),
285
+ nn.GELU(),
286
+ nn.Linear(intermediate_dim, 1),
287
+ )
288
+ for _ in range(8)
289
+ ])
290
+
291
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
292
+ """
293
+ Args:
294
+ hidden_states: [batch, hidden_dim]
295
+
296
+ Returns:
297
+ bits: [batch, 8] - one bit from each extractor
298
+ """
299
+ hidden_states = torch.nan_to_num(hidden_states, nan=0.0)
300
+
301
+ bit_logits = [extractor(hidden_states) for extractor in self.bit_extractors]
302
+ logits = torch.cat(bit_logits, dim=-1)
303
+ logits = torch.clamp(logits, -20, 20)
304
+
305
+ soft = torch.sigmoid(logits)
306
+ hard = heaviside_ste(logits)
307
+ bits = hard - soft.detach() + soft
308
+
309
+ return bits, logits
310
+
311
+
312
+ class Extractor(nn.Module):
313
+ """
314
+ Extracts operands and operation from LLM hidden states.
315
+ Uses attention pooling and per-bit extraction networks.
316
+ """
317
+
318
+ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
319
+ super().__init__()
320
+
321
+ self.attention_pool = AttentionPooling(hidden_dim, num_heads)
322
+
323
+ self.a_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2)
324
+ self.b_extractor = MultiHeadBitExtractor(hidden_dim, intermediate_dim // 2)
325
+
326
+ self.op_router = nn.Sequential(
327
+ nn.Linear(hidden_dim, intermediate_dim),
328
+ nn.GELU(),
329
+ nn.Linear(intermediate_dim, len(OPERATIONS)),
330
+ )
331
+
332
+ def forward(self, embeddings: torch.Tensor, mask: torch.Tensor):
333
+ """
334
+ Args:
335
+ embeddings: [batch, seq_len, hidden_dim]
336
+ mask: [batch, seq_len]
337
+
338
+ Returns:
339
+ a_bits: [batch, 8]
340
+ b_bits: [batch, 8]
341
+ op_logits: [batch, 6]
342
+ """
343
+ pooled = self.attention_pool(embeddings, mask)
344
+
345
+ a_bits, _ = self.a_extractor(pooled)
346
+ b_bits, _ = self.b_extractor(pooled)
347
+ op_logits = self.op_router(pooled)
348
+
349
+ return a_bits, b_bits, op_logits
350
+
351
+
352
+ class PositionExtractor(nn.Module):
353
+ """
354
+ Position-specific extraction.
355
+ Extracts operand A from first token positions, operand B from later positions.
356
+ For "47 + 86": positions 0-2 for A, position 3-4 for op, positions 5-7 for B.
357
+ """
358
+
359
+ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256):
360
+ super().__init__()
361
+
362
+ self.a_extractor = nn.Sequential(
363
+ nn.Linear(hidden_dim * 3, intermediate_dim),
364
+ nn.GELU(),
365
+ nn.Linear(intermediate_dim, 8),
366
+ )
367
+
368
+ self.b_extractor = nn.Sequential(
369
+ nn.Linear(hidden_dim * 3, intermediate_dim),
370
+ nn.GELU(),
371
+ nn.Linear(intermediate_dim, 8),
372
+ )
373
+
374
+ self.op_router = nn.Sequential(
375
+ nn.Linear(hidden_dim * 2, intermediate_dim),
376
+ nn.GELU(),
377
+ nn.Linear(intermediate_dim, len(OPERATIONS)),
378
+ )
379
+
380
+ def forward(self, hidden: torch.Tensor, mask: torch.Tensor):
381
+ """
382
+ Args:
383
+ hidden: [batch, seq_len, hidden_dim]
384
+ mask: [batch, seq_len]
385
+
386
+ Returns:
387
+ a_bits, b_bits, op_logits
388
+ """
389
+ batch_size, seq_len, hidden_dim = hidden.shape
390
+
391
+ seq_lens = mask.sum(dim=1).long()
392
+
393
+ a_features = []
394
+ b_features = []
395
+ op_features = []
396
+
397
+ for i in range(batch_size):
398
+ slen = seq_lens[i].item()
399
+ start = seq_len - slen
400
+
401
+ a_pos = hidden[i, start:start+3, :].reshape(-1)
402
+ if a_pos.shape[0] < hidden_dim * 3:
403
+ a_pos = F.pad(a_pos, (0, hidden_dim * 3 - a_pos.shape[0]))
404
+
405
+ op_pos = hidden[i, start+3:start+5, :].reshape(-1)
406
+ if op_pos.shape[0] < hidden_dim * 2:
407
+ op_pos = F.pad(op_pos, (0, hidden_dim * 2 - op_pos.shape[0]))
408
+
409
+ b_pos = hidden[i, start+5:start+8, :].reshape(-1)
410
+ if b_pos.shape[0] < hidden_dim * 3:
411
+ b_pos = F.pad(b_pos, (0, hidden_dim * 3 - b_pos.shape[0]))
412
+
413
+ a_features.append(a_pos)
414
+ b_features.append(b_pos)
415
+ op_features.append(op_pos)
416
+
417
+ a_features = torch.stack(a_features)
418
+ b_features = torch.stack(b_features)
419
+ op_features = torch.stack(op_features)
420
+
421
+ a_logits = self.a_extractor(a_features)
422
+ b_logits = self.b_extractor(b_features)
423
+ op_logits = self.op_router(op_features)
424
+
425
+ a_soft = torch.sigmoid(a_logits)
426
+ b_soft = torch.sigmoid(b_logits)
427
+ a_hard = heaviside_ste(a_logits)
428
+ b_hard = heaviside_ste(b_logits)
429
+ a_bits = a_hard - a_soft.detach() + a_soft
430
+ b_bits = b_hard - b_soft.detach() + b_soft
431
+
432
+ return a_bits, b_bits, op_logits
433
+
434
+
435
+ class DigitExtractor(nn.Module):
436
+ """
437
+ Digit-level extraction: predicts digits (0-9) then converts to bits.
438
+ More aligned with tokenization.
439
+ """
440
+
441
+ def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256, num_heads: int = 4):
442
+ super().__init__()
443
+
444
+ self.attention_pool = AttentionPooling(hidden_dim, num_heads)
445
+
446
+ self.a_digit_pred = nn.Sequential(
447
+ nn.Linear(hidden_dim, intermediate_dim),
448
+ nn.GELU(),
449
+ nn.Linear(intermediate_dim, 3 * 10),
450
+ )
451
+
452
+ self.b_digit_pred = nn.Sequential(
453
+ nn.Linear(hidden_dim, intermediate_dim),
454
+ nn.GELU(),
455
+ nn.Linear(intermediate_dim, 3 * 10),
456
+ )
457
+
458
+ self.op_router = nn.Sequential(
459
+ nn.Linear(hidden_dim, intermediate_dim),
460
+ nn.GELU(),
461
+ nn.Linear(intermediate_dim, len(OPERATIONS)),
462
+ )
463
+
464
+ def digits_to_bits(self, digit_logits: torch.Tensor) -> torch.Tensor:
465
+ """
466
+ Convert 3-digit predictions to 8-bit representation.
467
+ digit_logits: [batch, 30] (3 digits * 10 classes each)
468
+ Returns: [batch, 8] bits
469
+ """
470
+ batch_size = digit_logits.shape[0]
471
+
472
+ logits = digit_logits.view(batch_size, 3, 10)
473
+ probs = torch.softmax(logits, dim=-1)
474
+
475
+ digit_values = torch.arange(10, device=digit_logits.device).float()
476
+ soft_digits = (probs * digit_values).sum(dim=-1)
477
+
478
+ hundreds = soft_digits[:, 0]
479
+ tens = soft_digits[:, 1]
480
+ ones = soft_digits[:, 2]
481
+
482
+ value = hundreds * 100 + tens * 10 + ones
483
+ value = torch.clamp(value, 0, 255)
484
+
485
+ bits = []
486
+ for i in range(7, -1, -1):
487
+ bit = torch.fmod(torch.floor(value / (2 ** i)), 2)
488
+ bits.append(bit)
489
+
490
+ return torch.stack(bits, dim=-1)
491
+
492
+ def forward(self, hidden: torch.Tensor, mask: torch.Tensor):
493
+ """
494
+ Returns:
495
+ a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
496
+ """
497
+ pooled = self.attention_pool(hidden, mask)
498
+
499
+ a_digit_logits = self.a_digit_pred(pooled)
500
+ b_digit_logits = self.b_digit_pred(pooled)
501
+ op_logits = self.op_router(pooled)
502
+
503
+ a_bits = self.digits_to_bits(a_digit_logits)
504
+ b_bits = self.digits_to_bits(b_digit_logits)
505
+
506
+ return a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
507
+
508
+
509
+ class ArithmeticModel(nn.Module):
510
+ """
511
+ LLM + extractor + frozen threshold circuits.
512
+ Optionally unfreeze top N transformer layers with --unfreeze_layers.
513
+ """
514
+
515
+ def __init__(self, device: str = 'cuda', unfreeze_layers: int = 0,
516
+ extract_layer: int = -1, position_extract: bool = False,
517
+ digit_pred: bool = False):
518
+ super().__init__()
519
+ self.device = device
520
+ self.unfreeze_layers = unfreeze_layers
521
+ self.extract_layer = extract_layer
522
+ self.position_extract = position_extract
523
+ self.digit_pred = digit_pred
524
+
525
+ from transformers import AutoModelForCausalLM, AutoTokenizer
526
+
527
+ print("[1/4] Loading tokenizer...", flush=True)
528
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
529
+ self.tokenizer.padding_side = 'left'
530
+ if self.tokenizer.pad_token is None:
531
+ self.tokenizer.pad_token = self.tokenizer.eos_token
532
+ print(" Tokenizer loaded.", flush=True)
533
+
534
+ print("[2/4] Loading SmolLM2-360M...", flush=True)
535
+ self.llm = AutoModelForCausalLM.from_pretrained(
536
+ MODEL_ID,
537
+ torch_dtype=torch.float16,
538
+ device_map=device,
539
+ output_hidden_states=True
540
+ )
541
+
542
+ for param in self.llm.parameters():
543
+ param.requires_grad = False
544
+
545
+ if unfreeze_layers > 0:
546
+ num_layers = len(self.llm.model.layers)
547
+ layers_to_unfreeze = list(range(num_layers - unfreeze_layers, num_layers))
548
+ print(f" Unfreezing layers {layers_to_unfreeze}...", flush=True)
549
+ for layer_idx in layers_to_unfreeze:
550
+ for param in self.llm.model.layers[layer_idx].parameters():
551
+ param.requires_grad = True
552
+
553
+ hidden_dim = self.llm.config.hidden_size
554
+ llm_params = sum(p.numel() for p in self.llm.parameters())
555
+ trainable_llm = sum(p.numel() for p in self.llm.parameters() if p.requires_grad)
556
+ print(f" LLM loaded. Hidden dim: {hidden_dim}", flush=True)
557
+ print(f" LLM params: {llm_params:,} total, {trainable_llm:,} trainable", flush=True)
558
+
559
+ print("[3/4] Loading threshold circuits...", flush=True)
560
+ self.circuits = FrozenThresholdCircuits(device=device)
561
+ print(f" Circuits loaded. {len(self.circuits.weights)} tensors", flush=True)
562
+
563
+ print("[4/4] Initializing extractor...", flush=True)
564
+ if position_extract:
565
+ print(" Using position-specific extraction", flush=True)
566
+ self.extractor = PositionExtractor(
567
+ hidden_dim=hidden_dim,
568
+ intermediate_dim=256
569
+ ).to(device)
570
+ elif digit_pred:
571
+ print(" Using digit-level prediction", flush=True)
572
+ self.extractor = DigitExtractor(
573
+ hidden_dim=hidden_dim,
574
+ intermediate_dim=256,
575
+ num_heads=4
576
+ ).to(device)
577
+ else:
578
+ self.extractor = Extractor(
579
+ hidden_dim=hidden_dim,
580
+ intermediate_dim=256,
581
+ num_heads=4
582
+ ).to(device)
583
+
584
+ if extract_layer != -1:
585
+ print(f" Extracting from layer {extract_layer}", flush=True)
586
+
587
+ trainable_ext = sum(p.numel() for p in self.extractor.parameters())
588
+ total_trainable = trainable_llm + trainable_ext
589
+ print(f" Extractor params: {trainable_ext:,}", flush=True)
590
+ print(f" Total trainable: {total_trainable:,}", flush=True)
591
+
592
+ def get_hidden_states(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
593
+ """Get hidden states from specified layer."""
594
+ inputs = self.tokenizer(
595
+ texts,
596
+ return_tensors='pt',
597
+ padding=True,
598
+ truncation=True,
599
+ max_length=64
600
+ ).to(self.device)
601
+
602
+ if self.unfreeze_layers > 0:
603
+ outputs = self.llm(**inputs, output_hidden_states=True)
604
+ else:
605
+ with torch.no_grad():
606
+ outputs = self.llm(**inputs, output_hidden_states=True)
607
+
608
+ hidden = outputs.hidden_states[self.extract_layer].float()
609
+ mask = inputs.attention_mask.float()
610
+
611
+ return hidden, mask
612
+
613
+ def forward(self, texts: list[str]):
614
+ """
615
+ Full forward pass: text -> hidden states -> extractor -> circuits -> result
616
+
617
+ Returns:
618
+ result_bits, a_bits, b_bits, op_logits
619
+ If digit_pred: also returns a_digit_logits, b_digit_logits
620
+ """
621
+ hidden, mask = self.get_hidden_states(texts)
622
+
623
+ extractor_out = self.extractor(hidden, mask)
624
+
625
+ if self.digit_pred:
626
+ a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out
627
+ else:
628
+ a_bits, b_bits, op_logits = extractor_out
629
+ a_digit_logits, b_digit_logits = None, None
630
+
631
+ op_probs = torch.softmax(op_logits, dim=-1)
632
+
633
+ result_bits = self.circuits(a_bits, b_bits, op_probs)
634
+
635
+ if self.digit_pred:
636
+ return result_bits, a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
637
+ return result_bits, a_bits, b_bits, op_logits
638
+
639
+ def trainable_parameters(self):
640
+ """Return all trainable parameters for optimizer."""
641
+ params = list(self.extractor.parameters())
642
+ if self.unfreeze_layers > 0:
643
+ params += [p for p in self.llm.parameters() if p.requires_grad]
644
+ return params
645
+
646
+
647
  if __name__ == "__main__":
648
  import sys
649
  sys.path.insert(0, '.')
llm_integration/train.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified training script for threshold circuit LLM integration.
3
+
4
+ Modes:
5
+ --mode router : Train only OpRouter with ground truth bits (sanity check)
6
+ --mode interface : Train BitEncoder + OpRouter with ground truth bits (sanity check)
7
+ --mode llm : Train extractor with LLM hidden states (the real training)
8
+
9
+ LLM mode options:
10
+ --unfreeze_layers N : Unfreeze top N transformer layers (default 0 = fully frozen)
11
+
12
+ Hardware Profile (NVIDIA RTX 6000 Ada 48GB):
13
+ VRAM Scaling (unfreeze_layers=4):
14
+ batch_size | VRAM | %
15
+ -----------+---------+------
16
+ 512 | 5,784 | 11.8%
17
+ 1,024 | 7,384 | 15.0%
18
+ 4,096 | 16,534 | 33.6%
19
+ 13,000 | 39,000 | 79.4% <-- recommended for 80% target
20
+
21
+ Examples:
22
+ python train.py --mode llm --epochs 100 --batch_size 256
23
+ python train.py --mode llm --epochs 100 --batch_size 4096 --unfreeze_layers 4
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.optim as optim
29
+ import time
30
+ import argparse
31
+ import random
32
+
33
+ from model import (
34
+ ThresholdALU, DirectCircuitModel, OpRouter,
35
+ ArithmeticModel, OPERATIONS, OP_SYMBOLS
36
+ )
37
+ from circuits import FrozenThresholdCircuits
38
+ from fitness import generate_batch, compute_fitness, compute_loss
39
+
40
+ DEVICE = 'cuda'
41
+
42
+
43
+ def int_to_bits(val: int, device: str = 'cuda') -> torch.Tensor:
44
+ bits = torch.zeros(8, device=device)
45
+ for i in range(8):
46
+ bits[7-i] = (val >> i) & 1
47
+ return bits
48
+
49
+
50
+ def bits_to_int(bits: torch.Tensor) -> int:
51
+ val = 0
52
+ for i in range(8):
53
+ if bits[i].item() > 0.5:
54
+ val += 1 << (7-i)
55
+ return val
56
+
57
+
58
+ def generate_problem(max_val: int = 255):
59
+ """Generate a random arithmetic problem for LLM training."""
60
+ a = random.randint(0, max_val)
61
+ b = random.randint(0, max_val)
62
+ op = random.choice(OPERATIONS)
63
+
64
+ sym = OP_SYMBOLS[op]
65
+ text = f"{a} {sym} {b}"
66
+
67
+ if op == 'add':
68
+ result = (a + b) & 0xFF
69
+ elif op == 'sub':
70
+ result = (a - b) & 0xFF
71
+ elif op == 'mul':
72
+ result = (a * b) & 0xFF
73
+ elif op == 'gt':
74
+ result = 1 if a > b else 0
75
+ elif op == 'lt':
76
+ result = 1 if a < b else 0
77
+ elif op == 'eq':
78
+ result = 1 if a == b else 0
79
+
80
+ return text, a, b, op, result
81
+
82
+
83
+ def get_curriculum_max(epoch: int, total_epochs: int) -> int:
84
+ """
85
+ Curriculum learning: start with small numbers, gradually increase.
86
+ Epoch 0-20%: 0-9 (single digit)
87
+ Epoch 20-50%: 0-99 (two digit)
88
+ Epoch 50-100%: 0-255 (full range)
89
+ """
90
+ progress = epoch / total_epochs
91
+ if progress < 0.2:
92
+ return 9
93
+ elif progress < 0.5:
94
+ return 99
95
+ else:
96
+ return 255
97
+
98
+
99
+ def train_router(epochs: int = 100, batch_size: int = 256, lr: float = 1e-2, device: str = 'cuda'):
100
+ """Train only the router with ground truth bits."""
101
+ print("=" * 70)
102
+ print(" ROUTER-ONLY TRAINING (Ground Truth Bits)")
103
+ print("=" * 70)
104
+
105
+ circuits = FrozenThresholdCircuits(device=device)
106
+ router = OpRouter(input_dim=16 + 6, hidden_dim=64, n_ops=6).to(device)
107
+
108
+ print(f"\nRouter parameters: {sum(p.numel() for p in router.parameters()):,}")
109
+
110
+ def model_fn(a_bits, b_bits, op_onehot):
111
+ x = torch.cat([a_bits, b_bits, op_onehot], dim=-1)
112
+ op_weights = router(x)
113
+ return circuits(a_bits, b_bits, op_weights)
114
+
115
+ initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device)
116
+ print(f"Initial fitness: {initial_fitness:.4f}")
117
+
118
+ optimizer = optim.AdamW(router.parameters(), lr=lr)
119
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
120
+
121
+ print("\nTraining...")
122
+ print("-" * 70)
123
+
124
+ best_fitness = initial_fitness
125
+ start_time = time.perf_counter()
126
+
127
+ for epoch in range(epochs):
128
+ router.train()
129
+ epoch_loss = 0.0
130
+
131
+ for _ in range(100):
132
+ batch = generate_batch(batch_size, device)
133
+
134
+ optimizer.zero_grad()
135
+
136
+ x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1)
137
+ op_weights = router(x)
138
+ pred_bits = circuits(batch['a_bits'], batch['b_bits'], op_weights)
139
+
140
+ loss = compute_loss(pred_bits, batch['expected_bits'])
141
+ loss.backward()
142
+ optimizer.step()
143
+
144
+ epoch_loss += loss.item()
145
+
146
+ scheduler.step()
147
+
148
+ if (epoch + 1) % 10 == 0 or epoch == 0:
149
+ router.eval()
150
+ fitness, details = compute_fitness(model_fn, n_samples=2000, device=device, return_details=True)
151
+ elapsed = time.perf_counter() - start_time
152
+
153
+ if fitness > best_fitness:
154
+ best_fitness = fitness
155
+ marker = " *"
156
+ else:
157
+ marker = ""
158
+
159
+ print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/100:.4f} | "
160
+ f"Fitness: {fitness:.4f}{marker} | Time: {elapsed:.1f}s")
161
+
162
+ if fitness >= 0.9999:
163
+ print("\n TARGET: 100% FITNESS ACHIEVED")
164
+ break
165
+
166
+ print("\n" + "=" * 70)
167
+ print(" RESULTS")
168
+ print("=" * 70)
169
+
170
+ router.eval()
171
+ final_fitness, details = compute_fitness(model_fn, n_samples=5000, device=device, return_details=True)
172
+
173
+ print(f"\nFinal fitness: {final_fitness:.4f}")
174
+ print(f"\nPer-operation:")
175
+ for op in OPERATIONS:
176
+ acc = details['by_op'][op]['accuracy']
177
+ print(f" {op}: {acc:.4f}")
178
+
179
+ print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s")
180
+
181
+ if final_fitness >= 0.99:
182
+ print("\nCONCLUSION: Router successfully learned operation dispatch.")
183
+ print(" With correct bit encoding, 100% is achievable.")
184
+
185
+ save_path = "D:/8bit-threshold-computer/llm_integration/trained/router.pt"
186
+ torch.save({
187
+ 'router_state_dict': router.state_dict(),
188
+ 'final_fitness': final_fitness,
189
+ 'params': sum(p.numel() for p in router.parameters()),
190
+ }, save_path)
191
+ print(f"\nSaved trained router to: {save_path}")
192
+
193
+ return router, final_fitness
194
+
195
+
196
+ def get_gpu_memory():
197
+ """Get GPU memory usage in MB."""
198
+ if torch.cuda.is_available():
199
+ return torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.max_memory_allocated() / 1024 / 1024
200
+ return 0, 0
201
+
202
+
203
+ def train_interface(epochs: int = 200, batch_size: int = 512, lr: float = 1e-3,
204
+ eval_interval: int = 10, device: str = 'cuda'):
205
+ """Train BitEncoder + OpRouter with ground truth bits."""
206
+ print("=" * 70)
207
+ print(" INTERFACE TRAINING (Encoder + Router)")
208
+ print("=" * 70)
209
+ print(f" Started at: {time.strftime('%H:%M:%S')}")
210
+
211
+ print("\n[1/4] Verifying frozen circuits...")
212
+ print(" Creating DirectCircuitModel...", end=" ", flush=True)
213
+ direct_model = DirectCircuitModel(device=device)
214
+ mem, max_mem = get_gpu_memory()
215
+ print(f"done. VRAM: {mem:.0f}MB")
216
+
217
+ def direct_fn(a, b, op):
218
+ return direct_model(a, b, op)
219
+
220
+ print(" Running fitness check (1000 samples)...", end=" ", flush=True)
221
+ circuit_fitness = compute_fitness(direct_fn, n_samples=1000, device=device)
222
+ print(f"done. Fitness: {circuit_fitness:.4f}")
223
+ if circuit_fitness < 0.999:
224
+ print(" ERROR: Circuits not achieving 100%. Aborting.")
225
+ return None, 0.0
226
+ print(" STATUS: PASS")
227
+
228
+ print("\n[2/4] Initializing model...")
229
+ print(" Creating ThresholdALU...", end=" ", flush=True)
230
+ model = ThresholdALU(device=device)
231
+ mem, max_mem = get_gpu_memory()
232
+ print(f"done. VRAM: {mem:.0f}MB")
233
+
234
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
235
+ print(f" Trainable parameters: {trainable_params:,}")
236
+
237
+ def model_fn(a, b, op):
238
+ return model(a, b, op)
239
+
240
+ print(" Running initial fitness check...", end=" ", flush=True)
241
+ initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device)
242
+ print(f"done. Fitness: {initial_fitness:.4f}")
243
+
244
+ print("\n[3/4] Setting up training...")
245
+ print(" Creating optimizer...", end=" ", flush=True)
246
+ optimizer = optim.AdamW(model.parameters(), lr=lr)
247
+ print("done.")
248
+ print(" Creating scheduler...", end=" ", flush=True)
249
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
250
+ print("done.")
251
+
252
+ print(f" Config: lr={lr}, batch_size={batch_size}, epochs={epochs}")
253
+
254
+ print("\n[4/4] Training...")
255
+ print(" Generating first batch to warm up...", end=" ", flush=True)
256
+ warmup_batch = generate_batch(batch_size, device)
257
+ mem, max_mem = get_gpu_memory()
258
+ print(f"done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)")
259
+
260
+ print("-" * 70)
261
+
262
+ best_fitness = initial_fitness
263
+ start_time = time.perf_counter()
264
+ n_batches = 100
265
+
266
+ for epoch in range(epochs):
267
+ model.train()
268
+ epoch_loss = 0.0
269
+ epoch_start = time.perf_counter()
270
+
271
+ for batch_idx in range(n_batches):
272
+ batch = generate_batch(batch_size, device)
273
+
274
+ optimizer.zero_grad()
275
+
276
+ pred_bits = model(batch['a_bits'], batch['b_bits'], batch['op_onehot'])
277
+
278
+ loss = compute_loss(pred_bits, batch['expected_bits'])
279
+
280
+ loss.backward()
281
+ optimizer.step()
282
+
283
+ epoch_loss += loss.item()
284
+
285
+ if batch_idx == 0 and epoch == 0:
286
+ mem, max_mem = get_gpu_memory()
287
+ print(f" First forward/backward done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)")
288
+
289
+ if (batch_idx + 1) % 25 == 0:
290
+ avg_so_far = epoch_loss / (batch_idx + 1)
291
+ print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True)
292
+
293
+ scheduler.step()
294
+
295
+ avg_loss = epoch_loss / n_batches
296
+ epoch_time = time.perf_counter() - epoch_start
297
+
298
+ if (epoch + 1) % 5 == 0 or epoch == 0: # Eval every 5 epochs
299
+ model.eval()
300
+ fitness, details = compute_fitness(
301
+ model_fn, n_samples=2000, device=device, return_details=True
302
+ )
303
+
304
+ elapsed = time.perf_counter() - start_time
305
+
306
+ if fitness > best_fitness:
307
+ best_fitness = fitness
308
+ marker = " *"
309
+ else:
310
+ marker = ""
311
+
312
+ mem, _ = get_gpu_memory()
313
+ print(f"Epoch {epoch+1:4d} | Loss: {avg_loss:.4f} | "
314
+ f"Fitness: {fitness:.4f}{marker} | "
315
+ f"LR: {scheduler.get_last_lr()[0]:.2e} | "
316
+ f"VRAM: {mem:.0f}MB | "
317
+ f"Time: {elapsed:.1f}s ({epoch_time:.1f}s/epoch)")
318
+
319
+ if fitness >= 0.9999:
320
+ print("\n" + "=" * 70)
321
+ print(" TARGET ACHIEVED: 100% FITNESS")
322
+ print("=" * 70)
323
+ break
324
+
325
+ print("\n" + "=" * 70)
326
+ print(" TRAINING COMPLETE")
327
+ print("=" * 70)
328
+
329
+ model.eval()
330
+ final_fitness, details = compute_fitness(
331
+ model_fn, n_samples=5000, device=device, return_details=True
332
+ )
333
+
334
+ print(f"\nFinal fitness: {final_fitness:.4f}")
335
+ print(f"Best fitness: {best_fitness:.4f}")
336
+ print(f"\nPer-operation breakdown:")
337
+ for op in OPERATIONS:
338
+ acc = details['by_op'][op]['accuracy']
339
+ print(f" {op:6}: {acc:.4f}")
340
+
341
+ print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s")
342
+
343
+ save_path = "D:/8bit-threshold-computer/llm_integration/trained/interface.pt"
344
+ torch.save({
345
+ 'encoder_state_dict': model.encoder.state_dict(),
346
+ 'router_state_dict': model.router.state_dict(),
347
+ 'final_fitness': final_fitness,
348
+ 'best_fitness': best_fitness,
349
+ }, save_path)
350
+ print(f"\nSaved trained model to: {save_path}")
351
+
352
+ return model, final_fitness
353
+
354
+
355
+ def compute_llm_loss(pred_bits, a_bits, b_bits, op_logits,
356
+ target_result, target_a, target_b, target_op_idx,
357
+ bit_weight: float = 2.0):
358
+ """
359
+ Multi-component loss for LLM training.
360
+ bit_weight: multiplier for a/b bit losses (default 2x since extraction is the bottleneck)
361
+ """
362
+ result_loss = nn.functional.binary_cross_entropy_with_logits(
363
+ pred_bits, target_result
364
+ )
365
+
366
+ a_bits_safe = torch.clamp(a_bits, 0.0, 1.0)
367
+ b_bits_safe = torch.clamp(b_bits, 0.0, 1.0)
368
+ a_bits_safe = torch.nan_to_num(a_bits_safe, nan=0.5, posinf=1.0, neginf=0.0)
369
+ b_bits_safe = torch.nan_to_num(b_bits_safe, nan=0.5, posinf=1.0, neginf=0.0)
370
+
371
+ a_loss = nn.functional.binary_cross_entropy(
372
+ torch.clamp(a_bits_safe, 1e-6, 1-1e-6), target_a
373
+ )
374
+ b_loss = nn.functional.binary_cross_entropy(
375
+ torch.clamp(b_bits_safe, 1e-6, 1-1e-6), target_b
376
+ )
377
+
378
+ op_loss = nn.functional.cross_entropy(op_logits, target_op_idx)
379
+
380
+ total = result_loss + bit_weight * a_loss + bit_weight * b_loss + op_loss
381
+ total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0)
382
+
383
+ return total, {
384
+ 'result': result_loss.item() if not torch.isnan(result_loss) else 10.0,
385
+ 'a': a_loss.item() if not torch.isnan(a_loss) else 10.0,
386
+ 'b': b_loss.item() if not torch.isnan(b_loss) else 10.0,
387
+ 'op': op_loss.item() if not torch.isnan(op_loss) else 10.0
388
+ }
389
+
390
+
391
+ def evaluate_llm(model, n_samples: int = 500):
392
+ """Evaluate LLM model on random problems."""
393
+ model.extractor.eval()
394
+ correct = 0
395
+ op_correct = 0
396
+
397
+ for _ in range(n_samples):
398
+ text, a, b, op, expected = generate_problem()
399
+
400
+ with torch.no_grad():
401
+ result_bits, a_bits, b_bits, op_logits = model([text])
402
+
403
+ pred_result = bits_to_int(result_bits[0])
404
+ pred_op = OPERATIONS[op_logits[0].argmax().item()]
405
+
406
+ if pred_result == expected:
407
+ correct += 1
408
+ if pred_op == op:
409
+ op_correct += 1
410
+
411
+ model.extractor.train()
412
+ return correct / n_samples, op_correct / n_samples
413
+
414
+
415
+ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
416
+ unfreeze_layers: int = 0, extract_layer: int = -1,
417
+ position_extract: bool = False, digit_pred: bool = False,
418
+ device: str = 'cuda'):
419
+ """
420
+ Train extractor with LLM hidden states.
421
+
422
+ Args:
423
+ unfreeze_layers: Number of top transformer layers to unfreeze (0 = fully frozen)
424
+ extract_layer: Which layer to extract from (-1 = last)
425
+ position_extract: Use position-specific extraction
426
+ digit_pred: Predict digits instead of bits
427
+ """
428
+ print("=" * 70)
429
+ print(" LLM TRAINING")
430
+ if unfreeze_layers > 0:
431
+ print(f" {unfreeze_layers} transformer layers unfrozen")
432
+ else:
433
+ print(" LLM frozen")
434
+ if extract_layer != -1:
435
+ print(f" Extracting from layer {extract_layer}")
436
+ if position_extract:
437
+ print(" Position-specific extraction")
438
+ if digit_pred:
439
+ print(" Digit-level prediction")
440
+ print("=" * 70)
441
+
442
+ print("\nInitializing model...")
443
+ model = ArithmeticModel(
444
+ device=device,
445
+ unfreeze_layers=unfreeze_layers,
446
+ extract_layer=extract_layer,
447
+ position_extract=position_extract,
448
+ digit_pred=digit_pred
449
+ )
450
+
451
+ optimizer = optim.AdamW(model.trainable_parameters(), lr=lr)
452
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
453
+
454
+ print(f"\nTraining config:")
455
+ print(f" Epochs: {epochs}")
456
+ print(f" Batch size: {batch_size}")
457
+ print(f" Learning rate: {lr}")
458
+ print(f" Unfreeze layers: {unfreeze_layers}")
459
+ print(f" Samples/epoch: {batch_size * 20}")
460
+
461
+ print(f"\nInitial evaluation (200 samples)...")
462
+ acc, op_acc = evaluate_llm(model, n_samples=200)
463
+ print(f" Accuracy: {acc:.4f}, Op accuracy: {op_acc:.4f}")
464
+
465
+ print(f"\nStarting training...")
466
+ print("-" * 70)
467
+
468
+ best_acc = acc
469
+ start_time = time.perf_counter()
470
+
471
+ for epoch in range(epochs):
472
+ model.extractor.train()
473
+ if unfreeze_layers > 0:
474
+ model.llm.train()
475
+
476
+ max_val = get_curriculum_max(epoch, epochs)
477
+
478
+ epoch_loss = 0
479
+ epoch_losses = {'result': 0, 'a': 0, 'b': 0, 'op': 0}
480
+ n_batches = 20
481
+ epoch_start = time.perf_counter()
482
+
483
+ for batch_idx in range(n_batches):
484
+ batch_texts = []
485
+ batch_a = []
486
+ batch_b = []
487
+ batch_op = []
488
+ batch_result = []
489
+
490
+ for _ in range(batch_size):
491
+ text, a, b, op, result = generate_problem(max_val)
492
+ batch_texts.append(text)
493
+ batch_a.append(int_to_bits(a, device))
494
+ batch_b.append(int_to_bits(b, device))
495
+ batch_op.append(OPERATIONS.index(op))
496
+ batch_result.append(int_to_bits(result, device))
497
+
498
+ target_a = torch.stack(batch_a)
499
+ target_b = torch.stack(batch_b)
500
+ target_op = torch.tensor(batch_op, device=device)
501
+ target_result = torch.stack(batch_result)
502
+
503
+ optimizer.zero_grad()
504
+
505
+ pred_bits, a_bits, b_bits, op_logits = model(batch_texts)
506
+
507
+ loss, losses = compute_llm_loss(
508
+ pred_bits, a_bits, b_bits, op_logits,
509
+ target_result, target_a, target_b, target_op
510
+ )
511
+
512
+ loss.backward()
513
+ torch.nn.utils.clip_grad_norm_(model.trainable_parameters(), 1.0)
514
+ optimizer.step()
515
+
516
+ epoch_loss += loss.item()
517
+ for k in epoch_losses:
518
+ epoch_losses[k] += losses[k]
519
+
520
+ if (batch_idx + 1) % 5 == 0:
521
+ avg_so_far = epoch_loss / (batch_idx + 1)
522
+ print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True)
523
+
524
+ epoch_time = time.perf_counter() - epoch_start
525
+ scheduler.step()
526
+
527
+ avg_loss = epoch_loss / n_batches
528
+ for k in epoch_losses:
529
+ epoch_losses[k] /= n_batches
530
+
531
+ acc, op_acc = evaluate_llm(model, n_samples=300)
532
+ elapsed = time.perf_counter() - start_time
533
+
534
+ marker = " *" if acc > best_acc else ""
535
+ if acc > best_acc:
536
+ best_acc = acc
537
+
538
+ mem, _ = get_gpu_memory()
539
+ print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | "
540
+ f"Acc: {acc:.4f}{marker} | OpAcc: {op_acc:.4f} | "
541
+ f"Range: 0-{max_val} | VRAM: {mem:.0f}MB | Time: {elapsed:.0f}s")
542
+ print(f" Losses - result:{epoch_losses['result']:.4f} "
543
+ f"a:{epoch_losses['a']:.4f} b:{epoch_losses['b']:.4f} "
544
+ f"op:{epoch_losses['op']:.4f}")
545
+
546
+ print("\n" + "=" * 70)
547
+ print(" FINAL EVALUATION")
548
+ print("=" * 70)
549
+
550
+ acc, op_acc = evaluate_llm(model, n_samples=1000)
551
+ print(f"Final accuracy: {acc:.4f}")
552
+ print(f"Final op accuracy: {op_acc:.4f}")
553
+ print(f"Best accuracy: {best_acc:.4f}")
554
+
555
+ print("\nSample predictions:")
556
+ for _ in range(10):
557
+ text, a, b, op, expected = generate_problem()
558
+ with torch.no_grad():
559
+ result_bits, a_bits, b_bits, op_logits = model([text])
560
+ pred = bits_to_int(result_bits[0])
561
+ pred_a = bits_to_int(a_bits[0])
562
+ pred_b = bits_to_int(b_bits[0])
563
+ pred_op = OPERATIONS[op_logits[0].argmax().item()]
564
+
565
+ status = "OK" if pred == expected else "WRONG"
566
+ print(f" '{text}' = {expected} | pred={pred} (a={pred_a}, b={pred_b}, op={pred_op}) [{status}]")
567
+
568
+ save_path = "D:/8bit-threshold-computer/llm_integration/trained/llm.pt"
569
+ save_dict = {
570
+ 'extractor_state_dict': model.extractor.state_dict(),
571
+ 'final_accuracy': acc,
572
+ 'best_accuracy': best_acc,
573
+ 'unfreeze_layers': unfreeze_layers,
574
+ }
575
+ if unfreeze_layers > 0:
576
+ save_dict['llm_state_dict'] = {
577
+ k: v for k, v in model.llm.state_dict().items()
578
+ if any(f'layers.{i}.' in k for i in range(len(model.llm.model.layers) - unfreeze_layers, len(model.llm.model.layers)))
579
+ }
580
+ torch.save(save_dict, save_path)
581
+ print(f"\nSaved to: {save_path}")
582
+
583
+ return model, acc
584
+
585
+
586
+ def main():
587
+ parser = argparse.ArgumentParser(
588
+ description='Unified training for threshold circuit LLM integration',
589
+ formatter_class=argparse.RawDescriptionHelpFormatter,
590
+ epilog="""
591
+ Modes:
592
+ router - Train only OpRouter with ground truth bits (sanity check)
593
+ interface - Train BitEncoder + OpRouter with ground truth bits (sanity check)
594
+ llm - Train extractor with LLM hidden states (the real training)
595
+
596
+ LLM options:
597
+ --unfreeze_layers N Fine-tune top N transformer layers
598
+ --extract_layer N Extract from layer N (-1 = last)
599
+ --position_extract Use position-specific extraction
600
+ --digit_pred Predict digits instead of bits
601
+
602
+ Baked-in: curriculum learning (0-9 -> 0-99 -> 0-255), 2x loss weight for a/b
603
+
604
+ Examples:
605
+ python train.py --mode llm --epochs 100
606
+ python train.py --mode llm --position_extract
607
+ python train.py --mode llm --digit_pred --extract_layer 12
608
+ python train.py --mode llm --unfreeze_layers 4 --batch_size 4096
609
+ """
610
+ )
611
+ parser.add_argument('--mode', type=str, required=True,
612
+ choices=['router', 'interface', 'llm'],
613
+ help='Training mode')
614
+ parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
615
+ parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
616
+ parser.add_argument('--lr', type=float, default=None,
617
+ help='Learning rate (default: mode-specific)')
618
+ parser.add_argument('--unfreeze_layers', type=int, default=0,
619
+ help='Unfreeze top N transformer layers (default 0 = frozen)')
620
+ parser.add_argument('--extract_layer', type=int, default=-1,
621
+ help='Which layer to extract from (-1 = last)')
622
+ parser.add_argument('--position_extract', action='store_true',
623
+ help='Use position-specific extraction')
624
+ parser.add_argument('--digit_pred', action='store_true',
625
+ help='Predict digits instead of bits')
626
+ parser.add_argument('--device', type=str, default='cuda', help='Device')
627
+ args = parser.parse_args()
628
+
629
+ torch.manual_seed(42)
630
+ random.seed(42)
631
+
632
+ if args.mode == 'router':
633
+ lr = args.lr if args.lr is not None else 1e-2
634
+ train_router(epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device)
635
+
636
+ elif args.mode == 'interface':
637
+ lr = args.lr if args.lr is not None else 1e-3
638
+ model, fitness = train_interface(
639
+ epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device
640
+ )
641
+
642
+ print("\n" + "=" * 70)
643
+ print(" EXPERIMENT SUMMARY")
644
+ print("=" * 70)
645
+ print(f"\n Control (Vanilla SmolLM2-360M): 11.90%")
646
+ print(f" Experimental (Trained Interface): {100*fitness:.2f}%")
647
+ if fitness > 0:
648
+ print(f"\n Improvement: {100*(fitness - 0.119)/0.119:.1f}%")
649
+
650
+ if fitness >= 0.99:
651
+ print("\n CONCLUSION: Frozen threshold circuits + trained interface")
652
+ print(" achieves near-perfect arithmetic accuracy.")
653
+ print(" Core thesis VALIDATED.")
654
+ else:
655
+ print(f"\n CONCLUSION: Further training or architecture changes needed.")
656
+ print(f" Current gap: {100*(1.0 - fitness):.2f}%")
657
+
658
+ elif args.mode == 'llm':
659
+ lr = args.lr if args.lr is not None else 3e-4
660
+ train_llm(
661
+ epochs=args.epochs,
662
+ batch_size=args.batch_size,
663
+ lr=lr,
664
+ unfreeze_layers=args.unfreeze_layers,
665
+ extract_layer=args.extract_layer,
666
+ position_extract=args.position_extract,
667
+ digit_pred=args.digit_pred,
668
+ device=args.device
669
+ )
670
+
671
+
672
+ if __name__ == "__main__":
673
+ main()
llm_integration/train_llm.py DELETED
@@ -1,387 +0,0 @@
1
- """
2
- LLM Integration Training
3
-
4
- Train interface layers to extract operands from SmolLM2 hidden states.
5
- The hard part: learning to parse "47 + 86" into bits from embeddings.
6
- """
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.optim as optim
11
- import random
12
- import time
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
14
- from circuits import FrozenThresholdCircuits, heaviside_ste
15
-
16
- DEVICE = 'cuda'
17
- MODEL_ID = 'HuggingFaceTB/SmolLM2-360M-Instruct'
18
-
19
- OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq']
20
- OP_SYMBOLS = {'add': '+', 'sub': '-', 'mul': '*', 'gt': '>', 'lt': '<', 'eq': '=='}
21
-
22
-
23
- class HiddenStateExtractor(nn.Module):
24
- """
25
- Extracts operands and operation from LLM hidden states.
26
- This is the hard part - must learn to parse numbers from embeddings.
27
- """
28
-
29
- def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256):
30
- super().__init__()
31
-
32
- self.a_extractor = nn.Sequential(
33
- nn.Linear(hidden_dim, intermediate_dim),
34
- nn.GELU(),
35
- nn.Linear(intermediate_dim, 8),
36
- )
37
-
38
- self.b_extractor = nn.Sequential(
39
- nn.Linear(hidden_dim, intermediate_dim),
40
- nn.GELU(),
41
- nn.Linear(intermediate_dim, 8),
42
- )
43
-
44
- self.op_router = nn.Sequential(
45
- nn.Linear(hidden_dim, intermediate_dim),
46
- nn.GELU(),
47
- nn.Linear(intermediate_dim, len(OPERATIONS)),
48
- )
49
-
50
- def forward(self, hidden_states: torch.Tensor):
51
- """
52
- Args:
53
- hidden_states: [batch, hidden_dim] from LLM
54
-
55
- Returns:
56
- a_bits: [batch, 8]
57
- b_bits: [batch, 8]
58
- op_logits: [batch, 6]
59
- """
60
- a_logits = self.a_extractor(hidden_states)
61
- b_logits = self.b_extractor(hidden_states)
62
- op_logits = self.op_router(hidden_states)
63
-
64
- a_soft = torch.sigmoid(a_logits)
65
- b_soft = torch.sigmoid(b_logits)
66
-
67
- a_hard = heaviside_ste(a_logits)
68
- b_hard = heaviside_ste(b_logits)
69
-
70
- a_bits = a_hard - a_soft.detach() + a_soft
71
- b_bits = b_hard - b_soft.detach() + b_soft
72
-
73
- return a_bits, b_bits, op_logits
74
-
75
-
76
- class AugmentedArithmeticModel(nn.Module):
77
- """
78
- SmolLM2 + frozen threshold circuits.
79
- Trains only the extraction interface.
80
- """
81
-
82
- def __init__(self, device: str = 'cuda'):
83
- super().__init__()
84
- self.device = device
85
-
86
- print("[1/4] Loading tokenizer...", flush=True)
87
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
88
- self.tokenizer.padding_side = 'left'
89
- if self.tokenizer.pad_token is None:
90
- self.tokenizer.pad_token = self.tokenizer.eos_token
91
- print(" Tokenizer loaded.", flush=True)
92
-
93
- print("[2/4] Loading SmolLM2-360M...", flush=True)
94
- self.llm = AutoModelForCausalLM.from_pretrained(
95
- MODEL_ID,
96
- torch_dtype=torch.float16,
97
- device_map=device,
98
- output_hidden_states=True
99
- )
100
- self.llm.eval()
101
- for param in self.llm.parameters():
102
- param.requires_grad = False
103
-
104
- hidden_dim = self.llm.config.hidden_size
105
- llm_params = sum(p.numel() for p in self.llm.parameters())
106
- print(f" LLM loaded. Hidden dim: {hidden_dim}, Params: {llm_params:,}", flush=True)
107
-
108
- print("[3/4] Loading threshold circuits...", flush=True)
109
- self.circuits = FrozenThresholdCircuits(device=device)
110
- print(f" Circuits loaded. {len(self.circuits.weights)} tensors", flush=True)
111
-
112
- print("[4/4] Initializing extractor...", flush=True)
113
- self.extractor = HiddenStateExtractor(
114
- hidden_dim=hidden_dim,
115
- intermediate_dim=256
116
- ).to(device)
117
-
118
- trainable = sum(p.numel() for p in self.extractor.parameters())
119
- print(f" Extractor ready. Trainable params: {trainable:,}", flush=True)
120
-
121
- def get_hidden_states(self, texts: list[str]) -> torch.Tensor:
122
- """Get hidden states from last layer for each input."""
123
- inputs = self.tokenizer(
124
- texts,
125
- return_tensors='pt',
126
- padding=True,
127
- truncation=True,
128
- max_length=64
129
- ).to(self.device)
130
-
131
- with torch.no_grad():
132
- outputs = self.llm(**inputs, output_hidden_states=True)
133
-
134
- last_hidden = outputs.hidden_states[-1]
135
- mask = inputs.attention_mask
136
- seq_lens = mask.sum(dim=1) - 1
137
- batch_size = last_hidden.shape[0]
138
-
139
- final_hidden = torch.stack([
140
- last_hidden[i, seq_lens[i], :]
141
- for i in range(batch_size)
142
- ])
143
-
144
- return final_hidden.float()
145
-
146
- def forward(self, texts: list[str]):
147
- """
148
- Full forward pass: text → hidden states → extractor → circuits → result
149
- """
150
- hidden = self.get_hidden_states(texts)
151
-
152
- a_bits, b_bits, op_logits = self.extractor(hidden)
153
-
154
- op_probs = torch.softmax(op_logits, dim=-1)
155
-
156
- result_bits = self.circuits(a_bits, b_bits, op_probs)
157
-
158
- return result_bits, a_bits, b_bits, op_logits
159
-
160
-
161
- def generate_problem():
162
- """Generate a random arithmetic problem."""
163
- a = random.randint(0, 255)
164
- b = random.randint(0, 255)
165
- op = random.choice(OPERATIONS)
166
-
167
- sym = OP_SYMBOLS[op]
168
- text = f"{a} {sym} {b}"
169
-
170
- if op == 'add':
171
- result = (a + b) & 0xFF
172
- elif op == 'sub':
173
- result = (a - b) & 0xFF
174
- elif op == 'mul':
175
- result = (a * b) & 0xFF
176
- elif op == 'gt':
177
- result = 1 if a > b else 0
178
- elif op == 'lt':
179
- result = 1 if a < b else 0
180
- elif op == 'eq':
181
- result = 1 if a == b else 0
182
-
183
- return text, a, b, op, result
184
-
185
-
186
- def int_to_bits(val: int, device: str = 'cuda') -> torch.Tensor:
187
- bits = torch.zeros(8, device=device)
188
- for i in range(8):
189
- bits[7-i] = (val >> i) & 1
190
- return bits
191
-
192
-
193
- def bits_to_int(bits: torch.Tensor) -> int:
194
- val = 0
195
- for i in range(8):
196
- if bits[i].item() > 0.5:
197
- val += 1 << (7-i)
198
- return val
199
-
200
-
201
- def compute_loss(pred_bits, a_bits, b_bits, op_logits,
202
- target_result, target_a, target_b, target_op_idx):
203
- """
204
- Multi-component loss:
205
- 1. Result bits match expected
206
- 2. Extracted A bits match actual A
207
- 3. Extracted B bits match actual B
208
- 4. Operation classification correct
209
- """
210
- result_loss = nn.functional.binary_cross_entropy_with_logits(
211
- pred_bits, target_result
212
- )
213
-
214
- a_loss = nn.functional.binary_cross_entropy(
215
- torch.clamp(a_bits, 1e-7, 1-1e-7), target_a
216
- )
217
- b_loss = nn.functional.binary_cross_entropy(
218
- torch.clamp(b_bits, 1e-7, 1-1e-7), target_b
219
- )
220
-
221
- op_loss = nn.functional.cross_entropy(op_logits, target_op_idx)
222
-
223
- total = result_loss + a_loss + b_loss + op_loss
224
-
225
- return total, {
226
- 'result': result_loss.item(),
227
- 'a': a_loss.item(),
228
- 'b': b_loss.item(),
229
- 'op': op_loss.item()
230
- }
231
-
232
-
233
- def evaluate(model, n_samples: int = 500):
234
- """Evaluate on random problems."""
235
- model.extractor.eval()
236
- correct = 0
237
- op_correct = 0
238
-
239
- for _ in range(n_samples):
240
- text, a, b, op, expected = generate_problem()
241
-
242
- with torch.no_grad():
243
- result_bits, a_bits, b_bits, op_logits = model([text])
244
-
245
- pred_result = bits_to_int(result_bits[0])
246
- pred_op = OPERATIONS[op_logits[0].argmax().item()]
247
-
248
- if pred_result == expected:
249
- correct += 1
250
- if pred_op == op:
251
- op_correct += 1
252
-
253
- model.extractor.train()
254
- return correct / n_samples, op_correct / n_samples
255
-
256
-
257
- def train(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4):
258
- print("=" * 70, flush=True)
259
- print(" LLM INTEGRATION TRAINING", flush=True)
260
- print(" Learning to extract operands from hidden states", flush=True)
261
- print("=" * 70, flush=True)
262
-
263
- print("\nInitializing model...", flush=True)
264
- model = AugmentedArithmeticModel(device=DEVICE)
265
-
266
- optimizer = optim.AdamW(model.extractor.parameters(), lr=lr)
267
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
268
-
269
- print(f"\nTraining config:", flush=True)
270
- print(f" Epochs: {epochs}", flush=True)
271
- print(f" Batch size: {batch_size}", flush=True)
272
- print(f" Learning rate: {lr}", flush=True)
273
- print(f" Samples/epoch: {batch_size * 20}", flush=True)
274
-
275
- print(f"\nInitial evaluation (200 samples)...", flush=True)
276
- acc, op_acc = evaluate(model, n_samples=200)
277
- print(f" Accuracy: {acc:.4f}, Op accuracy: {op_acc:.4f}", flush=True)
278
-
279
- print(f"\nStarting training...", flush=True)
280
- print("-" * 70, flush=True)
281
-
282
- best_acc = acc
283
- start_time = time.perf_counter()
284
-
285
- for epoch in range(epochs):
286
- model.extractor.train()
287
- epoch_loss = 0
288
- epoch_losses = {'result': 0, 'a': 0, 'b': 0, 'op': 0}
289
- n_batches = 20 # 20 batches * 128 = 2,560 samples/epoch
290
-
291
- for batch_idx in range(n_batches):
292
- batch_texts = []
293
- batch_a = []
294
- batch_b = []
295
- batch_op = []
296
- batch_result = []
297
-
298
- for _ in range(batch_size):
299
- text, a, b, op, result = generate_problem()
300
- batch_texts.append(text)
301
- batch_a.append(int_to_bits(a, DEVICE))
302
- batch_b.append(int_to_bits(b, DEVICE))
303
- batch_op.append(OPERATIONS.index(op))
304
- batch_result.append(int_to_bits(result, DEVICE))
305
-
306
- target_a = torch.stack(batch_a)
307
- target_b = torch.stack(batch_b)
308
- target_op = torch.tensor(batch_op, device=DEVICE)
309
- target_result = torch.stack(batch_result)
310
-
311
- optimizer.zero_grad()
312
-
313
- pred_bits, a_bits, b_bits, op_logits = model(batch_texts)
314
-
315
- loss, losses = compute_loss(
316
- pred_bits, a_bits, b_bits, op_logits,
317
- target_result, target_a, target_b, target_op
318
- )
319
-
320
- loss.backward()
321
- torch.nn.utils.clip_grad_norm_(model.extractor.parameters(), 1.0)
322
- optimizer.step()
323
-
324
- epoch_loss += loss.item()
325
- for k in epoch_losses:
326
- epoch_losses[k] += losses[k]
327
-
328
-
329
- scheduler.step()
330
-
331
- avg_loss = epoch_loss / n_batches
332
- for k in epoch_losses:
333
- epoch_losses[k] /= n_batches
334
-
335
- if (epoch + 1) % 5 == 0 or epoch == 0:
336
- acc, op_acc = evaluate(model, n_samples=300)
337
- elapsed = time.perf_counter() - start_time
338
-
339
- marker = " *" if acc > best_acc else ""
340
- if acc > best_acc:
341
- best_acc = acc
342
-
343
- print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | "
344
- f"Acc: {acc:.4f}{marker} | OpAcc: {op_acc:.4f} | "
345
- f"Time: {elapsed:.0f}s")
346
- print(f" Losses - result:{epoch_losses['result']:.4f} "
347
- f"a:{epoch_losses['a']:.4f} b:{epoch_losses['b']:.4f} "
348
- f"op:{epoch_losses['op']:.4f}")
349
-
350
- print("\n" + "=" * 70)
351
- print(" FINAL EVALUATION")
352
- print("=" * 70)
353
-
354
- acc, op_acc = evaluate(model, n_samples=1000)
355
- print(f"Final accuracy: {acc:.4f}")
356
- print(f"Final op accuracy: {op_acc:.4f}")
357
- print(f"Best accuracy: {best_acc:.4f}")
358
-
359
- print("\nSample predictions:")
360
- for _ in range(10):
361
- text, a, b, op, expected = generate_problem()
362
- with torch.no_grad():
363
- result_bits, a_bits, b_bits, op_logits = model([text])
364
- pred = bits_to_int(result_bits[0])
365
- pred_a = bits_to_int(a_bits[0])
366
- pred_b = bits_to_int(b_bits[0])
367
- pred_op = OPERATIONS[op_logits[0].argmax().item()]
368
-
369
- status = "OK" if pred == expected else "WRONG"
370
- print(f" '{text}' = {expected} | pred={pred} (a={pred_a}, b={pred_b}, op={pred_op}) [{status}]")
371
-
372
- save_path = "D:/8bit-threshold-computer/llm_integration/trained_extractor.pt"
373
- torch.save({
374
- 'extractor_state_dict': model.extractor.state_dict(),
375
- 'final_accuracy': acc,
376
- 'best_accuracy': best_acc,
377
- }, save_path)
378
- print(f"\nSaved to: {save_path}")
379
-
380
- return model, acc
381
-
382
-
383
- if __name__ == "__main__":
384
- random.seed(42)
385
- torch.manual_seed(42)
386
-
387
- train(epochs=100, batch_size=384, lr=3e-4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm_integration/train_passthrough.py DELETED
@@ -1,182 +0,0 @@
1
- """
2
- Training script for ThresholdALU interface layers.
3
- Trains encoder/router to correctly use frozen threshold circuits.
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.optim as optim
9
- import time
10
- import argparse
11
- from model import ThresholdALU, DirectCircuitModel
12
- from fitness import generate_batch, compute_fitness, compute_loss, OPERATIONS
13
-
14
-
15
- def train(
16
- epochs: int = 100,
17
- batch_size: int = 512,
18
- lr: float = 1e-3,
19
- eval_interval: int = 10,
20
- eval_samples: int = 2000,
21
- device: str = 'cuda'
22
- ):
23
- print("=" * 70)
24
- print(" THRESHOLD ALU INTERFACE TRAINING")
25
- print("=" * 70)
26
-
27
- print("\n[1/4] Verifying frozen circuits...")
28
- direct_model = DirectCircuitModel(device=device)
29
-
30
- def direct_fn(a, b, op):
31
- return direct_model(a, b, op)
32
-
33
- circuit_fitness = compute_fitness(direct_fn, n_samples=1000, device=device)
34
- print(f" Frozen circuit fitness: {circuit_fitness:.4f}")
35
- if circuit_fitness < 0.999:
36
- print(" ERROR: Circuits not achieving 100%. Aborting.")
37
- return
38
- print(" STATUS: PASS")
39
-
40
- print("\n[2/4] Initializing model...")
41
- model = ThresholdALU(device=device)
42
-
43
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
44
- print(f" Trainable parameters: {trainable_params:,}")
45
-
46
- def model_fn(a, b, op):
47
- return model(a, b, op)
48
-
49
- initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device)
50
- print(f" Initial fitness: {initial_fitness:.4f}")
51
-
52
- print("\n[3/4] Setting up training...")
53
- optimizer = optim.AdamW(model.parameters(), lr=lr)
54
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
55
-
56
- print(f" Optimizer: AdamW")
57
- print(f" Learning rate: {lr}")
58
- print(f" Batch size: {batch_size}")
59
- print(f" Epochs: {epochs}")
60
-
61
- print("\n[4/4] Training...")
62
- print("-" * 70)
63
-
64
- best_fitness = initial_fitness
65
- start_time = time.perf_counter()
66
-
67
- for epoch in range(epochs):
68
- model.train()
69
- epoch_loss = 0.0
70
- n_batches = 100
71
-
72
- for _ in range(n_batches):
73
- batch = generate_batch(batch_size, device)
74
-
75
- optimizer.zero_grad()
76
-
77
- pred_bits = model(batch['a_bits'], batch['b_bits'], batch['op_onehot'])
78
-
79
- loss = compute_loss(pred_bits, batch['expected_bits'])
80
-
81
- loss.backward()
82
- optimizer.step()
83
-
84
- epoch_loss += loss.item()
85
-
86
- scheduler.step()
87
-
88
- avg_loss = epoch_loss / n_batches
89
-
90
- if (epoch + 1) % eval_interval == 0 or epoch == 0:
91
- model.eval()
92
- fitness, details = compute_fitness(
93
- model_fn, n_samples=eval_samples, device=device, return_details=True
94
- )
95
-
96
- elapsed = time.perf_counter() - start_time
97
-
98
- if fitness > best_fitness:
99
- best_fitness = fitness
100
- marker = " *"
101
- else:
102
- marker = ""
103
-
104
- print(f"Epoch {epoch+1:4d} | Loss: {avg_loss:.4f} | "
105
- f"Fitness: {fitness:.4f}{marker} | "
106
- f"LR: {scheduler.get_last_lr()[0]:.2e} | "
107
- f"Time: {elapsed:.1f}s")
108
-
109
- if fitness >= 0.9999:
110
- print("\n" + "=" * 70)
111
- print(" TARGET ACHIEVED: 100% FITNESS")
112
- print("=" * 70)
113
- break
114
-
115
- print("\n" + "=" * 70)
116
- print(" TRAINING COMPLETE")
117
- print("=" * 70)
118
-
119
- model.eval()
120
- final_fitness, details = compute_fitness(
121
- model_fn, n_samples=5000, device=device, return_details=True
122
- )
123
-
124
- print(f"\nFinal fitness: {final_fitness:.4f}")
125
- print(f"Best fitness: {best_fitness:.4f}")
126
- print(f"\nPer-operation breakdown:")
127
- for op in OPERATIONS:
128
- acc = details['by_op'][op]['accuracy']
129
- print(f" {op:6}: {acc:.4f}")
130
-
131
- print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s")
132
-
133
- # Save trained model
134
- save_path = "D:/8bit-threshold-computer/llm_integration/trained_model.pt"
135
- torch.save({
136
- 'encoder_state_dict': model.encoder.state_dict(),
137
- 'router_state_dict': model.router.state_dict(),
138
- 'final_fitness': final_fitness,
139
- 'best_fitness': best_fitness,
140
- }, save_path)
141
- print(f"\nSaved trained model to: {save_path}")
142
-
143
- return model, final_fitness
144
-
145
-
146
- def main():
147
- parser = argparse.ArgumentParser(description='Train ThresholdALU interface')
148
- parser.add_argument('--epochs', type=int, default=200, help='Number of epochs')
149
- parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
150
- parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
151
- parser.add_argument('--eval_interval', type=int, default=10, help='Eval every N epochs')
152
- parser.add_argument('--device', type=str, default='cuda', help='Device')
153
- args = parser.parse_args()
154
-
155
- torch.manual_seed(42)
156
-
157
- model, fitness = train(
158
- epochs=args.epochs,
159
- batch_size=args.batch_size,
160
- lr=args.lr,
161
- eval_interval=args.eval_interval,
162
- device=args.device
163
- )
164
-
165
- print("\n" + "=" * 70)
166
- print(" EXPERIMENT SUMMARY")
167
- print("=" * 70)
168
- print(f"\n Control (Vanilla SmolLM2-360M): 11.90%")
169
- print(f" Experimental (Trained Interface): {100*fitness:.2f}%")
170
- print(f"\n Improvement: {100*(fitness - 0.119)/0.119:.1f}%")
171
-
172
- if fitness >= 0.99:
173
- print("\n CONCLUSION: Frozen threshold circuits + trained interface")
174
- print(" achieves near-perfect arithmetic accuracy.")
175
- print(" Core thesis VALIDATED.")
176
- else:
177
- print(f"\n CONCLUSION: Further training or architecture changes needed.")
178
- print(f" Current gap: {100*(1.0 - fitness):.2f}%")
179
-
180
-
181
- if __name__ == "__main__":
182
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm_integration/train_passthrough_router.py DELETED
@@ -1,106 +0,0 @@
1
- """
2
- Train only the router with ground truth bits.
3
- Proves that operation routing can be learned perfectly.
4
- """
5
-
6
- import torch
7
- import torch.optim as optim
8
- import time
9
- from model import OpRouter
10
- from circuits import FrozenThresholdCircuits
11
- from fitness import generate_batch, compute_fitness, compute_loss, OPERATIONS
12
-
13
- device = 'cuda'
14
-
15
- print("=" * 70)
16
- print(" ROUTER-ONLY TRAINING (Ground Truth Bits)")
17
- print("=" * 70)
18
-
19
- circuits = FrozenThresholdCircuits(device=device)
20
- router = OpRouter(input_dim=16 + 6, hidden_dim=64, n_ops=6).to(device)
21
-
22
- print(f"\nRouter parameters: {sum(p.numel() for p in router.parameters()):,}")
23
-
24
- def model_fn(a_bits, b_bits, op_onehot):
25
- x = torch.cat([a_bits, b_bits, op_onehot], dim=-1)
26
- op_weights = router(x)
27
- return circuits(a_bits, b_bits, op_weights)
28
-
29
- initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device)
30
- print(f"Initial fitness: {initial_fitness:.4f}")
31
-
32
- optimizer = optim.AdamW(router.parameters(), lr=1e-2)
33
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
34
-
35
- print("\nTraining...")
36
- print("-" * 70)
37
-
38
- best_fitness = initial_fitness
39
- start_time = time.perf_counter()
40
-
41
- for epoch in range(100):
42
- router.train()
43
- epoch_loss = 0.0
44
-
45
- for _ in range(100):
46
- batch = generate_batch(256, device)
47
-
48
- optimizer.zero_grad()
49
-
50
- x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1)
51
- op_weights = router(x)
52
- pred_bits = circuits(batch['a_bits'], batch['b_bits'], op_weights)
53
-
54
- loss = compute_loss(pred_bits, batch['expected_bits'])
55
- loss.backward()
56
- optimizer.step()
57
-
58
- epoch_loss += loss.item()
59
-
60
- scheduler.step()
61
-
62
- if (epoch + 1) % 10 == 0 or epoch == 0:
63
- router.eval()
64
- fitness, details = compute_fitness(model_fn, n_samples=2000, device=device, return_details=True)
65
- elapsed = time.perf_counter() - start_time
66
-
67
- if fitness > best_fitness:
68
- best_fitness = fitness
69
- marker = " *"
70
- else:
71
- marker = ""
72
-
73
- print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/100:.4f} | "
74
- f"Fitness: {fitness:.4f}{marker} | Time: {elapsed:.1f}s")
75
-
76
- if fitness >= 0.9999:
77
- print("\n TARGET: 100% FITNESS ACHIEVED")
78
- break
79
-
80
- print("\n" + "=" * 70)
81
- print(" RESULTS")
82
- print("=" * 70)
83
-
84
- router.eval()
85
- final_fitness, details = compute_fitness(model_fn, n_samples=5000, device=device, return_details=True)
86
-
87
- print(f"\nFinal fitness: {final_fitness:.4f}")
88
- print(f"\nPer-operation:")
89
- for op in OPERATIONS:
90
- acc = details['by_op'][op]['accuracy']
91
- print(f" {op}: {acc:.4f}")
92
-
93
- print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s")
94
-
95
- if final_fitness >= 0.99:
96
- print("\nCONCLUSION: Router successfully learned operation dispatch.")
97
- print(" With correct bit encoding, 100% is achievable.")
98
-
99
- # Save trained router weights
100
- save_path = "D:/8bit-threshold-computer/llm_integration/trained_router.pt"
101
- torch.save({
102
- 'router_state_dict': router.state_dict(),
103
- 'final_fitness': final_fitness,
104
- 'params': sum(p.numel() for p in router.parameters()),
105
- }, save_path)
106
- print(f"\nSaved trained router to: {save_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm_integration/{trained_passthrough_router.pt → trained/router.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b33772a74d3891031225298d33d57663c36719e438b5bc9f9039f9e57d636df
3
- size 10147
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ddfc24cd4a98b65de8d434bb843ebd24f8c902d067201fd7954e7b623a8ebcd
3
+ size 9811