| | import torch
|
| | from safetensors.torch import load_file
|
| |
|
| | def load_model(path='model.safetensors'):
|
| | return load_file(path)
|
| |
|
| | def mux2(d0, d1, s, weights):
|
| | """2:1 Multiplexer: returns d0 if s=0, d1 if s=1."""
|
| | inp = torch.tensor([float(d0), float(d1), float(s)])
|
| |
|
| |
|
| | sel0 = int((inp @ weights['sel0.weight'].T + weights['sel0.bias'] >= 0).item())
|
| | sel1 = int((inp @ weights['sel1.weight'].T + weights['sel1.bias'] >= 0).item())
|
| |
|
| |
|
| | l1 = torch.tensor([float(sel0), float(sel1)])
|
| | return int((l1 @ weights['or.weight'].T + weights['or.bias'] >= 0).item())
|
| |
|
| | if __name__ == '__main__':
|
| | w = load_model()
|
| | print('MUX2 Truth Table:')
|
| | print('s d0 d1 | out | expected')
|
| | print('-' * 28)
|
| | for s in [0, 1]:
|
| | for d0 in [0, 1]:
|
| | for d1 in [0, 1]:
|
| | result = mux2(d0, d1, s, w)
|
| | expected = d1 if s else d0
|
| | status = 'OK' if result == expected else 'FAIL'
|
| | print(f'{s} {d0} {d1} | {result} | {expected} {status}')
|
| |
|