Spaces:
Runtime error
Runtime error
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons Attribution-NonCommercial | |
| # 4.0 International License. To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
| # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
| import numpy as np | |
| import scipy.ndimage | |
| #---------------------------------------------------------------------------- | |
| def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image): | |
| S = minibatch.shape # (minibatch, channel, height, width) | |
| assert len(S) == 4 and S[1] == 3 | |
| N = nhoods_per_image * S[0] | |
| H = nhood_size // 2 | |
| nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H+1, -H:H+1] | |
| img = nhood // nhoods_per_image | |
| x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1)) | |
| y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1)) | |
| idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x | |
| return minibatch.flat[idx] | |
| #---------------------------------------------------------------------------- | |
| def finalize_descriptors(desc): | |
| if isinstance(desc, list): | |
| desc = np.concatenate(desc, axis=0) | |
| assert desc.ndim == 4 # (neighborhood, channel, height, width) | |
| desc -= np.mean(desc, axis=(0, 2, 3), keepdims=True) | |
| desc /= np.std(desc, axis=(0, 2, 3), keepdims=True) | |
| desc = desc.reshape(desc.shape[0], -1) | |
| return desc | |
| #---------------------------------------------------------------------------- | |
| def sliced_wasserstein(A, B, dir_repeats, dirs_per_repeat): | |
| assert A.ndim == 2 and A.shape == B.shape # (neighborhood, descriptor_component) | |
| results = [] | |
| for repeat in range(dir_repeats): | |
| dirs = np.random.randn(A.shape[1], dirs_per_repeat) # (descriptor_component, direction) | |
| dirs /= np.sqrt(np.sum(np.square(dirs), axis=0, keepdims=True)) # normalize descriptor components for each direction | |
| dirs = dirs.astype(np.float32) | |
| projA = np.matmul(A, dirs) # (neighborhood, direction) | |
| projB = np.matmul(B, dirs) | |
| projA = np.sort(projA, axis=0) # sort neighborhood projections for each direction | |
| projB = np.sort(projB, axis=0) | |
| dists = np.abs(projA - projB) # pointwise wasserstein distances | |
| results.append(np.mean(dists)) # average over neighborhoods and directions | |
| return np.mean(results) # average over repeats | |
| #---------------------------------------------------------------------------- | |
| def downscale_minibatch(minibatch, lod): | |
| if lod == 0: | |
| return minibatch | |
| t = minibatch.astype(np.float32) | |
| for i in range(lod): | |
| t = (t[:, :, 0::2, 0::2] + t[:, :, 0::2, 1::2] + t[:, :, 1::2, 0::2] + t[:, :, 1::2, 1::2]) * 0.25 | |
| return np.round(t).clip(0, 255).astype(np.uint8) | |
| #---------------------------------------------------------------------------- | |
| gaussian_filter = np.float32([ | |
| [1, 4, 6, 4, 1], | |
| [4, 16, 24, 16, 4], | |
| [6, 24, 36, 24, 6], | |
| [4, 16, 24, 16, 4], | |
| [1, 4, 6, 4, 1]]) / 256.0 | |
| def pyr_down(minibatch): # matches cv2.pyrDown() | |
| assert minibatch.ndim == 4 | |
| return scipy.ndimage.convolve(minibatch, gaussian_filter[np.newaxis, np.newaxis, :, :], mode='mirror')[:, :, ::2, ::2] | |
| def pyr_up(minibatch): # matches cv2.pyrUp() | |
| assert minibatch.ndim == 4 | |
| S = minibatch.shape | |
| res = np.zeros((S[0], S[1], S[2] * 2, S[3] * 2), minibatch.dtype) | |
| res[:, :, ::2, ::2] = minibatch | |
| return scipy.ndimage.convolve(res, gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, mode='mirror') | |
| def generate_laplacian_pyramid(minibatch, num_levels): | |
| pyramid = [np.float32(minibatch)] | |
| for i in range(1, num_levels): | |
| pyramid.append(pyr_down(pyramid[-1])) | |
| pyramid[-2] -= pyr_up(pyramid[-1]) | |
| return pyramid | |
| def reconstruct_laplacian_pyramid(pyramid): | |
| minibatch = pyramid[-1] | |
| for level in pyramid[-2::-1]: | |
| minibatch = pyr_up(minibatch) + level | |
| return minibatch | |
| #---------------------------------------------------------------------------- | |
| class API: | |
| def __init__(self, num_images, image_shape, image_dtype, minibatch_size): | |
| self.nhood_size = 7 | |
| self.nhoods_per_image = 128 | |
| self.dir_repeats = 4 | |
| self.dirs_per_repeat = 128 | |
| self.resolutions = [] | |
| res = image_shape[1] | |
| while res >= 16: | |
| self.resolutions.append(res) | |
| res //= 2 | |
| def get_metric_names(self): | |
| return ['SWDx1e3_%d' % res for res in self.resolutions] + ['SWDx1e3_avg'] | |
| def get_metric_formatting(self): | |
| return ['%-13.4f'] * len(self.get_metric_names()) | |
| def begin(self, mode): | |
| assert mode in ['warmup', 'reals', 'fakes'] | |
| self.descriptors = [[] for res in self.resolutions] | |
| def feed(self, mode, minibatch): | |
| for lod, level in enumerate(generate_laplacian_pyramid(minibatch, len(self.resolutions))): | |
| desc = get_descriptors_for_minibatch(level, self.nhood_size, self.nhoods_per_image) | |
| self.descriptors[lod].append(desc) | |
| def end(self, mode): | |
| desc = [finalize_descriptors(d) for d in self.descriptors] | |
| del self.descriptors | |
| if mode in ['warmup', 'reals']: | |
| self.desc_real = desc | |
| dist = [sliced_wasserstein(dreal, dfake, self.dir_repeats, self.dirs_per_repeat) for dreal, dfake in zip(self.desc_real, desc)] | |
| del desc | |
| dist = [d * 1e3 for d in dist] # multiply by 10^3 | |
| return dist + [np.mean(dist)] | |
| #---------------------------------------------------------------------------- | |