import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def thermometer_to_binary(therm, weights): """Convert 7-bit thermometer to 3-bit binary. Returns (b2, b1, b0). """ t = torch.tensor([float(x) for x in therm]) # b2 = t3 (direct) b2 = int((t * weights['b2.weight']).sum() + weights['b2.bias'] >= 0) # b1 = (t1 AND NOT(t3)) OR t5 and_result = int((t * weights['b1_and.weight']).sum() + weights['b1_and.bias'] >= 0) b1_inputs = torch.tensor([float(and_result), float(therm[5])]) b1 = int((b1_inputs * weights['b1_or.weight']).sum() + weights['b1_or.bias'] >= 0) # b0 = (t0 AND NOT(t1)) OR (t2 AND NOT(t3)) OR (t4 AND NOT(t5)) OR t6 and0 = int((t * weights['b0_and0.weight']).sum() + weights['b0_and0.bias'] >= 0) and2 = int((t * weights['b0_and2.weight']).sum() + weights['b0_and2.bias'] >= 0) and4 = int((t * weights['b0_and4.weight']).sum() + weights['b0_and4.bias'] >= 0) b0_inputs = torch.tensor([float(and0), float(and2), float(and4), float(therm[6])]) b0 = int((b0_inputs * weights['b0_or.weight']).sum() + weights['b0_or.bias'] >= 0) return b2, b1, b0 if __name__ == '__main__': w = load_model() print('Thermometer to Binary Converter') thermometers = [ [0,0,0,0,0,0,0], [1,0,0,0,0,0,0], [1,1,0,0,0,0,0], [1,1,1,0,0,0,0], [1,1,1,1,0,0,0], [1,1,1,1,1,0,0], [1,1,1,1,1,1,0], [1,1,1,1,1,1,1], ] for therm in thermometers: b2, b1, b0 = thermometer_to_binary(therm, w) print(f"{''.join(map(str,therm))} -> {b2*4 + b1*2 + b0}")