import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def priority_encode(bits, weights): """8-to-3 priority encoder with valid bit. Returns (y2, y1, y0, valid) where y2y1y0 is the index of highest active input. Highest index (x7) has highest priority. """ inp = torch.tensor([float(b) for b in bits]) # Compute winner for each input position winners = [] for i in range(8): win = int((inp * weights[f'winner{i}.weight']).sum() + weights[f'winner{i}.bias'] >= 0) winners.append(win) # Compute output bits from winners win_137 = torch.tensor([float(winners[i]) for i in [1,3,5,7]]) y0 = int((win_137 * weights['y0.weight']).sum() + weights['y0.bias'] >= 0) win_2367 = torch.tensor([float(winners[i]) for i in [2,3,6,7]]) y1 = int((win_2367 * weights['y1.weight']).sum() + weights['y1.bias'] >= 0) win_4567 = torch.tensor([float(winners[i]) for i in [4,5,6,7]]) y2 = int((win_4567 * weights['y2.weight']).sum() + weights['y2.bias'] >= 0) valid = int((inp * weights['valid.weight']).sum() + weights['valid.bias'] >= 0) return y2, y1, y0, valid if __name__ == '__main__': w = load_model() print('8-to-3 Priority Encoder (highest index wins)') print('Input -> Index, Valid') # Single-bit tests for i in range(8): bits = [0]*8 bits[i] = 1 y2, y1, y0, valid = priority_encode(bits, w) print(f'x{i} only: {y2*4 + y1*2 + y0}, {valid}') # No input y2, y1, y0, valid = priority_encode([0]*8, w) print(f'None: {y2*4 + y1*2 + y0}, valid={valid}')