phanerozoic's picture
Upload folder using huggingface_hub
2f8703b verified
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def popcount3(x0, x1, x2, w):
"""3-bit population count: returns (out1, out0) where count = 2*out1 + out0."""
inp = torch.tensor([float(x0), float(x1), float(x2)])
at1 = int((inp @ w['atleast1.weight'].T + w['atleast1.bias'] >= 0).item())
at2 = int((inp @ w['atleast2.weight'].T + w['atleast2.bias'] >= 0).item())
at3 = int((inp @ w['atleast3.weight'].T + w['atleast3.bias'] >= 0).item())
out1 = at2
# XOR(at1, at2)
l1 = torch.tensor([float(at1), float(at2)])
or_out = int((l1 @ w['xor.or.weight'].T + w['xor.or.bias'] >= 0).item())
nand_out = int((l1 @ w['xor.nand.weight'].T + w['xor.nand.bias'] >= 0).item())
l2 = torch.tensor([float(or_out), float(nand_out)])
xor_result = int((l2 @ w['xor.and.weight'].T + w['xor.and.bias'] >= 0).item())
out0 = xor_result ^ at3
return out1, out0
if __name__ == '__main__':
w = load_model()
print('popcount3 truth table:')
print('x0 x1 x2 | count | out1 out0')
print('---------+-------+----------')
for i in range(8):
x0, x1, x2 = (i >> 0) & 1, (i >> 1) & 1, (i >> 2) & 1
out1, out0 = popcount3(x0, x1, x2, w)
result = 2 * out1 + out0
expected = x0 + x1 + x2
status = 'OK' if result == expected else 'FAIL'
print(f' {x0} {x1} {x2} | {expected} | {out1} {out0} {status}')