import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def at_least_1_of_6(bits, weights): """Returns 1 if at least 1 of 6 inputs is high.""" inputs = torch.tensor([float(b) for b in bits]) return int((inputs @ weights['neuron.weight'].T + weights['neuron.bias'] >= 0).item()) if __name__ == '__main__': w = load_model() print('1-out-of-6 selected tests:') for hw in range(7): bits = [1 if i < hw else 0 for i in range(6)] result = at_least_1_of_6(bits, w) expected = 1 if hw >= 1 else 0 status = 'OK' if result == expected else 'FAIL' print(f' HW={hw}: {bits} -> {result} {status}')