threshold-exactly4outof5 / create_safetensors.py
CharlesCNorton
Exactly 4 of 5 threshold circuit, magnitude 22
035d82d
import torch
from safetensors.torch import save_file
weights = {
'layer1.weight': torch.tensor([
[1.0, 1.0, 1.0, 1.0, 1.0], # N1: sum >= 4
[-1.0, -1.0, -1.0, -1.0, -1.0] # N2: sum <= 4
], dtype=torch.float32),
'layer1.bias': torch.tensor([-4.0, 4.0], dtype=torch.float32),
'layer2.weight': torch.tensor([[1.0, 1.0]], dtype=torch.float32),
'layer2.bias': torch.tensor([-2.0], dtype=torch.float32)
}
save_file(weights, 'model.safetensors')
def exactly4of5(a, b, c, d, e):
inp = torch.tensor([float(a), float(b), float(c), float(d), float(e)])
l1 = (inp @ weights['layer1.weight'].T + weights['layer1.bias'] >= 0).float()
out = (l1 @ weights['layer2.weight'].T + weights['layer2.bias'] >= 0).float()
return int(out.item())
print("Verifying exactly4outof5...")
errors = 0
for i in range(32):
bits = [(i >> j) & 1 for j in range(5)]
result = exactly4of5(*bits)
expected = 1 if sum(bits) == 4 else 0
if result != expected:
errors += 1
print(f"ERROR: {bits} -> {result}, expected {expected}")
if errors == 0:
print("All 32 test cases passed!")
print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")