File size: 3,670 Bytes
0caed3c
 
 
 
 
 
 
063c371
8ed6625
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
8ed6625
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
df23063
0caed3c
 
063c371
 
 
 
 
 
 
 
 
 
 
0caed3c
8ed6625
 
 
 
 
df23063
8ed6625
f6df16f
8ed6625
 
 
84022c3
8ed6625
f6df16f
 
 
df23063
 
 
 
 
 
 
 
 
f6df16f
 
 
 
 
063c371
f6df16f
 
 
 
063c371
f6df16f
 
 
063c371
f6df16f
063c371
f6df16f
 
063c371
f6df16f
 
063c371
f6df16f
 
 
 
063c371
f6df16f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import numpy as np
from skimage import color, io
import torch
import torch.nn.functional as F
from PIL import Image
from models import ColorEncoder, ColorUNet
from extractor.manga_panel_extractor import PanelExtractor
import argparse

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

def Lab2RGB_out(img_lab):
    img_lab = img_lab.detach().cpu()
    img_l = img_lab[:,:1,:,:]
    img_ab = img_lab[:,1:,:,:]
    img_l = img_l + 50
    pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
    out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
    return out

def RGB2Lab(inputs):
    return color.rgb2lab(inputs)

def Normalize(inputs):
    l = inputs[:, :, 0:1]
    ab = inputs[:, :, 1:3]
    l = l - 50
    lab = np.concatenate((l, ab), 2)
    return lab.astype('float32')

def numpy2tensor(inputs):
    out = torch.from_numpy(inputs.transpose(2,0,1))
    return out

def tensor2numpy(inputs):
    out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
    return out

def preprocessing(inputs):
    img_lab = Normalize(RGB2Lab(inputs))
    img = np.array(inputs, 'float32')
    img = numpy2tensor(img)
    img_lab = numpy2tensor(img_lab)
    return img.unsqueeze(0), img_lab.unsqueeze(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Colorize manga images.")
    parser.add_argument("-i", "--input", type=str, required=True, help="Path to input image directory")
    parser.add_argument("-r", "--reference", type=str, required=True, help="Path to reference image")
    parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
    parser.add_argument("-ckpt", "--checkpoint", type=str, required=True, help="Path to model checkpoint")

    args = parser.parse_args()

    device = "cuda"
    input_image_dir = args.input
    output_directory = args.output
    ckpt_path = args.checkpoint
    reference_image_path = args.reference
    imgsize = 256

    ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)

    colorEncoder = ColorEncoder().to(device)
    colorEncoder.load_state_dict(ckpt["colorEncoder"])
    colorEncoder.eval()

    colorUNet = ColorUNet().to(device)
    colorUNet.load_state_dict(ckpt["colorUNet"])
    colorUNet.eval()

    img_name = os.path.splitext(os.path.basename(img_path))[0]
    img1 = Image.open(img_path).convert("RGB")
    width, height = img1.size
    img1, img1_lab = preprocessing(img1)
    img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))

    img1 = img1.to(device)
    img1_lab = img1_lab.to(device)
    img2 = img2.to(device)
    img2_lab = img2_lab.to(device)

    with torch.no_grad():
        img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
        img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)

        color_vector = colorEncoder(img2_resize)

        fake_ab = colorUNet((img1_L_resize, color_vector))
        fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)

        fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
        fake_img = Lab2RGB_out(fake_img)

        out_folder = os.path.dirname(img_path)
        mkdirs(out_folder)
        out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
        io.imsave(out_img_path, fake_img)

    print(f'Colored image has been saved to {out_img_path}.')