import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def exactlyK(bits, weights): """Exactly-K-out-of-8 detector. bits: list of 8 binary values Returns: 1 if exactly K bits are set, 0 otherwise """ inp = torch.tensor([float(b) for b in bits]) atleast = int((inp * weights['atleast.weight']).sum() + weights['atleast.bias'] >= 0) atmost = int((inp * weights['atmost.weight']).sum() + weights['atmost.bias'] >= 0) return int((torch.tensor([float(atleast), float(atmost)]) * weights['and.weight']).sum() + weights['and.bias'] >= 0) if __name__ == '__main__': w = load_model() print('ExactlyKOutOf8 Detector') for hw in range(9): bits = [1] * hw + [0] * (8 - hw) result = exactlyK(bits, w) print(f'HW={hw}: {result}')