| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def xor2(a, b, prefix, w): | |
| inp = torch.tensor([float(a), float(b)]) | |
| or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item()) | |
| nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item()) | |
| l1 = torch.tensor([float(or_out), float(nand_out)]) | |
| return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item()) | |
| def compress_3to2(x, y, z, weights): | |
| """3:2 compressor: returns (sum, carry) where x+y+z = sum + 2*carry.""" | |
| xor_xy = xor2(x, y, 'xor1', weights) | |
| sum_out = xor2(xor_xy, z, 'xor2', weights) | |
| inp = torch.tensor([float(x), float(y), float(z)]) | |
| carry = int((inp @ weights['maj.weight'].T + weights['maj.bias'] >= 0).item()) | |
| return sum_out, carry | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('3:2 Compressor Truth Table:') | |
| print('x y z | sum carry | verify') | |
| print('------+-----------+-------') | |
| for x in [0, 1]: | |
| for y in [0, 1]: | |
| for z in [0, 1]: | |
| s, c = compress_3to2(x, y, z, w) | |
| check = 'OK' if (x + y + z) == (s + 2 * c) else 'FAIL' | |
| print(f'{x} {y} {z} | {s} {c} | {check}') | |