| |
| |
|
|
| import os |
| from collections import OrderedDict |
| from torch.autograd import Variable |
| from options.test_options import TestOptions |
| from models.models import create_model |
| from models.mapping_model import Pix2PixHDModel_Mapping |
| import util.util as util |
| from PIL import Image |
| import torch |
| import torchvision.utils as vutils |
| import torchvision.transforms as transforms |
| import numpy as np |
| import cv2 |
|
|
| def data_transforms(img, method=Image.BILINEAR, scale=False): |
|
|
| ow, oh = img.size |
| pw, ph = ow, oh |
| if scale == True: |
| if ow < oh: |
| ow = 256 |
| oh = ph / pw * 256 |
| else: |
| oh = 256 |
| ow = pw / ph * 256 |
|
|
| h = int(round(oh / 4) * 4) |
| w = int(round(ow / 4) * 4) |
|
|
| if (h == ph) and (w == pw): |
| return img |
|
|
| return img.resize((w, h), method) |
|
|
|
|
| def data_transforms_rgb_old(img): |
| w, h = img.size |
| A = img |
| if w < 256 or h < 256: |
| A = transforms.Scale(256, Image.BILINEAR)(img) |
| return transforms.CenterCrop(256)(A) |
|
|
|
|
| def irregular_hole_synthesize(img, mask): |
|
|
| img_np = np.array(img).astype("uint8") |
| mask_np = np.array(mask).astype("uint8") |
| mask_np = mask_np / 255 |
| img_new = img_np * (1 - mask_np) + mask_np * 255 |
|
|
| hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB") |
|
|
| return hole_img |
|
|
|
|
| def parameter_set(opt): |
| |
| opt.serial_batches = True |
| opt.no_flip = True |
| opt.label_nc = 0 |
| opt.n_downsample_global = 3 |
| opt.mc = 64 |
| opt.k_size = 4 |
| opt.start_r = 1 |
| opt.mapping_n_block = 6 |
| opt.map_mc = 512 |
| opt.no_instance = True |
| opt.checkpoints_dir = "./checkpoints/restoration" |
| |
|
|
| if opt.Quality_restore: |
| opt.name = "mapping_quality" |
| opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") |
| opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality") |
| if opt.Scratch_and_Quality_restore: |
| opt.NL_res = True |
| opt.use_SN = True |
| opt.correlation_renormalize = True |
| opt.NL_use_mask = True |
| opt.NL_fusion_method = "combine" |
| opt.non_local = "Setting_42" |
| opt.name = "mapping_scratch" |
| opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") |
| opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") |
| if opt.HR: |
| opt.mapping_exp = 1 |
| opt.inference_optimize = True |
| opt.mask_dilation = 3 |
| opt.name = "mapping_Patch_Attention" |
|
|
|
|
| if __name__ == "__main__": |
|
|
| opt = TestOptions().parse(save=False) |
| parameter_set(opt) |
|
|
| model = Pix2PixHDModel_Mapping() |
|
|
| model.initialize(opt) |
| model.eval() |
|
|
| if not os.path.exists(opt.outputs_dir + "/" + "input_image"): |
| os.makedirs(opt.outputs_dir + "/" + "input_image") |
| if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): |
| os.makedirs(opt.outputs_dir + "/" + "restored_image") |
| if not os.path.exists(opt.outputs_dir + "/" + "origin"): |
| os.makedirs(opt.outputs_dir + "/" + "origin") |
|
|
| dataset_size = 0 |
|
|
| input_loader = os.listdir(opt.test_input) |
| dataset_size = len(input_loader) |
| input_loader.sort() |
|
|
| if opt.test_mask != "": |
| mask_loader = os.listdir(opt.test_mask) |
| dataset_size = len(os.listdir(opt.test_mask)) |
| mask_loader.sort() |
|
|
| img_transform = transforms.Compose( |
| [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| ) |
| mask_transform = transforms.ToTensor() |
|
|
| for i in range(dataset_size): |
|
|
| input_name = input_loader[i] |
| input_file = os.path.join(opt.test_input, input_name) |
| if not os.path.isfile(input_file): |
| print("Skipping non-file %s" % input_name) |
| continue |
| input = Image.open(input_file).convert("RGB") |
|
|
| print("Now you are processing %s" % (input_name)) |
|
|
| if opt.NL_use_mask: |
| mask_name = mask_loader[i] |
| mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") |
| if opt.mask_dilation != 0: |
| kernel = np.ones((3,3),np.uint8) |
| mask = np.array(mask) |
| mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation) |
| mask = Image.fromarray(mask.astype('uint8')) |
| origin = input |
| input = irregular_hole_synthesize(input, mask) |
| mask = mask_transform(mask) |
| mask = mask[:1, :, :] |
| mask = mask.unsqueeze(0) |
| input = img_transform(input) |
| input = input.unsqueeze(0) |
| else: |
| if opt.test_mode == "Scale": |
| input = data_transforms(input, scale=True) |
| if opt.test_mode == "Full": |
| input = data_transforms(input, scale=False) |
| if opt.test_mode == "Crop": |
| input = data_transforms_rgb_old(input) |
| origin = input |
| input = img_transform(input) |
| input = input.unsqueeze(0) |
| mask = torch.zeros_like(input) |
| |
|
|
| try: |
| with torch.no_grad(): |
| generated = model.inference(input, mask) |
| except Exception as ex: |
| print("Skip %s due to an error:\n%s" % (input_name, str(ex))) |
| continue |
|
|
| if input_name.endswith(".jpg"): |
| input_name = input_name[:-4] + ".png" |
|
|
| image_grid = vutils.save_image( |
| (input + 1.0) / 2.0, |
| opt.outputs_dir + "/input_image/" + input_name, |
| nrow=1, |
| padding=0, |
| normalize=True, |
| ) |
| image_grid = vutils.save_image( |
| (generated.data.cpu() + 1.0) / 2.0, |
| opt.outputs_dir + "/restored_image/" + input_name, |
| nrow=1, |
| padding=0, |
| normalize=True, |
| ) |
|
|
| origin.save(opt.outputs_dir + "/origin/" + input_name) |