|
|
import cv2 |
|
|
import numpy as np |
|
|
import onnxruntime |
|
|
|
|
|
def colorize_image(input_image): |
|
|
|
|
|
ort_session = onnxruntime.InferenceSession('models/deoldify_artistic.onnx') |
|
|
input_name = ort_session.get_inputs()[0].name |
|
|
output_name = ort_session.get_outputs()[0].name |
|
|
|
|
|
|
|
|
temp_frame = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY) |
|
|
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) |
|
|
temp_frame = cv2.resize(temp_frame, (256, 256)) |
|
|
temp_frame = temp_frame.transpose((2, 0, 1)) |
|
|
temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32) |
|
|
|
|
|
|
|
|
ort_outs = ort_session.run([output_name], {input_name: temp_frame}) |
|
|
result = ort_outs[0][0] |
|
|
|
|
|
|
|
|
colorized_frame = result.transpose(1, 2, 0) |
|
|
colorized_frame = cv2.resize(colorized_frame, (input_image.shape[1], input_image.shape[0])) |
|
|
temp_blue_channel, _, _ = cv2.split(input_image) |
|
|
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8) |
|
|
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB) |
|
|
_, color_green_channel, color_red_channel = cv2.split(colorized_frame) |
|
|
colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel)) |
|
|
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR) |
|
|
|
|
|
return colorized_frame.astype(np.uint8) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |
|
|
|
|
|
|