threshold-popcount4 / create_safetensors.py
CharlesCNorton
4-bit population count, magnitude 55
f7d5919
import torch
from safetensors.torch import save_file
weights = {}
# Architecture:
# y2 = (sum >= 4) - 1 layer
# y1 = (sum >= 2) AND (sum <= 3) - 2 layers
# y0 = XOR4(a,b,c,d) = XOR(XOR(a,b), XOR(c,d)) - 4 layers
# Layer 1 (inputs: a,b,c,d)
# y2: sum >= 4
weights['y2.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)
weights['y2.bias'] = torch.tensor([-4.0], dtype=torch.float32)
# ge2: sum >= 2
weights['ge2.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)
weights['ge2.bias'] = torch.tensor([-2.0], dtype=torch.float32)
# le3: sum <= 3
weights['le3.weight'] = torch.tensor([[-1.0, -1.0, -1.0, -1.0]], dtype=torch.float32)
weights['le3.bias'] = torch.tensor([3.0], dtype=torch.float32)
# XOR(a,b) components
weights['xor_ab_or.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
weights['xor_ab_or.bias'] = torch.tensor([-1.0], dtype=torch.float32)
weights['xor_ab_nand.weight'] = torch.tensor([[-1.0, -1.0, 0.0, 0.0]], dtype=torch.float32)
weights['xor_ab_nand.bias'] = torch.tensor([1.0], dtype=torch.float32)
# XOR(c,d) components
weights['xor_cd_or.weight'] = torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float32)
weights['xor_cd_or.bias'] = torch.tensor([-1.0], dtype=torch.float32)
weights['xor_cd_nand.weight'] = torch.tensor([[0.0, 0.0, -1.0, -1.0]], dtype=torch.float32)
weights['xor_cd_nand.bias'] = torch.tensor([1.0], dtype=torch.float32)
# Layer 2
# y1 = AND(ge2, le3)
weights['y1.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) # [ge2, le3]
weights['y1.bias'] = torch.tensor([-2.0], dtype=torch.float32)
# xor_ab = AND(xor_ab_or, xor_ab_nand)
weights['xor_ab.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights['xor_ab.bias'] = torch.tensor([-2.0], dtype=torch.float32)
# xor_cd = AND(xor_cd_or, xor_cd_nand)
weights['xor_cd.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights['xor_cd.bias'] = torch.tensor([-2.0], dtype=torch.float32)
# Layer 3: XOR(xor_ab, xor_cd) components
weights['xor_final_or.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights['xor_final_or.bias'] = torch.tensor([-1.0], dtype=torch.float32)
weights['xor_final_nand.weight'] = torch.tensor([[-1.0, -1.0]], dtype=torch.float32)
weights['xor_final_nand.bias'] = torch.tensor([1.0], dtype=torch.float32)
# Layer 4: y0 = AND(xor_final_or, xor_final_nand)
weights['y0.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights['y0.bias'] = torch.tensor([-2.0], dtype=torch.float32)
save_file(weights, 'model.safetensors')
# Verify
def popcount4(a, b, c, d):
inp = torch.tensor([float(a), float(b), float(c), float(d)])
# Layer 1
y2 = int((inp @ weights['y2.weight'].T + weights['y2.bias'] >= 0).item())
ge2 = int((inp @ weights['ge2.weight'].T + weights['ge2.bias'] >= 0).item())
le3 = int((inp @ weights['le3.weight'].T + weights['le3.bias'] >= 0).item())
xor_ab_or = int((inp @ weights['xor_ab_or.weight'].T + weights['xor_ab_or.bias'] >= 0).item())
xor_ab_nand = int((inp @ weights['xor_ab_nand.weight'].T + weights['xor_ab_nand.bias'] >= 0).item())
xor_cd_or = int((inp @ weights['xor_cd_or.weight'].T + weights['xor_cd_or.bias'] >= 0).item())
xor_cd_nand = int((inp @ weights['xor_cd_nand.weight'].T + weights['xor_cd_nand.bias'] >= 0).item())
# Layer 2
l2_y1_in = torch.tensor([float(ge2), float(le3)])
y1 = int((l2_y1_in @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item())
l2_xor_ab_in = torch.tensor([float(xor_ab_or), float(xor_ab_nand)])
xor_ab = int((l2_xor_ab_in @ weights['xor_ab.weight'].T + weights['xor_ab.bias'] >= 0).item())
l2_xor_cd_in = torch.tensor([float(xor_cd_or), float(xor_cd_nand)])
xor_cd = int((l2_xor_cd_in @ weights['xor_cd.weight'].T + weights['xor_cd.bias'] >= 0).item())
# Layer 3
l3_in = torch.tensor([float(xor_ab), float(xor_cd)])
xor_final_or = int((l3_in @ weights['xor_final_or.weight'].T + weights['xor_final_or.bias'] >= 0).item())
xor_final_nand = int((l3_in @ weights['xor_final_nand.weight'].T + weights['xor_final_nand.bias'] >= 0).item())
# Layer 4
l4_in = torch.tensor([float(xor_final_or), float(xor_final_nand)])
y0 = int((l4_in @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item())
return [y2, y1, y0]
print("Verifying popcount4...")
errors = 0
for i in range(16):
a, b, c, d = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1
result = popcount4(a, b, c, d)
count = a + b + c + d
expected = [(count >> 2) & 1, (count >> 1) & 1, count & 1]
if result != expected:
errors += 1
print(f"ERROR: {a}{b}{c}{d} count={count} -> {result}, expected {expected}")
if errors == 0:
print("All 16 test cases passed!")
else:
print(f"FAILED: {errors} errors")
mag = sum(t.abs().sum().item() for t in weights.values())
print(f"Magnitude: {mag:.0f}")
print(f"Neurons: 13")
print(f"Parameters: {sum(t.numel() for t in weights.values())}")