tiny-mod5-verified / model.py
phanerozoic's picture
Initial commit: verified MOD-5 threshold circuit
278368f verified
"""
Inference code for mod5-verified threshold network.
This network computes MOD-5 (Hamming weight mod 5) on 8-bit binary inputs.
"""
import torch
import torch.nn as nn
from safetensors.torch import load_file
def heaviside(x):
"""Heaviside step function: 1 if x >= 0, else 0."""
return (x >= 0).float()
class Mod5Network(nn.Module):
"""
Verified threshold network for MOD-5 computation.
Architecture: 8 -> 9 -> 4 -> 5
- Layer 1: Thermometer encoding (9 neurons detect HW >= k)
- Layer 2: MOD-5 detection using (1,1,1,1,-4) weight pattern
- Output: 5-class classification
"""
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(8, 9)
self.layer2 = nn.Linear(9, 4)
self.output = nn.Linear(4, 5)
def forward(self, x):
"""Forward pass with Heaviside activation."""
x = x.float()
x = heaviside(self.layer1(x))
x = heaviside(self.layer2(x))
x = self.output(x)
return x
def predict(self, x):
"""Get predicted class (0, 1, 2, 3, or 4)."""
return self.forward(x).argmax(dim=-1)
@classmethod
def from_safetensors(cls, path):
"""Load model from safetensors file."""
model = cls()
weights = load_file(path)
model.layer1.weight.data = weights['layer1.weight']
model.layer1.bias.data = weights['layer1.bias']
model.layer2.weight.data = weights['layer2.weight']
model.layer2.bias.data = weights['layer2.bias']
model.output.weight.data = weights['output.weight']
model.output.bias.data = weights['output.bias']
return model
def mod5_reference(x):
"""Reference implementation: Hamming weight mod 5."""
return (x.sum(dim=-1) % 5).long()
def verify(model, verbose=True):
"""Verify model on all 256 inputs."""
inputs = torch.zeros(256, 8)
for i in range(256):
for j in range(8):
inputs[i, j] = (i >> j) & 1
targets = mod5_reference(inputs)
predictions = model.predict(inputs)
correct = (predictions == targets).sum().item()
if verbose:
print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)")
if correct < 256:
errors = (predictions != targets).nonzero(as_tuple=True)[0]
print(f"Errors at indices: {errors[:10].tolist()}")
return correct == 256
def demo():
"""Demonstration of MOD-5 computation."""
print("Loading mod5-verified model...")
model = Mod5Network.from_safetensors('model.safetensors')
print("\nVerifying on all 256 inputs...")
verify(model)
print("\nExample predictions:")
test_cases = [
[0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
]
for bits in test_cases:
x = torch.tensor([bits], dtype=torch.float32)
hw = sum(bits)
pred = model.predict(x).item()
expected = hw % 5
status = "OK" if pred == expected else "ERROR"
print(f" {bits} -> HW={hw}, pred={pred}, expected={expected} [{status}]")
if __name__ == '__main__':
demo()