| | import os
|
| | import torch
|
| | import torch.nn as nn
|
| | from torch.utils.data import DataLoader
|
| | from types import SimpleNamespace
|
| | from deepfillv2 import test_dataset, utils
|
| | from config import *
|
| |
|
| | class InpaintingTester:
|
| | def __init__(self, save_path, resize_to=None):
|
| | if resize_to is None:
|
| | resize_to = RESIZE_TO
|
| | self.save_path = save_path
|
| | self.setsize = resize_to
|
| |
|
| |
|
| | opt = SimpleNamespace(
|
| | pad_type=PAD_TYPE,
|
| | in_channels=IN_CHANNELS,
|
| | out_channels=OUT_CHANNELS,
|
| | latent_channels=LATENT_CHANNELS,
|
| | activation=ACTIVATION,
|
| | norm=NORM,
|
| | init_type=INIT_TYPE,
|
| | init_gain=INIT_GAIN,
|
| | use_cuda=CUDA,
|
| | gpu_device=GPU_DEVICE,
|
| | )
|
| |
|
| |
|
| | self.generator = utils.create_generator(opt).eval()
|
| |
|
| |
|
| |
|
| | self.load_model_generator(self.generator)
|
| |
|
| |
|
| | self.generator = self.generator.to(GPU_DEVICE)
|
| |
|
| | def load_model_generator(self, generator):
|
| | pretrained_dict = torch.load(
|
| | DEEPFILL_MODEL_PATH, map_location=torch.device(GPU_DEVICE), weights_only=True
|
| | )
|
| | generator.load_state_dict(pretrained_dict)
|
| |
|
| | def process_image(self, in_image, mask_image, save_image_path):
|
| |
|
| | trainset = test_dataset.InpaintDataset(in_image, mask_image, self.setsize)
|
| | dataloader = DataLoader(
|
| | trainset,
|
| | batch_size=1,
|
| | shuffle=False,
|
| | num_workers=8,
|
| | pin_memory=True,
|
| | )
|
| |
|
| |
|
| | for batch_idx, (img, mask) in enumerate(dataloader):
|
| | img = img.to(GPU_DEVICE)
|
| | mask = mask.to(GPU_DEVICE)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | first_out, second_out = self.generator(img, mask)
|
| |
|
| |
|
| | first_out_wholeimg = img * (1 - mask) + first_out * mask
|
| | second_out_wholeimg = img * (1 - mask) + second_out * mask
|
| |
|
| | masked_img = img * (1 - mask) + mask
|
| | mask = torch.cat((mask, mask, mask), 1)
|
| | img_list = [second_out_wholeimg]
|
| | name_list = ["second_out"]
|
| |
|
| |
|
| | results_path = os.path.dirname(save_image_path)
|
| | if not os.path.exists(results_path):
|
| | os.makedirs(results_path)
|
| |
|
| | utils.save_sample_png(
|
| | sample_folder=results_path,
|
| | sample_name=os.path.basename(save_image_path),
|
| | img_list=img_list,
|
| | name_list=name_list,
|
| | pixel_max_cnt=255,
|
| | )
|
| |
|
| | def process_multiple_images(self, image_mask_pairs):
|
| |
|
| | png_images=[]
|
| | for img_path, mask_path in image_mask_pairs:
|
| | try:
|
| | save_image_path = os.path.join(self.save_path, os.path.basename(img_path))
|
| | print(f"Processing: {img_path} and {mask_path}")
|
| | self.process_image(img_path, mask_path, save_image_path)
|
| | extention = os.path.splitext(save_image_path)[1]
|
| | save_at=save_image_path.replace(extention, ".png")
|
| | png_images.append(save_at)
|
| | except Exception as e:
|
| | if self.save_path in png_images:
|
| | png_images.pop()
|
| | png_images.append(None)
|
| | print(f"Error: {e}")
|
| |
|
| | return png_images
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|