Spaces:
Runtime error
Runtime error
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from swapae.evaluation import BaseEvaluator | |
| import swapae.util as util | |
| class SwapVisualizationEvaluator(BaseEvaluator): | |
| def modify_commandline_options(parser, is_train): | |
| parser.add_argument("--swap_num_columns", type=int, default=4, | |
| help="number of images to be shown in the swap visualization grid. Setting this value will result in 4x4 swapping grid, with additional row and col for showing original images.") | |
| parser.add_argument("--swap_num_images", type=int, default=16, | |
| help="total number of images to perform swapping. In the end, (swap_num_images / swap_num_columns) grid will be saved to disk") | |
| return parser | |
| def gather_images(self, dataset): | |
| all_images = [] | |
| num_images_to_gather = max(self.opt.swap_num_columns, self.opt.num_gpus) | |
| exhausted = False | |
| while len(all_images) < num_images_to_gather: | |
| try: | |
| data = next(dataset) | |
| except StopIteration: | |
| print("Exhausted the dataset at %s" % (self.opt.dataroot)) | |
| exhausted = True | |
| break | |
| for i in range(data["real_A"].size(0)): | |
| all_images.append(data["real_A"][i:i+1]) | |
| if "real_B" in data: | |
| all_images.append(data["real_B"][i:i+1]) | |
| if len(all_images) >= num_images_to_gather: | |
| break | |
| if len(all_images) == 0: | |
| return None, None, True | |
| return all_images, exhausted | |
| def generate_mix_grid(self, model, images): | |
| sps, gls = [], [] | |
| for image in images: | |
| assert image.size(0) == 1 | |
| sp, gl = model(image.expand(self.opt.num_gpus, -1, -1, -1), command="encode") | |
| sp = sp[:1] | |
| gl = gl[:1] | |
| sps.append(sp) | |
| gls.append(gl) | |
| gl = torch.cat(gls, dim=0) | |
| def put_img(img, canvas, row, col): | |
| h, w = img.shape[0], img.shape[1] | |
| start_x = int(self.opt.load_size * col + (self.opt.load_size - w) * 0.5) | |
| start_y = int(self.opt.load_size * row + (self.opt.load_size - h) * 0.5) | |
| canvas[start_y:start_y + h, start_x: start_x + w] = img | |
| grid_w = self.opt.load_size * (gl.size(0) + 1) | |
| grid_h = self.opt.load_size * (gl.size(0) + 1) | |
| grid_img = np.ones((grid_h, grid_w, 3), dtype=np.uint8) | |
| #images_np = util.tensor2im(images, tile=False) | |
| for i, image in enumerate(images): | |
| image_np = util.tensor2im(image, tile=False)[0] | |
| put_img(image_np, grid_img, 0, i + 1) | |
| put_img(image_np, grid_img, i + 1, 0) | |
| for i, sp in enumerate(sps): | |
| sp_for_current_row = sp.repeat(gl.size(0), 1, 1, 1) | |
| mix_row = model(sp_for_current_row, gl, command="decode") | |
| mix_row = util.tensor2im(mix_row, tile=False) | |
| for j, mix in enumerate(mix_row): | |
| put_img(mix, grid_img, i + 1, j + 1) | |
| final_grid = Image.fromarray(grid_img) | |
| return final_grid | |
| def evaluate(self, model, dataset, nsteps): | |
| nsteps = self.opt.resume_iter if nsteps is None else str(round(nsteps / 1000)) + "k" | |
| savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps)) | |
| os.makedirs(savedir, exist_ok=True) | |
| webpage_title = "Swap Visualization of %s. iter=%s. phase=%s" % \ | |
| (self.opt.name, str(nsteps), self.target_phase) | |
| webpage = util.HTML(savedir, webpage_title) | |
| num_repeats = int(np.ceil(self.opt.swap_num_images / max(self.opt.swap_num_columns, self.opt.num_gpus))) | |
| for i in range(num_repeats): | |
| images, should_break = self.gather_images(dataset) | |
| if images is None: | |
| break | |
| mix_grid = self.generate_mix_grid(model, images) | |
| webpage.add_images([mix_grid], ["%04d.png" % i]) | |
| if should_break: | |
| break | |
| webpage.save() | |
| return {} | |