""" Threshold Network for Parity Computation A ternary threshold network that computes the parity (XOR) of 8 binary inputs. Weights are constrained to {-1, 0, +1} and activations use the Heaviside step function. """ import json import torch import torch.nn as nn class ThresholdNetwork(nn.Module): """ Binary threshold network with ternary weights. Architecture: 8 -> 32 -> 16 -> 1 Weights: {-1, 0, +1} Activation: Heaviside (x >= 0 -> 1, else 0) """ def __init__(self, n_bits=8, hidden1=32, hidden2=16): super().__init__() self.n_bits = n_bits self.hidden1 = hidden1 self.hidden2 = hidden2 self.layer1_weight = nn.Parameter(torch.zeros(hidden1, n_bits)) self.layer1_bias = nn.Parameter(torch.zeros(hidden1)) self.layer2_weight = nn.Parameter(torch.zeros(hidden2, hidden1)) self.layer2_bias = nn.Parameter(torch.zeros(hidden2)) self.output_weight = nn.Parameter(torch.zeros(1, hidden2)) self.output_bias = nn.Parameter(torch.zeros(1)) def forward(self, x): """Forward pass with Heaviside activation.""" x = x.float() x = (torch.nn.functional.linear(x, self.layer1_weight, self.layer1_bias) >= 0).float() x = (torch.nn.functional.linear(x, self.layer2_weight, self.layer2_bias) >= 0).float() x = (torch.nn.functional.linear(x, self.output_weight, self.output_bias) >= 0).float() return x.squeeze(-1) @classmethod def from_safetensors(cls, path): """Load model from SafeTensors file.""" from safetensors.torch import load_file weights = load_file(path) hidden1 = weights['layer1.weight'].shape[0] hidden2 = weights['layer2.weight'].shape[0] n_bits = weights['layer1.weight'].shape[1] model = cls(n_bits=n_bits, hidden1=hidden1, hidden2=hidden2) model.layer1_weight.data = weights['layer1.weight'] model.layer1_bias.data = weights['layer1.bias'] model.layer2_weight.data = weights['layer2.weight'] model.layer2_bias.data = weights['layer2.bias'] model.output_weight.data = weights['output.weight'] model.output_bias.data = weights['output.bias'] return model def parity(x): """Ground truth parity function.""" return (x.sum(dim=-1) % 2).float() if __name__ == '__main__': model = ThresholdNetwork.from_safetensors('model.safetensors') test_inputs = torch.tensor([ [0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1], ], dtype=torch.float32) outputs = model(test_inputs) expected = parity(test_inputs) print("Input -> Output (Expected)") for i in range(len(test_inputs)): bits = test_inputs[i].int().tolist() print(f"{bits} -> {int(outputs[i].item())} ({int(expected[i].item())})")