import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def barrelshift4(a3, a2, a1, a0, s1, s0, weights): """4-bit left barrel shifter. Shifts left by s = 2*s1 + s0 positions.""" inp = torch.tensor([float(a3), float(a2), float(a1), float(a0), float(s1), float(s0)]) # Layer 1 l1_names = ['a3_s00', 'a2_s00', 'a2_s01', 'a1_s00', 'a1_s01', 'a1_s10', 'a0_s00', 'a0_s01', 'a0_s10', 'a0_s11'] l1 = [] for name in l1_names: v = int((inp @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item()) l1.append(float(v)) l1_tensor = torch.tensor(l1) # Layer 2 outputs = [] for name in ['y3', 'y2', 'y1', 'y0']: v = int((l1_tensor @ weights[f'{name}.weight'].T + weights[f'{name}.bias'] >= 0).item()) outputs.append(v) return outputs if __name__ == '__main__': w = load_model() print('barrelshift4 examples:') test_cases = [ (0, 0, 0, 1, 0, 0), # 0001 << 0 = 0001 (0, 0, 0, 1, 0, 1), # 0001 << 1 = 0010 (0, 0, 0, 1, 1, 0), # 0001 << 2 = 0100 (0, 0, 0, 1, 1, 1), # 0001 << 3 = 1000 (1, 0, 1, 0, 0, 1), # 1010 << 1 = 0100 ] for args in test_cases: a3, a2, a1, a0, s1, s0 = args result = barrelshift4(a3, a2, a1, a0, s1, s0, w) s = s1 * 2 + s0 print(f' {a3}{a2}{a1}{a0} << {s} = {result[0]}{result[1]}{result[2]}{result[3]}')