| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def xor2(a, b, prefix, w): | |
| inp = torch.tensor([float(a), float(b)]) | |
| or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item()) | |
| nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item()) | |
| l1 = torch.tensor([float(or_out), float(nand_out)]) | |
| return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item()) | |
| def prefix_xor(x3, x2, x1, x0, w): | |
| y3 = x3 | |
| y2 = xor2(y3, x2, 'xor2', w) | |
| y1 = xor2(y2, x1, 'xor1', w) | |
| y0 = xor2(y1, x0, 'xor0', w) | |
| return y3, y2, y1, y0 | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('Prefix-XOR (running parity):') | |
| for i in [0b0000, 0b0001, 0b0011, 0b0111, 0b1111, 0b1010]: | |
| x3, x2, x1, x0 = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1 | |
| y3, y2, y1, y0 = prefix_xor(x3, x2, x1, x0, w) | |
| print(f'{x3}{x2}{x1}{x0} -> {y3}{y2}{y1}{y0} (parity={y0})') | |