import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def encode8to3(i7, i6, i5, i4, i3, i2, i1, i0, weights): """Priority encoder: returns binary index of highest-set input.""" inp = torch.tensor([float(i7), float(i6), float(i5), float(i4), float(i3), float(i2), float(i1), float(i0)]) y2 = int((inp @ weights['y2.weight'].T + weights['y2.bias'] >= 0).item()) y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item()) y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item()) return y2, y1, y0 if __name__ == '__main__': w = load_model() print('8-to-3 Priority Encoder examples:') for val in [0b10000000, 0b01000000, 0b00100000, 0b00010000, 0b00001000, 0b00000100, 0b00000010, 0b00000001, 0b11111111]: bits = [(val >> (7-i)) & 1 for i in range(8)] y2, y1, y0 = encode8to3(*bits, w) print(f' I={"".join(map(str,bits))} -> {y2}{y1}{y0} (={4*y2+2*y1+y0})')