import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def popcount3(x0, x1, x2, w): """3-bit population count: returns (out1, out0) where count = 2*out1 + out0.""" inp = torch.tensor([float(x0), float(x1), float(x2)]) at1 = int((inp @ w['atleast1.weight'].T + w['atleast1.bias'] >= 0).item()) at2 = int((inp @ w['atleast2.weight'].T + w['atleast2.bias'] >= 0).item()) at3 = int((inp @ w['atleast3.weight'].T + w['atleast3.bias'] >= 0).item()) out1 = at2 # XOR(at1, at2) l1 = torch.tensor([float(at1), float(at2)]) or_out = int((l1 @ w['xor.or.weight'].T + w['xor.or.bias'] >= 0).item()) nand_out = int((l1 @ w['xor.nand.weight'].T + w['xor.nand.bias'] >= 0).item()) l2 = torch.tensor([float(or_out), float(nand_out)]) xor_result = int((l2 @ w['xor.and.weight'].T + w['xor.and.bias'] >= 0).item()) out0 = xor_result ^ at3 return out1, out0 if __name__ == '__main__': w = load_model() print('popcount3 truth table:') print('x0 x1 x2 | count | out1 out0') print('---------+-------+----------') for i in range(8): x0, x1, x2 = (i >> 0) & 1, (i >> 1) & 1, (i >> 2) & 1 out1, out0 = popcount3(x0, x1, x2, w) result = 2 * out1 + out0 expected = x0 + x1 + x2 status = 'OK' if result == expected else 'FAIL' print(f' {x0} {x1} {x2} | {expected} | {out1} {out0} {status}')