File size: 915 Bytes
f5fb315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def winner_take_all(inputs, weights):
    """4-input Winner-Take-All. Returns one-hot output if exactly one input active."""
    inp = torch.tensor([float(x) for x in inputs])
    outputs = []
    for i in range(4):
        out = int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0)
        outputs.append(out)
    return outputs

if __name__ == '__main__':
    w = load_model()
    print('4-input Winner-Take-All')
    print('Single active -> that input wins')
    print('Multiple active -> no winner (tie)')
    print()
    tests = [
        [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1],
        [1,1,0,0], [0,0,0,0], [1,1,1,1]
    ]
    for inputs in tests:
        outputs = winner_take_all(inputs, w)
        print(f'{inputs} -> {outputs}')