File size: 1,516 Bytes
2f8703b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def popcount3(x0, x1, x2, w):
    """3-bit population count: returns (out1, out0) where count = 2*out1 + out0."""
    inp = torch.tensor([float(x0), float(x1), float(x2)])

    at1 = int((inp @ w['atleast1.weight'].T + w['atleast1.bias'] >= 0).item())
    at2 = int((inp @ w['atleast2.weight'].T + w['atleast2.bias'] >= 0).item())
    at3 = int((inp @ w['atleast3.weight'].T + w['atleast3.bias'] >= 0).item())

    out1 = at2

    # XOR(at1, at2)
    l1 = torch.tensor([float(at1), float(at2)])
    or_out = int((l1 @ w['xor.or.weight'].T + w['xor.or.bias'] >= 0).item())
    nand_out = int((l1 @ w['xor.nand.weight'].T + w['xor.nand.bias'] >= 0).item())
    l2 = torch.tensor([float(or_out), float(nand_out)])
    xor_result = int((l2 @ w['xor.and.weight'].T + w['xor.and.bias'] >= 0).item())

    out0 = xor_result ^ at3

    return out1, out0

if __name__ == '__main__':
    w = load_model()
    print('popcount3 truth table:')
    print('x0 x1 x2 | count | out1 out0')
    print('---------+-------+----------')
    for i in range(8):
        x0, x1, x2 = (i >> 0) & 1, (i >> 1) & 1, (i >> 2) & 1
        out1, out0 = popcount3(x0, x1, x2, w)
        result = 2 * out1 + out0
        expected = x0 + x1 + x2
        status = 'OK' if result == expected else 'FAIL'
        print(f' {x0}  {x1}  {x2}  |   {expected}   |  {out1}    {out0}    {status}')