File size: 3,670 Bytes
0caed3c 063c371 8ed6625 0caed3c 8ed6625 0caed3c df23063 0caed3c 063c371 0caed3c 8ed6625 df23063 8ed6625 f6df16f 8ed6625 84022c3 8ed6625 f6df16f df23063 f6df16f 063c371 f6df16f 063c371 f6df16f 063c371 f6df16f 063c371 f6df16f 063c371 f6df16f 063c371 f6df16f 063c371 f6df16f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import numpy as np
from skimage import color, io
import torch
import torch.nn.functional as F
from PIL import Image
from models import ColorEncoder, ColorUNet
from extractor.manga_panel_extractor import PanelExtractor
import argparse
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
def mkdirs(path):
if not os.path.exists(path):
os.makedirs(path)
def Lab2RGB_out(img_lab):
img_lab = img_lab.detach().cpu()
img_l = img_lab[:,:1,:,:]
img_ab = img_lab[:,1:,:,:]
img_l = img_l + 50
pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
return out
def RGB2Lab(inputs):
return color.rgb2lab(inputs)
def Normalize(inputs):
l = inputs[:, :, 0:1]
ab = inputs[:, :, 1:3]
l = l - 50
lab = np.concatenate((l, ab), 2)
return lab.astype('float32')
def numpy2tensor(inputs):
out = torch.from_numpy(inputs.transpose(2,0,1))
return out
def tensor2numpy(inputs):
out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
return out
def preprocessing(inputs):
img_lab = Normalize(RGB2Lab(inputs))
img = np.array(inputs, 'float32')
img = numpy2tensor(img)
img_lab = numpy2tensor(img_lab)
return img.unsqueeze(0), img_lab.unsqueeze(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Colorize manga images.")
parser.add_argument("-i", "--input", type=str, required=True, help="Path to input image directory")
parser.add_argument("-r", "--reference", type=str, required=True, help="Path to reference image")
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
parser.add_argument("-ckpt", "--checkpoint", type=str, required=True, help="Path to model checkpoint")
args = parser.parse_args()
device = "cuda"
input_image_dir = args.input
output_directory = args.output
ckpt_path = args.checkpoint
reference_image_path = args.reference
imgsize = 256
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
colorEncoder = ColorEncoder().to(device)
colorEncoder.load_state_dict(ckpt["colorEncoder"])
colorEncoder.eval()
colorUNet = ColorUNet().to(device)
colorUNet.load_state_dict(ckpt["colorUNet"])
colorUNet.eval()
img_name = os.path.splitext(os.path.basename(img_path))[0]
img1 = Image.open(img_path).convert("RGB")
width, height = img1.size
img1, img1_lab = preprocessing(img1)
img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))
img1 = img1.to(device)
img1_lab = img1_lab.to(device)
img2 = img2.to(device)
img2_lab = img2_lab.to(device)
with torch.no_grad():
img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
color_vector = colorEncoder(img2_resize)
fake_ab = colorUNet((img1_L_resize, color_vector))
fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
fake_img = Lab2RGB_out(fake_img)
out_folder = os.path.dirname(img_path)
mkdirs(out_folder)
out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
io.imsave(out_img_path, fake_img)
print(f'Colored image has been saved to {out_img_path}.')
|