File size: 1,179 Bytes
bf721a0
 
 
 
 
 
 
 
 
 
 
 
185e212
 
bf721a0
 
185e212
bf721a0
 
 
 
 
185e212
bf721a0
185e212
bf721a0
185e212
 
bf721a0
185e212
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def xor_gate(a, b, w, idx):
    inp = torch.tensor([float(a), float(b)])
    l1 = (inp @ w[f'xor{idx}.layer1.weight'].T + w[f'xor{idx}.layer1.bias'] >= 0).float()
    return int((l1 @ w[f'xor{idx}.layer2.weight'].T + w[f'xor{idx}.layer2.bias'] >= 0).item())

def equal(a, b, weights):
    """8-bit equality comparator.

    a, b: lists of 8 bits each (LSB first)

    Returns: 1 if a == b, 0 otherwise

    """
    xors = [xor_gate(a[i], b[i], weights, i) for i in range(8)]
    xor_vec = torch.tensor([float(x) for x in xors])
    return int((xor_vec @ weights['nor.weight'].T + weights['nor.bias'] >= 0).item())

if __name__ == '__main__':
    w = load_model()
    print('8-bit Equal Comparator')
    print('a == b tests:')
    tests = [(0, 0), (0, 1), (127, 127), (127, 128), (255, 255), (255, 0), (100, 100)]
    for a_val, b_val in tests:
        a = [(a_val >> i) & 1 for i in range(8)]
        b = [(b_val >> i) & 1 for i in range(8)]
        result = equal(a, b, w)
        print(f'{a_val:3d} == {b_val:3d} = {result}')