File size: 1,879 Bytes
25d9c47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
"""
Inference code for mod7-verified threshold network.
This network computes MOD-7 (Hamming weight mod 7) on 8-bit binary inputs.
"""
import torch
import torch.nn as nn
from safetensors.torch import load_file
def heaviside(x):
return (x >= 0).float()
class Mod7Network(nn.Module):
"""
Verified threshold network for MOD-7 computation.
Architecture: 8 -> 9 -> 6 -> 7
"""
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(8, 9)
self.layer2 = nn.Linear(9, 6)
self.output = nn.Linear(6, 7)
def forward(self, x):
x = x.float()
x = heaviside(self.layer1(x))
x = heaviside(self.layer2(x))
return self.output(x)
def predict(self, x):
return self.forward(x).argmax(dim=-1)
@classmethod
def from_safetensors(cls, path):
model = cls()
weights = load_file(path)
model.layer1.weight.data = weights['layer1.weight']
model.layer1.bias.data = weights['layer1.bias']
model.layer2.weight.data = weights['layer2.weight']
model.layer2.bias.data = weights['layer2.bias']
model.output.weight.data = weights['output.weight']
model.output.bias.data = weights['output.bias']
return model
def mod7_reference(x):
return (x.sum(dim=-1) % 7).long()
def verify(model):
inputs = torch.zeros(256, 8)
for i in range(256):
for j in range(8):
inputs[i, j] = (i >> j) & 1
targets = mod7_reference(inputs)
predictions = model.predict(inputs)
correct = (predictions == targets).sum().item()
print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)")
return correct == 256
if __name__ == '__main__':
model = Mod7Network.from_safetensors('model.safetensors')
verify(model)
|