import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def equals2(a1, a0, b1, b0, w): inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)]) and1 = int((inp * w['layer1.and1.weight']).sum() + w['layer1.and1.bias'] >= 0) nor1 = int((inp * w['layer1.nor1.weight']).sum() + w['layer1.nor1.bias'] >= 0) and0 = int((inp * w['layer1.and0.weight']).sum() + w['layer1.and0.bias'] >= 0) nor0 = int((inp * w['layer1.nor0.weight']).sum() + w['layer1.nor0.bias'] >= 0) l1 = torch.tensor([float(and1), float(nor1), float(and0), float(nor0)]) xnor1 = int((l1 * w['layer2.xnor1.weight']).sum() + w['layer2.xnor1.bias'] >= 0) xnor0 = int((l1 * w['layer2.xnor0.weight']).sum() + w['layer2.xnor0.bias'] >= 0) l2 = torch.tensor([float(xnor1), float(xnor0)]) return int((l2 * w['layer3.eq.weight']).sum() + w['layer3.eq.bias'] >= 0) if __name__ == '__main__': w = load_model() print('equals2 truth table:') for a in range(4): for b in range(4): a1, a0, b1, b0 = (a >> 1) & 1, a & 1, (b >> 1) & 1, b & 1 print(f' {a} == {b}? {equals2(a1, a0, b1, b0, w)}')