Spaces:
Runtime error
Runtime error
| """ | |
| @Author: Nicolo' Bonettini | |
| @Author: Luca Bondi | |
| @Author: Francesco Picetti | |
| """ | |
| import random | |
| import numpy as np | |
| from skimage.util import view_as_windows, view_as_blocks | |
| # Score functions --- | |
| def mid_intensity_high_texture(in_content): | |
| """ | |
| Quality function that returns higher scores for mid intensity patches with high texture levels. Empirical. | |
| :type in_content: ndarray | |
| :param in_content : 2D or 3D ndarray. Values are expected in [0,1] if in_content is float, in [0,255] if in_content is uint8 | |
| :return score: float | |
| score in [0,1]. | |
| """ | |
| if in_content.dtype == np.uint8: | |
| in_content = in_content / 255. | |
| mean_std_weight = .7 | |
| in_content = in_content.flatten() | |
| mean_val = in_content.mean() | |
| std_val = in_content.std() | |
| ch_mean_score = -4 * mean_val ** 2 + 4 * mean_val | |
| ch_std_score = 1 - np.exp(-2 * np.log(10) * std_val) | |
| score = mean_std_weight * ch_mean_score + (1 - mean_std_weight) * ch_std_score | |
| return score | |
| def count_patches(in_size, patch_size, patch_stride): | |
| """ | |
| Compute the number of patches | |
| :param in_size: | |
| :param patch_size: | |
| :param patch_stride: | |
| :return: | |
| """ | |
| win_indices_shape = (((np.array(in_size) - np.array(patch_size)) | |
| // np.array(patch_stride)) + 1) | |
| return int(np.prod(win_indices_shape)) | |
| class PatchExtractor: | |
| def __init__(self, dim, offset=None, stride=None, rand=None, function=None, threshold=None, | |
| num=None, indexes=None): | |
| """ | |
| N-dimensional patch extractor | |
| Args: | |
| :param in_content : ndarray | |
| the content to process as a numpy array of ndim dimensions | |
| :param dim : tuple | |
| patch_array dimensions as a tuple of ndim elements | |
| Named args: | |
| :param offset : tuple | |
| the offsets along each axis as a tuple of ndim elements | |
| :param stride : tuple | |
| the stride of each axis as a tuple of ndim elements | |
| :param rand : bool | |
| randomize patch_array order. Mutually exclusive with function_handler | |
| :param function : function | |
| patch quality function handler. Mutually exclusive with rand | |
| :param threshold: float | |
| minimum quality threshold | |
| :param num : int | |
| maximum number of returned patch_array. Mutually exclusive with indexes | |
| :param indexes : list|ndarray | |
| explicitly return corresponding patch indexes (function_handler or C order used to index patch_array). | |
| Mutually exclusive with num | |
| :return ndarray: patch_array | |
| array of patch_array | |
| if rand==False and function_handler==None and num==None and indexes==None: | |
| patch_array.ndim = 2 * in_content.ndim | |
| else: | |
| patch_array.ndim = 1 + in_content.ndim | |
| """ | |
| # Arguments parser --- | |
| if not isinstance(dim, tuple): | |
| raise ValueError('dim must be a tuple') | |
| self.dim = dim | |
| ndim = len(dim) | |
| self.ndim = ndim | |
| if offset is None: | |
| offset = tuple([0] * ndim) | |
| if not isinstance(offset, tuple): | |
| raise ValueError('offset must be a tuple') | |
| if len(offset) != ndim: | |
| raise ValueError('offset must a tuple of length {:d}'.format(ndim)) | |
| self.offset = offset | |
| if stride is None: | |
| stride = dim | |
| if not isinstance(stride, tuple): | |
| raise ValueError('stride must be a tuple') | |
| if len(stride) != ndim: | |
| raise ValueError('stride must a tuple of length {:d}'.format(ndim)) | |
| self.stride = stride | |
| if rand is not None and function is not None: | |
| raise ValueError('rand and function cannot be set at the same time') | |
| if rand is None: | |
| rand = False | |
| if not isinstance(rand, bool): | |
| raise ValueError('rand must be a boolean') | |
| self.rand = rand | |
| if function is not None and not callable(function): | |
| raise ValueError('function must be a function handler') | |
| self.function_handler = function | |
| if threshold is None: | |
| threshold = 0.0 | |
| if not isinstance(threshold, float): | |
| raise ValueError('threshold must be a float') | |
| self.threshold = threshold | |
| if num is not None and indexes is not None: | |
| raise ValueError('num and indexes cannot be set at the same time') | |
| if num is not None and not isinstance(num, int): | |
| raise ValueError('num must be an int') | |
| self.num = num | |
| if indexes is not None and not isinstance(indexes, list) and not isinstance(indexes, np.ndarray): | |
| raise ValueError('indexes must be an list or a 1d ndarray') | |
| if indexes is not None: | |
| indexes = np.array(indexes).flatten() | |
| self.indexes = indexes | |
| self.in_content_original_shape = None | |
| self.in_content_cropped_shape = None | |
| def extract(self, in_content): | |
| if not isinstance(in_content, np.ndarray): | |
| raise ValueError('in_content must be of type: ' + str(np.ndarray)) | |
| if in_content.ndim != self.ndim: | |
| raise ValueError('in_content shape must a tuple of length {:d}'.format(self.ndim)) | |
| self.in_content_original_shape = in_content.shape | |
| # Offset --- | |
| for dim_idx, dim_offset in enumerate(self.offset): | |
| dim_max = in_content.shape[dim_idx] | |
| in_content = in_content.take(range(dim_offset, dim_max), axis=dim_idx) | |
| # Patch list --- | |
| if self.dim == self.stride: | |
| in_content_crop = in_content | |
| for dim_idx in range(self.ndim): | |
| dim_max = (in_content.shape[dim_idx] // self.dim[dim_idx]) * self.dim[dim_idx] | |
| in_content_crop = in_content_crop.take(range(0, dim_max), axis=dim_idx) | |
| patch_array = view_as_blocks(in_content_crop, self.dim) | |
| else: | |
| patch_array = view_as_windows(in_content, self.dim, self.stride) | |
| patch_array = np.ascontiguousarray(patch_array) | |
| patch_idx = patch_array.shape[:self.ndim] | |
| self.in_content_cropped_shape = tuple((np.asarray(patch_idx) - 1) * np.asarray(self.stride) + np.asarray(self.dim)) | |
| # Evaluate patch_array or rand sort --- | |
| if self.rand: | |
| patch_array.shape = (-1,) + self.dim | |
| random.shuffle(patch_array) | |
| else: | |
| if self.function_handler is not None: | |
| patch_array.shape = (-1,) + self.dim | |
| patch_scores = np.asarray(list(map(self.function_handler, patch_array))) | |
| sort_idxs = np.argsort(patch_scores)[::-1] | |
| patch_scores = patch_scores[sort_idxs] | |
| patch_array = patch_array[sort_idxs] | |
| patch_array = patch_array[patch_scores >= self.threshold] | |
| if self.num is not None: | |
| patch_array.shape = (-1,) + self.dim | |
| patch_array = patch_array[:self.num] | |
| if self.indexes is not None: | |
| patch_array.shape = (-1,) + self.dim | |
| patch_array = patch_array[self.indexes] | |
| return patch_array | |
| def extract_call(self, args): # TODO: verify | |
| in_content = args.pop('in_content') | |
| dim = args.pop('dim') | |
| return self.extract(in_content) | |
| def reconstruct(self, patch_array): | |
| """ | |
| Reconstruct the N-dim image from the patch_array that has been extracted previously | |
| :param patch_array: array of patches as output of patch_extractor | |
| :return: | |
| """ | |
| # Arguments parser --- | |
| if not isinstance(patch_array, np.ndarray): | |
| raise ValueError('patch_array must be of type: ' + str(np.ndarray)) | |
| ndim = patch_array.ndim // 2 | |
| # if not isinstance(patch_stride, tuple): | |
| # raise ValueError('patch_stride must be a tuple') | |
| # if len(patch_stride) != ndim: | |
| # raise ValueError('patch_stride must be a tuple of length {:d}'.format(ndim)) | |
| # | |
| # if not isinstance(image_shape, tuple): | |
| # raise ValueError('patch_idx must be a tuple') | |
| # if len(image_shape) != ndim: | |
| # raise ValueError('patch_idx must be a tuple of length {:d}'.format(ndim)) | |
| patch_stride = self.stride | |
| image_shape = self.in_content_cropped_shape | |
| patch_shape = patch_array.shape[-ndim:] | |
| patch_idx = patch_array.shape[:ndim] | |
| image_shape_computed = tuple((np.array(patch_idx) - 1) * np.array(patch_stride) + np.array(patch_shape)) | |
| if not image_shape == image_shape_computed: | |
| raise ValueError('There is something wrong with the dimensions!') | |
| if ndim > 4: | |
| raise ValueError('For now, it works only in 4D, sorry!') | |
| numpatches = count_patches(image_shape, patch_shape, patch_stride) | |
| patch_array_unwrapped = patch_array.reshape(numpatches, *patch_shape) | |
| image_recon = np.zeros(image_shape) | |
| norm_mask = np.zeros(image_shape) | |
| counter = 0 | |
| for h in np.arange(0, image_shape[0] - patch_shape[0] + 1, patch_stride[0]): | |
| if ndim > 1: | |
| for i in np.arange(0, image_shape[1] - patch_shape[1] + 1, patch_stride[1]): | |
| if ndim > 2: | |
| for j in np.arange(0, image_shape[2] - patch_shape[2] + 1, patch_stride[2]): | |
| if ndim > 3: | |
| for k in np.arange(0, image_shape[3] - patch_shape[3] + 1, patch_stride[3]): | |
| image_recon[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2], | |
| k:k + patch_shape[3]] += patch_array_unwrapped[counter, :, :, :, :] | |
| norm_mask[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2], | |
| k:k + patch_shape[3]] += 1 | |
| counter += 1 | |
| else: | |
| image_recon[h:h + patch_shape[0], i:i + patch_shape[1], | |
| j:j + patch_shape[2]] += patch_array_unwrapped[counter, :, :, :] | |
| norm_mask[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2]] += 1 | |
| counter += 1 | |
| else: | |
| image_recon[h:h + patch_shape[0], i:i + patch_shape[1]] += patch_array_unwrapped[counter, :, :] | |
| norm_mask[h:h + patch_shape[0], i:i + patch_shape[1]] += 1 | |
| counter += 1 | |
| else: | |
| image_recon[h:h + patch_shape[0]] += patch_array_unwrapped[counter, :] | |
| norm_mask[h:h + patch_shape[0]] += 1 | |
| counter += 1 | |
| image_recon /= norm_mask | |
| return image_recon | |
| def main(): | |
| in_shape = (644, 481, 3) | |
| dim = (120, 120, 3) | |
| stride = (7, 90, 90, 3) | |
| offset = (1, 0, 0, 0) | |
| in_content = np.random.randint(256, size=in_shape).astype(np.uint8) | |
| # args = {'in_content': in_content, | |
| # 'dim': dim, | |
| # 'offset': offset, | |
| # 'stride': stride, | |
| # } | |
| # patch_array = patch_extractor_call(args) | |
| pe = PatchExtractor(dim) | |
| patch_array = pe.extract(in_content) | |
| print('patch_array.shape = ' + str(patch_array.shape)) | |
| img_recon = pe.reconstruct(patch_array) | |
| print('img_recon.shape = ' + str(img_recon.shape)) | |
| if __name__ == "__main__": | |
| main() | |