""" Unified training script for threshold circuit LLM integration. Modes: --mode router : Train only OpRouter with ground truth bits (sanity check) --mode interface : Train BitEncoder + OpRouter with ground truth bits (sanity check) --mode llm : Train extractor with LLM hidden states (the real training) LLM mode options: --unfreeze_layers N : Unfreeze top N transformer layers (default 0 = fully frozen) Hardware Profile (NVIDIA RTX 6000 Ada 48GB): VRAM Scaling (unfreeze_layers=4): batch_size | VRAM | % -----------+---------+------ 512 | 5,784 | 11.8% 1,024 | 7,384 | 15.0% 4,096 | 16,534 | 33.6% 13,000 | 39,000 | 79.4% <-- recommended for 80% target Examples: python train.py --mode llm --epochs 100 --batch_size 256 python train.py --mode llm --epochs 100 --batch_size 4096 --unfreeze_layers 4 """ import torch import torch.nn as nn import torch.optim as optim import time import argparse import random from model import ( ThresholdALU, DirectCircuitModel, OpRouter, ArithmeticModel, OPERATIONS, OP_SYMBOLS ) from circuits import FrozenThresholdCircuits from fitness import generate_batch, compute_fitness, compute_loss DEVICE = 'cuda' ONES = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', 'seventeen', 'eighteen', 'nineteen'] TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety'] def int_to_words(n: int) -> str: """Convert integer 0-255 to English words.""" if n < 0 or n > 255: return str(n) if n < 20: return ONES[n] if n < 100: if n % 10 == 0: return TENS[n // 10] return f"{TENS[n // 10]} {ONES[n % 10]}" if n % 100 == 0: return f"{ONES[n // 100]} hundred" if n % 100 < 20: return f"{ONES[n // 100]} hundred {ONES[n % 100]}" if n % 10 == 0: return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]}" return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]} {ONES[n % 10]}" def int_to_bits(val: int, device: str = 'cuda') -> torch.Tensor: bits = torch.zeros(8, device=device) for i in range(8): bits[7-i] = (val >> i) & 1 return bits def bits_to_int(bits: torch.Tensor) -> int: val = 0 for i in range(8): if bits[i].item() > 0.5: val += 1 << (7-i) return val NL_TEMPLATES = { 'add': [ "What is {a} plus {b}?", "Calculate {a} + {b}", "Add {a} and {b}", "What's the sum of {a} and {b}?", "If I have {a} and get {b} more, how many total?", "{a} + {b} = ?", "Compute {a} plus {b}", ], 'sub': [ "What is {a} minus {b}?", "Calculate {a} - {b}", "Subtract {b} from {a}", "What's {a} take away {b}?", "If I have {a} and lose {b}, how many left?", "{a} - {b} = ?", "Compute {a} minus {b}", ], 'mul': [ "What is {a} times {b}?", "Calculate {a} * {b}", "Multiply {a} by {b}", "What's {a} multiplied by {b}?", "{a} * {b} = ?", "Compute {a} times {b}", "What is the product of {a} and {b}?", ], 'gt': [ "Is {a} greater than {b}?", "Is {a} > {b}?", "Check if {a} is larger than {b}", "Compare: is {a} more than {b}?", "{a} > {b}?", ], 'lt': [ "Is {a} less than {b}?", "Is {a} < {b}?", "Check if {a} is smaller than {b}", "Compare: is {a} fewer than {b}?", "{a} < {b}?", ], 'eq': [ "Is {a} equal to {b}?", "Is {a} == {b}?", "Does {a} equal {b}?", "Check if {a} equals {b}", "Are {a} and {b} the same?", ], } def generate_problem(max_val: int = 255): """ Generate a random arithmetic problem for LLM training. Randomly mixes digit and word formats. """ a = random.randint(0, max_val) b = random.randint(0, max_val) op = random.choice(OPERATIONS) fmt = random.choice(['digits', 'words', 'nl_digits', 'nl_words']) if fmt == 'digits': sym = OP_SYMBOLS[op] text = f"{a} {sym} {b}" elif fmt == 'words': a_word = int_to_words(a) b_word = int_to_words(b) op_word = {'add': 'plus', 'sub': 'minus', 'mul': 'times', 'gt': 'greater than', 'lt': 'less than', 'eq': 'equals'}[op] text = f"{a_word} {op_word} {b_word}" elif fmt == 'nl_digits': template = random.choice(NL_TEMPLATES[op]) text = template.format(a=a, b=b) else: template = random.choice(NL_TEMPLATES[op]) text = template.format(a=int_to_words(a), b=int_to_words(b)) if op == 'add': result = (a + b) & 0xFF elif op == 'sub': result = (a - b) & 0xFF elif op == 'mul': result = (a * b) & 0xFF elif op == 'gt': result = 1 if a > b else 0 elif op == 'lt': result = 1 if a < b else 0 elif op == 'eq': result = 1 if a == b else 0 return text, a, b, op, result def get_curriculum_max(epoch: int, total_epochs: int) -> int: """ Curriculum learning: start with small numbers, gradually increase. Epoch 0-20%: 0-9 (single digit) Epoch 20-50%: 0-99 (two digit) Epoch 50-100%: 0-255 (full range) """ progress = epoch / total_epochs if progress < 0.2: return 9 elif progress < 0.5: return 99 else: return 255 def train_router(epochs: int = 100, batch_size: int = 256, lr: float = 1e-2, device: str = 'cuda'): """Train only the router with ground truth bits.""" print("=" * 70) print(" ROUTER-ONLY TRAINING (Ground Truth Bits)") print("=" * 70) circuits = FrozenThresholdCircuits(device=device) router = OpRouter(input_dim=16 + 6, hidden_dim=64, n_ops=6).to(device) print(f"\nRouter parameters: {sum(p.numel() for p in router.parameters()):,}") def model_fn(a_bits, b_bits, op_onehot): x = torch.cat([a_bits, b_bits, op_onehot], dim=-1) op_weights = router(x) return circuits(a_bits, b_bits, op_weights) initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device) print(f"Initial fitness: {initial_fitness:.4f}") optimizer = optim.AdamW(router.parameters(), lr=lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) print("\nTraining...") print("-" * 70) best_fitness = initial_fitness start_time = time.perf_counter() for epoch in range(epochs): router.train() epoch_loss = 0.0 for _ in range(100): batch = generate_batch(batch_size, device) optimizer.zero_grad() x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1) op_weights = router(x) pred_bits = circuits(batch['a_bits'], batch['b_bits'], op_weights) loss = compute_loss(pred_bits, batch['expected_bits']) loss.backward() optimizer.step() epoch_loss += loss.item() scheduler.step() if (epoch + 1) % 10 == 0 or epoch == 0: router.eval() fitness, details = compute_fitness(model_fn, n_samples=2000, device=device, return_details=True) elapsed = time.perf_counter() - start_time if fitness > best_fitness: best_fitness = fitness marker = " *" else: marker = "" print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/100:.4f} | " f"Fitness: {fitness:.4f}{marker} | Time: {elapsed:.1f}s") if fitness >= 0.9999: print("\n TARGET: 100% FITNESS ACHIEVED") break print("\n" + "=" * 70) print(" RESULTS") print("=" * 70) router.eval() final_fitness, details = compute_fitness(model_fn, n_samples=5000, device=device, return_details=True) print(f"\nFinal fitness: {final_fitness:.4f}") print(f"\nPer-operation:") for op in OPERATIONS: acc = details['by_op'][op]['accuracy'] print(f" {op}: {acc:.4f}") print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s") if final_fitness >= 0.99: print("\nCONCLUSION: Router successfully learned operation dispatch.") print(" With correct bit encoding, 100% is achievable.") save_path = "D:/8bit-threshold-computer/llm_integration/trained/router.pt" torch.save({ 'router_state_dict': router.state_dict(), 'final_fitness': final_fitness, 'params': sum(p.numel() for p in router.parameters()), }, save_path) print(f"\nSaved trained router to: {save_path}") return router, final_fitness def get_gpu_memory(): """Get GPU memory usage in MB.""" if torch.cuda.is_available(): return torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.max_memory_allocated() / 1024 / 1024 return 0, 0 def train_interface(epochs: int = 200, batch_size: int = 512, lr: float = 1e-3, eval_interval: int = 10, device: str = 'cuda'): """Train BitEncoder + OpRouter with ground truth bits.""" print("=" * 70) print(" INTERFACE TRAINING (Encoder + Router)") print("=" * 70) print(f" Started at: {time.strftime('%H:%M:%S')}") print("\n[1/4] Verifying frozen circuits...") print(" Creating DirectCircuitModel...", end=" ", flush=True) direct_model = DirectCircuitModel(device=device) mem, max_mem = get_gpu_memory() print(f"done. VRAM: {mem:.0f}MB") def direct_fn(a, b, op): return direct_model(a, b, op) print(" Running fitness check (1000 samples)...", end=" ", flush=True) circuit_fitness = compute_fitness(direct_fn, n_samples=1000, device=device) print(f"done. Fitness: {circuit_fitness:.4f}") if circuit_fitness < 0.999: print(" ERROR: Circuits not achieving 100%. Aborting.") return None, 0.0 print(" STATUS: PASS") print("\n[2/4] Initializing model...") print(" Creating ThresholdALU...", end=" ", flush=True) model = ThresholdALU(device=device) mem, max_mem = get_gpu_memory() print(f"done. VRAM: {mem:.0f}MB") trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Trainable parameters: {trainable_params:,}") def model_fn(a, b, op): return model(a, b, op) print(" Running initial fitness check...", end=" ", flush=True) initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device) print(f"done. Fitness: {initial_fitness:.4f}") print("\n[3/4] Setting up training...") print(" Creating optimizer...", end=" ", flush=True) optimizer = optim.AdamW(model.parameters(), lr=lr) print("done.") print(" Creating scheduler...", end=" ", flush=True) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) print("done.") print(f" Config: lr={lr}, batch_size={batch_size}, epochs={epochs}") print("\n[4/4] Training...") print(" Generating first batch to warm up...", end=" ", flush=True) warmup_batch = generate_batch(batch_size, device) mem, max_mem = get_gpu_memory() print(f"done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)") print("-" * 70) best_fitness = initial_fitness start_time = time.perf_counter() n_batches = 100 for epoch in range(epochs): model.train() epoch_loss = 0.0 epoch_start = time.perf_counter() for batch_idx in range(n_batches): batch = generate_batch(batch_size, device) optimizer.zero_grad() pred_bits = model(batch['a_bits'], batch['b_bits'], batch['op_onehot']) loss = compute_loss(pred_bits, batch['expected_bits']) loss.backward() optimizer.step() epoch_loss += loss.item() if batch_idx == 0 and epoch == 0: mem, max_mem = get_gpu_memory() print(f" First forward/backward done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)") if (batch_idx + 1) % 25 == 0: avg_so_far = epoch_loss / (batch_idx + 1) print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True) scheduler.step() avg_loss = epoch_loss / n_batches epoch_time = time.perf_counter() - epoch_start if (epoch + 1) % 5 == 0 or epoch == 0: # Eval every 5 epochs model.eval() fitness, details = compute_fitness( model_fn, n_samples=2000, device=device, return_details=True ) elapsed = time.perf_counter() - start_time if fitness > best_fitness: best_fitness = fitness marker = " *" else: marker = "" mem, _ = get_gpu_memory() print(f"Epoch {epoch+1:4d} | Loss: {avg_loss:.4f} | " f"Fitness: {fitness:.4f}{marker} | " f"LR: {scheduler.get_last_lr()[0]:.2e} | " f"VRAM: {mem:.0f}MB | " f"Time: {elapsed:.1f}s ({epoch_time:.1f}s/epoch)") if fitness >= 0.9999: print("\n" + "=" * 70) print(" TARGET ACHIEVED: 100% FITNESS") print("=" * 70) break print("\n" + "=" * 70) print(" TRAINING COMPLETE") print("=" * 70) model.eval() final_fitness, details = compute_fitness( model_fn, n_samples=5000, device=device, return_details=True ) print(f"\nFinal fitness: {final_fitness:.4f}") print(f"Best fitness: {best_fitness:.4f}") print(f"\nPer-operation breakdown:") for op in OPERATIONS: acc = details['by_op'][op]['accuracy'] print(f" {op:6}: {acc:.4f}") print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s") save_path = "D:/8bit-threshold-computer/llm_integration/trained/interface.pt" torch.save({ 'encoder_state_dict': model.encoder.state_dict(), 'router_state_dict': model.router.state_dict(), 'final_fitness': final_fitness, 'best_fitness': best_fitness, }, save_path) print(f"\nSaved trained model to: {save_path}") return model, final_fitness def compute_llm_loss(pred_bits, a_bits, b_bits, op_logits, target_result, target_a, target_b, target_op_idx, bit_weight: float = 2.0): """ Multi-component loss for LLM training. bit_weight: multiplier for a/b bit losses (default 2x since extraction is the bottleneck) """ result_loss = nn.functional.binary_cross_entropy_with_logits( pred_bits, target_result ) a_bits_safe = torch.clamp(a_bits, 0.0, 1.0) b_bits_safe = torch.clamp(b_bits, 0.0, 1.0) a_bits_safe = torch.nan_to_num(a_bits_safe, nan=0.5, posinf=1.0, neginf=0.0) b_bits_safe = torch.nan_to_num(b_bits_safe, nan=0.5, posinf=1.0, neginf=0.0) a_loss = nn.functional.binary_cross_entropy( torch.clamp(a_bits_safe, 1e-6, 1-1e-6), target_a ) b_loss = nn.functional.binary_cross_entropy( torch.clamp(b_bits_safe, 1e-6, 1-1e-6), target_b ) op_loss = nn.functional.cross_entropy(op_logits, target_op_idx) total = result_loss + bit_weight * a_loss + bit_weight * b_loss + op_loss total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0) return total, { 'result': result_loss.item() if not torch.isnan(result_loss) else 10.0, 'a': a_loss.item() if not torch.isnan(a_loss) else 10.0, 'b': b_loss.item() if not torch.isnan(b_loss) else 10.0, 'op': op_loss.item() if not torch.isnan(op_loss) else 10.0 } def value_to_digits(value: int) -> list: """Convert integer value to list of digits (hundreds, tens, ones).""" digits = [] for place in [100, 10, 1]: digit = (value // place) % 10 digits.append(digit) return digits def compute_positional_digit_loss(pred_bits, op_logits, a_digit_logits_list, b_digit_logits_list, target_result, target_op_idx, target_a_values, target_b_values, device, digit_weight: float = 5.0): """ Loss for positional digit extraction with DIRECT digit supervision. This provides strong gradients by directly supervising digit classification instead of going through the value -> bits conversion. """ result_loss = nn.functional.binary_cross_entropy_with_logits( pred_bits, target_result ) op_loss = nn.functional.cross_entropy(op_logits, target_op_idx) a_digit_loss = torch.tensor(0.0, device=device) b_digit_loss = torch.tensor(0.0, device=device) n_a_digits = 0 n_b_digits = 0 for i, (a_logits_list, b_logits_list) in enumerate(zip(a_digit_logits_list, b_digit_logits_list)): target_a = target_a_values[i].item() target_b = target_b_values[i].item() a_digits = value_to_digits(int(target_a)) b_digits = value_to_digits(int(target_b)) n_a = len(a_logits_list) n_b = len(b_logits_list) if n_a > 0: target_a_digits = a_digits[-n_a:] for j, logits in enumerate(a_logits_list): target_digit = torch.tensor([target_a_digits[j]], device=device, dtype=torch.long) a_digit_loss = a_digit_loss + nn.functional.cross_entropy(logits.unsqueeze(0), target_digit) n_a_digits += 1 if n_b > 0: target_b_digits = b_digits[-n_b:] for j, logits in enumerate(b_logits_list): target_digit = torch.tensor([target_b_digits[j]], device=device, dtype=torch.long) b_digit_loss = b_digit_loss + nn.functional.cross_entropy(logits.unsqueeze(0), target_digit) n_b_digits += 1 if n_a_digits > 0: a_digit_loss = a_digit_loss / n_a_digits if n_b_digits > 0: b_digit_loss = b_digit_loss / n_b_digits total = result_loss + digit_weight * a_digit_loss + digit_weight * b_digit_loss + op_loss total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0) return total, { 'result': result_loss.item() if not torch.isnan(result_loss) else 10.0, 'a_digit': a_digit_loss.item() if not torch.isnan(a_digit_loss) else 10.0, 'b_digit': b_digit_loss.item() if not torch.isnan(b_digit_loss) else 10.0, 'op': op_loss.item() if not torch.isnan(op_loss) else 10.0 } def compute_hybrid_loss(pred_bits, op_logits, used_lookup, a_digit_logits, b_digit_logits, target_result, target_a_values, target_b_values, target_op_idx, device, digit_weight: float = 2.0): """ Loss for hybrid extraction with digit-level prediction. Uses cross-entropy on each digit (hundreds, tens, ones) for word samples. Samples using digit lookup are already 100% accurate - no loss computed. """ result_loss = nn.functional.binary_cross_entropy_with_logits( pred_bits, target_result ) op_loss = nn.functional.cross_entropy(op_logits, target_op_idx) word_mask = ~used_lookup n_words = word_mask.sum().item() if n_words > 0 and a_digit_logits is not None and b_digit_logits is not None: target_a_word = target_a_values[word_mask].long() target_b_word = target_b_values[word_mask].long() a_hundreds = target_a_word // 100 a_tens = (target_a_word % 100) // 10 a_ones = target_a_word % 10 b_hundreds = target_b_word // 100 b_tens = (target_b_word % 100) // 10 b_ones = target_b_word % 10 a_logits = a_digit_logits.view(-1, 3, 10) b_logits = b_digit_logits.view(-1, 3, 10) a_digit_loss = ( nn.functional.cross_entropy(a_logits[:, 0], a_hundreds) + nn.functional.cross_entropy(a_logits[:, 1], a_tens) + nn.functional.cross_entropy(a_logits[:, 2], a_ones) ) / 3.0 b_digit_loss = ( nn.functional.cross_entropy(b_logits[:, 0], b_hundreds) + nn.functional.cross_entropy(b_logits[:, 1], b_tens) + nn.functional.cross_entropy(b_logits[:, 2], b_ones) ) / 3.0 else: a_digit_loss = torch.tensor(0.0, device=device) b_digit_loss = torch.tensor(0.0, device=device) total = result_loss + op_loss + digit_weight * (a_digit_loss + b_digit_loss) total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0) return total, { 'result': result_loss.item() if not torch.isnan(result_loss) else 10.0, 'a_digit': a_digit_loss.item() if not torch.isnan(a_digit_loss) else 10.0, 'b_digit': b_digit_loss.item() if not torch.isnan(b_digit_loss) else 10.0, 'op': op_loss.item() if not torch.isnan(op_loss) else 10.0, 'n_words': n_words, 'n_lookup': used_lookup.sum().item() } def evaluate_llm(model, n_samples: int = 500): """Evaluate LLM model on random problems (mixed digit/word format).""" model.extractor.eval() correct = 0 op_correct = 0 for _ in range(n_samples): text, a, b, op, expected = generate_problem() with torch.no_grad(): outputs = model([text]) result_bits = outputs[0] op_logits = outputs[3] pred_result = bits_to_int(result_bits[0]) pred_op = OPERATIONS[op_logits[0].argmax().item()] if pred_result == expected: correct += 1 if pred_op == op: op_correct += 1 model.extractor.train() return correct / n_samples, op_correct / n_samples def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4, unfreeze_layers: int = 0, extract_layer: int = -1, position_extract: bool = False, digit_pred: bool = False, positional_digit: bool = False, device: str = 'cuda'): """ Train extractor with LLM hidden states. Args: unfreeze_layers: Number of top transformer layers to unfreeze (0 = fully frozen) extract_layer: Which layer to extract from (-1 = last) position_extract: Use position-specific extraction (legacy) digit_pred: Predict digits instead of bits (legacy) positional_digit: Use positional digit extraction (legacy, 100% on digits only) """ hybrid = not (positional_digit or position_extract or digit_pred) print("=" * 70) print(" LLM TRAINING") if unfreeze_layers > 0: print(f" {unfreeze_layers} transformer layers unfrozen") else: print(" LLM frozen") if extract_layer != -1: print(f" Extracting from layer {extract_layer}") if hybrid: print(" HYBRID extraction (digit lookup + word learning)") elif positional_digit: print(" POSITIONAL DIGIT extraction (legacy, 100% on digits only)") elif position_extract: print(" Position-specific extraction (legacy)") elif digit_pred: print(" Digit-level prediction (legacy)") print("=" * 70) print("\nInitializing model...") model = ArithmeticModel( device=device, unfreeze_layers=unfreeze_layers, extract_layer=extract_layer, position_extract=position_extract, digit_pred=digit_pred, positional_digit=positional_digit, hybrid=hybrid ) optimizer = optim.AdamW(model.trainable_parameters(), lr=lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) print(f"\nTraining config:") print(f" Epochs: {epochs}") print(f" Batch size: {batch_size}") print(f" Learning rate: {lr}") print(f" Unfreeze layers: {unfreeze_layers}") print(f" Samples/epoch: {batch_size * 20}") print(f"\nInitial evaluation (200 samples)...") acc, op_acc = evaluate_llm(model, 200) print(f" Accuracy: {acc:.4f}, Op accuracy: {op_acc:.4f}") print(f"\nStarting training...") print("-" * 70) best_acc = acc start_time = time.perf_counter() for epoch in range(epochs): model.extractor.train() if unfreeze_layers > 0: model.llm.train() max_val = get_curriculum_max(epoch, epochs) epoch_loss = 0 if hybrid: epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0, 'n_words': 0, 'n_lookup': 0} elif positional_digit: epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0} else: epoch_losses = {'result': 0, 'a': 0, 'b': 0, 'op': 0} n_batches = 20 epoch_start = time.perf_counter() for batch_idx in range(n_batches): batch_texts = [] batch_a = [] batch_b = [] batch_op = [] batch_result = [] batch_a_values = [] batch_b_values = [] for _ in range(batch_size): text, a, b, op, result = generate_problem(max_val) batch_texts.append(text) batch_a.append(int_to_bits(a, device)) batch_b.append(int_to_bits(b, device)) batch_op.append(OPERATIONS.index(op)) batch_result.append(int_to_bits(result, device)) batch_a_values.append(a) batch_b_values.append(b) target_a = torch.stack(batch_a) target_b = torch.stack(batch_b) target_op = torch.tensor(batch_op, device=device) target_result = torch.stack(batch_result) target_a_values = torch.tensor(batch_a_values, device=device, dtype=torch.float32) target_b_values = torch.tensor(batch_b_values, device=device, dtype=torch.float32) optimizer.zero_grad() outputs = model(batch_texts) pred_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3] if hybrid: a_values, b_values, used_lookup = outputs[4], outputs[5], outputs[6] a_digit_logits, b_digit_logits = outputs[7], outputs[8] loss, losses = compute_hybrid_loss( pred_bits, op_logits, used_lookup, a_digit_logits, b_digit_logits, target_result, target_a_values, target_b_values, target_op, device ) elif positional_digit: a_digit_logits_list = outputs[7] b_digit_logits_list = outputs[8] loss, losses = compute_positional_digit_loss( pred_bits, op_logits, a_digit_logits_list, b_digit_logits_list, target_result, target_op, target_a_values, target_b_values, device ) else: loss, losses = compute_llm_loss( pred_bits, a_bits, b_bits, op_logits, target_result, target_a, target_b, target_op ) loss.backward() torch.nn.utils.clip_grad_norm_(model.trainable_parameters(), 1.0) optimizer.step() epoch_loss += loss.item() for k in epoch_losses: epoch_losses[k] += losses[k] if (batch_idx + 1) % 5 == 0: avg_so_far = epoch_loss / (batch_idx + 1) print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True) epoch_time = time.perf_counter() - epoch_start scheduler.step() avg_loss = epoch_loss / n_batches for k in epoch_losses: epoch_losses[k] /= n_batches acc, op_acc = evaluate_llm(model, 300) elapsed = time.perf_counter() - start_time marker = " *" if acc > best_acc else "" if acc > best_acc: best_acc = acc mem, _ = get_gpu_memory() print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | " f"Acc: {acc:.4f}{marker} | OpAcc: {op_acc:.4f} | " f"Range: 0-{max_val} | VRAM: {mem:.0f}MB | Time: {elapsed:.0f}s") if hybrid: print(f" Losses - result:{epoch_losses['result']:.4f} " f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} " f"op:{epoch_losses['op']:.4f} | words:{epoch_losses['n_words']:.0f} lookup:{epoch_losses['n_lookup']:.0f}") elif positional_digit: print(f" Losses - result:{epoch_losses['result']:.4f} " f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} " f"op:{epoch_losses['op']:.4f}") else: print(f" Losses - result:{epoch_losses['result']:.4f} " f"a:{epoch_losses['a']:.4f} b:{epoch_losses['b']:.4f} " f"op:{epoch_losses['op']:.4f}") if acc >= 0.99: print("\n" + "=" * 70) print(" TARGET ACHIEVED: 99%+ ACCURACY") print("=" * 70) break print("\n" + "=" * 70) print(" FINAL EVALUATION") print("=" * 70) acc, op_acc = evaluate_llm(model, 1000) print(f"Final accuracy: {acc:.4f}") print(f"Final op accuracy: {op_acc:.4f}") print(f"Best accuracy: {best_acc:.4f}") print("\nSample predictions:") for _ in range(10): text, a, b, op, expected = generate_problem() with torch.no_grad(): outputs = model([text]) result_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3] pred = bits_to_int(result_bits[0]) pred_a = bits_to_int(a_bits[0]) pred_b = bits_to_int(b_bits[0]) pred_op = OPERATIONS[op_logits[0].argmax().item()] status = "OK" if pred == expected else "WRONG" print(f" '{text}' = {expected} | pred={pred} (a={pred_a}, b={pred_b}, op={pred_op}) [{status}]") save_path = "D:/8bit-threshold-computer/llm_integration/trained/llm.pt" save_dict = { 'extractor_state_dict': model.extractor.state_dict(), 'final_accuracy': acc, 'best_accuracy': best_acc, 'unfreeze_layers': unfreeze_layers, } if unfreeze_layers > 0: save_dict['llm_state_dict'] = { k: v for k, v in model.llm.state_dict().items() if any(f'layers.{i}.' in k for i in range(len(model.llm.model.layers) - unfreeze_layers, len(model.llm.model.layers))) } torch.save(save_dict, save_path) print(f"\nSaved to: {save_path}") return model, acc def main(): parser = argparse.ArgumentParser( description='Unified training for threshold circuit LLM integration', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Modes: router - Train only OpRouter with ground truth bits (sanity check) interface - Train BitEncoder + OpRouter with ground truth bits (sanity check) llm - Train extractor with LLM hidden states (the real training) LLM options: --unfreeze_layers N Fine-tune top N transformer layers --extract_layer N Extract from layer N (-1 = last) --position_extract Use position-specific extraction --digit_pred Predict digits instead of bits Baked-in: curriculum learning (0-9 -> 0-99 -> 0-255), 2x loss weight for a/b Examples: python train.py --mode llm --epochs 100 python train.py --mode llm --position_extract python train.py --mode llm --digit_pred --extract_layer 12 python train.py --mode llm --unfreeze_layers 4 --batch_size 4096 """ ) parser.add_argument('--mode', type=str, required=True, choices=['router', 'interface', 'llm'], help='Training mode') parser.add_argument('--epochs', type=int, default=100, help='Number of epochs') parser.add_argument('--batch_size', type=int, default=512, help='Batch size (default: 512)') parser.add_argument('--lr', type=float, default=None, help='Learning rate (default: mode-specific)') parser.add_argument('--unfreeze_layers', type=int, default=0, help='Unfreeze top N transformer layers (default 0 = frozen)') parser.add_argument('--extract_layer', type=int, default=0, help='Which layer to extract from (default: 0 = embeddings, best for digits)') parser.add_argument('--position_extract', action='store_true', help='Use position-specific extraction (legacy)') parser.add_argument('--digit_pred', action='store_true', help='Predict digits instead of bits (legacy)') parser.add_argument('--positional_digit', action='store_true', default=False, help='Use positional digit extraction (legacy, 100%% on digits only)') parser.add_argument('--device', type=str, default='cuda', help='Device') args = parser.parse_args() torch.manual_seed(42) random.seed(42) if args.mode == 'router': lr = args.lr if args.lr is not None else 1e-2 train_router(epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device) elif args.mode == 'interface': lr = args.lr if args.lr is not None else 1e-3 model, fitness = train_interface( epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device ) print("\n" + "=" * 70) print(" EXPERIMENT SUMMARY") print("=" * 70) print(f"\n Control (Vanilla SmolLM2-360M): 11.90%") print(f" Experimental (Trained Interface): {100*fitness:.2f}%") if fitness > 0: print(f"\n Improvement: {100*(fitness - 0.119)/0.119:.1f}%") if fitness >= 0.99: print("\n CONCLUSION: Frozen threshold circuits + trained interface") print(" achieves near-perfect arithmetic accuracy.") print(" Core thesis VALIDATED.") else: print(f"\n CONCLUSION: Further training or architecture changes needed.") print(f" Current gap: {100*(1.0 - fitness):.2f}%") elif args.mode == 'llm': lr = args.lr if args.lr is not None else 3e-4 train_llm( epochs=args.epochs, batch_size=args.batch_size, lr=lr, unfreeze_layers=args.unfreeze_layers, extract_layer=args.extract_layer, position_extract=args.position_extract, digit_pred=args.digit_pred, positional_digit=args.positional_digit, device=args.device ) if __name__ == "__main__": main()