CharlesCNorton
Switch HybridExtractor from 256-way softmax to digit prediction
7542035
"""
Unified training script for threshold circuit LLM integration.
Modes:
--mode router : Train only OpRouter with ground truth bits (sanity check)
--mode interface : Train BitEncoder + OpRouter with ground truth bits (sanity check)
--mode llm : Train extractor with LLM hidden states (the real training)
LLM mode options:
--unfreeze_layers N : Unfreeze top N transformer layers (default 0 = fully frozen)
Hardware Profile (NVIDIA RTX 6000 Ada 48GB):
VRAM Scaling (unfreeze_layers=4):
batch_size | VRAM | %
-----------+---------+------
512 | 5,784 | 11.8%
1,024 | 7,384 | 15.0%
4,096 | 16,534 | 33.6%
13,000 | 39,000 | 79.4% <-- recommended for 80% target
Examples:
python train.py --mode llm --epochs 100 --batch_size 256
python train.py --mode llm --epochs 100 --batch_size 4096 --unfreeze_layers 4
"""
import torch
import torch.nn as nn
import torch.optim as optim
import time
import argparse
import random
from model import (
ThresholdALU, DirectCircuitModel, OpRouter,
ArithmeticModel, OPERATIONS, OP_SYMBOLS
)
from circuits import FrozenThresholdCircuits
from fitness import generate_batch, compute_fitness, compute_loss
DEVICE = 'cuda'
ONES = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
'seventeen', 'eighteen', 'nineteen']
TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety']
def int_to_words(n: int) -> str:
"""Convert integer 0-255 to English words."""
if n < 0 or n > 255:
return str(n)
if n < 20:
return ONES[n]
if n < 100:
if n % 10 == 0:
return TENS[n // 10]
return f"{TENS[n // 10]} {ONES[n % 10]}"
if n % 100 == 0:
return f"{ONES[n // 100]} hundred"
if n % 100 < 20:
return f"{ONES[n // 100]} hundred {ONES[n % 100]}"
if n % 10 == 0:
return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]}"
return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]} {ONES[n % 10]}"
def int_to_bits(val: int, device: str = 'cuda') -> torch.Tensor:
bits = torch.zeros(8, device=device)
for i in range(8):
bits[7-i] = (val >> i) & 1
return bits
def bits_to_int(bits: torch.Tensor) -> int:
val = 0
for i in range(8):
if bits[i].item() > 0.5:
val += 1 << (7-i)
return val
NL_TEMPLATES = {
'add': [
"What is {a} plus {b}?",
"Calculate {a} + {b}",
"Add {a} and {b}",
"What's the sum of {a} and {b}?",
"If I have {a} and get {b} more, how many total?",
"{a} + {b} = ?",
"Compute {a} plus {b}",
],
'sub': [
"What is {a} minus {b}?",
"Calculate {a} - {b}",
"Subtract {b} from {a}",
"What's {a} take away {b}?",
"If I have {a} and lose {b}, how many left?",
"{a} - {b} = ?",
"Compute {a} minus {b}",
],
'mul': [
"What is {a} times {b}?",
"Calculate {a} * {b}",
"Multiply {a} by {b}",
"What's {a} multiplied by {b}?",
"{a} * {b} = ?",
"Compute {a} times {b}",
"What is the product of {a} and {b}?",
],
'gt': [
"Is {a} greater than {b}?",
"Is {a} > {b}?",
"Check if {a} is larger than {b}",
"Compare: is {a} more than {b}?",
"{a} > {b}?",
],
'lt': [
"Is {a} less than {b}?",
"Is {a} < {b}?",
"Check if {a} is smaller than {b}",
"Compare: is {a} fewer than {b}?",
"{a} < {b}?",
],
'eq': [
"Is {a} equal to {b}?",
"Is {a} == {b}?",
"Does {a} equal {b}?",
"Check if {a} equals {b}",
"Are {a} and {b} the same?",
],
}
def generate_problem(max_val: int = 255):
"""
Generate a random arithmetic problem for LLM training.
Randomly mixes digit and word formats.
"""
a = random.randint(0, max_val)
b = random.randint(0, max_val)
op = random.choice(OPERATIONS)
fmt = random.choice(['digits', 'words', 'nl_digits', 'nl_words'])
if fmt == 'digits':
sym = OP_SYMBOLS[op]
text = f"{a} {sym} {b}"
elif fmt == 'words':
a_word = int_to_words(a)
b_word = int_to_words(b)
op_word = {'add': 'plus', 'sub': 'minus', 'mul': 'times',
'gt': 'greater than', 'lt': 'less than', 'eq': 'equals'}[op]
text = f"{a_word} {op_word} {b_word}"
elif fmt == 'nl_digits':
template = random.choice(NL_TEMPLATES[op])
text = template.format(a=a, b=b)
else:
template = random.choice(NL_TEMPLATES[op])
text = template.format(a=int_to_words(a), b=int_to_words(b))
if op == 'add':
result = (a + b) & 0xFF
elif op == 'sub':
result = (a - b) & 0xFF
elif op == 'mul':
result = (a * b) & 0xFF
elif op == 'gt':
result = 1 if a > b else 0
elif op == 'lt':
result = 1 if a < b else 0
elif op == 'eq':
result = 1 if a == b else 0
return text, a, b, op, result
def get_curriculum_max(epoch: int, total_epochs: int) -> int:
"""
Curriculum learning: start with small numbers, gradually increase.
Epoch 0-20%: 0-9 (single digit)
Epoch 20-50%: 0-99 (two digit)
Epoch 50-100%: 0-255 (full range)
"""
progress = epoch / total_epochs
if progress < 0.2:
return 9
elif progress < 0.5:
return 99
else:
return 255
def train_router(epochs: int = 100, batch_size: int = 256, lr: float = 1e-2, device: str = 'cuda'):
"""Train only the router with ground truth bits."""
print("=" * 70)
print(" ROUTER-ONLY TRAINING (Ground Truth Bits)")
print("=" * 70)
circuits = FrozenThresholdCircuits(device=device)
router = OpRouter(input_dim=16 + 6, hidden_dim=64, n_ops=6).to(device)
print(f"\nRouter parameters: {sum(p.numel() for p in router.parameters()):,}")
def model_fn(a_bits, b_bits, op_onehot):
x = torch.cat([a_bits, b_bits, op_onehot], dim=-1)
op_weights = router(x)
return circuits(a_bits, b_bits, op_weights)
initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device)
print(f"Initial fitness: {initial_fitness:.4f}")
optimizer = optim.AdamW(router.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
print("\nTraining...")
print("-" * 70)
best_fitness = initial_fitness
start_time = time.perf_counter()
for epoch in range(epochs):
router.train()
epoch_loss = 0.0
for _ in range(100):
batch = generate_batch(batch_size, device)
optimizer.zero_grad()
x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1)
op_weights = router(x)
pred_bits = circuits(batch['a_bits'], batch['b_bits'], op_weights)
loss = compute_loss(pred_bits, batch['expected_bits'])
loss.backward()
optimizer.step()
epoch_loss += loss.item()
scheduler.step()
if (epoch + 1) % 10 == 0 or epoch == 0:
router.eval()
fitness, details = compute_fitness(model_fn, n_samples=2000, device=device, return_details=True)
elapsed = time.perf_counter() - start_time
if fitness > best_fitness:
best_fitness = fitness
marker = " *"
else:
marker = ""
print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/100:.4f} | "
f"Fitness: {fitness:.4f}{marker} | Time: {elapsed:.1f}s")
if fitness >= 0.9999:
print("\n TARGET: 100% FITNESS ACHIEVED")
break
print("\n" + "=" * 70)
print(" RESULTS")
print("=" * 70)
router.eval()
final_fitness, details = compute_fitness(model_fn, n_samples=5000, device=device, return_details=True)
print(f"\nFinal fitness: {final_fitness:.4f}")
print(f"\nPer-operation:")
for op in OPERATIONS:
acc = details['by_op'][op]['accuracy']
print(f" {op}: {acc:.4f}")
print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s")
if final_fitness >= 0.99:
print("\nCONCLUSION: Router successfully learned operation dispatch.")
print(" With correct bit encoding, 100% is achievable.")
save_path = "D:/8bit-threshold-computer/llm_integration/trained/router.pt"
torch.save({
'router_state_dict': router.state_dict(),
'final_fitness': final_fitness,
'params': sum(p.numel() for p in router.parameters()),
}, save_path)
print(f"\nSaved trained router to: {save_path}")
return router, final_fitness
def get_gpu_memory():
"""Get GPU memory usage in MB."""
if torch.cuda.is_available():
return torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.max_memory_allocated() / 1024 / 1024
return 0, 0
def train_interface(epochs: int = 200, batch_size: int = 512, lr: float = 1e-3,
eval_interval: int = 10, device: str = 'cuda'):
"""Train BitEncoder + OpRouter with ground truth bits."""
print("=" * 70)
print(" INTERFACE TRAINING (Encoder + Router)")
print("=" * 70)
print(f" Started at: {time.strftime('%H:%M:%S')}")
print("\n[1/4] Verifying frozen circuits...")
print(" Creating DirectCircuitModel...", end=" ", flush=True)
direct_model = DirectCircuitModel(device=device)
mem, max_mem = get_gpu_memory()
print(f"done. VRAM: {mem:.0f}MB")
def direct_fn(a, b, op):
return direct_model(a, b, op)
print(" Running fitness check (1000 samples)...", end=" ", flush=True)
circuit_fitness = compute_fitness(direct_fn, n_samples=1000, device=device)
print(f"done. Fitness: {circuit_fitness:.4f}")
if circuit_fitness < 0.999:
print(" ERROR: Circuits not achieving 100%. Aborting.")
return None, 0.0
print(" STATUS: PASS")
print("\n[2/4] Initializing model...")
print(" Creating ThresholdALU...", end=" ", flush=True)
model = ThresholdALU(device=device)
mem, max_mem = get_gpu_memory()
print(f"done. VRAM: {mem:.0f}MB")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Trainable parameters: {trainable_params:,}")
def model_fn(a, b, op):
return model(a, b, op)
print(" Running initial fitness check...", end=" ", flush=True)
initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device)
print(f"done. Fitness: {initial_fitness:.4f}")
print("\n[3/4] Setting up training...")
print(" Creating optimizer...", end=" ", flush=True)
optimizer = optim.AdamW(model.parameters(), lr=lr)
print("done.")
print(" Creating scheduler...", end=" ", flush=True)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
print("done.")
print(f" Config: lr={lr}, batch_size={batch_size}, epochs={epochs}")
print("\n[4/4] Training...")
print(" Generating first batch to warm up...", end=" ", flush=True)
warmup_batch = generate_batch(batch_size, device)
mem, max_mem = get_gpu_memory()
print(f"done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)")
print("-" * 70)
best_fitness = initial_fitness
start_time = time.perf_counter()
n_batches = 100
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
epoch_start = time.perf_counter()
for batch_idx in range(n_batches):
batch = generate_batch(batch_size, device)
optimizer.zero_grad()
pred_bits = model(batch['a_bits'], batch['b_bits'], batch['op_onehot'])
loss = compute_loss(pred_bits, batch['expected_bits'])
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx == 0 and epoch == 0:
mem, max_mem = get_gpu_memory()
print(f" First forward/backward done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)")
if (batch_idx + 1) % 25 == 0:
avg_so_far = epoch_loss / (batch_idx + 1)
print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True)
scheduler.step()
avg_loss = epoch_loss / n_batches
epoch_time = time.perf_counter() - epoch_start
if (epoch + 1) % 5 == 0 or epoch == 0: # Eval every 5 epochs
model.eval()
fitness, details = compute_fitness(
model_fn, n_samples=2000, device=device, return_details=True
)
elapsed = time.perf_counter() - start_time
if fitness > best_fitness:
best_fitness = fitness
marker = " *"
else:
marker = ""
mem, _ = get_gpu_memory()
print(f"Epoch {epoch+1:4d} | Loss: {avg_loss:.4f} | "
f"Fitness: {fitness:.4f}{marker} | "
f"LR: {scheduler.get_last_lr()[0]:.2e} | "
f"VRAM: {mem:.0f}MB | "
f"Time: {elapsed:.1f}s ({epoch_time:.1f}s/epoch)")
if fitness >= 0.9999:
print("\n" + "=" * 70)
print(" TARGET ACHIEVED: 100% FITNESS")
print("=" * 70)
break
print("\n" + "=" * 70)
print(" TRAINING COMPLETE")
print("=" * 70)
model.eval()
final_fitness, details = compute_fitness(
model_fn, n_samples=5000, device=device, return_details=True
)
print(f"\nFinal fitness: {final_fitness:.4f}")
print(f"Best fitness: {best_fitness:.4f}")
print(f"\nPer-operation breakdown:")
for op in OPERATIONS:
acc = details['by_op'][op]['accuracy']
print(f" {op:6}: {acc:.4f}")
print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s")
save_path = "D:/8bit-threshold-computer/llm_integration/trained/interface.pt"
torch.save({
'encoder_state_dict': model.encoder.state_dict(),
'router_state_dict': model.router.state_dict(),
'final_fitness': final_fitness,
'best_fitness': best_fitness,
}, save_path)
print(f"\nSaved trained model to: {save_path}")
return model, final_fitness
def compute_llm_loss(pred_bits, a_bits, b_bits, op_logits,
target_result, target_a, target_b, target_op_idx,
bit_weight: float = 2.0):
"""
Multi-component loss for LLM training.
bit_weight: multiplier for a/b bit losses (default 2x since extraction is the bottleneck)
"""
result_loss = nn.functional.binary_cross_entropy_with_logits(
pred_bits, target_result
)
a_bits_safe = torch.clamp(a_bits, 0.0, 1.0)
b_bits_safe = torch.clamp(b_bits, 0.0, 1.0)
a_bits_safe = torch.nan_to_num(a_bits_safe, nan=0.5, posinf=1.0, neginf=0.0)
b_bits_safe = torch.nan_to_num(b_bits_safe, nan=0.5, posinf=1.0, neginf=0.0)
a_loss = nn.functional.binary_cross_entropy(
torch.clamp(a_bits_safe, 1e-6, 1-1e-6), target_a
)
b_loss = nn.functional.binary_cross_entropy(
torch.clamp(b_bits_safe, 1e-6, 1-1e-6), target_b
)
op_loss = nn.functional.cross_entropy(op_logits, target_op_idx)
total = result_loss + bit_weight * a_loss + bit_weight * b_loss + op_loss
total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0)
return total, {
'result': result_loss.item() if not torch.isnan(result_loss) else 10.0,
'a': a_loss.item() if not torch.isnan(a_loss) else 10.0,
'b': b_loss.item() if not torch.isnan(b_loss) else 10.0,
'op': op_loss.item() if not torch.isnan(op_loss) else 10.0
}
def value_to_digits(value: int) -> list:
"""Convert integer value to list of digits (hundreds, tens, ones)."""
digits = []
for place in [100, 10, 1]:
digit = (value // place) % 10
digits.append(digit)
return digits
def compute_positional_digit_loss(pred_bits, op_logits, a_digit_logits_list, b_digit_logits_list,
target_result, target_op_idx, target_a_values, target_b_values,
device, digit_weight: float = 5.0):
"""
Loss for positional digit extraction with DIRECT digit supervision.
This provides strong gradients by directly supervising digit classification
instead of going through the value -> bits conversion.
"""
result_loss = nn.functional.binary_cross_entropy_with_logits(
pred_bits, target_result
)
op_loss = nn.functional.cross_entropy(op_logits, target_op_idx)
a_digit_loss = torch.tensor(0.0, device=device)
b_digit_loss = torch.tensor(0.0, device=device)
n_a_digits = 0
n_b_digits = 0
for i, (a_logits_list, b_logits_list) in enumerate(zip(a_digit_logits_list, b_digit_logits_list)):
target_a = target_a_values[i].item()
target_b = target_b_values[i].item()
a_digits = value_to_digits(int(target_a))
b_digits = value_to_digits(int(target_b))
n_a = len(a_logits_list)
n_b = len(b_logits_list)
if n_a > 0:
target_a_digits = a_digits[-n_a:]
for j, logits in enumerate(a_logits_list):
target_digit = torch.tensor([target_a_digits[j]], device=device, dtype=torch.long)
a_digit_loss = a_digit_loss + nn.functional.cross_entropy(logits.unsqueeze(0), target_digit)
n_a_digits += 1
if n_b > 0:
target_b_digits = b_digits[-n_b:]
for j, logits in enumerate(b_logits_list):
target_digit = torch.tensor([target_b_digits[j]], device=device, dtype=torch.long)
b_digit_loss = b_digit_loss + nn.functional.cross_entropy(logits.unsqueeze(0), target_digit)
n_b_digits += 1
if n_a_digits > 0:
a_digit_loss = a_digit_loss / n_a_digits
if n_b_digits > 0:
b_digit_loss = b_digit_loss / n_b_digits
total = result_loss + digit_weight * a_digit_loss + digit_weight * b_digit_loss + op_loss
total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0)
return total, {
'result': result_loss.item() if not torch.isnan(result_loss) else 10.0,
'a_digit': a_digit_loss.item() if not torch.isnan(a_digit_loss) else 10.0,
'b_digit': b_digit_loss.item() if not torch.isnan(b_digit_loss) else 10.0,
'op': op_loss.item() if not torch.isnan(op_loss) else 10.0
}
def compute_hybrid_loss(pred_bits, op_logits, used_lookup,
a_digit_logits, b_digit_logits,
target_result, target_a_values, target_b_values, target_op_idx,
device, digit_weight: float = 2.0):
"""
Loss for hybrid extraction with digit-level prediction.
Uses cross-entropy on each digit (hundreds, tens, ones) for word samples.
Samples using digit lookup are already 100% accurate - no loss computed.
"""
result_loss = nn.functional.binary_cross_entropy_with_logits(
pred_bits, target_result
)
op_loss = nn.functional.cross_entropy(op_logits, target_op_idx)
word_mask = ~used_lookup
n_words = word_mask.sum().item()
if n_words > 0 and a_digit_logits is not None and b_digit_logits is not None:
target_a_word = target_a_values[word_mask].long()
target_b_word = target_b_values[word_mask].long()
a_hundreds = target_a_word // 100
a_tens = (target_a_word % 100) // 10
a_ones = target_a_word % 10
b_hundreds = target_b_word // 100
b_tens = (target_b_word % 100) // 10
b_ones = target_b_word % 10
a_logits = a_digit_logits.view(-1, 3, 10)
b_logits = b_digit_logits.view(-1, 3, 10)
a_digit_loss = (
nn.functional.cross_entropy(a_logits[:, 0], a_hundreds) +
nn.functional.cross_entropy(a_logits[:, 1], a_tens) +
nn.functional.cross_entropy(a_logits[:, 2], a_ones)
) / 3.0
b_digit_loss = (
nn.functional.cross_entropy(b_logits[:, 0], b_hundreds) +
nn.functional.cross_entropy(b_logits[:, 1], b_tens) +
nn.functional.cross_entropy(b_logits[:, 2], b_ones)
) / 3.0
else:
a_digit_loss = torch.tensor(0.0, device=device)
b_digit_loss = torch.tensor(0.0, device=device)
total = result_loss + op_loss + digit_weight * (a_digit_loss + b_digit_loss)
total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0)
return total, {
'result': result_loss.item() if not torch.isnan(result_loss) else 10.0,
'a_digit': a_digit_loss.item() if not torch.isnan(a_digit_loss) else 10.0,
'b_digit': b_digit_loss.item() if not torch.isnan(b_digit_loss) else 10.0,
'op': op_loss.item() if not torch.isnan(op_loss) else 10.0,
'n_words': n_words,
'n_lookup': used_lookup.sum().item()
}
def evaluate_llm(model, n_samples: int = 500):
"""Evaluate LLM model on random problems (mixed digit/word format)."""
model.extractor.eval()
correct = 0
op_correct = 0
for _ in range(n_samples):
text, a, b, op, expected = generate_problem()
with torch.no_grad():
outputs = model([text])
result_bits = outputs[0]
op_logits = outputs[3]
pred_result = bits_to_int(result_bits[0])
pred_op = OPERATIONS[op_logits[0].argmax().item()]
if pred_result == expected:
correct += 1
if pred_op == op:
op_correct += 1
model.extractor.train()
return correct / n_samples, op_correct / n_samples
def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
unfreeze_layers: int = 0, extract_layer: int = -1,
position_extract: bool = False, digit_pred: bool = False,
positional_digit: bool = False, device: str = 'cuda'):
"""
Train extractor with LLM hidden states.
Args:
unfreeze_layers: Number of top transformer layers to unfreeze (0 = fully frozen)
extract_layer: Which layer to extract from (-1 = last)
position_extract: Use position-specific extraction (legacy)
digit_pred: Predict digits instead of bits (legacy)
positional_digit: Use positional digit extraction (legacy, 100% on digits only)
"""
hybrid = not (positional_digit or position_extract or digit_pred)
print("=" * 70)
print(" LLM TRAINING")
if unfreeze_layers > 0:
print(f" {unfreeze_layers} transformer layers unfrozen")
else:
print(" LLM frozen")
if extract_layer != -1:
print(f" Extracting from layer {extract_layer}")
if hybrid:
print(" HYBRID extraction (digit lookup + word learning)")
elif positional_digit:
print(" POSITIONAL DIGIT extraction (legacy, 100% on digits only)")
elif position_extract:
print(" Position-specific extraction (legacy)")
elif digit_pred:
print(" Digit-level prediction (legacy)")
print("=" * 70)
print("\nInitializing model...")
model = ArithmeticModel(
device=device,
unfreeze_layers=unfreeze_layers,
extract_layer=extract_layer,
position_extract=position_extract,
digit_pred=digit_pred,
positional_digit=positional_digit,
hybrid=hybrid
)
optimizer = optim.AdamW(model.trainable_parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
print(f"\nTraining config:")
print(f" Epochs: {epochs}")
print(f" Batch size: {batch_size}")
print(f" Learning rate: {lr}")
print(f" Unfreeze layers: {unfreeze_layers}")
print(f" Samples/epoch: {batch_size * 20}")
print(f"\nInitial evaluation (200 samples)...")
acc, op_acc = evaluate_llm(model, 200)
print(f" Accuracy: {acc:.4f}, Op accuracy: {op_acc:.4f}")
print(f"\nStarting training...")
print("-" * 70)
best_acc = acc
start_time = time.perf_counter()
for epoch in range(epochs):
model.extractor.train()
if unfreeze_layers > 0:
model.llm.train()
max_val = get_curriculum_max(epoch, epochs)
epoch_loss = 0
if hybrid:
epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0, 'n_words': 0, 'n_lookup': 0}
elif positional_digit:
epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0}
else:
epoch_losses = {'result': 0, 'a': 0, 'b': 0, 'op': 0}
n_batches = 20
epoch_start = time.perf_counter()
for batch_idx in range(n_batches):
batch_texts = []
batch_a = []
batch_b = []
batch_op = []
batch_result = []
batch_a_values = []
batch_b_values = []
for _ in range(batch_size):
text, a, b, op, result = generate_problem(max_val)
batch_texts.append(text)
batch_a.append(int_to_bits(a, device))
batch_b.append(int_to_bits(b, device))
batch_op.append(OPERATIONS.index(op))
batch_result.append(int_to_bits(result, device))
batch_a_values.append(a)
batch_b_values.append(b)
target_a = torch.stack(batch_a)
target_b = torch.stack(batch_b)
target_op = torch.tensor(batch_op, device=device)
target_result = torch.stack(batch_result)
target_a_values = torch.tensor(batch_a_values, device=device, dtype=torch.float32)
target_b_values = torch.tensor(batch_b_values, device=device, dtype=torch.float32)
optimizer.zero_grad()
outputs = model(batch_texts)
pred_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3]
if hybrid:
a_values, b_values, used_lookup = outputs[4], outputs[5], outputs[6]
a_digit_logits, b_digit_logits = outputs[7], outputs[8]
loss, losses = compute_hybrid_loss(
pred_bits, op_logits, used_lookup,
a_digit_logits, b_digit_logits,
target_result, target_a_values, target_b_values, target_op, device
)
elif positional_digit:
a_digit_logits_list = outputs[7]
b_digit_logits_list = outputs[8]
loss, losses = compute_positional_digit_loss(
pred_bits, op_logits, a_digit_logits_list, b_digit_logits_list,
target_result, target_op, target_a_values, target_b_values, device
)
else:
loss, losses = compute_llm_loss(
pred_bits, a_bits, b_bits, op_logits,
target_result, target_a, target_b, target_op
)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.trainable_parameters(), 1.0)
optimizer.step()
epoch_loss += loss.item()
for k in epoch_losses:
epoch_losses[k] += losses[k]
if (batch_idx + 1) % 5 == 0:
avg_so_far = epoch_loss / (batch_idx + 1)
print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True)
epoch_time = time.perf_counter() - epoch_start
scheduler.step()
avg_loss = epoch_loss / n_batches
for k in epoch_losses:
epoch_losses[k] /= n_batches
acc, op_acc = evaluate_llm(model, 300)
elapsed = time.perf_counter() - start_time
marker = " *" if acc > best_acc else ""
if acc > best_acc:
best_acc = acc
mem, _ = get_gpu_memory()
print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | "
f"Acc: {acc:.4f}{marker} | OpAcc: {op_acc:.4f} | "
f"Range: 0-{max_val} | VRAM: {mem:.0f}MB | Time: {elapsed:.0f}s")
if hybrid:
print(f" Losses - result:{epoch_losses['result']:.4f} "
f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} "
f"op:{epoch_losses['op']:.4f} | words:{epoch_losses['n_words']:.0f} lookup:{epoch_losses['n_lookup']:.0f}")
elif positional_digit:
print(f" Losses - result:{epoch_losses['result']:.4f} "
f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} "
f"op:{epoch_losses['op']:.4f}")
else:
print(f" Losses - result:{epoch_losses['result']:.4f} "
f"a:{epoch_losses['a']:.4f} b:{epoch_losses['b']:.4f} "
f"op:{epoch_losses['op']:.4f}")
if acc >= 0.99:
print("\n" + "=" * 70)
print(" TARGET ACHIEVED: 99%+ ACCURACY")
print("=" * 70)
break
print("\n" + "=" * 70)
print(" FINAL EVALUATION")
print("=" * 70)
acc, op_acc = evaluate_llm(model, 1000)
print(f"Final accuracy: {acc:.4f}")
print(f"Final op accuracy: {op_acc:.4f}")
print(f"Best accuracy: {best_acc:.4f}")
print("\nSample predictions:")
for _ in range(10):
text, a, b, op, expected = generate_problem()
with torch.no_grad():
outputs = model([text])
result_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3]
pred = bits_to_int(result_bits[0])
pred_a = bits_to_int(a_bits[0])
pred_b = bits_to_int(b_bits[0])
pred_op = OPERATIONS[op_logits[0].argmax().item()]
status = "OK" if pred == expected else "WRONG"
print(f" '{text}' = {expected} | pred={pred} (a={pred_a}, b={pred_b}, op={pred_op}) [{status}]")
save_path = "D:/8bit-threshold-computer/llm_integration/trained/llm.pt"
save_dict = {
'extractor_state_dict': model.extractor.state_dict(),
'final_accuracy': acc,
'best_accuracy': best_acc,
'unfreeze_layers': unfreeze_layers,
}
if unfreeze_layers > 0:
save_dict['llm_state_dict'] = {
k: v for k, v in model.llm.state_dict().items()
if any(f'layers.{i}.' in k for i in range(len(model.llm.model.layers) - unfreeze_layers, len(model.llm.model.layers)))
}
torch.save(save_dict, save_path)
print(f"\nSaved to: {save_path}")
return model, acc
def main():
parser = argparse.ArgumentParser(
description='Unified training for threshold circuit LLM integration',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Modes:
router - Train only OpRouter with ground truth bits (sanity check)
interface - Train BitEncoder + OpRouter with ground truth bits (sanity check)
llm - Train extractor with LLM hidden states (the real training)
LLM options:
--unfreeze_layers N Fine-tune top N transformer layers
--extract_layer N Extract from layer N (-1 = last)
--position_extract Use position-specific extraction
--digit_pred Predict digits instead of bits
Baked-in: curriculum learning (0-9 -> 0-99 -> 0-255), 2x loss weight for a/b
Examples:
python train.py --mode llm --epochs 100
python train.py --mode llm --position_extract
python train.py --mode llm --digit_pred --extract_layer 12
python train.py --mode llm --unfreeze_layers 4 --batch_size 4096
"""
)
parser.add_argument('--mode', type=str, required=True,
choices=['router', 'interface', 'llm'],
help='Training mode')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=512, help='Batch size (default: 512)')
parser.add_argument('--lr', type=float, default=None,
help='Learning rate (default: mode-specific)')
parser.add_argument('--unfreeze_layers', type=int, default=0,
help='Unfreeze top N transformer layers (default 0 = frozen)')
parser.add_argument('--extract_layer', type=int, default=0,
help='Which layer to extract from (default: 0 = embeddings, best for digits)')
parser.add_argument('--position_extract', action='store_true',
help='Use position-specific extraction (legacy)')
parser.add_argument('--digit_pred', action='store_true',
help='Predict digits instead of bits (legacy)')
parser.add_argument('--positional_digit', action='store_true', default=False,
help='Use positional digit extraction (legacy, 100%% on digits only)')
parser.add_argument('--device', type=str, default='cuda', help='Device')
args = parser.parse_args()
torch.manual_seed(42)
random.seed(42)
if args.mode == 'router':
lr = args.lr if args.lr is not None else 1e-2
train_router(epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device)
elif args.mode == 'interface':
lr = args.lr if args.lr is not None else 1e-3
model, fitness = train_interface(
epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device
)
print("\n" + "=" * 70)
print(" EXPERIMENT SUMMARY")
print("=" * 70)
print(f"\n Control (Vanilla SmolLM2-360M): 11.90%")
print(f" Experimental (Trained Interface): {100*fitness:.2f}%")
if fitness > 0:
print(f"\n Improvement: {100*(fitness - 0.119)/0.119:.1f}%")
if fitness >= 0.99:
print("\n CONCLUSION: Frozen threshold circuits + trained interface")
print(" achieves near-perfect arithmetic accuracy.")
print(" Core thesis VALIDATED.")
else:
print(f"\n CONCLUSION: Further training or architecture changes needed.")
print(f" Current gap: {100*(1.0 - fitness):.2f}%")
elif args.mode == 'llm':
lr = args.lr if args.lr is not None else 3e-4
train_llm(
epochs=args.epochs,
batch_size=args.batch_size,
lr=lr,
unfreeze_layers=args.unfreeze_layers,
extract_layer=args.extract_layer,
position_extract=args.position_extract,
digit_pred=args.digit_pred,
positional_digit=args.positional_digit,
device=args.device
)
if __name__ == "__main__":
main()