import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def iszero8(bits, weights): inp = torch.tensor([float(b) for b in bits]) return int((inp @ weights['neuron.weight'].T + weights['neuron.bias'] >= 0).item()) if __name__ == '__main__': w = load_model() print('iszero8: outputs 1 only for input 00000000') for i in [0, 1, 128, 255]: bits = [(i >> j) & 1 for j in range(8)] print(f' {i:08b} -> {iszero8(bits, w)}')