File size: 1,358 Bytes
6529868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def xor2(a, b, prefix, w):
    inp = torch.tensor([float(a), float(b)])
    or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item())
    nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item())
    l1 = torch.tensor([float(or_out), float(nand_out)])
    return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item())

def compress_3to2(x, y, z, weights):
    """3:2 compressor: returns (sum, carry) where x+y+z = sum + 2*carry."""
    xor_xy = xor2(x, y, 'xor1', weights)
    sum_out = xor2(xor_xy, z, 'xor2', weights)

    inp = torch.tensor([float(x), float(y), float(z)])
    carry = int((inp @ weights['maj.weight'].T + weights['maj.bias'] >= 0).item())

    return sum_out, carry

if __name__ == '__main__':
    w = load_model()
    print('3:2 Compressor Truth Table:')
    print('x y z | sum carry | verify')
    print('------+-----------+-------')
    for x in [0, 1]:
        for y in [0, 1]:
            for z in [0, 1]:
                s, c = compress_3to2(x, y, z, w)
                check = 'OK' if (x + y + z) == (s + 2 * c) else 'FAIL'
                print(f'{x} {y} {z} |  {s}    {c}   | {check}')