| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def half_sub(a, b, w, prefix): | |
| inp = torch.tensor([float(a), float(b)]) | |
| l1 = (inp @ w[f'{prefix}.xor.layer1.weight'].T + w[f'{prefix}.xor.layer1.bias'] >= 0).float() | |
| d = float((l1 @ w[f'{prefix}.xor.layer2.weight'].T + w[f'{prefix}.xor.layer2.bias'] >= 0).item()) | |
| borrow = float((inp @ w[f'{prefix}.borrow.weight'].T + w[f'{prefix}.borrow.bias'] >= 0).item()) | |
| return d, borrow | |
| def full_sub(a, b, bin_in, w, prefix): | |
| d1, b1 = half_sub(a, b, w, f'{prefix}.hs1') | |
| d, b2 = half_sub(d1, bin_in, w, f'{prefix}.hs2') | |
| bout = int((torch.tensor([b1, b2]) @ w[f'{prefix}.bout.weight'].T + w[f'{prefix}.bout.bias'] >= 0).item()) | |
| return int(d), bout | |
| def less_than(a, b, weights): | |
| """8-bit less-than comparator. | |
| a, b: lists of 8 bits each (LSB first) | |
| Returns: 1 if a < b, 0 otherwise | |
| """ | |
| borrows = [0] | |
| for i in range(8): | |
| d, bout = full_sub(a[i], b[i], borrows[i], weights, f'fs{i}') | |
| borrows.append(bout) | |
| return borrows[8] | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('8-bit LessThan Comparator') | |
| print('a < b tests:') | |
| tests = [(0, 0), (0, 1), (1, 0), (127, 128), (255, 0), (0, 255), (100, 100), (99, 100)] | |
| for a_val, b_val in tests: | |
| a = [(a_val >> i) & 1 for i in range(8)] | |
| b = [(b_val >> i) & 1 for i in range(8)] | |
| result = less_than(a, b, w) | |
| print(f'{a_val:3d} < {b_val:3d} = {result}') | |