File size: 676 Bytes
1575d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def demux(d, s, weights):
    """1:2 Demultiplexer: routes d to y0 if s=0, to y1 if s=1"""
    inp = torch.tensor([float(d), float(s)])
    y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item())
    y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item())
    return y0, y1

if __name__ == '__main__':
    w = load_model()
    print('DEMUX truth table:')
    for d in [0, 1]:
        for s in [0, 1]:
            y0, y1 = demux(d, s, w)
            print(f'DEMUX({d}, s={s}) = y0={y0}, y1={y1}')