| | |
| | |
| |
|
| | import os |
| | from collections import OrderedDict |
| | from torch.autograd import Variable |
| | from options.test_options import TestOptions |
| | from models.models import create_model |
| | from models.mapping_model import Pix2PixHDModel_Mapping |
| | import util.util as util |
| | from PIL import Image |
| | import torch |
| | import torchvision.utils as vutils |
| | import torchvision.transforms as transforms |
| | import numpy as np |
| | import cv2 |
| |
|
| | def data_transforms(img, method=Image.BILINEAR, scale=False): |
| |
|
| | ow, oh = img.size |
| | pw, ph = ow, oh |
| | if scale == True: |
| | if ow < oh: |
| | ow = 256 |
| | oh = ph / pw * 256 |
| | else: |
| | oh = 256 |
| | ow = pw / ph * 256 |
| |
|
| | h = int(round(oh / 4) * 4) |
| | w = int(round(ow / 4) * 4) |
| |
|
| | if (h == ph) and (w == pw): |
| | return img |
| |
|
| | return img.resize((w, h), method) |
| |
|
| |
|
| | def data_transforms_rgb_old(img): |
| | w, h = img.size |
| | A = img |
| | if w < 256 or h < 256: |
| | A = transforms.Scale(256, Image.BILINEAR)(img) |
| | return transforms.CenterCrop(256)(A) |
| |
|
| |
|
| | def irregular_hole_synthesize(img, mask): |
| |
|
| | img_np = np.array(img).astype("uint8") |
| | mask_np = np.array(mask).astype("uint8") |
| | mask_np = mask_np / 255 |
| | img_new = img_np * (1 - mask_np) + mask_np * 255 |
| |
|
| | hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB") |
| |
|
| | return hole_img |
| |
|
| |
|
| | def parameter_set(opt): |
| | |
| | opt.serial_batches = True |
| | opt.no_flip = True |
| | opt.label_nc = 0 |
| | opt.n_downsample_global = 3 |
| | opt.mc = 64 |
| | opt.k_size = 4 |
| | opt.start_r = 1 |
| | opt.mapping_n_block = 6 |
| | opt.map_mc = 512 |
| | opt.no_instance = True |
| | opt.checkpoints_dir = "./checkpoints/restoration" |
| | |
| |
|
| | if opt.Quality_restore: |
| | opt.name = "mapping_quality" |
| | opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") |
| | opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality") |
| | if opt.Scratch_and_Quality_restore: |
| | opt.NL_res = True |
| | opt.use_SN = True |
| | opt.correlation_renormalize = True |
| | opt.NL_use_mask = True |
| | opt.NL_fusion_method = "combine" |
| | opt.non_local = "Setting_42" |
| | opt.name = "mapping_scratch" |
| | opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") |
| | opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") |
| | if opt.HR: |
| | opt.mapping_exp = 1 |
| | opt.inference_optimize = True |
| | opt.mask_dilation = 3 |
| | opt.name = "mapping_Patch_Attention" |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | opt = TestOptions().parse(save=False) |
| | parameter_set(opt) |
| |
|
| | model = Pix2PixHDModel_Mapping() |
| |
|
| | model.initialize(opt) |
| | model.eval() |
| |
|
| | if not os.path.exists(opt.outputs_dir + "/" + "input_image"): |
| | os.makedirs(opt.outputs_dir + "/" + "input_image") |
| | if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): |
| | os.makedirs(opt.outputs_dir + "/" + "restored_image") |
| | if not os.path.exists(opt.outputs_dir + "/" + "origin"): |
| | os.makedirs(opt.outputs_dir + "/" + "origin") |
| |
|
| | dataset_size = 0 |
| |
|
| | input_loader = os.listdir(opt.test_input) |
| | dataset_size = len(input_loader) |
| | input_loader.sort() |
| |
|
| | if opt.test_mask != "": |
| | mask_loader = os.listdir(opt.test_mask) |
| | dataset_size = len(os.listdir(opt.test_mask)) |
| | mask_loader.sort() |
| |
|
| | img_transform = transforms.Compose( |
| | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| | ) |
| | mask_transform = transforms.ToTensor() |
| |
|
| | for i in range(dataset_size): |
| |
|
| | input_name = input_loader[i] |
| | input_file = os.path.join(opt.test_input, input_name) |
| | if not os.path.isfile(input_file): |
| | print("Skipping non-file %s" % input_name) |
| | continue |
| | input = Image.open(input_file).convert("RGB") |
| |
|
| | print("Now you are processing %s" % (input_name)) |
| |
|
| | if opt.NL_use_mask: |
| | mask_name = mask_loader[i] |
| | mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") |
| | if opt.mask_dilation != 0: |
| | kernel = np.ones((3,3),np.uint8) |
| | mask = np.array(mask) |
| | mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation) |
| | mask = Image.fromarray(mask.astype('uint8')) |
| | origin = input |
| | input = irregular_hole_synthesize(input, mask) |
| | mask = mask_transform(mask) |
| | mask = mask[:1, :, :] |
| | mask = mask.unsqueeze(0) |
| | input = img_transform(input) |
| | input = input.unsqueeze(0) |
| | else: |
| | if opt.test_mode == "Scale": |
| | input = data_transforms(input, scale=True) |
| | if opt.test_mode == "Full": |
| | input = data_transforms(input, scale=False) |
| | if opt.test_mode == "Crop": |
| | input = data_transforms_rgb_old(input) |
| | origin = input |
| | input = img_transform(input) |
| | input = input.unsqueeze(0) |
| | mask = torch.zeros_like(input) |
| | |
| |
|
| | try: |
| | with torch.no_grad(): |
| | generated = model.inference(input, mask) |
| | except Exception as ex: |
| | print("Skip %s due to an error:\n%s" % (input_name, str(ex))) |
| | continue |
| |
|
| | if input_name.endswith(".jpg"): |
| | input_name = input_name[:-4] + ".png" |
| |
|
| | image_grid = vutils.save_image( |
| | (input + 1.0) / 2.0, |
| | opt.outputs_dir + "/input_image/" + input_name, |
| | nrow=1, |
| | padding=0, |
| | normalize=True, |
| | ) |
| | image_grid = vutils.save_image( |
| | (generated.data.cpu() + 1.0) / 2.0, |
| | opt.outputs_dir + "/restored_image/" + input_name, |
| | nrow=1, |
| | padding=0, |
| | normalize=True, |
| | ) |
| |
|
| | origin.save(opt.outputs_dir + "/origin/" + input_name) |