import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def xor2(a, b, prefix, w): inp = torch.tensor([float(a), float(b)]) or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item()) nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item()) l1 = torch.tensor([float(or_out), float(nand_out)]) return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item()) def compress_3to2(x, y, z, weights): """3:2 compressor: returns (sum, carry) where x+y+z = sum + 2*carry.""" xor_xy = xor2(x, y, 'xor1', weights) sum_out = xor2(xor_xy, z, 'xor2', weights) inp = torch.tensor([float(x), float(y), float(z)]) carry = int((inp @ weights['maj.weight'].T + weights['maj.bias'] >= 0).item()) return sum_out, carry if __name__ == '__main__': w = load_model() print('3:2 Compressor Truth Table:') print('x y z | sum carry | verify') print('------+-----------+-------') for x in [0, 1]: for y in [0, 1]: for z in [0, 1]: s, c = compress_3to2(x, y, z, w) check = 'OK' if (x + y + z) == (s + 2 * c) else 'FAIL' print(f'{x} {y} {z} | {s} {c} | {check}')