| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def ctz8(bits, w): | |
| """Count trailing zeros in 8-bit input. bits[0] is LSB.""" | |
| inp = torch.tensor([float(b) for b in bits]) | |
| p = [] | |
| for i in range(8): | |
| pi = int((inp @ w[f'p{i}.weight'].T + w[f'p{i}.bias'] >= 0).item()) | |
| p.append(pi) | |
| pZ = int((inp @ w['pZ.weight'].T + w['pZ.bias'] >= 0).item()) | |
| c0 = 1 if (p[1] or p[3] or p[5] or p[7]) else 0 | |
| c1 = 1 if (p[2] or p[3] or p[6] or p[7]) else 0 | |
| c2 = 1 if (p[4] or p[5] or p[6] or p[7]) else 0 | |
| c3 = pZ | |
| return c3, c2, c1, c0 | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('CTZ8 selected tests:') | |
| print('Input | CTZ | c3c2c1c0') | |
| print('---------+-----+---------') | |
| test_vals = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x00, 0x06, 0xF0, 0xFF] | |
| for val in test_vals: | |
| bits = [(val >> j) & 1 for j in range(8)] | |
| c3, c2, c1, c0 = ctz8(bits, w) | |
| count = 8*c3 + 4*c2 + 2*c1 + c0 | |
| print(f'{val:08b} | {count} | {c3}{c2}{c1}{c0}') | |