""" Inference code for mod7-verified threshold network. This network computes MOD-7 (Hamming weight mod 7) on 8-bit binary inputs. """ import torch import torch.nn as nn from safetensors.torch import load_file def heaviside(x): return (x >= 0).float() class Mod7Network(nn.Module): """ Verified threshold network for MOD-7 computation. Architecture: 8 -> 9 -> 6 -> 7 """ def __init__(self): super().__init__() self.layer1 = nn.Linear(8, 9) self.layer2 = nn.Linear(9, 6) self.output = nn.Linear(6, 7) def forward(self, x): x = x.float() x = heaviside(self.layer1(x)) x = heaviside(self.layer2(x)) return self.output(x) def predict(self, x): return self.forward(x).argmax(dim=-1) @classmethod def from_safetensors(cls, path): model = cls() weights = load_file(path) model.layer1.weight.data = weights['layer1.weight'] model.layer1.bias.data = weights['layer1.bias'] model.layer2.weight.data = weights['layer2.weight'] model.layer2.bias.data = weights['layer2.bias'] model.output.weight.data = weights['output.weight'] model.output.bias.data = weights['output.bias'] return model def mod7_reference(x): return (x.sum(dim=-1) % 7).long() def verify(model): inputs = torch.zeros(256, 8) for i in range(256): for j in range(8): inputs[i, j] = (i >> j) & 1 targets = mod7_reference(inputs) predictions = model.predict(inputs) correct = (predictions == targets).sum().item() print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)") return correct == 256 if __name__ == '__main__': model = Mod7Network.from_safetensors('model.safetensors') verify(model)