DVD / utils /image_utils.py
haodongli's picture
init-1
4b35c4e
from PIL import Image
import matplotlib
import numpy as np
from typing import List, Union
import PIL.Image
import torch
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
def concatenate_images(*image_lists):
# Ensure at least one image list is provided
if not image_lists or not image_lists[0]:
raise ValueError("At least one non-empty image list must be provided")
# Determine the maximum width of any single row and the total height
max_width = 0
total_height = 0
row_widths = []
row_heights = []
# Compute dimensions for each row
for image_list in image_lists:
if image_list: # Ensure the list is not empty
width = sum(img.width for img in image_list)
height = image_list[0].height # Assuming all images in the list have the same height
max_width = max(max_width, width)
total_height += height
row_widths.append(width)
row_heights.append(height)
# Create a new image to concatenate everything into
new_image = Image.new('RGB', (max_width, total_height))
# Concatenate each row of images
y_offset = 0
for i, image_list in enumerate(image_lists):
x_offset = 0
for img in image_list:
new_image.paste(img, (x_offset, y_offset))
x_offset += img.width
y_offset += row_heights[i] # Move the offset down to the next row
return new_image
def colorize_depth_map(depth, mask=None, reverse_color=False):
cm = matplotlib.colormaps["Spectral"]
# normalize
depth = ((depth - depth.min()) / (depth.max() - depth.min()))
# colorize
if reverse_color:
img_colored_np = cm(1 - depth, bytes=False)[:, :, 0:3] # Invert the depth values before applying colormap
else:
img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)
depth_colored = (img_colored_np * 255).astype(np.uint8)
if mask is not None:
masked_image = np.zeros_like(depth_colored)
masked_image[mask.numpy()] = depth_colored[mask.numpy()]
depth_colored_img = Image.fromarray(masked_image)
else:
depth_colored_img = Image.fromarray(depth_colored)
return depth_colored_img
def resize_max_res(
img: torch.Tensor,
max_edge_resolution: int,
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`torch.Tensor`):
Image tensor to be resized. Expected shape: [B, C, H, W]
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`torch.Tensor`: Resized image.
"""
assert 4 == img.dim(), f"Invalid input shape {img.shape}"
original_height, original_width = img.shape[-2:]
downscale_factor = min(
max_edge_resolution / original_width, max_edge_resolution / original_height
)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
return resized_img
def resize_back(
img: Union[torch.Tensor, np.ndarray, PIL.Image.Image, List[PIL.Image.Image]],
target_size: Union[int, tuple[int, int]],
resample_method: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image, List[PIL.Image.Image]]:
"""
Resize image to target size.
Args:
img (`Union[torch.Tensor, np.ndarray, PIL.Image.Image, List[PIL.Image.Image]]`):
Image to be resized.
target_size (`Union[int, tuple[int, int]]`):
Target size of the resized image.
resample_method (`Union[InterpolationMode, int]`):
Resampling method used to resize images.
Returns:
`Union[torch.Tensor, np.ndarray, PIL.Image.Image, List[PIL.Image.Image]]`: Resized image.
"""
if isinstance(img, torch.Tensor): # [B, C, H, W]
resized_img = resize(img, target_size, resample_method, antialias=True)
if isinstance(img, np.ndarray): # [B, H, W, C]
# Conver to Torch.Tensor
img = torch.tensor(img).permute(0, 3, 1, 2)
resized_img = resize(img, target_size, resample_method, antialias=True)
# Convert back to np.ndarray
resized_img = resized_img.permute(0, 2, 3, 1).numpy()
elif isinstance(img, PIL.Image.Image):
target_size = (target_size[1], target_size[0]) # PIL uses (width, height)
resized_img = img.resize(target_size, resample_method)
elif isinstance(img, list) and all(isinstance(i, PIL.Image.Image) for i in img):
target_size = (target_size[1], target_size[0]) # PIL uses (width, height)
resized_img = [i.resize(target_size, resample_method) for i in img]
return resized_img
def get_pil_resample_method(method_str: str) -> int:
resample_method_dict = {
"bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC,
"nearest": Image.NEAREST,
}
resample_method = resample_method_dict.get(method_str, None)
if resample_method is None:
raise ValueError(f"Unknown resampling method: {resample_method}")
else:
return resample_method
def get_tv_resample_method(method_str: str) -> InterpolationMode:
resample_method_dict = {
"bilinear": InterpolationMode.BILINEAR,
"bicubic": InterpolationMode.BICUBIC,
"nearest": InterpolationMode.NEAREST_EXACT,
"nearest-exact": InterpolationMode.NEAREST_EXACT,
}
resample_method = resample_method_dict.get(method_str, None)
if resample_method is None:
raise ValueError(f"Unknown resampling method: {resample_method}")
else:
return resample_method