|
|
import torch
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
def load_model(path='model.safetensors'):
|
|
|
return load_file(path)
|
|
|
|
|
|
def full_subtractor(a, b, bin_in, weights):
|
|
|
"""Full subtractor: computes a - b - bin, returns (diff, borrow_out)"""
|
|
|
inp = torch.tensor([float(a), float(b)])
|
|
|
|
|
|
|
|
|
hs1_l1 = (inp @ weights['hs1.xor.layer1.weight'].T + weights['hs1.xor.layer1.bias'] >= 0).float()
|
|
|
d1 = (hs1_l1 @ weights['hs1.xor.layer2.weight'].T + weights['hs1.xor.layer2.bias'] >= 0).float().item()
|
|
|
b1 = (inp @ weights['hs1.borrow.weight'].T + weights['hs1.borrow.bias'] >= 0).float().item()
|
|
|
|
|
|
|
|
|
inp2 = torch.tensor([d1, float(bin_in)])
|
|
|
hs2_l1 = (inp2 @ weights['hs2.xor.layer1.weight'].T + weights['hs2.xor.layer1.bias'] >= 0).float()
|
|
|
diff = int((hs2_l1 @ weights['hs2.xor.layer2.weight'].T + weights['hs2.xor.layer2.bias'] >= 0).item())
|
|
|
b2 = (inp2 @ weights['hs2.borrow.weight'].T + weights['hs2.borrow.bias'] >= 0).float().item()
|
|
|
|
|
|
|
|
|
bout_inp = torch.tensor([b1, b2])
|
|
|
bout = int((bout_inp @ weights['bout.weight'].T + weights['bout.bias'] >= 0).item())
|
|
|
|
|
|
return diff, bout
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
w = load_model()
|
|
|
print('FullSubtractor truth table:')
|
|
|
print('a b bin | diff bout')
|
|
|
for a in [0, 1]:
|
|
|
for b in [0, 1]:
|
|
|
for bin_in in [0, 1]:
|
|
|
diff, bout = full_subtractor(a, b, bin_in, w)
|
|
|
print(f'{a} {b} {bin_in} | {diff} {bout}')
|
|
|
|