Spaces:
Runtime error
Runtime error
| from gui.ui_win import Ui_Form | |
| from gui.ui_draw import * | |
| from PIL import Image, ImageQt | |
| import numpy as np | |
| import random, io, os | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from util import task, util | |
| from dataloader.image_folder import make_dataset | |
| from dataloader.data_loader import get_transform | |
| from model import create_model | |
| class ui_model(QtWidgets.QWidget, Ui_Form): | |
| """define the class of UI""" | |
| shape = 'line' | |
| CurrentWidth = 1 | |
| def __init__(self, opt): | |
| super(ui_model, self).__init__() | |
| self.setupUi(self) | |
| self.opt = opt | |
| self.show_result_flag = False | |
| self.mask_type = None | |
| self.img_power = None | |
| self.model_names = ['celeba', 'ffhq', 'imagenet', 'places2'] | |
| self.img_root = './examples/' | |
| self.img_files = ['celeba/img', 'ffhq/img', 'imagenet/img', 'places2/img'] | |
| self.show_logo() | |
| self.comboBox.activated.connect(self.load_model) # select model | |
| self.pushButton_2.clicked.connect(self.select_image) # manually select an image | |
| self.pushButton_3.clicked.connect(self.random_image) # randomly select an image | |
| self.pushButton_4.clicked.connect(self.load_mask) # manually select a mask | |
| self.pushButton_5.clicked.connect(self.random_mask) # randomly select a mask | |
| # draw/erasure the mask | |
| self.radioButton.toggled.connect(lambda: self.draw_mask('line')) # draw the line | |
| self.radioButton_2.toggled.connect(lambda: self.draw_mask('rectangle')) # draw the rectangle | |
| self.radioButton_3.toggled.connect(lambda: self.draw_mask('center')) # center mask | |
| self.spinBox.valueChanged.connect(self.change_thickness) | |
| self.pushButton.clicked.connect(self.clear_mask) | |
| # fill image | |
| self.pushButton_6.clicked.connect(self.fill_image) | |
| self.comboBox_2.activated.connect(self.show_result) | |
| self.pushButton_7.clicked.connect(self.save_result) | |
| opt.preprocess = 'scale_shortside' | |
| self.transform_o = get_transform(opt, convert=False, augment=False) | |
| self.pil2tensor = transforms.ToTensor() | |
| def show_logo(self): | |
| """Show the logo of NTU and BTC""" | |
| img = QtWidgets.QLabel(self) | |
| img.setGeometry(1000, 10, 140, 50) | |
| pixmap = QtGui.QPixmap("./gui/logo/NTU_logo.jpg") # read examples | |
| pixmap = pixmap.scaled(140, 140, QtCore.Qt.KeepAspectRatio, QtCore.Qt.SmoothTransformation) | |
| img.setPixmap(pixmap) | |
| img.show() | |
| img1 = QtWidgets.QLabel(self) | |
| img1.setGeometry(1200, 10, 70, 50) | |
| pixmap1 = QtGui.QPixmap("./gui/logo/BTC_logo.png") # read examples | |
| pixmap1 = pixmap1.scaled(70, 70, QtCore.Qt.KeepAspectRatio, QtCore.Qt.SmoothTransformation) | |
| img1.setPixmap(pixmap1) | |
| img1.show() | |
| def show_image(self, img): | |
| """Show the masked examples""" | |
| show_img = img.copy() | |
| if self.mask_type == 'center': | |
| sub_img = Image.fromarray(np.uint8(255 * np.ones((int(self.pw/2), int(self.pw/2), 3)))) | |
| mask = Image.fromarray(np.uint8(255 * np.ones((int(self.pw/2), int(self.pw/2))))) | |
| show_img.paste(sub_img, box=(int(self.pw/4), int(self.pw/4)), mask=mask) | |
| elif self.mask_type == 'external': | |
| mask = Image.open(self.mname).resize(self.img_power.size).convert('RGB') | |
| mask_L = Image.open(self.mname).resize(self.img_power.size).convert('L') | |
| show_img = Image.composite(mask, show_img, mask_L) | |
| self.new_painter(ImageQt.ImageQt(show_img)) | |
| def show_result(self): | |
| """Show different kind examples""" | |
| value = self.comboBox_2.currentIndex() | |
| if value == 0: | |
| self.new_painter(ImageQt.ImageQt(self.img_power)) | |
| elif value == 1: | |
| masked_img = torch.where(self.mask > 0, self.img_m, torch.ones_like(self.img_m)) | |
| masked_img = Image.fromarray(util.tensor2im(masked_img.detach())) | |
| self.new_painter(ImageQt.ImageQt(masked_img)) | |
| elif value == 2: | |
| if 'refine' in self.opt.coarse_or_refine: | |
| img_out = Image.fromarray(util.tensor2im(self.img_ref_out.detach())) | |
| else: | |
| img_out = Image.fromarray(util.tensor2im(self.img_out.detach())) | |
| self.new_painter(ImageQt.ImageQt(img_out)) | |
| def save_result(self): | |
| """Save the results to the disk""" | |
| util.mkdir(self.opt.results_dir) | |
| img_name = self.fname.split('/')[-1] | |
| data_name = self.opt.img_file.split('/')[-1].split('.')[0] | |
| original_name = '%s_%s_%s' % ('original', data_name, img_name) # save the original image | |
| original_path = os.path.join(self.opt.results_dir, original_name) | |
| img_original = util.tensor2im(self.img_truth) | |
| util.save_image(img_original, original_path) | |
| mask_name = '%s_%s_%d_%s' % ('mask', data_name, self.PaintPanel.iteration, img_name) | |
| mask_path = os.path.join(self.opt.results_dir, mask_name) | |
| mask = self.mask.repeat(1, 3, 1, 1) | |
| img_mask = util.tensor2im(1-mask) | |
| util.save_image(img_mask, mask_path) | |
| #save masked image | |
| masked_img_name = '%s_%s_%d_%s' % ('masked_img', data_name, self.PaintPanel.iteration, img_name) | |
| img_path = os.path.join(self.opt.results_dir, masked_img_name) | |
| img = torch.where(self.mask < 0.2, torch.ones_like(self.img_truth), self.img_truth) | |
| masked_img = util.tensor2im(img) | |
| util.save_image(masked_img, img_path) | |
| # save the generated results | |
| img_g_name = '%s_%s_%d_%s' % ('g', data_name, self.PaintPanel.iteration, img_name) | |
| img_path = os.path.join(self.opt.results_dir, img_g_name) | |
| img_g = util.tensor2im(self.img_g) | |
| util.save_image(img_g, img_path) | |
| # save the results | |
| result_name = '%s_%s_%d_%s' % ('out', data_name, self.PaintPanel.iteration, img_name) | |
| result_path = os.path.join(self.opt.results_dir, result_name) | |
| img_result = util.tensor2im(self.img_out) | |
| util.save_image(img_result, result_path) | |
| # save the refined results | |
| if 'tc' in self.opt.model and 'refine' in self.opt.coarse_or_refine: | |
| result_name = '%s_%s_%d_%s' % ('ref', data_name, self.PaintPanel.iteration, img_name) | |
| result_path = os.path.join(self.opt.results_dir, result_name) | |
| img_result = util.tensor2im(self.img_ref_out) | |
| util.save_image(img_result, result_path) | |
| def load_model(self): | |
| """Load different kind models""" | |
| value = self.comboBox.currentIndex() | |
| if value == 0: | |
| raise NotImplementedError("Please choose a model") | |
| else: | |
| index = value-1 # define the model type and dataset type | |
| self.opt.name = self.model_names[index] | |
| self.opt.img_file = self.img_root + self.img_files[index % len(self.img_files)] | |
| self.model = create_model(self.opt) | |
| self.model.setup(self.opt) | |
| def load_image(self, fname): | |
| """Load the image""" | |
| self.img_o = Image.open(fname).convert('RGB') | |
| self.ow, self.oh = self.img_o.size | |
| self.img_power = self.transform_o(self.img_o) | |
| self.pw, self.ph = self.img_power.size | |
| return self.img_power | |
| def select_image(self): | |
| """Load the image""" | |
| self.fname, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'select the image', self.opt.img_file, '*') | |
| img = self.load_image(self.fname) | |
| self.mask_type = 'none' | |
| self.show_image(img) | |
| def random_image(self): | |
| """Random load the test image""" | |
| image_paths, image_size = make_dataset(self.opt.img_file) | |
| item = random.randint(0, image_size-1) | |
| self.fname = image_paths[item] | |
| img = self.load_image(self.fname) | |
| self.mask_type = 'none' | |
| self.show_image(img) | |
| def load_mask(self): | |
| """Load a mask""" | |
| self.mask_type = 'external' | |
| self.mname, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'select the mask', self.opt.mask_file,'*') | |
| self.show_image(self.img_power) | |
| def random_mask(self): | |
| """Random load the test mask""" | |
| if self.opt.mask_file == 'none': | |
| raise NotImplementedError("Please input the mask path") | |
| self.mask_type = 'external' | |
| mask_paths, mask_size = make_dataset(self.opt.mask_file) | |
| item = random.randint(0, mask_size - 1) | |
| self.mname = mask_paths[item] | |
| self.show_image(self.img_power) | |
| def read_mask(self): | |
| """Read the mask from the painted plain""" | |
| self.PaintPanel.saveDraw() | |
| buffer = QtCore.QBuffer() | |
| buffer.open(QtCore.QBuffer.ReadWrite) | |
| self.PaintPanel.map.save(buffer, 'PNG') | |
| pil_im = Image.open(io.BytesIO(buffer.data())) | |
| return pil_im | |
| def new_painter(self, image=None): | |
| """Build a painter to load and process the image""" | |
| # painter | |
| self.PaintPanel = painter(self, image) | |
| self.PaintPanel.close() | |
| if image is not None: | |
| w, h = image.size().width(), image.size().height() | |
| self.stackedWidget.setGeometry(QtCore.QRect(250+int(512-w/2), 100+int(128-h/8), w, h)) | |
| self.stackedWidget.insertWidget(0, self.PaintPanel) | |
| self.stackedWidget.setCurrentWidget(self.PaintPanel) | |
| def change_thickness(self, num): | |
| """Change the width of the painter""" | |
| self.CurrentWidth = num | |
| self.PaintPanel.CurrentWidth = num | |
| def draw_mask(self, masktype): | |
| """Draw the mask""" | |
| if masktype == 'center': | |
| self.mask_type = 'center' | |
| if self.img_power is not None: | |
| self.show_image(self.img_power) | |
| else: | |
| self.mask_type = 'draw' | |
| self.shape = masktype | |
| self.PaintPanel.shape = masktype | |
| def clear_mask(self): | |
| """Clear the mask""" | |
| self.mask_type = 'draw' | |
| if self.PaintPanel.Brush: | |
| self.PaintPanel.Brush = False | |
| else: | |
| self.PaintPanel.Brush = True | |
| def set_input(self): | |
| """Set the input for the network""" | |
| img_o = self.pil2tensor(self.img_o).unsqueeze(0) | |
| img = self.pil2tensor(self.img_power).unsqueeze(0) | |
| if self.mask_type == 'draw': | |
| # get the test mask from painter | |
| mask = self.read_mask() | |
| mask = torch.autograd.Variable(self.pil2tensor(mask)).unsqueeze(0)[:, 0:1, :, :] | |
| elif self.mask_type == 'center': | |
| mask = torch.zeros_like(img)[:, 0:1, :, :] | |
| mask[:, :, int(self.pw/4):int(3*self.pw/4), int(self.ph/4):int(3*self.ph/4)] = 1 | |
| elif self.mask_type == 'external': | |
| mask = self.pil2tensor(Image.open(self.mname).resize((self.pw, self.ph)).convert('L')).unsqueeze(0) | |
| mask = (mask < 0.5).float() | |
| if len(self.opt.gpu_ids) > 0: | |
| img = img.cuda(self.opt.gpu_ids[0]) | |
| mask = mask.cuda(self.opt.gpu_ids[0]) | |
| img_o = img_o.cuda(self.opt.gpu_ids[0]) | |
| self.mask = mask | |
| self.img_org = img_o * 2 - 1 | |
| self.img_truth = img * 2 - 1 | |
| self.img_m = self.mask * self.img_truth | |
| def fill_image(self): | |
| """Forward to get the completed results""" | |
| self.set_input() | |
| if self.PaintPanel.iteration < 1: | |
| with torch.no_grad(): | |
| fixed_img = F.interpolate(self.img_m, size=[self.opt.fixed_size, self.opt.fixed_size], mode='bicubic', align_corners=True).clamp(-1, 1) | |
| fixed_mask = (F.interpolate(self.mask, size=[self.opt.fixed_size, self.opt.fixed_size], mode='bicubic', align_corners=True) > 0.9).type_as(fixed_img) | |
| out, mask = self.model.netE(fixed_img, mask=fixed_mask, return_mask=True) | |
| out = self.model.netT(out, mask, bool_mask=False) | |
| self.img_g = self.model.netG(out) | |
| img_g_org = F.interpolate(self.img_g, size=self.img_truth.size()[2:], mode='bicubic', align_corners=True).clamp(-1, 1) | |
| self.img_out = self.mask * self.img_truth + (1 - self.mask) * img_g_org | |
| if 'refine' in self.opt.coarse_or_refine: | |
| img_ref = self.model.netG_Ref(self.img_out, mask=self.mask) | |
| self.img_ref_out = self.mask * self.img_truth + (1 - self.mask) * img_ref | |
| print('finish the completion') | |
| self.show_result_flag = True | |
| self.show_result() |