Spaces:
Build error
Build error
| """Image Tiler.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| from itertools import product | |
| from math import ceil | |
| from typing import Optional, Sequence, Tuple, Union | |
| import torch | |
| import torchvision.transforms as T | |
| from torch import Tensor | |
| from torch.nn import functional as F | |
| class StrideSizeError(Exception): | |
| """StrideSizeError to raise exception when stride size is greater than the tile size.""" | |
| def compute_new_image_size(image_size: Tuple, tile_size: Tuple, stride: Tuple) -> Tuple: | |
| """This function checks if image size is divisible by tile size and stride. | |
| If not divisible, it resizes the image size to make it divisible. | |
| Args: | |
| image_size (Tuple): Original image size | |
| tile_size (Tuple): Tile size | |
| stride (Tuple): Stride | |
| Examples: | |
| >>> compute_new_image_size(image_size=(512, 512), tile_size=(256, 256), stride=(128, 128)) | |
| (512, 512) | |
| >>> compute_new_image_size(image_size=(512, 512), tile_size=(222, 222), stride=(111, 111)) | |
| (555, 555) | |
| Returns: | |
| Tuple: Updated image size that is divisible by tile size and stride. | |
| """ | |
| def __compute_new_edge_size(edge_size: int, tile_size: int, stride: int) -> int: | |
| """This function makes the resizing within the edge level.""" | |
| if (edge_size - tile_size) % stride != 0: | |
| edge_size = (ceil((edge_size - tile_size) / stride) * stride) + tile_size | |
| return edge_size | |
| resized_h = __compute_new_edge_size(image_size[0], tile_size[0], stride[0]) | |
| resized_w = __compute_new_edge_size(image_size[1], tile_size[1], stride[1]) | |
| return resized_h, resized_w | |
| def upscale_image(image: Tensor, size: Tuple, mode: str = "padding") -> Tensor: | |
| """Upscale image to the desired size via either padding or interpolation. | |
| Args: | |
| image (Tensor): Image | |
| size (Tuple): Tuple to which image is upscaled. | |
| mode (str, optional): Upscaling mode. Defaults to "padding". | |
| Examples: | |
| >>> image = torch.rand(1, 3, 512, 512) | |
| >>> image = upscale_image(image, size=(555, 555), mode="padding") | |
| >>> image.shape | |
| torch.Size([1, 3, 555, 555]) | |
| >>> image = torch.rand(1, 3, 512, 512) | |
| >>> image = upscale_image(image, size=(555, 555), mode="interpolation") | |
| >>> image.shape | |
| torch.Size([1, 3, 555, 555]) | |
| Returns: | |
| Tensor: Upscaled image. | |
| """ | |
| image_h, image_w = image.shape[2:] | |
| resize_h, resize_w = size | |
| if mode == "padding": | |
| pad_h = resize_h - image_h | |
| pad_w = resize_w - image_w | |
| image = F.pad(image, [0, pad_w, 0, pad_h]) | |
| elif mode == "interpolation": | |
| image = F.interpolate(input=image, size=(resize_h, resize_w)) | |
| else: | |
| raise ValueError(f"Unknown mode {mode}. Only padding and interpolation is available.") | |
| return image | |
| def downscale_image(image: Tensor, size: Tuple, mode: str = "padding") -> Tensor: | |
| """Opposite of upscaling. This image downscales image to a desired size. | |
| Args: | |
| image (Tensor): Input image | |
| size (Tuple): Size to which image is down scaled. | |
| mode (str, optional): Downscaling mode. Defaults to "padding". | |
| Examples: | |
| >>> x = torch.rand(1, 3, 512, 512) | |
| >>> y = upscale_image(image, upscale_size=(555, 555), mode="padding") | |
| >>> y = downscale_image(y, size=(512, 512), mode='padding') | |
| >>> torch.allclose(x, y) | |
| True | |
| Returns: | |
| Tensor: Downscaled image | |
| """ | |
| input_h, input_w = size | |
| if mode == "padding": | |
| image = image[:, :, :input_h, :input_w] | |
| else: | |
| image = F.interpolate(input=image, size=(input_h, input_w)) | |
| return image | |
| class Tiler: | |
| """Tile Image into (non)overlapping Patches. Images are tiled in order to efficiently process large images. | |
| Args: | |
| tile_size: Tile dimension for each patch | |
| stride: Stride length between patches | |
| remove_border_count: Number of border pixels to be removed from tile before untiling | |
| mode: Upscaling mode for image resize.Supported formats: padding, interpolation | |
| Examples: | |
| >>> import torch | |
| >>> from torchvision import transforms | |
| >>> from skimage.data import camera | |
| >>> tiler = Tiler(tile_size=256,stride=128) | |
| >>> image = transforms.ToTensor()(camera()) | |
| >>> tiles = tiler.tile(image) | |
| >>> image.shape, tiles.shape | |
| (torch.Size([3, 512, 512]), torch.Size([9, 3, 256, 256])) | |
| >>> # Perform your operations on the tiles. | |
| >>> # Untile the patches to reconstruct the image | |
| >>> reconstructed_image = tiler.untile(tiles) | |
| >>> reconstructed_image.shape | |
| torch.Size([1, 3, 512, 512]) | |
| """ | |
| def __init__( | |
| self, | |
| tile_size: Union[int, Sequence], | |
| stride: Union[int, Sequence], | |
| remove_border_count: int = 0, | |
| mode: str = "padding", | |
| tile_count: int = 4, | |
| ) -> None: | |
| self.tile_size_h, self.tile_size_w = self.__validate_size_type(tile_size) | |
| self.tile_count = tile_count | |
| self.stride_h, self.stride_w = self.__validate_size_type(stride) | |
| self.remove_border_count = int(remove_border_count) | |
| self.overlapping = not (self.stride_h == self.tile_size_h and self.stride_w == self.tile_size_w) | |
| self.mode = mode | |
| if self.stride_h > self.tile_size_h or self.stride_w > self.tile_size_w: | |
| raise StrideSizeError( | |
| "Larger stride size than kernel size produces unreliable tiling results. " | |
| "Please ensure stride size is less than or equal than tiling size." | |
| ) | |
| if self.mode not in ["padding", "interpolation"]: | |
| raise ValueError(f"Unknown tiling mode {self.mode}. Available modes are padding and interpolation") | |
| self.batch_size: int | |
| self.num_channels: int | |
| self.input_h: int | |
| self.input_w: int | |
| self.pad_h: int | |
| self.pad_w: int | |
| self.resized_h: int | |
| self.resized_w: int | |
| self.num_patches_h: int | |
| self.num_patches_w: int | |
| def __validate_size_type(parameter: Union[int, Sequence]) -> Tuple[int, ...]: | |
| if isinstance(parameter, int): | |
| output = (parameter, parameter) | |
| elif isinstance(parameter, Sequence): | |
| output = (parameter[0], parameter[1]) | |
| else: | |
| raise ValueError(f"Unknown type {type(parameter)} for tile or stride size. Could be int or Sequence type.") | |
| if len(output) != 2: | |
| raise ValueError(f"Length of the size type must be 2 for height and width. Got {len(output)} instead.") | |
| return output | |
| def __random_tile(self, image: Tensor) -> Tensor: | |
| """Randomly crop tiles from the given image. | |
| Args: | |
| image: input image to be cropped | |
| Returns: Randomly cropped tiles from the image | |
| """ | |
| return torch.vstack([T.RandomCrop(self.tile_size_h)(image) for i in range(self.tile_count)]) | |
| def __unfold(self, tensor: Tensor) -> Tensor: | |
| """Unfolds tensor into tiles. | |
| This is the core function to perform tiling operation. | |
| Args: | |
| tensor: Input tensor from which tiles are generated. | |
| Returns: Generated tiles | |
| """ | |
| # identify device type based on input tensor | |
| device = tensor.device | |
| # extract and calculate parameters | |
| batch, channels, image_h, image_w = tensor.shape | |
| self.num_patches_h = int((image_h - self.tile_size_h) / self.stride_h) + 1 | |
| self.num_patches_w = int((image_w - self.tile_size_w) / self.stride_w) + 1 | |
| # create an empty torch tensor for output | |
| tiles = torch.zeros( | |
| (self.num_patches_h, self.num_patches_w, batch, channels, self.tile_size_h, self.tile_size_w), device=device | |
| ) | |
| # fill-in output tensor with spatial patches extracted from the image | |
| for (tile_i, tile_j), (loc_i, loc_j) in zip( | |
| product(range(self.num_patches_h), range(self.num_patches_w)), | |
| product( | |
| range(0, image_h - self.tile_size_h + 1, self.stride_h), | |
| range(0, image_w - self.tile_size_w + 1, self.stride_w), | |
| ), | |
| ): | |
| tiles[tile_i, tile_j, :] = tensor[ | |
| :, :, loc_i : (loc_i + self.tile_size_h), loc_j : (loc_j + self.tile_size_w) | |
| ] | |
| # rearrange the tiles in order [tile_count * batch, channels, tile_height, tile_width] | |
| tiles = tiles.permute(2, 0, 1, 3, 4, 5) | |
| tiles = tiles.contiguous().view(-1, channels, self.tile_size_h, self.tile_size_w) | |
| return tiles | |
| def __fold(self, tiles: Tensor) -> Tensor: | |
| """Fold the tiles back into the original tensor. | |
| This is the core method to reconstruct the original image from its tiled version. | |
| Args: | |
| tiles: Tiles from the input image, generated via __unfold method. | |
| Returns: | |
| Output that is the reconstructed version of the input tensor. | |
| """ | |
| # number of channels differs between image and anomaly map, so infer from input tiles. | |
| _, num_channels, tile_size_h, tile_size_w = tiles.shape | |
| scale_h, scale_w = (tile_size_h / self.tile_size_h), (tile_size_w / self.tile_size_w) | |
| # identify device type based on input tensor | |
| device = tiles.device | |
| # calculate tile size after borders removed | |
| reduced_tile_h = tile_size_h - (2 * self.remove_border_count) | |
| reduced_tile_w = tile_size_w - (2 * self.remove_border_count) | |
| # reconstructed image dimension | |
| image_size = (self.batch_size, num_channels, int(self.resized_h * scale_h), int(self.resized_w * scale_w)) | |
| # rearrange input tiles in format [tile_count, batch, channel, tile_h, tile_w] | |
| tiles = tiles.contiguous().view( | |
| self.batch_size, | |
| self.num_patches_h, | |
| self.num_patches_w, | |
| num_channels, | |
| tile_size_h, | |
| tile_size_w, | |
| ) | |
| tiles = tiles.permute(0, 3, 1, 2, 4, 5) | |
| tiles = tiles.contiguous().view(self.batch_size, num_channels, -1, tile_size_h, tile_size_w) | |
| tiles = tiles.permute(2, 0, 1, 3, 4) | |
| # remove tile borders by defined count | |
| tiles = tiles[ | |
| :, | |
| :, | |
| :, | |
| self.remove_border_count : reduced_tile_h + self.remove_border_count, | |
| self.remove_border_count : reduced_tile_w + self.remove_border_count, | |
| ] | |
| # create tensors to store intermediate results and outputs | |
| img = torch.zeros(image_size, device=device) | |
| lookup = torch.zeros(image_size, device=device) | |
| ones = torch.ones(reduced_tile_h, reduced_tile_w, device=device) | |
| # reconstruct image by adding patches to their respective location and | |
| # create a lookup for patch count in every location | |
| for patch, (loc_i, loc_j) in zip( | |
| tiles, | |
| product( | |
| range( | |
| self.remove_border_count, | |
| int(self.resized_h * scale_h) - reduced_tile_h + 1, | |
| int(self.stride_h * scale_h), | |
| ), | |
| range( | |
| self.remove_border_count, | |
| int(self.resized_w * scale_w) - reduced_tile_w + 1, | |
| int(self.stride_w * scale_w), | |
| ), | |
| ), | |
| ): | |
| img[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += patch | |
| lookup[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += ones | |
| # divide the reconstucted image by the lookup to average out the values | |
| img = torch.divide(img, lookup) | |
| # alternative way of removing nan values (isnan not supported by openvino) | |
| img[img != img] = 0 # pylint: disable=comparison-with-itself | |
| return img | |
| def tile(self, image: Tensor, use_random_tiling: Optional[bool] = False) -> Tensor: | |
| """Tiles an input image to either overlapping, non-overlapping or random patches. | |
| Args: | |
| image: Input image to tile. | |
| Examples: | |
| >>> from anomalib.data.tiler import Tiler | |
| >>> tiler = Tiler(tile_size=512,stride=256) | |
| >>> image = torch.rand(size=(2, 3, 1024, 1024)) | |
| >>> image.shape | |
| torch.Size([2, 3, 1024, 1024]) | |
| >>> tiles = tiler.tile(image) | |
| >>> tiles.shape | |
| torch.Size([18, 3, 512, 512]) | |
| Returns: | |
| Tiles generated from the image. | |
| """ | |
| if image.dim() == 3: | |
| image = image.unsqueeze(0) | |
| self.batch_size, self.num_channels, self.input_h, self.input_w = image.shape | |
| if self.input_h < self.tile_size_h or self.input_w < self.tile_size_w: | |
| raise ValueError( | |
| f"One of the edges of the tile size {self.tile_size_h, self.tile_size_w} " | |
| "is larger than that of the image {self.input_h, self.input_w}." | |
| ) | |
| self.resized_h, self.resized_w = compute_new_image_size( | |
| image_size=(self.input_h, self.input_w), | |
| tile_size=(self.tile_size_h, self.tile_size_w), | |
| stride=(self.stride_h, self.stride_w), | |
| ) | |
| image = upscale_image(image, size=(self.resized_h, self.resized_w), mode=self.mode) | |
| if use_random_tiling: | |
| image_tiles = self.__random_tile(image) | |
| else: | |
| image_tiles = self.__unfold(image) | |
| return image_tiles | |
| def untile(self, tiles: Tensor) -> Tensor: | |
| """Untiles patches to reconstruct the original input image. | |
| If patches, are overlapping patches, the function averages the overlapping pixels, | |
| and return the reconstructed image. | |
| Args: | |
| tiles: Tiles from the input image, generated via tile().. | |
| Examples: | |
| >>> from anomalib.datasets.tiler import Tiler | |
| >>> tiler = Tiler(tile_size=512,stride=256) | |
| >>> image = torch.rand(size=(2, 3, 1024, 1024)) | |
| >>> image.shape | |
| torch.Size([2, 3, 1024, 1024]) | |
| >>> tiles = tiler.tile(image) | |
| >>> tiles.shape | |
| torch.Size([18, 3, 512, 512]) | |
| >>> reconstructed_image = tiler.untile(tiles) | |
| >>> reconstructed_image.shape | |
| torch.Size([2, 3, 1024, 1024]) | |
| >>> torch.equal(image, reconstructed_image) | |
| True | |
| Returns: | |
| Output that is the reconstructed version of the input tensor. | |
| """ | |
| image = self.__fold(tiles) | |
| image = downscale_image(image=image, size=(self.input_h, self.input_w), mode=self.mode) | |
| return image | |