""" Pruned Threshold Network for Parity Computation A minimal ternary threshold network (8->11->3->1) that computes 8-bit parity. Pruned from the original 8->32->16->1 architecture with 83.3% parameter reduction. """ import torch import torch.nn as nn class PrunedThresholdNetwork(nn.Module): """ Pruned binary threshold network with ternary weights. Architecture: 8 -> 11 -> 3 -> 1 Weights: {-1, 0, +1} Activation: Heaviside (x >= 0 -> 1, else 0) """ def __init__(self, n_bits=8, hidden1=11, hidden2=3): super().__init__() self.n_bits = n_bits self.hidden1 = hidden1 self.hidden2 = hidden2 self.layer1_weight = nn.Parameter(torch.zeros(hidden1, n_bits)) self.layer1_bias = nn.Parameter(torch.zeros(hidden1)) self.layer2_weight = nn.Parameter(torch.zeros(hidden2, hidden1)) self.layer2_bias = nn.Parameter(torch.zeros(hidden2)) self.output_weight = nn.Parameter(torch.zeros(1, hidden2)) self.output_bias = nn.Parameter(torch.zeros(1)) def forward(self, x): """Forward pass with Heaviside activation.""" x = x.float() x = (torch.nn.functional.linear(x, self.layer1_weight, self.layer1_bias) >= 0).float() x = (torch.nn.functional.linear(x, self.layer2_weight, self.layer2_bias) >= 0).float() x = (torch.nn.functional.linear(x, self.output_weight, self.output_bias) >= 0).float() return x.squeeze(-1) @classmethod def from_safetensors(cls, path): """Load model from SafeTensors file.""" from safetensors.torch import load_file weights = load_file(path) hidden1 = weights['layer1.weight'].shape[0] hidden2 = weights['layer2.weight'].shape[0] n_bits = weights['layer1.weight'].shape[1] model = cls(n_bits=n_bits, hidden1=hidden1, hidden2=hidden2) model.layer1_weight.data = weights['layer1.weight'] model.layer1_bias.data = weights['layer1.bias'] model.layer2_weight.data = weights['layer2.weight'] model.layer2_bias.data = weights['layer2.bias'] model.output_weight.data = weights['output.weight'] model.output_bias.data = weights['output.bias'] return model def parity(x): """Ground truth parity function.""" return (x.sum(dim=-1) % 2).float() if __name__ == '__main__': model = PrunedThresholdNetwork.from_safetensors('model.safetensors') # Test all 256 inputs all_inputs = torch.tensor([[int(b) for b in format(i, '08b')] for i in range(256)], dtype=torch.float32) outputs = model(all_inputs) expected = parity(all_inputs) correct = (outputs == expected).sum().item() print(f'Accuracy: {correct}/256 ({100*correct/256:.1f}%)')