""" Inference code for mod5-verified threshold network. This network computes MOD-5 (Hamming weight mod 5) on 8-bit binary inputs. """ import torch import torch.nn as nn from safetensors.torch import load_file def heaviside(x): """Heaviside step function: 1 if x >= 0, else 0.""" return (x >= 0).float() class Mod5Network(nn.Module): """ Verified threshold network for MOD-5 computation. Architecture: 8 -> 9 -> 4 -> 5 - Layer 1: Thermometer encoding (9 neurons detect HW >= k) - Layer 2: MOD-5 detection using (1,1,1,1,-4) weight pattern - Output: 5-class classification """ def __init__(self): super().__init__() self.layer1 = nn.Linear(8, 9) self.layer2 = nn.Linear(9, 4) self.output = nn.Linear(4, 5) def forward(self, x): """Forward pass with Heaviside activation.""" x = x.float() x = heaviside(self.layer1(x)) x = heaviside(self.layer2(x)) x = self.output(x) return x def predict(self, x): """Get predicted class (0, 1, 2, 3, or 4).""" return self.forward(x).argmax(dim=-1) @classmethod def from_safetensors(cls, path): """Load model from safetensors file.""" model = cls() weights = load_file(path) model.layer1.weight.data = weights['layer1.weight'] model.layer1.bias.data = weights['layer1.bias'] model.layer2.weight.data = weights['layer2.weight'] model.layer2.bias.data = weights['layer2.bias'] model.output.weight.data = weights['output.weight'] model.output.bias.data = weights['output.bias'] return model def mod5_reference(x): """Reference implementation: Hamming weight mod 5.""" return (x.sum(dim=-1) % 5).long() def verify(model, verbose=True): """Verify model on all 256 inputs.""" inputs = torch.zeros(256, 8) for i in range(256): for j in range(8): inputs[i, j] = (i >> j) & 1 targets = mod5_reference(inputs) predictions = model.predict(inputs) correct = (predictions == targets).sum().item() if verbose: print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)") if correct < 256: errors = (predictions != targets).nonzero(as_tuple=True)[0] print(f"Errors at indices: {errors[:10].tolist()}") return correct == 256 def demo(): """Demonstration of MOD-5 computation.""" print("Loading mod5-verified model...") model = Mod5Network.from_safetensors('model.safetensors') print("\nVerifying on all 256 inputs...") verify(model) print("\nExample predictions:") test_cases = [ [0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1], ] for bits in test_cases: x = torch.tensor([bits], dtype=torch.float32) hw = sum(bits) pred = model.predict(x).item() expected = hw % 5 status = "OK" if pred == expected else "ERROR" print(f" {bits} -> HW={hw}, pred={pred}, expected={expected} [{status}]") if __name__ == '__main__': demo()