|
|
"""
|
|
|
Inference code for mod7-verified threshold network.
|
|
|
|
|
|
This network computes MOD-7 (Hamming weight mod 7) on 8-bit binary inputs.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
|
|
|
def heaviside(x):
|
|
|
return (x >= 0).float()
|
|
|
|
|
|
|
|
|
class Mod7Network(nn.Module):
|
|
|
"""
|
|
|
Verified threshold network for MOD-7 computation.
|
|
|
Architecture: 8 -> 9 -> 6 -> 7
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.layer1 = nn.Linear(8, 9)
|
|
|
self.layer2 = nn.Linear(9, 6)
|
|
|
self.output = nn.Linear(6, 7)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x.float()
|
|
|
x = heaviside(self.layer1(x))
|
|
|
x = heaviside(self.layer2(x))
|
|
|
return self.output(x)
|
|
|
|
|
|
def predict(self, x):
|
|
|
return self.forward(x).argmax(dim=-1)
|
|
|
|
|
|
@classmethod
|
|
|
def from_safetensors(cls, path):
|
|
|
model = cls()
|
|
|
weights = load_file(path)
|
|
|
model.layer1.weight.data = weights['layer1.weight']
|
|
|
model.layer1.bias.data = weights['layer1.bias']
|
|
|
model.layer2.weight.data = weights['layer2.weight']
|
|
|
model.layer2.bias.data = weights['layer2.bias']
|
|
|
model.output.weight.data = weights['output.weight']
|
|
|
model.output.bias.data = weights['output.bias']
|
|
|
return model
|
|
|
|
|
|
|
|
|
def mod7_reference(x):
|
|
|
return (x.sum(dim=-1) % 7).long()
|
|
|
|
|
|
|
|
|
def verify(model):
|
|
|
inputs = torch.zeros(256, 8)
|
|
|
for i in range(256):
|
|
|
for j in range(8):
|
|
|
inputs[i, j] = (i >> j) & 1
|
|
|
targets = mod7_reference(inputs)
|
|
|
predictions = model.predict(inputs)
|
|
|
correct = (predictions == targets).sum().item()
|
|
|
print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)")
|
|
|
return correct == 256
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
model = Mod7Network.from_safetensors('model.safetensors')
|
|
|
verify(model)
|
|
|
|