| """ |
| Utility functions for the DIE demo. |
| """ |
|
|
|
|
| import torch |
| from PIL import Image |
| from torch import Tensor |
| from torchvision import transforms |
|
|
|
|
| def resize_image( |
| image: Image.Image, |
| max_size: int = 1024 |
| ) -> Image.Image: |
| """ |
| Resizing images by keeping the ratios |
| :param image: PIL image |
| :param max_size: size of the new image larger side |
| :return: the resized PIL image |
| """ |
|
|
| |
| width, height = image.size |
|
|
| |
| height_larger = True if height >= width else False |
|
|
| |
| if height_larger: |
| height_new = max_size |
| width_new = round((height_new / height) * width) |
| else: |
| width_new = max_size |
| height_new = round((width_new / width) * height) |
|
|
| return image.resize((width_new, height_new)) |
|
|
|
|
| def make_image_square( |
| image: Image.Image, |
| image_size: int = 1024 |
| ) -> Image.Image: |
| """ |
| Making the input image a square |
| :param image: PIL image |
| :param image_size: defines the size of the square image |
| :return: the square-sized PIL image |
| """ |
|
|
| if max(image.size) > image_size: |
| image_size = max(image.size) |
| |
| if image.mode == 'L': |
| image_square = Image.new(image.mode, (image_size, image_size), (255,)) |
| elif image.mode == 'RGB': |
| image_square = Image.new(image.mode, (image_size, image_size), (255, 255, 255)) |
| else: |
| raise NotImplementedError("Not implemented image mode.") |
| |
| image_square.paste(image, (0, 0)) |
|
|
| return image_square |
|
|
|
|
| def cast_pil_image_to_torch_tensor_with_4_channel_dim( |
| image: Image.Image, |
| device: str | None = None |
| ) -> Tensor: |
| """ |
| Casting PIL image to torch tensor. |
| Adding the grayscale image of the original RGB image as a 4th channel dimension. |
| :param image: input image |
| :param device: cuda device |
| :return: torch tensor (4 channel dim) |
| """ |
|
|
| |
| transform = transforms.Compose([transforms.PILToTensor()]) |
|
|
| |
| image_gray = image.convert('L') |
|
|
| |
| image_tensor = transform(image.convert('RGB')).to(torch.float32) / 255.0 |
| image_gray_tensor = transform(image_gray).to(torch.float32) / 255.0 |
|
|
| |
| final_image_tensor = torch.cat((image_tensor, image_gray_tensor), dim=0) |
|
|
| |
| if device is not None: |
| final_image_tensor = final_image_tensor.to(device) |
|
|
| return final_image_tensor |
|
|
|
|
| def remove_square_padding( |
| original_image: Image.Image | Tensor, |
| square_image: Image.Image | Tensor, |
| resize_back_to_original: bool = False |
| ): |
| """ |
| Removing the square padding added to the original image to make square. |
| :param original_image: the image with the original size |
| :param square_image: the image with the square size |
| :param resize_back_to_original: defines if we want to resize the square image back to the original size |
| :return: square image with the original size ratio |
| """ |
|
|
| if isinstance(original_image, Image.Image): |
| original_width, original_height = original_image.size |
| else: |
| original_height, original_width = original_image.shape[:2] |
|
|
| if isinstance(square_image, Image.Image): |
| square_width, square_height = square_image.size |
| else: |
| square_height, square_width = square_image.shape[:2] |
|
|
| if original_width > original_height: |
| ratio = square_width / original_width |
| new_width = square_width |
| new_height = int(ratio * original_height) |
| else: |
| ratio = square_height / original_height |
| new_height = square_height |
| new_width = int(ratio * original_width) |
|
|
| |
| if isinstance(square_image, Image.Image): |
| square_image_with_original_ratio = square_image.crop((0, 0, new_width, new_height)) |
| else: |
| square_image_with_original_ratio = square_image[:new_height, :new_width] |
|
|
| if resize_back_to_original: |
| square_image_with_original_ratio = square_image_with_original_ratio.resize((original_width, original_height)) |
|
|
| return square_image_with_original_ratio |
|
|