| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def encode8to3(i7, i6, i5, i4, i3, i2, i1, i0, weights): | |
| """Priority encoder: returns binary index of highest-set input.""" | |
| inp = torch.tensor([float(i7), float(i6), float(i5), float(i4), | |
| float(i3), float(i2), float(i1), float(i0)]) | |
| y2 = int((inp @ weights['y2.weight'].T + weights['y2.bias'] >= 0).item()) | |
| y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item()) | |
| y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item()) | |
| return y2, y1, y0 | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('8-to-3 Priority Encoder examples:') | |
| for val in [0b10000000, 0b01000000, 0b00100000, 0b00010000, | |
| 0b00001000, 0b00000100, 0b00000010, 0b00000001, 0b11111111]: | |
| bits = [(val >> (7-i)) & 1 for i in range(8)] | |
| y2, y1, y0 = encode8to3(*bits, w) | |
| print(f' I={"".join(map(str,bits))} -> {y2}{y1}{y0} (={4*y2+2*y1+y0})') | |