import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def xor2(a, b, prefix, w): inp = torch.tensor([float(a), float(b)]) or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item()) nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item()) l1 = torch.tensor([float(or_out), float(nand_out)]) return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item()) def compress_4to2(x, y, z, w_in, cin, weights): """4:2 compressor: x+y+z+w+cin = sum + 2*carry + 2*cout.""" xy = xor2(x, y, 'xor_xy', weights) xyz = xor2(xy, z, 'xor_xyz', weights) xyzw = xor2(xyz, w_in, 'xor_xyzw', weights) sum_out = xor2(xyzw, cin, 'xor_sum', weights) inp_cout = torch.tensor([float(x), float(y), float(z)]) cout = int((inp_cout @ weights['cout.weight'].T + weights['cout.bias'] >= 0).item()) inp_carry = torch.tensor([float(xyz), float(w_in), float(cin)]) carry = int((inp_carry @ weights['carry.weight'].T + weights['carry.bias'] >= 0).item()) return sum_out, carry, cout if __name__ == '__main__': w = load_model() print('4:2 Compressor selected tests:') print('x y z w cin | sum carry cout | verify') print('------------+----------------+-------') for total in range(6): # Generate a combination with this total for x in [0, 1]: for y in [0, 1]: for z in [0, 1]: for w_in in [0, 1]: for cin in [0, 1]: if x + y + z + w_in + cin == total: s, carry, cout = compress_4to2(x, y, z, w_in, cin, w) check = 'OK' if total == s + 2*carry + 2*cout else 'FAIL' print(f'{x} {y} {z} {w_in} {cin} | {s} {carry} {cout} | {check}') break else: continue break else: continue break else: continue break