File size: 1,357 Bytes
7bf820e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def priority_encode(i7, i6, i5, i4, i3, i2, i1, i0, weights):
"""8-to-3 priority encoder. Returns (y2, y1, y0, valid)."""
inputs = [i7, i6, i5, i4, i3, i2, i1, i0]
inp = torch.tensor([float(x) for x in inputs])
# Layer 1
h = []
for k in range(8):
hk = int((inp @ weights[f'layer1.h{k}.weight'].T + weights[f'layer1.h{k}.bias'] >= 0).item())
h.append(hk)
h_tensor = torch.tensor([float(x) for x in h])
# Layer 2
y2 = int((h_tensor @ weights['layer2.y2.weight'].T + weights['layer2.y2.bias'] >= 0).item())
y1 = int((h_tensor @ weights['layer2.y1.weight'].T + weights['layer2.y1.bias'] >= 0).item())
y0 = int((h_tensor @ weights['layer2.y0.weight'].T + weights['layer2.y0.bias'] >= 0).item())
v = int((h_tensor @ weights['layer2.v.weight'].T + weights['layer2.v.bias'] >= 0).item())
return y2, y1, y0, v
if __name__ == '__main__':
w = load_model()
print('Priority Encoder 8 (selected tests)')
for val in [0, 1, 2, 4, 8, 16, 32, 64, 128, 255]:
inputs = [(val >> (7-j)) & 1 for j in range(8)]
y2, y1, y0, v = priority_encode(*inputs, w)
idx = 4*y2 + 2*y1 + y0
print(f' {val:3d} ({val:08b}) -> y={idx} ({y2}{y1}{y0}) v={v}')
|