|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
def load_model(path='model.safetensors'): |
|
|
return load_file(path) |
|
|
|
|
|
def barrelshift4(a3, a2, a1, a0, s1, s0, weights): |
|
|
"""4-bit left barrel shifter. Shifts left by s = 2*s1 + s0 positions.""" |
|
|
inp = torch.tensor([float(a3), float(a2), float(a1), float(a0), float(s1), float(s0)]) |
|
|
|
|
|
|
|
|
l1_names = ['a3_s00', 'a2_s00', 'a2_s01', 'a1_s00', 'a1_s01', 'a1_s10', 'a0_s00', 'a0_s01', 'a0_s10', 'a0_s11'] |
|
|
l1 = [] |
|
|
for name in l1_names: |
|
|
v = int((inp @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item()) |
|
|
l1.append(float(v)) |
|
|
l1_tensor = torch.tensor(l1) |
|
|
|
|
|
|
|
|
outputs = [] |
|
|
for name in ['y3', 'y2', 'y1', 'y0']: |
|
|
v = int((l1_tensor @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item()) |
|
|
outputs.append(v) |
|
|
return outputs |
|
|
|
|
|
if __name__ == '__main__': |
|
|
w = load_model() |
|
|
print('barrelshift4 examples:') |
|
|
test_cases = [ |
|
|
(0, 0, 0, 1, 0, 0), |
|
|
(0, 0, 0, 1, 0, 1), |
|
|
(0, 0, 0, 1, 1, 0), |
|
|
(0, 0, 0, 1, 1, 1), |
|
|
(1, 0, 1, 0, 0, 1), |
|
|
] |
|
|
for args in test_cases: |
|
|
a3, a2, a1, a0, s1, s0 = args |
|
|
result = barrelshift4(a3, a2, a1, a0, s1, s0, w) |
|
|
s = s1 * 2 + s0 |
|
|
print(f' {a3}{a2}{a1}{a0} << {s} = {result[0]}{result[1]}{result[2]}{result[3]}') |
|
|
|