import cv2 import numpy as np import onnxruntime def colorize_image(input_image): ort_session = onnxruntime.InferenceSession('models/deoldify_artistic.onnx') input_name = ort_session.get_inputs()[0].name output_name = ort_session.get_outputs()[0].name # Preprocess image temp_frame = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY) temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) temp_frame = cv2.resize(temp_frame, (256, 256)) temp_frame = temp_frame.transpose((2, 0, 1)) temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32) # Run inference ort_outs = ort_session.run([output_name], {input_name: temp_frame}) result = ort_outs[0][0] # Postprocess result colorized_frame = result.transpose(1, 2, 0) colorized_frame = cv2.resize(colorized_frame, (input_image.shape[1], input_image.shape[0])) temp_blue_channel, _, _ = cv2.split(input_image) colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8) colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB) _, color_green_channel, color_red_channel = cv2.split(colorized_frame) colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel)) colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR) return colorized_frame.astype(np.uint8) if __name__ == "__main__": pass