| """ | |
| Threshold Network for MOD-9 Circuit | |
| For 8-bit inputs, HW ranges 0-8, all less than 9, so HW mod 9 = HW. | |
| """ | |
| import torch | |
| from safetensors.torch import load_file | |
| class ThresholdMod9: | |
| def __init__(self, weights_dict): | |
| self.weight = weights_dict['weight'] | |
| self.bias = weights_dict['bias'] | |
| def __call__(self, bits): | |
| inputs = torch.tensor([float(b) for b in bits]) | |
| return (inputs * self.weight).sum() + self.bias | |
| def from_safetensors(cls, path="model.safetensors"): | |
| return cls(load_file(path)) | |
| if __name__ == "__main__": | |
| weights = load_file("model.safetensors") | |
| model = ThresholdMod9(weights) | |
| for hw in range(9): | |
| bits = [1]*hw + [0]*(8-hw) | |
| print(f"HW={hw}: out={model(bits).item():.0f}, HW mod 9 = {hw}") | |