File size: 854 Bytes
2978b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def exactlyK(bits, weights):
    """Exactly-K-out-of-8 detector.
    bits: list of 8 binary values
    Returns: 1 if exactly K bits are set, 0 otherwise
    """
    inp = torch.tensor([float(b) for b in bits])
    atleast = int((inp * weights['atleast.weight']).sum() + weights['atleast.bias'] >= 0)
    atmost = int((inp * weights['atmost.weight']).sum() + weights['atmost.bias'] >= 0)
    return int((torch.tensor([float(atleast), float(atmost)]) * weights['and.weight']).sum() + weights['and.bias'] >= 0)

if __name__ == '__main__':
    w = load_model()
    print('ExactlyKOutOf8 Detector')
    for hw in range(9):
        bits = [1] * hw + [0] * (8 - hw)
        result = exactlyK(bits, w)
        print(f'HW={hw}: {result}')