import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def xor2(a, b, prefix, w): inp = torch.tensor([float(a), float(b)]) or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item()) nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item()) l1 = torch.tensor([float(or_out), float(nand_out)]) return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item()) def prefix_xor(x3, x2, x1, x0, w): y3 = x3 y2 = xor2(y3, x2, 'xor2', w) y1 = xor2(y2, x1, 'xor1', w) y0 = xor2(y1, x0, 'xor0', w) return y3, y2, y1, y0 if __name__ == '__main__': w = load_model() print('Prefix-XOR (running parity):') for i in [0b0000, 0b0001, 0b0011, 0b0111, 0b1111, 0b1010]: x3, x2, x1, x0 = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1 y3, y2, y1, y0 = prefix_xor(x3, x2, x1, x0, w) print(f'{x3}{x2}{x1}{x0} -> {y3}{y2}{y1}{y0} (parity={y0})')