Spaces:
Paused
Paused
| from models import create_model | |
| from util.get_transform import get_transform | |
| from util.util import tensor2im | |
| from PIL import Image | |
| import os | |
| ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints') | |
| class Options(object): | |
| def __init__(self, *initial_data, **kwargs): | |
| for dictionary in initial_data: | |
| for key in dictionary: | |
| setattr(self, key, dictionary[key]) | |
| for key in kwargs: | |
| setattr(self, key, kwargs[key]) | |
| class ModelLoader: | |
| def __init__(self, gpu_ids='', max_img_wh=512) -> None: | |
| self.opt = Options({ | |
| 'isGradio': True, # Custom | |
| 'name': 'original', # Checkpoints name | |
| 'checkpoints_dir': ckp_path, # Checkpoint folder | |
| 'gpu_ids': gpu_ids.split(',') if gpu_ids else [], | |
| 'init_gain': 0.02, # Scaling Factor | |
| 'init_type': 'xavier', # list: 'normal', 'xavier', 'kaiming', 'orthogonal' | |
| 'input_nc': 3, # 3 -> RGB, 1 -> Grayscale | |
| 'output_nc': 3, | |
| 'isTrain': False, | |
| 'model': 'cwr', | |
| 'nce_idt': False, | |
| 'nce_layers': '0', | |
| 'ndf': 64, # Nb of discrim filters in the first conv layer | |
| 'netD': 'basic', | |
| 'netG': 'resnet_9blocks', | |
| 'netF': 'mlp_sample', | |
| 'netF_nc': 256, | |
| 'ngf': 64, # Nb of gen filters in the last conv layer | |
| 'no_antialias_up': False, | |
| 'no_antialias': False, | |
| 'no_dropout': True, | |
| 'normD': 'instance', | |
| 'normG': 'instance', | |
| 'preprocess': 'yarflam_auto', # see more: util.get_transform | |
| 'dataroot': 'placeholder', | |
| 'num_threads': 1, # test code only supports num_threads = 1 | |
| 'batch_size': 1, # test code only supports batch_size = 1 | |
| 'serial_batches': False, # disable data shuffling; comment this line if results on randomly chosen images are needed. | |
| 'no_flip': True, # no flip; comment this line if results on flipped images are needed. | |
| 'display_id': -1, # no visdom display; the test code saves the results to a HTML file. | |
| 'direction': 'AtoB', # inference | |
| 'flip_equivariance': False, | |
| 'load_size': 1680, # not used | |
| 'crop_size': 512, # not used | |
| 'yarflam_img_wh': max_img_wh, # max width|height + auto scale down | |
| }) | |
| self.transform = get_transform(self.opt, grayscale=False) | |
| self.model = None | |
| def load(self) -> None: | |
| self.model = create_model(self.opt) | |
| self.model.load_networks('latest') | |
| def inference(self, src='', image_pil=None): | |
| if self.model == None: self.load() | |
| # Loading | |
| if isinstance(image_pil, Image.Image): | |
| img = self.transform(image_pil.convert('RGB')).unsqueeze(0) | |
| else: | |
| if not os.path.isfile(src): | |
| raise Exception('The image %s is not found!' % src) | |
| print('Loading the image %s' % src) | |
| source = Image.open(src).convert('RGB') | |
| img = self.transform(source).unsqueeze(0) | |
| print(img.shape) | |
| # Inference | |
| self.model.set_input({ | |
| 'A': img, 'A_paths': src, | |
| 'B': img, 'B_paths': src | |
| }) | |
| self.model.forward() | |
| out_data = list(self.model.get_current_visuals().items())[1][1] | |
| out_img = Image.fromarray(tensor2im(out_data)) | |
| return out_img |