Spaces:
Runtime error
Runtime error
| import os | |
| from argparse import Namespace | |
| import torchvision | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import sys | |
| import time | |
| from tqdm import tqdm | |
| from Project.mapper.training.train_utils import convert_s_tensor_to_list | |
| sys.path.append(".") | |
| sys.path.append("..") | |
| from Project.mapper.datasets.latents_dataset import LatentsDataset, StyleSpaceLatentsDataset | |
| from Project.mapper.options.test_options import TestOptions | |
| from Project.mapper.styleclip_mapper import StyleCLIPMapper | |
| def run(test_opts): | |
| out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') | |
| os.makedirs(out_path_results, exist_ok=True) | |
| # update test options with options used during training | |
| ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') | |
| opts = ckpt['opts'] | |
| opts.update(vars(test_opts)) | |
| opts = Namespace(**opts) | |
| net = StyleCLIPMapper(opts) | |
| net.eval() | |
| test_latents = torch.load(opts.latents_test_path) | |
| if opts.work_in_stylespace: | |
| dataset = StyleSpaceLatentsDataset(latents=[l.cpu() for l in test_latents], opts=opts) | |
| else: | |
| dataset = LatentsDataset(latents=test_latents, opts=opts) | |
| dataloader = DataLoader(dataset, | |
| batch_size=opts.test_batch_size, | |
| shuffle=False, | |
| num_workers=int(opts.test_workers), | |
| drop_last=True) | |
| if opts.n_images is None: | |
| opts.n_images = len(dataset) | |
| global_i = 0 | |
| global_time = [] | |
| for input_batch in tqdm(dataloader): | |
| if global_i >= opts.n_images: | |
| break | |
| with torch.no_grad(): | |
| if opts.work_in_stylespace: | |
| input_cuda = convert_s_tensor_to_list(input_batch) | |
| input_cuda = [c for c in input_cuda] | |
| else: | |
| input_cuda = input_batch | |
| input_cuda = input_cuda | |
| tic = time.time() | |
| result_batch = run_on_batch(input_cuda, net, opts.couple_outputs, opts.work_in_stylespace) | |
| toc = time.time() | |
| global_time.append(toc - tic) | |
| for i in range(opts.test_batch_size): | |
| im_path = str(global_i).zfill(5) | |
| if test_opts.couple_outputs: | |
| couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)]) | |
| torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, value_range=(-1, 1)) | |
| else: | |
| torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, value_range=(-1, 1)) | |
| torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt")) | |
| global_i += 1 | |
| stats_path = os.path.join(opts.exp_dir, 'stats.txt') | |
| result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) | |
| print(result_str) | |
| with open(stats_path, 'w') as f: | |
| f.write(result_str) | |
| def run_on_batch(inputs, net, couple_outputs=False, stylespace=False): | |
| w = inputs | |
| with torch.no_grad(): | |
| if stylespace: | |
| delta = net.mapper(w) | |
| w_hat = [c + 0.1 * delta_c for (c, delta_c) in zip(w, delta)] | |
| x_hat, _, w_hat = net.decoder([w_hat], input_is_latent=True, return_latents=True, | |
| randomize_noise=False, truncation=1, input_is_stylespace=True) | |
| else: | |
| w_hat = w + 0.1 * net.mapper(w) | |
| x_hat, w_hat, _ = net.decoder([w_hat], input_is_latent=True, return_latents=True, | |
| randomize_noise=False, truncation=1) | |
| result_batch = (x_hat, w_hat) | |
| if couple_outputs: | |
| x, _ = net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1, input_is_stylespace=stylespace) | |
| result_batch = (x_hat, w_hat, x) | |
| return result_batch | |
| if __name__ == '__main__': | |
| test_opts = TestOptions().parse() | |
| run(test_opts) | |