| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def winner_take_all(inputs, weights): | |
| """4-input Winner-Take-All. Returns one-hot output if exactly one input active.""" | |
| inp = torch.tensor([float(x) for x in inputs]) | |
| outputs = [] | |
| for i in range(4): | |
| out = int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0) | |
| outputs.append(out) | |
| return outputs | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('4-input Winner-Take-All') | |
| print('Single active -> that input wins') | |
| print('Multiple active -> no winner (tie)') | |
| print() | |
| tests = [ | |
| [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1], | |
| [1,1,0,0], [0,0,0,0], [1,1,1,1] | |
| ] | |
| for inputs in tests: | |
| outputs = winner_take_all(inputs, w) | |
| print(f'{inputs} -> {outputs}') | |