|
|
""" |
|
|
Frozen threshold circuit wrapper for LLM integration. |
|
|
Loads safetensors and provides differentiable-compatible execution. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from safetensors import safe_open |
|
|
from typing import Dict, Tuple |
|
|
|
|
|
MODEL_PATH = "D:/8bit-threshold-computer/neural_computer.safetensors" |
|
|
|
|
|
|
|
|
def heaviside(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Standard Heaviside step function.""" |
|
|
return (x >= 0).float() |
|
|
|
|
|
|
|
|
class HeavisideSTE(torch.autograd.Function): |
|
|
"""Heaviside with straight-through estimator for gradients.""" |
|
|
@staticmethod |
|
|
def forward(ctx, x): |
|
|
return (x >= 0).float() |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
return grad_output |
|
|
|
|
|
|
|
|
def heaviside_ste(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Heaviside with STE gradient.""" |
|
|
return HeavisideSTE.apply(x) |
|
|
|
|
|
|
|
|
class FrozenThresholdCircuits(nn.Module): |
|
|
""" |
|
|
Wrapper for frozen threshold logic circuits. |
|
|
All weights are frozen - no gradients flow through circuit internals. |
|
|
Gradients flow through inputs/outputs via STE. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str = MODEL_PATH, device: str = 'cuda'): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.weights = {} |
|
|
self._load_weights(model_path) |
|
|
|
|
|
def _load_weights(self, path: str): |
|
|
"""Load weights from safetensors file.""" |
|
|
with safe_open(path, framework='pt') as f: |
|
|
for name in f.keys(): |
|
|
tensor = f.get_tensor(name).to(self.device).float() |
|
|
self.weights[name] = tensor |
|
|
|
|
|
def _gate(self, inputs: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: |
|
|
"""Execute single threshold gate with STE.""" |
|
|
weight = weight.view(-1) |
|
|
bias = bias.view(-1) |
|
|
pre_activation = (inputs * weight).sum(dim=-1) + bias |
|
|
return heaviside_ste(pre_activation) |
|
|
|
|
|
def _xor(self, a: torch.Tensor, b: torch.Tensor, prefix: str) -> torch.Tensor: |
|
|
"""XOR via OR-NAND-AND pattern (2 layers).""" |
|
|
inputs = torch.stack([a, b], dim=-1) |
|
|
|
|
|
w_or = self.weights[f'{prefix}.layer1.or.weight'] |
|
|
b_or = self.weights[f'{prefix}.layer1.or.bias'] |
|
|
w_nand = self.weights[f'{prefix}.layer1.nand.weight'] |
|
|
b_nand = self.weights[f'{prefix}.layer1.nand.bias'] |
|
|
|
|
|
h_or = self._gate(inputs, w_or, b_or) |
|
|
h_nand = self._gate(inputs, w_nand, b_nand) |
|
|
|
|
|
hidden = torch.stack([h_or, h_nand], dim=-1) |
|
|
w2 = self.weights[f'{prefix}.layer2.weight'] |
|
|
b2 = self.weights[f'{prefix}.layer2.bias'] |
|
|
|
|
|
return self._gate(hidden, w2, b2) |
|
|
|
|
|
def _full_adder(self, a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor, |
|
|
prefix: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Full adder: sum and carry out.""" |
|
|
ha1_sum = self._xor(a, b, f'{prefix}.ha1.sum') |
|
|
|
|
|
inp_carry1 = torch.stack([a, b], dim=-1) |
|
|
w_c1 = self.weights[f'{prefix}.ha1.carry.weight'] |
|
|
b_c1 = self.weights[f'{prefix}.ha1.carry.bias'] |
|
|
ha1_carry = self._gate(inp_carry1, w_c1, b_c1) |
|
|
|
|
|
ha2_sum = self._xor(ha1_sum, cin, f'{prefix}.ha2.sum') |
|
|
|
|
|
inp_carry2 = torch.stack([ha1_sum, cin], dim=-1) |
|
|
w_c2 = self.weights[f'{prefix}.ha2.carry.weight'] |
|
|
b_c2 = self.weights[f'{prefix}.ha2.carry.bias'] |
|
|
ha2_carry = self._gate(inp_carry2, w_c2, b_c2) |
|
|
|
|
|
inp_cout = torch.stack([ha1_carry, ha2_carry], dim=-1) |
|
|
w_cout = self.weights[f'{prefix}.carry_or.weight'] |
|
|
b_cout = self.weights[f'{prefix}.carry_or.bias'] |
|
|
cout = self._gate(inp_cout, w_cout, b_cout) |
|
|
|
|
|
return ha2_sum, cout |
|
|
|
|
|
def add_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
8-bit ripple carry addition. |
|
|
|
|
|
Args: |
|
|
a_bits: [batch, 8] MSB-first |
|
|
b_bits: [batch, 8] MSB-first |
|
|
|
|
|
Returns: |
|
|
result_bits: [batch, 8] MSB-first |
|
|
carry_out: [batch] final carry |
|
|
""" |
|
|
batch_size = a_bits.shape[0] |
|
|
carry = torch.zeros(batch_size, device=self.device) |
|
|
result_bits = [] |
|
|
|
|
|
for bit in range(8): |
|
|
bit_idx = 7 - bit |
|
|
s, carry = self._full_adder( |
|
|
a_bits[:, bit_idx], |
|
|
b_bits[:, bit_idx], |
|
|
carry, |
|
|
f'arithmetic.ripplecarry8bit.fa{bit}' |
|
|
) |
|
|
result_bits.insert(0, s) |
|
|
|
|
|
result = torch.stack(result_bits, dim=1) |
|
|
return result, carry |
|
|
|
|
|
def sub_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
8-bit subtraction via two's complement: A - B = A + (~B) + 1 |
|
|
|
|
|
Args: |
|
|
a_bits: [batch, 8] MSB-first |
|
|
b_bits: [batch, 8] MSB-first |
|
|
|
|
|
Returns: |
|
|
result_bits: [batch, 8] MSB-first |
|
|
borrow_out: [batch] (inverted carry) |
|
|
""" |
|
|
b_inv = 1.0 - b_bits |
|
|
batch_size = a_bits.shape[0] |
|
|
carry = torch.ones(batch_size, device=self.device) |
|
|
result_bits = [] |
|
|
|
|
|
for bit in range(8): |
|
|
bit_idx = 7 - bit |
|
|
s, carry = self._full_adder( |
|
|
a_bits[:, bit_idx], |
|
|
b_inv[:, bit_idx], |
|
|
carry, |
|
|
f'arithmetic.ripplecarry8bit.fa{bit}' |
|
|
) |
|
|
result_bits.insert(0, s) |
|
|
|
|
|
result = torch.stack(result_bits, dim=1) |
|
|
borrow = 1.0 - carry |
|
|
return result, borrow |
|
|
|
|
|
def mul_8bit(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
8-bit multiplication via shift-add (software implementation using adder circuits). |
|
|
Only keeps low 8 bits of result (matches 8-bit wrap behavior). |
|
|
|
|
|
Args: |
|
|
a_bits: [batch, 8] MSB-first |
|
|
b_bits: [batch, 8] MSB-first |
|
|
|
|
|
Returns: |
|
|
result_bits: [batch, 8] MSB-first (low 8 bits of product) |
|
|
""" |
|
|
batch_size = a_bits.shape[0] |
|
|
|
|
|
acc = torch.zeros(batch_size, 8, device=self.device) |
|
|
|
|
|
for i in range(8): |
|
|
b_bit = b_bits[:, 7 - i] |
|
|
pp = a_bits * b_bit.unsqueeze(1) |
|
|
|
|
|
shifted_pp = torch.zeros(batch_size, 8, device=self.device) |
|
|
for j in range(8): |
|
|
dst_idx = j + i |
|
|
if dst_idx < 8: |
|
|
shifted_pp[:, 7 - dst_idx] = pp[:, 7 - j] |
|
|
|
|
|
acc, _ = self.add_8bit(acc, shifted_pp) |
|
|
|
|
|
return acc |
|
|
|
|
|
def compare_gt(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor: |
|
|
"""A > B comparison.""" |
|
|
inputs = torch.cat([a_bits, b_bits], dim=-1) |
|
|
w = self.weights['arithmetic.greaterthan8bit.weight'].view(-1) |
|
|
b = self.weights['arithmetic.greaterthan8bit.bias'].view(-1) |
|
|
return heaviside_ste((inputs * w).sum(dim=-1) + b) |
|
|
|
|
|
def compare_lt(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor: |
|
|
"""A < B comparison.""" |
|
|
inputs = torch.cat([a_bits, b_bits], dim=-1) |
|
|
w = self.weights['arithmetic.lessthan8bit.weight'].view(-1) |
|
|
b = self.weights['arithmetic.lessthan8bit.bias'].view(-1) |
|
|
return heaviside_ste((inputs * w).sum(dim=-1) + b) |
|
|
|
|
|
def compare_eq(self, a_bits: torch.Tensor, b_bits: torch.Tensor) -> torch.Tensor: |
|
|
"""A == B comparison (two-layer).""" |
|
|
inputs = torch.cat([a_bits, b_bits], dim=-1) |
|
|
prefix = 'arithmetic.equality8bit' |
|
|
|
|
|
w_geq = self.weights[f'{prefix}.layer1.geq.weight'].view(-1) |
|
|
b_geq = self.weights[f'{prefix}.layer1.geq.bias'].view(-1) |
|
|
w_leq = self.weights[f'{prefix}.layer1.leq.weight'].view(-1) |
|
|
b_leq = self.weights[f'{prefix}.layer1.leq.bias'].view(-1) |
|
|
|
|
|
h_geq = heaviside_ste((inputs * w_geq).sum(dim=-1) + b_geq) |
|
|
h_leq = heaviside_ste((inputs * w_leq).sum(dim=-1) + b_leq) |
|
|
|
|
|
hidden = torch.stack([h_geq, h_leq], dim=-1) |
|
|
w2 = self.weights[f'{prefix}.layer2.weight'].view(-1) |
|
|
b2 = self.weights[f'{prefix}.layer2.bias'].view(-1) |
|
|
|
|
|
return heaviside_ste((hidden * w2).sum(dim=-1) + b2) |
|
|
|
|
|
def forward(self, a_bits: torch.Tensor, b_bits: torch.Tensor, |
|
|
op_onehot: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Execute operation based on one-hot selector. |
|
|
Uses soft routing during training for gradient flow. |
|
|
|
|
|
Args: |
|
|
a_bits: [batch, 8] operand A |
|
|
b_bits: [batch, 8] operand B |
|
|
op_onehot: [batch, 6] one-hot operation selector |
|
|
[add, sub, mul, gt, lt, eq] |
|
|
|
|
|
Returns: |
|
|
result_bits: [batch, 8] result (comparisons in bit 7, rest zeros) |
|
|
""" |
|
|
batch_size = a_bits.shape[0] |
|
|
|
|
|
add_result, _ = self.add_8bit(a_bits, b_bits) |
|
|
sub_result, _ = self.sub_8bit(a_bits, b_bits) |
|
|
mul_result = self.mul_8bit(a_bits, b_bits) |
|
|
|
|
|
gt_result = self.compare_gt(a_bits, b_bits) |
|
|
lt_result = self.compare_lt(a_bits, b_bits) |
|
|
eq_result = self.compare_eq(a_bits, b_bits) |
|
|
|
|
|
cmp_expanded = torch.zeros(batch_size, 8, device=self.device) |
|
|
|
|
|
gt_expanded = cmp_expanded.clone() |
|
|
gt_expanded[:, 7] = gt_result |
|
|
|
|
|
lt_expanded = cmp_expanded.clone() |
|
|
lt_expanded[:, 7] = lt_result |
|
|
|
|
|
eq_expanded = cmp_expanded.clone() |
|
|
eq_expanded[:, 7] = eq_result |
|
|
|
|
|
results = torch.stack([ |
|
|
add_result, |
|
|
sub_result, |
|
|
mul_result, |
|
|
gt_expanded, |
|
|
lt_expanded, |
|
|
eq_expanded |
|
|
], dim=1) |
|
|
|
|
|
op_weights = op_onehot.unsqueeze(-1) |
|
|
output = (results * op_weights).sum(dim=1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Testing frozen circuits...") |
|
|
|
|
|
circuits = FrozenThresholdCircuits(device='cuda') |
|
|
print(f"Loaded {len(circuits.weights)} tensors") |
|
|
|
|
|
a = torch.tensor([[0, 0, 0, 0, 0, 1, 0, 1]], device='cuda', dtype=torch.float32) |
|
|
b = torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1]], device='cuda', dtype=torch.float32) |
|
|
|
|
|
result, carry = circuits.add_8bit(a, b) |
|
|
val = sum(int(result[0, i].item()) << (7 - i) for i in range(8)) |
|
|
print(f"5 + 3 = {val} (expected 8)") |
|
|
|
|
|
a = torch.tensor([[0, 1, 1, 0, 0, 1, 0, 0]], device='cuda', dtype=torch.float32) |
|
|
b = torch.tensor([[0, 0, 1, 0, 0, 1, 0, 1]], device='cuda', dtype=torch.float32) |
|
|
result, _ = circuits.sub_8bit(a, b) |
|
|
val = sum(int(result[0, i].item()) << (7 - i) for i in range(8)) |
|
|
print(f"100 - 37 = {val} (expected 63)") |
|
|
|
|
|
a = torch.tensor([[0, 0, 0, 0, 1, 1, 0, 0]], device='cuda', dtype=torch.float32) |
|
|
b = torch.tensor([[0, 0, 0, 0, 1, 0, 1, 1]], device='cuda', dtype=torch.float32) |
|
|
result = circuits.mul_8bit(a, b) |
|
|
val = sum(int(result[0, i].item()) << (7 - i) for i in range(8)) |
|
|
print(f"12 * 11 = {val} (expected 132)") |
|
|
|
|
|
a = torch.tensor([[0, 0, 1, 1, 0, 0, 1, 0]], device='cuda', dtype=torch.float32) |
|
|
b = torch.tensor([[0, 0, 0, 1, 1, 1, 1, 0]], device='cuda', dtype=torch.float32) |
|
|
gt = circuits.compare_gt(a, b) |
|
|
lt = circuits.compare_lt(a, b) |
|
|
eq = circuits.compare_eq(a, b) |
|
|
print(f"50 > 30: {int(gt[0].item())} (expected 1)") |
|
|
print(f"50 < 30: {int(lt[0].item())} (expected 0)") |
|
|
print(f"50 == 30: {int(eq[0].item())} (expected 0)") |
|
|
|
|
|
print("\nTesting batched forward...") |
|
|
batch_a = torch.randint(0, 2, (16, 8), device='cuda', dtype=torch.float32) |
|
|
batch_b = torch.randint(0, 2, (16, 8), device='cuda', dtype=torch.float32) |
|
|
op = torch.zeros(16, 6, device='cuda') |
|
|
op[:, 0] = 1.0 |
|
|
|
|
|
result = circuits(batch_a, batch_b, op) |
|
|
print(f"Batch output shape: {result.shape}") |
|
|
print("Done.") |
|
|
|