backup / scripts /colorizer.py
killbill007's picture
Upload 754 files
93871a1 verified
raw
history blame
1.41 kB
import cv2
import numpy as np
import onnxruntime
def colorize_image(input_image):
ort_session = onnxruntime.InferenceSession('models/deoldify_artistic.onnx')
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
# Preprocess image
temp_frame = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB)
temp_frame = cv2.resize(temp_frame, (256, 256))
temp_frame = temp_frame.transpose((2, 0, 1))
temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32)
# Run inference
ort_outs = ort_session.run([output_name], {input_name: temp_frame})
result = ort_outs[0][0]
# Postprocess result
colorized_frame = result.transpose(1, 2, 0)
colorized_frame = cv2.resize(colorized_frame, (input_image.shape[1], input_image.shape[0]))
temp_blue_channel, _, _ = cv2.split(input_image)
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8)
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB)
_, color_green_channel, color_red_channel = cv2.split(colorized_frame)
colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel))
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR)
return colorized_frame.astype(np.uint8)
if __name__ == "__main__":
pass