File size: 1,412 Bytes
93871a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import cv2
import numpy as np
import onnxruntime

def colorize_image(input_image):

    ort_session = onnxruntime.InferenceSession('models/deoldify_artistic.onnx')
    input_name = ort_session.get_inputs()[0].name
    output_name = ort_session.get_outputs()[0].name

    # Preprocess image
    temp_frame = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)
    temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB)
    temp_frame = cv2.resize(temp_frame, (256, 256))
    temp_frame = temp_frame.transpose((2, 0, 1))
    temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32)

    # Run inference
    ort_outs = ort_session.run([output_name], {input_name: temp_frame})
    result = ort_outs[0][0]

    # Postprocess result
    colorized_frame = result.transpose(1, 2, 0)
    colorized_frame = cv2.resize(colorized_frame, (input_image.shape[1], input_image.shape[0]))
    temp_blue_channel, _, _ = cv2.split(input_image)
    colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8)
    colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB)
    _, color_green_channel, color_red_channel = cv2.split(colorized_frame)
    colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel))
    colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR)

    return colorized_frame.astype(np.uint8)

if __name__ == "__main__":
    pass