import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def xor_gate(a, b, w, idx): inp = torch.tensor([float(a), float(b)]) l1 = (inp @ w[f'xor{idx}.layer1.weight'].T + w[f'xor{idx}.layer1.bias'] >= 0).float() return int((l1 @ w[f'xor{idx}.layer2.weight'].T + w[f'xor{idx}.layer2.bias'] >= 0).item()) def equal(a, b, weights): """8-bit equality comparator. a, b: lists of 8 bits each (LSB first) Returns: 1 if a == b, 0 otherwise """ xors = [xor_gate(a[i], b[i], weights, i) for i in range(8)] xor_vec = torch.tensor([float(x) for x in xors]) return int((xor_vec @ weights['nor.weight'].T + weights['nor.bias'] >= 0).item()) if __name__ == '__main__': w = load_model() print('8-bit Equal Comparator') print('a == b tests:') tests = [(0, 0), (0, 1), (127, 127), (127, 128), (255, 255), (255, 0), (100, 100)] for a_val, b_val in tests: a = [(a_val >> i) & 1 for i in range(8)] b = [(b_val >> i) & 1 for i in range(8)] result = equal(a, b, w) print(f'{a_val:3d} == {b_val:3d} = {result}')