File size: 1,653 Bytes
f7d5919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def popcount4(a, b, c, d, w):
    """Count number of 1 bits in 4-bit input. Returns [y2, y1, y0] = binary count."""
    inp = torch.tensor([float(a), float(b), float(c), float(d)])

    # Layer 1
    y2 = int((inp @ w['y2.weight'].T + w['y2.bias'] >= 0).item())
    ge2 = int((inp @ w['ge2.weight'].T + w['ge2.bias'] >= 0).item())
    le3 = int((inp @ w['le3.weight'].T + w['le3.bias'] >= 0).item())
    xor_ab_or = int((inp @ w['xor_ab_or.weight'].T + w['xor_ab_or.bias'] >= 0).item())
    xor_ab_nand = int((inp @ w['xor_ab_nand.weight'].T + w['xor_ab_nand.bias'] >= 0).item())
    xor_cd_or = int((inp @ w['xor_cd_or.weight'].T + w['xor_cd_or.bias'] >= 0).item())
    xor_cd_nand = int((inp @ w['xor_cd_nand.weight'].T + w['xor_cd_nand.bias'] >= 0).item())

    # Layer 2
    y1 = int(ge2 + le3 - 2 >= 0)
    xor_ab = int(xor_ab_or + xor_ab_nand - 2 >= 0)
    xor_cd = int(xor_cd_or + xor_cd_nand - 2 >= 0)

    # Layer 3
    xor_final_or = int(xor_ab + xor_cd - 1 >= 0)
    xor_final_nand = int(-xor_ab - xor_cd + 1 >= 0)

    # Layer 4
    y0 = int(xor_final_or + xor_final_nand - 2 >= 0)

    return [y2, y1, y0]

if __name__ == '__main__':
    w = load_model()
    print('popcount4 truth table:')
    print('abcd | count | y2 y1 y0')
    print('-----+-------+---------')
    for i in range(16):
        a, b, c, d = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1
        result = popcount4(a, b, c, d, w)
        count = a + b + c + d
        print(f'{a}{b}{c}{d} |   {count}   |  {result[0]}  {result[1]}  {result[2]}')