File size: 1,204 Bytes
41edcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def equals2(a1, a0, b1, b0, w):
    inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)])

    and1 = int((inp * w['layer1.and1.weight']).sum() + w['layer1.and1.bias'] >= 0)
    nor1 = int((inp * w['layer1.nor1.weight']).sum() + w['layer1.nor1.bias'] >= 0)
    and0 = int((inp * w['layer1.and0.weight']).sum() + w['layer1.and0.bias'] >= 0)
    nor0 = int((inp * w['layer1.nor0.weight']).sum() + w['layer1.nor0.bias'] >= 0)

    l1 = torch.tensor([float(and1), float(nor1), float(and0), float(nor0)])
    xnor1 = int((l1 * w['layer2.xnor1.weight']).sum() + w['layer2.xnor1.bias'] >= 0)
    xnor0 = int((l1 * w['layer2.xnor0.weight']).sum() + w['layer2.xnor0.bias'] >= 0)

    l2 = torch.tensor([float(xnor1), float(xnor0)])
    return int((l2 * w['layer3.eq.weight']).sum() + w['layer3.eq.bias'] >= 0)

if __name__ == '__main__':
    w = load_model()
    print('equals2 truth table:')
    for a in range(4):
        for b in range(4):
            a1, a0, b1, b0 = (a >> 1) & 1, a & 1, (b >> 1) & 1, b & 1
            print(f'  {a} == {b}? {equals2(a1, a0, b1, b0, w)}')