File size: 1,031 Bytes
96d9c09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def compare8(a, b, weights):
    """8-bit magnitude comparator. Returns (GT, LT, EQ)."""
    a_bits = [(a >> (7-i)) & 1 for i in range(8)]
    b_bits = [(b >> (7-i)) & 1 for i in range(8)]
    inp = torch.tensor([float(x) for x in a_bits + b_bits])

    gt = int((inp @ weights['gt.weight'].T + weights['gt.bias'] >= 0).item())
    lt = int((inp @ weights['lt.weight'].T + weights['lt.bias'] >= 0).item())
    gt_lt = torch.tensor([float(gt), float(lt)])
    eq = int((gt_lt @ weights['eq.weight'].T + weights['eq.bias'] >= 0).item())
    return gt, lt, eq

if __name__ == '__main__':
    w = load_model()
    print('8-bit Magnitude Comparator:')
    examples = [(0, 0), (1, 0), (0, 1), (127, 128), (255, 255), (200, 100)]
    for a, b in examples:
        gt, lt, eq = compare8(a, b, w)
        rel = '>' if gt else ('<' if lt else '=')
        print(f'  {a:3d} {rel} {b:3d}  (GT={gt}, LT={lt}, EQ={eq})')