CharlesCNorton
Validate proof of concept: 100% arithmetic fitness with frozen circuits
084c69c
raw
history blame
7.37 kB
"""
Shared fitness function for threshold circuit LLM integration.
Randomized tests, no answer supervision - fitness IS the training signal.
"""
import torch
import random
from typing import Callable, Dict, Tuple, List
OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq']
def ground_truth(a: int, b: int, op: str) -> int:
"""Compute expected result (8-bit arithmetic)."""
if op == 'add':
return (a + b) & 0xFF
elif op == 'sub':
return (a - b) & 0xFF
elif op == 'mul':
return (a * b) & 0xFF
elif op == 'gt':
return 1 if a > b else 0
elif op == 'lt':
return 1 if a < b else 0
elif op == 'eq':
return 1 if a == b else 0
else:
raise ValueError(f"Unknown op: {op}")
def int_to_bits(val: int, n_bits: int = 8) -> torch.Tensor:
"""Convert integer to bit tensor (MSB first)."""
bits = torch.zeros(n_bits)
for i in range(n_bits):
bits[n_bits - 1 - i] = (val >> i) & 1
return bits
def bits_to_int(bits: torch.Tensor) -> int:
"""Convert bit tensor to integer (MSB first)."""
val = 0
n_bits = bits.shape[-1]
for i in range(n_bits):
val += int(bits[..., i].item()) << (n_bits - 1 - i)
return val
def op_to_idx(op: str) -> int:
"""Convert operation string to index."""
return OPERATIONS.index(op)
def idx_to_op(idx: int) -> str:
"""Convert index to operation string."""
return OPERATIONS[idx]
def generate_batch(batch_size: int, device: str = 'cuda') -> Dict[str, torch.Tensor]:
"""
Generate a batch of random arithmetic problems.
Returns:
Dict with:
'a': [batch_size] int tensor of first operands
'b': [batch_size] int tensor of second operands
'op': [batch_size] int tensor of operation indices
'a_bits': [batch_size, 8] bit tensor
'b_bits': [batch_size, 8] bit tensor
'op_onehot': [batch_size, 6] one-hot operation tensor
'expected': [batch_size] int tensor of expected results
'expected_bits': [batch_size, 8] bit tensor of expected results
"""
a_vals = torch.randint(0, 256, (batch_size,), device=device)
b_vals = torch.randint(0, 256, (batch_size,), device=device)
op_indices = torch.randint(0, len(OPERATIONS), (batch_size,), device=device)
a_bits = torch.zeros(batch_size, 8, device=device)
b_bits = torch.zeros(batch_size, 8, device=device)
for i in range(8):
a_bits[:, 7-i] = (a_vals >> i) & 1
b_bits[:, 7-i] = (b_vals >> i) & 1
op_onehot = torch.zeros(batch_size, len(OPERATIONS), device=device)
op_onehot.scatter_(1, op_indices.unsqueeze(1), 1.0)
expected = torch.zeros(batch_size, dtype=torch.long, device=device)
for i in range(batch_size):
a, b, op_idx = a_vals[i].item(), b_vals[i].item(), op_indices[i].item()
expected[i] = ground_truth(a, b, idx_to_op(op_idx))
expected_bits = torch.zeros(batch_size, 8, device=device)
for i in range(8):
expected_bits[:, 7-i] = (expected >> i) & 1
return {
'a': a_vals,
'b': b_vals,
'op': op_indices,
'a_bits': a_bits.float(),
'b_bits': b_bits.float(),
'op_onehot': op_onehot.float(),
'expected': expected,
'expected_bits': expected_bits.float(),
}
def compute_fitness(
model_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
n_samples: int = 10000,
batch_size: int = 256,
device: str = 'cuda',
return_details: bool = False
) -> float | Tuple[float, Dict]:
"""
Compute fitness score for a model.
Args:
model_fn: Function that takes (a_bits, b_bits, op_onehot) and returns result_bits
n_samples: Number of test cases
batch_size: Batch size for evaluation
device: Device to run on
return_details: If True, return per-operation breakdown
Returns:
Fitness score in [0, 1], optionally with details dict
"""
correct = 0
total = 0
op_correct = {op: 0 for op in OPERATIONS}
op_total = {op: 0 for op in OPERATIONS}
for _ in range(0, n_samples, batch_size):
actual_batch = min(batch_size, n_samples - total)
batch = generate_batch(actual_batch, device)
with torch.no_grad():
pred_bits = model_fn(batch['a_bits'], batch['b_bits'], batch['op_onehot'])
pred_bits_binary = (pred_bits > 0.5).float()
for i in range(actual_batch):
pred_val = 0
for j in range(8):
pred_val += int(pred_bits_binary[i, j].item()) << (7 - j)
expected_val = batch['expected'][i].item()
op_name = idx_to_op(batch['op'][i].item())
op_total[op_name] += 1
total += 1
if pred_val == expected_val:
correct += 1
op_correct[op_name] += 1
fitness = correct / total if total > 0 else 0.0
if return_details:
details = {
'correct': correct,
'total': total,
'by_op': {
op: {
'correct': op_correct[op],
'total': op_total[op],
'accuracy': op_correct[op] / op_total[op] if op_total[op] > 0 else 0.0
}
for op in OPERATIONS
}
}
return fitness, details
return fitness
def compute_bit_accuracy(pred_bits: torch.Tensor, expected_bits: torch.Tensor) -> float:
"""Compute per-bit accuracy (for gradient signal analysis)."""
pred_binary = (pred_bits > 0.5).float()
return (pred_binary == expected_bits).float().mean().item()
def compute_loss(pred_bits: torch.Tensor, expected_bits: torch.Tensor) -> torch.Tensor:
"""Binary cross-entropy loss on output bits."""
pred_clamped = pred_bits.clamp(1e-7, 1 - 1e-7)
return -((expected_bits * torch.log(pred_clamped) +
(1 - expected_bits) * torch.log(1 - pred_clamped))).mean()
if __name__ == "__main__":
print("Testing fitness module...")
batch = generate_batch(8, 'cpu')
print(f"\nSample batch:")
for i in range(4):
a, b = batch['a'][i].item(), batch['b'][i].item()
op = idx_to_op(batch['op'][i].item())
expected = batch['expected'][i].item()
print(f" {a} {op} {b} = {expected}")
def random_model(a_bits, b_bits, op_onehot):
return torch.rand(a_bits.shape[0], 8, device=a_bits.device)
fitness = compute_fitness(random_model, n_samples=1000, batch_size=100, device='cpu')
print(f"\nRandom model fitness: {fitness:.4f} (expected ~0.004 for 8-bit)")
def perfect_model(a_bits, b_bits, op_onehot):
batch_size = a_bits.shape[0]
results = torch.zeros(batch_size, 8, device=a_bits.device)
for i in range(batch_size):
a = sum(int(a_bits[i, j].item()) << (7-j) for j in range(8))
b = sum(int(b_bits[i, j].item()) << (7-j) for j in range(8))
op_idx = op_onehot[i].argmax().item()
result = ground_truth(a, b, idx_to_op(op_idx))
for j in range(8):
results[i, 7-j] = (result >> j) & 1
return results
fitness = compute_fitness(perfect_model, n_samples=1000, batch_size=100, device='cpu')
print(f"Perfect model fitness: {fitness:.4f} (expected 1.0)")