| """ |
| Shared fitness function for threshold circuit LLM integration. |
| Randomized tests, no answer supervision - fitness IS the training signal. |
| """ |
|
|
| import torch |
| import random |
| from typing import Callable, Dict, Tuple, List |
|
|
| OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq'] |
|
|
| def ground_truth(a: int, b: int, op: str) -> int: |
| """Compute expected result (8-bit arithmetic).""" |
| if op == 'add': |
| return (a + b) & 0xFF |
| elif op == 'sub': |
| return (a - b) & 0xFF |
| elif op == 'mul': |
| return (a * b) & 0xFF |
| elif op == 'gt': |
| return 1 if a > b else 0 |
| elif op == 'lt': |
| return 1 if a < b else 0 |
| elif op == 'eq': |
| return 1 if a == b else 0 |
| else: |
| raise ValueError(f"Unknown op: {op}") |
|
|
|
|
| def int_to_bits(val: int, n_bits: int = 8) -> torch.Tensor: |
| """Convert integer to bit tensor (MSB first).""" |
| bits = torch.zeros(n_bits) |
| for i in range(n_bits): |
| bits[n_bits - 1 - i] = (val >> i) & 1 |
| return bits |
|
|
|
|
| def bits_to_int(bits: torch.Tensor) -> int: |
| """Convert bit tensor to integer (MSB first).""" |
| val = 0 |
| n_bits = bits.shape[-1] |
| for i in range(n_bits): |
| val += int(bits[..., i].item()) << (n_bits - 1 - i) |
| return val |
|
|
|
|
| def op_to_idx(op: str) -> int: |
| """Convert operation string to index.""" |
| return OPERATIONS.index(op) |
|
|
|
|
| def idx_to_op(idx: int) -> str: |
| """Convert index to operation string.""" |
| return OPERATIONS[idx] |
|
|
|
|
| def generate_batch(batch_size: int, device: str = 'cuda') -> Dict[str, torch.Tensor]: |
| """ |
| Generate a batch of random arithmetic problems. |
| |
| Returns: |
| Dict with: |
| 'a': [batch_size] int tensor of first operands |
| 'b': [batch_size] int tensor of second operands |
| 'op': [batch_size] int tensor of operation indices |
| 'a_bits': [batch_size, 8] bit tensor |
| 'b_bits': [batch_size, 8] bit tensor |
| 'op_onehot': [batch_size, 6] one-hot operation tensor |
| 'expected': [batch_size] int tensor of expected results |
| 'expected_bits': [batch_size, 8] bit tensor of expected results |
| """ |
| a_vals = torch.randint(0, 256, (batch_size,), device=device) |
| b_vals = torch.randint(0, 256, (batch_size,), device=device) |
| op_indices = torch.randint(0, len(OPERATIONS), (batch_size,), device=device) |
|
|
| a_bits = torch.zeros(batch_size, 8, device=device) |
| b_bits = torch.zeros(batch_size, 8, device=device) |
| for i in range(8): |
| a_bits[:, 7-i] = (a_vals >> i) & 1 |
| b_bits[:, 7-i] = (b_vals >> i) & 1 |
|
|
| op_onehot = torch.zeros(batch_size, len(OPERATIONS), device=device) |
| op_onehot.scatter_(1, op_indices.unsqueeze(1), 1.0) |
|
|
| expected = torch.zeros(batch_size, dtype=torch.long, device=device) |
| for i in range(batch_size): |
| a, b, op_idx = a_vals[i].item(), b_vals[i].item(), op_indices[i].item() |
| expected[i] = ground_truth(a, b, idx_to_op(op_idx)) |
|
|
| expected_bits = torch.zeros(batch_size, 8, device=device) |
| for i in range(8): |
| expected_bits[:, 7-i] = (expected >> i) & 1 |
|
|
| return { |
| 'a': a_vals, |
| 'b': b_vals, |
| 'op': op_indices, |
| 'a_bits': a_bits.float(), |
| 'b_bits': b_bits.float(), |
| 'op_onehot': op_onehot.float(), |
| 'expected': expected, |
| 'expected_bits': expected_bits.float(), |
| } |
|
|
|
|
| def compute_fitness( |
| model_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], |
| n_samples: int = 10000, |
| batch_size: int = 256, |
| device: str = 'cuda', |
| return_details: bool = False |
| ) -> float | Tuple[float, Dict]: |
| """ |
| Compute fitness score for a model. |
| |
| Args: |
| model_fn: Function that takes (a_bits, b_bits, op_onehot) and returns result_bits |
| n_samples: Number of test cases |
| batch_size: Batch size for evaluation |
| device: Device to run on |
| return_details: If True, return per-operation breakdown |
| |
| Returns: |
| Fitness score in [0, 1], optionally with details dict |
| """ |
| correct = 0 |
| total = 0 |
| op_correct = {op: 0 for op in OPERATIONS} |
| op_total = {op: 0 for op in OPERATIONS} |
|
|
| for _ in range(0, n_samples, batch_size): |
| actual_batch = min(batch_size, n_samples - total) |
| batch = generate_batch(actual_batch, device) |
|
|
| with torch.no_grad(): |
| pred_bits = model_fn(batch['a_bits'], batch['b_bits'], batch['op_onehot']) |
|
|
| pred_bits_binary = (pred_bits > 0.5).float() |
|
|
| for i in range(actual_batch): |
| pred_val = 0 |
| for j in range(8): |
| pred_val += int(pred_bits_binary[i, j].item()) << (7 - j) |
|
|
| expected_val = batch['expected'][i].item() |
| op_name = idx_to_op(batch['op'][i].item()) |
|
|
| op_total[op_name] += 1 |
| total += 1 |
|
|
| if pred_val == expected_val: |
| correct += 1 |
| op_correct[op_name] += 1 |
|
|
| fitness = correct / total if total > 0 else 0.0 |
|
|
| if return_details: |
| details = { |
| 'correct': correct, |
| 'total': total, |
| 'by_op': { |
| op: { |
| 'correct': op_correct[op], |
| 'total': op_total[op], |
| 'accuracy': op_correct[op] / op_total[op] if op_total[op] > 0 else 0.0 |
| } |
| for op in OPERATIONS |
| } |
| } |
| return fitness, details |
|
|
| return fitness |
|
|
|
|
| def compute_bit_accuracy(pred_bits: torch.Tensor, expected_bits: torch.Tensor) -> float: |
| """Compute per-bit accuracy (for gradient signal analysis).""" |
| pred_binary = (pred_bits > 0.5).float() |
| return (pred_binary == expected_bits).float().mean().item() |
|
|
|
|
| def compute_loss(pred_bits: torch.Tensor, expected_bits: torch.Tensor) -> torch.Tensor: |
| """Binary cross-entropy loss on output bits.""" |
| pred_clamped = pred_bits.clamp(1e-7, 1 - 1e-7) |
| return -((expected_bits * torch.log(pred_clamped) + |
| (1 - expected_bits) * torch.log(1 - pred_clamped))).mean() |
|
|
|
|
| if __name__ == "__main__": |
| print("Testing fitness module...") |
|
|
| batch = generate_batch(8, 'cpu') |
| print(f"\nSample batch:") |
| for i in range(4): |
| a, b = batch['a'][i].item(), batch['b'][i].item() |
| op = idx_to_op(batch['op'][i].item()) |
| expected = batch['expected'][i].item() |
| print(f" {a} {op} {b} = {expected}") |
|
|
| def random_model(a_bits, b_bits, op_onehot): |
| return torch.rand(a_bits.shape[0], 8, device=a_bits.device) |
|
|
| fitness = compute_fitness(random_model, n_samples=1000, batch_size=100, device='cpu') |
| print(f"\nRandom model fitness: {fitness:.4f} (expected ~0.004 for 8-bit)") |
|
|
| def perfect_model(a_bits, b_bits, op_onehot): |
| batch_size = a_bits.shape[0] |
| results = torch.zeros(batch_size, 8, device=a_bits.device) |
| for i in range(batch_size): |
| a = sum(int(a_bits[i, j].item()) << (7-j) for j in range(8)) |
| b = sum(int(b_bits[i, j].item()) << (7-j) for j in range(8)) |
| op_idx = op_onehot[i].argmax().item() |
| result = ground_truth(a, b, idx_to_op(op_idx)) |
| for j in range(8): |
| results[i, 7-j] = (result >> j) & 1 |
| return results |
|
|
| fitness = compute_fitness(perfect_model, n_samples=1000, batch_size=100, device='cpu') |
| print(f"Perfect model fitness: {fitness:.4f} (expected 1.0)") |
|
|