Spaces:
Build error
Build error
| import os | |
| from argparse import Namespace | |
| from tqdm import tqdm | |
| import time | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| import sys | |
| sys.path.append(".") | |
| sys.path.append("..") | |
| from configs import data_configs | |
| from datasets.inference_dataset import InferenceDataset | |
| from utils.common import tensor2im, log_input_image | |
| from options.test_options import TestOptions | |
| from models.psp import pSp | |
| def run(): | |
| test_opts = TestOptions().parse() | |
| if test_opts.resize_factors is not None: | |
| assert len( | |
| test_opts.resize_factors.split(',')) == 1, "When running inference, provide a single downsampling factor!" | |
| out_path_results = os.path.join(test_opts.exp_dir, 'inference_results', | |
| 'downsampling_{}'.format(test_opts.resize_factors)) | |
| out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled', | |
| 'downsampling_{}'.format(test_opts.resize_factors)) | |
| else: | |
| out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') | |
| out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled') | |
| os.makedirs(out_path_results, exist_ok=True) | |
| os.makedirs(out_path_coupled, exist_ok=True) | |
| # update test options with options used during training | |
| ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') | |
| opts = ckpt['opts'] | |
| opts.update(vars(test_opts)) | |
| if 'learn_in_w' not in opts: | |
| opts['learn_in_w'] = False | |
| if 'output_size' not in opts: | |
| opts['output_size'] = 1024 | |
| opts = Namespace(**opts) | |
| net = pSp(opts) | |
| net.eval() | |
| net.cuda() | |
| print('Loading dataset for {}'.format(opts.dataset_type)) | |
| dataset_args = data_configs.DATASETS[opts.dataset_type] | |
| transforms_dict = dataset_args['transforms'](opts).get_transforms() | |
| dataset = InferenceDataset(root=opts.data_path, | |
| transform=transforms_dict['transform_inference'], | |
| opts=opts) | |
| dataloader = DataLoader(dataset, | |
| batch_size=opts.test_batch_size, | |
| shuffle=False, | |
| num_workers=int(opts.test_workers), | |
| drop_last=True) | |
| if opts.n_images is None: | |
| opts.n_images = len(dataset) | |
| global_i = 0 | |
| global_time = [] | |
| for input_batch in tqdm(dataloader): | |
| if global_i >= opts.n_images: | |
| break | |
| with torch.no_grad(): | |
| input_cuda = input_batch.cuda().float() | |
| tic = time.time() | |
| result_batch = run_on_batch(input_cuda, net, opts) | |
| toc = time.time() | |
| global_time.append(toc - tic) | |
| for i in range(opts.test_batch_size): | |
| result = tensor2im(result_batch[i]) | |
| im_path = dataset.paths[global_i] | |
| if opts.couple_outputs or global_i % 100 == 0: | |
| input_im = log_input_image(input_batch[i], opts) | |
| resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) | |
| if opts.resize_factors is not None: | |
| # for super resolution, save the original, down-sampled, and output | |
| source = Image.open(im_path) | |
| res = np.concatenate([np.array(source.resize(resize_amount)), | |
| np.array(input_im.resize(resize_amount, resample=Image.NEAREST)), | |
| np.array(result.resize(resize_amount))], axis=1) | |
| else: | |
| # otherwise, save the original and output | |
| res = np.concatenate([np.array(input_im.resize(resize_amount)), | |
| np.array(result.resize(resize_amount))], axis=1) | |
| Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path))) | |
| im_save_path = os.path.join(out_path_results, os.path.basename(im_path)) | |
| Image.fromarray(np.array(result)).save(im_save_path) | |
| global_i += 1 | |
| stats_path = os.path.join(opts.exp_dir, 'stats.txt') | |
| result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) | |
| print(result_str) | |
| with open(stats_path, 'w') as f: | |
| f.write(result_str) | |
| def run_on_batch(inputs, net, opts): | |
| if opts.latent_mask is None: | |
| result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs) | |
| else: | |
| latent_mask = [int(l) for l in opts.latent_mask.split(",")] | |
| result_batch = [] | |
| for image_idx, input_image in enumerate(inputs): | |
| # get latent vector to inject into our input image | |
| vec_to_inject = np.random.randn(1, 512).astype('float32') | |
| _, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"), | |
| input_code=True, | |
| return_latents=True) | |
| # get output image with injected style vector | |
| res = net(input_image.unsqueeze(0).to("cuda").float(), | |
| latent_mask=latent_mask, | |
| inject_latent=latent_to_inject, | |
| alpha=opts.mix_alpha, | |
| resize=opts.resize_outputs) | |
| result_batch.append(res) | |
| result_batch = torch.cat(result_batch, dim=0) | |
| return result_batch | |
| if __name__ == '__main__': | |
| run() | |