File size: 1,103 Bytes
76f4679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def mux2(d0, d1, s, weights):
    """2:1 Multiplexer: returns d0 if s=0, d1 if s=1."""
    inp = torch.tensor([float(d0), float(d1), float(s)])

    # Layer 1: Selection gates
    sel0 = int((inp @ weights['sel0.weight'].T + weights['sel0.bias'] >= 0).item())
    sel1 = int((inp @ weights['sel1.weight'].T + weights['sel1.bias'] >= 0).item())

    # Layer 2: OR gate
    l1 = torch.tensor([float(sel0), float(sel1)])
    return int((l1 @ weights['or.weight'].T + weights['or.bias'] >= 0).item())

if __name__ == '__main__':
    w = load_model()
    print('MUX2 Truth Table:')
    print('s  d0 d1 | out | expected')
    print('-' * 28)
    for s in [0, 1]:
        for d0 in [0, 1]:
            for d1 in [0, 1]:
                result = mux2(d0, d1, s, w)
                expected = d1 if s else d0
                status = 'OK' if result == expected else 'FAIL'
                print(f'{s}  {d0}  {d1}  |  {result}  |    {expected}     {status}')