File size: 1,357 Bytes
7bf820e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def priority_encode(i7, i6, i5, i4, i3, i2, i1, i0, weights):
    """8-to-3 priority encoder. Returns (y2, y1, y0, valid)."""
    inputs = [i7, i6, i5, i4, i3, i2, i1, i0]
    inp = torch.tensor([float(x) for x in inputs])
    # Layer 1
    h = []
    for k in range(8):
        hk = int((inp @ weights[f'layer1.h{k}.weight'].T + weights[f'layer1.h{k}.bias'] >= 0).item())
        h.append(hk)
    h_tensor = torch.tensor([float(x) for x in h])
    # Layer 2
    y2 = int((h_tensor @ weights['layer2.y2.weight'].T + weights['layer2.y2.bias'] >= 0).item())
    y1 = int((h_tensor @ weights['layer2.y1.weight'].T + weights['layer2.y1.bias'] >= 0).item())
    y0 = int((h_tensor @ weights['layer2.y0.weight'].T + weights['layer2.y0.bias'] >= 0).item())
    v = int((h_tensor @ weights['layer2.v.weight'].T + weights['layer2.v.bias'] >= 0).item())
    return y2, y1, y0, v

if __name__ == '__main__':
    w = load_model()
    print('Priority Encoder 8 (selected tests)')
    for val in [0, 1, 2, 4, 8, 16, 32, 64, 128, 255]:
        inputs = [(val >> (7-j)) & 1 for j in range(8)]
        y2, y1, y0, v = priority_encode(*inputs, w)
        idx = 4*y2 + 2*y1 + y0
        print(f'  {val:3d} ({val:08b}) -> y={idx} ({y2}{y1}{y0}) v={v}')