| | import torch
|
| | from safetensors.torch import load_file
|
| |
|
| | def load_model(path='model.safetensors'):
|
| | return load_file(path)
|
| |
|
| | def popcount3(x0, x1, x2, w):
|
| | """3-bit population count: returns (out1, out0) where count = 2*out1 + out0."""
|
| | inp = torch.tensor([float(x0), float(x1), float(x2)])
|
| |
|
| | at1 = int((inp @ w['atleast1.weight'].T + w['atleast1.bias'] >= 0).item())
|
| | at2 = int((inp @ w['atleast2.weight'].T + w['atleast2.bias'] >= 0).item())
|
| | at3 = int((inp @ w['atleast3.weight'].T + w['atleast3.bias'] >= 0).item())
|
| |
|
| | out1 = at2
|
| |
|
| |
|
| | l1 = torch.tensor([float(at1), float(at2)])
|
| | or_out = int((l1 @ w['xor.or.weight'].T + w['xor.or.bias'] >= 0).item())
|
| | nand_out = int((l1 @ w['xor.nand.weight'].T + w['xor.nand.bias'] >= 0).item())
|
| | l2 = torch.tensor([float(or_out), float(nand_out)])
|
| | xor_result = int((l2 @ w['xor.and.weight'].T + w['xor.and.bias'] >= 0).item())
|
| |
|
| | out0 = xor_result ^ at3
|
| |
|
| | return out1, out0
|
| |
|
| | if __name__ == '__main__':
|
| | w = load_model()
|
| | print('popcount3 truth table:')
|
| | print('x0 x1 x2 | count | out1 out0')
|
| | print('---------+-------+----------')
|
| | for i in range(8):
|
| | x0, x1, x2 = (i >> 0) & 1, (i >> 1) & 1, (i >> 2) & 1
|
| | out1, out0 = popcount3(x0, x1, x2, w)
|
| | result = 2 * out1 + out0
|
| | expected = x0 + x1 + x2
|
| | status = 'OK' if result == expected else 'FAIL'
|
| | print(f' {x0} {x1} {x2} | {expected} | {out1} {out0} {status}')
|
| |
|