import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def priority_encode(i7, i6, i5, i4, i3, i2, i1, i0, weights): """8-to-3 priority encoder. Returns (y2, y1, y0, valid).""" inputs = [i7, i6, i5, i4, i3, i2, i1, i0] inp = torch.tensor([float(x) for x in inputs]) # Layer 1 h = [] for k in range(8): hk = int((inp @ weights[f'layer1.h{k}.weight'].T + weights[f'layer1.h{k}.bias'] >= 0).item()) h.append(hk) h_tensor = torch.tensor([float(x) for x in h]) # Layer 2 y2 = int((h_tensor @ weights['layer2.y2.weight'].T + weights['layer2.y2.bias'] >= 0).item()) y1 = int((h_tensor @ weights['layer2.y1.weight'].T + weights['layer2.y1.bias'] >= 0).item()) y0 = int((h_tensor @ weights['layer2.y0.weight'].T + weights['layer2.y0.bias'] >= 0).item()) v = int((h_tensor @ weights['layer2.v.weight'].T + weights['layer2.v.bias'] >= 0).item()) return y2, y1, y0, v if __name__ == '__main__': w = load_model() print('Priority Encoder 8 (selected tests)') for val in [0, 1, 2, 4, 8, 16, 32, 64, 128, 255]: inputs = [(val >> (7-j)) & 1 for j in range(8)] y2, y1, y0, v = priority_encode(*inputs, w) idx = 4*y2 + 2*y1 + y0 print(f' {val:3d} ({val:08b}) -> y={idx} ({y2}{y1}{y0}) v={v}')