CharlesCNorton
4-bit population count, magnitude 55
f7d5919
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def popcount4(a, b, c, d, w):
"""Count number of 1 bits in 4-bit input. Returns [y2, y1, y0] = binary count."""
inp = torch.tensor([float(a), float(b), float(c), float(d)])
# Layer 1
y2 = int((inp @ w['y2.weight'].T + w['y2.bias'] >= 0).item())
ge2 = int((inp @ w['ge2.weight'].T + w['ge2.bias'] >= 0).item())
le3 = int((inp @ w['le3.weight'].T + w['le3.bias'] >= 0).item())
xor_ab_or = int((inp @ w['xor_ab_or.weight'].T + w['xor_ab_or.bias'] >= 0).item())
xor_ab_nand = int((inp @ w['xor_ab_nand.weight'].T + w['xor_ab_nand.bias'] >= 0).item())
xor_cd_or = int((inp @ w['xor_cd_or.weight'].T + w['xor_cd_or.bias'] >= 0).item())
xor_cd_nand = int((inp @ w['xor_cd_nand.weight'].T + w['xor_cd_nand.bias'] >= 0).item())
# Layer 2
y1 = int(ge2 + le3 - 2 >= 0)
xor_ab = int(xor_ab_or + xor_ab_nand - 2 >= 0)
xor_cd = int(xor_cd_or + xor_cd_nand - 2 >= 0)
# Layer 3
xor_final_or = int(xor_ab + xor_cd - 1 >= 0)
xor_final_nand = int(-xor_ab - xor_cd + 1 >= 0)
# Layer 4
y0 = int(xor_final_or + xor_final_nand - 2 >= 0)
return [y2, y1, y0]
if __name__ == '__main__':
w = load_model()
print('popcount4 truth table:')
print('abcd | count | y2 y1 y0')
print('-----+-------+---------')
for i in range(16):
a, b, c, d = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1
result = popcount4(a, b, c, d, w)
count = a + b + c + d
print(f'{a}{b}{c}{d} | {count} | {result[0]} {result[1]} {result[2]}')