import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def full_subtractor(a, b, bin_in, weights): """Full subtractor: computes a - b - bin, returns (diff, borrow_out)""" inp = torch.tensor([float(a), float(b)]) # HS1: a - b -> d1, b1 hs1_l1 = (inp @ weights['hs1.xor.layer1.weight'].T + weights['hs1.xor.layer1.bias'] >= 0).float() d1 = (hs1_l1 @ weights['hs1.xor.layer2.weight'].T + weights['hs1.xor.layer2.bias'] >= 0).float().item() b1 = (inp @ weights['hs1.borrow.weight'].T + weights['hs1.borrow.bias'] >= 0).float().item() # HS2: d1 - bin -> diff, b2 inp2 = torch.tensor([d1, float(bin_in)]) hs2_l1 = (inp2 @ weights['hs2.xor.layer1.weight'].T + weights['hs2.xor.layer1.bias'] >= 0).float() diff = int((hs2_l1 @ weights['hs2.xor.layer2.weight'].T + weights['hs2.xor.layer2.bias'] >= 0).item()) b2 = (inp2 @ weights['hs2.borrow.weight'].T + weights['hs2.borrow.bias'] >= 0).float().item() # Final borrow: OR(b1, b2) bout_inp = torch.tensor([b1, b2]) bout = int((bout_inp @ weights['bout.weight'].T + weights['bout.bias'] >= 0).item()) return diff, bout if __name__ == '__main__': w = load_model() print('FullSubtractor truth table:') print('a b bin | diff bout') for a in [0, 1]: for b in [0, 1]: for bin_in in [0, 1]: diff, bout = full_subtractor(a, b, bin_in, w) print(f'{a} {b} {bin_in} | {diff} {bout}')