|
|
"""
|
|
|
Inference code for mod5-verified threshold network.
|
|
|
|
|
|
This network computes MOD-5 (Hamming weight mod 5) on 8-bit binary inputs.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
|
|
|
def heaviside(x):
|
|
|
"""Heaviside step function: 1 if x >= 0, else 0."""
|
|
|
return (x >= 0).float()
|
|
|
|
|
|
|
|
|
class Mod5Network(nn.Module):
|
|
|
"""
|
|
|
Verified threshold network for MOD-5 computation.
|
|
|
|
|
|
Architecture: 8 -> 9 -> 4 -> 5
|
|
|
- Layer 1: Thermometer encoding (9 neurons detect HW >= k)
|
|
|
- Layer 2: MOD-5 detection using (1,1,1,1,-4) weight pattern
|
|
|
- Output: 5-class classification
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.layer1 = nn.Linear(8, 9)
|
|
|
self.layer2 = nn.Linear(9, 4)
|
|
|
self.output = nn.Linear(4, 5)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Forward pass with Heaviside activation."""
|
|
|
x = x.float()
|
|
|
x = heaviside(self.layer1(x))
|
|
|
x = heaviside(self.layer2(x))
|
|
|
x = self.output(x)
|
|
|
return x
|
|
|
|
|
|
def predict(self, x):
|
|
|
"""Get predicted class (0, 1, 2, 3, or 4)."""
|
|
|
return self.forward(x).argmax(dim=-1)
|
|
|
|
|
|
@classmethod
|
|
|
def from_safetensors(cls, path):
|
|
|
"""Load model from safetensors file."""
|
|
|
model = cls()
|
|
|
weights = load_file(path)
|
|
|
|
|
|
model.layer1.weight.data = weights['layer1.weight']
|
|
|
model.layer1.bias.data = weights['layer1.bias']
|
|
|
model.layer2.weight.data = weights['layer2.weight']
|
|
|
model.layer2.bias.data = weights['layer2.bias']
|
|
|
model.output.weight.data = weights['output.weight']
|
|
|
model.output.bias.data = weights['output.bias']
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
def mod5_reference(x):
|
|
|
"""Reference implementation: Hamming weight mod 5."""
|
|
|
return (x.sum(dim=-1) % 5).long()
|
|
|
|
|
|
|
|
|
def verify(model, verbose=True):
|
|
|
"""Verify model on all 256 inputs."""
|
|
|
inputs = torch.zeros(256, 8)
|
|
|
for i in range(256):
|
|
|
for j in range(8):
|
|
|
inputs[i, j] = (i >> j) & 1
|
|
|
|
|
|
targets = mod5_reference(inputs)
|
|
|
predictions = model.predict(inputs)
|
|
|
|
|
|
correct = (predictions == targets).sum().item()
|
|
|
|
|
|
if verbose:
|
|
|
print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)")
|
|
|
|
|
|
if correct < 256:
|
|
|
errors = (predictions != targets).nonzero(as_tuple=True)[0]
|
|
|
print(f"Errors at indices: {errors[:10].tolist()}")
|
|
|
|
|
|
return correct == 256
|
|
|
|
|
|
|
|
|
def demo():
|
|
|
"""Demonstration of MOD-5 computation."""
|
|
|
print("Loading mod5-verified model...")
|
|
|
model = Mod5Network.from_safetensors('model.safetensors')
|
|
|
|
|
|
print("\nVerifying on all 256 inputs...")
|
|
|
verify(model)
|
|
|
|
|
|
print("\nExample predictions:")
|
|
|
test_cases = [
|
|
|
[0, 0, 0, 0, 0, 0, 0, 0],
|
|
|
[1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
[1, 1, 0, 0, 0, 0, 0, 0],
|
|
|
[1, 1, 1, 0, 0, 0, 0, 0],
|
|
|
[1, 1, 1, 1, 0, 0, 0, 0],
|
|
|
[1, 1, 1, 1, 1, 0, 0, 0],
|
|
|
[1, 1, 1, 1, 1, 1, 0, 0],
|
|
|
[1, 1, 1, 1, 1, 1, 1, 0],
|
|
|
[1, 1, 1, 1, 1, 1, 1, 1],
|
|
|
]
|
|
|
|
|
|
for bits in test_cases:
|
|
|
x = torch.tensor([bits], dtype=torch.float32)
|
|
|
hw = sum(bits)
|
|
|
pred = model.predict(x).item()
|
|
|
expected = hw % 5
|
|
|
status = "OK" if pred == expected else "ERROR"
|
|
|
print(f" {bits} -> HW={hw}, pred={pred}, expected={expected} [{status}]")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
demo()
|
|
|
|