threshold-mux2 / model.py
phanerozoic's picture
Upload folder using huggingface_hub
76f4679 verified
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def mux2(d0, d1, s, weights):
"""2:1 Multiplexer: returns d0 if s=0, d1 if s=1."""
inp = torch.tensor([float(d0), float(d1), float(s)])
# Layer 1: Selection gates
sel0 = int((inp @ weights['sel0.weight'].T + weights['sel0.bias'] >= 0).item())
sel1 = int((inp @ weights['sel1.weight'].T + weights['sel1.bias'] >= 0).item())
# Layer 2: OR gate
l1 = torch.tensor([float(sel0), float(sel1)])
return int((l1 @ weights['or.weight'].T + weights['or.bias'] >= 0).item())
if __name__ == '__main__':
w = load_model()
print('MUX2 Truth Table:')
print('s d0 d1 | out | expected')
print('-' * 28)
for s in [0, 1]:
for d0 in [0, 1]:
for d1 in [0, 1]:
result = mux2(d0, d1, s, w)
expected = d1 if s else d0
status = 'OK' if result == expected else 'FAIL'
print(f'{s} {d0} {d1} | {result} | {expected} {status}')