|
|
""" |
|
|
Unified training script for threshold circuit LLM integration. |
|
|
|
|
|
Modes: |
|
|
--mode router : Train only OpRouter with ground truth bits (sanity check) |
|
|
--mode interface : Train BitEncoder + OpRouter with ground truth bits (sanity check) |
|
|
--mode llm : Train extractor with LLM hidden states (the real training) |
|
|
|
|
|
LLM mode options: |
|
|
--unfreeze_layers N : Unfreeze top N transformer layers (default 0 = fully frozen) |
|
|
|
|
|
Hardware Profile (NVIDIA RTX 6000 Ada 48GB): |
|
|
VRAM Scaling (unfreeze_layers=4): |
|
|
batch_size | VRAM | % |
|
|
-----------+---------+------ |
|
|
512 | 5,784 | 11.8% |
|
|
1,024 | 7,384 | 15.0% |
|
|
4,096 | 16,534 | 33.6% |
|
|
13,000 | 39,000 | 79.4% <-- recommended for 80% target |
|
|
|
|
|
Examples: |
|
|
python train.py --mode llm --epochs 100 --batch_size 256 |
|
|
python train.py --mode llm --epochs 100 --batch_size 4096 --unfreeze_layers 4 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import time |
|
|
import argparse |
|
|
import random |
|
|
|
|
|
from model import ( |
|
|
ThresholdALU, DirectCircuitModel, OpRouter, |
|
|
ArithmeticModel, OPERATIONS, OP_SYMBOLS |
|
|
) |
|
|
from circuits import FrozenThresholdCircuits |
|
|
from fitness import generate_batch, compute_fitness, compute_loss |
|
|
|
|
|
DEVICE = 'cuda' |
|
|
|
|
|
ONES = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', |
|
|
'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', |
|
|
'seventeen', 'eighteen', 'nineteen'] |
|
|
TENS = ['', '', 'twenty', 'thirty', 'forty', 'fifty', 'sixty', 'seventy', 'eighty', 'ninety'] |
|
|
|
|
|
def int_to_words(n: int) -> str: |
|
|
"""Convert integer 0-255 to English words.""" |
|
|
if n < 0 or n > 255: |
|
|
return str(n) |
|
|
if n < 20: |
|
|
return ONES[n] |
|
|
if n < 100: |
|
|
if n % 10 == 0: |
|
|
return TENS[n // 10] |
|
|
return f"{TENS[n // 10]} {ONES[n % 10]}" |
|
|
if n % 100 == 0: |
|
|
return f"{ONES[n // 100]} hundred" |
|
|
if n % 100 < 20: |
|
|
return f"{ONES[n // 100]} hundred {ONES[n % 100]}" |
|
|
if n % 10 == 0: |
|
|
return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]}" |
|
|
return f"{ONES[n // 100]} hundred {TENS[(n % 100) // 10]} {ONES[n % 10]}" |
|
|
|
|
|
|
|
|
def int_to_bits(val: int, device: str = 'cuda') -> torch.Tensor: |
|
|
bits = torch.zeros(8, device=device) |
|
|
for i in range(8): |
|
|
bits[7-i] = (val >> i) & 1 |
|
|
return bits |
|
|
|
|
|
|
|
|
def bits_to_int(bits: torch.Tensor) -> int: |
|
|
val = 0 |
|
|
for i in range(8): |
|
|
if bits[i].item() > 0.5: |
|
|
val += 1 << (7-i) |
|
|
return val |
|
|
|
|
|
|
|
|
NL_TEMPLATES = { |
|
|
'add': [ |
|
|
"What is {a} plus {b}?", |
|
|
"Calculate {a} + {b}", |
|
|
"Add {a} and {b}", |
|
|
"What's the sum of {a} and {b}?", |
|
|
"If I have {a} and get {b} more, how many total?", |
|
|
"{a} + {b} = ?", |
|
|
"Compute {a} plus {b}", |
|
|
], |
|
|
'sub': [ |
|
|
"What is {a} minus {b}?", |
|
|
"Calculate {a} - {b}", |
|
|
"Subtract {b} from {a}", |
|
|
"What's {a} take away {b}?", |
|
|
"If I have {a} and lose {b}, how many left?", |
|
|
"{a} - {b} = ?", |
|
|
"Compute {a} minus {b}", |
|
|
], |
|
|
'mul': [ |
|
|
"What is {a} times {b}?", |
|
|
"Calculate {a} * {b}", |
|
|
"Multiply {a} by {b}", |
|
|
"What's {a} multiplied by {b}?", |
|
|
"{a} * {b} = ?", |
|
|
"Compute {a} times {b}", |
|
|
"What is the product of {a} and {b}?", |
|
|
], |
|
|
'gt': [ |
|
|
"Is {a} greater than {b}?", |
|
|
"Is {a} > {b}?", |
|
|
"Check if {a} is larger than {b}", |
|
|
"Compare: is {a} more than {b}?", |
|
|
"{a} > {b}?", |
|
|
], |
|
|
'lt': [ |
|
|
"Is {a} less than {b}?", |
|
|
"Is {a} < {b}?", |
|
|
"Check if {a} is smaller than {b}", |
|
|
"Compare: is {a} fewer than {b}?", |
|
|
"{a} < {b}?", |
|
|
], |
|
|
'eq': [ |
|
|
"Is {a} equal to {b}?", |
|
|
"Is {a} == {b}?", |
|
|
"Does {a} equal {b}?", |
|
|
"Check if {a} equals {b}", |
|
|
"Are {a} and {b} the same?", |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
def generate_problem(max_val: int = 255): |
|
|
""" |
|
|
Generate a random arithmetic problem for LLM training. |
|
|
Randomly mixes digit and word formats. |
|
|
""" |
|
|
a = random.randint(0, max_val) |
|
|
b = random.randint(0, max_val) |
|
|
op = random.choice(OPERATIONS) |
|
|
|
|
|
fmt = random.choice(['digits', 'words', 'nl_digits', 'nl_words']) |
|
|
|
|
|
if fmt == 'digits': |
|
|
sym = OP_SYMBOLS[op] |
|
|
text = f"{a} {sym} {b}" |
|
|
elif fmt == 'words': |
|
|
a_word = int_to_words(a) |
|
|
b_word = int_to_words(b) |
|
|
op_word = {'add': 'plus', 'sub': 'minus', 'mul': 'times', |
|
|
'gt': 'greater than', 'lt': 'less than', 'eq': 'equals'}[op] |
|
|
text = f"{a_word} {op_word} {b_word}" |
|
|
elif fmt == 'nl_digits': |
|
|
template = random.choice(NL_TEMPLATES[op]) |
|
|
text = template.format(a=a, b=b) |
|
|
else: |
|
|
template = random.choice(NL_TEMPLATES[op]) |
|
|
text = template.format(a=int_to_words(a), b=int_to_words(b)) |
|
|
|
|
|
if op == 'add': |
|
|
result = (a + b) & 0xFF |
|
|
elif op == 'sub': |
|
|
result = (a - b) & 0xFF |
|
|
elif op == 'mul': |
|
|
result = (a * b) & 0xFF |
|
|
elif op == 'gt': |
|
|
result = 1 if a > b else 0 |
|
|
elif op == 'lt': |
|
|
result = 1 if a < b else 0 |
|
|
elif op == 'eq': |
|
|
result = 1 if a == b else 0 |
|
|
|
|
|
return text, a, b, op, result |
|
|
|
|
|
|
|
|
def get_curriculum_max(epoch: int, total_epochs: int) -> int: |
|
|
""" |
|
|
Curriculum learning: start with small numbers, gradually increase. |
|
|
Epoch 0-20%: 0-9 (single digit) |
|
|
Epoch 20-50%: 0-99 (two digit) |
|
|
Epoch 50-100%: 0-255 (full range) |
|
|
""" |
|
|
progress = epoch / total_epochs |
|
|
if progress < 0.2: |
|
|
return 9 |
|
|
elif progress < 0.5: |
|
|
return 99 |
|
|
else: |
|
|
return 255 |
|
|
|
|
|
|
|
|
def train_router(epochs: int = 100, batch_size: int = 256, lr: float = 1e-2, device: str = 'cuda'): |
|
|
"""Train only the router with ground truth bits.""" |
|
|
print("=" * 70) |
|
|
print(" ROUTER-ONLY TRAINING (Ground Truth Bits)") |
|
|
print("=" * 70) |
|
|
|
|
|
circuits = FrozenThresholdCircuits(device=device) |
|
|
router = OpRouter(input_dim=16 + 6, hidden_dim=64, n_ops=6).to(device) |
|
|
|
|
|
print(f"\nRouter parameters: {sum(p.numel() for p in router.parameters()):,}") |
|
|
|
|
|
def model_fn(a_bits, b_bits, op_onehot): |
|
|
x = torch.cat([a_bits, b_bits, op_onehot], dim=-1) |
|
|
op_weights = router(x) |
|
|
return circuits(a_bits, b_bits, op_weights) |
|
|
|
|
|
initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device) |
|
|
print(f"Initial fitness: {initial_fitness:.4f}") |
|
|
|
|
|
optimizer = optim.AdamW(router.parameters(), lr=lr) |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
|
|
|
|
|
print("\nTraining...") |
|
|
print("-" * 70) |
|
|
|
|
|
best_fitness = initial_fitness |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
for epoch in range(epochs): |
|
|
router.train() |
|
|
epoch_loss = 0.0 |
|
|
|
|
|
for _ in range(100): |
|
|
batch = generate_batch(batch_size, device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
x = torch.cat([batch['a_bits'], batch['b_bits'], batch['op_onehot']], dim=-1) |
|
|
op_weights = router(x) |
|
|
pred_bits = circuits(batch['a_bits'], batch['b_bits'], op_weights) |
|
|
|
|
|
loss = compute_loss(pred_bits, batch['expected_bits']) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
if (epoch + 1) % 10 == 0 or epoch == 0: |
|
|
router.eval() |
|
|
fitness, details = compute_fitness(model_fn, n_samples=2000, device=device, return_details=True) |
|
|
elapsed = time.perf_counter() - start_time |
|
|
|
|
|
if fitness > best_fitness: |
|
|
best_fitness = fitness |
|
|
marker = " *" |
|
|
else: |
|
|
marker = "" |
|
|
|
|
|
print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss/100:.4f} | " |
|
|
f"Fitness: {fitness:.4f}{marker} | Time: {elapsed:.1f}s") |
|
|
|
|
|
if fitness >= 0.9999: |
|
|
print("\n TARGET: 100% FITNESS ACHIEVED") |
|
|
break |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" RESULTS") |
|
|
print("=" * 70) |
|
|
|
|
|
router.eval() |
|
|
final_fitness, details = compute_fitness(model_fn, n_samples=5000, device=device, return_details=True) |
|
|
|
|
|
print(f"\nFinal fitness: {final_fitness:.4f}") |
|
|
print(f"\nPer-operation:") |
|
|
for op in OPERATIONS: |
|
|
acc = details['by_op'][op]['accuracy'] |
|
|
print(f" {op}: {acc:.4f}") |
|
|
|
|
|
print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s") |
|
|
|
|
|
if final_fitness >= 0.99: |
|
|
print("\nCONCLUSION: Router successfully learned operation dispatch.") |
|
|
print(" With correct bit encoding, 100% is achievable.") |
|
|
|
|
|
save_path = "D:/8bit-threshold-computer/llm_integration/trained/router.pt" |
|
|
torch.save({ |
|
|
'router_state_dict': router.state_dict(), |
|
|
'final_fitness': final_fitness, |
|
|
'params': sum(p.numel() for p in router.parameters()), |
|
|
}, save_path) |
|
|
print(f"\nSaved trained router to: {save_path}") |
|
|
|
|
|
return router, final_fitness |
|
|
|
|
|
|
|
|
def get_gpu_memory(): |
|
|
"""Get GPU memory usage in MB.""" |
|
|
if torch.cuda.is_available(): |
|
|
return torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.max_memory_allocated() / 1024 / 1024 |
|
|
return 0, 0 |
|
|
|
|
|
|
|
|
def train_interface(epochs: int = 200, batch_size: int = 512, lr: float = 1e-3, |
|
|
eval_interval: int = 10, device: str = 'cuda'): |
|
|
"""Train BitEncoder + OpRouter with ground truth bits.""" |
|
|
print("=" * 70) |
|
|
print(" INTERFACE TRAINING (Encoder + Router)") |
|
|
print("=" * 70) |
|
|
print(f" Started at: {time.strftime('%H:%M:%S')}") |
|
|
|
|
|
print("\n[1/4] Verifying frozen circuits...") |
|
|
print(" Creating DirectCircuitModel...", end=" ", flush=True) |
|
|
direct_model = DirectCircuitModel(device=device) |
|
|
mem, max_mem = get_gpu_memory() |
|
|
print(f"done. VRAM: {mem:.0f}MB") |
|
|
|
|
|
def direct_fn(a, b, op): |
|
|
return direct_model(a, b, op) |
|
|
|
|
|
print(" Running fitness check (1000 samples)...", end=" ", flush=True) |
|
|
circuit_fitness = compute_fitness(direct_fn, n_samples=1000, device=device) |
|
|
print(f"done. Fitness: {circuit_fitness:.4f}") |
|
|
if circuit_fitness < 0.999: |
|
|
print(" ERROR: Circuits not achieving 100%. Aborting.") |
|
|
return None, 0.0 |
|
|
print(" STATUS: PASS") |
|
|
|
|
|
print("\n[2/4] Initializing model...") |
|
|
print(" Creating ThresholdALU...", end=" ", flush=True) |
|
|
model = ThresholdALU(device=device) |
|
|
mem, max_mem = get_gpu_memory() |
|
|
print(f"done. VRAM: {mem:.0f}MB") |
|
|
|
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f" Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
def model_fn(a, b, op): |
|
|
return model(a, b, op) |
|
|
|
|
|
print(" Running initial fitness check...", end=" ", flush=True) |
|
|
initial_fitness = compute_fitness(model_fn, n_samples=1000, device=device) |
|
|
print(f"done. Fitness: {initial_fitness:.4f}") |
|
|
|
|
|
print("\n[3/4] Setting up training...") |
|
|
print(" Creating optimizer...", end=" ", flush=True) |
|
|
optimizer = optim.AdamW(model.parameters(), lr=lr) |
|
|
print("done.") |
|
|
print(" Creating scheduler...", end=" ", flush=True) |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
|
|
print("done.") |
|
|
|
|
|
print(f" Config: lr={lr}, batch_size={batch_size}, epochs={epochs}") |
|
|
|
|
|
print("\n[4/4] Training...") |
|
|
print(" Generating first batch to warm up...", end=" ", flush=True) |
|
|
warmup_batch = generate_batch(batch_size, device) |
|
|
mem, max_mem = get_gpu_memory() |
|
|
print(f"done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)") |
|
|
|
|
|
print("-" * 70) |
|
|
|
|
|
best_fitness = initial_fitness |
|
|
start_time = time.perf_counter() |
|
|
n_batches = 100 |
|
|
|
|
|
for epoch in range(epochs): |
|
|
model.train() |
|
|
epoch_loss = 0.0 |
|
|
epoch_start = time.perf_counter() |
|
|
|
|
|
for batch_idx in range(n_batches): |
|
|
batch = generate_batch(batch_size, device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
pred_bits = model(batch['a_bits'], batch['b_bits'], batch['op_onehot']) |
|
|
|
|
|
loss = compute_loss(pred_bits, batch['expected_bits']) |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
if batch_idx == 0 and epoch == 0: |
|
|
mem, max_mem = get_gpu_memory() |
|
|
print(f" First forward/backward done. VRAM: {mem:.0f}MB (max: {max_mem:.0f}MB)") |
|
|
|
|
|
if (batch_idx + 1) % 25 == 0: |
|
|
avg_so_far = epoch_loss / (batch_idx + 1) |
|
|
print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True) |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
avg_loss = epoch_loss / n_batches |
|
|
epoch_time = time.perf_counter() - epoch_start |
|
|
|
|
|
if (epoch + 1) % 5 == 0 or epoch == 0: |
|
|
model.eval() |
|
|
fitness, details = compute_fitness( |
|
|
model_fn, n_samples=2000, device=device, return_details=True |
|
|
) |
|
|
|
|
|
elapsed = time.perf_counter() - start_time |
|
|
|
|
|
if fitness > best_fitness: |
|
|
best_fitness = fitness |
|
|
marker = " *" |
|
|
else: |
|
|
marker = "" |
|
|
|
|
|
mem, _ = get_gpu_memory() |
|
|
print(f"Epoch {epoch+1:4d} | Loss: {avg_loss:.4f} | " |
|
|
f"Fitness: {fitness:.4f}{marker} | " |
|
|
f"LR: {scheduler.get_last_lr()[0]:.2e} | " |
|
|
f"VRAM: {mem:.0f}MB | " |
|
|
f"Time: {elapsed:.1f}s ({epoch_time:.1f}s/epoch)") |
|
|
|
|
|
if fitness >= 0.9999: |
|
|
print("\n" + "=" * 70) |
|
|
print(" TARGET ACHIEVED: 100% FITNESS") |
|
|
print("=" * 70) |
|
|
break |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" TRAINING COMPLETE") |
|
|
print("=" * 70) |
|
|
|
|
|
model.eval() |
|
|
final_fitness, details = compute_fitness( |
|
|
model_fn, n_samples=5000, device=device, return_details=True |
|
|
) |
|
|
|
|
|
print(f"\nFinal fitness: {final_fitness:.4f}") |
|
|
print(f"Best fitness: {best_fitness:.4f}") |
|
|
print(f"\nPer-operation breakdown:") |
|
|
for op in OPERATIONS: |
|
|
acc = details['by_op'][op]['accuracy'] |
|
|
print(f" {op:6}: {acc:.4f}") |
|
|
|
|
|
print(f"\nTotal time: {time.perf_counter() - start_time:.1f}s") |
|
|
|
|
|
save_path = "D:/8bit-threshold-computer/llm_integration/trained/interface.pt" |
|
|
torch.save({ |
|
|
'encoder_state_dict': model.encoder.state_dict(), |
|
|
'router_state_dict': model.router.state_dict(), |
|
|
'final_fitness': final_fitness, |
|
|
'best_fitness': best_fitness, |
|
|
}, save_path) |
|
|
print(f"\nSaved trained model to: {save_path}") |
|
|
|
|
|
return model, final_fitness |
|
|
|
|
|
|
|
|
def compute_llm_loss(pred_bits, a_bits, b_bits, op_logits, |
|
|
target_result, target_a, target_b, target_op_idx, |
|
|
bit_weight: float = 2.0): |
|
|
""" |
|
|
Multi-component loss for LLM training. |
|
|
bit_weight: multiplier for a/b bit losses (default 2x since extraction is the bottleneck) |
|
|
""" |
|
|
result_loss = nn.functional.binary_cross_entropy_with_logits( |
|
|
pred_bits, target_result |
|
|
) |
|
|
|
|
|
a_bits_safe = torch.clamp(a_bits, 0.0, 1.0) |
|
|
b_bits_safe = torch.clamp(b_bits, 0.0, 1.0) |
|
|
a_bits_safe = torch.nan_to_num(a_bits_safe, nan=0.5, posinf=1.0, neginf=0.0) |
|
|
b_bits_safe = torch.nan_to_num(b_bits_safe, nan=0.5, posinf=1.0, neginf=0.0) |
|
|
|
|
|
a_loss = nn.functional.binary_cross_entropy( |
|
|
torch.clamp(a_bits_safe, 1e-6, 1-1e-6), target_a |
|
|
) |
|
|
b_loss = nn.functional.binary_cross_entropy( |
|
|
torch.clamp(b_bits_safe, 1e-6, 1-1e-6), target_b |
|
|
) |
|
|
|
|
|
op_loss = nn.functional.cross_entropy(op_logits, target_op_idx) |
|
|
|
|
|
total = result_loss + bit_weight * a_loss + bit_weight * b_loss + op_loss |
|
|
total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0) |
|
|
|
|
|
return total, { |
|
|
'result': result_loss.item() if not torch.isnan(result_loss) else 10.0, |
|
|
'a': a_loss.item() if not torch.isnan(a_loss) else 10.0, |
|
|
'b': b_loss.item() if not torch.isnan(b_loss) else 10.0, |
|
|
'op': op_loss.item() if not torch.isnan(op_loss) else 10.0 |
|
|
} |
|
|
|
|
|
|
|
|
def value_to_digits(value: int) -> list: |
|
|
"""Convert integer value to list of digits (hundreds, tens, ones).""" |
|
|
digits = [] |
|
|
for place in [100, 10, 1]: |
|
|
digit = (value // place) % 10 |
|
|
digits.append(digit) |
|
|
return digits |
|
|
|
|
|
|
|
|
def compute_positional_digit_loss(pred_bits, op_logits, a_digit_logits_list, b_digit_logits_list, |
|
|
target_result, target_op_idx, target_a_values, target_b_values, |
|
|
device, digit_weight: float = 5.0): |
|
|
""" |
|
|
Loss for positional digit extraction with DIRECT digit supervision. |
|
|
|
|
|
This provides strong gradients by directly supervising digit classification |
|
|
instead of going through the value -> bits conversion. |
|
|
""" |
|
|
result_loss = nn.functional.binary_cross_entropy_with_logits( |
|
|
pred_bits, target_result |
|
|
) |
|
|
|
|
|
op_loss = nn.functional.cross_entropy(op_logits, target_op_idx) |
|
|
|
|
|
a_digit_loss = torch.tensor(0.0, device=device) |
|
|
b_digit_loss = torch.tensor(0.0, device=device) |
|
|
n_a_digits = 0 |
|
|
n_b_digits = 0 |
|
|
|
|
|
for i, (a_logits_list, b_logits_list) in enumerate(zip(a_digit_logits_list, b_digit_logits_list)): |
|
|
target_a = target_a_values[i].item() |
|
|
target_b = target_b_values[i].item() |
|
|
|
|
|
a_digits = value_to_digits(int(target_a)) |
|
|
b_digits = value_to_digits(int(target_b)) |
|
|
|
|
|
n_a = len(a_logits_list) |
|
|
n_b = len(b_logits_list) |
|
|
|
|
|
if n_a > 0: |
|
|
target_a_digits = a_digits[-n_a:] |
|
|
for j, logits in enumerate(a_logits_list): |
|
|
target_digit = torch.tensor([target_a_digits[j]], device=device, dtype=torch.long) |
|
|
a_digit_loss = a_digit_loss + nn.functional.cross_entropy(logits.unsqueeze(0), target_digit) |
|
|
n_a_digits += 1 |
|
|
|
|
|
if n_b > 0: |
|
|
target_b_digits = b_digits[-n_b:] |
|
|
for j, logits in enumerate(b_logits_list): |
|
|
target_digit = torch.tensor([target_b_digits[j]], device=device, dtype=torch.long) |
|
|
b_digit_loss = b_digit_loss + nn.functional.cross_entropy(logits.unsqueeze(0), target_digit) |
|
|
n_b_digits += 1 |
|
|
|
|
|
if n_a_digits > 0: |
|
|
a_digit_loss = a_digit_loss / n_a_digits |
|
|
if n_b_digits > 0: |
|
|
b_digit_loss = b_digit_loss / n_b_digits |
|
|
|
|
|
total = result_loss + digit_weight * a_digit_loss + digit_weight * b_digit_loss + op_loss |
|
|
total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0) |
|
|
|
|
|
return total, { |
|
|
'result': result_loss.item() if not torch.isnan(result_loss) else 10.0, |
|
|
'a_digit': a_digit_loss.item() if not torch.isnan(a_digit_loss) else 10.0, |
|
|
'b_digit': b_digit_loss.item() if not torch.isnan(b_digit_loss) else 10.0, |
|
|
'op': op_loss.item() if not torch.isnan(op_loss) else 10.0 |
|
|
} |
|
|
|
|
|
|
|
|
def compute_hybrid_loss(pred_bits, op_logits, used_lookup, |
|
|
a_digit_logits, b_digit_logits, |
|
|
target_result, target_a_values, target_b_values, target_op_idx, |
|
|
device, digit_weight: float = 2.0): |
|
|
""" |
|
|
Loss for hybrid extraction with digit-level prediction. |
|
|
|
|
|
Uses cross-entropy on each digit (hundreds, tens, ones) for word samples. |
|
|
Samples using digit lookup are already 100% accurate - no loss computed. |
|
|
""" |
|
|
result_loss = nn.functional.binary_cross_entropy_with_logits( |
|
|
pred_bits, target_result |
|
|
) |
|
|
|
|
|
op_loss = nn.functional.cross_entropy(op_logits, target_op_idx) |
|
|
|
|
|
word_mask = ~used_lookup |
|
|
n_words = word_mask.sum().item() |
|
|
|
|
|
if n_words > 0 and a_digit_logits is not None and b_digit_logits is not None: |
|
|
target_a_word = target_a_values[word_mask].long() |
|
|
target_b_word = target_b_values[word_mask].long() |
|
|
|
|
|
a_hundreds = target_a_word // 100 |
|
|
a_tens = (target_a_word % 100) // 10 |
|
|
a_ones = target_a_word % 10 |
|
|
|
|
|
b_hundreds = target_b_word // 100 |
|
|
b_tens = (target_b_word % 100) // 10 |
|
|
b_ones = target_b_word % 10 |
|
|
|
|
|
a_logits = a_digit_logits.view(-1, 3, 10) |
|
|
b_logits = b_digit_logits.view(-1, 3, 10) |
|
|
|
|
|
a_digit_loss = ( |
|
|
nn.functional.cross_entropy(a_logits[:, 0], a_hundreds) + |
|
|
nn.functional.cross_entropy(a_logits[:, 1], a_tens) + |
|
|
nn.functional.cross_entropy(a_logits[:, 2], a_ones) |
|
|
) / 3.0 |
|
|
|
|
|
b_digit_loss = ( |
|
|
nn.functional.cross_entropy(b_logits[:, 0], b_hundreds) + |
|
|
nn.functional.cross_entropy(b_logits[:, 1], b_tens) + |
|
|
nn.functional.cross_entropy(b_logits[:, 2], b_ones) |
|
|
) / 3.0 |
|
|
else: |
|
|
a_digit_loss = torch.tensor(0.0, device=device) |
|
|
b_digit_loss = torch.tensor(0.0, device=device) |
|
|
|
|
|
total = result_loss + op_loss + digit_weight * (a_digit_loss + b_digit_loss) |
|
|
total = torch.nan_to_num(total, nan=10.0, posinf=10.0, neginf=0.0) |
|
|
|
|
|
return total, { |
|
|
'result': result_loss.item() if not torch.isnan(result_loss) else 10.0, |
|
|
'a_digit': a_digit_loss.item() if not torch.isnan(a_digit_loss) else 10.0, |
|
|
'b_digit': b_digit_loss.item() if not torch.isnan(b_digit_loss) else 10.0, |
|
|
'op': op_loss.item() if not torch.isnan(op_loss) else 10.0, |
|
|
'n_words': n_words, |
|
|
'n_lookup': used_lookup.sum().item() |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_llm(model, n_samples: int = 500): |
|
|
"""Evaluate LLM model on random problems (mixed digit/word format).""" |
|
|
model.extractor.eval() |
|
|
correct = 0 |
|
|
op_correct = 0 |
|
|
|
|
|
for _ in range(n_samples): |
|
|
text, a, b, op, expected = generate_problem() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model([text]) |
|
|
result_bits = outputs[0] |
|
|
op_logits = outputs[3] |
|
|
|
|
|
pred_result = bits_to_int(result_bits[0]) |
|
|
pred_op = OPERATIONS[op_logits[0].argmax().item()] |
|
|
|
|
|
if pred_result == expected: |
|
|
correct += 1 |
|
|
if pred_op == op: |
|
|
op_correct += 1 |
|
|
|
|
|
model.extractor.train() |
|
|
return correct / n_samples, op_correct / n_samples |
|
|
|
|
|
|
|
|
def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4, |
|
|
unfreeze_layers: int = 0, extract_layer: int = -1, |
|
|
position_extract: bool = False, digit_pred: bool = False, |
|
|
positional_digit: bool = False, device: str = 'cuda'): |
|
|
""" |
|
|
Train extractor with LLM hidden states. |
|
|
|
|
|
Args: |
|
|
unfreeze_layers: Number of top transformer layers to unfreeze (0 = fully frozen) |
|
|
extract_layer: Which layer to extract from (-1 = last) |
|
|
position_extract: Use position-specific extraction (legacy) |
|
|
digit_pred: Predict digits instead of bits (legacy) |
|
|
positional_digit: Use positional digit extraction (legacy, 100% on digits only) |
|
|
""" |
|
|
hybrid = not (positional_digit or position_extract or digit_pred) |
|
|
|
|
|
print("=" * 70) |
|
|
print(" LLM TRAINING") |
|
|
if unfreeze_layers > 0: |
|
|
print(f" {unfreeze_layers} transformer layers unfrozen") |
|
|
else: |
|
|
print(" LLM frozen") |
|
|
if extract_layer != -1: |
|
|
print(f" Extracting from layer {extract_layer}") |
|
|
if hybrid: |
|
|
print(" HYBRID extraction (digit lookup + word learning)") |
|
|
elif positional_digit: |
|
|
print(" POSITIONAL DIGIT extraction (legacy, 100% on digits only)") |
|
|
elif position_extract: |
|
|
print(" Position-specific extraction (legacy)") |
|
|
elif digit_pred: |
|
|
print(" Digit-level prediction (legacy)") |
|
|
print("=" * 70) |
|
|
|
|
|
print("\nInitializing model...") |
|
|
model = ArithmeticModel( |
|
|
device=device, |
|
|
unfreeze_layers=unfreeze_layers, |
|
|
extract_layer=extract_layer, |
|
|
position_extract=position_extract, |
|
|
digit_pred=digit_pred, |
|
|
positional_digit=positional_digit, |
|
|
hybrid=hybrid |
|
|
) |
|
|
|
|
|
optimizer = optim.AdamW(model.trainable_parameters(), lr=lr) |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
|
|
|
|
|
print(f"\nTraining config:") |
|
|
print(f" Epochs: {epochs}") |
|
|
print(f" Batch size: {batch_size}") |
|
|
print(f" Learning rate: {lr}") |
|
|
print(f" Unfreeze layers: {unfreeze_layers}") |
|
|
print(f" Samples/epoch: {batch_size * 20}") |
|
|
|
|
|
print(f"\nInitial evaluation (200 samples)...") |
|
|
acc, op_acc = evaluate_llm(model, 200) |
|
|
print(f" Accuracy: {acc:.4f}, Op accuracy: {op_acc:.4f}") |
|
|
|
|
|
print(f"\nStarting training...") |
|
|
print("-" * 70) |
|
|
|
|
|
best_acc = acc |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
for epoch in range(epochs): |
|
|
model.extractor.train() |
|
|
if unfreeze_layers > 0: |
|
|
model.llm.train() |
|
|
|
|
|
max_val = get_curriculum_max(epoch, epochs) |
|
|
|
|
|
epoch_loss = 0 |
|
|
if hybrid: |
|
|
epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0, 'n_words': 0, 'n_lookup': 0} |
|
|
elif positional_digit: |
|
|
epoch_losses = {'result': 0, 'a_digit': 0, 'b_digit': 0, 'op': 0} |
|
|
else: |
|
|
epoch_losses = {'result': 0, 'a': 0, 'b': 0, 'op': 0} |
|
|
n_batches = 20 |
|
|
epoch_start = time.perf_counter() |
|
|
|
|
|
for batch_idx in range(n_batches): |
|
|
batch_texts = [] |
|
|
batch_a = [] |
|
|
batch_b = [] |
|
|
batch_op = [] |
|
|
batch_result = [] |
|
|
batch_a_values = [] |
|
|
batch_b_values = [] |
|
|
|
|
|
for _ in range(batch_size): |
|
|
text, a, b, op, result = generate_problem(max_val) |
|
|
batch_texts.append(text) |
|
|
batch_a.append(int_to_bits(a, device)) |
|
|
batch_b.append(int_to_bits(b, device)) |
|
|
batch_op.append(OPERATIONS.index(op)) |
|
|
batch_result.append(int_to_bits(result, device)) |
|
|
batch_a_values.append(a) |
|
|
batch_b_values.append(b) |
|
|
|
|
|
target_a = torch.stack(batch_a) |
|
|
target_b = torch.stack(batch_b) |
|
|
target_op = torch.tensor(batch_op, device=device) |
|
|
target_result = torch.stack(batch_result) |
|
|
target_a_values = torch.tensor(batch_a_values, device=device, dtype=torch.float32) |
|
|
target_b_values = torch.tensor(batch_b_values, device=device, dtype=torch.float32) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
outputs = model(batch_texts) |
|
|
pred_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3] |
|
|
|
|
|
if hybrid: |
|
|
a_values, b_values, used_lookup = outputs[4], outputs[5], outputs[6] |
|
|
a_digit_logits, b_digit_logits = outputs[7], outputs[8] |
|
|
loss, losses = compute_hybrid_loss( |
|
|
pred_bits, op_logits, used_lookup, |
|
|
a_digit_logits, b_digit_logits, |
|
|
target_result, target_a_values, target_b_values, target_op, device |
|
|
) |
|
|
elif positional_digit: |
|
|
a_digit_logits_list = outputs[7] |
|
|
b_digit_logits_list = outputs[8] |
|
|
loss, losses = compute_positional_digit_loss( |
|
|
pred_bits, op_logits, a_digit_logits_list, b_digit_logits_list, |
|
|
target_result, target_op, target_a_values, target_b_values, device |
|
|
) |
|
|
else: |
|
|
loss, losses = compute_llm_loss( |
|
|
pred_bits, a_bits, b_bits, op_logits, |
|
|
target_result, target_a, target_b, target_op |
|
|
) |
|
|
|
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.trainable_parameters(), 1.0) |
|
|
optimizer.step() |
|
|
|
|
|
epoch_loss += loss.item() |
|
|
for k in epoch_losses: |
|
|
epoch_losses[k] += losses[k] |
|
|
|
|
|
if (batch_idx + 1) % 5 == 0: |
|
|
avg_so_far = epoch_loss / (batch_idx + 1) |
|
|
print(f" Epoch {epoch+1} batch {batch_idx+1}/{n_batches} | loss: {avg_so_far:.4f}", flush=True) |
|
|
|
|
|
epoch_time = time.perf_counter() - epoch_start |
|
|
scheduler.step() |
|
|
|
|
|
avg_loss = epoch_loss / n_batches |
|
|
for k in epoch_losses: |
|
|
epoch_losses[k] /= n_batches |
|
|
|
|
|
acc, op_acc = evaluate_llm(model, 300) |
|
|
elapsed = time.perf_counter() - start_time |
|
|
|
|
|
marker = " *" if acc > best_acc else "" |
|
|
if acc > best_acc: |
|
|
best_acc = acc |
|
|
|
|
|
mem, _ = get_gpu_memory() |
|
|
print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | " |
|
|
f"Acc: {acc:.4f}{marker} | OpAcc: {op_acc:.4f} | " |
|
|
f"Range: 0-{max_val} | VRAM: {mem:.0f}MB | Time: {elapsed:.0f}s") |
|
|
if hybrid: |
|
|
print(f" Losses - result:{epoch_losses['result']:.4f} " |
|
|
f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} " |
|
|
f"op:{epoch_losses['op']:.4f} | words:{epoch_losses['n_words']:.0f} lookup:{epoch_losses['n_lookup']:.0f}") |
|
|
elif positional_digit: |
|
|
print(f" Losses - result:{epoch_losses['result']:.4f} " |
|
|
f"a_digit:{epoch_losses['a_digit']:.4f} b_digit:{epoch_losses['b_digit']:.4f} " |
|
|
f"op:{epoch_losses['op']:.4f}") |
|
|
else: |
|
|
print(f" Losses - result:{epoch_losses['result']:.4f} " |
|
|
f"a:{epoch_losses['a']:.4f} b:{epoch_losses['b']:.4f} " |
|
|
f"op:{epoch_losses['op']:.4f}") |
|
|
|
|
|
if acc >= 0.99: |
|
|
print("\n" + "=" * 70) |
|
|
print(" TARGET ACHIEVED: 99%+ ACCURACY") |
|
|
print("=" * 70) |
|
|
break |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" FINAL EVALUATION") |
|
|
print("=" * 70) |
|
|
|
|
|
acc, op_acc = evaluate_llm(model, 1000) |
|
|
print(f"Final accuracy: {acc:.4f}") |
|
|
print(f"Final op accuracy: {op_acc:.4f}") |
|
|
print(f"Best accuracy: {best_acc:.4f}") |
|
|
|
|
|
print("\nSample predictions:") |
|
|
for _ in range(10): |
|
|
text, a, b, op, expected = generate_problem() |
|
|
with torch.no_grad(): |
|
|
outputs = model([text]) |
|
|
result_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3] |
|
|
pred = bits_to_int(result_bits[0]) |
|
|
pred_a = bits_to_int(a_bits[0]) |
|
|
pred_b = bits_to_int(b_bits[0]) |
|
|
pred_op = OPERATIONS[op_logits[0].argmax().item()] |
|
|
|
|
|
status = "OK" if pred == expected else "WRONG" |
|
|
print(f" '{text}' = {expected} | pred={pred} (a={pred_a}, b={pred_b}, op={pred_op}) [{status}]") |
|
|
|
|
|
save_path = "D:/8bit-threshold-computer/llm_integration/trained/llm.pt" |
|
|
save_dict = { |
|
|
'extractor_state_dict': model.extractor.state_dict(), |
|
|
'final_accuracy': acc, |
|
|
'best_accuracy': best_acc, |
|
|
'unfreeze_layers': unfreeze_layers, |
|
|
} |
|
|
if unfreeze_layers > 0: |
|
|
save_dict['llm_state_dict'] = { |
|
|
k: v for k, v in model.llm.state_dict().items() |
|
|
if any(f'layers.{i}.' in k for i in range(len(model.llm.model.layers) - unfreeze_layers, len(model.llm.model.layers))) |
|
|
} |
|
|
torch.save(save_dict, save_path) |
|
|
print(f"\nSaved to: {save_path}") |
|
|
|
|
|
return model, acc |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Unified training for threshold circuit LLM integration', |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Modes: |
|
|
router - Train only OpRouter with ground truth bits (sanity check) |
|
|
interface - Train BitEncoder + OpRouter with ground truth bits (sanity check) |
|
|
llm - Train extractor with LLM hidden states (the real training) |
|
|
|
|
|
LLM options: |
|
|
--unfreeze_layers N Fine-tune top N transformer layers |
|
|
--extract_layer N Extract from layer N (-1 = last) |
|
|
--position_extract Use position-specific extraction |
|
|
--digit_pred Predict digits instead of bits |
|
|
|
|
|
Baked-in: curriculum learning (0-9 -> 0-99 -> 0-255), 2x loss weight for a/b |
|
|
|
|
|
Examples: |
|
|
python train.py --mode llm --epochs 100 |
|
|
python train.py --mode llm --position_extract |
|
|
python train.py --mode llm --digit_pred --extract_layer 12 |
|
|
python train.py --mode llm --unfreeze_layers 4 --batch_size 4096 |
|
|
""" |
|
|
) |
|
|
parser.add_argument('--mode', type=str, required=True, |
|
|
choices=['router', 'interface', 'llm'], |
|
|
help='Training mode') |
|
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs') |
|
|
parser.add_argument('--batch_size', type=int, default=512, help='Batch size (default: 512)') |
|
|
parser.add_argument('--lr', type=float, default=None, |
|
|
help='Learning rate (default: mode-specific)') |
|
|
parser.add_argument('--unfreeze_layers', type=int, default=0, |
|
|
help='Unfreeze top N transformer layers (default 0 = frozen)') |
|
|
parser.add_argument('--extract_layer', type=int, default=0, |
|
|
help='Which layer to extract from (default: 0 = embeddings, best for digits)') |
|
|
parser.add_argument('--position_extract', action='store_true', |
|
|
help='Use position-specific extraction (legacy)') |
|
|
parser.add_argument('--digit_pred', action='store_true', |
|
|
help='Predict digits instead of bits (legacy)') |
|
|
parser.add_argument('--positional_digit', action='store_true', default=False, |
|
|
help='Use positional digit extraction (legacy, 100%% on digits only)') |
|
|
parser.add_argument('--device', type=str, default='cuda', help='Device') |
|
|
args = parser.parse_args() |
|
|
|
|
|
torch.manual_seed(42) |
|
|
random.seed(42) |
|
|
|
|
|
if args.mode == 'router': |
|
|
lr = args.lr if args.lr is not None else 1e-2 |
|
|
train_router(epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device) |
|
|
|
|
|
elif args.mode == 'interface': |
|
|
lr = args.lr if args.lr is not None else 1e-3 |
|
|
model, fitness = train_interface( |
|
|
epochs=args.epochs, batch_size=args.batch_size, lr=lr, device=args.device |
|
|
) |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" EXPERIMENT SUMMARY") |
|
|
print("=" * 70) |
|
|
print(f"\n Control (Vanilla SmolLM2-360M): 11.90%") |
|
|
print(f" Experimental (Trained Interface): {100*fitness:.2f}%") |
|
|
if fitness > 0: |
|
|
print(f"\n Improvement: {100*(fitness - 0.119)/0.119:.1f}%") |
|
|
|
|
|
if fitness >= 0.99: |
|
|
print("\n CONCLUSION: Frozen threshold circuits + trained interface") |
|
|
print(" achieves near-perfect arithmetic accuracy.") |
|
|
print(" Core thesis VALIDATED.") |
|
|
else: |
|
|
print(f"\n CONCLUSION: Further training or architecture changes needed.") |
|
|
print(f" Current gap: {100*(1.0 - fitness):.2f}%") |
|
|
|
|
|
elif args.mode == 'llm': |
|
|
lr = args.lr if args.lr is not None else 3e-4 |
|
|
train_llm( |
|
|
epochs=args.epochs, |
|
|
batch_size=args.batch_size, |
|
|
lr=lr, |
|
|
unfreeze_layers=args.unfreeze_layers, |
|
|
extract_layer=args.extract_layer, |
|
|
position_extract=args.position_extract, |
|
|
digit_pred=args.digit_pred, |
|
|
positional_digit=args.positional_digit, |
|
|
device=args.device |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|