| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def compare8(a, b, weights): | |
| """8-bit magnitude comparator. Returns (GT, LT, EQ).""" | |
| a_bits = [(a >> (7-i)) & 1 for i in range(8)] | |
| b_bits = [(b >> (7-i)) & 1 for i in range(8)] | |
| inp = torch.tensor([float(x) for x in a_bits + b_bits]) | |
| gt = int((inp @ weights['gt.weight'].T + weights['gt.bias'] >= 0).item()) | |
| lt = int((inp @ weights['lt.weight'].T + weights['lt.bias'] >= 0).item()) | |
| gt_lt = torch.tensor([float(gt), float(lt)]) | |
| eq = int((gt_lt @ weights['eq.weight'].T + weights['eq.bias'] >= 0).item()) | |
| return gt, lt, eq | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('8-bit Magnitude Comparator:') | |
| examples = [(0, 0), (1, 0), (0, 1), (127, 128), (255, 255), (200, 100)] | |
| for a, b in examples: | |
| gt, lt, eq = compare8(a, b, w) | |
| rel = '>' if gt else ('<' if lt else '=') | |
| print(f' {a:3d} {rel} {b:3d} (GT={gt}, LT={lt}, EQ={eq})') | |