|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
def load_model(path='model.safetensors'): |
|
|
return load_file(path) |
|
|
|
|
|
def popcount4(a, b, c, d, w): |
|
|
"""Count number of 1 bits in 4-bit input. Returns [y2, y1, y0] = binary count.""" |
|
|
inp = torch.tensor([float(a), float(b), float(c), float(d)]) |
|
|
|
|
|
|
|
|
y2 = int((inp @ w['y2.weight'].T + w['y2.bias'] >= 0).item()) |
|
|
ge2 = int((inp @ w['ge2.weight'].T + w['ge2.bias'] >= 0).item()) |
|
|
le3 = int((inp @ w['le3.weight'].T + w['le3.bias'] >= 0).item()) |
|
|
xor_ab_or = int((inp @ w['xor_ab_or.weight'].T + w['xor_ab_or.bias'] >= 0).item()) |
|
|
xor_ab_nand = int((inp @ w['xor_ab_nand.weight'].T + w['xor_ab_nand.bias'] >= 0).item()) |
|
|
xor_cd_or = int((inp @ w['xor_cd_or.weight'].T + w['xor_cd_or.bias'] >= 0).item()) |
|
|
xor_cd_nand = int((inp @ w['xor_cd_nand.weight'].T + w['xor_cd_nand.bias'] >= 0).item()) |
|
|
|
|
|
|
|
|
y1 = int(ge2 + le3 - 2 >= 0) |
|
|
xor_ab = int(xor_ab_or + xor_ab_nand - 2 >= 0) |
|
|
xor_cd = int(xor_cd_or + xor_cd_nand - 2 >= 0) |
|
|
|
|
|
|
|
|
xor_final_or = int(xor_ab + xor_cd - 1 >= 0) |
|
|
xor_final_nand = int(-xor_ab - xor_cd + 1 >= 0) |
|
|
|
|
|
|
|
|
y0 = int(xor_final_or + xor_final_nand - 2 >= 0) |
|
|
|
|
|
return [y2, y1, y0] |
|
|
|
|
|
if __name__ == '__main__': |
|
|
w = load_model() |
|
|
print('popcount4 truth table:') |
|
|
print('abcd | count | y2 y1 y0') |
|
|
print('-----+-------+---------') |
|
|
for i in range(16): |
|
|
a, b, c, d = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1 |
|
|
result = popcount4(a, b, c, d, w) |
|
|
count = a + b + c + d |
|
|
print(f'{a}{b}{c}{d} | {count} | {result[0]} {result[1]} {result[2]}') |
|
|
|