CharlesCNorton
Validate proof of concept: 100% arithmetic fitness with frozen circuits
084c69c
"""
Frozen threshold circuit wrapper for LLM integration.
Loads safetensors and provides differentiable-compatible execution.
"""
import torch
import torch.nn as nn
from safetensors import safe_open
from typing import Dict, Tuple
MODEL_PATH = "D:/8bit-threshold-computer/neural_computer.safetensors"
def heaviside(x: torch.Tensor) -> torch.Tensor:
"""Standard Heaviside step function."""
return (x >= 0).float()
class HeavisideSTE(torch.autograd.Function):
"""Heaviside with straight-through estimator for gradients."""
@staticmethod
def forward(ctx, x):
return (x >= 0).float()
@staticmethod
def backward(ctx, grad_output):
return grad_output
def heaviside_ste(x: torch.Tensor) -> torch.Tensor:
"""Heaviside with STE gradient."""
return HeavisideSTE.apply(x)
class FrozenThresholdCircuits(nn.Module):
"""
Wrapper for frozen threshold logic circuits.
All weights are frozen - no gradients flow through circuit internals.
Gradients flow through inputs/outputs via STE.
"""
def __init__(self, model_path: str = MODEL_PATH, device: str = 'cuda'):
super().__init__()
self.device = device
self.weights = {}
self._load_weights(model_path)
def _load_weights(self, path: str):
"""Load weights from safetensors file."""
with safe_open(path, framework='pt') as f:
for name in f.keys():
tensor = f.get_tensor(name).to(self.device).float()
self.weights[name] = tensor
def _gate(self, inputs: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Execute single threshold gate with STE."""
weight = weight.view(-1)
bias = bias.view(-1)
pre_activation = (inputs * weight).sum(dim=-1) + bias
return heaviside_ste(pre_activation)
def _xor(self, a: torch.Tensor, b: torch.Tensor, prefix: str) -> torch.Tensor:
"""XOR via OR-NAND-AND pattern (2 layers)."""
inputs = torch.stack([a, b], dim=-1)
w_or = self.weights[f'{prefix}.layer1.or.weight']
b_or = self.weights[f'{prefix}.layer1.or.bias']
w_nand = self.weights[f'{prefix}.layer1.nand.weight']
b_nand = self.weights[f'{prefix}.layer1.nand.bias']
h_or = self._gate(inputs, w_or, b_or)
h_nand = self._gate(inputs, w_nand, b_nand)
hidden = torch.stack([h_or, h_nand], dim=-1)
w2 = self.weights[f'{prefix}.layer2.weight']
b2 = self.weights[f'{prefix}.layer2.bias']
return self._gate(hidden, w2, b2)
def _full_adder(self, a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor,
prefix: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Full adder: sum and carry out."""
ha1_sum = self._xor(a, b, f'{prefix}.ha1.sum')
inp_carry1 = torch.stack([a, b], dim=-1)
w_c1 = self.weights[f'{prefix}.ha1.carry.weight']
b_c1 = self.weights[f'{prefix}.ha1.carry.bias']
ha1_carry = self._gate(inp_carry1, w_c1, b_c1)
ha2_sum = self._xor(ha1_sum, cin, f'{prefix}.ha2.sum')
inp_carry2 = torch.stack([ha1_sum, cin], dim=-1)
w_c2 = self.weights[f'{prefix}.ha2.carry.weight']
b_c2 = self.weights[f'{prefix}.ha2.carry.bias']
ha2_carry = self._gate(inp_carry2, w_c2, b_c2)
inp_cout = torch.stack([ha1_carry, ha2_carry], dim=-1)
w_cout = self.weights[f'{prefix}.carry_or.weight']
b_cout = self.weights[f'{prefix}.carry_or.bias']
cout = self._gate(inp_cout, w_cout, b_cout)
return ha2_sum, cout
def add_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
8-bit ripple carry addition.
Args:
a_bits: [batch, 8] MSB-first
b_bits: [batch, 8] MSB-first
Returns:
result_bits: [batch, 8] MSB-first
carry_out: [batch] final carry
"""
batch_size = a_bits.shape[0]
carry = torch.zeros(batch_size, device=self.device)
result_bits = []
for bit in range(8):
bit_idx = 7 - bit
s, carry = self._full_adder(
a_bits[:, bit_idx],
b_bits[:, bit_idx],
carry,
f'arithmetic.ripplecarry8bit.fa{bit}'
)
result_bits.insert(0, s)
result = torch.stack(result_bits, dim=1)
return result, carry
def sub_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
8-bit subtraction via two's complement: A - B = A + (~B) + 1
Args:
a_bits: [batch, 8] MSB-first
b_bits: [batch, 8] MSB-first
Returns:
result_bits: [batch, 8] MSB-first
borrow_out: [batch] (inverted carry)
"""
b_inv = 1.0 - b_bits
batch_size = a_bits.shape[0]
carry = torch.ones(batch_size, device=self.device)
result_bits = []
for bit in range(8):
bit_idx = 7 - bit
s, carry = self._full_adder(
a_bits[:, bit_idx],
b_inv[:, bit_idx],
carry,
f'arithmetic.ripplecarry8bit.fa{bit}'
)
result_bits.insert(0, s)
result = torch.stack(result_bits, dim=1)
borrow = 1.0 - carry
return result, borrow
def mul_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
"""
8-bit multiplication via shift-add (software implementation using adder circuits).
Only keeps low 8 bits of result (matches 8-bit wrap behavior).
Args:
a_bits: [batch, 8] MSB-first
b_bits: [batch, 8] MSB-first
Returns:
result_bits: [batch, 8] MSB-first (low 8 bits of product)
"""
batch_size = a_bits.shape[0]
acc = torch.zeros(batch_size, 8, device=self.device)
for i in range(8):
b_bit = b_bits[:, 7 - i]
pp = a_bits * b_bit.unsqueeze(1)
shifted_pp = torch.zeros(batch_size, 8, device=self.device)
for j in range(8):
dst_idx = j + i
if dst_idx < 8:
shifted_pp[:, 7 - dst_idx] = pp[:, 7 - j]
acc, _ = self.add_8bit(acc, shifted_pp)
return acc
def compare_gt(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
"""A > B comparison."""
inputs = torch.cat([a_bits, b_bits], dim=-1)
w = self.weights['arithmetic.greaterthan8bit.weight'].view(-1)
b = self.weights['arithmetic.greaterthan8bit.bias'].view(-1)
return heaviside_ste((inputs * w).sum(dim=-1) + b)
def compare_lt(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
"""A < B comparison."""
inputs = torch.cat([a_bits, b_bits], dim=-1)
w = self.weights['arithmetic.lessthan8bit.weight'].view(-1)
b = self.weights['arithmetic.lessthan8bit.bias'].view(-1)
return heaviside_ste((inputs * w).sum(dim=-1) + b)
def compare_eq(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor:
"""A == B comparison (two-layer)."""
inputs = torch.cat([a_bits, b_bits], dim=-1)
prefix = 'arithmetic.equality8bit'
w_geq = self.weights[f'{prefix}.layer1.geq.weight'].view(-1)
b_geq = self.weights[f'{prefix}.layer1.geq.bias'].view(-1)
w_leq = self.weights[f'{prefix}.layer1.leq.weight'].view(-1)
b_leq = self.weights[f'{prefix}.layer1.leq.bias'].view(-1)
h_geq = heaviside_ste((inputs * w_geq).sum(dim=-1) + b_geq)
h_leq = heaviside_ste((inputs * w_leq).sum(dim=-1) + b_leq)
hidden = torch.stack([h_geq, h_leq], dim=-1)
w2 = self.weights[f'{prefix}.layer2.weight'].view(-1)
b2 = self.weights[f'{prefix}.layer2.bias'].view(-1)
return heaviside_ste((hidden * w2).sum(dim=-1) + b2)
def forward(self, a_bits: torch.Tensor, b_bits: torch.Tensor,
op_onehot: torch.Tensor) -> torch.Tensor:
"""
Execute operation based on one-hot selector.
Uses soft routing during training for gradient flow.
Args:
a_bits: [batch, 8] operand A
b_bits: [batch, 8] operand B
op_onehot: [batch, 6] one-hot operation selector
[add, sub, mul, gt, lt, eq]
Returns:
result_bits: [batch, 8] result (comparisons in bit 7, rest zeros)
"""
batch_size = a_bits.shape[0]
add_result, _ = self.add_8bit(a_bits, b_bits)
sub_result, _ = self.sub_8bit(a_bits, b_bits)
mul_result = self.mul_8bit(a_bits, b_bits)
gt_result = self.compare_gt(a_bits, b_bits)
lt_result = self.compare_lt(a_bits, b_bits)
eq_result = self.compare_eq(a_bits, b_bits)
cmp_expanded = torch.zeros(batch_size, 8, device=self.device)
gt_expanded = cmp_expanded.clone()
gt_expanded[:, 7] = gt_result
lt_expanded = cmp_expanded.clone()
lt_expanded[:, 7] = lt_result
eq_expanded = cmp_expanded.clone()
eq_expanded[:, 7] = eq_result
results = torch.stack([
add_result,
sub_result,
mul_result,
gt_expanded,
lt_expanded,
eq_expanded
], dim=1)
op_weights = op_onehot.unsqueeze(-1)
output = (results * op_weights).sum(dim=1)
return output
if __name__ == "__main__":
print("Testing frozen circuits...")
circuits = FrozenThresholdCircuits(device='cuda')
print(f"Loaded {len(circuits.weights)} tensors")
a = torch.tensor([[0, 0, 0, 0, 0, 1, 0, 1]], device='cuda', dtype=torch.float32)
b = torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1]], device='cuda', dtype=torch.float32)
result, carry = circuits.add_8bit(a, b)
val = sum(int(result[0, i].item()) << (7 - i) for i in range(8))
print(f"5 + 3 = {val} (expected 8)")
a = torch.tensor([[0, 1, 1, 0, 0, 1, 0, 0]], device='cuda', dtype=torch.float32)
b = torch.tensor([[0, 0, 1, 0, 0, 1, 0, 1]], device='cuda', dtype=torch.float32)
result, _ = circuits.sub_8bit(a, b)
val = sum(int(result[0, i].item()) << (7 - i) for i in range(8))
print(f"100 - 37 = {val} (expected 63)")
a = torch.tensor([[0, 0, 0, 0, 1, 1, 0, 0]], device='cuda', dtype=torch.float32)
b = torch.tensor([[0, 0, 0, 0, 1, 0, 1, 1]], device='cuda', dtype=torch.float32)
result = circuits.mul_8bit(a, b)
val = sum(int(result[0, i].item()) << (7 - i) for i in range(8))
print(f"12 * 11 = {val} (expected 132)")
a = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 0]], device='cuda', dtype=torch.float32)
b = torch.tensor([[0, 0, 0, 1, 1, 1, 1, 0]], device='cuda', dtype=torch.float32)
gt = circuits.compare_gt(a, b)
lt = circuits.compare_lt(a, b)
eq = circuits.compare_eq(a, b)
print(f"50 > 30: {int(gt[0].item())} (expected 1)")
print(f"50 < 30: {int(lt[0].item())} (expected 0)")
print(f"50 == 30: {int(eq[0].item())} (expected 0)")
print("\nTesting batched forward...")
batch_a = torch.randint(0, 2, (16, 8), device='cuda', dtype=torch.float32)
batch_b = torch.randint(0, 2, (16, 8), device='cuda', dtype=torch.float32)
op = torch.zeros(16, 6, device='cuda')
op[:, 0] = 1.0
result = circuits(batch_a, batch_b, op)
print(f"Batch output shape: {result.shape}")
print("Done.")