File size: 1,468 Bytes
3635a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def barrelshift4(a3, a2, a1, a0, s1, s0, weights):
    """4-bit left barrel shifter. Shifts left by s = 2*s1 + s0 positions."""
    inp = torch.tensor([float(a3), float(a2), float(a1), float(a0), float(s1), float(s0)])

    # Layer 1
    l1_names = ['a3_s00', 'a2_s00', 'a2_s01', 'a1_s00', 'a1_s01', 'a1_s10', 'a0_s00', 'a0_s01', 'a0_s10', 'a0_s11']
    l1 = []
    for name in l1_names:
        v = int((inp @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item())
        l1.append(float(v))
    l1_tensor = torch.tensor(l1)

    # Layer 2
    outputs = []
    for name in ['y3', 'y2', 'y1', 'y0']:
        v = int((l1_tensor @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item())
        outputs.append(v)
    return outputs

if __name__ == '__main__':
    w = load_model()
    print('barrelshift4 examples:')
    test_cases = [
        (0, 0, 0, 1, 0, 0),  # 0001 << 0 = 0001
        (0, 0, 0, 1, 0, 1),  # 0001 << 1 = 0010
        (0, 0, 0, 1, 1, 0),  # 0001 << 2 = 0100
        (0, 0, 0, 1, 1, 1),  # 0001 << 3 = 1000
        (1, 0, 1, 0, 0, 1),  # 1010 << 1 = 0100
    ]
    for args in test_cases:
        a3, a2, a1, a0, s1, s0 = args
        result = barrelshift4(a3, a2, a1, a0, s1, s0, w)
        s = s1 * 2 + s0
        print(f'  {a3}{a2}{a1}{a0} << {s} = {result[0]}{result[1]}{result[2]}{result[3]}')