import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def winner_take_all(inputs, weights): """4-input Winner-Take-All. Returns one-hot output if exactly one input active.""" inp = torch.tensor([float(x) for x in inputs]) outputs = [] for i in range(4): out = int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0) outputs.append(out) return outputs if __name__ == '__main__': w = load_model() print('4-input Winner-Take-All') print('Single active -> that input wins') print('Multiple active -> no winner (tie)') print() tests = [ [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1], [1,1,0,0], [0,0,0,0], [1,1,1,1] ] for inputs in tests: outputs = winner_take_all(inputs, w) print(f'{inputs} -> {outputs}')