CharlesCNorton
4-bit left barrel shifter, magnitude 61
3635a46
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def barrelshift4(a3, a2, a1, a0, s1, s0, weights):
"""4-bit left barrel shifter. Shifts left by s = 2*s1 + s0 positions."""
inp = torch.tensor([float(a3), float(a2), float(a1), float(a0), float(s1), float(s0)])
# Layer 1
l1_names = ['a3_s00', 'a2_s00', 'a2_s01', 'a1_s00', 'a1_s01', 'a1_s10', 'a0_s00', 'a0_s01', 'a0_s10', 'a0_s11']
l1 = []
for name in l1_names:
v = int((inp @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item())
l1.append(float(v))
l1_tensor = torch.tensor(l1)
# Layer 2
outputs = []
for name in ['y3', 'y2', 'y1', 'y0']:
v = int((l1_tensor @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item())
outputs.append(v)
return outputs
if __name__ == '__main__':
w = load_model()
print('barrelshift4 examples:')
test_cases = [
(0, 0, 0, 1, 0, 0), # 0001 << 0 = 0001
(0, 0, 0, 1, 0, 1), # 0001 << 1 = 0010
(0, 0, 0, 1, 1, 0), # 0001 << 2 = 0100
(0, 0, 0, 1, 1, 1), # 0001 << 3 = 1000
(1, 0, 1, 0, 0, 1), # 1010 << 1 = 0100
]
for args in test_cases:
a3, a2, a1, a0, s1, s0 = args
result = barrelshift4(a3, a2, a1, a0, s1, s0, w)
s = s1 * 2 + s0
print(f' {a3}{a2}{a1}{a0} << {s} = {result[0]}{result[1]}{result[2]}{result[3]}')