phanerozoic's picture
Upload folder using huggingface_hub
6529868 verified
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def xor2(a, b, prefix, w):
inp = torch.tensor([float(a), float(b)])
or_out = int((inp @ w[f'{prefix}.or.weight'].T + w[f'{prefix}.or.bias'] >= 0).item())
nand_out = int((inp @ w[f'{prefix}.nand.weight'].T + w[f'{prefix}.nand.bias'] >= 0).item())
l1 = torch.tensor([float(or_out), float(nand_out)])
return int((l1 @ w[f'{prefix}.and.weight'].T + w[f'{prefix}.and.bias'] >= 0).item())
def compress_3to2(x, y, z, weights):
"""3:2 compressor: returns (sum, carry) where x+y+z = sum + 2*carry."""
xor_xy = xor2(x, y, 'xor1', weights)
sum_out = xor2(xor_xy, z, 'xor2', weights)
inp = torch.tensor([float(x), float(y), float(z)])
carry = int((inp @ weights['maj.weight'].T + weights['maj.bias'] >= 0).item())
return sum_out, carry
if __name__ == '__main__':
w = load_model()
print('3:2 Compressor Truth Table:')
print('x y z | sum carry | verify')
print('------+-----------+-------')
for x in [0, 1]:
for y in [0, 1]:
for z in [0, 1]:
s, c = compress_3to2(x, y, z, w)
check = 'OK' if (x + y + z) == (s + 2 * c) else 'FAIL'
print(f'{x} {y} {z} | {s} {c} | {check}')