phanerozoic's picture
Upload folder using huggingface_hub
f5fb315 verified
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def winner_take_all(inputs, weights):
"""4-input Winner-Take-All. Returns one-hot output if exactly one input active."""
inp = torch.tensor([float(x) for x in inputs])
outputs = []
for i in range(4):
out = int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0)
outputs.append(out)
return outputs
if __name__ == '__main__':
w = load_model()
print('4-input Winner-Take-All')
print('Single active -> that input wins')
print('Multiple active -> no winner (tie)')
print()
tests = [
[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1],
[1,1,0,0], [0,0,0,0], [1,1,1,1]
]
for inputs in tests:
outputs = winner_take_all(inputs, w)
print(f'{inputs} -> {outputs}')