| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def exactlyK(bits, weights): | |
| """Exactly-K-out-of-8 detector. | |
| bits: list of 8 binary values | |
| Returns: 1 if exactly K bits are set, 0 otherwise | |
| """ | |
| inp = torch.tensor([float(b) for b in bits]) | |
| atleast = int((inp * weights['atleast.weight']).sum() + weights['atleast.bias'] >= 0) | |
| atmost = int((inp * weights['atmost.weight']).sum() + weights['atmost.bias'] >= 0) | |
| return int((torch.tensor([float(atleast), float(atmost)]) * weights['and.weight']).sum() + weights['and.bias'] >= 0) | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('ExactlyKOutOf8 Detector') | |
| for hw in range(9): | |
| bits = [1] * hw + [0] * (8 - hw) | |
| result = exactlyK(bits, w) | |
| print(f'HW={hw}: {result}') | |