File size: 1,082 Bytes
5937f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def encode8to3(i7, i6, i5, i4, i3, i2, i1, i0, weights):
    """Priority encoder: returns binary index of highest-set input."""
    inp = torch.tensor([float(i7), float(i6), float(i5), float(i4),
                        float(i3), float(i2), float(i1), float(i0)])
    y2 = int((inp @ weights['y2.weight'].T + weights['y2.bias'] >= 0).item())
    y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item())
    y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item())
    return y2, y1, y0

if __name__ == '__main__':
    w = load_model()
    print('8-to-3 Priority Encoder examples:')
    for val in [0b10000000, 0b01000000, 0b00100000, 0b00010000,
                0b00001000, 0b00000100, 0b00000010, 0b00000001, 0b11111111]:
        bits = [(val >> (7-i)) & 1 for i in range(8)]
        y2, y1, y0 = encode8to3(*bits, w)
        print(f'  I={"".join(map(str,bits))} -> {y2}{y1}{y0} (={4*y2+2*y1+y0})')