File size: 1,109 Bytes
f4de00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def compare4(a3, a2, a1, a0, b3, b2, b1, b0, weights):
    """4-bit magnitude comparator. Returns (GT, LT, EQ)."""
    inp = torch.tensor([float(a3), float(a2), float(a1), float(a0),
                        float(b3), float(b2), float(b1), float(b0)])
    gt = int((inp @ weights['gt.weight'].T + weights['gt.bias'] >= 0).item())
    lt = int((inp @ weights['lt.weight'].T + weights['lt.bias'] >= 0).item())
    gt_lt = torch.tensor([float(gt), float(lt)])
    eq = int((gt_lt @ weights['eq.weight'].T + weights['eq.bias'] >= 0).item())
    return gt, lt, eq

if __name__ == '__main__':
    w = load_model()
    print('Comparator4bit examples:')
    for a, b in [(5, 3), (3, 5), (7, 7), (0, 15), (15, 0)]:
        a3, a2, a1, a0 = (a >> 3) & 1, (a >> 2) & 1, (a >> 1) & 1, a & 1
        b3, b2, b1, b0 = (b >> 3) & 1, (b >> 2) & 1, (b >> 1) & 1, b & 1
        gt, lt, eq = compare4(a3, a2, a1, a0, b3, b2, b1, b0, w)
        print(f'  A={a:2d}, B={b:2d} -> GT={gt}, LT={lt}, EQ={eq}')