threshold-equals2 / model.py
CharlesCNorton
2-bit equality comparator, magnitude 18
41edcb6
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def equals2(a1, a0, b1, b0, w):
inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)])
and1 = int((inp * w['layer1.and1.weight']).sum() + w['layer1.and1.bias'] >= 0)
nor1 = int((inp * w['layer1.nor1.weight']).sum() + w['layer1.nor1.bias'] >= 0)
and0 = int((inp * w['layer1.and0.weight']).sum() + w['layer1.and0.bias'] >= 0)
nor0 = int((inp * w['layer1.nor0.weight']).sum() + w['layer1.nor0.bias'] >= 0)
l1 = torch.tensor([float(and1), float(nor1), float(and0), float(nor0)])
xnor1 = int((l1 * w['layer2.xnor1.weight']).sum() + w['layer2.xnor1.bias'] >= 0)
xnor0 = int((l1 * w['layer2.xnor0.weight']).sum() + w['layer2.xnor0.bias'] >= 0)
l2 = torch.tensor([float(xnor1), float(xnor0)])
return int((l2 * w['layer3.eq.weight']).sum() + w['layer3.eq.bias'] >= 0)
if __name__ == '__main__':
w = load_model()
print('equals2 truth table:')
for a in range(4):
for b in range(4):
a1, a0, b1, b0 = (a >> 1) & 1, a & 1, (b >> 1) & 1, b & 1
print(f' {a} == {b}? {equals2(a1, a0, b1, b0, w)}')