| | import torch
|
| | from safetensors.torch import load_file
|
| |
|
| | def load_model(path='model.safetensors'):
|
| | return load_file(path)
|
| |
|
| | def xor2(a, b, prefix, w):
|
| | inp = torch.tensor([float(a), float(b)])
|
| | or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item())
|
| | nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item())
|
| | l1 = torch.tensor([float(or_out), float(nand_out)])
|
| | return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item())
|
| |
|
| | def compress_4to2(x, y, z, w_in, cin, weights):
|
| | """4:2 compressor: x+y+z+w+cin = sum + 2*carry + 2*cout."""
|
| | xy = xor2(x, y, 'xor_xy', weights)
|
| | xyz = xor2(xy, z, 'xor_xyz', weights)
|
| | xyzw = xor2(xyz, w_in, 'xor_xyzw', weights)
|
| | sum_out = xor2(xyzw, cin, 'xor_sum', weights)
|
| |
|
| | inp_cout = torch.tensor([float(x), float(y), float(z)])
|
| | cout = int((inp_cout @ weights['cout.weight'].T + weights['cout.bias'] >= 0).item())
|
| |
|
| | inp_carry = torch.tensor([float(xyz), float(w_in), float(cin)])
|
| | carry = int((inp_carry @ weights['carry.weight'].T + weights['carry.bias'] >= 0).item())
|
| |
|
| | return sum_out, carry, cout
|
| |
|
| | if __name__ == '__main__':
|
| | w = load_model()
|
| | print('4:2 Compressor selected tests:')
|
| | print('x y z w cin | sum carry cout | verify')
|
| | print('------------+----------------+-------')
|
| | for total in range(6):
|
| |
|
| | for x in [0, 1]:
|
| | for y in [0, 1]:
|
| | for z in [0, 1]:
|
| | for w_in in [0, 1]:
|
| | for cin in [0, 1]:
|
| | if x + y + z + w_in + cin == total:
|
| | s, carry, cout = compress_4to2(x, y, z, w_in, cin, w)
|
| | check = 'OK' if total == s + 2*carry + 2*cout else 'FAIL'
|
| | print(f'{x} {y} {z} {w_in} {cin} | {s} {carry} {cout} | {check}')
|
| | break
|
| | else:
|
| | continue
|
| | break
|
| | else:
|
| | continue
|
| | break
|
| | else:
|
| | continue
|
| | break
|
| |
|