File size: 1,558 Bytes
2a0195d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def full_subtractor(a, b, bin_in, weights):
    """Full subtractor: computes a - b - bin, returns (diff, borrow_out)"""
    inp = torch.tensor([float(a), float(b)])

    # HS1: a - b -> d1, b1
    hs1_l1 = (inp @ weights['hs1.xor.layer1.weight'].T + weights['hs1.xor.layer1.bias'] >= 0).float()
    d1 = (hs1_l1 @ weights['hs1.xor.layer2.weight'].T + weights['hs1.xor.layer2.bias'] >= 0).float().item()
    b1 = (inp @ weights['hs1.borrow.weight'].T + weights['hs1.borrow.bias'] >= 0).float().item()

    # HS2: d1 - bin -> diff, b2
    inp2 = torch.tensor([d1, float(bin_in)])
    hs2_l1 = (inp2 @ weights['hs2.xor.layer1.weight'].T + weights['hs2.xor.layer1.bias'] >= 0).float()
    diff = int((hs2_l1 @ weights['hs2.xor.layer2.weight'].T + weights['hs2.xor.layer2.bias'] >= 0).item())
    b2 = (inp2 @ weights['hs2.borrow.weight'].T + weights['hs2.borrow.bias'] >= 0).float().item()

    # Final borrow: OR(b1, b2)
    bout_inp = torch.tensor([b1, b2])
    bout = int((bout_inp @ weights['bout.weight'].T + weights['bout.bias'] >= 0).item())

    return diff, bout

if __name__ == '__main__':
    w = load_model()
    print('FullSubtractor truth table:')
    print('a b bin | diff bout')
    for a in [0, 1]:
        for b in [0, 1]:
            for bin_in in [0, 1]:
                diff, bout = full_subtractor(a, b, bin_in, w)
                print(f'{a} {b}  {bin_in}  |  {diff}    {bout}')