File size: 1,577 Bytes
2198727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370ec48
 
2198727
 
 
370ec48
2198727
 
370ec48
2198727
 
 
370ec48
2198727
370ec48
2198727
370ec48
 
2198727
370ec48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def half_sub(a, b, w, prefix):
    inp = torch.tensor([float(a), float(b)])
    l1 = (inp @ w[f'{prefix}.xor.layer1.weight'].T + w[f'{prefix}.xor.layer1.bias'] >= 0).float()
    d = float((l1 @ w[f'{prefix}.xor.layer2.weight'].T + w[f'{prefix}.xor.layer2.bias'] >= 0).item())
    borrow = float((inp @ w[f'{prefix}.borrow.weight'].T + w[f'{prefix}.borrow.bias'] >= 0).item())
    return d, borrow

def full_sub(a, b, bin_in, w, prefix):
    d1, b1 = half_sub(a, b, w, f'{prefix}.hs1')
    d, b2 = half_sub(d1, bin_in, w, f'{prefix}.hs2')
    bout = int((torch.tensor([b1, b2]) @ w[f'{prefix}.bout.weight'].T + w[f'{prefix}.bout.bias'] >= 0).item())
    return int(d), bout

def less_than(a, b, weights):
    """8-bit less-than comparator.

    a, b: lists of 8 bits each (LSB first)

    Returns: 1 if a < b, 0 otherwise

    """
    borrows = [0]
    for i in range(8):
        d, bout = full_sub(a[i], b[i], borrows[i], weights, f'fs{i}')
        borrows.append(bout)
    return borrows[8]

if __name__ == '__main__':
    w = load_model()
    print('8-bit LessThan Comparator')
    print('a < b tests:')
    tests = [(0, 0), (0, 1), (1, 0), (127, 128), (255, 0), (0, 255), (100, 100), (99, 100)]
    for a_val, b_val in tests:
        a = [(a_val >> i) & 1 for i in range(8)]
        b = [(b_val >> i) & 1 for i in range(8)]
        result = less_than(a, b, w)
        print(f'{a_val:3d} < {b_val:3d} = {result}')