import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def mux2(d0, d1, s, weights): """2:1 Multiplexer: returns d0 if s=0, d1 if s=1.""" inp = torch.tensor([float(d0), float(d1), float(s)]) # Layer 1: Selection gates sel0 = int((inp @ weights['sel0.weight'].T + weights['sel0.bias'] >= 0).item()) sel1 = int((inp @ weights['sel1.weight'].T + weights['sel1.bias'] >= 0).item()) # Layer 2: OR gate l1 = torch.tensor([float(sel0), float(sel1)]) return int((l1 @ weights['or.weight'].T + weights['or.bias'] >= 0).item()) if __name__ == '__main__': w = load_model() print('MUX2 Truth Table:') print('s d0 d1 | out | expected') print('-' * 28) for s in [0, 1]: for d0 in [0, 1]: for d1 in [0, 1]: result = mux2(d0, d1, s, w) expected = d1 if s else d0 status = 'OK' if result == expected else 'FAIL' print(f'{s} {d0} {d1} | {result} | {expected} {status}')