Spaces:
Configuration error
Configuration error
| import os | |
| from tqdm import tqdm | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from logger import Logger, Visualizer | |
| import numpy as np | |
| import imageio | |
| from sync_batchnorm import DataParallelWithCallback | |
| def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset): | |
| png_dir = os.path.join(log_dir, 'reconstruction/png') | |
| log_dir = os.path.join(log_dir, 'reconstruction') | |
| if checkpoint is not None: | |
| Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) | |
| else: | |
| raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) | |
| if not os.path.exists(log_dir): | |
| os.makedirs(log_dir) | |
| if not os.path.exists(png_dir): | |
| os.makedirs(png_dir) | |
| loss_list = [] | |
| if torch.cuda.is_available(): | |
| generator = DataParallelWithCallback(generator) | |
| kp_detector = DataParallelWithCallback(kp_detector) | |
| generator.eval() | |
| kp_detector.eval() | |
| for it, x in tqdm(enumerate(dataloader)): | |
| if config['reconstruction_params']['num_videos'] is not None: | |
| if it > config['reconstruction_params']['num_videos']: | |
| break | |
| with torch.no_grad(): | |
| predictions = [] | |
| visualizations = [] | |
| if torch.cuda.is_available(): | |
| x['video'] = x['video'].cuda() | |
| kp_source = kp_detector(x['video'][:, :, 0]) | |
| for frame_idx in range(x['video'].shape[2]): | |
| source = x['video'][:, :, 0] | |
| driving = x['video'][:, :, frame_idx] | |
| kp_driving = kp_detector(driving) | |
| out = generator(source, kp_source=kp_source, kp_driving=kp_driving) | |
| out['kp_source'] = kp_source | |
| out['kp_driving'] = kp_driving | |
| del out['sparse_deformed'] | |
| predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) | |
| visualization = Visualizer(**config['visualizer_params']).visualize(source=source, | |
| driving=driving, out=out) | |
| visualizations.append(visualization) | |
| loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy()) | |
| predictions = np.concatenate(predictions, axis=1) | |
| imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) | |
| image_name = x['name'][0] + config['reconstruction_params']['format'] | |
| imageio.mimsave(os.path.join(log_dir, image_name), visualizations) | |
| print("Reconstruction loss: %s" % np.mean(loss_list)) | |