import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def popcount4(a, b, c, d, w): """Count number of 1 bits in 4-bit input. Returns [y2, y1, y0] = binary count.""" inp = torch.tensor([float(a), float(b), float(c), float(d)]) # Layer 1 y2 = int((inp @ w['y2.weight'].T + w['y2.bias'] >= 0).item()) ge2 = int((inp @ w['ge2.weight'].T + w['ge2.bias'] >= 0).item()) le3 = int((inp @ w['le3.weight'].T + w['le3.bias'] >= 0).item()) xor_ab_or = int((inp @ w['xor_ab_or.weight'].T + w['xor_ab_or.bias'] >= 0).item()) xor_ab_nand = int((inp @ w['xor_ab_nand.weight'].T + w['xor_ab_nand.bias'] >= 0).item()) xor_cd_or = int((inp @ w['xor_cd_or.weight'].T + w['xor_cd_or.bias'] >= 0).item()) xor_cd_nand = int((inp @ w['xor_cd_nand.weight'].T + w['xor_cd_nand.bias'] >= 0).item()) # Layer 2 y1 = int(ge2 + le3 - 2 >= 0) xor_ab = int(xor_ab_or + xor_ab_nand - 2 >= 0) xor_cd = int(xor_cd_or + xor_cd_nand - 2 >= 0) # Layer 3 xor_final_or = int(xor_ab + xor_cd - 1 >= 0) xor_final_nand = int(-xor_ab - xor_cd + 1 >= 0) # Layer 4 y0 = int(xor_final_or + xor_final_nand - 2 >= 0) return [y2, y1, y0] if __name__ == '__main__': w = load_model() print('popcount4 truth table:') print('abcd | count | y2 y1 y0') print('-----+-------+---------') for i in range(16): a, b, c, d = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1 result = popcount4(a, b, c, d, w) count = a + b + c + d print(f'{a}{b}{c}{d} | {count} | {result[0]} {result[1]} {result[2]}')